@@ -768,19 +768,18 @@ def main(args):
768768 # quantize the model
769769 model = quantize_model (model , quant_cfg , args , calib_dataloader , calibration_only )
770770
771- # amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nano_nemotron_v2_vlm_fp8_ptq_amax.pt")
772-
773- # model_keys = model.load_state_dict(amax_state_dict, strict=False)
774- # print(f"Loaded amax_state_dict with keys: {model_keys}")
775- # mtq.print_quant_summary(model)
771+ # amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
776772
777773
778774 # For VL models, update full_model to use the quantized language model
779775 if is_nemotron_vl and hasattr (full_model , "language_model" ):
780776 print ("Updating full_model with quantized language_model..." )
781777 full_model .language_model = model
782- fullmodel_key = full_model .load_state_dict (torch .load ("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt" ), strict = False )
783- print (f"Loaded full_model_state_dict with keys: { fullmodel_key } " )
778+ amax_state_dict = torch .load ("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt" )
779+ model_keys = full_model .load_state_dict (amax_state_dict , strict = False )
780+ print (f"Loaded amax_state_dict with keys: { model_keys } " )
781+ # fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
782+ # print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
784783 mtq .print_quant_summary (full_model .language_model )
785784 print ("Loaded additional state dict into full_model." )
786785 if args .verbose :
0 commit comments