@@ -283,17 +283,17 @@ def test_shared_memory_cast(fresh_knobs):
283
283
284
284
285
285
@gluon .jit
286
- def warp_specialize_default (a , b ):
286
+ def warp_specialize_default (a , b , e : ttgl . constexpr ):
287
287
return b , a
288
288
289
289
290
290
@gluon .jit
291
- def warp_specialize_worker0 (a , b ):
291
+ def warp_specialize_worker0 (a , b , e : ttgl . constexpr ):
292
292
pass
293
293
294
294
295
295
@gluon .jit
296
- def warp_specialize_worker1 (a , b ):
296
+ def warp_specialize_worker1 (a , b , e : ttgl . constexpr ):
297
297
pass
298
298
299
299
@@ -322,15 +322,15 @@ def test_warp_specialize():
322
322
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
323
323
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
324
324
# CHECK-NEXT: default {
325
- # CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
325
+ # CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}cconstexpr_42{{.*}} ([[A]], [[B]], [[C]])
326
326
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
327
327
# CHECK-NEXT: }
328
328
# CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
329
- # CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
329
+ # CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}cconstexpr_42{{.*}} (%arg0, %arg1, %arg2)
330
330
# CHECK-NEXT: warp_return
331
331
# CHECK-NEXT: }
332
332
# CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
333
- # CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
333
+ # CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}} (%arg0, %arg1, %arg2)
334
334
# CHECK-NEXT: warp_return
335
335
# CHECK-NEXT: }
336
336
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
@@ -340,8 +340,9 @@ def test_warp_specialize():
340
340
b = ttgl .arange (0 , 2 , layout = layout )
341
341
c = ttgl .arange (0 , 4 , layout = layout )
342
342
pair = Pair (a , b )
343
- a , b = ttgl .warp_specialize ((pair , c ), warp_specialize_default , [warp_specialize_worker0 , warp_specialize_worker1 ],
344
- [4 , 4 ], [24 , 48 ])
343
+ e : ttgl .constexpr = 42
344
+ a , b = ttgl .warp_specialize ((pair , c , e ), warp_specialize_default ,
345
+ [warp_specialize_worker0 , warp_specialize_worker1 ], [4 , 4 ], [24 , 48 ])
345
346
anchor (a )
346
347
anchor (b )
347
348
@@ -781,3 +782,23 @@ def test_reduce(fresh_knobs):
781
782
} loc(#loc)
782
783
} loc(#loc)
783
784
""" )
785
+
786
+
787
+ @filecheck_test
788
+ @gluon .jit
789
+ def test_elementwise_core ():
790
+ # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
791
+ # CHECK: @test_elementwise_core
792
+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ])
793
+ x = ttgl .arange (0 , 16 , layout )
794
+ y = ttgl .arange (16 , 32 , layout )
795
+
796
+ # CHECK: arith.select {{.*}} : tensor<16xi1, [[BLOCKED]]>, tensor<16xi32, [[BLOCKED]]>
797
+ a = ttgl .where (x > 8 , x , y )
798
+ # CHECK: arith.maxsi {{.*}} : tensor<16xi32, [[BLOCKED]]>
799
+ b = ttgl .maximum (x , y )
800
+ # CHECK: arith.minsi {{.*}} : tensor<16xi32, [[BLOCKED]]>
801
+ c = ttgl .minimum (x , y )
802
+ ttgl .static_assert (a .type == x .type )
803
+ ttgl .static_assert (b .type == x .type )
804
+ ttgl .static_assert (c .type == x .type )
0 commit comments