@@ -70,35 +70,39 @@ def test_compare_to_transformers(self, model_arch):
7070 ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
7171 self .assertIsInstance (ipex_model .config , PretrainedConfig )
7272 tokenizer = AutoTokenizer .from_pretrained (model_id )
73- tokens = tokenizer (
74- "This is a sample" ,
75- return_tensors = "pt" ,
76- return_token_type_ids = False if model_arch in ("llama2" ,) else None ,
77- ).to (DEVICE )
78- inputs = ipex_model .prepare_inputs_for_generation (** tokens )
79- outputs = ipex_model (** inputs )
73+ texts = ["This is a sample" , ["This is the first input" , "This is the second input" ]]
74+ for text in texts :
75+ tokens = tokenizer (
76+ text ,
77+ return_tensors = "pt" ,
78+ return_token_type_ids = False if model_arch in ("llama2" ,) else None ,
79+ ).to (DEVICE )
80+ outputs = ipex_model (** tokens )
81+ inputs = ipex_model .prepare_inputs_for_generation (** tokens )
82+ outputs_2 = ipex_model (** inputs )
83+ self .assertTrue (torch .allclose (outputs .logits , outputs_2 .logits , atol = 1e-3 ))
8084
81- self .assertIsInstance (outputs .logits , torch .Tensor )
85+ self .assertIsInstance (outputs .logits , torch .Tensor )
8286
83- transformers_model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
84- with torch .no_grad ():
85- transformers_outputs = transformers_model (** tokens )
87+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
88+ with torch .no_grad ():
89+ transformers_outputs = transformers_model (** tokens )
8690
87- # Test re-load model
88- with tempfile .TemporaryDirectory () as tmpdirname :
89- ipex_model .save_pretrained (tmpdirname )
90- loaded_model = self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
91- loaded_model_outputs = loaded_model (** inputs )
91+ # Test re-load model
92+ with tempfile .TemporaryDirectory () as tmpdirname :
93+ ipex_model .save_pretrained (tmpdirname )
94+ loaded_model = self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
95+ loaded_model_outputs = loaded_model (** inputs )
9296
93- # Test init method
94- init_model = self .IPEX_MODEL_CLASS (transformers_model )
95- init_model_outputs = init_model (** inputs )
97+ # Test init method
98+ init_model = self .IPEX_MODEL_CLASS (transformers_model )
99+ init_model_outputs = init_model (** inputs )
96100
97- # Compare tensor outputs
98- self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 1e-3 ))
99- # To avoid float pointing error
100- self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
101- self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
101+ # Compare tensor outputs
102+ self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 1e-3 ))
103+ # To avoid float pointing error
104+ self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
105+ self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
102106
103107 @parameterized .expand (SUPPORTED_ARCHITECTURES )
104108 @unittest .skip (reason = "Paged attention do not support assisted decoding for now" )
0 commit comments