Skip to content

Commit 311fca2

Browse files
committed
Use torch._scaled_mm instead of torch.matmul for fp8_gemm and fp8_attention
1 parent 3e4cf98 commit 311fca2

File tree

4 files changed

+362
-155
lines changed

4 files changed

+362
-155
lines changed

examples/fp8_attention.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fp8_attention_kernel(
2323

2424
# Output tensor with 4D shape in FP8 format
2525
out = torch.empty(
26-
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
26+
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
2727
)
2828

2929
# Scale factor for attention
@@ -54,8 +54,15 @@ def fp8_attention_kernel(
5454
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
5555

5656
# Compute Q @ K^T with FP8 inputs, result in FP32
57-
qk = torch.matmul(q_tile, k_tile_t).to(
58-
torch.float32
57+
scale_a = hl.full([], 1.0, dtype=torch.float32)
58+
scale_b = hl.full([], 1.0, dtype=torch.float32)
59+
qk = torch._scaled_mm(
60+
q_tile,
61+
k_tile_t,
62+
scale_a,
63+
scale_b,
64+
use_fast_accum=False,
65+
out_dtype=torch.float32,
5966
) # [tile_m, tile_n]
6067

6168
# Scale QK scores first
@@ -91,7 +98,16 @@ def fp8_attention_kernel(
9198

9299
# Accumulate attention @ V with FP8 GEMM
93100
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
94-
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
101+
scale_p = hl.full([], 1.0, dtype=torch.float32)
102+
scale_v = hl.full([], 1.0, dtype=torch.float32)
103+
pv = torch._scaled_mm(
104+
p_fp8,
105+
v_t,
106+
scale_p,
107+
scale_v,
108+
use_fast_accum=False,
109+
out_dtype=torch.float32,
110+
) # [tile_m, dim]
95111
acc = acc + pv
96112

97113
# Update max tracker
@@ -100,18 +116,18 @@ def fp8_attention_kernel(
100116
# Final normalization
101117
acc = acc / l_i[:, None]
102118
# Convert to FP8 before writing to output
103-
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
119+
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
104120

105121
return out
106122

107123

108124
def preprocess_fp8_attention_inputs(
109125
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
110126
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111-
q_fp8 = q.to(torch.float8_e5m2)
112-
k_fp8 = k.to(torch.float8_e5m2)
127+
q_fp8 = q.to(torch.float8_e4m3fn)
128+
k_fp8 = k.to(torch.float8_e4m3fn)
113129
v = v.permute(0, 1, 3, 2)
114-
v_fp8 = v.to(torch.float8_e5m2)
130+
v_fp8 = v.to(torch.float8_e4m3fn)
115131
batch, heads, seq_len, head_dim = q.shape
116132
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
117133
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
@@ -147,13 +163,25 @@ def _fp8_attention_pytorch_impl(
147163
k_i = k_fp8[i] # [seq, dim] - already FP8
148164
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
149165

150-
# For Q @ K^T, we need K^T to be column-major
151-
kt_fp8 = k_i.t() # column-major [dim, seq]
152-
153-
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154-
q_deq = q_i.to(torch.float32)
155-
kt_deq = kt_fp8.to(torch.float32)
156-
qk = torch.matmul(q_deq, kt_deq)
166+
# For Q @ K^T using torch._scaled_mm
167+
# torch._scaled_mm requires column-major for second operand
168+
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
169+
# Direct conversion: k_i -> contiguous -> transpose view
170+
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major
171+
172+
# Create scale tensors
173+
scale_q = torch.tensor(1.0, device=q_i.device)
174+
scale_k = torch.tensor(1.0, device=k_i.device)
175+
176+
# Q @ K^T using torch._scaled_mm
177+
qk = torch._scaled_mm(
178+
q_i,
179+
kt_fp8_col_major,
180+
scale_q,
181+
scale_k,
182+
use_fast_accum=False,
183+
out_dtype=torch.float32,
184+
)
157185

158186
# Compute max before scaling
159187
qk_max = torch.amax(qk, dim=-1, keepdim=True)
@@ -168,16 +196,26 @@ def _fp8_attention_pytorch_impl(
168196
# Step 2: Attention @ V using FP8
169197
# P is [seq, seq], V is [dim, seq]
170198
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171-
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
199+
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]
172200

173201
# v_i is [dim, seq], already FP8
174-
vt_fp8 = v_i.t() # column-major [seq, dim]
175-
176-
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177-
p_deq = p_fp8.to(torch.float32)
178-
vt_deq = vt_fp8.to(torch.float32)
179-
out_i = torch.matmul(p_deq, vt_deq)
180-
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
202+
# Direct conversion: v_i -> contiguous -> transpose view
203+
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major
204+
205+
# Create scale tensors for P @ V^T
206+
scale_p = torch.tensor(1.0, device=p_fp8.device)
207+
scale_v = torch.tensor(1.0, device=v_i.device)
208+
209+
# P @ V^T using torch._scaled_mm
210+
out_i = torch._scaled_mm(
211+
p_fp8,
212+
vt_fp8_col_major,
213+
scale_p,
214+
scale_v,
215+
use_fast_accum=False,
216+
out_dtype=torch.float32,
217+
)
218+
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel
181219

182220
outputs.append(out_i)
183221

@@ -192,7 +230,7 @@ def fp8_attention_pytorch(
192230
v: torch.Tensor, # [batch, heads, seq, dim]
193231
) -> Callable[[], torch.Tensor]:
194232
"""
195-
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
233+
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
196234
"""
197235
batch, heads, seq_len, head_dim = q.shape
198236
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)

examples/fp8_gemm.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
# Override default config to work around Triton tl.dot requirement:
12+
# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32`
13+
config = None
14+
if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1":
15+
config = helion.Config(block_sizes=[32, 32, 32])
16+
917

10-
@helion.kernel(static_shapes=True)
18+
@helion.kernel(static_shapes=True, config=config)
1119
def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1220
"""FP8 General Matrix Multiplication (GEMM).
1321
@@ -37,11 +45,24 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3745
x_tile = x[tile_m, tile_k]
3846
y_tile = y[tile_k, tile_n]
3947

40-
# Use torch.matmul which will be lowered to tl.dot
41-
# When the inputs are FP8, tl.dot handles them natively
42-
# The result needs to be converted to FP32 for accumulation
43-
result = torch.matmul(x_tile, y_tile).to(torch.float32)
44-
acc = acc + result
48+
# torch._scaled_mm(A, B) requires B to be column-major
49+
# We make y_tile column-major by transposing twice
50+
y_tile_col_major = y_tile.transpose(0, 1).contiguous().transpose(0, 1)
51+
52+
# Create scale tensors
53+
scale_a = hl.full([], 1.0, dtype=torch.float32)
54+
scale_b = hl.full([], 1.0, dtype=torch.float32)
55+
56+
# Use torch._scaled_mm for FP8 GEMM, then accumulate result in FP32
57+
mm_out = torch._scaled_mm(
58+
x_tile,
59+
y_tile_col_major,
60+
scale_a,
61+
scale_b,
62+
use_fast_accum=False,
63+
out_dtype=torch.float32,
64+
)
65+
acc = acc + mm_out
4566
out[tile_m, tile_n] = acc.to(torch.float16)
4667

4768
return out
@@ -52,12 +73,17 @@ def reference_fp8_gemm_pytorch(
5273
) -> torch.Tensor:
5374
"""Reference implementation using torch._scaled_mm."""
5475
# torch._scaled_mm requires column-major for second operand
55-
y_fp8_t = y_fp8.T.contiguous().T
76+
y_fp8_col_major = y_fp8.T.contiguous().T
5677
scale_a = torch.tensor(1.0, device=x_fp8.device)
5778
scale_b = torch.tensor(1.0, device=x_fp8.device)
5879
return torch._scaled_mm(
59-
x_fp8, y_fp8_t, scale_a, scale_b, use_fast_accum=False, out_dtype=torch.float16
60-
)
80+
x_fp8,
81+
y_fp8_col_major,
82+
scale_a,
83+
scale_b,
84+
use_fast_accum=False,
85+
out_dtype=torch.float32,
86+
).to(torch.float16)
6187

6288

6389
def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)