diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 0776e2d4d7b..cdcd2ed864a 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -51,45 +51,80 @@ def per_token_cast_to_fp8_e8m0( g, m, n), sf -def per_block_cast_to_fp8_e8m0( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if x.dim() == 2: - m, n = x.shape - x_padded = torch.zeros((align(m, 128), align(n, 128)), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - sf = ceil_to_ue8m0(x_amax / 448.0) - x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(2)) - else: - g, m, n = x.shape - x_padded = torch.zeros((g, align(m, 128), align(n, 128)), - dtype=x.dtype, - device=x.device) - x_padded[:, :m, :n] = x - x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4) - sf = ceil_to_ue8m0(x_amax / 448.0) - x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(1), x_view.size(3)) - - -def resmooth_to_fp8_e8m0(weight: torch.Tensor, - sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - weight = weight.cuda() - sf = sf.cuda() - if weight.dim() == 2: - x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( - 128, dim=1)[:weight.shape[0], :weight.shape[1]] - else: - x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( - 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] - return per_block_cast_to_fp8_e8m0(x) +@triton.jit +def _resmooth_kernel( + w_ptr, + s_ptr, + M, + K, + stride_wm, + stride_wk, + stride_sm, + stride_sk, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_k = tl.cdiv(K, BLOCK_K) + pid_m = pid // num_pid_k + pid_k = pid % num_pid_k + + s_offset = pid_m * stride_sm + pid_k * stride_sk + old_scale = tl.load(s_ptr + s_offset) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + + w_mask = (rm[:, None] < M) & (rk[None, :] < K) + w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk + w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0) + w_fp32 = w_fp8.to(tl.float32) + + w_val = w_fp32 * old_scale + block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4) + + # E8M0 sf = 2 ^ ceil(log2(sf)) + new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0))) + w_requant = w_val * (1.0 / new_scale) + + tl.store(w_ptr + w_offsets, w_requant, mask=w_mask) + tl.store(s_ptr + s_offset, new_scale) + + +def resmooth_to_fp8_e8m0( + weight: torch.Tensor, + weight_scale: torch.Tensor, + block_size: tuple[int, int] = (128, 128), +): + assert weight.dtype == torch.float8_e4m3fn + assert weight_scale.dtype == torch.float32 + + orig_shape = weight.shape + M, K = orig_shape[-2:] + w_view = weight.view(-1, M, K) + s_view = weight_scale.view(-1, weight_scale.shape[-2], + weight_scale.shape[-1]) + + num_batches = w_view.shape[0] + BLOCK_M, BLOCK_K = block_size + + grid = (num_batches * (triton.cdiv(M, BLOCK_M)) * + (triton.cdiv(K, BLOCK_K)), ) + + _resmooth_kernel[grid]( + w_view, + s_view, + M, + K, + w_view.stride(1), + w_view.stride(2), + s_view.stride(1), + s_view.stride(2), + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + ) + # this is an in-place operation, however, we return for simplicity + return weight, weight_scale def get_m_alignment_for_contiguous_layout():