Skip to content

Commit d9fcc10

Browse files
authored
[Gluon] Fix warp_specialize with constexprs, add a few APIs (#7097)
* where, maximum, minimum * add gluon_ir builder for fence_async_shared. Not sure what API to use * fix `ttgl.warp_specialize` passing constexpr arguments
1 parent d57cbee commit d9fcc10

File tree

5 files changed

+39
-16
lines changed

5 files changed

+39
-16
lines changed

python/src/gluon_ir.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ void init_gluon_ir(py::module &&m) {
299299
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
300300
src);
301301
})
302+
.def("create_fence_async_shared",
303+
[](GluonOpBuilder &self, bool bCluster) -> OpState {
304+
return self.create<ttng::FenceAsyncSharedOp>(bCluster);
305+
})
306+
302307
.def("create_broadcast",
303308
[](TritonOpBuilder &self, Value &arg, Type retTy) -> Value {
304309
return self.create<tt::BroadcastOp>(retTy, arg);

python/src/ir.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,12 +1425,7 @@ void init_triton_ir(py::module &&m) {
14251425
})
14261426
.def("create_expand_dims",
14271427
[](TritonOpBuilder &self, Value &arg, int axis) -> Value {
1428-
auto argType = dyn_cast<RankedTensorType>(arg.getType());
1429-
auto argEltType = argType.getElementType();
1430-
std::vector<int64_t> retShape = argType.getShape();
1431-
retShape.insert(retShape.begin() + axis, 1);
1432-
return self.create<ExpandDimsOp>(
1433-
RankedTensorType::get(retShape, argEltType), arg, axis);
1428+
return self.create<ExpandDimsOp>(arg, axis);
14341429
})
14351430
.def("create_cat",
14361431
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {

python/test/gluon/test_frontend.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,17 @@ def test_shared_memory_cast(fresh_knobs):
283283

284284

285285
@gluon.jit
286-
def warp_specialize_default(a, b):
286+
def warp_specialize_default(a, b, e: ttgl.constexpr):
287287
return b, a
288288

289289

290290
@gluon.jit
291-
def warp_specialize_worker0(a, b):
291+
def warp_specialize_worker0(a, b, e: ttgl.constexpr):
292292
pass
293293

294294

295295
@gluon.jit
296-
def warp_specialize_worker1(a, b):
296+
def warp_specialize_worker1(a, b, e: ttgl.constexpr):
297297
pass
298298

299299

@@ -322,15 +322,15 @@ def test_warp_specialize():
322322
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
323323
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
324324
# CHECK-NEXT: default {
325-
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
325+
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}cconstexpr_42{{.*}}([[A]], [[B]], [[C]])
326326
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
327327
# CHECK-NEXT: }
328328
# CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
329-
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
329+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2)
330330
# CHECK-NEXT: warp_return
331331
# CHECK-NEXT: }
332332
# CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
333-
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
333+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2)
334334
# CHECK-NEXT: warp_return
335335
# CHECK-NEXT: }
336336
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
@@ -340,8 +340,9 @@ def test_warp_specialize():
340340
b = ttgl.arange(0, 2, layout=layout)
341341
c = ttgl.arange(0, 4, layout=layout)
342342
pair = Pair(a, b)
343-
a, b = ttgl.warp_specialize((pair, c), warp_specialize_default, [warp_specialize_worker0, warp_specialize_worker1],
344-
[4, 4], [24, 48])
343+
e: ttgl.constexpr = 42
344+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
345+
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
345346
anchor(a)
346347
anchor(b)
347348

@@ -781,3 +782,23 @@ def test_reduce(fresh_knobs):
781782
} loc(#loc)
782783
} loc(#loc)
783784
""")
785+
786+
787+
@filecheck_test
788+
@gluon.jit
789+
def test_elementwise_core():
790+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
791+
# CHECK: @test_elementwise_core
792+
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0])
793+
x = ttgl.arange(0, 16, layout)
794+
y = ttgl.arange(16, 32, layout)
795+
796+
# CHECK: arith.select {{.*}} : tensor<16xi1, [[BLOCKED]]>, tensor<16xi32, [[BLOCKED]]>
797+
a = ttgl.where(x > 8, x, y)
798+
# CHECK: arith.maxsi {{.*}} : tensor<16xi32, [[BLOCKED]]>
799+
b = ttgl.maximum(x, y)
800+
# CHECK: arith.minsi {{.*}} : tensor<16xi32, [[BLOCKED]]>
801+
c = ttgl.minimum(x, y)
802+
ttgl.static_assert(a.type == x.type)
803+
ttgl.static_assert(b.type == x.type)
804+
ttgl.static_assert(c.type == x.type)

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
"static_assert", # NOQA: F822
5050
"store", # NOQA: F822
5151
"to_tensor", # NOQA: F822
52+
"where", # NOQA: F822
53+
"maximum", # NOQA: F822
54+
"minimum", # NOQA: F822
5255
]
5356

5457
__all__ = [
@@ -303,6 +306,5 @@ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps
303306
_semantic=None, _generator=None):
304307
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
305308
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
306-
args = [_unwrap_if_constexpr(arg) for arg in args]
307309
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
308310
worker_num_regs, _generator)

python/triton/language/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,7 @@ def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder
16831683

16841684
scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
16851685
region_builder_fn(scan_op)
1686-
scan_op.verify()
1686+
assert scan_op.verify()
16871687

16881688
return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
16891689

0 commit comments

Comments
 (0)