Skip to content

Commit f580db2

Browse files
committed
Add external data model architecture tests
1 parent c2f2bd4 commit f580db2

File tree

1 file changed

+45
-21
lines changed

1 file changed

+45
-21
lines changed

tests/models.test.js

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,24 @@
22
* Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`);
33
*/
44

5-
import { AutoTokenizer, AutoModel, BertModel, GPT2Model, T5ForConditionalGeneration, BertTokenizer, GPT2Tokenizer, T5Tokenizer } from "../src/transformers.js";
5+
import {
6+
AutoTokenizer,
7+
AutoProcessor,
8+
BertForMaskedLM,
9+
GPT2LMHeadModel,
10+
T5ForConditionalGeneration,
11+
BertTokenizer,
12+
GPT2Tokenizer,
13+
T5Tokenizer,
14+
LlamaTokenizer,
15+
LlamaForCausalLM,
16+
WhisperForConditionalGeneration,
17+
WhisperProcessor,
18+
AutoModelForMaskedLM,
19+
AutoModelForCausalLM,
20+
AutoModelForSeq2SeqLM,
21+
AutoModelForSpeechSeq2Seq,
22+
} from "../src/transformers.js";
623
import { init, MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "./init.js";
724
import { compare, collect_and_execute_tests } from "./test_utils.js";
825

@@ -12,44 +29,51 @@ init();
1229
describe("Loading different architecture types", () => {
1330
// List all models which will be tested
1431
const models_to_test = [
15-
// [name, modelClass, tokenizerClass]
16-
["hf-internal-testing/tiny-random-BertForMaskedLM", BertModel, BertTokenizer], // Encoder-only
17-
["hf-internal-testing/tiny-random-GPT2LMHeadModel", GPT2Model, GPT2Tokenizer], // Decoder-only
18-
["hf-internal-testing/tiny-random-T5ForConditionalGeneration", T5ForConditionalGeneration, T5Tokenizer], // Encoder-decoder
32+
// [name, [AutoModelClass, ModelClass], [AutoProcessorClass, ProcessorClass], [modelOptions?], [modality?]]
33+
["hf-internal-testing/tiny-random-BertForMaskedLM", [AutoModelForMaskedLM, BertForMaskedLM], [AutoTokenizer, BertTokenizer]], // Encoder-only
34+
["hf-internal-testing/tiny-random-GPT2LMHeadModel", [AutoModelForCausalLM, GPT2LMHeadModel], [AutoTokenizer, GPT2Tokenizer]], // Decoder-only
35+
["hf-internal-testing/tiny-random-T5ForConditionalGeneration", [AutoModelForSeq2SeqLM, T5ForConditionalGeneration], [AutoTokenizer, T5Tokenizer]], // Encoder-decoder
36+
["onnx-internal-testing/tiny-random-LlamaForCausalLM-ONNX_external", [AutoModelForCausalLM, LlamaForCausalLM], [AutoTokenizer, LlamaTokenizer]], // Decoder-only w/ external data
37+
["onnx-internal-testing/tiny-random-WhisperForConditionalGeneration-ONNX_external", [AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration], [AutoProcessor, WhisperProcessor], {}, "audio"], // Encoder-decoder-only w/ external data
1938
];
2039

2140
const texts = ["Once upon a time", "I like to eat apples"];
2241

23-
for (const [model_id, modelClass, tokenizerClass] of models_to_test) {
42+
for (const [model_id, models, processors, modelOptions, modality] of models_to_test) {
2443
// Test that both the auto model and the specific model work
25-
const tokenizers = [AutoTokenizer, tokenizerClass];
26-
const models = [AutoModel, modelClass];
27-
28-
for (let i = 0; i < tokenizers.length; ++i) {
29-
const tokenizerClassToTest = tokenizers[i];
44+
for (let i = 0; i < processors.length; ++i) {
45+
const processorClassToTest = processors[i];
3046
const modelClassToTest = models[i];
3147

3248
it(
3349
`${model_id} (${modelClassToTest.name})`,
3450
async () => {
35-
// Load model and tokenizer
36-
const tokenizer = await tokenizerClassToTest.from_pretrained(model_id);
37-
const model = await modelClassToTest.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
51+
// Load model and processor
52+
const processor = await processorClassToTest.from_pretrained(model_id);
53+
const model = await modelClassToTest.from_pretrained(model_id, modelOptions ?? DEFAULT_MODEL_OPTIONS);
54+
55+
const tests = modality === "audio"
56+
? [new Float32Array(16000)]
57+
: [
58+
texts[0], // single
59+
texts, // batched
60+
]
3861

39-
const tests = [
40-
texts[0], // single
41-
texts, // batched
42-
];
4362
for (const test of tests) {
44-
const inputs = await tokenizer(test, { truncation: true, padding: true });
63+
const inputs = await processor(test, { truncation: true, padding: true });
4564
if (model.config.is_encoder_decoder) {
46-
inputs.decoder_input_ids = inputs.input_ids;
65+
if (model.config.model_type === "whisper") {
66+
inputs.decoder_input_ids = processor.tokenizer(texts[0]).input_ids;
67+
} else {
68+
inputs.decoder_input_ids = inputs.input_ids;
69+
}
4770
}
4871
const output = await model(inputs);
4972

5073
if (output.logits) {
5174
// Ensure correct shapes
52-
const expected_shape = [...inputs.input_ids.dims, model.config.vocab_size];
75+
const input_ids = inputs.input_ids ?? inputs.decoder_input_ids;
76+
const expected_shape = [...input_ids.dims, model.config.vocab_size];
5377
const actual_shape = output.logits.dims;
5478
compare(expected_shape, actual_shape);
5579
} else if (output.last_hidden_state) {

0 commit comments

Comments
 (0)