Commit 4adc4b0
Rewrite Memory Metadata Tagging Pass (#12927)
Summary:
## Context
In ET-VK, tensors may be stored with either a GPU buffer or a GPU texture. They may also be stored with a specific memory layout: width packed, height packed, or channels packed. The memory layout controls which dimension will have its elements be adjacent in physical memory.
In this way, the "representation" of tensors in ET-VK may be described with a storage type, memory layout pair.
Operator implementations may only support certain tensor representations for inputs and outputs. Furthermore, implementations typically have expectations around which input/output tensors will share the same representation.
Some examples:
* Binary Operators:
* I/O tensors may use any representation; however, all tensors in the op must use the same representation. i.e. If the first input tensor uses buffer storage, so must the other tensor and the output tensor
* Native Group Norm:
*Input tensors must be a channels packed texture. However, the op produces 3 outputs: the normalized tensor, the running mean, and the running stddev. The normalized tensor must use the same representation as the first input. However, the mean and stddev tensors are expected to be contiguous buffers.
* Choose qparams:
* The Input tensor can use any representation. However, the two output tensors (zero points and scales) will always be contiguous buffers
* Dynamically quantized linear:
* The input tensor can be either buffer or texture, but must be contiguous/width packed. The scales and zeros tensors for the inputs and weights must all be contiguous buffers. The output tensor must be the same representation as the input tensors.
The operator registry (`op_registry.py`) is responsible for denoting these representational requirements for each op, and the `tag_memory_metadata_pass.py` graph pass is responsible for determining what representation each tensor in each operator should use. The graph pass is also responsible for inserting nodes to move input arguments to a required representation, if they have been created with a non-supported representation.
## Current Method
Currently, the operator registry will indicate the following:
* Are texture inputs supported for the op
* If yes, which texture memory layouts are supported for inputs to the op
* Are buffer inputs supported for the op
* An "optimal" storage type and memory layout to use for inputs/outputs of the operator.
The underlying assumption is that all tensors participating in an operator will use the same representation for all tensors. Although this assumption holds true for most operators, this assumption is clearly insufficient for some of the example operators described above, where some input tensors may require that certain inputs use specific representations that are different from other tensors.
During export, the memory metadata tagging pass will go through each op and mark the tensors participating in the op with a valid representation for that op. It will ensure that all tensors participating in an op will use the same representation. To determine the representation to use, it accounts for three things in order of priority:
* The "optimal" storage type and memory layout marked for the op in the operator registry
* Any existing representation that have already been determined for input tensors
* What representations are supported by users of the output tensor of the current op
## Goals of this diff
The main goal of this diff is to address the problem that the current method of annotating tensor representation requirements for operators is insufficient for describing the tensor representation requirements for operator implementation.
Critically, for operators like choose_qparams and dynamically quantized linear, the current system cannot ensure that all input/output tensors are using representations that are supported by the op impl, since the current system tries to make all tensors participating in an operator use the same representation.
## Changes
### `utils.py`
First, in 'utils.py` I introduce several classes to abstract the concept of tensor representations and sets of possible tensor representations.
`TensorRepr` represents a pair of storage type + memory layout which describes the representation to use for a single tensor.
`TensorRepSet` represents the set of possible representations that may be used for a single tensor.
`OpRepSet` manages the set of possible representations (i.e. `TensorRepSet`s) for all tensors participating in a operation. To do this, it accounts for 3 things:
* The supported tensor representations for input/output that are denoted by the operator registration
* The actual sizes of the tensor - some tensors may have dims that are too large to fit into a texture.
* Sync requirements, i.e. requirements re: which tensors in the operation must use the same representation
For the last point, `OpRepSet` accounts for three "rules" internally:
* All input tensors must use the same representation
* All output tensors must use the same representation
* The "primary" (i.e. first) input and output tensors must use the same representation
I have settled on these three rules for now since they adequately describe the possible requirements of all operators.
These three rules are validated to be true at all times within `OpRepSet`. Since `TensorRepSet`s may be ambiguous (i.e. there are multiple possible representations that could be used), `OpRepSet` also provides utility functions to constrain the possible representation set of an input operator while maintaining the synchronization rules.
I have also defined `TensorRepSet` instances like:
* `utils.ANY_STORAGE`
* `utils.CONTIGUOUS_BUFFER`
* `utils.CHANNELS_PACKED_TEXTURE`
as convenience definitions for common representation set configurations.
### `op_registry.py`
Now, in `op_registry.py` operator registrations only need to define 2 things: `input_storages` and optionally `output_storages`, which describe the possible representation sets that may be used for input and output tensors.
The registrations for each example operator would be:
```
# binary ops
def register_binary_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
# group norm
def register_native_group_norm():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
outputs_storage=[
utils.CHANNELS_PACKED_TEXTURE,
utils.CONTIGUOUS_BUFFER,
utils.CONTIGUOUS_BUFFER,
],
supports_prepacking=True,
)
# choose qparams
update_features(
[
exir_ops.edge.torchao.choose_qparams_affine.default,
]
)
def register_torchao_quantization_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
outputs_storage=utils.CONTIGUOUS_BUFFER
supports_resize=True,
)
# DQ-Linear
def register_linear_qta8a_qga4w_op():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.CONTIGUOUS_BUFFER, # mat1 scales
utils.CONTIGUOUS_BUFFER, # mat1 zeros
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # group size (non tensor)
utils.CONTIGUOUS_BUFFER, # mat2 scales
utils.CONTIGUOUS_BUFFER, # mat2 zeros
],
supports_resize=True,
supports_prepacking=True,
)
```
The 3 synchronization rules are inferred from the defined `inputs_storage` and `outputs_storage`:
* If no `outputs_storage` is defined, then assume that the `outputs_storage` is the same as the first `TensorRepSet` in `inputs_storage`. This also implies that the primary input and output need to be synced
* If `inputs_storage` only contains a single `TensorRepSet`, it is assumed that all input tensors need to be synchronized.
* Similarly, if `outputs_storage` only contains a single `TensorRepSet`, it is assumed that all output tensors need to be synchronized
* If the first entry in `inputs_storage` and `outputs_storage` are the same, assume that the primary input and output need to be synced.
### `tag_memory_metadata_pass.py`
The `tag_memory_metadata_pass.py` maintains the same scope and behaviour as before. However, it is almost re-written completely to use `OpRepSet` utility class. However, it goes through the same steps as before:
* For each operator, determine the initial `OpRepSets`
* Constrain the initial `OpRepSets` by checking any existing representations of input tensors, and checking future uses of the output tensor(s) to try and reduce the number of representation transitions needed
* Set the representation of each input/output tensor in the operator. If an input tensor requires a different representation than it currently has, insert a clone node to transition the arg to the required representation.
Differential Revision: D791165601 parent 339e95f commit 4adc4b0
File tree
8 files changed
+1367
-707
lines changed- backends/vulkan
- _passes
- partitioner
- serialization
- test
8 files changed
+1367
-707
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
38 | | - | |
| 38 | + | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
55 | | - | |
| 55 | + | |
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
77 | | - | |
| 77 | + | |
78 | 78 | | |
79 | 79 | | |
80 | 80 | | |
| |||
0 commit comments