Skip to content

Commit f4789ef

Browse files
[triton_kernels][matmul] support mxfp8 x (triton-lang#8062)
1 parent 01d3c87 commit f4789ef

File tree

5 files changed

+103
-48
lines changed

5 files changed

+103
-48
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class Case:
260260
(False, False, False),
261261
(True, False, False),
262262
(False, True, False),
263+
(False, True, True),
263264
(True, True, False),
264265
(True, True, True),
265266
])
@@ -277,9 +278,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
277278
if weight_dtype_str.startswith("mx"):
278279
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
279280
pytest.skip("float8 x mx not supported with cuda capability < 10")
280-
if act_dtype_str == "mxfloat8_e4m3fn":
281-
if is_persistent:
282-
pytest.skip("mx x mx not supported with persistent kernel")
283281
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
284282
pytest.skip("Not enough memory on A100")
285283

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def matmul_ogs(x, w, bias,
332332
w_scale = precision_config.weight_scale
333333
w_has_mx = w_scale is not None
334334
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
335+
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
335336
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
336337
if not isinstance(w, Tensor):
337338
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
@@ -456,7 +457,7 @@ def matmul_ogs(x, w, bias,
456457
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
457458
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
458459
out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
459-
out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
460+
out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
460461
# launch kernel
461462
kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
462463
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
@@ -467,15 +468,14 @@ def matmul_ogs(x, w, bias,
467468
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
468469
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
469470
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
470-
*out_matmul_scale_strides[-3:],
471+
*out_matmul_scale_strides[-4:],
471472
x_tensor_or_tma, x_storage.data, *x_strides,
472473
flex.lhs_data.scale,
473474
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
474475
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
475476
flex.rhs_data.scale,
476477
w_scale_tensor_or_tma, *w_scale_strides,
477478
bias, bias_stride,
478-
x.shape[-2],
479479
x.shape[-2] if routing_data.expt_hist is None else None,
480480
N, K,
481481
betas, gammas,

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ def _zero_masked_rows(
3333
def _matmul_ogs(
3434
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3535
YExpectedScale, YActualScale, YChecksumScale,
36-
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
36+
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
3737
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
3838
XScale,
3939
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
4040
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
4141
WScale,
4242
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
4343
B, stride_b_e, # Bias
44-
NRows, M, N, K, # shapes
44+
M, N, K, # shapes
4545
# expt data
4646
Betas, Gammas,
4747
GatherIndx,
@@ -165,7 +165,7 @@ def _matmul_ogs(
165165
if SPLIT_K > 1:
166166
Y += pid_k.to( index_type) * stride_y_k
167167
if is_out_microscaled:
168-
YActualScale += pid_k.to(index_type) * stride_x_mx_k
168+
YActualScale += pid_k.to(index_type) * stride_y_mx_k
169169
# set masked out rows to 0
170170
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
171171
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
@@ -408,7 +408,7 @@ def _matmul_ogs(
408408
YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
409409
mask = mask_m[:, None] & mask_n[None, :]
410410
if is_out_microscaled:
411-
MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
411+
MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE
412412
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
413413
tl.static_assert(EPILOGUE_FN is not None)
414414
out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
@@ -420,7 +420,7 @@ def _matmul_ogs(
420420
YActualScale += start_m * stride_y_mx_m
421421
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
422422
else:
423-
YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
423+
YActualScalePtrs = YActualScale + (offs_y_m - num_idxs // N_EXPTS_ACT).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
424424
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
425425
else:
426426
out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
8181
def _p_matmul_ogs(
8282
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
8383
YExpectedScale, YActualScale, YChecksumScale,
84-
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
84+
stride_y_mx_k, stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
8585
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
8686
XScale,
8787
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
8888
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
8989
WScale,
90-
MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
90+
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
9191
B, stride_b_e, # Bias
92-
NRows, M, N, K, # shapes
92+
M, N, K, # shapes
9393
# expt data
9494
Betas, Gammas,
9595
GatherIndx,
@@ -133,14 +133,14 @@ def _p_matmul_ogs(
133133
if Y_TMA_MODE is not None:
134134
Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
135135

136-
is_microscaled_format: tl.constexpr = MxScale is not None
137-
tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
136+
is_w_microscaled: tl.constexpr = WMxScale is not None
137+
tl.static_assert(not is_w_microscaled or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
138138
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
139-
if is_microscaled_format:
139+
if is_w_microscaled:
140140
w_type: tl.constexpr = get_dtype(W)
141141
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
142-
"mx_weight_ptr must be uint8")
143-
tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
142+
"mx_weight_ptr must be uint8 or fp8")
143+
tl.static_assert(get_dtype(WMxScale) == tl.uint8, "mx_scale_ptr must be uint8")
144144
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
145145
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
146146

@@ -153,6 +153,13 @@ def _p_matmul_ogs(
153153
MX_SCALE_BLOCK_K: tl.constexpr = 1
154154
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
155155
tl.static_assert(SWIZZLE_MX_SCALE is None)
156+
is_x_microscaled: tl.constexpr = XMxScale is not None
157+
if is_x_microscaled:
158+
x_type: tl.constexpr = get_dtype(X)
159+
tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
160+
tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
161+
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
162+
is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
156163

157164
if ExptOffsSum is not None:
158165
# Determine how much padding there is on the expert data. This allows us to
@@ -214,7 +221,7 @@ def _p_matmul_ogs(
214221
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
215222
local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
216223

217-
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
224+
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_w_microscaled and BLOCK_M * BLOCK_N >= 128 * 256
218225

219226
for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
220227
expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
@@ -241,25 +248,42 @@ def _p_matmul_ogs(
241248
else:
242249
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
243250
mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
244-
elif X_TMA_MODE is None:
245-
tl.static_assert(HAS_GATHER)
251+
elif X_TMA_MODE is None or is_x_microscaled:
246252
offs_m = off_m + tl.arange(0, BLOCK_M)
247253
if M is not None:
248254
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
249255
else:
250256
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
251257
# no needs to bounds-check here because `offs_m` wraps around M dim
252-
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
258+
if GatherIndx is not None:
259+
tl.static_assert(HAS_GATHER)
260+
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
253261
offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
254262

255263

264+
if is_x_microscaled:
265+
XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z
266+
if GatherIndx is None:
267+
XMxScalePtrs += start_m * stride_x_mx_m
268+
offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
269+
XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m
270+
XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k
271+
else:
272+
XMxScalePtrs = None
273+
256274
acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
257275
for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
258276
off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
259277
off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
260278
off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
261279

262280
# --- load x ---
281+
if is_x_microscaled:
282+
if EVEN_K:
283+
mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
284+
else:
285+
mask_k_scale = offs_k_scale < tl.cdiv(K, MX_PACK_DIVISOR)
286+
263287
if USE_GATHER_TMA:
264288
x = X.gather(offs_x_m, off_k)
265289
elif X_TMA_MODE == "dense":
@@ -288,28 +312,33 @@ def _p_matmul_ogs(
288312
w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
289313

290314
# --- load w_scale ---
291-
if is_microscaled_format:
315+
if is_w_microscaled:
292316
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
293-
mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
294-
if x_format == "fp16" or x_format == "bf16":
317+
w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
318+
319+
if is_x_microscaled:
320+
x_scales = tl.load(XMxScalePtrs, mask=mask_k_scale[None, :], other=0.0)
321+
elif x_format == "fp16" or x_format == "bf16":
295322
x_scales: tl.constexpr = None
296323
else:
297324
x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
298325
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
299326
flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
300-
w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
327+
w_scales = WMxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
301328
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
302329
w_scales = unswizzle_mx_scale_bw(w_scales)
303330
else:
304-
w_scales = MxScale.load([expt_id, off_k_mx, off_n])
331+
w_scales = WMxScale.load([expt_id, off_k_mx, off_n])
305332
w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
306333

307334
# --- update accumulator ---
308-
if is_microscaled_format:
335+
if is_w_microscaled:
309336
if SWAP_XW:
310-
acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
337+
acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
311338
else:
312-
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
339+
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
340+
if is_x_microscaled:
341+
XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
313342
else:
314343
if SWAP_XW:
315344
acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
@@ -392,6 +421,10 @@ def _p_matmul_ogs(
392421
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
393422
tl.static_assert(len(accs) == SUBTILE_FACTOR)
394423

424+
if is_out_microscaled:
425+
MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE
426+
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
427+
395428
for a_i in tl.static_range(len(accs)):
396429
acc_tile = accs[a_i]
397430
acc_tile *= x_scale * w_scale
@@ -414,20 +447,47 @@ def _p_matmul_ogs(
414447

415448
if MASK_ACC:
416449
out = tl.where(mask_m[:, None], out, 0.0)
417-
# Flexpoint
418-
out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
419-
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
420-
out = float_to_flex(
421-
out, YExpectedScale,
422-
None, # ActualScale: local absmax is tracked and updated after the loop
423-
YChecksumScale,
424-
None, # mask: out is manually masked to 0
425-
YPtr, FLEXPOINT_SATURATE_INF
426-
)
427-
if EPILOGUE_FN is not None:
428-
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
429-
430450
out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
451+
if is_out_microscaled:
452+
tl.static_assert(EPILOGUE_FN is not None)
453+
offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
454+
mask_n = offs_y_n < yN
455+
out, out_scale = EPILOGUE_FN(out, mask_m[:, None] & mask_n[None, :], *epilogue_fn_args)
456+
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
457+
offs_y_n_scale = off_n1 // ACTIVATION_REDUCTION_N // MXFP_BLOCK_SIZE + a_i * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N)
458+
mask_n_scale = offs_y_n_scale < N_MX_BLOCK
459+
offs_y_mx_k = 0
460+
if USE_SCATTER_TMA:
461+
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
462+
# there shouldn't be any other negative values.
463+
offs_y_mx_z = 0
464+
offs_y_mx_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
465+
elif Y_TMA_MODE == "dense":
466+
offs_y_mx_z = pid_k * batch_size + start_z1
467+
offs_y_mx_m = off_m1 + tl.arange(0, BLOCK_M)
468+
elif Y_TMA_MODE == "ragged":
469+
offs_y_mx_z = pid_k
470+
offs_y_mx_m = start_m1 + off_m1 + tl.arange(0, BLOCK_M)
471+
else:
472+
tl.static_assert(Y_TMA_MODE is None)
473+
offs_y_mx_k = pid_k1
474+
offs_y_mx_z = start_z1
475+
YActualScalePtrs = YActualScale + offs_y_mx_k.to(index_type) * stride_y_mx_k + offs_y_mx_z.to(index_type) * stride_y_mx_z + offs_y_mx_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
476+
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
477+
else:
478+
# Flexpoint
479+
out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
480+
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
481+
out = float_to_flex(
482+
out, YExpectedScale,
483+
None, # ActualScale: local absmax is tracked and updated after the loop
484+
YChecksumScale,
485+
None, # mask: out is manually masked to 0
486+
YPtr, FLEXPOINT_SATURATE_INF
487+
)
488+
if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
489+
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
490+
431491
out = out.to(YPtr.dtype.element_ty)
432492
if USE_SCATTER_TMA:
433493
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
@@ -452,7 +512,7 @@ def _p_matmul_ogs(
452512

453513

454514
# Update the flexpoint scales
455-
if YActualScale is not None:
515+
if YActualScale is not None and not is_out_microscaled:
456516
tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
457517

458518

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ def make_default_opt_flags_nvidia(
177177
else:
178178
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
179179
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
180-
# TEMP CHANGE
181-
if precision_config.act_scale is not None or precision_config.out_scale is not None:
182-
is_persistent = False
183180
# TMA is slower for batched matmuls with small m/n/k.
184181
if m * n * k < 131072:
185182
is_persistent = False

0 commit comments

Comments
 (0)