-
Notifications
You must be signed in to change notification settings - Fork 169
FP8 Block quantize onnx export support #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
071f167
831c32d
d2c6e0f
0af26b2
94ec97d
35f3da2
25be640
9e88a34
0a0ad7a
f80b847
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -619,14 +619,37 @@ def _real_quantize(self, inputs): | |
self._dequantize = True | ||
return outputs | ||
|
||
def _get_block_sizes_list(self, shape): | ||
"""Convert block_sizes dict to list format based on tensor shape. | ||
|
||
Args: | ||
shape: The tensor shape to use for conversion (can be tuple or torch.Size) | ||
|
||
Returns: | ||
List of block sizes for each dimension, or None if block_sizes is None | ||
|
||
Example: | ||
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if self.block_sizes is None: | ||
return None | ||
|
||
block_sizes_list = [] | ||
for dim in range(len(shape)): | ||
# Check both positive and negative dimension indices | ||
dim_negative = dim - len(shape) | ||
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) | ||
block_sizes_list.append(block_size if block_size is not None else 1) | ||
return block_sizes_list | ||
|
||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _fake_quantize(self, inputs): | ||
"""Fake quantization.""" | ||
amax = None | ||
if not self.is_mx_format: | ||
amax = self._get_amax(inputs) | ||
|
||
if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic": | ||
# Block quantization, including dynamic and static block quantization | ||
# Double scale Block quantization, including dynamic and static block quantization | ||
block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( | ||
inputs.dim() - 1, None | ||
) | ||
|
@@ -648,6 +671,10 @@ def _fake_quantize(self, inputs): | |
# Float-point quantization, e.g., FP8 | ||
E, M = self._num_bits # noqa: N806 | ||
|
||
# Convert block_sizes dict to list format | ||
# Use original input shape if available (before reshaping), otherwise use current shape | ||
shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape) | ||
block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes) | ||
outputs = scaled_e4m3( | ||
inputs, | ||
amax, | ||
|
@@ -656,6 +683,7 @@ def _fake_quantize(self, inputs): | |
M, | ||
self._trt_high_precision_dtype, | ||
self._pass_through_bwd, | ||
block_sizes_list, | ||
) | ||
|
||
else: | ||
|
@@ -901,9 +929,10 @@ def forward(self, inputs): | |
and self.block_sizes.get("type", None) != "dynamic" | ||
and self._fake_quant | ||
): | ||
# Tensor reshaping is required for static block quantization | ||
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization | ||
# Reshape is required if the logic is not handled in the simulation kernel | ||
# Only MX format and NVFP4 reshape are currently supported by the kernel. | ||
self._setup_for_blockquant(inputs) | ||
setattr(self, "_original_input_shape", inputs.shape) | ||
inputs = self._process_for_blockquant(inputs) | ||
Comment on lines
+932
to
936
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Storing _original_input_shape as an attribute creates state leakage risks. Setting Consider using a local variable or a context manager to ensure cleanup even on exceptions. if (
self.block_sizes is not None
and self.block_sizes.get("type", None) != "dynamic"
and self._fake_quant
):
# Reshape is required if the logic is not handled in the simulation kernel
# Only MX format and NVFP4 reshape are currently supported by the kernel.
self._setup_for_blockquant(inputs)
- setattr(self, "_original_input_shape", inputs.shape)
+ original_input_shape = inputs.shape
+ self._original_input_shape = original_input_shape
inputs = self._process_for_blockquant(inputs) Then wrap the forward logic in try-finally as noted in the next comment. Based on learnings
🤖 Prompt for AI Agents
|
||
|
||
outputs = inputs | ||
|
@@ -943,6 +972,8 @@ def forward(self, inputs): | |
): | ||
outputs = self._reset_to_original_shape(outputs) | ||
|
||
if hasattr(self, "_original_input_shape"): | ||
delattr(self, "_original_input_shape") | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return outputs | ||
|
||
def _short_amax(self, fmt=".4f"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,9 +117,26 @@ def _quantized_sdpa(self, *args, **kwargs): | |
key = self.k_bmm_quantizer(key) | ||
value = self.v_bmm_quantizer(value) | ||
|
||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
if ( | ||
not self.q_bmm_quantizer._dynamic | ||
and not self.k_bmm_quantizer._dynamic | ||
and not self.v_bmm_quantizer._dynamic | ||
): | ||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
else: | ||
assert ( | ||
self.q_bmm_quantizer._dynamic | ||
and self.k_bmm_quantizer._dynamic | ||
and self.v_bmm_quantizer._dynamic | ||
), "QKV QDQS must be in the same type" | ||
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None | ||
|
||
# Get block sizes lists for each quantizer if needed | ||
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] | ||
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] | ||
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] | ||
|
||
# We don't need to calibrate the output of softmax | ||
return self.bmm2_output_quantizer( | ||
|
@@ -135,6 +152,9 @@ def _quantized_sdpa(self, *args, **kwargs): | |
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype") | ||
else "Half", | ||
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, | ||
q_block_sizes, | ||
k_block_sizes, | ||
v_block_sizes, | ||
) | ||
) | ||
|
||
|
@@ -188,6 +208,9 @@ def forward( | |
v_quantized_scale=None, | ||
high_precision_flag=None, | ||
disable_fp8_mha=True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Forward method.""" | ||
ctx.save_for_backward(query, key, value, attn_mask) | ||
|
@@ -206,7 +229,9 @@ def forward( | |
) | ||
|
||
@staticmethod | ||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") | ||
@symbolic_helper.parse_args( | ||
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
def symbolic( | ||
Comment on lines
231
to
235
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainRepo-wide check: ensure no other symbolics still use 't' for optional scales. Run to verify and avoid future breakages. 🏁 Script executed: #!/bin/bash
# Find symbolics that parse optional scales with 't' (should be 'v')
rg -nP '@symbolic_helper\.parse_args\([^)]*\)' -C2 | rg -n 'parse_args\((.*)\)' -o | sed -E 's/.*parse_args\((.*)\).*/\1/' | awk 'BEGIN{RS=","} /"t"/{print FILENAME ":" NR ": " $0}' FILENAME=/dev/stdin
# Narrow search for q/k/v scale params in this file
rg -n 'parse_args\(' modelopt/torch/quantization/plugins/diffusers.py -n -C1 Length of output: 200 🏁 Script executed: #!/bin/bash
# List all @symbolic_helper.parse_args calls containing "t" across Python files
rg -nP '@symbolic_helper\.parse_args\([^)]*\"t\"[^)]*\)' -g '*.py' -C1 Length of output: 779 🏁 Script executed: #!/bin/bash
# Show context around parse_args usages in tensor_quant.py
sed -n '320,335p' modelopt/torch/quantization/tensor_quant.py
sed -n '405,420p' modelopt/torch/quantization/tensor_quant.py
sed -n '500,515p' modelopt/torch/quantization/tensor_quant.py Length of output: 1170 Replace 't' with 'v' for optional scale parameters in all symbolic definitions
|
||
g: "GraphContext", | ||
query: "torch._C.Value", | ||
|
@@ -215,12 +240,15 @@ def symbolic( | |
attn_mask: "torch._C.Value | None" = None, | ||
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
scale: "torch._C.Value | None" = None, | ||
q_quantized_scale: float = 1.0, | ||
k_quantized_scale: float = 1.0, | ||
v_quantized_scale: float = 1.0, | ||
scale: torch._C.Value | None = None, | ||
q_quantized_scale: float | None = 1.0, | ||
k_quantized_scale: float | None = 1.0, | ||
v_quantized_scale: float | None = 1.0, | ||
high_precision_flag: str = "Half", | ||
disable_fp8_mha: bool = True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
"""Symbolic method.""" | ||
return export_fp8_mha( | ||
|
@@ -237,4 +265,7 @@ def symbolic( | |
v_quantized_scale, | ||
high_precision_flag, | ||
disable_fp8_mha, | ||
q_block_shape, | ||
k_block_shape, | ||
v_block_shape, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.