@@ -39,7 +39,6 @@ class DispatchTuner(DispatchParser):
3939 @abstractmethod
4040 def get_td_spec (
4141 self ,
42- ir_module : ir .Module ,
4342 compilation_info : iree_codegen .CompilationInfoAttr ,
4443 ) -> ir .Module :
4544 """Generate a transform dialect spec that applies the compilation info attr."""
@@ -62,12 +61,14 @@ def find_handler(self, op_name: str) -> DispatchTuner:
6261
6362
6463class ContractionOpInterfaceTuner (DispatchTuner , ContractionOpInterfaceParser ):
64+ def __init__ (self , root_op : ir .Operation ):
65+ super ().__init__ (root_op )
66+
6567 def get_td_spec (
6668 self ,
67- ir_module : ir .Module ,
6869 compilation_info : iree_codegen .CompilationInfoAttr ,
6970 ) -> ir .Module :
70- contraction_op : ir . Operation = self .get_contraction_operation ( ir_module )
71+ contraction_op = self .get_root_op ( )
7172 lhs_type = ir .ShapedType (contraction_op .operands [0 ].type )
7273 rhs_type = ir .ShapedType (contraction_op .operands [1 ].type )
7374 acc_type = ir .ShapedType (contraction_op .operands [2 ].type )
@@ -77,17 +78,16 @@ def get_td_spec(
7778 # TODO(Max191): Get the function name from the func.func in the input module.
7879 func_name = f"match_contraction_{ M } x{ N } x{ K } _{ lhs_type .element_type } x{ rhs_type .element_type } x{ acc_type .element_type } "
7980 return build_td_spec (
80- ir_module .context , contraction_op , compilation_info , func_name
81+ contraction_op .context , contraction_op , compilation_info , func_name
8182 )
8283
8384
8485class ConvolutionOpInterfaceTuner (DispatchTuner , ConvolutionOpInterfaceParser ):
8586 def get_td_spec (
8687 self ,
87- ir_module : ir .Module ,
8888 compilation_info : iree_codegen .CompilationInfoAttr ,
8989 ) -> ir .Module :
90- conv_op : ir . Operation = self .get_conv_operation ( ir_module )
90+ conv_op = self .get_root_op ( )
9191 assert (
9292 conv_op .name == "linalg.conv_2d_nhwc_hwcf"
9393 ), "expected linalg.conv_2d_nhwc_hwcf"
@@ -104,7 +104,7 @@ def get_td_spec(
104104 conv_type = conv_op .name .split ("." )[- 1 ]
105105 # TODO(Max191): Get the function name from the func.func in the input module.
106106 func_name = f"match_{ conv_type } _{ N } x{ H } x{ W } x{ C } x{ P } x{ Q } x{ F } _{ lhs_type .element_type } x{ rhs_type .element_type } x{ acc_type .element_type } "
107- return build_td_spec (ir_module .context , conv_op , compilation_info , func_name )
107+ return build_td_spec (conv_op .context , conv_op , compilation_info , func_name )
108108
109109
110110@dataclass
@@ -156,21 +156,33 @@ def generate_configs_and_td_specs(
156156 pipeline_options_search_space : PipelineOptionsSearchSpace = PipelineOptionsSearchSpace (),
157157 codegen_pipeline : iree_codegen .DispatchLoweringPassPipeline = iree_codegen .DispatchLoweringPassPipeline .LLVMGPUVectorDistribute ,
158158) -> list [ir .Module ]:
159- dispatch_tuner_registry = DispatchTunerRegistry ()
160- dispatch_tuner_registry .register (
161- [
162- ContractionOpInterfaceTuner (),
163- ConvolutionOpInterfaceTuner (),
164- ]
165- )
159+ dispatch_tuners : list [type [DispatchTuner ]] = [
160+ ContractionOpInterfaceTuner ,
161+ ConvolutionOpInterfaceTuner ,
162+ ]
163+
164+ root_op_list = iree_codegen .get_tuner_root_ops (input_module )
165+ if len (root_op_list ) == 0 :
166+ tune_logger .error (
167+ "No root ops found. Did you forget to pass "
168+ "--iree-config-add-tuner-attributes during compilation?"
169+ )
170+ return []
171+ elif len (root_op_list ) > 1 :
172+ tune_logger .error ("Multiple root ops found. Only one is currently supported." )
173+ return []
166174
167- walk_result : OpWalkResult = walk_mlir_op (input_module , dispatch_tuner_registry )
175+ root_op = root_op_list [0 ]
176+
177+ dispatch_tuner : Optional [DispatchTuner ] = None
178+ for tuner_class in dispatch_tuners :
179+ tuner = tuner_class (root_op )
180+ if tuner .has_valid_root_op ():
181+ dispatch_tuner = tuner
182+ break
168183
169- dispatch_tuner = walk_result .dispatch_tuner
170184 assert dispatch_tuner , "No suitable dispatch tuner found"
171- problem_size : ProblemSize = dispatch_tuner .get_shapes (
172- str (input_module ).splitlines ()
173- )
185+ problem_size : ProblemSize = dispatch_tuner .get_problem_size ()
174186 tune_logger .debug (str (problem_size ))
175187
176188 # Index 0 is reserved for default config, so it gets a placeholder spec.
@@ -196,7 +208,7 @@ def generate_configs_and_td_specs(
196208 if i >= limit :
197209 break
198210 tune_logger .debug (f"Solution #{ i + 1 } : { config } " )
199- td_spec_module = dispatch_tuner .get_td_spec (input_module , config )
211+ td_spec_module = dispatch_tuner .get_td_spec (config )
200212 assert td_spec_module , "Failed to generate transform dialect spec"
201213 config_specs .append (td_spec_module )
202214
@@ -263,7 +275,7 @@ def run_command(run_pack: RunPack) -> RunResult:
263275# info makes the inputs to compilation consistent, and allows for overwriting
264276# the compilation info with generated TD specs during codegen.
265277def strip_root_op_attr (module : ir .Module ):
266- root_ops : list [ir .Operation ] = get_ops_from_module (module , is_root_op )
278+ root_ops : list [ir .Operation ] = iree_codegen . get_tuner_root_ops (module )
267279 for root_op in root_ops :
268280 assert (
269281 ROOT_OP_ATTR_NAME in root_op .opview .attributes
0 commit comments