@@ -252,6 +252,8 @@ def shared_memory_cast_kernel():
252
252
smem = ttgl .allocate_shared_memory (ttgl .int8 , [2 , 256 , 128 ], layout_a )
253
253
perm = smem .index (0 ).permute ((1 , 0 ))
254
254
ttgl .static_assert (perm .type .layout == layout_T )
255
+ # Check that the MLIR type and Gluon types match by emitting a call.
256
+ anchor_noinline (perm )
255
257
256
258
layout_b : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = False , element_bitwidth = 16 ,
257
259
rank = 4 , cta_order = [3 , 2 , 1 , 0 ])
@@ -279,11 +281,15 @@ def test_shared_memory_cast(fresh_knobs):
279
281
%c0_i32_0 = arith.constant 0 : i32
280
282
%1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
281
283
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
284
+ tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
282
285
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
283
286
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
284
287
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
285
288
tt.return
286
289
}
290
+ tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} {
291
+ tt.return
292
+ }
287
293
}
288
294
""" )
289
295
@@ -318,6 +324,11 @@ def anchor(x):
318
324
pass
319
325
320
326
327
+ @gluon .jit (noinline = True )
328
+ def anchor_noinline (x ):
329
+ pass
330
+
331
+
321
332
@filecheck_test
322
333
@gluon .jit
323
334
def test_warp_specialize ():
0 commit comments