Skip to content

Commit 576202d

Browse files
committed
Fix CI build failure.
1 parent b5f4119 commit 576202d

File tree

8 files changed

+53
-12
lines changed

8 files changed

+53
-12
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
},
158158
"dependencies": {
159159
"@huggingface/transformers": "github:mybigday/transformers.js-rn#merge",
160+
"onnxruntime-react-native": "^1.21.0",
160161
"patch-package": "^8.0.0",
161162
"postinstall-postinstall": "^2.1.0",
162163
"text-encoding-polyfill": "^0.6.7"

src/models/base.tsx

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,22 @@ export class Base {
109109

110110
protected argmax(t: Tensor): number {
111111
const arr = t.data;
112-
const start = t.dims[2] * (t.dims[1] - 1);
112+
const dims = t.dims;
113+
114+
if (!dims || dims.length < 3 || !dims[1] || !dims[2]) {
115+
throw new Error('Invalid tensor dimensions');
116+
}
117+
118+
const start = dims[2] * (dims[1] - 1);
113119
let max = arr[start];
114120
let maxidx = 0;
115121

116-
for (let i = 0; i < t.dims[2]; i++) {
122+
for (let i = 0; i < dims[2]; i++) {
117123
const val = arr[i + start];
118124
if (!isFinite(val as number)) {
119125
throw new Error('found infinitive in logits');
120126
}
121-
if (val > max) {
127+
if (val !== undefined && max !== undefined && val > max) {
122128
max = val;
123129
maxidx = i;
124130
}
@@ -138,7 +144,10 @@ export class Base {
138144
if (t !== undefined && t.location === 'gpu-buffer') {
139145
t.dispose();
140146
}
141-
feed[newName] = outputs[name];
147+
const outputTensor = outputs[name];
148+
if (outputTensor) {
149+
feed[newName] = outputTensor;
150+
}
142151
}
143152
}
144153
}

src/models/text-embedding.tsx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,20 @@ export class TextEmbedding extends Base {
4646
// Calculate mean across token dimension (dim 1) to get a single embedding vector
4747
const data = embeddings.data as Float32Array;
4848
const [, seqLen, hiddenSize] = embeddings.dims;
49+
50+
if (!seqLen || !hiddenSize || !data) {
51+
throw new Error('Invalid embedding dimensions or data');
52+
}
53+
4954
const result = new Float32Array(hiddenSize);
5055

5156
for (let h = 0; h < hiddenSize; h++) {
5257
let sum = 0;
5358
for (let s = 0; s < seqLen; s++) {
54-
sum += data[s * hiddenSize + h];
59+
const index = s * hiddenSize + h;
60+
if (data[index] !== undefined) {
61+
sum += data[index];
62+
}
5563
}
5664
result[h] = sum / seqLen;
5765
}

src/pipelines/text-embedding.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import {
22
env,
33
AutoTokenizer,
4-
PreTrainedTokenizer,
54
} from '@huggingface/transformers';
5+
import type { PreTrainedTokenizer } from '@huggingface/transformers';
66
import { TextEmbedding as Model } from '../models/text-embedding';
7-
import { LoadOptions } from '../models/base';
7+
import type { LoadOptions } from '../models/base';
88

99
/** Initialization Options for Text Embedding */
1010
export interface TextEmbeddingOptions extends LoadOptions {
@@ -48,7 +48,7 @@ async function embed(text: string): Promise<Float32Array> {
4848
max_length: _options.max_tokens,
4949
});
5050

51-
return await model.embed(input_ids);
51+
return await model.embed(input_ids.map(BigInt));
5252
}
5353

5454
/**

src/pipelines/text-generation.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import {
22
env,
33
AutoTokenizer,
4-
PreTrainedTokenizer,
54
} from '@huggingface/transformers';
5+
import type { PreTrainedTokenizer } from '@huggingface/transformers';
66
import { TextGeneration as Model } from '../models/text-generation';
7-
import { LoadOptions } from '../models/base';
7+
import type { LoadOptions } from '../models/base';
88

99
/** Initialization Options */
1010
export interface InitOptions extends LoadOptions {
@@ -78,7 +78,7 @@ async function generate(
7878

7979
const output_index = model.outputTokens.length + input_ids.length;
8080
const output_tokens = await model.generate(
81-
input_ids,
81+
input_ids.map(BigInt),
8282
(tokens) => {
8383
callback(record_output(token_to_text(tokens, output_index)));
8484
},
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
declare module '@huggingface/transformers' {
2+
export interface PreTrainedTokenizer {
3+
(text: string, options?: {
4+
return_tensor?: boolean;
5+
padding?: boolean;
6+
truncation?: boolean;
7+
max_length?: number;
8+
}): Promise<{ input_ids: number[] }>;
9+
decode(tokens: number[], options?: { skip_special_tokens?: boolean }): string;
10+
}
11+
12+
export class AutoTokenizer {
13+
static from_pretrained(model_name: string): Promise<PreTrainedTokenizer>;
14+
}
15+
16+
export const env: {
17+
allowRemoteModels: boolean;
18+
allowLocalModels: boolean;
19+
logLevel?: string;
20+
};
21+
}

tsconfig.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"paths": {
55
"react-native-transformers": ["./src/index"]
66
},
7+
"typeRoots": ["./node_modules/@types", "./src/types"],
78
"allowUnreachableCode": false,
89
"allowUnusedLabels": false,
910
"esModuleInterop": true,

yarn.lock

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11778,7 +11778,7 @@ __metadata:
1177811778
languageName: node
1177911779
linkType: hard
1178011780

11781-
"onnxruntime-react-native@npm:^1.22.0":
11781+
"onnxruntime-react-native@npm:^1.21.0, onnxruntime-react-native@npm:^1.22.0":
1178211782
version: 1.22.0
1178311783
resolution: "onnxruntime-react-native@npm:1.22.0"
1178411784
dependencies:
@@ -12759,6 +12759,7 @@ __metadata:
1275912759
eslint-plugin-ft-flow: ^3.0.11
1276012760
eslint-plugin-prettier: ^5.2.3
1276112761
jest: ^29.7.0
12762+
onnxruntime-react-native: ^1.21.0
1276212763
patch-package: ^8.0.0
1276312764
postinstall-postinstall: ^2.1.0
1276412765
prettier: ^3.0.3

0 commit comments

Comments
 (0)