@@ -897,3 +897,58 @@ def kernel(x, y):
897897
898898 compiled_kernel = kernel .warmup (input , output , grid = (1 , ))
899899 assert compiled_kernel .asm ["ttgir" ].count ("tt.func private" ) == 0
900+
901+
902+ @pytest .mark .parametrize ("interval_pairs" , [[[32 , 4 ]], [[16 , 4 ]], [[16 , 4 ], [64 , 8 ]]])
903+ @pytest .mark .parametrize (
904+ "shared_layout" ,
905+ [{"order" : [0 , 1 ]}, {"order" : [1 , 0 ]},
906+ {"offsets" : [[0 , 1 ], [0 , 2 ], [0 , 8 ], [0 , 4 ], [0 , 16 ], [0 , 32 ], [2 , 0 ], [1 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [32 , 0 ]]}])
907+ @pytest .mark .parametrize ("slice_m_offset, slice_n_offset, slice_m, slice_n" , [(48 , 16 , 16 , 16 ), (32 , 48 , 32 , 16 ),
908+ (48 , 32 , 16 , 32 )])
909+ def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ):
910+ m = 64
911+ n = 64
912+ num_warps = 1
913+ num_warps_cst = ttgl .constexpr (num_warps )
914+ warp_size_cst = ttgl .constexpr (THREADS_PER_WARP )
915+
916+ shape = [m , n ]
917+ if "order" in shared_layout :
918+ order = shared_layout ["order" ]
919+ smem_layout = ttgl .constexpr (ttgl .PaddedSharedLayout .with_identity_for (interval_pairs , shape , order ))
920+ elif "offsets" in shared_layout :
921+ offsets = shared_layout ["offsets" ]
922+ blocks = []
923+ smem_layout = ttgl .constexpr (ttgl .PaddedSharedLayout (interval_pairs , offsets , blocks , shape ))
924+
925+ @gluon .jit
926+ def kernel (in_ptr , out_ptr , M : ttgl .constexpr , N : ttgl .constexpr , SLICE_M_OFFSET : ttgl .constexpr ,
927+ SLICE_N_OFFSET : ttgl .constexpr , SLICE_M : ttgl .constexpr , SLICE_N : ttgl .constexpr ):
928+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [warp_size_cst , 1 ], [1 , num_warps_cst ], [1 , 0 ])
929+ offs_m_load = ttgl .arange (0 , M , ttgl .SliceLayout (1 , blocked ))
930+ offs_n_load = ttgl .arange (0 , N , ttgl .SliceLayout (0 , blocked ))
931+ in_offs = offs_m_load [:, None ] * N + offs_n_load [None , :]
932+
933+ in_data = ttgl .load (in_ptr + in_offs )
934+
935+ smem = ttgl .allocate_shared_memory (ttgl .int32 , [M , N ], smem_layout )
936+ smem_slice0 = smem .slice (SLICE_M_OFFSET , SLICE_M , dim = 0 )
937+ smem_slice1 = smem_slice0 .slice (SLICE_N_OFFSET , SLICE_N , dim = 1 )
938+
939+ smem .store (in_data )
940+
941+ out_data = smem_slice1 .load (blocked )
942+
943+ offs_m_store = ttgl .arange (0 , SLICE_M , ttgl .SliceLayout (1 , blocked ))
944+ offs_n_store = ttgl .arange (0 , SLICE_N , ttgl .SliceLayout (0 , blocked ))
945+ out_offs = offs_m_store [:, None ] * SLICE_N + offs_n_store [None , :]
946+ ttgl .store (out_ptr + out_offs , out_data )
947+
948+ input = torch .arange (m * n , device = "cuda" ).reshape (m , n ).to (torch .int32 )
949+ output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = "cuda" )
950+ ref_output = input [slice_m_offset :slice_m_offset + slice_m , slice_n_offset :slice_n_offset + slice_n ]
951+
952+ kernel [(1 , )](input , output , m , n , slice_m_offset , slice_n_offset , slice_m , slice_n , num_warps = num_warps )
953+
954+ assert (output == ref_output ).all ()
0 commit comments