Skip to content

Commit f763776

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Rewrite Memory Metadata Tagging Pass
Summary: ## Context Operator implementations in the Vulkan delegate may require that input and output tensors use a specific representation. Representation in this case refers to a combination of storage type (buffer or texture) and memory layout (width, height, or channels packed). The tag memory metadata pass is responsible for marking each tensor in the graph with the appropriate representation to use. It is also responsible for inserting operators to transition argument tensors to a required/compatible representation if a mismatch has been detected. The memory metadata tagging pass uses the operator registry to determine what tensor representations are valid for the inputs and outputs of a given operator. When operators are registered, fields like `has_buffer_impl`, `texture_impl`, `optimal_storage`, etc. are used to annotate what tensor representations are supported by a given operator. However, the current implementation of the operator registry and the memory metadata tagging pass assumes that all tensors participating in a given operator must use the same representation. As of late, quantization and normalization operators have been added that break this assumption; their implementations require certain inputs/outputs to use specific tensor representations, which do not need to be the same as other tensors participating in the op. The goal of this diff is to introduce a better (i.e. more flexible) way to express the tensor representation requirements of an operator, and re-implement the memory metadata tagging pass to be able to account for the certain inputs/outputs tensors require a specific representation. **More specifically, this is required to unblock dynamic quantization since some quantized operator implementations need scales/zeros to be contiguous buffers, regardless of the representation used for other tensors.** ## Changes Introduce several utility classes to aid in expressing the possible representations of a tensors. `TensorRepr` represents a pair of storage type + memory layout which describes the representation to use for a single tensor. `TensorRepSet` represents the set of possible representations that may be used for a single tensor. This is needed because a given operator may support multiple different representations. `OpRepSet` maintains the set of possible representations (i.e. `RepSet`s) for all tensors participating in an operator. Please see the docstrings for these new classes for more context. All functionality related to determining or checking tensor representation is now centered around the new `OpRepSet` class, which automatically maintains rules about which tensors in an operator should use the same representation and provides utilities to constrain representation sets based on pre-existing input representations. The `tag_memory_metadata_pass.py` has been rewritten to use the `OpRepSet` utility class. Another consequence of these changes is to simplify how operator implementations are registered. Instead of defining `texture_impl` and `buffer_impl` separately, registration now directly specifies what storage types are valid for inputs and outputs. Sync rules that require inputs/outputs to have the same representation are inferred. Differential Revision: D79116560
1 parent 37e3003 commit f763776

File tree

8 files changed

+1368
-707
lines changed

8 files changed

+1368
-707
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3535

3636
# Mark that this node is going to be represented as a TensorRef type in the
3737
# Vulkan compute graph. This annotation is used in later graph passes.
38-
node.meta["vkdg_tensorref"] = True
38+
node.meta["etvk_tensorref"] = True
3939

4040
# Get the list of node users that do not handle their own prepacking
4141
nodes_to_replace_input = []

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
5252

5353
for user in node.users:
5454
if node_is_local_scalar_dense_chain(user):
55-
node.meta["vkdg_is_scalar_tensor"] = True
55+
node.meta["etvk_is_scalar_tensor"] = True
5656

5757

5858
def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
@@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node)
7474
if replace_node.args[0].meta["val"].numel() == 1:
7575
replace_node = replace_node.args[0]
7676
assert isinstance(replace_node, torch.fx.Node)
77-
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
77+
assert replace_node.meta.get("etvk_is_scalar_tensor", True)
7878

7979
with graph.inserting_after(node):
8080
node.replace_all_uses_with(replace_node)

0 commit comments

Comments
 (0)