1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414# ============================================================================== 
15+ from  ai_edge_torch  import  fx_infra 
1516from  ai_edge_torch  import  odml_torch 
1617import  numpy  as  np 
1718import  torch 
2021from  absl .testing  import  absltest  as  googletest 
2122
2223
24+ def  _reset_from_node_meta_and_lower (ep : torch .export .ExportedProgram ):
25+   """Lower the exported program with canonical history stack.""" 
26+   ep  =  fx_infra .graph_utils .reset_from_node_meta (ep )
27+   return  odml_torch .export .exported_program_to_mlir (ep )
28+ 
29+ 
30+ def  _is_aten_op (node : torch .fx .Node ) ->  bool :
31+   return  node .op  ==  "call_function"  and  not  node .name .startswith ("getitem" )
32+ 
33+ 
34+ class  AddModel (torch .nn .Module ):
35+   """A simple model that does addition.""" 
36+ 
37+   def  forward (self , x , y ):
38+     return  x  +  y  +  x  +  y 
39+ 
40+ 
2341class  TensorflowIntegrationTest (googletest .TestCase ):
2442
2543  def  setUp (self ):
@@ -29,11 +47,6 @@ def setUp(self):
2947  def  test_mlir_lowered_call (self ):
3048    """Test a simple model with MLIR lowered call.""" 
3149
32-     class  AddModel (torch .nn .Module ):
33- 
34-       def  forward (self , x , y ):
35-         return  x  +  y  +  x  +  y 
36- 
3750    model  =  AddModel ().eval ()
3851    forward_args  =  lambda : (torch .rand ((10 , 10 )), torch .rand ((10 , 10 )))
3952    ep  =  torch .export .export (model , forward_args ())
@@ -53,8 +66,7 @@ def test_resnet18(self):
5366    forward_args  =  lambda : (torch .rand ((1 , 3 , 224 , 224 )),)
5467
5568    ep  =  torch .export .export (model , forward_args ())
56- 
57-     lowered  =  odml_torch .export .exported_program_to_mlir (ep )
69+     lowered  =  _reset_from_node_meta_and_lower (ep )
5870
5971    args  =  forward_args ()
6072    torch_output  =  model (* args ).detach ().numpy ()
@@ -70,7 +82,7 @@ def test_debuginfo_from_export_lower(self):
7082    forward_args  =  lambda : (torch .rand ((1 , 3 , 224 , 224 )),)
7183
7284    ep  =  torch .export .export (model , forward_args ())
73-     lowered  =  odml_torch . export . exported_program_to_mlir (ep )
85+     lowered  =  _reset_from_node_meta_and_lower (ep )
7486
7587    lowered_text  =  lowered .get_text (enable_debug_info = True )
7688    # Check the file info. 
@@ -79,10 +91,35 @@ def test_debuginfo_from_export_lower(self):
7991    for  n  in  ep .graph .nodes :
8092      # Record all aten op nodes from the original graph and check if they 
8193      # are lowered to the same name in the lowered graph. 
82-       if  n . op   ==   "call_function"   and   not   n . name . startswith ( "getitem" ):
94+       if  _is_aten_op ( n ):
8395        # Ensure strings like `loc("relu__1"` are present in the lowered text. 
8496        self .assertIn (f'loc("{ n .name }  "' , lowered_text )
8597
98+   def  test_debuginfo_from_loaded_reexport_lower (self ):
99+     """Test the debuginfo with loaded reexport lower.""" 
100+ 
101+     model  =  AddModel ().eval ()
102+     forward_args  =  lambda : (torch .rand ((10 , 10 )), torch .rand ((10 , 10 )))
103+ 
104+     # Ensure the debuginfo is preserved after saving, loading and reexporting. 
105+     ep  =  torch .export .export (model , forward_args ())
106+     torch .export .save (ep , "/tmp/add_model.pt2" )
107+     loaded_ep  =  torch .export .load ("/tmp/add_model.pt2" )
108+     reexported_ep  =  torch .export .export (loaded_ep .module (), forward_args ())
109+     lowered  =  _reset_from_node_meta_and_lower (reexported_ep )
110+ 
111+     lowered_text  =  lowered .get_text (enable_debug_info = True )
112+     # Check the file info. 
113+     self .assertIn (
114+         "ai_edge_torch/odml_torch/test/test_tf_integration.py" , lowered_text 
115+     )
116+     # Check the fx node names. 
117+     for  n  in  reexported_ep .graph .nodes :
118+       # Record all aten op nodes from the original graph and check if they 
119+       # are lowered to the same name in the lowered graph. 
120+       if  _is_aten_op (n ):
121+         self .assertIn (f'loc("{ n .name }  "' , lowered_text )
122+ 
86123
87124if  __name__  ==  "__main__" :
88125  googletest .main ()
0 commit comments