-
Notifications
You must be signed in to change notification settings - Fork 162
FP8 Block quantize onnx export support #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
071f167
831c32d
d2c6e0f
0af26b2
94ec97d
35f3da2
25be640
9e88a34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,9 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs): | |
key = self.k_bmm_quantizer(key) | ||
value = self.v_bmm_quantizer(value) | ||
|
||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
if ( | ||
not self.q_bmm_quantizer._dynamic | ||
and not self.k_bmm_quantizer._dynamic | ||
and not self.v_bmm_quantizer._dynamic | ||
): | ||
q_quantized_scale = self.q_bmm_quantizer._get_amax(query) | ||
k_quantized_scale = self.k_bmm_quantizer._get_amax(key) | ||
v_quantized_scale = self.v_bmm_quantizer._get_amax(value) | ||
else: | ||
assert ( | ||
self.q_bmm_quantizer._dynamic | ||
and self.k_bmm_quantizer._dynamic | ||
and self.v_bmm_quantizer._dynamic | ||
), "QKV QDQS must be in the same type" | ||
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None | ||
|
||
# Get block sizes lists for each quantizer if needed | ||
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr] | ||
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr] | ||
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr] | ||
|
||
# We don't need to calibrate the output of softmax | ||
return self.bmm2_output_quantizer( | ||
|
@@ -132,6 +149,9 @@ def _quantized_sdpa(self, *args, **kwargs): | |
if hasattr(self.q_bmm_quantizer, "trt_high_precision_dtype") | ||
else "Half", | ||
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True, | ||
q_block_sizes, | ||
k_block_sizes, | ||
v_block_sizes, | ||
) | ||
) | ||
|
||
|
@@ -185,6 +205,9 @@ def forward( | |
v_quantized_scale=None, | ||
high_precision_flag=None, | ||
disable_fp8_mha=True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
jingyu-ml marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Forward method.""" | ||
ctx.save_for_backward(query, key, value, attn_mask) | ||
|
@@ -203,7 +226,9 @@ def forward( | |
) | ||
|
||
@staticmethod | ||
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") | ||
@symbolic_helper.parse_args( | ||
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this pattern is changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can check parse_args: Each value represents the input type, which helps Torch trace the graph more effectively. |
||
) | ||
def symbolic( | ||
g: jit_utils.GraphContext, | ||
query: torch._C.Value, | ||
|
@@ -213,11 +238,14 @@ def symbolic( | |
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
scale: torch._C.Value | None = None, | ||
q_quantized_scale: float = 1.0, | ||
k_quantized_scale: float = 1.0, | ||
v_quantized_scale: float = 1.0, | ||
q_quantized_scale: float | None = 1.0, | ||
k_quantized_scale: float | None = 1.0, | ||
v_quantized_scale: float | None = 1.0, | ||
high_precision_flag: str = "Half", | ||
disable_fp8_mha: bool = True, | ||
q_block_shape: list | None = None, | ||
k_block_shape: list | None = None, | ||
v_block_shape: list | None = None, | ||
): | ||
"""Symbolic method.""" | ||
return export_fp8_mha( | ||
|
@@ -234,4 +262,7 @@ def symbolic( | |
v_quantized_scale, | ||
high_precision_flag, | ||
disable_fp8_mha, | ||
q_block_shape, | ||
k_block_shape, | ||
v_block_shape, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.