@@ -209,7 +209,10 @@ def __repr__(self):
209209
210210 def get_text (self , enable_debug_info = False ):
211211 return str (
212- self .module .operation .get_asm (enable_debug_info = enable_debug_info )
212+ self .module .operation .get_asm (
213+ enable_debug_info = enable_debug_info ,
214+ large_elements_limit = 16 ,
215+ )
213216 )
214217
215218 @property
@@ -326,8 +329,24 @@ def _convert_q_dq_per_channel_args_to_list(
326329
327330def exported_program_to_mlir (
328331 exported_program : torch .export .ExportedProgram ,
332+ * ,
333+ ir_context : ir .Context | None = None ,
334+ _pre_lower_pass : (
335+ Callable [[torch .export .ExportedProgram ], None ] | None
336+ ) = None ,
329337) -> MlirLowered :
330- """Lower the exported program to MLIR."""
338+ """Lower the exported program to MLIR.
339+
340+ Args:
341+ exported_program: The exported program to lower.
342+ ir_context: The MLIR context to use. If not provided, a new context will be
343+ created.
344+ _pre_lower_pass: A function to run on exported program before lowering.
345+
346+ Returns:
347+ The lowered MLIR module, metadata, and weight tensors bundle from exported
348+ program.
349+ """
331350 exported_program = fx_infra .safe_run_decompositions (
332351 exported_program ,
333352 fx_infra .decomp .pre_lower_decomp (),
@@ -340,10 +359,16 @@ def exported_program_to_mlir(
340359 # Do not call run_decompositions after applying the passes.
341360 _convert_q_dq_per_channel_args_to_list (exported_program )
342361
343- with export_utils .create_ir_context () as context , ir .Location .unknown ():
362+ if _pre_lower_pass :
363+ _pre_lower_pass (exported_program )
364+
365+ if not ir_context :
366+ ir_context = export_utils .create_ir_context ()
367+
368+ with ir_context , ir .Location .unknown ():
344369
345370 module = ir .Module .create ()
346- lctx = LoweringContext (context , module )
371+ lctx = LoweringContext (ir_context , module )
347372 interpreter = LoweringInterpreter (exported_program .graph_module , lctx )
348373 ir_flat_inputs , export_flat_args , tensor_metas = _build_flat_inputs (
349374 exported_program
@@ -382,7 +407,6 @@ def exported_program_to_mlir(
382407
383408 main_func .attributes ["sym_visibility" ] = ir .StringAttr .get ("public" )
384409 temp_func .erase ()
385-
386410 module .operation .verify ()
387411
388412 input_signature = []
@@ -422,5 +446,5 @@ def exported_program_to_mlir(
422446 for tensor_meta in _get_output_metas (exported_program )
423447 ]
424448 return MlirLowered (
425- context , module , state_dict , input_signature , output_signature
449+ ir_context , module , state_dict , input_signature , output_signature
426450 )
0 commit comments