2323
2424from executorch .exir .pass_base import ExportPass , PassResult
2525
26- from torch ._subclasses .fake_tensor import FakeTensor
27-
2826from torch .fx .passes .tools_common import NodeList
2927from torch .fx .passes .utils .fuser_utils import topo_sort
3028
@@ -138,9 +136,7 @@ def propose_node_storage(
138136 return storage
139137
140138 for arg in node .args :
141- if isinstance (arg , torch .fx .Node ) and isinstance (
142- arg .meta ["val" ], FakeTensor
143- ):
139+ if isinstance (arg , torch .fx .Node ) and utils .is_tensor_node (arg ):
144140 storage = utils .get_node_storage_type (arg )
145141 if storage is not None and storage in valid_storage_types :
146142 return storage
@@ -178,9 +174,7 @@ def propose_node_layout(
178174 return layout
179175
180176 for arg in node .args :
181- if isinstance (arg , torch .fx .Node ) and isinstance (
182- arg .meta ["val" ], FakeTensor
183- ):
177+ if isinstance (arg , torch .fx .Node ) and utils .is_tensor_node (arg ):
184178 layout = utils .get_node_memory_layout (arg )
185179 if layout is not None and layout in valid_layouts :
186180 return layout
@@ -202,14 +196,19 @@ def should_annotate(self, node) -> bool:
202196 if not isinstance (node , torch .fx .Node ):
203197 return False
204198
205- if not isinstance (node . meta [ "val" ], FakeTensor ):
199+ if not utils . is_tensor_node (node ):
206200 return False
207201
208202 # Storage type and memory layout for tensorref will be determined at runtime
209203 # so there's no use in setting those attributes ahead of time.
210204 if node .meta .get ("vkdg_tensorref" , False ):
211205 return False
212206
207+ # Skip annotating output node. The output tensors should be annotated by the
208+ # time the output node is observed.
209+ if node .op == "output" :
210+ return False
211+
213212 return True
214213
215214 def should_delay_annotation (self , node : torch .fx .Node ) -> bool :
0 commit comments