-
Notifications
You must be signed in to change notification settings - Fork 170
Improve realquant gemm impl #368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,35 +32,38 @@ | |
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
|
||
|
||
def fp8_per_tensor_gemm(quant_module, input, bias=None): | ||
"""GEMM function for fp8 per tensor quantization.""" | ||
@torch.compile(dynamic=True) | ||
def _to_fp8(x, scale): | ||
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) | ||
|
||
|
||
@torch.compile(dynamic=True) | ||
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): | ||
input_shape = input.shape | ||
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) | ||
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() | ||
output = torch._scaled_mm( | ||
input_fp8, | ||
weight_fp8_t, | ||
scale_a=scale_a, | ||
scale_b=scale_b, | ||
bias=bias, | ||
out_dtype=input.dtype, | ||
use_fast_accum=True, | ||
) | ||
return output.reshape(*input_shape[:-1], output.shape[-1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq, could you share why moving these functions to the module level reduces the CPU overheads? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. My suspicion is that if it's not in the module level, the torch.compile will be called everytime showed in the nsys trace. |
||
|
||
@torch.compile(dynamic=True) | ||
def _to_fp8(x, scale): | ||
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) | ||
|
||
@torch.compile(dynamic=True) | ||
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): | ||
input_shape = input.shape | ||
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) | ||
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() | ||
output = torch._scaled_mm( | ||
input_fp8, | ||
weight_fp8_t, | ||
scale_a=scale_a, | ||
scale_b=scale_b, | ||
bias=bias, | ||
out_dtype=input.dtype, | ||
use_fast_accum=True, | ||
) | ||
return output.reshape(*input_shape[:-1], output.shape[-1]) | ||
|
||
def fp8_per_tensor_gemm(quant_module, input, bias=None): | ||
"""GEMM function for fp8 per tensor quantization.""" | ||
cached_scale_a = ( | ||
hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None | ||
) | ||
|
||
if not cached_scale_a: | ||
input_amax = quant_module.input_quantizer.amax or reduce_amax(input) | ||
input_amax = quant_module.input_quantizer.amax | ||
if input_amax is None: | ||
input_amax = reduce_amax(input) | ||
assert input_amax != 0 | ||
quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device) | ||
|
||
|
@@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): | |
) | ||
|
||
if not cached_scale_b: | ||
weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight) | ||
weight_amax = quant_module.weight_quantizer.amax | ||
if weight_amax is None: | ||
weight_amax = reduce_amax(quant_module.weight) | ||
assert weight_amax != 0 | ||
quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device) | ||
|
||
|
@@ -146,9 +151,9 @@ def forward( | |
ctx.save_for_backward( | ||
input_tensor if weight.requires_grad else None, | ||
weight if input_tensor.requires_grad else None, | ||
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None, | ||
getattr(quant_module.weight_quantizer, "_scale", None), | ||
) | ||
ctx.compute_bias_grad = bias is not None and bias.requires_grad | ||
ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None) | ||
|
||
ctx.allreduce_dgrad = allreduce_dgrad | ||
|
@@ -166,7 +171,7 @@ def backward(ctx, grad_outputs): | |
dequantize it to compute the input gradient. If the weight is not compressed, we will save | ||
the unquantized weight and use it directly to compute the input gradient. | ||
""" | ||
input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors | ||
input_tensor, weight, scale = ctx.saved_tensors | ||
grad_input = grad_weight = grad_bias = None | ||
if weight is not None: | ||
if isinstance(weight, QTensorWrapper): | ||
|
@@ -175,8 +180,10 @@ def backward(ctx, grad_outputs): | |
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes) | ||
grad_input = grad_outputs @ weight | ||
if input_tensor is not None: | ||
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor | ||
if compute_bias_grad is not None: | ||
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( | ||
-1, input_tensor.shape[-1] | ||
) | ||
if ctx.compute_bias_grad: | ||
# Sum all dimensions except the last one | ||
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.