Skip to content

Commit effa9a9

Browse files
authored
Refactor per-model unit testing (#1083)
* Set up per-model unit tests * Rename tests * Do not modify original object when updating model file name * Distribute unit tests across separate files * Update comments * Update tokenization test file names * Refactor: use asset cache * Destructuring for code deduplication * Remove empty file * Rename deberta-v2 -> deberta_v2 * Rename * Support casting between number and bigint types * Use fp32 tiny models * Move image processing tests to separate folders + auto-detection
1 parent 14bf689 commit effa9a9

File tree

87 files changed

+3930
-3497
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+3930
-3497
lines changed

src/models.js

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,9 +3666,11 @@ export class CLIPModel extends CLIPPreTrainedModel { }
36663666
export class CLIPTextModel extends CLIPPreTrainedModel {
36673667
/** @type {typeof PreTrainedModel.from_pretrained} */
36683668
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3669-
// Update default model file name if not provided
3670-
options.model_file_name ??= 'text_model';
3671-
return super.from_pretrained(pretrained_model_name_or_path, options);
3669+
return super.from_pretrained(pretrained_model_name_or_path, {
3670+
// Update default model file name if not provided
3671+
model_file_name: 'text_model',
3672+
...options,
3673+
});
36723674
}
36733675
}
36743676

@@ -3701,9 +3703,11 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
37013703
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
37023704
/** @type {typeof PreTrainedModel.from_pretrained} */
37033705
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3704-
// Update default model file name if not provided
3705-
options.model_file_name ??= 'text_model';
3706-
return super.from_pretrained(pretrained_model_name_or_path, options);
3706+
return super.from_pretrained(pretrained_model_name_or_path, {
3707+
// Update default model file name if not provided
3708+
model_file_name: 'text_model',
3709+
...options,
3710+
});
37073711
}
37083712
}
37093713

@@ -3713,9 +3717,11 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
37133717
export class CLIPVisionModel extends CLIPPreTrainedModel {
37143718
/** @type {typeof PreTrainedModel.from_pretrained} */
37153719
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3716-
// Update default model file name if not provided
3717-
options.model_file_name ??= 'vision_model';
3718-
return super.from_pretrained(pretrained_model_name_or_path, options);
3720+
return super.from_pretrained(pretrained_model_name_or_path, {
3721+
// Update default model file name if not provided
3722+
model_file_name: 'vision_model',
3723+
...options,
3724+
});
37193725
}
37203726
}
37213727

@@ -3748,9 +3754,11 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
37483754
export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
37493755
/** @type {typeof PreTrainedModel.from_pretrained} */
37503756
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3751-
// Update default model file name if not provided
3752-
options.model_file_name ??= 'vision_model';
3753-
return super.from_pretrained(pretrained_model_name_or_path, options);
3757+
return super.from_pretrained(pretrained_model_name_or_path, {
3758+
// Update default model file name if not provided
3759+
model_file_name: 'vision_model',
3760+
...options,
3761+
});
37543762
}
37553763
}
37563764
//////////////////////////////////////////////////
@@ -3834,9 +3842,11 @@ export class SiglipModel extends SiglipPreTrainedModel { }
38343842
export class SiglipTextModel extends SiglipPreTrainedModel {
38353843
/** @type {typeof PreTrainedModel.from_pretrained} */
38363844
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3837-
// Update default model file name if not provided
3838-
options.model_file_name ??= 'text_model';
3839-
return super.from_pretrained(pretrained_model_name_or_path, options);
3845+
return super.from_pretrained(pretrained_model_name_or_path, {
3846+
// Update default model file name if not provided
3847+
model_file_name: 'text_model',
3848+
...options,
3849+
});
38403850
}
38413851
}
38423852

@@ -3869,9 +3879,11 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
38693879
export class SiglipVisionModel extends CLIPPreTrainedModel {
38703880
/** @type {typeof PreTrainedModel.from_pretrained} */
38713881
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3872-
// Update default model file name if not provided
3873-
options.model_file_name ??= 'vision_model';
3874-
return super.from_pretrained(pretrained_model_name_or_path, options);
3882+
return super.from_pretrained(pretrained_model_name_or_path, {
3883+
// Update default model file name if not provided
3884+
model_file_name: 'vision_model',
3885+
...options,
3886+
});
38753887
}
38763888
}
38773889
//////////////////////////////////////////////////
@@ -3926,18 +3938,22 @@ export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
39263938
export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
39273939
/** @type {typeof PreTrainedModel.from_pretrained} */
39283940
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3929-
// Update default model file name if not provided
3930-
options.model_file_name ??= 'text_model';
3931-
return super.from_pretrained(pretrained_model_name_or_path, options);
3941+
return super.from_pretrained(pretrained_model_name_or_path, {
3942+
// Update default model file name if not provided
3943+
model_file_name: 'text_model',
3944+
...options,
3945+
});
39323946
}
39333947
}
39343948

39353949
export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
39363950
/** @type {typeof PreTrainedModel.from_pretrained} */
39373951
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3938-
// Update default model file name if not provided
3939-
options.model_file_name ??= 'vision_model';
3940-
return super.from_pretrained(pretrained_model_name_or_path, options);
3952+
return super.from_pretrained(pretrained_model_name_or_path, {
3953+
// Update default model file name if not provided
3954+
model_file_name: 'vision_model',
3955+
...options,
3956+
});
39413957
}
39423958
}
39433959
//////////////////////////////////////////////////
@@ -6159,9 +6175,11 @@ export class ClapModel extends ClapPreTrainedModel { }
61596175
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
61606176
/** @type {typeof PreTrainedModel.from_pretrained} */
61616177
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
6162-
// Update default model file name if not provided
6163-
options.model_file_name ??= 'text_model';
6164-
return super.from_pretrained(pretrained_model_name_or_path, options);
6178+
return super.from_pretrained(pretrained_model_name_or_path, {
6179+
// Update default model file name if not provided
6180+
model_file_name: 'text_model',
6181+
...options,
6182+
});
61656183
}
61666184
}
61676185

@@ -6194,9 +6212,11 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
61946212
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
61956213
/** @type {typeof PreTrainedModel.from_pretrained} */
61966214
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
6197-
// Update default model file name if not provided
6198-
options.model_file_name ??= 'audio_model';
6199-
return super.from_pretrained(pretrained_model_name_or_path, options);
6215+
return super.from_pretrained(pretrained_model_name_or_path, {
6216+
// Update default model file name if not provided
6217+
model_file_name: 'audio_model',
6218+
...options,
6219+
});
62006220
}
62016221
}
62026222
//////////////////////////////////////////////////

src/utils/tensor.js

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,8 +772,21 @@ export class Tensor {
772772
if (!DataTypeMap.hasOwnProperty(type)) {
773773
throw new Error(`Unsupported type: ${type}`);
774774
}
775+
776+
// Handle special cases where a mapping function is needed (e.g., where one type is a bigint and the other is a number)
777+
let map_fn;
778+
const is_source_bigint = ['int64', 'uint64'].includes(this.type);
779+
const is_dest_bigint = ['int64', 'uint64'].includes(type);
780+
if (is_source_bigint && !is_dest_bigint) {
781+
// TypeError: Cannot convert a BigInt value to a number
782+
map_fn = Number;
783+
} else if (!is_source_bigint && is_dest_bigint) {
784+
// TypeError: Cannot convert [x] to a BigInt
785+
map_fn = BigInt;
786+
}
787+
775788
// @ts-ignore
776-
return new Tensor(type, DataTypeMap[type].from(this.data), this.dims);
789+
return new Tensor(type, DataTypeMap[type].from(this.data, map_fn), this.dims);
777790
}
778791
}
779792

tests/asset_cache.js

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { RawImage } from "../src/transformers.js";
2+
3+
const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/";
4+
const TEST_IMAGES = Object.freeze({
5+
white_image: BASE_URL + "white-image.png",
6+
pattern_3x3: BASE_URL + "pattern_3x3.png",
7+
pattern_3x5: BASE_URL + "pattern_3x5.png",
8+
checkerboard_8x8: BASE_URL + "checkerboard_8x8.png",
9+
checkerboard_64x32: BASE_URL + "checkerboard_64x32.png",
10+
gradient_1280x640: BASE_URL + "gradient_1280x640.png",
11+
receipt: BASE_URL + "receipt.png",
12+
tiger: BASE_URL + "tiger.jpg",
13+
paper: BASE_URL + "nougat_paper.png",
14+
cats: BASE_URL + "cats.jpg",
15+
16+
// grayscale image
17+
skateboard: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png",
18+
19+
vitmatte_image: BASE_URL + "vitmatte_image.png",
20+
vitmatte_trimap: BASE_URL + "vitmatte_trimap.png",
21+
22+
beetle: BASE_URL + "beetle.png",
23+
book_cover: BASE_URL + "book-cover.png",
24+
});
25+
26+
/** @type {Map<string, RawImage>} */
27+
const IMAGE_CACHE = new Map();
28+
const load_image = async (url) => {
29+
const cached = IMAGE_CACHE.get(url);
30+
if (cached) {
31+
return cached;
32+
}
33+
const image = await RawImage.fromURL(url);
34+
IMAGE_CACHE.set(url, image);
35+
return image;
36+
};
37+
38+
/**
39+
* Load a cached image.
40+
* @param {keyof typeof TEST_IMAGES} name The name of the image to load.
41+
* @returns {Promise<RawImage>} The loaded image.
42+
*/
43+
export const load_cached_image = (name) => load_image(TEST_IMAGES[name]);

tests/init.js

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,66 @@ export function init() {
5757
registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY);
5858
}
5959

60+
export const MAX_PROCESSOR_LOAD_TIME = 10_000; // 10 seconds
6061
export const MAX_MODEL_LOAD_TIME = 15_000; // 15 seconds
6162
export const MAX_TEST_EXECUTION_TIME = 60_000; // 60 seconds
6263
export const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second
6364

6465
export const MAX_TEST_TIME = MAX_MODEL_LOAD_TIME + MAX_TEST_EXECUTION_TIME + MAX_MODEL_DISPOSE_TIME;
66+
67+
export const DEFAULT_MODEL_OPTIONS = {
68+
dtype: "fp32",
69+
};
70+
71+
expect.extend({
72+
toBeCloseToNested(received, expected, numDigits = 2) {
73+
const compare = (received, expected, path = "") => {
74+
if (typeof received === "number" && typeof expected === "number" && !Number.isInteger(received) && !Number.isInteger(expected)) {
75+
const pass = Math.abs(received - expected) < Math.pow(10, -numDigits);
76+
return {
77+
pass,
78+
message: () => (pass ? `✓ At path '${path}': expected ${received} not to be close to ${expected} with tolerance of ${numDigits} decimal places` : `✗ At path '${path}': expected ${received} to be close to ${expected} with tolerance of ${numDigits} decimal places`),
79+
};
80+
} else if (Array.isArray(received) && Array.isArray(expected)) {
81+
if (received.length !== expected.length) {
82+
return {
83+
pass: false,
84+
message: () => `✗ At path '${path}': array lengths differ. Received length ${received.length}, expected length ${expected.length}`,
85+
};
86+
}
87+
for (let i = 0; i < received.length; i++) {
88+
const result = compare(received[i], expected[i], `${path}[${i}]`);
89+
if (!result.pass) return result;
90+
}
91+
} else if (typeof received === "object" && typeof expected === "object" && received !== null && expected !== null) {
92+
const receivedKeys = Object.keys(received);
93+
const expectedKeys = Object.keys(expected);
94+
if (receivedKeys.length !== expectedKeys.length) {
95+
return {
96+
pass: false,
97+
message: () => `✗ At path '${path}': object keys length differ. Received keys: ${JSON.stringify(receivedKeys)}, expected keys: ${JSON.stringify(expectedKeys)}`,
98+
};
99+
}
100+
for (const key of receivedKeys) {
101+
if (!expected.hasOwnProperty(key)) {
102+
return {
103+
pass: false,
104+
message: () => `✗ At path '${path}': key '${key}' found in received but not in expected`,
105+
};
106+
}
107+
const result = compare(received[key], expected[key], `${path}.${key}`);
108+
if (!result.pass) return result;
109+
}
110+
} else {
111+
const pass = received === expected;
112+
return {
113+
pass,
114+
message: () => (pass ? `✓ At path '${path}': expected ${JSON.stringify(received)} not to equal ${JSON.stringify(expected)}` : `✗ At path '${path}': expected ${JSON.stringify(received)} to equal ${JSON.stringify(expected)}`),
115+
};
116+
}
117+
return { pass: true };
118+
};
119+
120+
return compare(received, expected);
121+
},
122+
});

0 commit comments

Comments
 (0)