@@ -23,7 +23,7 @@ def fp8_attention_kernel(
23
23
24
24
# Output tensor with 4D shape in FP8 format
25
25
out = torch .empty (
26
- [batch , heads , seq_len , head_dim ], dtype = torch .float8_e5m2 , device = q .device
26
+ [batch , heads , seq_len , head_dim ], dtype = torch .float8_e4m3fn , device = q .device
27
27
)
28
28
29
29
# Scale factor for attention
@@ -54,8 +54,15 @@ def fp8_attention_kernel(
54
54
k_tile_t = k_tile .transpose (0 , 1 ) # [dim, tile_n]
55
55
56
56
# Compute Q @ K^T with FP8 inputs, result in FP32
57
- qk = torch .matmul (q_tile , k_tile_t ).to (
58
- torch .float32
57
+ scale_a = hl .full ([], 1.0 , dtype = torch .float32 )
58
+ scale_b = hl .full ([], 1.0 , dtype = torch .float32 )
59
+ qk = torch ._scaled_mm (
60
+ q_tile ,
61
+ k_tile_t ,
62
+ scale_a ,
63
+ scale_b ,
64
+ use_fast_accum = False ,
65
+ out_dtype = torch .float32 ,
59
66
) # [tile_m, tile_n]
60
67
61
68
# Scale QK scores first
@@ -91,7 +98,16 @@ def fp8_attention_kernel(
91
98
92
99
# Accumulate attention @ V with FP8 GEMM
93
100
v_t = v_tile .transpose (0 , 1 ) # [tile_n, dim]
94
- pv = torch .matmul (p_fp8 , v_t ).to (torch .float32 ) # [tile_m, dim]
101
+ scale_p = hl .full ([], 1.0 , dtype = torch .float32 )
102
+ scale_v = hl .full ([], 1.0 , dtype = torch .float32 )
103
+ pv = torch ._scaled_mm (
104
+ p_fp8 ,
105
+ v_t ,
106
+ scale_p ,
107
+ scale_v ,
108
+ use_fast_accum = False ,
109
+ out_dtype = torch .float32 ,
110
+ ) # [tile_m, dim]
95
111
acc = acc + pv
96
112
97
113
# Update max tracker
@@ -100,18 +116,18 @@ def fp8_attention_kernel(
100
116
# Final normalization
101
117
acc = acc / l_i [:, None ]
102
118
# Convert to FP8 before writing to output
103
- out [b , h , tile_m , :] = acc .to (torch .float8_e5m2 )
119
+ out [b , h , tile_m , :] = acc .to (torch .float8_e4m3fn )
104
120
105
121
return out
106
122
107
123
108
124
def preprocess_fp8_attention_inputs (
109
125
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
110
126
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
111
- q_fp8 = q .to (torch .float8_e5m2 )
112
- k_fp8 = k .to (torch .float8_e5m2 )
127
+ q_fp8 = q .to (torch .float8_e4m3fn )
128
+ k_fp8 = k .to (torch .float8_e4m3fn )
113
129
v = v .permute (0 , 1 , 3 , 2 )
114
- v_fp8 = v .to (torch .float8_e5m2 )
130
+ v_fp8 = v .to (torch .float8_e4m3fn )
115
131
batch , heads , seq_len , head_dim = q .shape
116
132
q_fp8_reshaped = q_fp8 .reshape (batch * heads , seq_len , head_dim )
117
133
k_fp8_reshaped = k_fp8 .reshape (batch * heads , seq_len , head_dim )
@@ -147,13 +163,25 @@ def _fp8_attention_pytorch_impl(
147
163
k_i = k_fp8 [i ] # [seq, dim] - already FP8
148
164
v_i = v_fp8 [i ] # [dim, seq] - pre-transposed, already FP8
149
165
150
- # For Q @ K^T, we need K^T to be column-major
151
- kt_fp8 = k_i .t () # column-major [dim, seq]
152
-
153
- # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154
- q_deq = q_i .to (torch .float32 )
155
- kt_deq = kt_fp8 .to (torch .float32 )
156
- qk = torch .matmul (q_deq , kt_deq )
166
+ # For Q @ K^T using torch._scaled_mm
167
+ # torch._scaled_mm requires column-major for second operand
168
+ # k_i is [seq, dim], we need K^T as [dim, seq] in column-major
169
+ # Direct conversion: k_i -> contiguous -> transpose view
170
+ kt_fp8_col_major = k_i .contiguous ().t () # [dim, seq] in column-major
171
+
172
+ # Create scale tensors
173
+ scale_q = torch .tensor (1.0 , device = q_i .device )
174
+ scale_k = torch .tensor (1.0 , device = k_i .device )
175
+
176
+ # Q @ K^T using torch._scaled_mm
177
+ qk = torch ._scaled_mm (
178
+ q_i ,
179
+ kt_fp8_col_major ,
180
+ scale_q ,
181
+ scale_k ,
182
+ use_fast_accum = False ,
183
+ out_dtype = torch .float32 ,
184
+ )
157
185
158
186
# Compute max before scaling
159
187
qk_max = torch .amax (qk , dim = - 1 , keepdim = True )
@@ -168,16 +196,26 @@ def _fp8_attention_pytorch_impl(
168
196
# Step 2: Attention @ V using FP8
169
197
# P is [seq, seq], V is [dim, seq]
170
198
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171
- p_fp8 = p_norm .to (torch .float8_e5m2 ) # row-major [seq, seq]
199
+ p_fp8 = p_norm .to (torch .float8_e4m3fn ) # row-major [seq, seq]
172
200
173
201
# v_i is [dim, seq], already FP8
174
- vt_fp8 = v_i .t () # column-major [seq, dim]
175
-
176
- # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177
- p_deq = p_fp8 .to (torch .float32 )
178
- vt_deq = vt_fp8 .to (torch .float32 )
179
- out_i = torch .matmul (p_deq , vt_deq )
180
- out_i = out_i .to (torch .float8_e5m2 ) # convert back to FP8
202
+ # Direct conversion: v_i -> contiguous -> transpose view
203
+ vt_fp8_col_major = v_i .contiguous ().t () # [seq, dim] in column-major
204
+
205
+ # Create scale tensors for P @ V^T
206
+ scale_p = torch .tensor (1.0 , device = p_fp8 .device )
207
+ scale_v = torch .tensor (1.0 , device = v_i .device )
208
+
209
+ # P @ V^T using torch._scaled_mm
210
+ out_i = torch ._scaled_mm (
211
+ p_fp8 ,
212
+ vt_fp8_col_major ,
213
+ scale_p ,
214
+ scale_v ,
215
+ use_fast_accum = False ,
216
+ out_dtype = torch .float32 ,
217
+ )
218
+ out_i = out_i .to (torch .float8_e4m3fn ) # convert back to FP8 to match kernel
181
219
182
220
outputs .append (out_i )
183
221
@@ -192,7 +230,7 @@ def fp8_attention_pytorch(
192
230
v : torch .Tensor , # [batch, heads, seq, dim]
193
231
) -> Callable [[], torch .Tensor ]:
194
232
"""
195
- Baseline PyTorch implementation of FP8 attention using FP8 e5m2 .
233
+ Baseline PyTorch implementation of FP8 attention using torch._scaled_mm .
196
234
"""
197
235
batch , heads , seq_len , head_dim = q .shape
198
236
q_fp8 , k_fp8 , v_fp8 = preprocess_fp8_attention_inputs (q , k , v )
0 commit comments