Skip to content

Commit 03cdf58

Browse files
Nekofish-Lkarljang
andauthored
[None][fix] impl fused triton kernel for e8m0 resmooth to reduce memory footprint (NVIDIA#10327)
Signed-off-by: Nekofish-L <[email protected]> Co-authored-by: Kanghwan <[email protected]>
1 parent f001c49 commit 03cdf58

File tree

1 file changed

+79
-39
lines changed

1 file changed

+79
-39
lines changed

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -51,45 +51,85 @@ def per_token_cast_to_fp8_e8m0(
5151
g, m, n), sf
5252

5353

54-
def per_block_cast_to_fp8_e8m0(
55-
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
56-
if x.dim() == 2:
57-
m, n = x.shape
58-
x_padded = torch.zeros((align(m, 128), align(n, 128)),
59-
dtype=x.dtype,
60-
device=x.device)
61-
x_padded[:m, :n] = x
62-
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
63-
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
64-
sf = ceil_to_ue8m0(x_amax / 448.0)
65-
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
66-
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
67-
x_view.size(0), x_view.size(2))
68-
else:
69-
g, m, n = x.shape
70-
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
71-
dtype=x.dtype,
72-
device=x.device)
73-
x_padded[:, :m, :n] = x
74-
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
75-
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
76-
sf = ceil_to_ue8m0(x_amax / 448.0)
77-
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
78-
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
79-
x_view.size(0), x_view.size(1), x_view.size(3))
80-
81-
82-
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
83-
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
84-
weight = weight.cuda()
85-
sf = sf.cuda()
86-
if weight.dim() == 2:
87-
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
88-
128, dim=1)[:weight.shape[0], :weight.shape[1]]
89-
else:
90-
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
91-
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
92-
return per_block_cast_to_fp8_e8m0(x)
54+
@triton.jit
55+
def _resmooth_kernel(
56+
w_ptr,
57+
s_ptr,
58+
M,
59+
K,
60+
stride_wb,
61+
stride_wm,
62+
stride_wk,
63+
stride_sb,
64+
stride_sm,
65+
stride_sk,
66+
BLOCK_M: tl.constexpr,
67+
BLOCK_K: tl.constexpr,
68+
):
69+
batch_idx = tl.program_id(0)
70+
pid_m = tl.program_id(1)
71+
pid_k = tl.program_id(2)
72+
73+
curr_w_ptr = w_ptr + batch_idx * stride_wb
74+
curr_s_ptr = s_ptr + batch_idx * stride_sb
75+
76+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
77+
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
78+
79+
s_offset = pid_m * stride_sm + pid_k * stride_sk
80+
old_scale = tl.load(curr_s_ptr + s_offset)
81+
82+
w_mask = (rm[:, None] < M) & (rk[None, :] < K)
83+
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
84+
w_fp8 = tl.load(curr_w_ptr + w_offsets, mask=w_mask, other=0.0)
85+
w_fp32 = w_fp8.to(tl.float32)
86+
87+
w_val = w_fp32 * old_scale
88+
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)
89+
90+
# UE8M0 sf = 2 ^ ceil(log2(sf))
91+
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
92+
w_requant = w_val * (1.0 / new_scale)
93+
94+
tl.store(curr_w_ptr + w_offsets, w_requant, mask=w_mask)
95+
tl.store(curr_s_ptr + s_offset, new_scale)
96+
97+
98+
def resmooth_to_fp8_e8m0(
99+
weight: torch.Tensor,
100+
weight_scale: torch.Tensor,
101+
block_size: tuple[int, int] = (128, 128),
102+
):
103+
assert weight.dtype == torch.float8_e4m3fn
104+
assert weight_scale.dtype == torch.float32
105+
106+
orig_shape = weight.shape
107+
M, K = orig_shape[-2:]
108+
w_view = weight.view(-1, M, K)
109+
s_view = weight_scale.view(-1, weight_scale.shape[-2],
110+
weight_scale.shape[-1])
111+
112+
num_batches = w_view.shape[0]
113+
BLOCK_M, BLOCK_K = block_size
114+
115+
grid = (num_batches, triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
116+
117+
_resmooth_kernel[grid](
118+
w_view,
119+
s_view,
120+
M,
121+
K,
122+
w_view.stride(0),
123+
w_view.stride(1),
124+
w_view.stride(2),
125+
s_view.stride(0),
126+
s_view.stride(1),
127+
s_view.stride(2),
128+
BLOCK_M=BLOCK_M,
129+
BLOCK_K=BLOCK_K,
130+
)
131+
# this is an in-place operation, however, we return for simplicity
132+
return weight, weight_scale
93133

94134

95135
def get_m_alignment_for_contiguous_layout():

0 commit comments

Comments
 (0)