Skip to content

Commit 0859efb

Browse files
authored
[ET-VK] Don't specify memory layouts when testing
Differential Revision: D67180897 Pull Request resolved: #7322
1 parent 0b1c1e5 commit 0859efb

File tree

3 files changed

+43
-86
lines changed

3 files changed

+43
-86
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
2525

26-
from torch._subclasses.fake_tensor import FakeTensor
27-
2826
from torch.fx.passes.tools_common import NodeList
2927
from torch.fx.passes.utils.fuser_utils import topo_sort
3028

@@ -138,9 +136,7 @@ def propose_node_storage(
138136
return storage
139137

140138
for arg in node.args:
141-
if isinstance(arg, torch.fx.Node) and isinstance(
142-
arg.meta["val"], FakeTensor
143-
):
139+
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
144140
storage = utils.get_node_storage_type(arg)
145141
if storage is not None and storage in valid_storage_types:
146142
return storage
@@ -178,9 +174,7 @@ def propose_node_layout(
178174
return layout
179175

180176
for arg in node.args:
181-
if isinstance(arg, torch.fx.Node) and isinstance(
182-
arg.meta["val"], FakeTensor
183-
):
177+
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
184178
layout = utils.get_node_memory_layout(arg)
185179
if layout is not None and layout in valid_layouts:
186180
return layout
@@ -202,14 +196,19 @@ def should_annotate(self, node) -> bool:
202196
if not isinstance(node, torch.fx.Node):
203197
return False
204198

205-
if not isinstance(node.meta["val"], FakeTensor):
199+
if not utils.is_tensor_node(node):
206200
return False
207201

208202
# Storage type and memory layout for tensorref will be determined at runtime
209203
# so there's no use in setting those attributes ahead of time.
210204
if node.meta.get("vkdg_tensorref", False):
211205
return False
212206

207+
# Skip annotating output node. The output tensors should be annotated by the
208+
# time the output node is observed.
209+
if node.op == "output":
210+
return False
211+
213212
return True
214213

215214
def should_delay_annotation(self, node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)