Skip to content

Commit 05924f9

Browse files
committed
Add PKV generation unit test with onnxruntime-genai GQA model
1 parent c701ccf commit 05924f9

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/utils/generation.test.js

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,67 @@ describe("PKV caching", () => {
282282
}, MAX_MODEL_DISPOSE_TIME);
283283
});
284284

285+
describe("LlamaForCausalLM (onnxruntime-genai)", () => {
286+
const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-GQA";
287+
/** @type {LlamaForCausalLM} */
288+
let model;
289+
/** @type {LlamaTokenizer} */
290+
let tokenizer;
291+
beforeAll(async () => {
292+
model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
293+
tokenizer = await LlamaTokenizer.from_pretrained(model_id);
294+
}, MAX_MODEL_LOAD_TIME);
295+
296+
it(
297+
"batch_size=1",
298+
async () => {
299+
const inputs = tokenizer("1");
300+
301+
// Generate first sequence w/o PKV
302+
// NOTE: `return_dict_in_generate=true` is required to get PKV
303+
const { past_key_values, sequences } = await model.generate({
304+
...inputs,
305+
max_new_tokens: 5,
306+
do_sample: false,
307+
return_dict_in_generate: true,
308+
});
309+
310+
// Update output with new text
311+
const decoded = tokenizer.batch_decode(sequences, {
312+
skip_special_tokens: false,
313+
})[0];
314+
const new_inputs = tokenizer(decoded + "2", {
315+
add_special_tokens: false,
316+
});
317+
318+
// Run w/o PKV
319+
const generated_ids = await model.generate({
320+
...new_inputs,
321+
max_new_tokens: 3,
322+
do_sample: false,
323+
});
324+
325+
// Run w/ PKV
326+
const generated_ids_pkv = await model.generate({
327+
...new_inputs,
328+
past_key_values,
329+
max_new_tokens: 3,
330+
do_sample: false,
331+
});
332+
333+
const target = [[128000n, 16n, 34732n, 98805n, 116404n, 68265n, 99392n, 17n, 21855n, 60933n, 14285n]];
334+
335+
expect(generated_ids.tolist()).toEqual(target);
336+
expect(generated_ids_pkv.tolist()).toEqual(target);
337+
},
338+
MAX_TEST_EXECUTION_TIME,
339+
);
340+
341+
afterAll(async () => {
342+
await model?.dispose();
343+
}, MAX_MODEL_DISPOSE_TIME);
344+
});
345+
285346
describe("LlavaForConditionalGeneration", () => {
286347
const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration";
287348
/** @type {LlavaForConditionalGeneration} */

0 commit comments

Comments
 (0)