Skip to content

Commit 5ade7b0

Browse files
authored
[NNBUG: 5701866] Update DS V3.2 PTQ code (NVIDIA#630)
## What does this PR do? **Type of change:** ? Bug fix **Overview:** 1) Update the DS V3.2 repo code reference to the latest version 2) The new DS V3.2 model now includes fp32 layers. We cast it down to match the checkpoint format during loading 3) Fix get_quant_config API change. ## Testing Generate the deepseek-ai/DeepSeek-V3.2 checkpoint ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: Chenjie Luo <[email protected]>
1 parent d0b0c0f commit 5ade7b0

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

examples/deepseek/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ git clone https://github.com/deepseek-ai/DeepSeek-V3.git && cd DeepSeek-V3 && gi
3333
huggingface-cli download deepseek-ai/DeepSeek-V3.2-Exp --local-dir $HF_FP8_CKPT
3434

3535
# clone DeepSeek-V3.2 Github repository for FP8 inference,
36-
git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git && cd DeepSeek-V3.2-Exp && git checkout 3b99a53
36+
git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git && cd DeepSeek-V3.2-Exp && git checkout 87e509a
3737

3838
# Install requirements
3939
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git

examples/deepseek/ptq.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,18 @@ def load_deepseek_model(model_config: str, model_path: str, batch_size: int):
257257
# load model
258258
checkpoint_path = os.path.join(model_path, f"model{rank}-mp{world_size}.safetensors")
259259
print(f"Loading {checkpoint_path}")
260+
261+
# Temporary fix for fp32 params
262+
fp32_params = {}
263+
for name, param in model.named_parameters():
264+
if param.dtype == torch.float32 and (
265+
"head.weight" in name or "attn.indexer.weights_proj.weight" in name
266+
):
267+
param.data = param.data.to(torch.get_default_dtype())
268+
fp32_params[name] = param
260269
load_model(model, checkpoint_path)
270+
for param in fp32_params.values():
271+
param.data = param.data.to(torch.float32)
261272
print(f"Loaded {checkpoint_path}")
262273
return model
263274

@@ -347,7 +358,7 @@ def state_dict_filter(state_dict):
347358
# counts = module.activated_expert_counts()
348359
# f.writelines(f"{name}: {count}\n" for count in counts)
349360

350-
quant_config = get_quant_config(model.named_modules())
361+
quant_config = get_quant_config(model)
351362

352363
if enable_fp8_kvcache:
353364
quant_config["quantization"]["kv_cache_quant_algo"] = KV_CACHE_FP8

0 commit comments

Comments
 (0)