Skip to content

Commit 130e2e8

Browse files
committed
debug loading v2 converted nvfp4 weights from mcore
1 parent 659a2c2 commit 130e2e8

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,12 +768,24 @@ def main(args):
768768
# quantize the model
769769
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
770770

771+
# amax_state_dict = torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nano_nemotron_v2_vlm_fp8_ptq_amax.pt")
772+
773+
# model_keys = model.load_state_dict(amax_state_dict, strict=False)
774+
# print(f"Loaded amax_state_dict with keys: {model_keys}")
775+
# mtq.print_quant_summary(model)
776+
777+
771778
# For VL models, update full_model to use the quantized language model
772779
if is_nemotron_vl and hasattr(full_model, "language_model"):
773780
print("Updating full_model with quantized language_model...")
774781
full_model.language_model = model
782+
fullmodel_key = full_model.load_state_dict(torch.load("/home/scratch.omniml_data_2/jingyux/models/llama_nemotron_v2_fp4_ptq_state_dict.pt"), strict=False)
783+
print(f"Loaded full_model_state_dict with keys: {fullmodel_key}")
784+
mtq.print_quant_summary(full_model.language_model)
785+
print("Loaded additional state dict into full_model.")
775786
if args.verbose:
776-
mtq.print_quant_summary(model)
787+
pass
788+
# mtq.print_quant_summary(model)
777789

778790
# Run some samples
779791
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)