@@ -82,10 +82,34 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
8282 torch .testing .assert_close (z_tri , z_ref )
8383
8484
85- @pytest .mark .parametrize ("M, N" , [[128 , 16 ], [32 , 128 ], [32 , 32 ], [16 , 16 ]])
86- @pytest .mark .parametrize (
87- "src_layout" ,
88- _filter_layouts ([
85+ def _reduce_linear_layouts ():
86+ if THREADS_PER_WARP == 32 :
87+ return [
88+ ttgl .DistributedLinearLayout (
89+ reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]],
90+ lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ]],
91+ warp_bases = [[32 , 0 ], [0 , 32 ]],
92+ block_bases = [],
93+ shape = [64 , 64 ],
94+ )
95+ ]
96+ elif THREADS_PER_WARP == 64 :
97+ return [
98+ ttgl .DistributedLinearLayout (
99+ reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]],
100+ lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [0 , 64 ]],
101+ warp_bases = [[32 , 0 ], [0 , 32 ]],
102+ block_bases = [],
103+ shape = [64 , 128 ],
104+ )
105+ ]
106+ else :
107+ raise RuntimeError (f"Unsupported THREADS_PER_WARP: { THREADS_PER_WARP } " )
108+
109+
110+ def _reduce_layouts ():
111+ shapes = [(128 , 16 ), (32 , 128 ), (32 , 32 ), (16 , 16 )]
112+ layouts = _filter_layouts ([
89113 # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved
90114 # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
91115 # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
@@ -104,47 +128,50 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
104128 ttgl .amd .AMDMFMALayout (version = 2 , warps_per_cta = [1 , 4 ], tiles_per_warp = [1 , 1 ], instr_shape = [32 , 32 ],
105129 transposed = True ),
106130 # TODO: AMDWMMA layouts
107- # WmmaLayout(version=1, warps_per_cta=[4, 1]),
108- # WmmaLayout(version=1, warps_per_cta=[1, 4]),
109131 ttgl .DotOperandLayout (
110- parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [2 , 4 ], ctas_per_cga = [1 , 1 ], #
111- cta_split_num = [1 , 1 ], cta_order = [0 , 1 ], instr_shape = [16 , 8 ]), #
132+ parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [2 , 4 ], ctas_per_cga = [1 , 1 ],
133+ cta_split_num = [1 , 1 ], cta_order = [0 , 1 ], instr_shape = [16 , 8 ]),
112134 operand_index = 1 , k_width = 8 ),
113135 ttgl .DotOperandLayout (
114- parent = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [8 , 1 ], ctas_per_cga = [1 , 1 ], #
115- cta_split_num = [1 , 1 ], cta_order = [1 , 0 ], instr_shape = [16 , 32 , 16 ]), #
136+ parent = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [8 , 1 ], ctas_per_cga = [1 , 1 ],
137+ cta_split_num = [1 , 1 ], cta_order = [1 , 0 ], instr_shape = [16 , 32 , 16 ]),
116138 operand_index = 0 , k_width = 2 ),
117139 ttgl .SliceLayout (
118- dim = 0 ,
119- parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ], #
120- cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 , 0 ], instr_shape = [1 , 16 ,
121- 8 ])), #
140+ dim = 0 , parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ],
141+ cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 ,
142+ 0 ], instr_shape = [1 , 16 , 8 ])),
122143 ttgl .SliceLayout (
123144 dim = 1 , parent = ttgl .DotOperandLayout (
124- parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ], #
125- cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 , 0 ], instr_shape = [1 , 16 ,
126- 8 ]), #
127- operand_index = 1 , k_width = 2 )),
128- "linear_layout" ,
129- ]))
145+ parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ],
146+ cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 , 0 ],
147+ instr_shape = [1 , 16 , 8 ]), operand_index = 1 , k_width = 2 )),
148+ ])
149+
150+ rets = []
151+ for (M , N ) in shapes :
152+ for layout in layouts :
153+ if isinstance (layout , (ttgl .amd .AMDMFMALayout , ttgl .NVMMADistributedLayout )):
154+ instr_shape = layout .instr_shape
155+ if M < instr_shape [0 ] or N < instr_shape [1 ]:
156+ continue
157+ rets .append ((M , N , layout ))
158+ return rets
159+
160+
161+ def _reduce_cases ():
162+ for layout in _reduce_linear_layouts ():
163+ yield (layout .shape [0 ], layout .shape [1 ], layout )
164+ for M , N , layout in _reduce_layouts ():
165+ yield (M , N , layout )
166+
167+
168+ @pytest .mark .parametrize ("M, N, src_layout" , _reduce_cases ())
130169@pytest .mark .parametrize ("axis" , [0 , 1 ])
131170@pytest .mark .parametrize ("epilogue_kind" , ['reduce1d' , 'reduce2d' , 'expand_reduce2d' ])
132171@pytest .mark .parametrize ("dtype_str, sanitize_overflow" , [("int32" , False ), ("int32" , True ), ("float32" , False ),
133172 ("float16" , False )])
134173@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
135174def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , sanitize_overflow , reduce_op , device ):
136- if src_layout == "linear_layout" :
137- src_layout = ttgl .DistributedLinearLayout (reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]], #
138- lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ]], #
139- warp_bases = [[32 , 0 ], [0 , 32 ]], block_bases = [], shape = [M , N ])
140- if THREADS_PER_WARP != (1 << len (src_layout .lane_bases )):
141- pytest .skip (f"Skipping. This LinearLayout assumes { 1 << len (src_layout .lane_bases )} threads per warp" )
142- elif M < 64 or N < 64 :
143- pytest .skip (f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={ M } , N={ N } " )
144- if isinstance (src_layout ,
145- (ttgl .amd .AMDMFMALayout , ttgl .NVMMADistributedLayout )) and (M < src_layout .instr_shape [0 ]
146- or N < src_layout .instr_shape [1 ]):
147- pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
148175
149176 @gluon .jit
150177 def _add (a , b ):
@@ -240,9 +267,33 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr):
240267])
241268
242269
243- @pytest .mark .parametrize ("M, bins" , [[2048 , 2 ], [8 , 512 ], [32 , 32 ]])
244- @pytest .mark .parametrize ("src_layout" , [ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]), "linear_layout" ])
245- @pytest .mark .parametrize ("dst_layout" , [ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ])])
270+ def _histogram_cases ():
271+ if THREADS_PER_WARP not in (32 , 64 ):
272+ raise RuntimeError (f"Unsupported THREADS_PER_WARP: { THREADS_PER_WARP } " )
273+
274+ m_bins = [(2048 , 2 ), (8 , 512 ), (32 , 32 )]
275+ layouts = [(ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ],
276+ [0 ]), ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]))]
277+ for m , bins in m_bins :
278+ for src_layout , dst_layout in layouts :
279+ yield (m , bins , src_layout , dst_layout )
280+ import math
281+
282+ linear_layouts = [(
283+ ttgl .DistributedLinearLayout (
284+ reg_bases = [[1 << (5 + i )] for i in range (int (math .log2 (m )) - 5 )],
285+ lane_bases = [[0 ], [16 ], [4 ], [2 ], [1 ]] + ([[0 ]] if THREADS_PER_WARP == 64 else []),
286+ warp_bases = [[0 ], [8 ]],
287+ block_bases = [],
288+ shape = (m , ),
289+ ),
290+ bins ,
291+ ) for (m , bins ) in m_bins if m >= 32 ]
292+ for linear_layout , bins in linear_layouts :
293+ yield (linear_layout .shape [0 ], bins , linear_layout , ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]))
294+
295+
296+ @pytest .mark .parametrize ("M, bins, src_layout, dst_layout" , _histogram_cases ())
246297def test_histogram (M , bins , src_layout , dst_layout , device ):
247298
248299 @gluon .jit
@@ -254,18 +305,6 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, B: ttgl.constexpr, src_layout: ttgl.
254305 z_offs = ttgl .arange (0 , B , layout = dst_layout )
255306 ttgl .store (z_ptr + z_offs , h )
256307
257- if src_layout == "linear_layout" :
258- if M == 32 :
259- src_layout = ttgl .DistributedLinearLayout (
260- reg_bases = [],
261- lane_bases = [[0 ], [16 ], [4 ], [2 ], [1 ]] + [[0 ]] * (THREADS_PER_WARP >> 6 ),
262- warp_bases = [[0 ], [8 ]],
263- block_bases = [],
264- shape = (M , ),
265- )
266- else :
267- pytest .skip ("Linear layout is specialized for 32 elements" )
268-
269308 torch .manual_seed (0 )
270309 x = torch .randint (0 , bins , (M , ), dtype = torch .int32 , device = device )
271310 z = torch .zeros ((bins , ), dtype = torch .int32 , device = device )
0 commit comments