-
Notifications
You must be signed in to change notification settings - Fork 162
FP8 Block quantize onnx export support #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
071f167
831c32d
d2c6e0f
0af26b2
94ec97d
35f3da2
25be640
9e88a34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -648,6 +648,29 @@ def _real_quantize(self, inputs): | |
self._dequantize = True | ||
return outputs | ||
|
||
def _get_block_sizes_list(self, shape): | ||
"""Convert block_sizes dict to list format based on tensor shape. | ||
|
||
Args: | ||
shape: The tensor shape to use for conversion (can be tuple or torch.Size) | ||
|
||
Returns: | ||
List of block sizes for each dimension, or None if block_sizes is None | ||
|
||
Example: | ||
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1] | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if self.block_sizes is None: | ||
return None | ||
|
||
block_sizes_list = [] | ||
for dim in range(len(shape)): | ||
# Check both positive and negative dimension indices | ||
dim_negative = dim - len(shape) | ||
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None) | ||
block_sizes_list.append(block_size if block_size is not None else 1) | ||
return block_sizes_list | ||
|
||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _fake_quantize(self, inputs): | ||
"""Fake quantization.""" | ||
amax = None | ||
|
@@ -656,7 +679,7 @@ def _fake_quantize(self, inputs): | |
self._validate_amax(amax) | ||
|
||
if self.block_sizes is not None and self.block_sizes.get("type", "static") == "dynamic": | ||
# Block quantization, including dynamic and static block quantization | ||
# Double scale Block quantization, including dynamic and static block quantization | ||
block_size = self.block_sizes.get(-1, None) or self.block_sizes.get( | ||
inputs.dim() - 1, None | ||
) | ||
|
@@ -677,9 +700,14 @@ def _fake_quantize(self, inputs): | |
# Float-point quantization, e.g., FP8 | ||
E, M = self._num_bits # noqa: N806 | ||
|
||
# Convert block_sizes dict to list format | ||
# Use original input shape if available (before reshaping), otherwise use current shape | ||
shape_for_block_sizes = getattr(self, "_original_input_shape", inputs.shape) | ||
block_sizes_list = self._get_block_sizes_list(shape_for_block_sizes) | ||
outputs = scaled_e4m3( | ||
inputs, | ||
amax, | ||
block_sizes_list, | ||
self._get_bias(inputs), | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
E, | ||
M, | ||
|
@@ -928,9 +956,9 @@ def forward(self, inputs): | |
and self.block_sizes.get("type", None) != "dynamic" | ||
and self._fake_quant | ||
): | ||
# Tensor reshaping is required for static block quantization | ||
# Tensor shapes are handled separately by the quantization kernels for dynamic block quantization | ||
# Reshape is required if the logic isnt handled in the simulation kernel | ||
|
||
self._setup_for_blockquant(inputs) | ||
setattr(self, "_original_input_shape", inputs.shape) | ||
inputs = self._process_for_blockquant(inputs) | ||
|
||
outputs = inputs | ||
|
@@ -971,6 +999,8 @@ def forward(self, inputs): | |
): | ||
outputs = self._reset_to_original_shape(outputs) | ||
|
||
if hasattr(self, "_original_input_shape"): | ||
delattr(self, "_original_input_shape") | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return outputs | ||
|
||
def _short_amax(self, fmt=".4f"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,9 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs): | |
key = self.k_bmm_quantizer(key) | ||
value = self.v_bmm_quantizer(value) | ||
|
||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
if ( | ||
not self.q_bmm_quantizer._dynamic | ||
and not self.k_bmm_quantizer._dynamic | ||
and not self.v_bmm_quantizer._dynamic | ||
): | ||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
else: | ||
assert ( | ||
self.q_bmm_quantizer._dynamic | ||
and self.k_bmm_quantizer._dynamic | ||
and self.v_bmm_quantizer._dynamic | ||
), "QKV QDQS must be in the same type" | ||
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None | ||
|
||
# Get block sizes lists for each quantizer if needed | ||
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] | ||
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] | ||
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] | ||
|
||
# We don't need to calibrate the output of softmax | ||
return self.bmm2_output_quantizer( | ||
|
@@ -132,6 +149,9 @@ def _quantized_sdpa(self, *args, **kwargs): | |
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype") | ||
else "Half", | ||
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, | ||
q_block_sizes, | ||
k_block_sizes, | ||
v_block_sizes, | ||
) | ||
) | ||
|
||
|
@@ -185,6 +205,9 @@ def forward( | |
v_quantized_scale=None, | ||
high_precision_flag=None, | ||
disable_fp8_mha=True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Forward method.""" | ||
ctx.save_for_backward(query, key, value, attn_mask) | ||
|
@@ -203,7 +226,9 @@ def forward( | |
) | ||
|
||
@staticmethod | ||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") | ||
@symbolic_helper.parse_args( | ||
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this pattern is changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can check parse_args: Each value represents the input type, which helps Torch trace the graph more effectively. |
||
) | ||
def symbolic( | ||
g: jit_utils.GraphContext, | ||
query: torch._C.Value, | ||
|
@@ -213,11 +238,14 @@ def symbolic( | |
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
scale: torch._C.Value | None = None, | ||
q_quantized_scale: float = 1.0, | ||
k_quantized_scale: float = 1.0, | ||
v_quantized_scale: float = 1.0, | ||
q_quantized_scale: float | None = 1.0, | ||
k_quantized_scale: float | None = 1.0, | ||
v_quantized_scale: float | None = 1.0, | ||
high_precision_flag: str = "Half", | ||
disable_fp8_mha: bool = True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
"""Symbolic method.""" | ||
return export_fp8_mha( | ||
|
@@ -234,4 +262,7 @@ def symbolic( | |
v_quantized_scale, | ||
high_precision_flag, | ||
disable_fp8_mha, | ||
q_block_shape, | ||
k_block_shape, | ||
v_block_shape, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.