@@ -797,19 +797,18 @@ def main(args):
797797 # quantize the model
798798 model = quantize_model (model , quant_cfg , args , calib_dataloader , calibration_only )
799799
800- # amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nano_nemotron_v2_vlm_fp8_ptq_amax.pt")
801-
802- # model_keys = model.load_state_dict(amax_state_dict, strict=False)
803- # print(f"Loaded amax_state_dict with keys: {model_keys}")
804- # mtq.print_quant_summary(model)
800+ # amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
805801
806802
807803 # For VL models, update full_model to use the quantized language model
808804 if is_nemotron_vl and hasattr (full_model , "language_model" ):
809805 print ("Updating full_model with quantized language_model..." )
810806 full_model .language_model = model
811- fullmodel_key = full_model .load_state_dict (torch .load ("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt" ), strict = False )
812- print (f"Loaded full_model_state_dict with keys: { fullmodel_key } " )
807+ amax_state_dict = torch .load ("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt" )
808+ model_keys = full_model .load_state_dict (amax_state_dict , strict = False )
809+ print (f"Loaded amax_state_dict with keys: { model_keys } " )
810+ # fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
811+ # print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
813812 mtq .print_quant_summary (full_model .language_model )
814813 print ("Loaded additional state dict into full_model." )
815814 if args .verbose :
0 commit comments