@@ -282,6 +282,67 @@ describe("PKV caching", () => {
282282 } , MAX_MODEL_DISPOSE_TIME ) ;
283283 } ) ;
284284
285+ describe ( "LlamaForCausalLM (onnxruntime-genai)" , ( ) => {
286+ const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-GQA" ;
287+ /** @type {LlamaForCausalLM } */
288+ let model ;
289+ /** @type {LlamaTokenizer } */
290+ let tokenizer ;
291+ beforeAll ( async ( ) => {
292+ model = await LlamaForCausalLM . from_pretrained ( model_id , DEFAULT_MODEL_OPTIONS ) ;
293+ tokenizer = await LlamaTokenizer . from_pretrained ( model_id ) ;
294+ } , MAX_MODEL_LOAD_TIME ) ;
295+
296+ it (
297+ "batch_size=1" ,
298+ async ( ) => {
299+ const inputs = tokenizer ( "1" ) ;
300+
301+ // Generate first sequence w/o PKV
302+ // NOTE: `return_dict_in_generate=true` is required to get PKV
303+ const { past_key_values, sequences } = await model . generate ( {
304+ ...inputs ,
305+ max_new_tokens : 5 ,
306+ do_sample : false ,
307+ return_dict_in_generate : true ,
308+ } ) ;
309+
310+ // Update output with new text
311+ const decoded = tokenizer . batch_decode ( sequences , {
312+ skip_special_tokens : false ,
313+ } ) [ 0 ] ;
314+ const new_inputs = tokenizer ( decoded + "2" , {
315+ add_special_tokens : false ,
316+ } ) ;
317+
318+ // Run w/o PKV
319+ const generated_ids = await model . generate ( {
320+ ...new_inputs ,
321+ max_new_tokens : 3 ,
322+ do_sample : false ,
323+ } ) ;
324+
325+ // Run w/ PKV
326+ const generated_ids_pkv = await model . generate ( {
327+ ...new_inputs ,
328+ past_key_values,
329+ max_new_tokens : 3 ,
330+ do_sample : false ,
331+ } ) ;
332+
333+ const target = [ [ 128000n , 16n , 34732n , 98805n , 116404n , 68265n , 99392n , 17n , 21855n , 60933n , 14285n ] ] ;
334+
335+ expect ( generated_ids . tolist ( ) ) . toEqual ( target ) ;
336+ expect ( generated_ids_pkv . tolist ( ) ) . toEqual ( target ) ;
337+ } ,
338+ MAX_TEST_EXECUTION_TIME ,
339+ ) ;
340+
341+ afterAll ( async ( ) => {
342+ await model ?. dispose ( ) ;
343+ } , MAX_MODEL_DISPOSE_TIME ) ;
344+ } ) ;
345+
285346 describe ( "LlavaForConditionalGeneration" , ( ) => {
286347 const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration" ;
287348 /** @type {LlavaForConditionalGeneration } */
0 commit comments