@@ -83,10 +83,34 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
8383 torch .testing .assert_close (z_tri , z_ref )
8484
8585
86- @pytest .mark .parametrize ("M, N" , [[128 , 16 ], [32 , 128 ], [32 , 32 ], [16 , 16 ]])
87- @pytest .mark .parametrize (
88- "src_layout" ,
89- _filter_layouts ([
86+ def _reduce_linear_layouts ():
87+ if THREADS_PER_WARP == 32 :
88+ return [
89+ ttgl .DistributedLinearLayout (
90+ reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]],
91+ lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ]],
92+ warp_bases = [[32 , 0 ], [0 , 32 ]],
93+ block_bases = [],
94+ shape = [64 , 64 ],
95+ )
96+ ]
97+ elif THREADS_PER_WARP == 64 :
98+ return [
99+ ttgl .DistributedLinearLayout (
100+ reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]],
101+ lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [0 , 64 ]],
102+ warp_bases = [[32 , 0 ], [0 , 32 ]],
103+ block_bases = [],
104+ shape = [64 , 128 ],
105+ )
106+ ]
107+ else :
108+ raise RuntimeError (f"Unsupported THREADS_PER_WARP: { THREADS_PER_WARP } " )
109+
110+
111+ def _reduce_layouts ():
112+ shapes = [(128 , 16 ), (32 , 128 ), (32 , 32 ), (16 , 16 )]
113+ layouts = _filter_layouts ([
90114 # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved
91115 # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
92116 # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
@@ -117,83 +141,50 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
117141 ttgl .amd .AMDMFMALayout (version = 2 , warps_per_cta = [1 , 4 ], tiles_per_warp = [1 , 1 ], instr_shape = [32 , 32 ],
118142 transposed = True ),
119143 # TODO: AMDWMMA layouts
120- # WmmaLayout(version=1, warps_per_cta=[4, 1]),
121- # WmmaLayout(version=1, warps_per_cta=[1, 4]),
122144 ttgl .DotOperandLayout (
123- parent = ttgl .NVMMADistributedLayout (
124- version = [2 , 0 ],
125- warps_per_cta = [2 , 4 ],
126- ctas_per_cga = [1 , 1 ], #
127- cta_split_num = [1 , 1 ],
128- cta_order = [0 , 1 ],
129- instr_shape = [16 , 8 ],
130- ), #
131- operand_index = 1 ,
132- k_width = 8 ,
133- ),
145+ parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [2 , 4 ], ctas_per_cga = [1 , 1 ],
146+ cta_split_num = [1 , 1 ], cta_order = [0 , 1 ], instr_shape = [16 , 8 ]),
147+ operand_index = 1 , k_width = 8 ),
134148 ttgl .DotOperandLayout (
135- parent = ttgl .NVMMADistributedLayout (
136- version = [3 , 0 ],
137- warps_per_cta = [8 , 1 ],
138- ctas_per_cga = [1 , 1 ], #
139- cta_split_num = [1 , 1 ],
140- cta_order = [1 , 0 ],
141- instr_shape = [16 , 32 , 16 ],
142- ), #
143- operand_index = 0 ,
144- k_width = 2 ,
145- ),
149+ parent = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [8 , 1 ], ctas_per_cga = [1 , 1 ],
150+ cta_split_num = [1 , 1 ], cta_order = [1 , 0 ], instr_shape = [16 , 32 , 16 ]),
151+ operand_index = 0 , k_width = 2 ),
146152 ttgl .SliceLayout (
147- dim = 0 ,
148- parent = ttgl .NVMMADistributedLayout (
149- version = [2 , 0 ],
150- warps_per_cta = [4 , 1 , 1 ],
151- ctas_per_cga = [1 , 1 , 1 ], #
152- cta_split_num = [1 , 1 , 1 ],
153- cta_order = [2 , 1 , 0 ],
154- instr_shape = [1 , 16 , 8 ],
155- ),
156- ), #
153+ dim = 0 , parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ],
154+ cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 ,
155+ 0 ], instr_shape = [1 , 16 , 8 ])),
157156 ttgl .SliceLayout (
158- dim = 1 ,
159- parent = ttgl .DotOperandLayout (
160- parent = ttgl .NVMMADistributedLayout (
161- version = [2 , 0 ],
162- warps_per_cta = [4 , 1 , 1 ],
163- ctas_per_cga = [1 , 1 , 1 ], #
164- cta_split_num = [1 , 1 , 1 ],
165- cta_order = [2 , 1 , 0 ],
166- instr_shape = [1 , 16 , 8 ],
167- ), #
168- operand_index = 1 ,
169- k_width = 2 ,
170- ),
171- ),
172- "linear_layout" ,
173- ]),
174- )
157+ dim = 1 , parent = ttgl .DotOperandLayout (
158+ parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 , 1 ], ctas_per_cga = [1 , 1 , 1 ],
159+ cta_split_num = [1 , 1 , 1 ], cta_order = [2 , 1 , 0 ],
160+ instr_shape = [1 , 16 , 8 ]), operand_index = 1 , k_width = 2 )),
161+ ])
162+
163+ rets = []
164+ for (M , N ) in shapes :
165+ for layout in layouts :
166+ if isinstance (layout , (ttgl .amd .AMDMFMALayout , ttgl .NVMMADistributedLayout )):
167+ instr_shape = layout .instr_shape
168+ if M < instr_shape [0 ] or N < instr_shape [1 ]:
169+ continue
170+ rets .append ((M , N , layout ))
171+ return rets
172+
173+
174+ def _reduce_cases ():
175+ for layout in _reduce_linear_layouts ():
176+ yield (layout .shape [0 ], layout .shape [1 ], layout )
177+ for M , N , layout in _reduce_layouts ():
178+ yield (M , N , layout )
179+
180+
181+ @pytest .mark .parametrize ("M, N, src_layout" , _reduce_cases ())
175182@pytest .mark .parametrize ("axis" , [0 , 1 ])
176183@pytest .mark .parametrize ("epilogue_kind" , ["reduce1d" , "reduce2d" , "expand_reduce2d" ])
177184@pytest .mark .parametrize ("dtype_str, sanitize_overflow" , [("int32" , False ), ("int32" , True ), ("float32" , False ),
178185 ("float16" , False )])
179186@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
180187def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , sanitize_overflow , reduce_op , device ):
181- if src_layout == "linear_layout" :
182- src_layout = ttgl .DistributedLinearLayout (
183- reg_bases = [[0 , 16 ], [1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]], #
184- lane_bases = [[0 , 0 ], [0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ]], #
185- warp_bases = [[32 , 0 ], [0 , 32 ]],
186- block_bases = [],
187- shape = [M , N ],
188- )
189- if THREADS_PER_WARP != (1 << len (src_layout .lane_bases )):
190- pytest .skip (f"Skipping. This LinearLayout assumes { 1 << len (src_layout .lane_bases )} threads per warp" )
191- elif M < 64 or N < 64 :
192- pytest .skip (f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={ M } , N={ N } " )
193- if isinstance (src_layout ,
194- (ttgl .amd .AMDMFMALayout , ttgl .NVMMADistributedLayout )) and (M < src_layout .instr_shape [0 ]
195- or N < src_layout .instr_shape [1 ]):
196- pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
197188
198189 @gluon .jit
199190 def _add (a , b ):
@@ -341,9 +332,33 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr):
341332])
342333
343334
344- @pytest .mark .parametrize ("M, bins" , [[2048 , 2 ], [8 , 512 ], [32 , 32 ]])
345- @pytest .mark .parametrize ("src_layout" , [ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]), "linear_layout" ])
346- @pytest .mark .parametrize ("dst_layout" , [ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ])])
335+ def _histogram_cases ():
336+ if THREADS_PER_WARP not in (32 , 64 ):
337+ raise RuntimeError (f"Unsupported THREADS_PER_WARP: { THREADS_PER_WARP } " )
338+
339+ m_bins = [(2048 , 2 ), (8 , 512 ), (32 , 32 )]
340+ layouts = [(ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ],
341+ [0 ]), ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]))]
342+ for m , bins in m_bins :
343+ for src_layout , dst_layout in layouts :
344+ yield (m , bins , src_layout , dst_layout )
345+ import math
346+
347+ linear_layouts = [(
348+ ttgl .DistributedLinearLayout (
349+ reg_bases = [[1 << (5 + i )] for i in range (int (math .log2 (m )) - 5 )],
350+ lane_bases = [[0 ], [16 ], [4 ], [2 ], [1 ]] + ([[0 ]] if THREADS_PER_WARP == 64 else []),
351+ warp_bases = [[0 ], [8 ]],
352+ block_bases = [],
353+ shape = (m , ),
354+ ),
355+ bins ,
356+ ) for (m , bins ) in m_bins if m >= 32 ]
357+ for linear_layout , bins in linear_layouts :
358+ yield (linear_layout .shape [0 ], bins , linear_layout , ttgl .BlockedLayout ([1 ], [THREADS_PER_WARP ], [4 ], [0 ]))
359+
360+
361+ @pytest .mark .parametrize ("M, bins, src_layout, dst_layout" , _histogram_cases ())
347362def test_histogram (M , bins , src_layout , dst_layout , device ):
348363
349364 @gluon .jit
@@ -355,18 +370,6 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, B: ttgl.constexpr, src_layout: ttgl.
355370 z_offs = ttgl .arange (0 , B , layout = dst_layout )
356371 ttgl .store (z_ptr + z_offs , h )
357372
358- if src_layout == "linear_layout" :
359- if M == 32 :
360- src_layout = ttgl .DistributedLinearLayout (
361- reg_bases = [],
362- lane_bases = [[0 ], [16 ], [4 ], [2 ], [1 ]] + [[0 ]] * (THREADS_PER_WARP >> 6 ),
363- warp_bases = [[0 ], [8 ]],
364- block_bases = [],
365- shape = (M , ),
366- )
367- else :
368- pytest .skip ("Linear layout is specialized for 32 elements" )
369-
370373 torch .manual_seed (0 )
371374 x = torch .randint (0 , bins , (M , ), dtype = torch .int32 , device = device )
372375 z = torch .zeros ((bins , ), dtype = torch .int32 , device = device )
0 commit comments