@@ -1450,25 +1450,35 @@ def test_compare_with_and_without_past_key_values(self):
14501450 model_id = MODEL_NAMES ["gpt2" ]
14511451 tokenizer = AutoTokenizer .from_pretrained (model_id )
14521452 tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
1453+
14531454 model_with_pkv = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = False )
14541455 outputs_model_with_pkv = model_with_pkv .generate (
14551456 ** tokens , min_length = self .GENERATION_LENGTH , max_length = self .GENERATION_LENGTH , num_beams = 1
14561457 )
1458+ del model_with_pkv
1459+
14571460 model_without_pkv = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = False )
14581461 outputs_model_without_pkv = model_without_pkv .generate (
14591462 ** tokens , min_length = self .GENERATION_LENGTH , max_length = self .GENERATION_LENGTH , num_beams = 1
14601463 )
1464+ del model_without_pkv
1465+
14611466 self .assertTrue (torch .equal (outputs_model_with_pkv , outputs_model_without_pkv ))
14621467 self .assertEqual (outputs_model_with_pkv .shape [1 ], self .GENERATION_LENGTH )
14631468 self .assertEqual (outputs_model_without_pkv .shape [1 ], self .GENERATION_LENGTH )
1469+
14641470 model_stateful = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = True )
14651471 outputs_model_stateful = model_stateful .generate (
14661472 ** tokens , min_length = self .GENERATION_LENGTH , max_length = self .GENERATION_LENGTH , num_beams = 1
14671473 )
14681474 self .assertTrue (torch .equal (outputs_model_without_pkv , outputs_model_stateful ))
14691475
1470- del model_with_pkv
1471- del model_without_pkv
1476+ logits = model_stateful (** tokens ).logits
1477+ copy_logits = copy .deepcopy (logits )
1478+ tokens = tokenizer ("Input sample" , return_tensors = "pt" )
1479+ model_stateful (** tokens ).logits
1480+ self .assertTrue (torch .equal (copy_logits , logits ))
1481+ del model_stateful
14721482 gc .collect ()
14731483
14741484 def test_print_model_properties (self ):
@@ -1496,7 +1506,7 @@ def test_auto_device_loading(self):
14961506
14971507 def test_default_filling_attention_mask (self ):
14981508 model_id = MODEL_NAMES ["gpt2" ]
1499- model_with_cache = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True )
1509+ model_with_cache = OVModelForCausalLM .from_pretrained (model_id , stateful = False , use_cache = True )
15001510 tokenizer = AutoTokenizer .from_pretrained (model_id )
15011511 tokenizer .pad_token = tokenizer .eos_token
15021512 texts = ["this is a simple input" ]
@@ -1519,7 +1529,7 @@ def test_default_filling_attention_mask(self):
15191529
15201530 def test_default_filling_attention_mask_and_position_ids (self ):
15211531 model_id = MODEL_NAMES ["llama" ]
1522- model_with_cache = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True )
1532+ model_with_cache = OVModelForCausalLM .from_pretrained (model_id , stateful = False , use_cache = True )
15231533 tokenizer = AutoTokenizer .from_pretrained (model_id )
15241534 tokenizer .pad_token = tokenizer .eos_token
15251535 texts = ["this is a simple input" ]
0 commit comments