@@ -483,6 +483,45 @@ def test_bnb(self):
483483 self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
484484 self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
485485
486+ @unittest .skipIf (not is_auto_awq_available (), reason = "Test requires autoawq" )
487+ def test_awq (self ):
488+ model_id = "PrunaAI/JackFram-llama-68m-AWQ-4bit-smashed"
489+ set_seed (SEED )
490+ dtype = torch .float16 if IS_XPU_AVAILABLE else torch .float32
491+ # Test model forward do not need cache.
492+ ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
493+ self .assertIsInstance (ipex_model .config , PretrainedConfig )
494+ tokenizer = AutoTokenizer .from_pretrained (model_id )
495+ tokens = tokenizer (
496+ "This is a sample" ,
497+ return_tensors = "pt" ,
498+ return_token_type_ids = False ,
499+ ).to (DEVICE )
500+ inputs = ipex_model .prepare_inputs_for_generation (** tokens )
501+ outputs = ipex_model (** inputs )
502+
503+ self .assertIsInstance (outputs .logits , torch .Tensor )
504+
505+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
506+ with torch .no_grad ():
507+ transformers_outputs = transformers_model (** tokens )
508+
509+ # Test re-load model
510+ with tempfile .TemporaryDirectory () as tmpdirname :
511+ ipex_model .save_pretrained (tmpdirname )
512+ loaded_model = self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
513+ loaded_model_outputs = loaded_model (** inputs )
514+
515+ # Test init method
516+ init_model = self .IPEX_MODEL_CLASS (transformers_model )
517+ init_model_outputs = init_model (** inputs )
518+
519+ # Compare tensor outputs
520+ self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 5e-2 ))
521+ # To avoid float pointing error
522+ self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
523+ self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
524+
486525
487526class IPEXModelForAudioClassificationTest (unittest .TestCase ):
488527 IPEX_MODEL_CLASS = IPEXModelForAudioClassification
0 commit comments