Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 74 additions & 39 deletions tensorrt_llm/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,45 +51,80 @@ def per_token_cast_to_fp8_e8m0(
g, m, n), sf


def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))


def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
@triton.jit
def _resmooth_kernel(
w_ptr,
s_ptr,
M,
K,
stride_wm,
stride_wk,
stride_sm,
stride_sk,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_k = tl.cdiv(K, BLOCK_K)
pid_m = pid // num_pid_k
pid_k = pid % num_pid_k

s_offset = pid_m * stride_sm + pid_k * stride_sk
old_scale = tl.load(s_ptr + s_offset)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)

w_mask = (rm[:, None] < M) & (rk[None, :] < K)
w_offsets = rm[:, None] * stride_wm + rk[None, :] * stride_wk
w_fp8 = tl.load(w_ptr + w_offsets, mask=w_mask, other=0.0)
w_fp32 = w_fp8.to(tl.float32)

w_val = w_fp32 * old_scale
block_amax = tl.maximum(tl.max(tl.abs(w_val)), 1e-4)

# E8M0 sf = 2 ^ ceil(log2(sf))
new_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(block_amax / 448.0)))
w_requant = w_val * (1.0 / new_scale)

tl.store(w_ptr + w_offsets, w_requant, mask=w_mask)
tl.store(s_ptr + s_offset, new_scale)


def resmooth_to_fp8_e8m0(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: tuple[int, int] = (128, 128),
):
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32

orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)
s_view = weight_scale.view(-1, weight_scale.shape[-2],
weight_scale.shape[-1])

num_batches = w_view.shape[0]
BLOCK_M, BLOCK_K = block_size

grid = (num_batches * (triton.cdiv(M, BLOCK_M)) *
(triton.cdiv(K, BLOCK_K)), )

_resmooth_kernel[grid](
w_view,
s_view,
M,
K,
w_view.stride(1),
w_view.stride(2),
s_view.stride(1),
s_view.stride(2),
BLOCK_M=BLOCK_M,
BLOCK_K=BLOCK_K,
)
# this is an in-place operation, however, we return for simplicity
return weight, weight_scale


def get_m_alignment_for_contiguous_layout():
Expand Down