Commit f763776
Rewrite Memory Metadata Tagging Pass
Summary:
## Context
Operator implementations in the Vulkan delegate may require that input and output tensors use a specific representation. Representation in this case refers to a combination of storage type (buffer or texture) and memory layout (width, height, or channels packed).
The tag memory metadata pass is responsible for marking each tensor in the graph with the appropriate representation to use. It is also responsible for inserting operators to transition argument tensors to a required/compatible representation if a mismatch has been detected.
The memory metadata tagging pass uses the operator registry to determine what tensor representations are valid for the inputs and outputs of a given operator. When operators are registered, fields like `has_buffer_impl`, `texture_impl`, `optimal_storage`, etc. are used to annotate what tensor representations are supported by a given operator.
However, the current implementation of the operator registry and the memory metadata tagging pass assumes that all tensors participating in a given operator must use the same representation. As of late, quantization and normalization operators have been added that break this assumption; their implementations require certain inputs/outputs to use specific tensor representations, which do not need to be the same as other tensors participating in the op.
The goal of this diff is to introduce a better (i.e. more flexible) way to express the tensor representation requirements of an operator, and re-implement the memory metadata tagging pass to be able to account for the certain inputs/outputs tensors require a specific representation.
**More specifically, this is required to unblock dynamic quantization since some quantized operator implementations need scales/zeros to be contiguous buffers, regardless of the representation used for other tensors.**
## Changes
Introduce several utility classes to aid in expressing the possible representations of a tensors.
`TensorRepr` represents a pair of storage type + memory layout which describes the representation to use for a single tensor.
`TensorRepSet` represents the set of possible representations that may be used for a single tensor. This is needed because a given operator may support multiple different representations.
`OpRepSet` maintains the set of possible representations (i.e. `RepSet`s) for all tensors participating in an operator.
Please see the docstrings for these new classes for more context.
All functionality related to determining or checking tensor representation is now centered around the new `OpRepSet` class, which automatically maintains rules about which tensors in an operator should use the same representation and provides utilities to constrain representation sets based on pre-existing input representations.
The `tag_memory_metadata_pass.py` has been rewritten to use the `OpRepSet` utility class.
Another consequence of these changes is to simplify how operator implementations are registered. Instead of defining `texture_impl` and `buffer_impl` separately, registration now directly specifies what storage types are valid for inputs and outputs. Sync rules that require inputs/outputs to have the same representation are inferred.
Differential Revision: D791165601 parent 37e3003 commit f763776
File tree
8 files changed
+1368
-707
lines changed- backends/vulkan
- _passes
- partitioner
- serialization
- test
8 files changed
+1368
-707
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
38 | | - | |
| 38 | + | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
55 | | - | |
| 55 | + | |
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
77 | | - | |
| 77 | + | |
78 | 78 | | |
79 | 79 | | |
80 | 80 | | |
| |||
0 commit comments