3838
3939tune_logger = logging .getLogger ("tune" )
4040
41- # TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
41+
4242def apply_configuration (
4343 template : list [str ],
4444 configuration : Configuration ,
45- workgroup_sizes : list [int ],
46- reduction_sizes : list [int ],
4745) -> str :
48- intrinsic = get_intrinsic (configuration )
49- subgroup_m_count = get_subgroup_m_count (configuration )
50- subgroup_n_count = get_subgroup_n_count (configuration )
46+ lowering_config = configuration .lowering_config
47+ intrinsic = lowering_config .mma_kind
48+ (
49+ subgroup_m_count ,
50+ subgroup_n_count ,
51+ ) = lowering_config .subgroup_count_mn
52+ workgroup_sizes = lowering_config .workgroup_tile_sizes
53+ reduction_sizes = lowering_config .reduction_tile_sizes
5154 tune_logger .info (f"Applying: { configuration } " )
5255 expr0 = re .compile (
5356 r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
@@ -125,9 +128,12 @@ class MmtTuner(DispatchTuner, MmtParser):
125128 def get_transform_function_mmt (
126129 self , problem_size : ProblemSize , functionName : str , configuration : Configuration
127130 ) -> str :
128- intrinsic = get_intrinsic (configuration )
129- subgroup_m_count = get_subgroup_m_count (configuration )
130- subgroup_n_count = get_subgroup_n_count (configuration )
131+ lowering_config = configuration .lowering_config
132+ intrinsic = lowering_config .mma_kind
133+ (
134+ subgroup_m_count ,
135+ subgroup_n_count ,
136+ ) = lowering_config .subgroup_count_mn
131137
132138 wg_x , wg_y , wg_z = configuration .workgroup_size
133139 extra_config = get_pipeline_config (configuration )
@@ -167,8 +173,6 @@ def apply_params(
167173 modified += apply_configuration (
168174 template ,
169175 configuration ,
170- get_mmt_workgroup_sizes (configuration ),
171- get_mmt_reduction_sizes (configuration ),
172176 )
173177 embeddable = indent (
174178 self .get_transform_function_mmt (problem_size , f"match_op" , configuration ),
@@ -193,15 +197,12 @@ def get_transform_function_conv(
193197 filter = f"tensor<{ problem_size .rhs_type } >"
194198 output = f"tensor<{ dynamic_batch_output_ty } >"
195199
196- workgroup_sizes = ", " .join (
197- map (str , self .get_conv_workgroup_sizes (configuration ))
198- )
199- reduction_sizes = ", " .join (
200- map (str , self .get_conv_reduction_sizes (configuration ))
201- )
202- intrinsic = get_intrinsic (configuration )
203- subgroup_m_count = get_subgroup_m_count (configuration )
204- subgroup_n_count = get_subgroup_n_count (configuration )
200+ lowering_config = configuration .lowering_config
201+ intrinsic = lowering_config .mma_kind
202+ (
203+ subgroup_m_count ,
204+ subgroup_n_count ,
205+ ) = lowering_config .subgroup_count_mn
205206
206207 wg_x , wg_y , wg_z = configuration .workgroup_size
207208 extra_config = get_pipeline_config (configuration )
@@ -246,8 +247,6 @@ def apply_params(
246247 modified += apply_configuration (
247248 template ,
248249 configuration ,
249- self .get_conv_workgroup_sizes (configuration ),
250- self .get_conv_reduction_sizes (configuration ),
251250 )
252251 embeddable = indent (
253252 self .get_transform_function_conv (problem_size , f"match_op" , configuration ),
@@ -263,15 +262,12 @@ def get_transform_function_broadcast_rhs_mmt(
263262 functionName : str ,
264263 configuration : Configuration ,
265264 ) -> str :
266- workgroup_sizes = ", " .join (
267- map (str , get_batch_mmt_workgroup_sizes (configuration ))
268- )
269- reduction_sizes = ", " .join (
270- map (str , get_batch_mmt_reduction_sizes (configuration ))
271- )
272- intrinsic = get_intrinsic (configuration )
273- subgroup_m_count = get_subgroup_m_count (configuration )
274- subgroup_n_count = get_subgroup_n_count (configuration )
265+ lowering_config = configuration .lowering_config
266+ intrinsic = lowering_config .mma_kind
267+ (
268+ subgroup_m_count ,
269+ subgroup_n_count ,
270+ ) = lowering_config .subgroup_count_mn
275271
276272 wg_x , wg_y , wg_z = configuration .workgroup_size
277273 extra_config = get_pipeline_config (configuration )
@@ -316,8 +312,6 @@ def apply_params_broadcast_rhs_mmt(
316312 modified += apply_configuration (
317313 template ,
318314 configuration ,
319- get_batch_mmt_workgroup_sizes (configuration ),
320- get_batch_mmt_reduction_sizes (configuration ),
321315 )
322316
323317 embeddable = indent (
@@ -345,8 +339,6 @@ def apply_params(
345339 apply_configuration (
346340 template ,
347341 configuration ,
348- get_contract_workgroup_sizes (configuration , self .tile_dims ),
349- get_contract_reduction_sizes (configuration , self .tile_dims ),
350342 ),
351343 "" ,
352344 )
@@ -359,9 +351,12 @@ def get_transform_function_batch_mmt(
359351 functionName : str ,
360352 configuration : Configuration ,
361353 ) -> str :
362- intrinsic = get_intrinsic (configuration )
363- subgroup_m_count = get_subgroup_m_count (configuration )
364- subgroup_n_count = get_subgroup_n_count (configuration )
354+ lowering_config = configuration .lowering_config
355+ intrinsic = lowering_config .mma_kind
356+ (
357+ subgroup_m_count ,
358+ subgroup_n_count ,
359+ ) = lowering_config .subgroup_count_mn
365360
366361 wg_x , wg_y , wg_z = configuration .workgroup_size
367362 extra_config = get_pipeline_config (configuration )
@@ -403,8 +398,6 @@ def apply_params(
403398 modified += apply_configuration (
404399 template ,
405400 configuration ,
406- get_batch_mmt_workgroup_sizes (configuration ),
407- get_batch_mmt_reduction_sizes (configuration ),
408401 )
409402
410403 embeddable = indent (
@@ -428,9 +421,12 @@ def get_transform_function_batch_matmul(
428421 input1 = f"tensor<{ problem_size .rhs_type } >"
429422 output = f"tensor<{ problem_size .res_type } >"
430423
431- intrinsic = get_intrinsic (configuration )
432- subgroup_m_count = get_subgroup_m_count (configuration )
433- subgroup_n_count = get_subgroup_n_count (configuration )
424+ lowering_config = configuration .lowering_config
425+ intrinsic = lowering_config .mma_kind
426+ (
427+ subgroup_m_count ,
428+ subgroup_n_count ,
429+ ) = lowering_config .subgroup_count_mn
434430
435431 wg_x , wg_y , wg_z = configuration .workgroup_size
436432 extra_config = get_pipeline_config (configuration )
@@ -476,8 +472,6 @@ def apply_params(
476472 modified += apply_configuration (
477473 template ,
478474 configuration ,
479- get_contract_workgroup_sizes (configuration , self .tile_dims ),
480- get_contract_reduction_sizes (configuration , self .tile_dims ),
481475 )
482476
483477 embeddable = indent (
0 commit comments