Skip to content

Commit 16b9f8f

Browse files
committed
load scalers only for v2 fp4
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 130e2e8 commit 16b9f8f

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,19 +768,18 @@ def main(args):
768768
# quantize the model
769769
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
770770

771-
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nano_nemotron_v2_vlm_fp8_ptq_amax.pt")
772-
773-
# model_keys = model.load_state_dict(amax_state_dict, strict=False)
774-
# print(f"Loaded amax_state_dict with keys: {model_keys}")
775-
# mtq.print_quant_summary(model)
771+
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
776772

777773

778774
# For VL models, update full_model to use the quantized language model
779775
if is_nemotron_vl and hasattr(full_model, "language_model"):
780776
print("Updating full_model with quantized language_model...")
781777
full_model.language_model = model
782-
fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
783-
print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
778+
amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict_scalers_only.pt")
779+
model_keys = full_model.load_state_dict(amax_state_dict, strict=False)
780+
print(f"Loaded amax_state_dict with keys: {model_keys}")
781+
# fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
782+
# print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
784783
mtq.print_quant_summary(full_model.language_model)
785784
print("Loaded additional state dict into full_model.")
786785
if args.verbose:

0 commit comments

Comments
 (0)