-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][fix] impl fused triton kernel for e8m0 resmooth to reduce memory footprint #10327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,45 +51,80 @@ def per_token_cast_to_fp8_e8m0( | |
| g, m, n), sf | ||
|
|
||
|
|
||
| def per_block_cast_to_fp8_e8m0( | ||
| x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| if x.dim() == 2: | ||
| m, n = x.shape | ||
| x_padded = torch.zeros((align(m, 128), align(n, 128)), | ||
| dtype=x.dtype, | ||
| device=x.device) | ||
| x_padded[:m, :n] = x | ||
| x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) | ||
| x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) | ||
| sf = ceil_to_ue8m0(x_amax / 448.0) | ||
| x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) | ||
| return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( | ||
| x_view.size(0), x_view.size(2)) | ||
| else: | ||
| g, m, n = x.shape | ||
| x_padded = torch.zeros((g, align(m, 128), align(n, 128)), | ||
| dtype=x.dtype, | ||
| device=x.device) | ||
| x_padded[:, :m, :n] = x | ||
| x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128) | ||
| x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4) | ||
| sf = ceil_to_ue8m0(x_amax / 448.0) | ||
| x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) | ||
| return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view( | ||
| x_view.size(0), x_view.size(1), x_view.size(3)) | ||
|
|
||
|
|
||
| def resmooth_to_fp8_e8m0(weight: torch.Tensor, | ||
| sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| weight = weight.cuda() | ||
| sf = sf.cuda() | ||
| if weight.dim() == 2: | ||
| x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( | ||
| 128, dim=1)[:weight.shape[0], :weight.shape[1]] | ||
| else: | ||
| x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( | ||
| 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] | ||
| return per_block_cast_to_fp8_e8m0(x) | ||
| @triton.jit | ||
| def _resmooth_kernel( | ||
| w_ptr, | ||
| s_ptr, | ||
| M, | ||
| K, | ||
| stride_wm, | ||
| stride_wk, | ||
| stride_sm, | ||
| stride_sk, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_K: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| num_pid_k = tl.cdiv(K, BLOCK_K) | ||
| pid_m = pid // num_pid_k | ||
| pid_k = pid % num_pid_k | ||
|
|
||
| s_offset = pid_m * stride_sm + pid_k * stride_sk | ||
| old_scale = tl.load(s_ptr + s_offset) | ||
|
|
||
| rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
| rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) | ||
|
|
||
| w_mask = (rm[:, None] < M) & (rk[None, :] < K) | ||
| w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk | ||
| w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0) | ||
| w_fp32 = w_fp8.to(tl.float32) | ||
|
|
||
| w_val = w_fp32 * old_scale | ||
| block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4) | ||
|
|
||
| # E8M0 sf = 2 ^ ceil(log2(sf)) | ||
| new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0))) | ||
| w_requant = w_val * (1.0 / new_scale) | ||
|
|
||
| tl.store(w_ptr + w_offsets, w_requant, mask=w_mask) | ||
| tl.store(s_ptr + s_offset, new_scale) | ||
|
|
||
|
|
||
| def resmooth_to_fp8_e8m0( | ||
| weight: torch.Tensor, | ||
| weight_scale: torch.Tensor, | ||
| block_size: tuple[int, int] = (128, 128), | ||
| ): | ||
| assert weight.dtype == torch.float8_e4m3fn | ||
| assert weight_scale.dtype == torch.float32 | ||
|
|
||
| orig_shape = weight.shape | ||
| M, K = orig_shape[-2:] | ||
| w_view = weight.view(-1, M, K) | ||
| s_view = weight_scale.view(-1, weight_scale.shape[-2], | ||
| weight_scale.shape[-1]) | ||
|
|
||
| num_batches = w_view.shape[0] | ||
| BLOCK_M, BLOCK_K = block_size | ||
|
|
||
| grid = (num_batches * (triton.cdiv(M, BLOCK_M)) * | ||
| (triton.cdiv(K, BLOCK_K)), ) | ||
|
|
||
| _resmooth_kernel[grid]( | ||
| w_view, | ||
| s_view, | ||
| M, | ||
| K, | ||
| w_view.stride(1), | ||
| w_view.stride(2), | ||
| s_view.stride(1), | ||
| s_view.stride(2), | ||
| BLOCK_M=BLOCK_M, | ||
| BLOCK_K=BLOCK_K, | ||
| ) | ||
| # this is an in-place operation, however, we return for simplicity | ||
| return weight, weight_scale | ||
|
Comment on lines
+94
to
+127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Find all calls to resmooth_to_fp8_e8m0
echo "=== Searching for resmooth_to_fp8_e8m0 calls ==="
rg -n 'resmooth_to_fp8_e8m0' --type=py -A 3 -B 1
echo ""
echo "=== Checking file header for copyright ==="
head -20 tensorrt_llm/quantization/utils/fp8_utils.pyRepository: NVIDIA/TensorRT-LLM Length of output: 3302 Add NVIDIA copyright header to the file. The file Regarding the function signature change: All call sites in the codebase correctly use the new signature and are compatible with the default 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def get_m_alignment_for_contiguous_layout(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Batch dimension handling is broken in the kernel.
The kernel receives batched 3D tensors (
w_viewands_viewwith shape(num_batches, M, K)) but only receivesstride(1)andstride(2). The grid on Line 111-112 multiplies bynum_batches, but the kernel doesn't extract the batch ID frompidor use batch strides. This causes incorrect indexing for multi-batch inputs.🔎 Proposed fix
Add batch handling to the kernel:
@triton.jit def _resmooth_kernel( w_ptr, s_ptr, M, K, + stride_wb, stride_wm, stride_wk, + stride_sb, stride_sm, stride_sk, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, ): pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_k = tl.cdiv(K, BLOCK_K) - pid_m = pid // num_pid_k - pid_k = pid % num_pid_k + num_pid_per_batch = num_pid_m * num_pid_k + + pid_batch = pid // num_pid_per_batch + pid_remainder = pid % num_pid_per_batch + pid_m = pid_remainder // num_pid_k + pid_k = pid_remainder % num_pid_k - s_offset = pid_m * stride_sm + pid_k * stride_sk + s_offset = pid_batch * stride_sb + pid_m * stride_sm + pid_k * stride_sk old_scale = tl.load(s_ptr + s_offset) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) w_mask = (rm[:, None] < M) & (rk[None, :] < K) - w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk + w_offsets = pid_batch * stride_wb + rm[:, None] * stride_wm + rk[None, :] * stride_wk w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)Update the kernel call on Line 114-125:
_resmooth_kernel[grid]( w_view, s_view, M, K, + w_view.stride(0), w_view.stride(1), w_view.stride(2), + s_view.stride(0), s_view.stride(1), s_view.stride(2), BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K, )📝 Committable suggestion
🤖 Prompt for AI Agents