@@ -1209,3 +1209,48 @@ def test_gather_layouts(axis, src_layout, index_layout, src_shape, idx_shape, de
12091209
12101210 torch .testing .assert_close (out , ref , rtol = 0 , atol = 0 )
12111211 assert ("nvvm.shfl.sync.idx" in obj .asm ["llir" ]) or ("llvm.amdgcn.ds.bpermute" in obj .asm ["llir" ])
1212+
1213+
1214+ @pytest .mark .parametrize ("M, N, M_tile_size, N_tile_size" ,
1215+ [[128 , 128 , 64 , 64 ], [128 , 128 , 64 , 32 ], [128 , 64 , 64 , 32 ], [256 , 128 , 64 , 64 ]])
1216+ def test_memdesc_subslice (M , N , M_tile_size , N_tile_size , device ):
1217+ if M % M_tile_size != 0 or N % N_tile_size != 0 :
1218+ pytest .skip (f"Shape size ({ M } , { N } ) must be divisible by tile size ({ M_tile_size } , { N_tile_size } )" )
1219+
1220+ num_rows_per_warp = THREADS_PER_WARP // 4
1221+ blocked_layout = ttgl .BlockedLayout (size_per_thread = [1 , 8 ], threads_per_warp = [num_rows_per_warp , 4 ],
1222+ warps_per_cta = [4 , 1 ], order = [1 , 0 ])
1223+ shared_layout = ttgl .SwizzledSharedLayout (vec = 8 , per_phase = 1 , max_phase = 8 , order = [1 , 0 ])
1224+
1225+ @gluon .jit
1226+ def kernel (
1227+ out ,
1228+ M : ttgl .constexpr ,
1229+ N : ttgl .constexpr ,
1230+ BLOCK_SIZE_M : ttgl .constexpr ,
1231+ BLOCK_SIZE_N : ttgl .constexpr ,
1232+ blocked_layout : ttgl .constexpr ,
1233+ shared_layout : ttgl .constexpr ,
1234+ ):
1235+ offs_m = ttgl .arange (0 , M , layout = ttgl .SliceLayout (1 , blocked_layout ))[:, None ]
1236+ offs_n = ttgl .arange (0 , N , layout = ttgl .SliceLayout (0 , blocked_layout ))[None , :]
1237+ vals = ttgl .load (out + offs_m * N + offs_n )
1238+
1239+ smem : ttgl .shared_memory_descriptor = ttgl .allocate_shared_memory (vals .dtype , (M , N ), shared_layout , value = vals )
1240+ for i in ttgl .static_range (M // BLOCK_SIZE_M ):
1241+ for j in ttgl .static_range (N // BLOCK_SIZE_N ):
1242+ tile = smem .slice (i * BLOCK_SIZE_M , BLOCK_SIZE_M , dim = 0 ).slice (j * BLOCK_SIZE_N , BLOCK_SIZE_N , dim = 1 )
1243+ tile_vals = tile .load (blocked_layout )
1244+ tile_offs_m = ttgl .arange (0 , BLOCK_SIZE_M , layout = ttgl .SliceLayout (1 , blocked_layout ))[:, None ]
1245+ tile_offs_n = ttgl .arange (0 , BLOCK_SIZE_N , layout = ttgl .SliceLayout (0 , blocked_layout ))[None , :]
1246+ linear_idx = tile_offs_m * N + tile_offs_n + i * BLOCK_SIZE_M * N + j * BLOCK_SIZE_N
1247+ tile .store (linear_idx + tile_vals )
1248+
1249+ vals = smem .load (blocked_layout )
1250+ ttgl .store (out + offs_m * N + offs_n , vals )
1251+
1252+ out = torch .zeros ((M , N ), device = device , dtype = torch .float16 )
1253+ kernel [(1 , )](out , M , N , M_tile_size , N_tile_size , blocked_layout , shared_layout )
1254+
1255+ out_ref = torch .arange (0 , M * N , device = device ).reshape ((M , N )).to (torch .float16 )
1256+ torch .testing .assert_close (out , out_ref , rtol = 0 , atol = 0 )
0 commit comments