@@ -241,7 +241,6 @@ def test_compare_to_transformers(self, model_arch):
241241        model_id  =  MODEL_NAMES [model_arch ]
242242        set_seed (SEED )
243243        dtype  =  torch .float16  if  IS_XPU_AVAILABLE  else  torch .float32 
244-         # Test model forward do not need cache. 
245244        ipex_model  =  IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
246245        self .assertIsInstance (ipex_model .config , PretrainedConfig )
247246        tokenizer  =  AutoTokenizer .from_pretrained (model_id )
@@ -275,6 +274,38 @@ def test_compare_to_transformers(self, model_arch):
275274        self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
276275        self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
277276
277+     @parameterized .expand (SUPPORTED_ARCHITECTURES ) 
278+     def  test_forward (self , model_arch ):
279+         model_id  =  MODEL_NAMES [model_arch ]
280+         set_seed (SEED )
281+         dtype  =  torch .float16  if  IS_XPU_AVAILABLE  else  torch .float32 
282+         ipex_model  =  IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
283+         self .assertIsInstance (ipex_model .config , PretrainedConfig )
284+         input_ids  =  torch .Tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]]).to (torch .long )
285+         outputs  =  ipex_model (input_ids )
286+ 
287+         self .assertIsInstance (outputs .logits , torch .Tensor )
288+ 
289+         transformers_model  =  AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
290+         with  torch .no_grad ():
291+             transformers_outputs  =  transformers_model (input_ids )
292+ 
293+         # Test re-load model 
294+         with  tempfile .TemporaryDirectory () as  tmpdirname :
295+             ipex_model .save_pretrained (tmpdirname )
296+             loaded_model  =  self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
297+             loaded_model_outputs  =  loaded_model (input_ids )
298+ 
299+         # Test init method 
300+         init_model  =  self .IPEX_MODEL_CLASS (transformers_model )
301+         init_model_outputs  =  init_model (input_ids )
302+ 
303+         # Compare tensor outputs 
304+         self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 1e-4 ))
305+         # To avoid float pointing error 
306+         self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
307+         self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
308+ 
278309    @parameterized .expand (SUPPORTED_ARCHITECTURES ) 
279310    def  test_pipeline (self , model_arch ):
280311        dtype  =  torch .float16  if  IS_XPU_AVAILABLE  else  torch .float32 
0 commit comments