Skip to content

Commit 19722c9

Browse files
committed
Support DeepSeek V3.2 models
Signed-off-by: Chenjie Luo <[email protected]>
1 parent 32d168c commit 19722c9

File tree

8 files changed

+117
-17
lines changed

8 files changed

+117
-17
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Model Optimizer Changelog (Linux)
1313
- Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details.
1414
- Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` if no dataset is specified.
1515
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.
16+
- Support ``DeepSeek V3.2`` model quantization. See ``examples/deepseek`` for more details.
1617

1718
**Documentation**
1819

examples/deepseek/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
DeepSeek-V3/
2+
DeepSeek-V3.2-Exp/

examples/deepseek/README.md

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,64 @@ This example will demonstrate the steps to quantize DeepSeek R1 model to FP4 and
66

77
Due to the model size, currently it requires 8xH200 or 16xH100 to quantize the FP8 model, we will use 8xH200 as example.
88

9-
### Convert the HF checkpoint for deepseek FP8 inference
9+
## Convert the HF checkpoint for deepseek FP8 inference
1010

1111
```bash
1212
# set up variables to run the example
1313
export HF_FP8_CKPT={path_to_downloaded_hf_checkpoint}
1414
export DS_CKPT={path_to_save_converted_checkpoint}
1515
export FP4_QUANT_PATH={path_to_save_quantization_results}
1616
export HF_FP4_PATH={path_to_save_the_final_FP4_checkpoint}
17+
```
18+
19+
### DeepSeek V3, R1 and V3.1
1720

18-
# download the FP8 checkpoint from Hugginface
21+
```bash
22+
# download the FP8 checkpoint from Hugginface. This is an example of DeepSeek-R1
1923
huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $HF_FP8_CKPT
2024

2125
# clone DeepSeek-V3 (base model of R1) Github repository for FP8 inference,
2226
git clone https://github.com/deepseek-ai/DeepSeek-V3.git && cd DeepSeek-V3 && git checkout 1398800
27+
```
28+
29+
### DeepSeek V3.2
2330

31+
```bash
32+
# download the FP8 checkpoint from Hugginface.
33+
huggingface-cli download deepseek-ai/DeepSeek-V3.2-Exp --local-dir $HF_FP8_CKPT
34+
35+
# clone DeepSeek-V3.2 Github repository for FP8 inference,
36+
git clone https://github.com/deepseek-ai/DeepSeek-V3.2-Exp.git && cd DeepSeek-V3.2-Exp && git checkout 3b99a53
37+
38+
# Install requirements
39+
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
40+
pip install -r DeepSeek-V3.2-Exp/inference/requirements.txt
41+
```
42+
43+
### Convert the Checkpoint
44+
45+
```bash
2446
# convert the HF checkpoint to a specific format for Deepseek
2547
python inference/convert.py --hf-ckpt-path $HF_FP8_CKPT --save-path $DS_CKPT --n-experts 256 --model-parallel 8
2648
```
2749

28-
### Post-training quantization
50+
## Post-training quantization
51+
52+
### Run the calibration scripts
2953

30-
#### Run the calibration scripts
54+
DeepSeek V3
3155

3256
```bash
3357
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3/inference/configs/config_671B.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH
3458
```
3559

36-
#### Quantize the FP8 hf checkpoint to FP4
60+
DeepSeek V3.2
61+
62+
```bash
63+
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH
64+
```
65+
66+
### Quantize the FP8 hf checkpoint to FP4
3767

3868
We provide a one-step-script which will:
3969

examples/deepseek/ds_kernel.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py"""
6+
7+
8+
@triton.jit
9+
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
10+
"""
11+
Dequantizes weights using the provided scaling factors and stores the result.
12+
13+
Args:
14+
x_ptr (tl.pointer): Pointer to the quantized weights.
15+
s_ptr (tl.pointer): Pointer to the scaling factors.
16+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
17+
M (int): Number of rows in the weight matrix.
18+
N (int): Number of columns in the weight matrix.
19+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
20+
21+
Returns:
22+
None
23+
"""
24+
pid_m = tl.program_id(axis=0)
25+
pid_n = tl.program_id(axis=1)
26+
n = tl.cdiv(N, BLOCK_SIZE)
27+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
29+
offs = offs_m[:, None] * N + offs_n[None, :]
30+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
31+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
32+
s = tl.load(s_ptr + pid_m * n + pid_n)
33+
y = x * s
34+
tl.store(y_ptr + offs, y, mask=mask)
35+
36+
37+
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
38+
"""
39+
Dequantizes the given weight tensor using the provided scale tensor.
40+
41+
Args:
42+
x (torch.Tensor): The quantized weight tensor of shape (M, N).
43+
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
44+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
45+
46+
Returns:
47+
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
48+
49+
Raises:
50+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
51+
"""
52+
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
53+
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
54+
M, N = x.size()
55+
y = torch.empty_like(x, dtype=torch.get_default_dtype())
56+
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
57+
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
58+
return y

examples/deepseek/ptq.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,21 @@
6464
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
6565
from modelopt.torch.utils.distributed import ParallelState
6666

67-
sys.path.append(str(Path(__file__).resolve().parent / "DeepSeek-V3/inference"))
68-
import model as deekseep_model
69-
from kernel import act_quant, fp8_gemm, weight_dequant
67+
DS_V3_PATH = Path(__file__).resolve().parent / "DeepSeek-V3/inference"
68+
DS_V3_2_PATH = Path(__file__).resolve().parent / "DeepSeek-V3.2-Exp/inference"
69+
70+
if DS_V3_2_PATH.exists():
71+
sys.path.append(str(DS_V3_2_PATH))
72+
elif DS_V3_PATH.exists():
73+
sys.path.append(str(DS_V3_PATH))
74+
else:
75+
raise ValueError(
76+
f"DeepSeek-V3 or DeepSeek-V3.2-Exp not found in {Path(__file__).resolve().parent}"
77+
)
78+
79+
import model as deekseep_model # noqa: E402
80+
from ds_kernel import weight_dequant # noqa: E402
81+
from kernel import act_quant, fp8_gemm # noqa: E402
7082

7183

7284
def monkey_patch_deepseek_model():
@@ -243,10 +255,10 @@ def ptq(
243255
## create dataset
244256
device = next(model.parameters()).device
245257
calib_dataset = get_dataset_dataloader(
246-
dataset_name="cnn_dailymail",
258+
dataset_name=["cnn_dailymail", "nemotron-post-training-dataset-v2"],
247259
tokenizer=tokenizer,
248260
batch_size=batch_size,
249-
num_samples=calib_size,
261+
num_samples=[calib_size, calib_size],
250262
device=device,
251263
)
252264

examples/deepseek/quantize_fp8_to_nvfp4.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ fi
7878

7979
# Copy miscellaneous files to the quantized checkpoint
8080
mkdir -p $FP4_PATH
81-
cp $FP8_HF_PATH/*.json $FP8_HF_PATH/*.py $FP4_PATH/
81+
cp $FP8_HF_PATH/*.json $FP4_PATH/
82+
cp $FP8_HF_PATH/*.py $FP4_PATH/ || true
83+
cp -r $FP8_HF_PATH/assets $FP4_PATH/ || true
8284

8385
# Run the quantization command
8486
echo "Running quantization..."

examples/deepseek/quantize_to_nvfp4.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,15 @@
4141
import glob
4242
import json
4343
import os
44-
import sys
45-
from pathlib import Path
4644
from typing import Any
4745

4846
import torch
47+
from ds_kernel import weight_dequant
4948
from safetensors.torch import load_file, save_file
5049
from tqdm import tqdm
5150

5251
from modelopt.torch.quantization.qtensor import NVFP4QTensor
5352

54-
sys.path.append(str(Path(__file__).resolve().parent / "DeepSeek-V3/inference"))
55-
from kernel import weight_dequant
56-
5753

5854
def _remap_key(key_dict: dict[str, Any]):
5955
# renaming the module to match HF modeling

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ extend-ignore = [
6565
"*/_[a-zA-Z]*" = ["D"] # Private packages (_abc/*.py) or modules (_xyz.py)
6666
"*.ipynb" = ["D", "E501"] # Ignore missing docstrings or line length for Jupyter notebooks
6767
"modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style
68-
68+
"examples/deepseek/ds_kernel.py" = ["N803", "N806", "E731"] # triton style
6969

7070
[tool.ruff.lint.pycodestyle]
7171
max-line-length = 120 # Line length limit for comments and docstrings

0 commit comments

Comments
 (0)