Skip to content

Commit d5597e2

Browse files
committed
load scalers only for v2 fp4
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent c6f623d commit d5597e2

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -797,19 +797,18 @@ def main(args):
797797
# quantize the model
798798
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
799799

800-
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nano_nemotron_v2_vlm_fp8_ptq_amax.pt")
801-
802-
# model_keys = model.load_state_dict(amax_state_dict, strict=False)
803-
# print(f"Loaded amax_state_dict with keys: {model_keys}")
804-
# mtq.print_quant_summary(model)
800+
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
805801

806802

807803
# For VL models, update full_model to use the quantized language model
808804
if is_nemotron_vl and hasattr(full_model, "language_model"):
809805
print("Updating full_model with quantized language_model...")
810806
full_model.language_model = model
811-
fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
812-
print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
807+
amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
808+
model_keys = full_model.load_state_dict(amax_state_dict, strict=False)
809+
print(f"Loaded amax_state_dict with keys: {model_keys}")
810+
# fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
811+
# print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
813812
mtq.print_quant_summary(full_model.language_model)
814813
print("Loaded additional state dict into full_model.")
815814
if args.verbose:

0 commit comments

Comments
 (0)