@@ -29,16 +29,32 @@ def _module_lowering(
2929 output_type ,
3030 torch_mod ,
3131 extra_library_file_name = None ,
32+ backend_legal_ops = None ,
3233):
3334
3435 if output_type == OutputType .RAW :
3536 if verbose :
3637 print (torch_mod )
3738 return torch_mod
3839 # TODO: pass extra_library_file_name by caller
40+
41+ backend_legal_op_arg_str = ""
42+ if backend_legal_ops is not None :
43+ if not len (backend_legal_ops ) == 0 :
44+ backend_legal_op_arg_str = "backend-legal-ops=" + "," .join (
45+ backend_legal_ops
46+ )
47+
3948 if extra_library_file_name is None :
4049 extra_library_file_name = ""
41- option_string = "{extra-library=" + extra_library_file_name + "}"
50+ option_string = (
51+ "{"
52+ + backend_legal_op_arg_str
53+ + " extra-library="
54+ + extra_library_file_name
55+ + "}"
56+ )
57+
4258 run_pipeline_with_repro_report (
4359 torch_mod ,
4460 f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{ option_string } )" ,
@@ -61,6 +77,7 @@ def export_and_import(
6177 func_name : str = "main" ,
6278 enable_graph_printing : bool = False ,
6379 enable_ir_printing : bool = False ,
80+ backend_legal_ops : Optional [list [str ]] = None ,
6481 ** kwargs ,
6582):
6683 context = ir .Context ()
@@ -98,7 +115,10 @@ def export_and_import(
98115 )
99116
100117 return _module_lowering (
101- enable_ir_printing , OutputType .get (output_type ), fx_importer .module
118+ enable_ir_printing ,
119+ OutputType .get (output_type ),
120+ fx_importer .module ,
121+ backend_legal_ops = backend_legal_ops ,
102122 )
103123
104124
@@ -110,6 +130,7 @@ def stateless_fx_import(
110130 model_name : str = "main" ,
111131 enable_graph_printing : bool = False ,
112132 enable_ir_printing : bool = False ,
133+ backend_legal_ops : Optional [list [str ]] = None ,
113134):
114135 if enable_graph_printing :
115136 gm .print_readable ()
@@ -119,5 +140,8 @@ def stateless_fx_import(
119140 fx_importer = FxImporter (context = context , hooks = hooks )
120141 fx_importer .import_stateless_graph (gm .graph , func_name = model_name )
121142 return _module_lowering (
122- enable_ir_printing , OutputType .get (output_type ), fx_importer .module
143+ enable_ir_printing ,
144+ OutputType .get (output_type ),
145+ fx_importer .module ,
146+ backend_legal_ops = backend_legal_ops ,
123147 )
0 commit comments