Skip to content

Commit 3597ff1

Browse files
authored
[Gluon] Fix warp_specialize when the default function has no results (#7145)
1 parent 9e7dfc6 commit 3597ff1

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/test/gluon/test_frontend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def anchor(x):
321321
@filecheck_test
322322
@gluon.jit
323323
def test_warp_specialize():
324-
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
324+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
325325
# CHECK-LABEL: test_warp_specialize
326326
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
327327
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
@@ -352,6 +352,10 @@ def test_warp_specialize():
352352
anchor(a)
353353
anchor(b)
354354

355+
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
356+
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
357+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
358+
355359

356360
@gluon.jit
357361
def mbarrier_kernel():

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
257257
default_block = builder.new_block()
258258
builder.set_insertion_point_to_start(default_block)
259259
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
260-
mlir_results = flatten_values_to_ir(default_results)
260+
mlir_results = []
261+
if default_results is not None:
262+
mlir_results = flatten_values_to_ir(default_results)
261263
builder.create_warp_yield(mlir_results)
262264
result_types = [r.get_type() for r in mlir_results]
263265

@@ -281,4 +283,6 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
281283

282284
builder.set_insertion_point_after(ws_op.get_operation())
283285
mlir_results = [ws_op.get_result(i) for i in range(len(result_types))]
286+
if default_results is None:
287+
return
284288
return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results]))

0 commit comments

Comments
 (0)