diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index ed736438cbb..c45ed4ea25d 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Mark that this node is going to be represented as a TensorRef type in the # Vulkan compute graph. This annotation is used in later graph passes. - node.meta["vkdg_tensorref"] = True + node.meta["etvk_tensorref"] = True # Get the list of node users that do not handle their own prepacking nodes_to_replace_input = [] diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py index 4c4b8c265af..6ce3572ec0c 100644 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py @@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: for user in node.users: if node_is_local_scalar_dense_chain(user): - node.meta["vkdg_is_scalar_tensor"] = True + node.meta["etvk_is_scalar_tensor"] = True def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: @@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) if replace_node.args[0].meta["val"].numel() == 1: replace_node = replace_node.args[0] assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("vkdg_is_scalar_tensor", True) + assert replace_node.meta.get("etvk_is_scalar_tensor", True) with graph.inserting_after(node): node.replace_all_uses_with(replace_node) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 0bd8dae0b66..db53cc666a8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -5,13 +5,15 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Optional, Set +import operator + +from typing import Any import executorch.backends.vulkan.utils as utils import torch -from executorch.backends.vulkan.op_registry import get_op_features, has_impl +from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -27,23 +29,16 @@ logger.setLevel(logging.INFO) -def set_memory_metadata( - node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout -) -> None: - utils.set_node_spec_attr(node, "vk_storage_type", storage) - utils.set_node_spec_attr(node, "vk_memory_layout", layout) - - def insert_transition_node( graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg: torch.fx.Node, - storage: VkStorageType, - layout: VkMemoryLayout, + arg_node_repr: utils.TensorRepr, ) -> None: """ - Insert a clone node to copy the original tensor to a tensor with the desired storage - type and memory layout. + Insert a clone node to transition the tensor associated with `arg` to a tensor with + the requested representation `arg_node_repr`, and use the cloned node as an argument + to `node` instead of `arg`. """ with graph_module.graph.inserting_before(node): clone_node = graph_module.graph.create_node( @@ -54,30 +49,80 @@ def insert_transition_node( clone_node.meta["val"] = arg.meta["val"] clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"]) clone_node.meta["spec"].const = False - set_memory_metadata(clone_node, storage, layout) + utils.set_node_repr(clone_node, arg_node_repr) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) -class TagMemoryMetaPass(ExportPass): +def set_arg_node_repr_or_transition( + graph_module: torch.fx.GraphModule, + op_node: torch.fx.Node, + arg_i: int, + arg_node_repr: utils.TensorRepr, + dirty: bool, +) -> bool: """ - There are a variety of ways that tensors can be represented in Vulkan. The two main - descriptors for how a tensor is laid out in memory is: + Does one of following: + 1. Sets the `node_repr` of the argument at `arg_i` of `op_node` if the argument node + does not currently have a `node_repr` + 2. No-op if the current `node_repr` is already the same as the requested represetnation. + 3. Insert a transition node to create a copy of the argument with the desired `node_repr` + if the current `node_repr` is different than what is needed. + """ + arg_node = op_node.args[arg_i] + + def single_node_impl(node: torch.fx.Node) -> bool: + # Case where the arg node has not been touched yet; in this case, simply set it and + # return. + if not utils.has_node_repr(node): + utils.set_node_repr(node, arg_node_repr) + return False + + # Case where the current node representation is the same as the new one. + cur_node_repr = utils.get_node_repr(node) + assert isinstance(cur_node_repr, utils.TensorRepr) + + if cur_node_repr == arg_node_repr: + return False + + if not dirty: + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {op_node.format_node()}:" + ) + + # Existing node representation is different; insert a transition node + # Currently, the transition node insertion logic can only handle single tensor nodes + assert utils.is_single_tensor_node(node) + insert_transition_node(graph_module, op_node, node, arg_node_repr) + + logger.info(f" arg {arg_i} ({node}): ({cur_node_repr}) -> ({arg_node_repr})") + + return True + + if isinstance(arg_node, torch.fx.Node): + return single_node_impl(arg_node) + elif isinstance(arg_node, (list, tuple)): + ret: bool = False + for n in arg_node: + assert isinstance(n, torch.fx.Node) + assert utils.is_single_tensor_node(n) + ret = single_node_impl(n) or ret - 1. Storage Type (buffer or texture) - 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) + return ret - Due to the differences between buffers and textures, and the differences between - different memory layouts, an implementation for an operator may only support a - specific set of (storage type, memory layout) combinations. + raise NotImplementedError(f"Unhandled node type {arg_node}") - Furthermore, if an operator implementation supports multiple (storage type, memory - layout) combinations, there may be a "preferred" setting which results in optimal - performance. - This pass is responsible for ensuring that all tensors participating in an operator - call have a valid/optimal (storage type, memory layout) setting, and insert - transition operators to transfer input tensors to the correct memory settings when - necessary. +class TagMemoryMetaPass(ExportPass): + """ + Operator implementations in the Vulkan delegate may require that input and output + tensors use a specific representation. Representation in this case refers to a + combination of storage type (buffer or texture) and memory layout (width, height, or + channels packed). + + The tag memory metadata pass is responsible for marking each tensor in the graph + with the appropriate representation to use. It is also responsible for inserting + operators to transition argument tensors to a required/compatible representation if + a mismatch has been detected. """ def __init__( @@ -91,241 +136,331 @@ def __init__( self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits - def propose_node_storage( # noqa: C901 - self, - node: torch.fx.Node, - ) -> Optional[VkStorageType]: + # Magic number to limit "lookahead" when tracing through users of an operator + # to constrain the representation of its arguments/outputs. + self.max_trace_search_depth = 20 + + def is_valid_op_node(self, node: Any) -> bool: """ - Uses the operator registry to determine the storage type that should be used for - a given node. The storage type is determined with the following priorities: - 1. In some cases, a tensor involved in the computation may be too large to be - represented as a texture. If this is the case, the node is "opinionated" and - buffer representation must be used. - 1. If the operator called by the node indicates an optimal storage type, or only - supports a single storage type, use that storage type. If either is true, - then the node is considered to be opinionated as well. If multiple storage - and no preferred storage type is indicated, then the node is not opinionated; - go to the next step. - 2. If the node's arguments already have memory metadata annotations, then - preserve the settings of the first argument. Otherwise, proceed to the next - step. - 3. Recursively search the node's uses to see if any subsequent uses are - opinionated; inherit the settings of the first opinionated node. If no - opinionated user can be found, then proceed to the last step. - 4. Use the default storage type setting. + Fails the check for: + * nodes that are not associated with a tensor + * nodes that are associated with a constant tensor + * nodes that are not associated with a supported operator """ - if not utils.is_tensor_node(node): - return None - - # The node may have an input/output tensor that is too big to be stored in a - # texture. In this case, buffer storage must be used. Note that the partitioner - # has already checked for the fact that buffer storage is supported by the - # operator. - if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: - return VkStorageType.BUFFER - - valid_storage_types: Set[VkStorageType] = utils.all_storage_types - - # pyre-ignore - if has_impl(node.target): - # pyre-ignore - features = get_op_features(node.target) - valid_storage_types = features.supported_storage_types() - storage = features.propose_storage_type() - if storage is not None: - return storage - - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): - storage = utils.get_node_storage_type(arg) - # Some operators which return multiple output tensors may specify a - # different storage type for each output. In this case, the storage type - # for the first output is used. - if isinstance(storage, (list, tuple)): - storage = storage[0] - if storage is not None and storage in valid_storage_types: - return storage - - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. - for user in node.users: - storage = self.propose_node_storage(user) - # See above - if isinstance(storage, (list, tuple)): - storage = storage[0] - if storage is not None: - return storage - - if self.default_storage in valid_storage_types: - return self.default_storage - else: - return next(iter(valid_storage_types)) + if not isinstance(node, torch.fx.Node) or not utils.is_tensor_node(node): + return False + if node.meta.get("etvk_tensorref", False): + return False + if not has_impl(node.target): + return False - def propose_node_layout( - self, - node: torch.fx.Node, - storage: VkStorageType, - ) -> Optional[VkMemoryLayout]: + return True + + def is_non_constant_tensor_node(self, node: Any) -> bool: """ - Performs the same steps as propose_node_storage, but detects the memory layout - that should be used for the specific storage type. The same prioritization logic - is applied. + Fails the check for: + * Nodes that are not associated with tensor values + * Nodes associated with constant tensors + * """ - if not utils.is_tensor_node(node): - return None - - valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts - # pyre-ignore - if has_impl(node.target): - # pyre-ignore - features = get_op_features(node.target) - valid_layouts = features.supported_memory_layouts(storage) - layout = features.propose_memory_layout(storage) - if layout is not None: - return layout - - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): - layout = utils.get_node_memory_layout(arg) - # Some operators which return multiple output tensors may specify a - # different memory layout for each output. In this case, the storage - # type for the first output is used. - if isinstance(layout, (list, tuple)): - layout = layout[0] - if layout is not None and layout in valid_layouts: - return layout - - # If no memory layout has been resolved yet, assume the optimal layout of the - # first opinionated user. This search is recursive. - for user in node.users: - layout = self.propose_node_layout(user, storage) - # See above comment - if isinstance(layout, (list, tuple)): - layout = layout[0] - if layout is not None: - return layout - - # As a last resort, return the default storage type that should be used. - if self.default_layout in valid_layouts: - return self.default_layout - else: - return next(iter(valid_layouts)) - - def should_annotate(self, node) -> bool: if isinstance(node, torch.fx.Node): if not utils.is_tensor_node(node): return False - - # Storage type and memory layout for tensorref will be determined at runtime - # so there's no use in setting those attributes ahead of time. - if node.meta.get("vkdg_tensorref", False): + if node.meta.get("etvk_tensorref", False): return False + return True - # Skip annotating output node. The output tensors should be annotated by the - # time the output node is observed. - if node.op == "output": - return False - elif isinstance(node, (list, tuple)): - return all( - isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node - ) + if isinstance(node, (tuple, list)): + for n in node: + if not isinstance(n, torch.fx.Node): + return False + if not self.is_non_constant_tensor_node(n): + return False + + return True + + # Return false by default + return False + + def get_node_cached_repsets(self, op_node: torch.fx.Node) -> utils.OpRepSets: + """ + Implements a cache layer for getting the OpRepSets for a given operator node. + """ + assert self.is_valid_op_node(op_node) + + if "etvk_node_repsets" in op_node.meta: + op_repsets = op_node.meta["etvk_node_repsets"] + assert isinstance(op_repsets, utils.OpRepSets) + return op_repsets else: - return False + # Special case for getitem - set the input and output to the repset of the + # tensor value being extracted + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) + + arg_node_repsets = self.get_node_cached_repsets(src_node) + out_tensor_repset = arg_node_repsets.get_out_repset(idx) + + op_repsets = utils.OpRepSets( + utils.TensorRepSetList(out_tensor_repset), + utils.TensorRepSetList(out_tensor_repset), + op_node, + self.texture_limits, + ) + else: + features: OpFeatures = get_op_features(op_node.target) # noqa + op_repsets = features.make_op_repsets(op_node, self.texture_limits) - return True + op_node.meta["etvk_node_repsets"] = op_repsets + return op_repsets - def should_delay_annotation(self, node: torch.fx.Node) -> bool: - # For prepack nodes, delay setting the storage type and memory layout as long as - # possible. This is to minimize the number of transitions, since it can be - # difficult to predict what storage type and memory layout should be used at the - # time the prepack node is observed. - return node.target == exir_ops.edge.et_vk.prepack.default + def get_arg_tensor_source_repset( + self, op_node: torch.fx.Node, arg_i: int + ) -> utils.TensorRepSet: + """ + Get the "source RepSet" for the tensor argument at index `arg_i` of `op_node`. + The source repset is obtained in one of two ways: - def set_or_transition_arg_node( + 1. If the tensor argument already has a representation determined for it, return + a repset that contains that representation. + 2. Otherwise, return the output repset of the operator that produces the tensor + """ + arg_node = op_node.args[arg_i] + + # Special case for cat - use the first tensor in the list as representative + if isinstance(arg_node, list): + arg_node = arg_node[0] + + if utils.has_node_repr(arg_node): + arg_node_repr = utils.get_node_repr(arg_node) + assert isinstance(arg_node_repr, utils.TensorRepr) + return utils.make_tensor_repset(arg_node_repr) + elif self.is_valid_op_node(arg_node): + # Special case for getitem - propagate the node representation of the original node + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) + + src_node_repsets = self.get_node_cached_repsets(src_node) + return src_node_repsets.get_out_repset(idx) + + src_node_repsets = self.get_node_cached_repsets(arg_node) + return src_node_repsets.get_out_repset(0) + + # default return + return utils.ANY_STORAGE + + def constrain_repset_with_user( self, - i: int, - arg: torch.fx.Node, - node: torch.fx.Node, - graph_module: torch.fx.GraphModule, - dirty: bool, - ) -> bool: - assert isinstance(arg, torch.fx.Node) - - storage = utils.get_node_storage_type(node) - assert storage is not None - layout = utils.get_node_memory_layout(node) - assert layout is not None - - arg_storage = utils.get_node_storage_type(arg) - arg_layout = utils.get_node_memory_layout(arg) - - if arg_storage is None: - utils.set_node_spec_attr(arg, "vk_storage_type", storage) - arg_storage = storage - if arg_layout is None: - utils.set_node_spec_attr(arg, "vk_memory_layout", layout) - arg_layout = layout - - if arg_storage == storage and arg_layout == layout: - return False + current_node: torch.fx.Node, + arg_i: int, + arg_repset: utils.TensorRepSet, + search_depth: int = 0, + ) -> utils.TensorRepSet: + """ + Attempts to constrain `arg_repset` based on the required repset of the argument + at index `arg_i` of `current_node`. This tries to find a representation for the + argument that can be used for as long as possible without needing a transition. + """ + # The repset is already constrained; return it + if arg_repset.is_constrained(): + return arg_repset + + # The current node is not a valid op node, so no OpRepSets object can be created + # for it. + if not self.is_valid_op_node(current_node): + return arg_repset + + cur_node_repsets = self.get_node_cached_repsets(current_node) + + # Intersect with the repset required by the current operator; otherwise, return + # since a transition will be required anyways + req_arg_repset = cur_node_repsets.get_arg_repset(arg_i) + if req_arg_repset.any_in_common(arg_repset): + arg_repset = arg_repset.make_intersect(req_arg_repset) + else: + return arg_repset - if not dirty: - logger.info( - f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" - ) + # Check if the argument at `arg_i` will influence the output representation of + # the current operator. + repset_propagates_to_output = cur_node_repsets.sync_primary_io_repr and ( + cur_node_repsets.sync_args_repr or arg_i == cur_node_repsets.primary_arg_idx + ) - insert_transition_node(graph_module, node, arg, storage, layout) + # If not, then no point in continuing to trace the users of the current node + if not repset_propagates_to_output: + return arg_repset - logger.info( - f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + return self.trace_node_users_to_constrain_repset( + current_node, arg_repset, search_depth ) - return True - - def set_or_transition_arg( + def trace_node_users_to_constrain_repset( self, - i: int, - arg: Any, - node: torch.fx.Node, - graph_module: torch.fx.GraphModule, - dirty: bool, - ) -> bool: - if isinstance(arg, torch.fx.Node): - return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty) - elif isinstance(arg, (list, tuple)): - need_transition = False - for arg_node in arg: - need_transition = ( - self.set_or_transition_arg_node( - i, arg_node, node, graph_module, need_transition - ) - or need_transition + origin_node: torch.fx.Node, + repset: utils.TensorRepSet, + search_depth: int = 0, + ) -> utils.TensorRepSet: + """ + For an ambiguous repset, try to constrain the repset by tracing the required + repsets of the users of `origin_node`. The idea is to try to find a representation + that can be used the longest without needing user nodes to insert a transition + for its arguments. + """ + # Optionally limit the search depth to improve export time + if self.max_trace_search_depth is not None: + if search_depth > self.max_trace_search_depth: + return repset + + users_to_trace = origin_node.users + + sync_outs_repr = True + if self.is_valid_op_node(origin_node): + sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr + + if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr: + users_to_trace = [] + for usage_node in origin_node.users: + if usage_node.target == operator.getitem and usage_node.args[1] == 1: + users_to_trace.append(usage_node) + + for usage_node in users_to_trace: + arg_i_in_user = None + for i in range(len(usage_node.args)): + if origin_node == usage_node.args[i]: + arg_i_in_user = i + break + + if arg_i_in_user is not None: + repset = self.constrain_repset_with_user( + usage_node, arg_i_in_user, repset, search_depth + 1 ) - return need_transition - else: - return False - # noqa - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - if not self.should_annotate(node) or self.should_delay_annotation(node): - continue + if repset.is_constrained(): + return repset + + return repset + + def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None: + """ + Attempts to constrain the repset of the argument at index `arg_i` of the op + associated with `op_repsets`. Does this with two stages: + + 1. First, account for any existing representation that has already been determined + for the argument. If no existing representation has been determined, then use + the output repset of the operator that produces the argument. + 2. Then, try to trace through the users of the argument to find a representation + that can be used for as long as possible without needing a transition. + """ + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) + + arg_repset = op_repsets.get_arg_repset(arg_i) + if arg_repset.is_constrained(): + return arg_repset + + arg_node = op_repsets.op_node.args[arg_i] + + if isinstance(arg_node, list): + arg_node = arg_node[0] + + arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: + # For most ops, constraining the argument repsets will also contrain the output + # repset due to OpRepSets maintaining synchronization rules. + for i in range(len(op_repsets.op_node.args)): + if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): + self.constrain_op_arg_repset(i, op_repsets) + + # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there + # is no need to constrain output repsets explicitly. Currently, the exceptions + # (i.e. choose qparams) already define constrined repsets for the output, so + # there is again no need to explicitly constrain the outputs. If an operator + # appears later on that does not sync input and output representations, and + # defines ambiguous repsets for the output tensor(s), then we will need to add + # additional logic to this function to constrain the output repsets separately + # from the input repsets. + + def set_op_node_tensor_reprs( + self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node + ) -> None: + """ + For an operator representated by `op_node`, get the OpRepSets associated with + the operation and try to constrain the repsets by accounting for existing + representations and tracing through the users of the operator. + + Then, determine a tensor representation for all tensors participating in the + operation and mark it in the node metadata. If the requested representation is + different than an already determined representation, then insert a transition + node to create a copy of the tensor with the desired representation. + """ + if not self.is_valid_op_node(op_node): + return + + # Special case for getitem - propagate the node representation of the original node + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) - storage = self.propose_node_storage(node) - layout = self.propose_node_layout(node, storage) + arg_node_repr = utils.get_node_repr(src_node) + assert isinstance(arg_node_repr, list) + utils.set_node_repr(op_node, arg_node_repr[idx]) + return - set_memory_metadata(node, storage, layout) + # Get a "fresh" OpRepSets object instead of using the cache. Do this because this + # class instance will go through the constraining process which may modify it. + features: OpFeatures = get_op_features(op_node.target) + op_repsets = features.make_op_repsets(op_node, self.texture_limits) - need_transition = False - for i, arg in enumerate(node.args): - if not self.should_annotate(arg): - continue + self.constrain_op_repsets(op_repsets) - need_transition = ( - self.set_or_transition_arg( - i, arg, node, graph_module, need_transition + args_repr_list, outs_repr_list = op_repsets.pick_representations() + + if len(outs_repr_list) == 1: + utils.set_node_repr(op_node, outs_repr_list[0]) + else: + utils.set_node_repr(op_node, outs_repr_list) + + transitions_inserted = False + for i, arg_node in enumerate(op_node.args): + if not self.is_non_constant_tensor_node(arg_node): + continue + + arg_node_repr = args_repr_list[i] + + if isinstance(arg_node, torch.fx.Node): + transitions_inserted = ( + set_arg_node_repr_or_transition( + graph_module, op_node, i, arg_node_repr, transitions_inserted ) - or need_transition + or transitions_inserted ) + elif isinstance(arg_node, (list, tuple)): + for n in arg_node: + assert isinstance(n, torch.fx.Node) + assert utils.is_single_tensor_node(n) + transitions_inserted = ( + set_arg_node_repr_or_transition( + graph_module, + op_node, + i, + arg_node_repr, + transitions_inserted, + ) + or transitions_inserted + ) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + self.set_op_node_tensor_reprs(graph_module, node) return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 33ed3150535..2e0be1d68d7 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,22 +8,14 @@ import operator -from typing import Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa -import torch +import executorch.backends.vulkan.utils as utils -from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( - VkMemoryLayout, - VkStorageType, -) +import torch -from executorch.backends.vulkan.utils import ( - all_memory_layouts, - all_packed_dims, - PackedDim, -) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -38,156 +30,60 @@ def allow_node(node: torch.fx.Node) -> bool: return True -class TextureImplFeatures: - __slots__ = [ - "valid_packed_dims", - "uses_axis_map", - ] - - def __init__( - self, - uses_axis_map: bool = False, - valid_packed_dims: Optional[Set[PackedDim]] = None, - ): - self.uses_axis_map: bool = uses_axis_map - self.valid_packed_dims = set() - if valid_packed_dims is not None: - self.valid_packed_dims = valid_packed_dims - - def valid_memory_layouts(self) -> Set[VkMemoryLayout]: - """ - Derive the set of memory layouts supported by the texture implementation based - on the valid packed dimensions. - """ - layouts = set() - - if PackedDim.WIDTH in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) - - if PackedDim.HEIGHT in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) - - if PackedDim.CHANNELS in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) - - return layouts - - class OpFeatures: __slots__ = [ - # None or TextureImplFeatures to specify implementation details of the texture - # based operator implementation. - "texture_impl", - # bool indicating if the operator has a buffer based implementation. - "buffer_impl", + # Sets of possible (storage types, memory layouts) to use for the input tensor(s) + "inputs_storage", + # Sets of possible (storage types, memory layouts) to use for the output tensor(s) + "outputs_storage", # bool indicating if the operator has a resize function, which allows it to - # support dynamic shape tensors. - "resize_fn", - # Optimal - "optimal_storage", - "optimal_layout", + # support models with dynamic shape + "supports_resize", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. - "handles_own_prepacking", - # Optional dictionary to specify a custom function to calculate the required - # image extents for a particular argument index. - "skip_limits_check", + "supports_prepacking", # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. - "check_node_fn", + "are_node_inputs_supported_fn", ] def __init__( self, - texture_impl: Optional[TextureImplFeatures] = None, - buffer_impl: bool = False, - resize_fn: bool = False, - optimal_storage: Optional[VkStorageType] = None, - optimal_layout: Optional[VkMemoryLayout] = None, - handles_own_prepacking: bool = False, - skip_limits_check: Optional[Set[int]] = None, - check_node_fn: Optional[Callable] = None, + inputs_storage: Optional[ + Union[utils.TensorRepSet, List[utils.TensorRepSet]] + ] = None, + outputs_storage: Optional[ + Union[utils.TensorRepSet, List[utils.TensorRepSet]] + ] = None, + supports_resize: bool = False, + supports_prepacking: bool = False, + are_node_inputs_supported_fn: Optional[Callable] = allow_node, ): - self.texture_impl: Optional[TextureImplFeatures] = texture_impl - self.buffer_impl: bool = buffer_impl - self.resize_fn: bool = resize_fn - self.optimal_storage: Optional[VkStorageType] = optimal_storage - self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout - self.handles_own_prepacking: bool = handles_own_prepacking - - self.skip_limits_check: Set[int] = set() - if skip_limits_check is not None: - self.skip_limits_check = skip_limits_check - - self.check_node_fn: Callable = allow_node - if check_node_fn is not None: - self.check_node_fn = check_node_fn - - def propose_storage_type(self) -> Optional[VkStorageType]: - """ - Propose a storage type that should be used for this operator. A proposal can be - made if one of the following is true: - 1. The operator specifies an optimal storage type - 2. Only one storage type is supported. - - If both storage types are supported and no optimal storage type is specified, - then None is returned to indicate that there is no preference in storage type. - """ - if self.optimal_storage is not None: - return self.optimal_storage - - if self.texture_impl is not None and not self.buffer_impl: - return VkStorageType.TEXTURE_3D - elif self.buffer_impl and self.texture_impl is None: - return VkStorageType.BUFFER - - return None - - def supported_storage_types(self) -> Set[VkStorageType]: - """ - Return the set of storage types supported by this operator. - """ - storage_types = set() - if self.texture_impl is not None: - storage_types.add(VkStorageType.TEXTURE_3D) - if self.buffer_impl: - storage_types.add(VkStorageType.BUFFER) - - return storage_types - - def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: - """ - Given a storage type as a precondition, propose a memory layout that should be - used for this operator. A proposal can be made if one of the following is true: - 1. The operator specifies an optimal memory layout - 2. Only one memory layout is supported. - - If multiple memory layouts are supported and no optimal memory layout is - specified then return None to indicate that the "best" memory layout for the - operator is ambiguous. - """ - if self.optimal_layout is not None: - return self.optimal_layout - - if storage == VkStorageType.TEXTURE_3D: - assert self.texture_impl is not None - possible_layouts = self.texture_impl.valid_memory_layouts() - if len(possible_layouts) == 1: - return next(iter(possible_layouts)) - - return None - - def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: - """ - Return the set of memory layouts supported by this operator for a given storage - type. - """ - if storage == VkStorageType.TEXTURE_3D: - assert self.texture_impl is not None - return self.texture_impl.valid_memory_layouts() - else: - return all_memory_layouts + self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( + inputs_storage if inputs_storage is not None else [] + ) + self.outputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( + outputs_storage if outputs_storage is not None else [] + ) + + # If output storage is not set, assume that it is derived from the first input + if self.outputs_storage.any_is_empty(): + self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0]) + + self.supports_resize = supports_resize + self.supports_prepacking = supports_prepacking + + self.are_node_inputs_supported_fn = are_node_inputs_supported_fn + + def make_op_repsets( + self, + op_node: torch.fx.Node, + texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS, + ) -> utils.OpRepSets: + return utils.OpRepSets( + self.inputs_storage, self.outputs_storage, op_node, texture_limits + ) ####################### @@ -204,8 +100,7 @@ def features_decorator(fn: Callable): def update_features_impl(op: OpKey): if op in vulkan_supported_ops: raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!") - vulkan_supported_ops[op] = OpFeatures() - vulkan_supported_ops[op] = fn(vulkan_supported_ops[op]) + vulkan_supported_ops[op] = fn() if isinstance(aten_op, list): for op in aten_op: @@ -233,14 +128,11 @@ def update_features_impl(op: OpKey): torch.ops.aten.sym_constrain_range_for_size.default, ] ) -def register_ephemeral_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_ephemeral_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - return features @update_features( @@ -253,23 +145,13 @@ def register_ephemeral_op(features: OpFeatures): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_token.default, exir_ops.edge.quantized_decomposed.dequantize_per_token.default, - exir_ops.edge.quantized_decomposed.choose_qparams.tensor, - exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, ] ) -def register_quantization_op(features: OpFeatures): - # Quantization requires buffer storage and width packing for scales/zero_points - # but we need to provide texture impl features for the partitioner to work properly - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims={ - PackedDim.WIDTH, - }, +def register_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.BUFFER - return features @update_features( @@ -278,39 +160,25 @@ def register_quantization_op(features: OpFeatures): exir_ops.edge.torchao.dequantize_affine.default, ] ) -def register_affine_quantization_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_affine_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - - return features @update_features( [ exir_ops.edge.torchao.choose_qparams_affine.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, ] ) -def register_choose_qparams_affine_op(features: OpFeatures): - # Currently only created a rudimentary buffer implementation for choose_qparams_affine - # since the reduction logic for blocks in texture3d is not trivial to implement in vulkan. - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={ - PackedDim.WIDTH, - }, +def register_torchao_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.BUFFER - - return features @update_features( @@ -329,13 +197,11 @@ def register_choose_qparams_affine_op(features: OpFeatures): exir_ops.edge.aten.ge.Tensor, ] ) -def register_binary_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_binary_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -358,24 +224,15 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.leaky_relu.default, ] ) -def register_unary_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_unary_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - return features @update_features(exir_ops.edge.aten._to_copy.default) -def register_to_copy_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, - ) - features.resize_fn = True - +def register_to_copy_op(): def check_to_copy_node(node: torch.fx.Node) -> bool: float_dtypes = [torch.float16, torch.float32] @@ -395,20 +252,15 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return False - features.check_node_fn = check_to_copy_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_to_copy_node, + ) @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) -def register_to_copy_dim_order_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, - ) - features.buffer_impl = True - features.resize_fn = True - +def register_to_copy_dim_order_op(): # Currently there is no "real" implementation for to_dim_order_copy, but it can be # removed as long as the operator is not changing the dtype, i.e. the operator call # is modifying the dim order only. Therefore, check that the input and output dtypes @@ -426,9 +278,11 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool: return True - features.check_node_fn = check_dim_order_copy_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_dim_order_copy_node, + ) @update_features( @@ -439,20 +293,12 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.linear.default, ] ) -def register_mm_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims={ - PackedDim.WIDTH, - PackedDim.CHANNELS, - }, +def register_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( @@ -461,37 +307,46 @@ def register_mm_op(features: OpFeatures): exir_ops.edge.et_vk.linear_qcs4w.default, ] ) -def register_int8_mm_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_int8_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( [ exir_ops.edge.et_vk.linear_weight_int4.default, + ] +) +def register_int4_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, + ) + + +@update_features( + [ exir_ops.edge.et_vk.linear_qta8a_qga4w.default, ] ) -def register_int4_mm_op(features: OpFeatures): - features.buffer_impl = True - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_dqlinear_op(): + return OpFeatures( + inputs_storage=[ + utils.CONTIGUOUS_ANY, # input + utils.CONTIGUOUS_BUFFER, # mat1 scales + utils.CONTIGUOUS_BUFFER, # mat1 zeros + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # group size (non tensor) + utils.CONTIGUOUS_BUFFER, # mat2 scales + utils.CONTIGUOUS_BUFFER, # mat2 zeros + ], + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - features.skip_limits_check = {1} - return features @update_features( @@ -500,12 +355,11 @@ def register_int4_mm_op(features: OpFeatures): exir_ops.edge.aten._softmax.default, ] ) -def register_softmax_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_softmax_op(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -516,25 +370,24 @@ def register_softmax_op(features: OpFeatures): exir_ops.edge.aten.amin.default, ] ) -def register_reduce_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, - ) - features.resize_fn = True - +def register_reduce_op(): def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] if isinstance(dim_list, list) and len(dim_list) != 1: return False - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + if len(node.args) > 2: + keepdim = node.args[2] + if isinstance(keepdim, bool) and not keepdim: + return False return True - features.check_node_fn = check_reduce_node - return features + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, + are_node_inputs_supported_fn=check_reduce_node, + ) @update_features( @@ -543,12 +396,11 @@ def check_reduce_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.max_pool2d_with_indices.default, ] ) -def register_2d_pool_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_2d_pool_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -557,28 +409,21 @@ def register_2d_pool_op(features: OpFeatures): exir_ops.edge.et_vk.conv_with_clamp.default, ] ) -def register_convolution_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_convolution_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED - features.handles_own_prepacking = True - features.skip_limits_check = {1, 2} - return features @update_features("llama::sdpa_with_kv_cache") -def register_sdpa_with_kv_cache_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_sdpa_with_kv_cache_op(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( @@ -587,23 +432,19 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): "llama::custom_sdpa", ] ) -def register_sdpa_ops(features: OpFeatures): - features.resize_fn = False - features.buffer_impl = False - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_sdpa_ops(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) -def register_rotary_emb_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_rotary_emb_op(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -614,25 +455,18 @@ def register_rotary_emb_op(features: OpFeatures): exir_ops.edge.aten.view_copy.default, ] ) -def register_view_ops(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_view_ops(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features # Fully featured transfer operators (i.e. operators that copy data from the input # tensor(s) to the output tensor(s)), which have memory layout agnostic implementations # for both texture and buffer storage types. @update_features(exir_ops.edge.aten.cat.default) -def register_cat_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, - ) - features.buffer_impl = True - features.resize_fn = True - +def register_cat_op(): def check_cat_node(node: torch.fx.Node) -> bool: inputs = node.args[0] if isinstance(inputs, (list, tuple)) and len(inputs) <= 3: @@ -640,9 +474,11 @@ def check_cat_node(node: torch.fx.Node) -> bool: return False - features.check_node_fn = check_cat_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_cat_node, + ) # Fully featured transfer operators (i.e. operators that copy data from the input @@ -654,14 +490,11 @@ def check_cat_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.slice_copy.Tensor, ] ) -def register_transfer_ops(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_transfer_ops(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - - return features # Ops ported from PyTorch Vulkan backend. These ops commonly support channels @@ -688,14 +521,13 @@ def register_transfer_ops(features: OpFeatures): exir_ops.edge.et_vk.grid_priors.default, ] ) -def register_ported_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_ported_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, ) - return features -# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions +# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions @update_features( [ # Shape Manipulation @@ -707,11 +539,10 @@ def register_ported_op(features: OpFeatures): exir_ops.edge.aten.split.Tensor, ] ) -def register_ported_op_all_packed_dims(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_ported_op_all_packed_dims(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, ) - return features # Ported ops that support their own prepacking. @@ -721,12 +552,11 @@ def register_ported_op_all_packed_dims(features: OpFeatures): exir_ops.edge.aten._native_batch_norm_legit_no_training.default, ] ) -def register_ported_ops_with_prepacking(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_ported_ops_with_prepacking(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_prepacking=True, ) - features.handles_own_prepacking = True - return features @update_features( @@ -734,25 +564,16 @@ def register_ported_ops_with_prepacking(features: OpFeatures): exir_ops.edge.aten.native_group_norm.default, ] ) -def register_native_group_norm(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_native_group_norm(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + outputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, + utils.CONTIGUOUS_BUFFER, + utils.CONTIGUOUS_BUFFER, + ], + supports_prepacking=True, ) - features.handles_own_prepacking = True - - features.optimal_storage = [ - VkStorageType.TEXTURE_3D, - VkStorageType.BUFFER, - VkStorageType.BUFFER, - ] - - features.optimal_layout = [ - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_WIDTH_PACKED, - ] - - return features # Ported ops that support their own prepacking. @@ -761,12 +582,11 @@ def register_native_group_norm(features: OpFeatures): exir_ops.edge.aten.native_layer_norm.default, ] ) -def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_ported_ops_with_prepacking_all_dims(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_prepacking=True, ) - features.handles_own_prepacking = True - return features ####################### @@ -774,7 +594,7 @@ def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): ####################### -def has_impl(target: OpKey) -> bool: +def has_impl(target: Any) -> bool: if not isinstance(target, str): if target not in vulkan_supported_ops: return target.name() in vulkan_supported_ops @@ -783,7 +603,7 @@ def has_impl(target: OpKey) -> bool: return target in vulkan_supported_ops -def get_op_features(target: OpKey) -> OpFeatures: +def get_op_features(target: Any) -> OpFeatures: if not isinstance(target, str): if target not in vulkan_supported_ops: # Try the op's name @@ -795,4 +615,4 @@ def get_op_features(target: OpKey) -> OpFeatures: def handles_own_prepacking(target: OpKey) -> bool: - return get_op_features(target).handles_own_prepacking + return get_op_features(target).supports_prepacking diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 9b76f6acd33..776d1d6e168 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -83,61 +83,18 @@ def op_node_is_compatible( # noqa: C901: Function is too complex return False, "no operator implementation" features = get_op_features(target) - # Check for high dimensional tensors - if utils.is_tensor_node(node) and utils.tensor_node_is_high_dim(node): - return False, "contains high dim tensor" - - valid_texture_layouts = utils.possible_node_memory_layouts( + # Get the possible tensor representations for each tensor participating in the + # this operator. Then check that all tensors are representable as either a + # buffer or texture. + op_repsets: utils.OpRepSets = features.make_op_repsets( node, self.texture_limits ) - can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit) - for i, arg in enumerate(node.args): - if ( - isinstance(arg, torch.fx.Node) - and utils.is_tensor_node(arg) - and i not in features.skip_limits_check - ): - # Check for bool inputs - if utils.tensor_node_is_bool(arg): - return False, "contains bool tensor" - - # Check for high dimensional tensors - if utils.tensor_node_is_high_dim(arg): - return False, "contains high dim tensor" - - arg_texture_layouts = utils.possible_node_memory_layouts( - arg, self.texture_limits - ) - valid_texture_layouts = valid_texture_layouts.intersection( - arg_texture_layouts - ) - can_use_buffers = can_use_buffers and utils.within_buffer_limit( - arg, self.buffer_limit - ) - - op_available_layouts = features.supported_memory_layouts( - VkStorageType.TEXTURE_3D - ) - - can_use_texture = any( - layout in op_available_layouts for layout in valid_texture_layouts - ) - - # If there are no valid texture memory layouts, then buffer storage must be - # supported by the operator implementation. - if not can_use_texture: - if not can_use_buffers: - return ( - False, - f"op requires buffers that exceed the buffer limit ({self.buffer_limit})", - ) - - compatible = VkStorageType.BUFFER in features.supported_storage_types() - reason = "op is compatible" - if not compatible: - reason = "op requires buffers which is not supported by op impl" - return compatible, reason + if op_repsets.any_is_empty(): + return ( + False, + "No valid representations for a tensor in the operation", + ) return True, "Op is compatible" @@ -266,11 +223,11 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: assert features is not None - if not features.check_node_fn(node): + if not features.are_node_inputs_supported_fn(node): self.log_skip(node, "op args not supported") return False - if self.require_dynamic_shapes and not features.resize_fn: + if self.require_dynamic_shapes and not features.supports_resize: self.log_skip(node, "no dynamic shape support") return False @@ -331,7 +288,10 @@ def __init__( def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: - return (ops_not_to_decompose, None) + def filter_fn(node: torch.fx.Node) -> bool: + return True + + return (ops_not_to_decompose, filter_fn) def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index cd876bd6305..b74a7fb1f8e 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -23,6 +23,7 @@ is_mutable_buffer_node, is_param_node, is_symint_node, + TensorRepr, ) from executorch.exir.backend.utils import DelegateMappingBuilder @@ -135,7 +136,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int: def create_node_value(self, node: Node) -> int: # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor - if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False): + if is_symint_node(node) or node.meta.get("etvk_is_scalar_tensor", False): new_id = self.create_symint_value() self.node_to_value_ids[node] = new_id return new_id @@ -197,12 +198,11 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: storage_type = VkStorageType.DEFAULT_STORAGE memory_layout = VkMemoryLayout.DEFAULT_LAYOUT - if hasattr(spec, "vk_storage_type"): + if hasattr(spec, "etvk_node_repr"): # pyre-ignore[16] - storage_type = spec.vk_storage_type - if hasattr(spec, "vk_memory_layout"): - # pyre-ignore[16] - memory_layout = spec.vk_memory_layout + assert isinstance(spec.etvk_node_repr, TensorRepr) + storage_type = spec.etvk_node_repr.storage_type + memory_layout = spec.etvk_node_repr.memory_layout # Apply downcast logic before getting VK datatype effective_dtype = spec.dtype diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 926452dd388..4799a22882d 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1790,25 +1790,21 @@ def forward(self, x): def test_vulkan_backend_large_linear_layer(self): class LinearModel(torch.nn.Module): - def __init__( - self, n_pca_basis: int, n_sh_basis: int, n_gaussians: int - ) -> None: + def __init__(self, large_out_channels: int) -> None: super(LinearModel, self).__init__() - self.fc1 = torch.nn.Linear( - n_pca_basis, (n_sh_basis + 3 + 3 + 4) * n_gaussians - ) + self.fc0 = torch.nn.Linear(1024, 128) + self.fc1 = torch.nn.Linear(128, large_out_channels) def forward(self, x: torch.Tensor): + x = self.fc0(x) out = self.fc1(x) return out - n_pca_basis = 64 - n_sh_basis = 6 - n_gaussians = 2**16 + large_out_channels = 2**16 self.lower_module_and_test_output( - LinearModel(n_pca_basis, n_sh_basis, n_gaussians), - (torch.ones(n_pca_basis),), + LinearModel(large_out_channels), + (torch.ones(1024),), ) def test_vulkan_backend_sym_size_int(self): @@ -2060,3 +2056,97 @@ def forward(self, x): self.lower_module_and_test_output( full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 ) + + def test_vulkan_backend_different_required_reprs(self): + class ComplexModule(torch.nn.Module): + """ + This Module tests the tag memory metadata pass. The first few ops executed + are binary ops, which don't require any specific representation for input + and output tensors. + + This is followed by a linear layer, which requires the input tensor to be + width packed. + + Three linear layer outputs are then concatenated, and the result is passed + to a convolution layer which requires channels packing. Finally, group norm + is called and the output is postprocessed by a binary op before returning. + + In addition to requiring memory layout transitions between the linear and + conv stages, the module also contains ops which have "non-standard" + torch.fx.Nodes; cat will contain an argument node that is a list of nodes, + and group norm's node will be associated with multiple output tensors. + """ + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.conv = torch.nn.Conv2d( + in_channels=3, # Assuming concatenation triples the channels + out_channels=16, + kernel_size=3, + padding=1, + ) + self.group_norm = torch.nn.GroupNorm(num_groups=4, num_channels=16) + + def forward(self, x, a, b, c, d): + w = a + b + y = a + c + z = a + d + + b1 = x + y + b2 = x + z + b3 = x + w + + l1 = self.linear(b1).unsqueeze(0) + l2 = self.linear(b2).unsqueeze(0) + l3 = self.linear(b3).unsqueeze(0) + + concat = torch.cat([l1, l2, l3], dim=0) # Concatenate along channels + conv = self.conv(concat + a) + g = self.group_norm(conv.unsqueeze(0)) + return g + x + + complex_module = ComplexModule() + sample_inputs = ( + torch.rand(size=(10, 10), dtype=torch.float32), # x + torch.rand(size=(10, 10), dtype=torch.float32), # a + torch.rand(size=(10, 10), dtype=torch.float32), # b + torch.rand(size=(10, 10), dtype=torch.float32), # c + torch.rand(size=(10, 10), dtype=torch.float32), # d + ) + + self.lower_module_and_test_output(complex_module, sample_inputs) + + def test_vulkan_backend_cat_different_reprs(self): + class CustomComplexModule(torch.nn.Module): + """ + This test validates that the memory metadata tagging pass can handle + transitioning arguments to the cat operator. Linear layers require width + packing, while conv layers require channels packing. Before executing the + cat operator, all input tensors should use the same representation. + """ + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + self.linear2 = torch.nn.Linear(10, 10) + self.conv = torch.nn.Conv2d( + in_channels=4, # Assuming input b has 3 channels + out_channels=8, + kernel_size=3, + padding=1, + ) + + def forward(self, a, b): + x1 = self.linear1(a).unsqueeze(0) + x2 = self.linear2(a).unsqueeze(0) + y = self.conv(b) + return torch.cat([x1, x2, y], dim=0) + + custom_complex_module = CustomComplexModule() + sample_inputs = ( + torch.rand(size=(10, 10), dtype=torch.float32), # a + torch.rand(size=(4, 10, 10), dtype=torch.float32), # b + ) + + self.lower_module_and_test_output(custom_complex_module, sample_inputs) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 9086b2d0792..fa45063a4d3 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import IntEnum -from typing import Optional, Set, Tuple +import operator +from typing import Any, List, Optional, Set, Tuple, Union import torch @@ -50,6 +50,9 @@ ## Node type determination ## +# Convenience type +MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]] + def is_dequant_node(node: torch.fx.Node) -> bool: if node.op != "call_function": @@ -121,10 +124,42 @@ def is_symint_node(node: torch.fx.Node) -> bool: return False -def is_tensor_node(node: torch.fx.Node) -> bool: +def is_single_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a single tensor value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + return False + + +def is_tensor_collection_node(node: Any) -> bool: + """ + Returns true if the given node produces a collection of tensor values + """ + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +def is_tensor_node(node: Any) -> bool: """ Returns true if the given node produces a tensor value, or a collection of tensor values """ + if not isinstance(node, torch.fx.Node): + return False + if "val" not in node.meta: return False @@ -137,6 +172,47 @@ def is_tensor_node(node: torch.fx.Node) -> bool: return False +def is_tensor_arg_node(node: Any) -> bool: + if isinstance(node, torch.fx.Node): + return is_tensor_node(node) + elif isinstance(node, (list, tuple)): + return all(is_tensor_node(n) for n in node) + + return False + + +def num_tensor_arg_nodes(node: torch.fx.Node) -> int: + """ + For a given node, return the number of argument nodes that are associated with + tensors. + """ + count = 0 + for arg_node in node.args: + if not isinstance(arg_node, torch.fx.Node): + continue + if is_tensor_node(arg_node): + count += 1 + + return count + + +def num_tensors_in_node(node: torch.fx.Node) -> int: + """ + Returns the number of tensors associated a given node + """ + if "val" not in node.meta: + return 0 + + if isinstance(node.meta["val"], FakeTensor): + return 1 + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + if all(isinstance(x, FakeTensor) for x in node.meta["val"]): + return len(node.meta["val"]) + + return 0 + + def tensor_node_is_bool(node: torch.fx.Node) -> bool: """ Returns true if a given node contains a tensor with bool dtype @@ -151,6 +227,15 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: return False +def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: + primary_arg_idx: Optional[int] = None + for i, arg_node in enumerate(node.args): + if self.is_non_constant_tensor_node(arg_node): + return i + + return primary_arg_idx + + ## ## Memory Layout, Storage Type Determination ## @@ -160,19 +245,6 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) - -class PackedDim(IntEnum): - WIDTH = 0 - HEIGHT = 1 - CHANNELS = 2 - - -all_packed_dims: Set[PackedDim] = { - PackedDim.WIDTH, - PackedDim.HEIGHT, - PackedDim.CHANNELS, -} - all_storage_types: Set[VkStorageType] = { VkStorageType.BUFFER, VkStorageType.TEXTURE_3D, @@ -184,6 +256,9 @@ class PackedDim(IntEnum): VkMemoryLayout.TENSOR_CHANNELS_PACKED, } +MemoryLayoutSet = Set[VkMemoryLayout] +MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]] + def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: """ @@ -257,24 +332,622 @@ def valid_texture_memory_layouts( return valid_layouts -def possible_node_memory_layouts( - node: torch.fx.Node, texture_limits: ImageExtents -) -> Set[VkMemoryLayout]: +class TensorRepr: """ - Given a node, determine the set of memory layouts which can be used to represent all - tensors involved in the computation. + This class is a wrapper around a pair of VkStorageType and VkMemoryLayout which + describes how a tensor should be represented in the Vulkan Delegate. """ - assert is_tensor_node(node) - if isinstance(node.meta["val"], FakeTensor): - return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) - valid_layouts = set() - if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): - for fake_tensor in node.meta["val"]: - valid_layouts = valid_layouts.union( - valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + + def __init__(self, storage_type: VkStorageType, memory_layout: VkMemoryLayout): + self.storage_type = storage_type + self.memory_layout = memory_layout + + def __str__(self) -> str: + return f"TensorRepr({self.storage_type}, {self.memory_layout})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorRepr): + return NotImplemented + return ( + self.storage_type == other.storage_type + and self.memory_layout == other.memory_layout + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class TensorReprList: + """ + This class is a wrapper around a list of TensorRepr instances that automatically + applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single + underlying TensorRepr to be used to represent multiple tensors. + """ + + def __init__(self, tensor_reprs: Union[TensorRepr, List[TensorRepr]]): + self.vals: List[TensorRepr] = ( + tensor_reprs if isinstance(tensor_reprs, list) else [tensor_reprs] + ) + + def __len__(self): + return len(self.vals) + + def __getitem__(self, idx: int) -> TensorRepr: + if idx > 0 and len(self) == 1: + return self.vals[0] + else: + return self.vals[idx] + + def __setitem__(self, idx: int, val: TensorRepr) -> None: + if idx > 0 and len(self) == 1: + self.vals[0] = val + else: + self.vals[idx] = val + + def __str__(self) -> str: + return f"[{', '.join(str(ts) for ts in self.vals)}]" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorReprList): + return NotImplemented + + if len(self) == len(other): + for self_val, other_val in zip(self.vals, other.vals): + if self_val != other_val: + return False + + return True + + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def append(self, val: TensorRepr) -> None: + self.vals.append(val) + + def storage_type(self, idx: int = 0) -> VkStorageType: + return self.vals[idx].storage_type + + def memory_layout(self, idx: int = 0) -> VkMemoryLayout: + return self.vals[idx].memory_layout + + +class TensorRepSet: + """ + This class describes the possible set of representations (i.e. TensorRepr) that may + be used to represent a tensor. This set is determined by the implementation of the + operator that the tensor participates in as well as the texture extents of the GPU. + """ + + def __init__( + self, + buffer_memory_layouts: Set[VkMemoryLayout], + texture_memory_layouts: Set[VkMemoryLayout], + ): + self.valid_buffer_layouts = buffer_memory_layouts + self.valid_texture_layouts = texture_memory_layouts + + def __str__(self) -> str: + buffer_layouts = ", ".join(layout.name for layout in self.valid_buffer_layouts) + texture_layouts = ", ".join( + layout.name for layout in self.valid_texture_layouts + ) + return f"TensorRepSet(Buffer Layouts: [{buffer_layouts}], Texture Layouts: [{texture_layouts}])" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorRepSet): + return NotImplemented + return ( + self.valid_buffer_layouts == other.valid_buffer_layouts + and self.valid_texture_layouts == other.valid_texture_layouts + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def is_empty(self) -> bool: + """ + A TensorRepSet is "empty" if there are no valid representations of the tensor. + """ + return ( + len(self.valid_buffer_layouts) == 0 and len(self.valid_texture_layouts) == 0 + ) + + def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet": + """ + Merge this TensorRepr with another TensorRepr, returning a new TensorRepr + with the intersection of the two. + """ + return TensorRepSet( + self.valid_buffer_layouts & other.valid_buffer_layouts, + self.valid_texture_layouts & other.valid_texture_layouts, + ) + + def is_compatible(self, storage: TensorRepr) -> bool: + """ + Check if this TensorRepr is compatible with the given TensorRepSet. + """ + if storage.storage_type == VkStorageType.BUFFER: + return storage.memory_layout in self.valid_buffer_layouts + elif storage.storage_type == VkStorageType.TEXTURE_3D: + return storage.memory_layout in self.valid_texture_layouts + else: + raise RuntimeError(f"Unsupported storage type {storage.storage_type}") + + def any_in_common(self, other: "TensorRepSet") -> bool: + """ + Check if this TensorRepr has any representations in common with another + TensorRepr. + """ + return ( + len(self.valid_buffer_layouts & other.valid_buffer_layouts) > 0 + or len(self.valid_texture_layouts & other.valid_texture_layouts) > 0 + ) + + def texture_is_valid(self): + return len(self.valid_texture_layouts) > 0 + + def buffer_is_valid(self): + return len(self.valid_buffer_layouts) > 0 + + def first_valid_buffer_layout(self): + return list(self.valid_buffer_layouts)[0] + + def first_valid_texture_layout(self): + return list(self.valid_texture_layouts)[0] + + def make_tensor_repr(self) -> TensorRepr: + """ + Pick a representation (i.e. TensorRepr) from the set of possible representations. + If there are multiple valid representations, then: + 1. Prefer texture storage over buffer storage + 2. Pick the first available memory layout. + """ + if self.is_empty(): + # An empty repset typically means that it is associated with a weight tensor + # or non tensor argument. In this case, just return default storage and + # layout as placeholder. + return TensorRepr( + VkStorageType.DEFAULT_STORAGE, VkMemoryLayout.DEFAULT_LAYOUT ) - return valid_layouts + if self.texture_is_valid(): + return TensorRepr( + VkStorageType.TEXTURE_3D, self.first_valid_texture_layout() + ) + + else: + return TensorRepr(VkStorageType.BUFFER, self.first_valid_buffer_layout()) + + def is_constrained(self) -> bool: + """ + A "constrained" RepSet is one that has either: + 1. A single valid texture memory layout, and no valid buffer memory layouts + 2. No valid texture memory layouts, and a single valid buffer memory layout + 3. Is empty + + In this case, it is unambiguous which representation should be used for the + tensor. + """ + if self.is_empty(): + return True + elif ( + len(self.valid_texture_layouts) == 1 and len(self.valid_buffer_layouts) == 0 + ): + return True + elif ( + len(self.valid_texture_layouts) == 0 and len(self.valid_buffer_layouts) == 1 + ): + return True + else: + return False + + def is_ambiguous(self) -> bool: + """ + An "ambiguous" RepSet is one that is not constrained. + """ + return not self.is_constrained() + + +def make_tensor_repset(tensor_repr: TensorRepr) -> TensorRepSet: + """ + Given a TensorRepr, return a TensorRepSet that contains only that TensorRepr + """ + if tensor_repr.storage_type == VkStorageType.BUFFER: + return TensorRepSet({tensor_repr.memory_layout}, set()) + elif tensor_repr.storage_type == VkStorageType.TEXTURE_3D: + return TensorRepSet(set(), {tensor_repr.memory_layout}) + else: + raise RuntimeError(f"Unsupported storage type {tensor_repr.storage_type}") + + +def make_filtered_tensor_repset( + tensor_val: FakeTensor, + tensor_repset: TensorRepSet, + texture_limits: ImageExtents, +) -> TensorRepSet: + """ + `tensor_val` represents an actual tensor participating in some operator computation. + + `tensor_repset` represents the set of valid tensor representations that may be used + for that tensor that is supported by the op implementation. + + `texture_limits` represents the maximum texture sizes that is supported by the GPU. + + Given the above, return a new TensorRepSet that contains only texture layouts that + can be used to produce a valid image texture for the given tensor (i.e. fits within + texture limits). + """ + valid_texture_layouts = set() + for memory_layout in tensor_repset.valid_texture_layouts: + extents = required_image_extents(tensor_val.shape, memory_layout) + if extents_are_valid(extents, texture_limits): + valid_texture_layouts.add(memory_layout) + + # High dimensional tensors are currently not supported + if len(tensor_val.shape) > 4: + return NO_STORAGE + + # Bool tensors are currently not supported + if tensor_val.dtype == torch.bool: + return NO_STORAGE + + return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts) + + +## Convenience TensorRepSet definitions + +CONTIGUOUS_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} +) +CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + +WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED}) +CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + +ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) + +ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts) +NO_STORAGE = TensorRepSet(set(), set()) + + +class TensorRepSetList: + """ + This class is a wrapper around a list of TensorRepSet instances that automatically + applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single + underlying TensorRepSet to be used for multiple tensors. + """ + + def __init__( + self, + tensor_repsets: Union[TensorRepSet, List[TensorRepSet]], + ): + self.vals: List[TensorRepSet] = ( + tensor_repsets if isinstance(tensor_repsets, list) else [tensor_repsets] + ) + + def __len__(self): + return len(self.vals) + + def __getitem__(self, idx: int) -> TensorRepSet: + if idx > 0 and len(self) == 1: + return self.vals[0] + else: + return self.vals[idx] + + def __setitem__(self, idx: int, val: TensorRepSet) -> None: + if idx > 0 and len(self.vals) == 1: + self.vals[0] = val + else: + self.vals[idx] = val + + def __str__(self) -> str: + return f"[{', '.join(str(ts) for ts in self.vals)}]" + + def append(self, val: TensorRepSet) -> None: + return self.vals.append(val) + + def any_is_empty(self) -> bool: + if len(self.vals) == 0: + return True + + return any(tensor_repr.is_empty() for tensor_repr in self.vals) + + +class OpRepSets: + """ + This class is responsible for representing and managing the set of valid tensor + representations that may be used for all input and output tensors of an operator. + It is also responsible for maintaining synchronization rules between tensors + participating in the computation. + + Currently, three synchronization rules exist: + 1. All input tensors must use the same representation (e.g. binary ops) + 2. The "primary" input and output tensors must use the same representation + (e.g. group norm; the output is a tuple of out, mean, rstd; out must be the same + representation as the first input x, but mean and rstd may use different + representations as out) + 3. All output tensors must use the same representation (e.g. choose qparams) + + Note that "primary" input and output tensor refers to the first non-weight input + tensor and the first output tensor. Note that Some operators (such as arange) do not + have any tensor inputs. + + Currently, the above three synchronization rules are sufficient to describe the + representation requirements of all ET-VK operators. + + This class also provides utilities to constrain the repsets; when applying the + constraints, the synchronization rules will be maintained. + """ + + def __init__( # noqa: C901 + self, + inputs_repsets: TensorRepSetList, + outputs_repsets: TensorRepSetList, + op_node: torch.fx.Node, + texture_limits: ImageExtents, + ): + self.op_node = op_node + + # inputs_repset_list is received from the operator registration. If a different + # repset is defined for each input tensor, then assume that the input tensor + # representations do not need to be synchronized. + if len(inputs_repsets) > 1: + self.sync_args_repr = False + # Otherwise, default to True + else: + self.sync_args_repr = True + + # outputs_repset_list is received from the operator registration. If a different + # repset is defined for each output tensor, then assume that the output tensor + # representations do not need to be synchronized. + if len(outputs_repsets) > 1: + self.sync_outs_repr = False + else: + self.sync_outs_repr = True + + # Try to determine the index of the "primary" argument, i.e. the first non + # constant tensor argument. For the vast majority of operators with tensor + # arguments, this will be the first argument. + self.primary_arg_idx: Optional[int] = None + for i, arg_node in enumerate(self.op_node.args): + arg_node_repset = inputs_repsets[i] + if not is_tensor_arg_node(arg_node): + continue + if arg_node_repset is None: + continue + if arg_node_repset.is_empty(): + continue + + self.primary_arg_idx = i + break + + # If the repset of the primary input and the primary output are the same, then + # assume they need to be the same. + self.sync_primary_io_repr = self.primary_arg_idx is not None + if self.primary_arg_idx is not None: + if inputs_repsets[self.primary_arg_idx] != outputs_repsets[0]: + self.sync_primary_io_repr = False + + # Now, go through the arguments of the operator and create a filtered repset + # for each based on the actual tensor value. + args_repset_list = TensorRepSetList([]) + common_arg_repset = ANY_STORAGE + for i, arg_node in enumerate(op_node.args): + arg_repset = inputs_repsets[i] + + # Use ANY_STORAGE for non-tensor nodes so they don't cause the op repsets to + # appear empty + if not is_tensor_arg_node(arg_node): + args_repset_list.append(ANY_STORAGE) + # NO_STORAGE is used to denote that an input is either a non tensor arg or + # a weight tensor that is not prepacked. Similar to the above, use + # ANY_STORAGE in this case. + elif arg_repset.is_empty(): + args_repset_list.append(ANY_STORAGE) + else: + assert not arg_repset.is_empty() + + arg_repset = self.make_valid_tensor_repset_for_arg( + arg_repset, arg_node, texture_limits + ) + + args_repset_list.append(arg_repset) + common_arg_repset = common_arg_repset.make_intersect(arg_repset) + + # Repeat for output tensors. + outs_repset_list = TensorRepSetList([]) + common_out_repset = ANY_STORAGE + if num_tensors_in_node(op_node) == 1: + common_out_repset = make_filtered_tensor_repset( + op_node.meta["val"], outputs_repsets[0], texture_limits + ) + outs_repset_list.append(common_out_repset) + # Multiple output tensors + else: + for i, val in enumerate(op_node.meta["val"]): + assert isinstance(val, FakeTensor) + out_repset = make_filtered_tensor_repset( + val, outputs_repsets[i], texture_limits + ) + + outs_repset_list.append(out_repset) + common_out_repset = common_out_repset.make_intersect(out_repset) + + # Apply synchronization rules; if either all inputs/outputs must use the same + # representation, then only use a single underlying repset. + if self.sync_args_repr: + args_repset_list = TensorRepSetList([common_arg_repset]) + + if self.sync_outs_repr: + outs_repset_list = TensorRepSetList([common_out_repset]) + + # Finally, apply synchronization rules that sync inputs and outputs. If input + # or output repsets are updated, then maintain synchronization rules. + if self.sync_primary_io_repr: + assert self.primary_arg_idx is not None + + primary_in_repset = args_repset_list[self.primary_arg_idx] + primary_out_repset = outs_repset_list[0] + + primary_repset = primary_in_repset.make_intersect(primary_out_repset) + + if self.sync_args_repr: + args_repset_list = TensorRepSetList([primary_repset]) + else: + assert self.primary_arg_idx is not None + args_repset_list[self.primary_arg_idx] = primary_repset + + if self.sync_outs_repr: + outs_repset_list = TensorRepSetList([primary_repset]) + else: + assert self.primary_arg_idx is not None + outs_repset_list[0] = primary_repset + + # Save the resulting repsets + self.args_repset_list = args_repset_list + self.outs_repset_list = outs_repset_list + + # Check that synchronization rules are respected. + self.assert_sync_contraints() + + def __str__(self) -> str: + return f"OpRepSets(ins={self.args_repset_list}, outs={self.outs_repset_list})" + + def make_valid_tensor_repset_for_node_list_arg( + self, + arg_repsets: TensorRepSet, + arg_node: List[torch.fx.Node], + texture_limits: ImageExtents, + ) -> TensorRepSet: + """ + Wrapper around make_filtered_tensor_repset for a list of nodes. This will happen + for the cat operator, where the first argument is a list of nodes. + """ + # For variable length args, assume that they all need to use the same representation + # only one repset should be defined + common_tensor_repsets = arg_repsets + + for n in arg_node: + assert isinstance(n, torch.fx.Node) + common_tensor_repsets = common_tensor_repsets.make_intersect( + make_filtered_tensor_repset( + n.meta["val"], common_tensor_repsets, texture_limits + ) + ) + + return common_tensor_repsets + + def make_valid_tensor_repset_for_arg( + self, arg_repsets: TensorRepSet, arg_node: Any, texture_limits: ImageExtents + ) -> TensorRepSet: + """ + Helper function to call make_filtered_tensor_repset + """ + if isinstance(arg_node, torch.fx.Node) and is_single_tensor_node(arg_node): + return make_filtered_tensor_repset( + arg_node.meta["val"], arg_repsets, texture_limits + ) + elif isinstance(arg_node, list) and all( + is_single_tensor_node(n) for n in arg_node + ): + return self.make_valid_tensor_repset_for_node_list_arg( + arg_repsets, arg_node, texture_limits + ) + # Special case for getitem; return the repset of the particular val in the + # list of tensors that is being extracted. + elif ( + self.op_node.target == operator.getitem and arg_node == self.op_node.args[0] + ): + idx = self.op_node.args[1] + assert isinstance(idx, int) + return make_filtered_tensor_repset( + arg_node.meta["val"][idx], arg_repsets, texture_limits + ) + + raise NotImplementedError(f"Unhandled node type {arg_node}") + + def assert_sync_contraints(self) -> None: + if self.sync_args_repr: + assert len(self.args_repset_list) == 1 + + if self.sync_outs_repr: + assert len(self.outs_repset_list) == 1 + + if self.sync_primary_io_repr: + assert ( + self.args_repset_list[self.primary_arg_idx] == self.outs_repset_list[0] + ) + + def any_is_empty(self) -> bool: + return ( + self.args_repset_list.any_is_empty() or self.outs_repset_list.any_is_empty() + ) + + def get_arg_repset(self, i: int): + return self.args_repset_list[i] + + def get_out_repset(self, i: int): + return self.outs_repset_list[i] + + def try_constrain_with_arg_repset( + self, arg_i: int, source_repset: TensorRepSet + ) -> bool: + """ + Attempt to constrain the repsets of the tensors participating in this operator + based on an "existing" repset of an argument. The existing repset can have two + sources: + * A representation may have been determined for the argument already from a + prior operator + * The output repset of the operator which produces the argument + + If the existing repset of the argument is compatible with the current operator, + then constrain the repsets of this operator and apply synchronization rules. + + This process tries to minimize the number of transition nodes that will need to + be inserted by tag_memory_meta_pass.py by maintaining existing representations + for as long as possible. + """ + arg_current_repset = self.args_repset_list[arg_i] + + if arg_current_repset == source_repset: + return False + + if not arg_current_repset.any_in_common(source_repset): + return False + + if self.sync_primary_io_repr: + if not self.get_out_repset(0).any_in_common(source_repset): + return False + + # If this point is reached, then it is possible to constrain + self.args_repset_list[arg_i] = arg_current_repset.make_intersect(source_repset) + if self.sync_primary_io_repr and ( + arg_i == self.primary_arg_idx or self.sync_args_repr + ): + self.outs_repset_list[0] = arg_current_repset.make_intersect(source_repset) + + self.assert_sync_contraints() + return True + + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: + """ + For each tensor participating in the op, pick a representation for it among the + possible represetntation sets. + """ + args_repr_list = TensorReprList([]) + outs_repr_list = TensorReprList([]) + + for i in range(len(self.op_node.args)): + arg_repset = self.args_repset_list[i] + args_repr_list.append(arg_repset.make_tensor_repr()) + + for i in range(num_tensors_in_node(self.op_node)): + out_repset = self.outs_repset_list[i] + outs_repr_list.append(out_repset.make_tensor_repr()) + + return args_repr_list, outs_repr_list ## @@ -282,6 +955,10 @@ def possible_node_memory_layouts( ## +def has_node_spec_attr(node: torch.fx.Node, attr: str) -> bool: + return "spec" in node.meta and hasattr(node.meta["spec"], attr) + + def set_node_spec_attr(node: torch.fx.Node, attr: str, value): assert "spec" in node.meta spec = node.meta["spec"] @@ -327,6 +1004,30 @@ def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: return get_node_spec_attr(node, "vk_memory_layout") +def has_node_repr(node) -> bool: + if isinstance(node, (list, tuple)): + return all(has_node_spec_attr(n, "etvk_node_repr") for n in node) + else: + return has_node_spec_attr(node, "etvk_node_repr") + + +def set_node_repr(node: torch.fx.Node, node_repr: Union[TensorRepr, TensorReprList]): + if isinstance(node_repr, TensorReprList): + # Convert to a regular list so taht `set_node_spec_attr` can attach each entry + # to a separate TensorSpec + node_repr_list = [node_repr[i] for i in range(num_tensors_in_node(node))] + set_node_spec_attr(node, "etvk_node_repr", node_repr_list) + else: + set_node_spec_attr(node, "etvk_node_repr", node_repr) + + +def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: + if isinstance(node, (list, tuple)): + raise NotImplementedError("get_node_repr not implemented for list of nodes") + else: + return get_node_spec_attr(node, "etvk_node_repr", False) + + ## ## Misc ##