Skip to content

Commit 4adc4b0

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Rewrite Memory Metadata Tagging Pass (#12927)
Summary: ## Context In ET-VK, tensors may be stored with either a GPU buffer or a GPU texture. They may also be stored with a specific memory layout: width packed, height packed, or channels packed. The memory layout controls which dimension will have its elements be adjacent in physical memory. In this way, the "representation" of tensors in ET-VK may be described with a storage type, memory layout pair. Operator implementations may only support certain tensor representations for inputs and outputs. Furthermore, implementations typically have expectations around which input/output tensors will share the same representation. Some examples: * Binary Operators: * I/O tensors may use any representation; however, all tensors in the op must use the same representation. i.e. If the first input tensor uses buffer storage, so must the other tensor and the output tensor * Native Group Norm: *Input tensors must be a channels packed texture. However, the op produces 3 outputs: the normalized tensor, the running mean, and the running stddev. The normalized tensor must use the same representation as the first input. However, the mean and stddev tensors are expected to be contiguous buffers. * Choose qparams: * The Input tensor can use any representation. However, the two output tensors (zero points and scales) will always be contiguous buffers * Dynamically quantized linear: * The input tensor can be either buffer or texture, but must be contiguous/width packed. The scales and zeros tensors for the inputs and weights must all be contiguous buffers. The output tensor must be the same representation as the input tensors. The operator registry (`op_registry.py`) is responsible for denoting these representational requirements for each op, and the `tag_memory_metadata_pass.py` graph pass is responsible for determining what representation each tensor in each operator should use. The graph pass is also responsible for inserting nodes to move input arguments to a required representation, if they have been created with a non-supported representation. ## Current Method Currently, the operator registry will indicate the following: * Are texture inputs supported for the op * If yes, which texture memory layouts are supported for inputs to the op * Are buffer inputs supported for the op * An "optimal" storage type and memory layout to use for inputs/outputs of the operator. The underlying assumption is that all tensors participating in an operator will use the same representation for all tensors. Although this assumption holds true for most operators, this assumption is clearly insufficient for some of the example operators described above, where some input tensors may require that certain inputs use specific representations that are different from other tensors. During export, the memory metadata tagging pass will go through each op and mark the tensors participating in the op with a valid representation for that op. It will ensure that all tensors participating in an op will use the same representation. To determine the representation to use, it accounts for three things in order of priority: * The "optimal" storage type and memory layout marked for the op in the operator registry * Any existing representation that have already been determined for input tensors * What representations are supported by users of the output tensor of the current op ## Goals of this diff The main goal of this diff is to address the problem that the current method of annotating tensor representation requirements for operators is insufficient for describing the tensor representation requirements for operator implementation. Critically, for operators like choose_qparams and dynamically quantized linear, the current system cannot ensure that all input/output tensors are using representations that are supported by the op impl, since the current system tries to make all tensors participating in an operator use the same representation. ## Changes ### `utils.py` First, in 'utils.py` I introduce several classes to abstract the concept of tensor representations and sets of possible tensor representations. `TensorRepr` represents a pair of storage type + memory layout which describes the representation to use for a single tensor. `TensorRepSet` represents the set of possible representations that may be used for a single tensor. `OpRepSet` manages the set of possible representations (i.e. `TensorRepSet`s) for all tensors participating in a operation. To do this, it accounts for 3 things: * The supported tensor representations for input/output that are denoted by the operator registration * The actual sizes of the tensor - some tensors may have dims that are too large to fit into a texture. * Sync requirements, i.e. requirements re: which tensors in the operation must use the same representation For the last point, `OpRepSet` accounts for three "rules" internally: * All input tensors must use the same representation * All output tensors must use the same representation * The "primary" (i.e. first) input and output tensors must use the same representation I have settled on these three rules for now since they adequately describe the possible requirements of all operators. These three rules are validated to be true at all times within `OpRepSet`. Since `TensorRepSet`s may be ambiguous (i.e. there are multiple possible representations that could be used), `OpRepSet` also provides utility functions to constrain the possible representation set of an input operator while maintaining the synchronization rules. I have also defined `TensorRepSet` instances like: * `utils.ANY_STORAGE` * `utils.CONTIGUOUS_BUFFER` * `utils.CHANNELS_PACKED_TEXTURE` as convenience definitions for common representation set configurations. ### `op_registry.py` Now, in `op_registry.py` operator registrations only need to define 2 things: `input_storages` and optionally `output_storages`, which describe the possible representation sets that may be used for input and output tensors. The registrations for each example operator would be: ``` # binary ops def register_binary_op(): return OpFeatures( inputs_storage=utils.ANY_STORAGE, supports_resize=True, ) # group norm def register_native_group_norm(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER, ], supports_prepacking=True, ) # choose qparams update_features( [ exir_ops.edge.torchao.choose_qparams_affine.default, ] ) def register_torchao_quantization_op(): return OpFeatures( inputs_storage=utils.CONTIGUOUS_ANY, outputs_storage=utils.CONTIGUOUS_BUFFER supports_resize=True, ) # DQ-Linear def register_linear_qta8a_qga4w_op(): return OpFeatures( inputs_storage=[ utils.CONTIGUOUS_ANY, # input utils.CONTIGUOUS_BUFFER, # mat1 scales utils.CONTIGUOUS_BUFFER, # mat1 zeros utils.NO_STORAGE, # weight (prepacked) utils.NO_STORAGE, # group size (non tensor) utils.CONTIGUOUS_BUFFER, # mat2 scales utils.CONTIGUOUS_BUFFER, # mat2 zeros ], supports_resize=True, supports_prepacking=True, ) ``` The 3 synchronization rules are inferred from the defined `inputs_storage` and `outputs_storage`: * If no `outputs_storage` is defined, then assume that the `outputs_storage` is the same as the first `TensorRepSet` in `inputs_storage`. This also implies that the primary input and output need to be synced * If `inputs_storage` only contains a single `TensorRepSet`, it is assumed that all input tensors need to be synchronized. * Similarly, if `outputs_storage` only contains a single `TensorRepSet`, it is assumed that all output tensors need to be synchronized * If the first entry in `inputs_storage` and `outputs_storage` are the same, assume that the primary input and output need to be synced. ### `tag_memory_metadata_pass.py` The `tag_memory_metadata_pass.py` maintains the same scope and behaviour as before. However, it is almost re-written completely to use `OpRepSet` utility class. However, it goes through the same steps as before: * For each operator, determine the initial `OpRepSets` * Constrain the initial `OpRepSets` by checking any existing representations of input tensors, and checking future uses of the output tensor(s) to try and reduce the number of representation transitions needed * Set the representation of each input/output tensor in the operator. If an input tensor requires a different representation than it currently has, insert a clone node to transition the arg to the required representation. Differential Revision: D79116560
1 parent 339e95f commit 4adc4b0

File tree

8 files changed

+1367
-707
lines changed

8 files changed

+1367
-707
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3535

3636
# Mark that this node is going to be represented as a TensorRef type in the
3737
# Vulkan compute graph. This annotation is used in later graph passes.
38-
node.meta["vkdg_tensorref"] = True
38+
node.meta["etvk_tensorref"] = True
3939

4040
# Get the list of node users that do not handle their own prepacking
4141
nodes_to_replace_input = []

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
5252

5353
for user in node.users:
5454
if node_is_local_scalar_dense_chain(user):
55-
node.meta["vkdg_is_scalar_tensor"] = True
55+
node.meta["etvk_is_scalar_tensor"] = True
5656

5757

5858
def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
@@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node)
7474
if replace_node.args[0].meta["val"].numel() == 1:
7575
replace_node = replace_node.args[0]
7676
assert isinstance(replace_node, torch.fx.Node)
77-
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)
77+
assert replace_node.meta.get("etvk_is_scalar_tensor", True)
7878

7979
with graph.inserting_after(node):
8080
node.replace_all_uses_with(replace_node)

0 commit comments

Comments
 (0)