2323
2424from  executorch .exir .pass_base  import  ExportPass , PassResult 
2525
26- from  torch ._subclasses .fake_tensor  import  FakeTensor 
27- 
2826from  torch .fx .passes .tools_common  import  NodeList 
2927from  torch .fx .passes .utils .fuser_utils  import  topo_sort 
3028
@@ -138,9 +136,7 @@ def propose_node_storage(
138136                return  storage 
139137
140138        for  arg  in  node .args :
141-             if  isinstance (arg , torch .fx .Node ) and  isinstance (
142-                 arg .meta ["val" ], FakeTensor 
143-             ):
139+             if  isinstance (arg , torch .fx .Node ) and  utils .is_tensor_node (arg ):
144140                storage  =  utils .get_node_storage_type (arg )
145141                if  storage  is  not None  and  storage  in  valid_storage_types :
146142                    return  storage 
@@ -178,9 +174,7 @@ def propose_node_layout(
178174                return  layout 
179175
180176        for  arg  in  node .args :
181-             if  isinstance (arg , torch .fx .Node ) and  isinstance (
182-                 arg .meta ["val" ], FakeTensor 
183-             ):
177+             if  isinstance (arg , torch .fx .Node ) and  utils .is_tensor_node (arg ):
184178                layout  =  utils .get_node_memory_layout (arg )
185179                if  layout  is  not None  and  layout  in  valid_layouts :
186180                    return  layout 
@@ -202,14 +196,19 @@ def should_annotate(self, node) -> bool:
202196        if  not  isinstance (node , torch .fx .Node ):
203197            return  False 
204198
205-         if  not  isinstance (node . meta [ "val" ],  FakeTensor ):
199+         if  not  utils . is_tensor_node (node ):
206200            return  False 
207201
208202        # Storage type and memory layout for tensorref will be determined at runtime 
209203        # so there's no use in setting those attributes ahead of time. 
210204        if  node .meta .get ("vkdg_tensorref" , False ):
211205            return  False 
212206
207+         # Skip annotating output node. The output tensors should be annotated by the 
208+         # time the output node is observed. 
209+         if  node .op  ==  "output" :
210+             return  False 
211+ 
213212        return  True 
214213
215214    def  should_delay_annotation (self , node : torch .fx .Node ) ->  bool :
0 commit comments