Skip to content

Commit 6ac742f

Browse files
Merge OpenAI Triton commit 2b5505c (#4754)
This PR change the Triton base from 91d58f5 to 2b5505c (Jul 11). Pass rate: 98.46%
2 parents 9275820 + 46156f3 commit 6ac742f

32 files changed

+1258
-1733
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,6 +2546,13 @@ def get_reduced_dtype(dtype_str, op):
25462546
return dtype_str
25472547

25482548

2549+
def get_reduce_input(dtype_str, shape):
2550+
# limit the range of integers so that reduce ops do not overflow
2551+
low = 0 if dtype_str in uint_dtypes else -10 if dtype_str in integral_dtypes else None
2552+
high = 10 if dtype_str in integral_dtypes else None
2553+
return numpy_random(shape, dtype_str=dtype_str, low=low, high=high)
2554+
2555+
25492556
@pytest.mark.interpreter
25502557
@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [
25512558
'min',
@@ -2579,14 +2586,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
25792586
patch = f'z = tl.{op}(x, axis=0)'
25802587
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch})
25812588
# input
2582-
rs = RandomState(17)
2583-
# limit the range of integers so that the sum does not overflow
2584-
if dtype_str in integral_dtypes:
2585-
low = 0 if dtype_str in uint_dtypes else -100
2586-
high = 100
2587-
x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs, low=low, high=high)
2588-
else:
2589-
x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs)
2589+
x = get_reduce_input(dtype_str, (shape, ))
25902590
numpy_op = {
25912591
'sum': np.sum,
25922592
'max': np.max,
@@ -2611,7 +2611,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
26112611
else:
26122612
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
26132613
# triton result
2614-
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2614+
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
26152615
if is_xpu():
26162616
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas, num_warps=num_warps,
26172617
threads_per_warp=threads_per_warp)
@@ -2715,9 +2715,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
27152715

27162716
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'})
27172717
# input
2718-
rs = RandomState(17)
2719-
# limit the range of integers so that the sum does not overflow
2720-
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
2718+
x = get_reduce_input(dtype_str, shape)
27212719
x_tri = to_triton(x, device=device)
27222720
numpy_op = {
27232721
'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum':
@@ -2742,7 +2740,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
27422740

27432741
# triton result
27442742
z_shape = z_ref.shape
2745-
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2743+
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
27462744
BLOCK_K = 1 if len(shape) == 2 else shape[2]
27472745
IS_3D = bool(len(shape) == 3)
27482746
USE_I1 = dtype_str == 'bool'
@@ -3398,8 +3396,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
33983396
temp_file.write_text(ir)
33993397
kernel = triton.compile(str(temp_file))
34003398

3401-
rs = RandomState(17)
3402-
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
3399+
x = get_reduce_input(dtype_str, (M, N))
34033400
reduce2d = 'reduce2d' in epilogue_kind
34043401
z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1)
34053402
z = np.zeros(z_shape).astype(dtype_str)

python/triton_kernels/bench/bench_mlp.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch
77
import triton_kernels
88
import triton_kernels.swiglu
9-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType
9+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
10+
from triton_kernels.tensor import SwizzlingType, swizzle
1011
from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1112
from triton_kernels.numerics import InFlexData
1213
from triton_kernels.routing import routing
@@ -35,14 +36,12 @@ def quantize(w, dtype, dev, **opt):
3536
assert dtype == "mx4", f"{dtype=}"
3637
swizzle_mx_scale = opt.get("swizzle_mx_scale", None)
3738
swizzle_mx_value = opt.get("swizzle_mx_value", None)
38-
swizzle_axis = 2 if swizzle_mx_scale else None
3939
w = w.to(torch.bfloat16)
40-
w, mx_scales, weight_scale_shape = downcast_to_mxfp(w, torch.uint8, axis=1, swizzle_axis=swizzle_axis,
41-
swizzle_scale=swizzle_mx_scale,
42-
swizzle_value=swizzle_mx_value)
40+
w, mx_scales = downcast_to_mxfp(w, torch.uint8, axis=1)
41+
w = swizzle(w, swizzle_mx_value)
42+
mx_scales = swizzle(mx_scales, swizzle_mx_scale)
4343
return w, InFlexData(), MicroscalingCtx(weight_scale=mx_scales, swizzle_scale=swizzle_mx_scale,
44-
swizzle_value=swizzle_mx_value,
45-
actual_weight_scale_shape=weight_scale_shape)
44+
swizzle_value=swizzle_mx_value)
4645

4746

4847
@dataclass
@@ -111,11 +110,11 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
111110
swizzle_mx_value = None
112111
swizzle_mx_scale = None
113112
elif torch.cuda.get_device_capability()[0] < 10:
114-
swizzle_mx_value = SwizzlingType.HOPPER
115-
swizzle_mx_scale = SwizzlingType.HOPPER
113+
swizzle_mx_value = SwizzlingType.HOPPER_VALUE
114+
swizzle_mx_scale = SwizzlingType.HOPPER_SCALE
116115
else:
117116
swizzle_mx_value = None
118-
swizzle_mx_scale = SwizzlingType.BLACKWELL
117+
swizzle_mx_scale = SwizzlingType.BLACKWELL_SCALE
119118
opt1 = {"swizzle_mx_value": swizzle_mx_value, "swizzle_mx_scale": swizzle_mx_scale}
120119
opt2 = deepcopy(opt1)
121120
wg, wg_flex, wg_mx = quantize(wg, "bf16", dev, **optg)
@@ -216,7 +215,7 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
216215
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
217216
dense_dtypes = ["fp8", "fp8"]
218217
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
219-
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
218+
# roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
220219
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
221-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
222-
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
220+
# roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
221+
# roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")

python/triton_kernels/tests/test_matmul.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from triton_kernels.routing import routing
88
# matmul utilities
99
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
10-
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx, FusedActivation, FnSpecs
11-
from triton_kernels.matmul_ogs import can_use_persistent_tma
10+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs
1211
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
1312
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
13+
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
14+
from triton_kernels.tensor_details import layout
1415
# numerics utilities
1516
from triton_kernels.numerics import InFlexData, OutFlexData
16-
from triton_kernels.numerics_details.mxfp import SwizzlingType, downcast_to_mxfp, upcast_from_mxfp
17+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
1718
# testing utilities
1819
from triton_kernels.testing import assert_close, compute_actual_scale
1920
# target-specific utilities
@@ -53,20 +54,22 @@ def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_
5354
def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
5455
has_y_gammas, requires_grad=True, device="cuda"):
5556
torch.manual_seed(0)
56-
assert mode in {'batched', 'ragged'}
57+
assert mode in {'batched', "plain", 'ragged'}
5758
in_m = m * (n_expts_act if gindx is None else 1)
5859
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
60+
shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, )
5961
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
60-
w = alloc_rand((n_expts_tot // n_expt_shards, k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
61-
bias = alloc_rand((n_expts_tot // n_expt_shards, n), device=device, dtype=torch.float32,
62-
requires_grad=requires_grad)
62+
w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
63+
bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6364
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6465
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
6566
gs0 = gs0.detach().requires_grad_(requires_grad)
6667
gs1 = gs1.detach().requires_grad_(requires_grad)
6768
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
6869
gs0 = None
6970
gs1 = None
71+
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
72+
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
7073
return x, w, bias, gs0, gs1
7174

7275

@@ -75,7 +78,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7578
# ---------------
7679

7780

78-
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ctx=MicroscalingCtx(), device="cuda"):
81+
def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, device="cuda"):
7982
act_use_flexpoint = out_dtype.itemsize == 1
8083
weight_use_flexpoint = weight_dtype.itemsize == 1 and not is_mixed_input
8184
# flexpoint
@@ -95,7 +98,7 @@ def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ct
9598
out_data=out_flex_data(4.00, act_use_flexpoint),
9699
)
97100
return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0,
98-
mx_ctx=mx_ctx, out_dtype=out_dtype)
101+
out_dtype=out_dtype)
99102

100103

101104
def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
@@ -183,8 +186,10 @@ class Case:
183186
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
184187
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
185188
# mx types:
186-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
187-
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
189+
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
190+
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
191+
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
192+
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
188193
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
189194
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
190195
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
@@ -198,10 +203,10 @@ class Case:
198203
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
199204
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
200205
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
201-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
202-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
203-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
204-
Case(1000, 704, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
206+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
207+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
208+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
209+
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
205210
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
206211
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
207212
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
@@ -317,38 +322,32 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
317322
has_y_gammas, requires_grad=test_bwd, device=device)
318323
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
319324

320-
if is_mixed_input:
321-
if hbm_swizzling:
322-
swizzle_axis = 2
323-
if torch.cuda.get_device_capability()[0] < 10:
324-
swizzle_value = SwizzlingType.HOPPER
325-
swizzle_scale = SwizzlingType.HOPPER
326-
else:
327-
swizzle_value = None
328-
swizzle_scale = SwizzlingType.BLACKWELL
329-
else:
330-
swizzle_axis = None
331-
swizzle_value = None
332-
swizzle_scale = None
333-
w_tri, mx_scales_tri, weight_scale_shape = downcast_to_mxfp(w_tri, weight_dtype, axis=1,
334-
swizzle_axis=swizzle_axis,
335-
swizzle_value=swizzle_value,
336-
swizzle_scale=swizzle_scale)
337-
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=1, swizzle_axis=swizzle_axis,
338-
swizzle_value=swizzle_value, swizzle_scale=swizzle_scale)
339-
340-
precision_opt.mx_ctx = MicroscalingCtx(weight_scale=mx_scales_tri, swizzle_value=swizzle_value,
341-
swizzle_scale=swizzle_scale,
342-
actual_weight_scale_shape=weight_scale_shape)
343-
344-
if is_persistent and not can_use_persistent_tma(x_tri, w_tri, gindx, precision_opt):
345-
pytest.skip("persistent TMAs not supported for this test")
346-
347325
if w_tri.shape[0] == 1:
348326
# Test the case when weight has dim 2, i.e., shape (K, N).
349327
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
350328
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
351329

330+
if is_mixed_input:
331+
capability_major = torch.cuda.get_device_capability()[0]
332+
w_layout = layout.StridedLayout
333+
w_scale_layout = layout.StridedLayout
334+
if hbm_swizzling and "float4" in weight_dtype_str:
335+
# weight layout
336+
w_layouts = {9: layout.HopperMXValueLayout}
337+
w_layout = w_layouts.get(capability_major, layout.StridedLayout)
338+
# weight scale layout
339+
w_scales_layouts = {9: layout.HopperMXScaleLayout, 10: layout.BlackwellMXScaleLayout}
340+
w_scale_layout = w_scales_layouts.get(capability_major, layout.StridedLayout)
341+
w_tri, mx_scales_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=-2)
342+
w_ref = upcast_from_mxfp(w_tri, mx_scales_tri, torch.bfloat16, axis=-2)
343+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
344+
w_tri = convert_layout(wrap_torch_tensor(w_tri, w_tri_dtype), w_layout)
345+
mx_scales_tri = convert_layout(wrap_torch_tensor(mx_scales_tri), w_scale_layout)
346+
precision_opt.weight_scale = mx_scales_tri
347+
348+
# if not is_persistent and precision_opt.weight_scale is not None:
349+
# pytest.skip("non-persistent not supported with mxfp")
350+
352351
if test_launch_metadata:
353352

354353
def _clobber(t, used_mask):
@@ -394,7 +393,10 @@ def _hook(launch_metadata):
394393
flex = precision_opt.flex_ctx
395394

396395
# triton
397-
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
396+
try:
397+
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
398+
except (opt_flags.InapplicableConstraint, NotImplementedError):
399+
pytest.skip("inapplicable opt_flags constraint")
398400
# If split_k > 1, then the intermediate tensor is fp32.
399401
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
400402
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -498,16 +500,16 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
498500
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
499501
act_dtype, weight_dtype, False, requires_grad=False, device=device)
500502

501-
if is_persistent and not can_use_persistent_tma(x.view(1, x.shape[-2], x.shape[-1]),
502-
w.view(1, w.shape[-2], w.shape[-1]), gindx, precision_opt):
503-
pytest.skip("persistent TMAs not supported for this test")
504-
505503
if mode == "batched":
506504
rdata, gindx, sindx = None, None, None
507-
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
508-
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
509-
b = matmul_ogs(
510-
x, w, bias, rdata, gindx, sindx, precision_opt,
511-
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit),
512-
2))
505+
506+
try:
507+
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
508+
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
509+
b = matmul_ogs(
510+
x, w, bias, rdata, gindx, sindx, precision_opt,
511+
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
512+
(swiglu_alpha, swiglu_limit), 2))
513+
except opt_flags.InapplicableConstraint:
514+
pytest.skip("inapplicable constraint")
513515
assert_close(a, b)

0 commit comments

Comments
 (0)