@@ -252,6 +252,8 @@ def shared_memory_cast_kernel():
252252 smem = ttgl .allocate_shared_memory (ttgl .int8 , [2 , 256 , 128 ], layout_a )
253253 perm = smem .index (0 ).permute ((1 , 0 ))
254254 ttgl .static_assert (perm .type .layout == layout_T )
255+ # Check that the MLIR type and Gluon types match by emitting a call.
256+ anchor_noinline (perm )
255257
256258 layout_b : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = False , element_bitwidth = 16 ,
257259 rank = 4 , cta_order = [3 , 2 , 1 , 0 ])
@@ -279,11 +281,15 @@ def test_shared_memory_cast(fresh_knobs):
279281 %c0_i32_0 = arith.constant 0 : i32
280282 %1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
281283 %2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
284+ tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
282285 %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
283286 %4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
284287 %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
285288 tt.return
286289 }
290+ tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} {
291+ tt.return
292+ }
287293}
288294""" )
289295
@@ -318,6 +324,11 @@ def anchor(x):
318324 pass
319325
320326
327+ @gluon .jit (noinline = True )
328+ def anchor_noinline (x ):
329+ pass
330+
331+
321332@filecheck_test
322333@gluon .jit
323334def test_warp_specialize ():
0 commit comments