2424from executorch .backends .cadence .aot .utils import (
2525 get_default_memory_config ,
2626 MemoryConfig ,
27- model_is_quantized ,
2827)
2928from executorch .devtools import generate_etrecord
3029from executorch .exir import (
3837from executorch .exir .passes import ToOutVarPass
3938from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
4039from torch ._inductor .decomposition import remove_decompositions
41- from torch .ao .quantization .pt2e .export_utils import model_is_exported
4240from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
4341
4442from torch .export import export
@@ -158,26 +156,10 @@ def export_program(
158156) -> ExportedProgram :
159157 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
160158
161- # We don't support training mode. Make the model inference mode by
162- # calling model.eval() or an equivalent call for quantized models.
163- # GraphModules cannot call eval(), so we skip them.
164- if not isinstance (model , torch .fx .GraphModule ):
165- if hasattr (model , "eval" ):
166- model .eval ()
167- else :
168- # If the model is quantized, call the suggested torch.ao.quantization API
169- # which only does dropout and batchnorm.
170- if model_is_quantized (model ):
171- torch .ao .quantization .move_exported_model_to_eval (model )
172- else :
173- # If we get a GraphModule which is _not_ quantized, then it should already
174- # have been exported.
175- assert model_is_exported (model ), "model should be from an ExportedProgram"
176-
177159 # Prevent mkldnn decompositions
178160 torch ._C ._set_mkldnn_enabled (False )
179161
180- # else: capture the model and return it.
162+ # Export the model and return it.
181163 expo_program = export (model , inputs , strict = True )
182164
183165 if dump_graphs :
@@ -206,8 +188,8 @@ def export_to_edge(
206188 _skip_dim_order = True ,
207189 # Allow specific non-core aten ops in the IR.
208190 _core_aten_ops_exception_list = [
191+ torch .ops .aten ._native_batch_norm_legit_functional .default ,
209192 torch .ops .aten .linear .default ,
210- torch .ops .aten .native_batch_norm .default ,
211193 torch .ops .aten .linalg_vector_norm .default ,
212194 torch .ops .aten .unfold .default ,
213195 torch .ops .aten .angle .default ,
@@ -226,10 +208,9 @@ def export_to_cadence(
226208 model : torch .nn .Module ,
227209 inputs : tuple [object , ...],
228210 dump_graphs : bool = False ,
229- output_dir : Optional [str ] = None ,
230211 opt_level : int = 1 ,
231212) -> EdgeProgramManager :
232- edge_prog_manager = export_to_edge (model , inputs )
213+ edge_prog_manager = export_to_edge (model , inputs , dump_graphs = dump_graphs )
233214 cadence_passes = get_cadence_passes (opt_level )
234215
235216 # Run a couple required passes for quant/dequant ops
0 commit comments