@@ -3140,6 +3140,42 @@ def test_amd_tdm_load(target):
31403140""" )
31413141
31423142
3143+ @gluon .jit
3144+ def amd_host_tdm_load_kernel (desc ):
3145+ buffer = ttgl .allocate_shared_memory (desc .dtype , shape = desc .block_shape , layout = desc .layout )
3146+ ttgl .amd .gfx1250 .tdm .async_load (desc , offsets = [0 , 2 ], dest = buffer )
3147+
3148+ ttgl .amd .gfx1250 .tdm .async_wait (0 )
3149+ buffer .load (layout = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [4 , 1 ], [1 , 0 ]))
3150+
3151+
3152+ @pytest .mark .parametrize ("target" , [HIP_TARGET_GFX1250 ])
3153+ def test_amd_host_tdm_load (target ):
3154+
3155+ ptr = MockTensor (ttgl .float16 , shape = (32 , 128 ))
3156+ layout = ttgl .PaddedSharedLayout .with_identity_for ([[32 , 4 ]], [16 , 64 ], [1 , 0 ])
3157+ desc = gluon .amd .gfx1250 .TensorDescriptor .from_tensor (ptr , block_shape = (16 , 64 ), layout = layout )
3158+ module = run_parser (amd_host_tdm_load_kernel , * make_args (desc ), target )
3159+ expecttest .assert_expected_inline (
3160+ anonymize_ir (module .str_nodebug ()), """\
3161+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
3162+ #shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
3163+ #smem = #ttg.shared_memory
3164+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
3165+ tt.func public @amd_host_tdm_load_kernel(%arg0: !tt.tensordesc<tensor<16x64xf16, #shared>>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) attributes {noinline = false} {
3166+ %0 = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
3167+ %c0_i32 = arith.constant 0 : i32
3168+ %c2_i32 = arith.constant 2 : i32
3169+ %true = arith.constant true
3170+ %1 = amdg.async_tdm_copy_global_to_local %arg0[%c0_i32, %c2_i32] into %0, %true : !tt.tensordesc<tensor<16x64xf16, #shared>> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
3171+ %2 = amdg.async_tdm_wait {num = 0 : i32}
3172+ %3 = ttg.local_load %0 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> tensor<16x64xf16, #blocked>
3173+ tt.return
3174+ }
3175+ }
3176+ """ )
3177+
3178+
31433179@gluon .jit
31443180def amd_tdm_store_kernel (ptr ):
31453181 SHARED_LAYOUT : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ])
0 commit comments