Skip to content

Commit f0b6c06

Browse files
committed
[ET-VK] Introduce memory metadata tagging pass
## Context As title; implements the memory metadata tagging graph transform described in the dependent diff. See the comments for more details. Differential Revision: [D65428842](https://our.internmc.facebook.com/intern/diff/D65428842/) [ghstack-poisoned]
1 parent ba4bb54 commit f0b6c06

File tree

11 files changed

+339
-15
lines changed

11 files changed

+339
-15
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ runtime.python_library(
1616
],
1717
)
1818

19+
runtime.python_library(
20+
name = "int4_weight_only_quantizer",
21+
srcs = [
22+
"int4_weight_only_quantizer.py",
23+
],
24+
visibility = [
25+
"//executorch/backends/...",
26+
],
27+
deps = [
28+
"//executorch/backends/vulkan:custom_ops_lib",
29+
"//pytorch/ao:torchao",
30+
]
31+
)
32+
1933
runtime.python_library(
2034
name = "remove_local_scalar_dense",
2135
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -30,17 +44,18 @@ runtime.python_library(
3044
)
3145

3246
runtime.python_library(
33-
name = "int4_weight_only_quantizer",
34-
srcs = [
35-
"int4_weight_only_quantizer.py",
36-
],
47+
name = "tag_memory_meta_pass",
48+
srcs = ["tag_memory_meta_pass.py"],
3749
visibility = [
3850
"//executorch/backends/...",
3951
],
4052
deps = [
41-
"//executorch/backends/vulkan:custom_ops_lib",
42-
"//pytorch/ao:torchao",
43-
]
53+
"//caffe2:torch",
54+
"//executorch/exir:pass_base",
55+
"//executorch/exir/dialects:lib",
56+
"//executorch/backends/vulkan:utils_lib",
57+
"//executorch/backends/vulkan/serialization:lib",
58+
],
4459
)
4560

4661
runtime.python_library(
@@ -56,5 +71,6 @@ runtime.python_library(
5671
":insert_prepack_nodes",
5772
":int4_weight_only_quantizer",
5873
":remove_local_scalar_dense",
74+
":tag_memory_meta_pass"
5975
]
6076
)

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
66
RemoveLocalScalarDenseOpsTransform,
77
)
8+
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
89

910
__all__ = [
1011
"insert_prepack_nodes",
1112
"VkInt4WeightOnlyQuantizer",
1213
"RemoveLocalScalarDenseOpsTransform",
14+
"TagMemoryMetaPass",
1315
]
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from copy import deepcopy
9+
from typing import Set
10+
11+
import executorch.backends.vulkan.utils as utils
12+
13+
import torch
14+
15+
from executorch.backends.vulkan.op_registry import get_op_features, has_impl
16+
17+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18+
VkMemoryLayout,
19+
VkStorageType,
20+
)
21+
22+
from executorch.exir.dialects._ops import ops as exir_ops
23+
24+
from executorch.exir.pass_base import ExportPass, PassResult
25+
26+
from torch._subclasses.fake_tensor import FakeTensor
27+
28+
from torch.fx.passes.tools_common import NodeList
29+
from torch.fx.passes.utils.fuser_utils import topo_sort
30+
31+
logger: logging.Logger = logging.getLogger("")
32+
logger.setLevel(logging.INFO)
33+
34+
35+
def set_memory_metadata(
36+
node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout
37+
) -> None:
38+
utils.set_node_spec_attr(node, "vk_storage_type", storage)
39+
utils.set_node_spec_attr(node, "vk_memory_layout", layout)
40+
41+
42+
class TagMemoryMetaPass(ExportPass):
43+
"""
44+
There are a variety of ways that tensors can be represented in Vulkan. The two main
45+
descriptors for how a tensor is laid out in memory is:
46+
47+
1. Storage Type (buffer or texture)
48+
2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.)
49+
50+
Due to the differences between buffers and textures, and the differences between
51+
different memory layouts, an implementation for an operator may only support a
52+
specific set of (storage type, memory layout) combinations.
53+
54+
Furthermore, if an operator implementation supports multiple (storage type, memory
55+
layout) combinations, there may be a "preferred" setting which results in optimal
56+
performance.
57+
58+
This pass is responsible for ensuring that all tensors participating in an operator
59+
call is has a valid/optimal (storage type, memory layout) setting, and insert
60+
transition operators to transfer input tensors to the correct memory settings when
61+
necessary.
62+
"""
63+
64+
def __init__(self):
65+
super().__init__()
66+
self.default_storage: VkStorageType = VkStorageType.DEFAULT_STORAGE
67+
self.default_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
68+
self.texture_limits = (16384, 16384, 2048)
69+
70+
def propose_node_storage(
71+
self,
72+
node: torch.fx.Node,
73+
) -> VkStorageType:
74+
"""
75+
Uses the operator registry to determine the storage type that should be used for
76+
a given node. The storage type is determined with the following priorities:
77+
1. In some cases, a tensor involved in the computation may be too large to be
78+
represented as a texture. If this is the case, the node is "opinionated" and
79+
buffer representation must be used.
80+
1. If the operator called by the node indicates a optimal storage type, or only
81+
supports a single storage type, use that storage type. If either is true,
82+
then the node is considered to be opinionated as well. If multiple storage
83+
and no preferred storage type is indicated, then the node is not opinionated;
84+
go to the next step.
85+
2. If the node's arguments already have memory metadata annotations, then
86+
preserve the settings of the first argument. Otherwise, proceed to the next
87+
step.
88+
3. Recursively search the node's uses to see if any subsequent uses are
89+
opinionated; inherit the settings of the first opinionated node. If no
90+
opinionated user can be found, then proceed to the last step.
91+
4. Use the default storage type setting.
92+
"""
93+
# The node may have an input/output tensor that is too big to be stored in a
94+
# texture. In this case, buffer storage must be used. Note that the partitioner
95+
# has already checked for the fact that buffer storage is supported by the
96+
# operator.
97+
if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0:
98+
return VkStorageType.BUFFER
99+
100+
valid_storage_types: Set[VkStorageType] = utils.all_storage_types
101+
102+
# pyre-ignore
103+
if has_impl(node.target):
104+
# pyre-ignore
105+
features = get_op_features(node.target)
106+
valid_storage_types = features.supported_storage_types()
107+
storage = features.propose_storage_type()
108+
if storage is not None:
109+
return storage
110+
111+
for arg in node.args:
112+
if isinstance(arg, torch.fx.Node) and isinstance(
113+
arg.meta["val"], FakeTensor
114+
):
115+
storage = utils.get_node_storage_type(arg)
116+
if storage is not None and storage in valid_storage_types:
117+
return storage
118+
119+
# If no storage type has been resolved yet, assume the optimal storage type of
120+
# the first opinionated user. This search is recursive.
121+
for user in node.users:
122+
optimal_storage = self.propose_node_storage(user)
123+
if optimal_storage is not None:
124+
return optimal_storage
125+
126+
if self.default_storage in valid_storage_types:
127+
return self.default_storage
128+
else:
129+
return next(iter(valid_storage_types))
130+
131+
def propose_node_layout(
132+
self,
133+
node: torch.fx.Node,
134+
storage: VkStorageType,
135+
) -> VkMemoryLayout:
136+
"""
137+
Performs the same steps as propose_node_storage, but detects the memory layout
138+
that should be used for the specific storage type. The same prioritization logic
139+
is applied.
140+
"""
141+
valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
142+
# pyre-ignore
143+
if has_impl(node.target):
144+
# pyre-ignore
145+
features = get_op_features(node.target)
146+
valid_layouts = features.supported_memory_layouts(storage)
147+
layout = features.propose_memory_layout(storage)
148+
if layout is not None:
149+
return layout
150+
151+
for arg in node.args:
152+
if isinstance(arg, torch.fx.Node) and isinstance(
153+
arg.meta["val"], FakeTensor
154+
):
155+
layout = utils.get_node_memory_layout(arg)
156+
if layout is not None and layout in valid_layouts:
157+
return layout
158+
159+
# If no storage type has been resolved yet, assume the optimal storage type of
160+
# the first opinionated user. This search is recursive.
161+
for user in node.users:
162+
optimal_storage = self.propose_node_layout(user, storage)
163+
if optimal_storage is not None:
164+
return optimal_storage
165+
166+
# As a last resort, return the default storage type that should be used.
167+
if self.default_layout in valid_layouts:
168+
return self.default_layout
169+
else:
170+
return next(iter(valid_layouts))
171+
172+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
173+
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
174+
175+
for node in sorted_nodes:
176+
if not isinstance(node.meta["val"], FakeTensor):
177+
continue
178+
179+
if node.target == exir_ops.edge.et_vk.prepack.default:
180+
continue
181+
182+
storage = self.propose_node_storage(node)
183+
layout = self.propose_node_layout(node, storage)
184+
185+
set_memory_metadata(node, storage, layout)
186+
187+
inserting_transitions_for_node = False
188+
for i, arg in enumerate(node.args):
189+
if not isinstance(arg, torch.fx.Node):
190+
continue
191+
if not isinstance(arg.meta["val"], FakeTensor):
192+
continue
193+
194+
arg_storage = utils.get_node_storage_type(arg)
195+
arg_layout = utils.get_node_memory_layout(arg)
196+
197+
if arg_storage is None:
198+
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
199+
arg_storage = storage
200+
if arg_layout is None:
201+
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
202+
arg_layout = layout
203+
204+
if arg_storage == storage and arg_layout == layout:
205+
continue
206+
207+
if not inserting_transitions_for_node:
208+
inserting_transitions_for_node = True
209+
logger.info(
210+
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
211+
)
212+
213+
logger.info(
214+
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
215+
)
216+
217+
# Insert a clone node to copy the original tensor to a tensor with the
218+
# desired storage type and memory layout.
219+
with graph_module.graph.inserting_before(node):
220+
clone_node = graph_module.graph.create_node(
221+
"call_function",
222+
exir_ops.edge.aten.clone.default,
223+
(arg,),
224+
)
225+
clone_node.meta["val"] = arg.meta["val"]
226+
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
227+
clone_node.meta["spec"].const = False
228+
set_memory_metadata(clone_node, storage, layout)
229+
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
230+
231+
return PassResult(graph_module, True)

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ def op_node_is_compatible(
9595
# If there are no valid texture memory layouts, then buffer storage must be
9696
# supported by the operator implementation.
9797
if len(valid_texture_layouts) == 0:
98-
# TODO: once memory metadata tagging pass is implemented, check that the
99-
# op impl supports buffers instead
100-
return False, "requires buffer representation"
98+
compatible = VkStorageType.BUFFER in op_features.supported_storage_types()
99+
reason = "op is compatible"
100+
if not compatible:
101+
reason = "op requires buffers which is not supported by op impl"
102+
return compatible, reason
101103

102104
op_available_layouts = op_features.supported_memory_layouts(
103105
VkStorageType.TEXTURE_3D

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1313

1414
import torch
15+
16+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
17+
VkMemoryLayout,
18+
VkStorageType,
19+
)
1520
from executorch.backends.vulkan.utils import (
1621
is_constant,
1722
is_get_attr_node,
@@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
169174
if spec.mem_obj_id is not None:
170175
mem_obj_id = spec.mem_obj_id
171176

177+
storage_type = VkStorageType.DEFAULT_STORAGE
178+
memory_layout = VkMemoryLayout.DEFAULT_LAYOUT
179+
if hasattr(spec, "vk_storage_type"):
180+
# pyre-ignore[16]
181+
storage_type = spec.vk_storage_type
182+
if hasattr(spec, "vk_memory_layout"):
183+
# pyre-ignore[16]
184+
memory_layout = spec.vk_memory_layout
185+
172186
new_id = len(self.values)
173187
self.values.append(
174188
vk_graph_schema.VkValue(
@@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
177191
dims=spec.shape,
178192
constant_id=constant_id,
179193
mem_obj_id=mem_obj_id,
194+
storage_type=storage_type,
195+
memory_layout=memory_layout,
180196
)
181197
)
182198
)

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,19 @@ class VkStorageType(IntEnum):
3737
TEXTURE_2D = 2
3838
DEFAULT_STORAGE = 255
3939

40+
def __str__(self) -> str:
41+
return self.name
42+
4043

4144
class VkMemoryLayout(IntEnum):
4245
TENSOR_WIDTH_PACKED = 0
4346
TENSOR_HEIGHT_PACKED = 1
4447
TENSOR_CHANNELS_PACKED = 2
4548
DEFAULT_LAYOUT = 255
4649

50+
def __str__(self) -> str:
51+
return self.name
52+
4753

4854
@dataclass
4955
class VkTensor:

backends/vulkan/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def define_common_targets(is_fbcode = False):
223223
],
224224
deps = [
225225
"//caffe2:torch",
226+
"//executorch/exir:tensor",
227+
"//executorch/backends/vulkan/serialization:lib",
226228
]
227229
)
228230

0 commit comments

Comments
 (0)