Skip to content

Commit 114ccf7

Browse files
committed
Simplify tests
1 parent f580db2 commit 114ccf7

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

tests/models.test.js

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ describe("Loading different architecture types", () => {
3434
["hf-internal-testing/tiny-random-GPT2LMHeadModel", [AutoModelForCausalLM, GPT2LMHeadModel], [AutoTokenizer, GPT2Tokenizer]], // Decoder-only
3535
["hf-internal-testing/tiny-random-T5ForConditionalGeneration", [AutoModelForSeq2SeqLM, T5ForConditionalGeneration], [AutoTokenizer, T5Tokenizer]], // Encoder-decoder
3636
["onnx-internal-testing/tiny-random-LlamaForCausalLM-ONNX_external", [AutoModelForCausalLM, LlamaForCausalLM], [AutoTokenizer, LlamaTokenizer]], // Decoder-only w/ external data
37-
["onnx-internal-testing/tiny-random-WhisperForConditionalGeneration-ONNX_external", [AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration], [AutoProcessor, WhisperProcessor], {}, "audio"], // Encoder-decoder-only w/ external data
37+
["onnx-internal-testing/tiny-random-WhisperForConditionalGeneration-ONNX_external", [AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration], [AutoProcessor, WhisperProcessor], {}], // Encoder-decoder-only w/ external data
3838
];
3939

4040
const texts = ["Once upon a time", "I like to eat apples"];
4141

42-
for (const [model_id, models, processors, modelOptions, modality] of models_to_test) {
42+
for (const [model_id, models, processors, modelOptions] of models_to_test) {
4343
// Test that both the auto model and the specific model work
4444
for (let i = 0; i < processors.length; ++i) {
4545
const processorClassToTest = processors[i];
@@ -52,22 +52,24 @@ describe("Loading different architecture types", () => {
5252
const processor = await processorClassToTest.from_pretrained(model_id);
5353
const model = await modelClassToTest.from_pretrained(model_id, modelOptions ?? DEFAULT_MODEL_OPTIONS);
5454

55-
const tests = modality === "audio"
56-
? [new Float32Array(16000)]
57-
: [
58-
texts[0], // single
59-
texts, // batched
60-
]
55+
const tests = [
56+
texts[0], // single
57+
texts, // batched
58+
]
59+
60+
const { model_type } = model.config;
61+
const tokenizer = model_type === "whisper" ? processor.tokenizer : processor;
62+
const feature_extractor = model_type === "whisper" ? processor.feature_extractor : null;
6163

6264
for (const test of tests) {
63-
const inputs = await processor(test, { truncation: true, padding: true });
65+
const inputs = await tokenizer(test, { truncation: true, padding: true });
6466
if (model.config.is_encoder_decoder) {
65-
if (model.config.model_type === "whisper") {
66-
inputs.decoder_input_ids = processor.tokenizer(texts[0]).input_ids;
67-
} else {
68-
inputs.decoder_input_ids = inputs.input_ids;
69-
}
67+
inputs.decoder_input_ids = inputs.input_ids;
68+
}
69+
if (feature_extractor) {
70+
Object.assign(inputs, await feature_extractor(new Float32Array(16000)));
7071
}
72+
7173
const output = await model(inputs);
7274

7375
if (output.logits) {

0 commit comments

Comments
 (0)