File tree Expand file tree Collapse file tree 2 files changed +25
-4
lines changed
tests/quantization/torchao_integration Expand file tree Collapse file tree 2 files changed +25
-4
lines changed Original file line number Diff line number Diff line change @@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
943943 old_param = model
944944 splits = param_name .split ("." )
945945 for split in splits :
946- old_param = getattr (old_param , split )
947- # Not all the attributes of a module are Parameters/Tensor
948- if not isinstance (old_param , (torch .nn .Parameter , torch .Tensor )):
949- old_param = None
946+ # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
947+ old_param = getattr (old_param , split , None )
950948 if old_param is None :
951949 break
952950
951+ if not isinstance (old_param , (torch .nn .Parameter , torch .Tensor )):
952+ old_param = None
953+
953954 if old_param is not None :
954955 if dtype is None :
955956 param = param .to (old_param .dtype )
Original file line number Diff line number Diff line change @@ -208,6 +208,26 @@ def test_int4wo_offload(self):
208208
209209 self .assertEqual (tokenizer .decode (output [0 ], skip_special_tokens = True ), EXPECTED_OUTPUT )
210210
211+ def test_int8_dynamic_activation_int8_weight_quant (self ):
212+ """
213+ Simple LLM model testing int8_dynamic_activation_int8_weight
214+ """
215+ quant_config = TorchAoConfig ("int8_dynamic_activation_int8_weight" )
216+
217+ # Note: we quantize the bfloat16 model on the fly to int4
218+ quantized_model = AutoModelForCausalLM .from_pretrained (
219+ self .model_name ,
220+ device_map = torch_device ,
221+ quantization_config = quant_config ,
222+ )
223+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
224+
225+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (torch_device )
226+
227+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
228+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
229+ self .assertEqual (tokenizer .decode (output [0 ], skip_special_tokens = True ), EXPECTED_OUTPUT )
230+
211231
212232if __name__ == "__main__" :
213233 unittest .main ()
You can’t perform that action at this time.
0 commit comments