Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 74 additions & 39 deletions tensorrt_llm/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,45 +51,80 @@ def per_token_cast_to_fp8_e8m0(
g, m, n), sf


def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))


def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wm,
stride_wk,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_k = tl.cdiv(K, BLOCK_K)
pid_m = pid // num_pid_k
pid_k = pid % num_pid_k

s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)

w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)

w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)

# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)

tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)

Comment on lines +54 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Batch dimension handling is broken in the kernel.

The kernel receives batched 3D tensors (w_view and s_view with shape (num_batches, M, K)) but only receives stride(1) and stride(2). The grid on Line 111-112 multiplies by num_batches, but the kernel doesn't extract the batch ID from pid or use batch strides. This causes incorrect indexing for multi-batch inputs.

🔎 Proposed fix

Add batch handling to the kernel:

 @triton.jit
 def _resmooth_kernel(
     w_ptr,
     s_ptr,
     M,
     K,
+    stride_wb,
     stride_wm,
     stride_wk,
+    stride_sb,
     stride_sm,
     stride_sk,
     BLOCK_M: tl.constexpr,
     BLOCK_K: tl.constexpr,
 ):
     pid = tl.program_id(0)
+    num_pid_m = tl.cdiv(M, BLOCK_M)
     num_pid_k = tl.cdiv(K, BLOCK_K)
-    pid_m = pid // num_pid_k
-    pid_k = pid % num_pid_k
+    num_pid_per_batch = num_pid_m * num_pid_k
+    
+    pid_batch = pid // num_pid_per_batch
+    pid_remainder = pid % num_pid_per_batch
+    pid_m = pid_remainder // num_pid_k
+    pid_k = pid_remainder % num_pid_k
 
-    s_offset = pid_m * stride_sm + pid_k * stride_sk
+    s_offset = pid_batch * stride_sb + pid_m * stride_sm + pid_k * stride_sk
     old_scale = tl.load(s_ptr + s_offset)
 
     rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
     rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
 
     w_mask = (rm[:, None] < M) & (rk[None, :] < K)
-    w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
+    w_offsets = pid_batch * stride_wb + rm[:, None] * stride_wm + rk[None, :] * stride_wk
     w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)

Update the kernel call on Line 114-125:

 _resmooth_kernel[grid](
     w_view,
     s_view,
     M,
     K,
+    w_view.stride(0),
     w_view.stride(1),
     w_view.stride(2),
+    s_view.stride(0),
     s_view.stride(1),
     s_view.stride(2),
     BLOCK_M=BLOCK_M,
     BLOCK_K=BLOCK_K,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wm,
stride_wk,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_k = tl.cdiv(K, BLOCK_K)
pid_m = pid // num_pid_k
pid_k = pid % num_pid_k
s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wb,
stride_wm,
stride_wk,
stride_sb,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_k = tl.cdiv(K, BLOCK_K)
num_pid_per_batch = num_pid_m * num_pid_k
pid_batch = pid // num_pid_per_batch
pid_remainder = pid % num_pid_per_batch
pid_m = pid_remainder // num_pid_k
pid_k = pid_remainder % num_pid_k
s_offset = pid_batch * stride_sb + pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = pid_batch * stride_wb + rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)
w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)
tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)
🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 54 to 92, the
kernel ignores the batch dimension so multi-batch inputs index incorrectly;
update the kernel to extract a batch id from tl.program_id(0) and compute
pid_m/pid_k relative to the per‑batch grid, add batch strides (e.g. stride_wb
and stride_sb) as additional kernel arguments, and incorporate them into
s_offset and w_offsets (compute a batch_base for w_ptr and s_ptr using
batch*stride_wb / batch*stride_sb before adding pid_m/pid_k offsets). Also
update the kernel launch to pass num_batches and the new batch strides so each
program id indexes the correct (batch, m, k) block.


def resmooth_to_fp8_e8m0(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: tuple[int, int] = (128, 128),
):
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32

orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)
s_view = weight_scale.view(-1, weight_scale.shape[-2],
weight_scale.shape[-1])

num_batches = w_view.shape[0]
BLOCK_M, BLOCK_K = block_size

grid = (num_batches * (triton.cdiv(M, BLOCK_M)) *
(triton.cdiv(K, BLOCK_K)), )

_resmooth_kernel[grid](
w_view,
s_view,
M,
K,
w_view.stride(1),
w_view.stride(2),
s_view.stride(1),
s_view.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
)
# this is an in-place operation, however, we return for simplicity
return weight, weight_scale
Comment on lines +94 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find all calls to resmooth_to_fp8_e8m0
echo "=== Searching for resmooth_to_fp8_e8m0 calls ==="
rg -n 'resmooth_to_fp8_e8m0' --type=py -A 3 -B 1

echo ""
echo "=== Checking file header for copyright ==="
head -20 tensorrt_llm/quantization/utils/fp8_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3302


Add NVIDIA copyright header to the file.

The file tensorrt_llm/quantization/utils/fp8_utils.py is missing the required NVIDIA copyright header. According to coding guidelines, all TensorRT-LLM code must contain an NVIDIA copyright header with the year of latest meaningful modification. Add the header at the top of the file before the imports.

Regarding the function signature change: All call sites in the codebase correctly use the new signature and are compatible with the default block_size=(128, 128) parameter. No signature updates are required at call sites.

🤖 Prompt for AI Agents
In tensorrt_llm/quantization/utils/fp8_utils.py around lines 94 to 127: add the
required NVIDIA copyright header at the very top of the file (before any
imports) with the correct year of latest meaningful modification and the
standard NVIDIA license phrasing used in this repo; ensure the header exactly
matches the project's header format and encoding, then save—no changes to the
function signature or any call sites are required since they are already
compatible with the default block_size.



def get_m_alignment_for_contiguous_layout():
Expand Down