@@ -264,13 +264,16 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
264264    exported_program: The exported program to apply the pass. 
265265  """ 
266266
267+   is_modified  =  False 
268+ 
267269  def  in_i32 (x : int ):
268270    return  - 2147483648  <=  x  <=  2147483647 
269271
270272  def  to_int32 (x : torch .Tensor ):
271273    return  torch .ops .aten ._to_copy .default (x , dtype = torch .int32 )
272274
273275  def  rewrite_arange (node : torch .fx .Node ):
276+     nonlocal  is_modified 
274277    tensor_meta  =  node .meta .get ("tensor_meta" , None )
275278    if  not  tensor_meta :
276279      return 
@@ -282,12 +285,14 @@ def rewrite_arange(node: torch.fx.Node):
282285      return 
283286    op  =  node .target 
284287    node .target  =  lambda  * args , ** kwargs : to_int32 (op (* args , ** kwargs ))
288+     is_modified  =  True 
285289
286290  graph_module  =  exported_program .graph_module 
287291  for  node  in  graph_module .graph .nodes :
288292
289293    if  node .target  ==  torch .ops .aten .arange .start_step :
290294      rewrite_arange (node )
295+   return  is_modified 
291296
292297
293298# TODO(b/331481564) Make this a ai_edge_torch FX pass. 
@@ -351,9 +356,9 @@ def exported_program_to_mlir(
351356      exported_program ,
352357      fx_infra .decomp .pre_lower_decomp (),
353358  )
354-   _convert_i64_to_i32 (exported_program )
355-   # Run decompositions for retracing and cananicalization. 
356-   exported_program  =  fx_infra .safe_run_decompositions (exported_program , {})
359+   if   _convert_i64_to_i32 (exported_program ): 
360+      # Run decompositions for retracing and cananicalization, if modified . 
361+      exported_program  =  fx_infra .safe_run_decompositions (exported_program , {})
357362
358363  # Passes below mutate the exported program to a state not executable by torch. 
359364  # Do not call run_decompositions after applying the passes. 
0 commit comments