Skip to content

Commit d0d6a4b

Browse files
authored
[Gluon] Expose ttg.warp_specialize (#6989)
📚 Stack PRs 📚 1. triton-lang/triton#6988 2. ➡️ triton-lang/triton#6989 This PR adds a function to expose `ttg.warp_specialize` as a list of functions, where the default function is allowed to return values that are passed through the default region. It also moves the filecheck testing to a common package to be used across various unit tests.
1 parent 2bd6c6f commit d0d6a4b

File tree

11 files changed

+300
-107
lines changed

11 files changed

+300
-107
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,9 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
468468
let builders = [
469469
OpBuilder<(ins "TypeRange":$resultTypes,
470470
"ArrayRef<int32_t>":$partitionNumWarps,
471-
"unsigned":$numPartitionRegions)>
471+
"unsigned":$numPartitionRegions)>,
472+
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$explicitCaptures,
473+
"ArrayRef<int32_t>":$partitionNumWarps)>,
472474
];
473475

474476
let hasVerifier = 1;

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,13 @@ void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
886886
partitionNumRegions);
887887
}
888888

889+
void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
890+
TypeRange resultTypes, ValueRange explicitCaptures,
891+
ArrayRef<int32_t> partitionNumWarps) {
892+
build(builder, state, resultTypes, explicitCaptures, partitionNumWarps, {},
893+
{}, {});
894+
}
895+
889896
ParseResult WarpSpecializeOp::parse(OpAsmParser &p, OperationState &result) {
890897
SmallVector<OpAsmParser::UnresolvedOperand> operands;
891898
SMLoc operandLoc = p.getCurrentLocation();

python/src/gluon_ir.cc

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace ttg = triton::gpu;
1515
struct GluonOpBuilder : public TritonOpBuilder {};
1616

1717
void init_gluon_ir(py::module &&m) {
18+
using ret = py::return_value_policy;
19+
1820
py::class_<GluonOpBuilder, TritonOpBuilder>(
1921
m, "GluonOpBuilder", py::module_local(), py::dynamic_attr())
2022
.def(py::init<MLIRContext *>())
@@ -82,5 +84,36 @@ void init_gluon_ir(py::module &&m) {
8284
.def("create_local_load",
8385
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
8486
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
85-
});
87+
})
88+
89+
.def("create_warp_return",
90+
[](GluonOpBuilder &self) -> Operation * {
91+
return self.create<ttg::WarpReturnOp>();
92+
})
93+
.def("create_warp_yield",
94+
[](GluonOpBuilder &self, std::vector<Value> &values) -> Operation * {
95+
return self.create<ttg::WarpYieldOp>(values);
96+
})
97+
.def("create_warp_specialize_partitions",
98+
[](GluonOpBuilder &self, int numPartitions) -> Operation * {
99+
return self.create<ttg::WarpSpecializePartitionsOp>(numPartitions);
100+
})
101+
.def("create_warp_specialize", [](GluonOpBuilder &self,
102+
std::vector<Type> &resultTypes,
103+
std::vector<Value> &explicitCaptures,
104+
std::vector<int> &partitionNumWarps) {
105+
return self.create<ttg::WarpSpecializeOp>(resultTypes, explicitCaptures,
106+
partitionNumWarps);
107+
});
108+
109+
py::class_<ttg::WarpSpecializeOp, OpState>(m, "WarpSpecializeOp",
110+
py::module_local())
111+
.def("get_default_region", &ttg::WarpSpecializeOp::getDefaultRegion,
112+
ret::reference)
113+
.def("get_partition_op_holder",
114+
&ttg::WarpSpecializeOp::getPartitionOpHolder, ret::reference)
115+
.def("set_requested_registers", [](ttg::WarpSpecializeOp &self,
116+
std::vector<int> &requestedRegisters) {
117+
self.setRequestedRegisters(requestedRegisters);
118+
});
86119
}

python/src/ir.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,11 @@ void init_triton_ir(py::module &&m) {
382382
.def("get_parent_region", &Region::getParentRegion, ret::reference)
383383
.def("size", [](Region &self) { return self.getBlocks().size(); })
384384
.def("empty", &Region::empty)
385-
.def("id", [](Region &self) { return (uint64_t)&self; });
385+
.def("id", [](Region &self) { return (uint64_t)&self; })
386+
.def("push_back",
387+
[](Region &self, Block *block) { self.push_back(block); })
388+
.def("push_front",
389+
[](Region &self, Block *block) { self.push_front(block); });
386390

387391
py::class_<Block>(m, "block", py::module_local())
388392
.def("arg",
@@ -492,13 +496,23 @@ void init_triton_ir(py::module &&m) {
492496
self->print(os, printingFlags);
493497
return str;
494498
})
499+
.def("str_nodebug",
500+
[](OpState &self) -> std::string {
501+
std::string str;
502+
llvm::raw_string_ostream os(str);
503+
self->print(os);
504+
return str;
505+
})
495506
.def("append_operand",
496507
[](OpState &self, Value &val) {
497508
self->insertOperands(self->getNumOperands(), val);
498509
})
499-
.def("verify", [](OpState &self) -> bool {
500-
return succeeded(verify(self.getOperation()));
501-
});
510+
.def("verify",
511+
[](OpState &self) -> bool {
512+
return succeeded(verify(self.getOperation()));
513+
})
514+
.def("get_operation", [](OpState &self) { return self.getOperation(); });
515+
502516
// scf Ops
503517
py::class_<scf::ForOp, OpState>(m, "ForOp", py::module_local())
504518
.def("get_induction_var", &scf::ForOp::getInductionVar);

python/test/gluon/test_frontend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from triton import knobs
44
from triton.experimental import gluon
55
from triton.experimental.gluon import language as ttgl
6+
from triton._filecheck import filecheck_test
7+
import triton.language as tl
68

79

810
@gluon.jit
@@ -68,3 +70,62 @@ def test_shared_memory(fresh_knobs):
6870
} loc(#loc)
6971
#loc = loc(unknown)
7072
""")
73+
74+
75+
@gluon.jit
76+
def warp_specialize_default(a, b):
77+
return b, a
78+
79+
80+
@gluon.jit
81+
def warp_specialize_worker0(a, b):
82+
pass
83+
84+
85+
@gluon.jit
86+
def warp_specialize_worker1(a, b):
87+
pass
88+
89+
90+
@tl.core._aggregate
91+
class Pair:
92+
first: tl.tensor
93+
second: tl.tensor
94+
95+
def __init__(self, first, second):
96+
self.first = first
97+
self.second = second
98+
99+
100+
@gluon.jit
101+
def anchor(x):
102+
pass
103+
104+
105+
@filecheck_test
106+
@gluon.jit
107+
def test_warp_specialize():
108+
# CHECK-LABEL: tt.func public @test_warp_specialize
109+
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
110+
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
111+
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
112+
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
113+
# CHECK-NEXT: default {
114+
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @"warp_specialize_default{{.*}}"([[A]], [[B]], [[C]])
115+
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
116+
# CHECK-NEXT: }
117+
# CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
118+
# CHECK-NEXT: call @"warp_specialize_worker0{{.*}}"(%arg0, %arg1, %arg2)
119+
# CHECK-NEXT: warp_return
120+
# CHECK-NEXT: }
121+
# CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
122+
# CHECK-NEXT: call @"warp_specialize_worker1{{.*}}"(%arg0, %arg1, %arg2)
123+
# CHECK-NEXT: warp_return
124+
# CHECK-NEXT: }
125+
# CHECK-NEXT: call @anchor{{.*}}([[OUTS]]#0)
126+
# CHECK-NEXT: call @"anchor{{.*}}"([[OUTS]]#1, [[OUTS]]#2)
127+
pair = Pair(tl.arange(0, 1), tl.arange(0, 2))
128+
a, b = ttgl.warp_specialize((pair, tl.arange(0, 4)), warp_specialize_default,
129+
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
130+
anchor(a)
131+
anchor(b)

python/test/unit/language/test_frontend.py

Lines changed: 2 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,17 @@
1-
import sys
2-
import os
3-
import io
4-
import inspect
5-
6-
from filecheck.options import Options
7-
from filecheck.finput import FInput
8-
from filecheck.parser import Parser, pattern_for_opts
9-
from filecheck.matcher import Matcher
10-
111
import triton
122
import triton.language as tl
13-
from triton.compiler import ASTSource, make_backend
14-
from triton.backends.compiler import GPUTarget
15-
from triton._C.libtriton import ir
16-
17-
import pytest
3+
from triton._filecheck import filecheck_test
184

195
# ===-----------------------------------------------------------------------===#
20-
# filecheck_test
6+
# Unit Tests
217
# ===-----------------------------------------------------------------------===#
228

23-
# Stub target for testing the frontend.
24-
stub_target = GPUTarget("cuda", 100, 32)
25-
stub_backend = make_backend(stub_target)
26-
27-
llvm_bin_dir = os.path.join(os.path.dirname(sys.executable), "bin")
28-
filecheck_path = os.path.join(llvm_bin_dir, "FileCheck")
29-
30-
31-
def run_filecheck(name, module_str, check_template):
32-
options = Options(match_filename=name)
33-
fin = FInput(name, module_str)
34-
ops = io.StringIO(check_template)
35-
parser = Parser(options, ops, *pattern_for_opts(options))
36-
matcher = Matcher(options, fin, parser)
37-
matcher.stderr = io.StringIO()
38-
if matcher.run() != 0:
39-
raise ValueError(matcher.stderr.getvalue())
40-
41-
42-
def run_parser(kernel_fn):
43-
sigkeys = [x.name for x in kernel_fn.params]
44-
sigvals = [f"arg{i}" for i in range(len(sigkeys))]
45-
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
46-
src = ASTSource(fn=kernel_fn, signature=signature)
47-
48-
context = ir.context()
49-
ir.load_dialects(context)
50-
stub_backend.load_dialects(context)
51-
52-
extra_options = src.parse_options()
53-
options = stub_backend.parse_options(dict(**extra_options))
54-
codegen_fns = stub_backend.get_codegen_implementation(options)
55-
module_map = stub_backend.get_module_map()
56-
return src.make_ir(options, codegen_fns, module_map, context)
57-
58-
59-
def run_filecheck_test(kernel_fn):
60-
assert isinstance(kernel_fn, triton.runtime.JITFunction)
61-
check_template = inspect.getsource(kernel_fn.fn)
62-
if check_template is None:
63-
raise ValueError("kernel function must have a docstring with FileCheck template")
64-
mlir_module = run_parser(kernel_fn)
65-
66-
run_filecheck("placeholder", str(mlir_module), check_template)
67-
689

6910
@triton.jit
7011
def anchor(v):
7112
pass
7213

7314

74-
# Smoke test to make sure filecheck is working correctly.
75-
def test_filecheck_positive():
76-
77-
@triton.jit
78-
def test_kernel():
79-
# CHECK-LABEL: test_kernel
80-
scalar = 42
81-
# CHECK: %c42_i32 = arith.constant 42 : i32
82-
# CHECK-NEXT: call @anchor{{.*}}(%c42_i32) : (i32) -> ()
83-
anchor(scalar)
84-
85-
run_filecheck_test(test_kernel)
86-
87-
88-
def test_filecheck_negative():
89-
90-
@triton.jit
91-
def test_kernel():
92-
# CHECK-LABEL: test_kernel
93-
scalar = 11
94-
# CHECK: %c42_i32
95-
anchor(scalar)
96-
97-
with pytest.raises(ValueError, match="Couldn't match \"%c42_i32\""):
98-
run_filecheck_test(test_kernel)
99-
100-
101-
def filecheck_test(fn):
102-
103-
def test_fn():
104-
run_filecheck_test(fn)
105-
106-
return test_fn
107-
108-
109-
# ===-----------------------------------------------------------------------===#
110-
# Unit Tests
111-
# ===-----------------------------------------------------------------------===#
112-
113-
11415
@tl.core._aggregate
11516
class Pair:
11617
first: tl.tensor

python/test/unit/test_filecheck.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
import triton
3+
4+
from triton._filecheck import run_filecheck_test
5+
6+
7+
@triton.jit
8+
def anchor(v):
9+
pass
10+
11+
12+
# Smoke test to make sure filecheck is working correctly.
13+
def test_filecheck_positive():
14+
15+
@triton.jit
16+
def test_kernel():
17+
# CHECK-LABEL: test_kernel
18+
scalar = 42
19+
# CHECK: %c42_i32 = arith.constant 42 : i32
20+
# CHECK-NEXT: call @anchor{{.*}}(%c42_i32) : (i32) -> ()
21+
anchor(scalar)
22+
23+
run_filecheck_test(test_kernel)
24+
25+
26+
def test_filecheck_negative():
27+
28+
@triton.jit
29+
def test_kernel():
30+
# CHECK-LABEL: test_kernel
31+
scalar = 11
32+
# CHECK: %c42_i32
33+
anchor(scalar)
34+
35+
with pytest.raises(ValueError, match="Couldn't match \"%c42_i32\""):
36+
run_filecheck_test(test_kernel)

python/triton/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from . import language
2626
from . import testing
2727
from . import tools
28+
from ._filecheck import run_filecheck_test, filecheck_test, run_parser
2829

2930
must_use_result = language.core.must_use_result
3031

@@ -51,6 +52,9 @@
5152
"TritonError",
5253
"testing",
5354
"tools",
55+
"run_filecheck_test",
56+
"filecheck_test",
57+
"run_parser",
5458
]
5559

5660
# -------------------------------------

0 commit comments

Comments
 (0)