Skip to content

Commit 402d57c

Browse files
Merge commit '607c50cc9fdd2541db88b5a8681164f081dd71ad'
2 parents 0c81300 + 607c50c commit 402d57c

File tree

10 files changed

+102
-55
lines changed

10 files changed

+102
-55
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ test-unit: all
3636
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
3737
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3838
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
39-
# Run cuda/test_flashattention.py separately to avoid out of gpu memory
40-
$(PYTEST) -s python/test/unit/cuda/test_flashattention.py
39+
# Run attention separately to avoid out of gpu memory
40+
TRITON_PRINT_AUTOTUNING=1 $(PYTEST) -vs python/tutorials/06-fused-attention.py
4141
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4242
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4343

python/test/unit/language/test_standard.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,26 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
6565

6666

6767
@pytest.mark.interpreter
68-
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
68+
@pytest.mark.parametrize("M, N, K", [[1, 16, 64], [8, 2, 256], [32, 1, 2], [128, 8, 1]])
6969
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
70-
def test_flip(M, N, dtype_str, device):
70+
@pytest.mark.parametrize("dim", [0, 1, 2, -2])
71+
def test_flip(M, N, K, dtype_str, dim, device):
7172

7273
@triton.jit
73-
def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr):
74-
offx = tl.arange(0, M)
75-
offy = tl.arange(0, N) * M
76-
off2d = offx[None, :] + offy[:, None]
77-
x = tl.load(X + off2d)
78-
x = tl.flip(x)
79-
tl.store(Z + off2d, x)
80-
81-
x = numpy_random((N, M), dtype_str=dtype_str)
74+
def flip_kernel(X, Z, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, dim: tl.constexpr):
75+
offx = tl.arange(0, M) * N * K
76+
offy = tl.arange(0, N) * K
77+
offz = tl.arange(0, K)
78+
off3d = offx[:, None, None] + offy[None, :, None] + offz[None, None, :]
79+
x = tl.load(X + off3d)
80+
x = tl.flip(x, dim)
81+
tl.store(Z + off3d, x)
82+
83+
x = numpy_random((M, N, K), dtype_str=dtype_str)
8284
x = torch.from_numpy(x).to(device)
83-
y = torch.flip(x, (1, ))
85+
y = torch.flip(x, (dim, ))
8486
z = torch.empty_like(x, device=device)
85-
flip_kernel[(1, )](x, z, N, M, num_warps=8)
87+
flip_kernel[(1, )](x, z, M, N, K, dim, num_warps=8)
8688
assert (y == z).all(), (y, z)
8789

8890

python/test/unit/language/test_tuple.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,38 @@ def mul(x, a):
162162
ty = Tensor(y, y.shape, y.stride())
163163
_namedtuple_kernel[(1, )](function, tx, ty, 64, 64)
164164
assert torch.allclose(y, x[:16, :16] * a)
165+
166+
167+
@pytest.mark.interpreter
168+
def test_eq(device):
169+
170+
@triton.jit
171+
def fn(ret_ptrs):
172+
tl.store(ret_ptrs + 0, (1, 2) == (1, 2))
173+
tl.store(ret_ptrs + 1, (1, 2) == (1, 1))
174+
tl.store(ret_ptrs + 2, tl.tuple((1, 2)) == (1, 2))
175+
tl.store(ret_ptrs + 3, tl.tuple((1, 2)) == (1, 3))
176+
177+
rets = torch.zeros((4, ), dtype=torch.int32, device=device)
178+
fn[(1, )](rets)
179+
assert rets[0].item() == 1
180+
assert rets[1].item() == 0
181+
assert rets[2].item() == 1
182+
assert rets[3].item() == 0
183+
184+
185+
@pytest.mark.interpreter
186+
def test_add(device):
187+
188+
@triton.jit
189+
def fn(ret_ptrs):
190+
tuple0 = ((0, 1)) + (2, 3)
191+
for i in tl.static_range(4):
192+
tl.store(ret_ptrs + i, tuple0[i])
193+
tuple1 = tl.tuple((4, 5)) + (6, 7)
194+
for i in tl.static_range(4):
195+
tl.store(ret_ptrs + 4 + i, tuple1[i])
196+
197+
rets = torch.zeros((8, ), dtype=torch.int32, device=device)
198+
fn[(1, )](rets)
199+
torch.testing.assert_close(rets.cpu(), torch.arange(8, dtype=torch.int32))

python/triton/language/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ def _unwrap_if_constexpr(o):
306306
return o.value if isinstance(o, constexpr) else o
307307

308308

309+
def _normalize_tuple(t):
310+
normalized_tuple = _unwrap_if_constexpr(t)
311+
if isinstance(normalized_tuple, (list, builtins.tuple)):
312+
normalized_tuple = tuple(normalized_tuple)
313+
return normalized_tuple
314+
315+
309316
def check_bit_width(value, shift_value):
310317
if isinstance(value, tensor) and isinstance(shift_value, constexpr):
311318
bitwidth = value.type.scalar.primitive_bitwidth
@@ -1069,7 +1076,6 @@ def __not__(self, _builder=None):
10691076

10701077
@builtin
10711078
def __getitem__(self, slices, _builder=None):
1072-
import builtins
10731079
if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
10741080
slices = [slices]
10751081
if isinstance(slices, tuple):
@@ -1237,7 +1243,7 @@ def flip(self, dim=None) -> tensor:
12371243

12381244
class tuple(base_value):
12391245

1240-
def __init__(self, args: list, type: tuple_type = None):
1246+
def __init__(self, args: Sequence, type: tuple_type = None):
12411247
self.values = [i for i in args]
12421248

12431249
def get_type(x):
@@ -1255,7 +1261,6 @@ def __getitem__(self, idx: constexpr):
12551261
if isinstance(idx, constexpr):
12561262
return self.values[idx]
12571263
else:
1258-
import builtins
12591264
assert isinstance(idx, (slice, builtins.slice))
12601265
return tuple(self.values[idx.start:idx.stop:idx.step])
12611266

@@ -1270,8 +1275,7 @@ def __setitem__(self, idx: constexpr, value):
12701275
self.values[idx] = value
12711276

12721277
def __add__(self, other):
1273-
if isinstance(other, list):
1274-
other = tuple(other)
1278+
other = _normalize_tuple(other)
12751279
return tuple(self.values + other.values)
12761280
# return tuple(a + b for a, b in zip(self.values, other.values))
12771281

@@ -1280,13 +1284,10 @@ def __mul__(self, other):
12801284
return tuple(self.values * other.value)
12811285

12821286
def __eq__(self, other):
1283-
import builtins
1284-
if isinstance(other, (list, builtins.tuple)):
1285-
other = tuple(other)
1287+
other = _normalize_tuple(other)
12861288
return constexpr(self.values == other.values)
12871289

12881290
def __hash__(self):
1289-
import builtins
12901291
return hash(builtins.tuple(self.values))
12911292

12921293
def __str__(self):

python/triton/language/standard.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,8 @@ def _get_flip_dim(dim, shape):
475475
shape = core._unwrap_if_constexpr(shape)
476476
if dim is None:
477477
dim = len(shape) - 1
478-
assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
478+
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
479+
dim += len(shape)
479480
return core.constexpr(dim)
480481

481482

@@ -487,20 +488,19 @@ def flip(x, dim=None):
487488
488489
:param x: the first input tensor
489490
:type x: Block
490-
:param dim: the dimension to flip along (currently only final dimension supported)
491+
:param dim: the dimension to flip along
491492
:type dim: int
492493
"""
493-
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
494-
core.static_assert(_is_power_of_two(x.numel))
495-
# reshape the tensor to have all dimensions be 2.
496-
# TODO: We shouldn't have to change the dimensions not sorted.
497-
steps: core.constexpr = _log2(x.numel)
498-
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
494+
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
495+
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
496+
core.static_assert(_is_power_of_two(x.shape[_dim]))
497+
steps: core.constexpr = _log2(x.shape[_dim])
499498

499+
# reshape the swap dimension to (2, 2, ..., 2)
500500
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
501-
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
502-
for i in core.static_range(start, steps):
503-
y = y ^ xor_sum(y, i, True)
501+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
502+
for i in core.static_range(steps):
503+
y = y ^ xor_sum(y, _dim + i, True)
504504
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
505505
return x
506506

python/triton_kernels/bench/bench_mlp.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _query_gpu_specs():
3838

3939
gpu_specs = {
4040
"NVIDIA H100 80GB HBM3": {"MAX_TFLOPS8": 1979, "MAX_TFLOPS16": 989, "MAX_TBPS": 3.35},
41-
"HGX GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
41+
"NVIDIA GB200": {"MAX_TFLOPS8": 4500, "MAX_TFLOPS16": 2250, "MAX_TBPS": 8.0},
4242
"AMD Instinct MI300X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 5.3},
4343
"AMD Instinct MI325X": {"MAX_TFLOPS8": 2615, "MAX_TFLOPS16": 1307, "MAX_TBPS": 6.0},
4444
}
@@ -175,7 +175,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
175175
batches = list(chain(*[range(*r) for r in batch_ranges]))
176176
# collect performance data
177177
perfs = []
178-
print(f"Benchmarking {name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})...")
178+
bench_case = f"{name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})"
179+
print(f"Benchmarking {bench_case}...")
179180
print("===============================================================")
180181
for batch in batches:
181182
perfs += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name)]
@@ -186,7 +187,7 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
186187
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
187188
ax.set_xlabel("batch size (toks/expt)")
188189
ax.set_ylabel("performance [TFLOP/s]")
189-
ax.set_title("roofline")
190+
ax.set_title(f"{bench_case} roofline")
190191
# add a tiny margin so points are not flush with the frame
191192
xs = [batch * n_expts_act / n_expts_tot for batch in batches]
192193
perf = [p.tflops for p in perfs]
@@ -200,7 +201,8 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
200201
opints = [p.opint for p in perfs]
201202
knee = bisect_left(opints, max_tflops / max_tbps) - 1
202203
x_bw, x_comp = xs[:knee], xs[knee:]
203-
y_bw = [op * max_tbps for op in opints[:knee]]
204+
x_bw = [x_bw[0], x_comp[0]]
205+
y_bw = [opints[0] * max_tbps, max_tflops]
204206
y_comp = [max_tflops] * len(x_comp)
205207
ax.plot(x_bw, y_bw, "--", label=f"BW-bound ({max_tbps:.0f} TB/s)")
206208
ax.plot(x_comp, y_comp, "--", label=f"Compute-bound ({max_tflops:.0f} TFLOP/s)")
@@ -217,10 +219,11 @@ def roofline_mlp(batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_
217219
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
218220
if SPECS is None:
219221
print("Current GPU has no specs provided, utilization is N/A")
220-
batch_ranges = [(1024, 32768, 1024)]
222+
batch_ranges_dense = [(1024, 32768, 1024)]
223+
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
221224
dense_dtypes = ["fp8", "fp8"]
222225
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
223-
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
224-
roofline_mlp(batch_ranges, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
225-
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
226-
roofline_mlp(batch_ranges, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
226+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *dense_dtypes, TP=1, EP=1, name="dense")
227+
roofline_mlp(batch_ranges_dense, 8192, 8192, 1, 1, *quantized_dtypes, TP=1, EP=1, name="dense")
228+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
229+
roofline_mlp(batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")

python/triton_kernels/triton_kernels/routing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
6262
HIST_BLOCK_M = 64
6363
INDX_OFFS_BLOCK_M = 512
6464
MEMSET_BLOCK = 1024
65-
assert logits.dtype.itemsize == 2
6665
n_tokens, n_expts_tot = logits.shape
6766
n_gates = n_tokens * n_expts_act
6867
device = logits.device

python/triton_kernels/triton_kernels/topk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ def topk(x, k, dim=1, return_bitmatrix=True):
77
cdiv = lambda a, b: (a + b - 1) // b
88
BLOCK_M = 8
99
BLOCK_N = 128
10-
assert x.dtype.itemsize == 2
1110
assert x.ndim == 2
1211
assert x.shape[-1] < 32768
1312
assert dim == 1

python/triton_kernels/triton_kernels/topk_details/_topk.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
@triton.jit
66
def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
77
BLOCK_N: tl.constexpr):
8+
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
9+
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
10+
x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}")
11+
x_dbtype: tl.constexpr = tl.dtype(f"fp{2*x_nbits}")
812

913
# subtract 1 from loop iterations because we peel the first (masked) iteration:
1014
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
@@ -15,8 +19,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
1519
# first iteration:
1620
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
1721
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
18-
x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :]
19-
x = x.to(tl.float32, bitcast=True)
22+
x = (x.to(x_utype, bitcast=True).to(x_ultype) << x_nbits) | offs_x_n[None, :]
23+
x = x.to(x_dbtype, bitcast=True)
2024

2125
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
2226

@@ -26,8 +30,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
2630
X_ptrs -= BLOCK_N
2731
offs_x_n -= BLOCK_N
2832
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
29-
x = (x.to(tl.uint16, bitcast=True).to(tl.int32) << 16) | offs_x_n[None, :]
30-
x = x.to(tl.float32, bitcast=True)
33+
x = (x.to(x_utype, bitcast=True).to(x_ultype) << x_nbits) | offs_x_n[None, :]
34+
x = x.to(x_dbtype, bitcast=True)
3135
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
3236

3337
return acc
@@ -43,18 +47,21 @@ def _topk(X, stride_xm, # inputs
4347
tl.static_assert(BLOCK_N % 32 == 0)
4448
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
4549
x_dtype: tl.constexpr = X.dtype.element_ty
50+
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
51+
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
52+
x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}")
4653

4754
# load logits
4855
offs_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
4956
mask_m = offs_m[:, None] < n_rows
5057
y = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
51-
y = y.to(tl.uint32, bitcast=True)
58+
y = y.to(x_ultype, bitcast=True)
5259

5360
# sort result in direction of ascending expert index
54-
y = (y << 16) | (y >> 16)
61+
y = (y << x_nbits) | (y >> x_nbits)
5562
y = tl.sort(y, dim=1)
56-
y_indices = y >> 16
57-
y_values = (y & 0x0000FFFF).to(tl.uint16).to(x_dtype, bitcast=True)
63+
y_indices = y >> x_nbits
64+
y_values = (y & ((1 << x_nbits) - 1)).to(x_utype).to(x_dtype, bitcast=True)
5865
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
5966

6067
# write back

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,9 @@ struct FpToFpOpConversion
428428
Fp16_to_Fp8E5M2_RTNE(computeCapability >= 89)},
429429
{{F16TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp16_to_Fp8E5M2_RTZ},
430430
// F8 -> BF16
431+
// mul{.rnd}.bf16 and mul{.rnd}.bf16x2 requires sm_90 or higher.
431432
{{F8E5M2TyID, BF16TyID, undefRounding},
432-
Fp8E5M2_to_Bf16(computeCapability >= 89)},
433+
Fp8E5M2_to_Bf16(computeCapability >= 90)},
433434
{{F8E4M3TyID, BF16TyID, undefRounding},
434435
Fp8E4M3Nv_to_Bf16(computeCapability >= 89)},
435436
// BF16 -> F8

0 commit comments

Comments
 (0)