Skip to content

Commit 2892882

Browse files
anmyachevyongjikpawelszczerbuknjriasanyiqian1
authored
Merge OpenAI Triton commit 272188c (#4612)
This PR change the Triton base from e21efcb to 272188c (Jun 26). Pass rate: 97.14% Please do not squash and merge this PR. --------- Co-authored-by: Yongjik Kim <[email protected]> Co-authored-by: pawelszczerbuk <[email protected]> Co-authored-by: Nick Riasanovsky <[email protected]> Co-authored-by: Yi Qian <[email protected]> Co-authored-by: Jeff Niu <[email protected]>
2 parents c017cf7 + 46b6ede commit 2892882

File tree

18 files changed

+262
-30
lines changed

18 files changed

+262
-30
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace mlir::triton {
1414
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
1515
// clang-format off
1616
"AMDGCN_ENABLE_DUMP",
17+
"AMDGCN_USE_BUFFER_ATOMICS",
1718
"AMDGCN_USE_BUFFER_OPS",
1819
"DISABLE_LLVM_OPT",
1920
"DISABLE_MMA_V3",

python/src/gluon_ir.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,10 @@ void init_gluon_ir(py::module &&m) {
391391
return self.create<ttng::WarpGroupDotOp>(
392392
a, b, acc, useAcc, precision, maxNumImpreciseAcc, isAsync);
393393
})
394-
394+
.def("create_warpgroup_mma_wait",
395+
[](GluonOpBuilder &self, std::vector<Value> &deps, int pendings) {
396+
self.create<ttng::WarpGroupDotWaitOp>(deps, pendings);
397+
})
395398
.def("create_tmem_alloc",
396399
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
397400
return self.create<ttng::TMEMAllocOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_async_copy_mbarrier():
100100

101101

102102
@gluon.jit
103-
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr):
103+
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, ASYNC: ttgl.constexpr):
104104
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
105105
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1],
106106
instr_shape=[16, 32, 16])
@@ -121,19 +121,23 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
121121
a_shmem = ttgl.allocate_shared_memory(ttgl.float16, [M, K], nvmma_layout, A)
122122
b_shmem = ttgl.allocate_shared_memory(ttgl.float16, [K, N], nvmma_layout, B)
123123

124-
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc)
124+
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc, is_async=ASYNC)
125+
126+
if ASYNC:
127+
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
125128

126129
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
127130

128131

129132
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
130-
def test_warpgroup_mma():
133+
@pytest.mark.parametrize("ASYNC", [True, False])
134+
def test_warpgroup_mma(ASYNC):
131135
torch.manual_seed(0)
132136
M, N, K = 64, 32, 32
133137
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
134138
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
135139
out = torch.zeros((M, N), device="cuda", dtype=torch.float16)
136-
warpgroup_mma_kernel[(1, )](a, b, out, M, N, K)
140+
warpgroup_mma_kernel[(1, )](a, b, out, M, N, K, ASYNC)
137141

138142
ref = torch.matmul(a, b)
139143

python/test/gluon/test_frontend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,33 @@ def test_warpgroup_mma(fresh_knobs):
482482
""")
483483

484484

485+
@gluon.jit
486+
def warpgroup_mma_wait_kernel():
487+
layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16])
488+
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=layout)
489+
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
490+
491+
492+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper WGMMA")
493+
def test_warpgroup_mma_wait(fresh_knobs):
494+
knobs.compilation.disable_line_info = True
495+
496+
h = warpgroup_mma_wait_kernel.warmup(grid=(1, ))
497+
expecttest.assert_expected_inline(
498+
anonymize_ir(h.asm["source"]), """\
499+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
500+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
501+
tt.func public @warpgroup_mma_wait_kernel() attributes {noinline = false} {
502+
%cst = arith.constant 0.000000e+00 : f16 loc(#loc)
503+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #mma> loc(#loc)
504+
%0 = ttng.warp_group_dot_wait %cst_0 {pendings = 1 : i32} : tensor<128x128xf16, #mma> loc(#loc)
505+
tt.return loc(#loc)
506+
} loc(#loc)
507+
} loc(#loc)
508+
#loc = loc(unknown)
509+
""")
510+
511+
485512
@gluon.jit
486513
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
487514
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)

python/test/unit/language/test_frontend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def test_constexpr_generator():
374374
generator(lhs)
375375

376376

377+
@tl.constexpr_function
377378
def Box(T):
378379

379380
@tl.core._aggregate
@@ -401,3 +402,23 @@ def kernel():
401402
anchor(value)
402403

403404
run_filecheck_test(kernel)
405+
406+
407+
@filecheck_test
408+
@triton.jit
409+
def test_modify_if_livein():
410+
# CHECK-LABEL: test_modify_if_livein
411+
none_livein = None # noqa: F841
412+
413+
# CHECK: [[LOOP_OUT:%.*]] = scf.for {{.*}} iter_args([[BOX:%.*]] = %true)
414+
# CHECK: [[LIVEOUT:%.*]] = scf.if [[BOX]]
415+
# CHECK: yield %false
416+
# CHECK: else
417+
# CHECK: yield [[BOX]]
418+
# CHECK: yield [[LIVEOUT]]
419+
# CHECK: call @{{.*}}anchor{{.*}}([[LOOP_OUT]])
420+
box = Box(tl.tensor)(tl.core.to_tensor(True))
421+
for i in range(10):
422+
if box.value:
423+
box.value = False
424+
anchor(box.value)

python/triton/compiler/code_generator.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -719,35 +719,40 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
719719
self.visit_compound_statement(node.body)
720720
then_block = self.builder.get_insertion_block()
721721
then_defs = self.local_defs.copy()
722+
then_vals = self.lscope.copy()
722723
# else block
723724
else_defs = {}
725+
else_vals = liveins.copy()
724726
if node.orelse:
725727
self.builder.set_insertion_point_to_start(else_block)
726728
self.lscope = liveins.copy()
727729
self.local_defs = {}
728730
self.visit_compound_statement(node.orelse)
729731
else_defs = self.local_defs.copy()
730732
else_block = self.builder.get_insertion_block()
733+
else_vals = self.lscope.copy()
731734

732735
# update block arguments
733736
names = []
734737
# variables in livein whose value is updated in `if`
735-
for name in liveins:
738+
for name, value in liveins.items():
739+
# livein variable changed value in either then or else
740+
if not _is_triton_value(value):
741+
continue
742+
then_handles = flatten_values_to_ir([then_vals[name]])
743+
else_handles = flatten_values_to_ir([else_vals[name]])
744+
if then_handles == else_handles:
745+
continue
746+
names.append(name)
747+
then_defs[name] = then_vals[name]
748+
else_defs[name] = else_vals[name]
736749
# check type
737750
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
738-
if name in defs:
739-
type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721
740-
assert type_equal and defs[name].type == liveins[name].type, \
741-
f'initial value for `{name}` is of type {liveins[name]}, '\
742-
f'but the {block_name} block redefines it as {defs[name]}'
743-
if name in then_defs or name in else_defs:
744-
names.append(name)
745-
# variable defined in then but not in else
746-
if name in then_defs and name not in else_defs:
747-
else_defs[name] = liveins[name]
748-
# variable defined in else but not in then
749-
if name in else_defs and name not in then_defs:
750-
then_defs[name] = liveins[name]
751+
type_equal = type(defs[name]) == type(value) # noqa: E721
752+
assert type_equal and defs[name].type == value.type, \
753+
f'initial value for `{name}` is of type {value}, '\
754+
f'but the {block_name} block redefines it as {defs[name]}'
755+
751756
# variables that are both in then and else but not in liveins
752757
# TODO: could probably be cleaned up
753758
for name in sorted(then_defs.keys() & else_defs.keys()):

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from . import mbarrier, tma
33
from ... import _core
44

5-
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma"]
5+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
66

77

88
@_core.builtin
@@ -25,3 +25,10 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
2525
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
2626
max_num_imprecise_acc, is_async)
2727
return _core.tensor(handle, acc.type)
28+
29+
30+
@_core.builtin
31+
def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
32+
deps = [x.handle for x in deps] if deps is not None else []
33+
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
34+
_semantic.builder.create_warpgroup_mma_wait(deps, num_outstanding)

python/triton/knobs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ class intel_knobs(base_knobs):
510510

511511
class amd_knobs(base_knobs):
512512
use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS", True)
513+
# Note: This requires use_buffer_ops be true to have any effect
514+
use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
513515
dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
514516
libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
515517
lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,37 @@ def convert_dtype(dtype):
7878

7979

8080
def matmul_launch_metadata(grid, kernel, args):
81+
from ..proton_opts import launch_metadata_allow_sync
82+
8183
ret = dict()
8284
M, N, K = args["M"], args["N"], args["K"]
8385
Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
86+
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
8487
hist = args["ExptHist"]
8588
if hist is not None:
86-
n_tokens = float(hist.sum())
87-
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
89+
# If annotation is given, use that to generate name for profiling.
90+
if tokens_per_expt is not None:
91+
n_rows = f"{tokens_per_expt}*"
92+
elif launch_metadata_allow_sync():
93+
n_rows = int(hist.float().mean())
94+
else:
95+
n_rows = "unknown"
96+
97+
if launch_metadata_allow_sync():
98+
n_tokens = float(hist.sum())
99+
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
100+
elif tokens_per_expt is not None:
101+
n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
102+
# This may not be totally correct (e.g., we might not be using all experts)
103+
# but it's better than nothing.
104+
n_w_bytes = W.numel() * W.element_size()
105+
else:
106+
n_tokens = None
107+
n_w_bytes = 0
88108

89109
# If annotation is given, use that to generate name for profiling.
90110
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
91-
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else int(hist.float().mean())
111+
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
92112
else:
93113
n_tokens = None
94114
n_w_bytes = W.numel() * W.element_size()
@@ -101,6 +121,10 @@ def matmul_launch_metadata(grid, kernel, args):
101121
ep_subtile = args["EPILOGUE_SUBTILE"]
102122
if ep_subtile is not None and ep_subtile > 1:
103123
ret["name"] += f" ep/{ep_subtile}"
124+
125+
if hist is not None and n_tokens is None:
126+
return ret # Don't fill metadata because we can't compute them properly.
127+
104128
fM = M if M is not None else n_tokens
105129
fK = K if K is not None else n_tokens
106130
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
@@ -115,7 +139,7 @@ def matmul_launch_metadata(grid, kernel, args):
115139
assert n_tokens is not None
116140
n_expts_act = args["N_EXPTS_ACT"]
117141

118-
if gindx is not None:
142+
if (gindx is not None) and launch_metadata_allow_sync():
119143
# recreate inverse GatherIndx.
120144
dst = torch.full_like(gindx, -1)
121145
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def _zero_masked_rows(
2929

3030

3131
_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
32-
@triton.jit(repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
32+
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
33+
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
3334
def _matmul_ogs(
3435
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3536
YExpectedScale, YActualScale, YChecksumScale,

0 commit comments

Comments
 (0)