Skip to content

Commit 9056f76

Browse files
authored
Fix setting of model_file_name (#1114)
* Fix setting of model_file_name * Add CLIP-like unit test for image feature extraction pipeline
1 parent da2c1e9 commit 9056f76

File tree

2 files changed

+101
-57
lines changed

2 files changed

+101
-57
lines changed

src/models.js

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3421,7 +3421,7 @@ export class MoonshinePreTrainedModel extends PreTrainedModel {
34213421
*/
34223422
export class MoonshineModel extends MoonshinePreTrainedModel { }
34233423

3424-
export class MoonshineForConditionalGeneration extends MoonshinePreTrainedModel { }
3424+
export class MoonshineForConditionalGeneration extends MoonshinePreTrainedModel { }
34253425
//////////////////////////////////////////////////
34263426

34273427

@@ -3821,9 +3821,9 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
38213821
/** @type {typeof PreTrainedModel.from_pretrained} */
38223822
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
38233823
return super.from_pretrained(pretrained_model_name_or_path, {
3824-
// Update default model file name if not provided
3825-
model_file_name: 'text_model',
38263824
...options,
3825+
// Update default model file name if not provided
3826+
model_file_name: options.model_file_name ?? 'text_model',
38273827
});
38283828
}
38293829
}
@@ -3858,9 +3858,9 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
38583858
/** @type {typeof PreTrainedModel.from_pretrained} */
38593859
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
38603860
return super.from_pretrained(pretrained_model_name_or_path, {
3861-
// Update default model file name if not provided
3862-
model_file_name: 'text_model',
38633861
...options,
3862+
// Update default model file name if not provided
3863+
model_file_name: options.model_file_name ?? 'text_model',
38643864
});
38653865
}
38663866
}
@@ -3872,9 +3872,9 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
38723872
/** @type {typeof PreTrainedModel.from_pretrained} */
38733873
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
38743874
return super.from_pretrained(pretrained_model_name_or_path, {
3875-
// Update default model file name if not provided
3876-
model_file_name: 'vision_model',
38773875
...options,
3876+
// Update default model file name if not provided
3877+
model_file_name: options.model_file_name ?? 'vision_model',
38783878
});
38793879
}
38803880
}
@@ -3909,9 +3909,9 @@ export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
39093909
/** @type {typeof PreTrainedModel.from_pretrained} */
39103910
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
39113911
return super.from_pretrained(pretrained_model_name_or_path, {
3912-
// Update default model file name if not provided
3913-
model_file_name: 'vision_model',
39143912
...options,
3913+
// Update default model file name if not provided
3914+
model_file_name: options.model_file_name ?? 'vision_model',
39153915
});
39163916
}
39173917
}
@@ -3997,9 +3997,9 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
39973997
/** @type {typeof PreTrainedModel.from_pretrained} */
39983998
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
39993999
return super.from_pretrained(pretrained_model_name_or_path, {
4000-
// Update default model file name if not provided
4001-
model_file_name: 'text_model',
40024000
...options,
4001+
// Update default model file name if not provided
4002+
model_file_name: options.model_file_name ?? 'text_model',
40034003
});
40044004
}
40054005
}
@@ -4034,9 +4034,9 @@ export class SiglipVisionModel extends CLIPPreTrainedModel {
40344034
/** @type {typeof PreTrainedModel.from_pretrained} */
40354035
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
40364036
return super.from_pretrained(pretrained_model_name_or_path, {
4037-
// Update default model file name if not provided
4038-
model_file_name: 'vision_model',
40394037
...options,
4038+
// Update default model file name if not provided
4039+
model_file_name: options.model_file_name ?? 'vision_model',
40404040
});
40414041
}
40424042
}
@@ -4093,9 +4093,9 @@ export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
40934093
/** @type {typeof PreTrainedModel.from_pretrained} */
40944094
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
40954095
return super.from_pretrained(pretrained_model_name_or_path, {
4096-
// Update default model file name if not provided
4097-
model_file_name: 'text_model',
40984096
...options,
4097+
// Update default model file name if not provided
4098+
model_file_name: options.model_file_name ?? 'text_model',
40994099
});
41004100
}
41014101
}
@@ -4104,9 +4104,9 @@ export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
41044104
/** @type {typeof PreTrainedModel.from_pretrained} */
41054105
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
41064106
return super.from_pretrained(pretrained_model_name_or_path, {
4107-
// Update default model file name if not provided
4108-
model_file_name: 'vision_model',
41094107
...options,
4108+
// Update default model file name if not provided
4109+
model_file_name: options.model_file_name ?? 'vision_model',
41104110
});
41114111
}
41124112
}
@@ -6338,9 +6338,9 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
63386338
/** @type {typeof PreTrainedModel.from_pretrained} */
63396339
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
63406340
return super.from_pretrained(pretrained_model_name_or_path, {
6341-
// Update default model file name if not provided
6342-
model_file_name: 'text_model',
63436341
...options,
6342+
// Update default model file name if not provided
6343+
model_file_name: options.model_file_name ?? 'text_model',
63446344
});
63456345
}
63466346
}
@@ -6375,9 +6375,9 @@ export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
63756375
/** @type {typeof PreTrainedModel.from_pretrained} */
63766376
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
63776377
return super.from_pretrained(pretrained_model_name_or_path, {
6378-
// Update default model file name if not provided
6379-
model_file_name: 'audio_model',
63806378
...options,
6379+
// Update default model file name if not provided
6380+
model_file_name: options.model_file_name ?? 'audio_model',
63816381
});
63826382
}
63836383
}

tests/pipelines/test_pipelines_image_feature_extraction.js

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,89 @@ const PIPELINE_ID = "image-feature-extraction";
77

88
export default () => {
99
describe("Image Feature Extraction", () => {
10-
const model_id = "hf-internal-testing/tiny-random-ViTMAEModel";
11-
/** @type {ImageFeatureExtractionPipeline} */
12-
let pipe;
13-
let images;
14-
beforeAll(async () => {
15-
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
16-
images = await Promise.all([load_cached_image("white_image"), load_cached_image("blue_image")]);
17-
}, MAX_MODEL_LOAD_TIME);
18-
19-
it("should be an instance of ImageFeatureExtractionPipeline", () => {
20-
expect(pipe).toBeInstanceOf(ImageFeatureExtractionPipeline);
21-
});
10+
describe("Default", () => {
11+
const model_id = "hf-internal-testing/tiny-random-ViTMAEModel";
12+
/** @type {ImageFeatureExtractionPipeline} */
13+
let pipe;
14+
let images;
15+
beforeAll(async () => {
16+
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
17+
images = await Promise.all([load_cached_image("white_image"), load_cached_image("blue_image")]);
18+
}, MAX_MODEL_LOAD_TIME);
2219

23-
describe("batch_size=1", () => {
24-
it(
25-
"default",
26-
async () => {
27-
const output = await pipe(images[0]);
28-
expect(output.dims).toEqual([1, 91, 32]);
29-
expect(output.mean().item()).toBeCloseTo(-8.507473614471905e-10, 6);
30-
},
31-
MAX_TEST_EXECUTION_TIME,
32-
);
33-
});
20+
it("should be an instance of ImageFeatureExtractionPipeline", () => {
21+
expect(pipe).toBeInstanceOf(ImageFeatureExtractionPipeline);
22+
});
23+
24+
describe("batch_size=1", () => {
25+
it(
26+
"default",
27+
async () => {
28+
const output = await pipe(images[0]);
29+
expect(output.dims).toEqual([1, 91, 32]);
30+
expect(output.mean().item()).toBeCloseTo(-8.507473614471905e-10, 6);
31+
},
32+
MAX_TEST_EXECUTION_TIME,
33+
);
34+
});
35+
36+
describe("batch_size>1", () => {
37+
it(
38+
"default",
39+
async () => {
40+
const output = await pipe(images);
41+
expect(output.dims).toEqual([images.length, 91, 32]);
42+
expect(output.mean().item()).toBeCloseTo(-5.997602414709036e-10, 6);
43+
},
44+
MAX_TEST_EXECUTION_TIME,
45+
);
46+
});
3447

35-
describe("batch_size>1", () => {
36-
it(
37-
"default",
38-
async () => {
39-
const output = await pipe(images);
40-
expect(output.dims).toEqual([images.length, 91, 32]);
41-
expect(output.mean().item()).toBeCloseTo(-5.997602414709036e-10, 6);
42-
},
43-
MAX_TEST_EXECUTION_TIME,
44-
);
48+
afterAll(async () => {
49+
await pipe.dispose();
50+
}, MAX_MODEL_DISPOSE_TIME);
4551
});
52+
describe("CLIP-like", () => {
53+
const model_id = "hf-internal-testing/tiny-random-CLIPModel";
54+
/** @type {ImageFeatureExtractionPipeline} */
55+
let pipe;
56+
let images;
57+
beforeAll(async () => {
58+
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
59+
images = await Promise.all([load_cached_image("white_image"), load_cached_image("blue_image")]);
60+
}, MAX_MODEL_LOAD_TIME);
4661

47-
afterAll(async () => {
48-
await pipe.dispose();
49-
}, MAX_MODEL_DISPOSE_TIME);
62+
it("should be an instance of ImageFeatureExtractionPipeline", () => {
63+
expect(pipe).toBeInstanceOf(ImageFeatureExtractionPipeline);
64+
});
65+
66+
describe("batch_size=1", () => {
67+
it(
68+
"default",
69+
async () => {
70+
const output = await pipe(images[0]);
71+
expect(output.dims).toEqual([1, 64]);
72+
expect(output.mean().item()).toBeCloseTo(-0.11340035498142242, 6);
73+
},
74+
MAX_TEST_EXECUTION_TIME,
75+
);
76+
});
77+
78+
describe("batch_size>1", () => {
79+
it(
80+
"default",
81+
async () => {
82+
const output = await pipe(images);
83+
expect(output.dims).toEqual([images.length, 64]);
84+
expect(output.mean().item()).toBeCloseTo(-0.11006651818752289, 6);
85+
},
86+
MAX_TEST_EXECUTION_TIME,
87+
);
88+
});
89+
90+
afterAll(async () => {
91+
await pipe.dispose();
92+
}, MAX_MODEL_DISPOSE_TIME);
93+
});
5094
});
5195
};

0 commit comments

Comments
 (0)