diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index 23481894f0d..c3c42ed483a 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,19 +1,14 @@ import warnings -from collections import OrderedDict from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import ( - QCOM_DATA, - QCOM_DTYPE, - QCOM_QUANT_ATTRS, -) +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS from executorch.exir.dialects._ops import ops as exir_ops -from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP +from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP from .node_visitor_manager import register_node_visitor from .qnn_constants import ( OpConcat, @@ -31,7 +26,7 @@ class IndexPutVisitor(NodeVisitor): def __init__(self, *args) -> None: super().__init__(*args) - def define_node( # noqa: C901 + def define_node( self, node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], @@ -42,7 +37,6 @@ def define_node( # noqa: C901 if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = quant_attrs.copy() input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs - input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -52,110 +46,52 @@ def define_node( # noqa: C901 nodes_to_wrappers, ) - indices_nodes = ( - node.args[1] if isinstance(node.args[1], list) else [node.args[1]] - ) + indicies_node = node.args[1] + index_node_dim = None + index_nodes = [] + index_tensors = [] target_index = [] - all_range_index = OrderedDict() - index_dtype = [ - node.meta["val"].dtype for node in indices_nodes if node is not None - ][0] - - # preprocess: - # - broadcast dimension for multiple specified index - # - broadcast specified index if dimensions are not matched - max_indices_in_specified_index = 0 - for index, idx_node in enumerate(indices_nodes): - if isinstance(idx_node, torch.fx.Node): - last_specified_index_node = index - if max_indices_in_specified_index < idx_node.meta["val"].nelement(): - max_indices_in_specified_index = idx_node.meta["val"].nelement() # If there is None in a list, it means all range at that dimension - for index, idx_node in enumerate(indices_nodes): - # First, collect the index_node and index of None to construct the shape of index node - # E.g., shape of input: [1, 1024, 12, 64] - # For "None" axis (assume indices_node: [None, None, aten__to_copy_default_1]), - # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 - if isinstance(idx_node, torch.fx.Node): - # e.g. for case [index_node_0, None, index_node_1], nodes will have the same number of indices - target_index.append( - self.get_tensor(idx_node, idx_node).nelement() - if last_specified_index_node == index - else 1 - ) - elif idx_node is None: - # E.g., indices_node: [None, None, aten__to_copy_default_1] - all_range_index[index] = torch.arange( - input_tensor.size(index), dtype=index_dtype - ) - target_index.append(input_tensor.size(index)) - else: - warnings.warn( - f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", - stacklevel=1, - ) - return - - # preprocess all range indices if any - if None in indices_nodes: - all_range_tensor = torch.cartesian_prod(*all_range_index.values()) - # repeat all_range_tensor interleavely for future concatenation - # e.g. input_node = [5, 4, 3, 2], indices = [index_0_node, None, index_2_node] - # index_0.shape == index_2.shape == 2 (will guarantee this condition) - # where user specified (3, 4) for index_0, (0, 1) for index_2 - # --- - # we should have all_range_tensor: [0, 1, 2, 3] - # repeat interleavely with 2 to match future tiled index_0_node & index_2_node - # we'll have 1(index_0 -> same as index_2)*4(index_1)*2(index_2) indices in total: - # | index_0_node | None | index_2_node | - # | 3 | 0 | 0 | - # | 4 | 0 | 1 | - # | 3 | 1 | 0 | - # | 4 | 1 | 1 | - # | 3 | 2 | 0 | - # | 4 | 2 | 1 | - # | 3 | 3 | 0 | - # | 4 | 3 | 1 | - all_range_tensor_aug = all_range_tensor.repeat_interleave( - max_indices_in_specified_index, dim=0 - ) - for index in all_range_index.keys(): - # Repeat index for "None" axis in indices_nodes - range_index_node = torch.fx.Node( - node.graph, - node.name + f"_all_range_index_{index}", - "call_function", - exir_ops.edge.aten.tensor.default, - (), # args - {}, # kwargs - ) - range_indices = ( - ( - all_range_tensor_aug[:, index] - if all_range_tensor_aug.dim() > 1 - else - # if there is only one None - all_range_tensor_aug + # E.g., indicies_node: [None, None, aten__to_copy_default_1] + if isinstance(indicies_node, list): + for index, idx_node in enumerate(indicies_node): + # First, collect the indice_node and index of None to construct the shape of index node + # E.g., shape of input: [1, 1024, 12, 64] + # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]), + # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 + if isinstance(idx_node, torch.fx.Node): + index_nodes.append(idx_node) + index_tensors.append(self.get_tensor(idx_node, idx_node)) + target_index.extend(index_tensors[-1].size()) + index_node_dim = index + elif idx_node is None and index_node_dim is None: + # E.g., indicies_node: [None, aten__to_copy_default_1, None] + # Don't need to consider "None" after index_node. + target_index.append(input_tensor.size(index)) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", + stacklevel=1, ) - .reshape(-1, 1) - .contiguous() - ) - target_index_tensor_wrapper = self.define_tensor( - range_index_node, - node, - range_indices, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) - # store it for future concatenation - all_range_index[index] = (range_indices, target_index_tensor_wrapper) + return + # Assume that there is only one node in list + assert len(index_nodes) == 1, "Not support multiple indices tensor" + indice_node = index_nodes[0] + indice_tensor = index_tensors[0] + indices_tensor_wrapper = self.define_tensor( + indice_node, + node, + indice_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) # Need to reconstruct the index tensor. # E.g., based on ScatterND Op Def in QNN Docs. # Torch: # Given that # shape of input: [1, 12, 1024, 64] - # indices_node: [None, None, aten__to_copy_default_1] + # indicies_node: [None, None, aten__to_copy_default_1] # shape of aten__to_copy_default_1: [1] # QNN: # Index tensor: @@ -168,135 +104,113 @@ def define_node( # noqa: C901 # update_indices = indices.shape[:-1] # for idx in np.ndindex(update_indices): # output[indices[idx]] = updates[idx] - specified_index = OrderedDict() - for i, indices_node in enumerate(indices_nodes): - if indices_node is None: - continue - indices_tensor = self.get_tensor(indices_node, indices_node) - indices_tensor_wrapper = self.define_tensor( - indices_node, - node, - indices_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - if indices_tensor.nelement() < max_indices_in_specified_index: - # broadcast the specified index - indices_tensor = indices_tensor.repeat(max_indices_in_specified_index) - indices_multiples = [max_indices_in_specified_index] - indices_multiples_shape = [len(indices_multiples)] - indices_tile_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + f"_indices_tile_{i}", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[indices_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, - quant_configs={}, - dims=indices_tensor.size(), - tensor=indices_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - tile_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpTile.op_name, - ) - tile_op.AddInputTensors([indices_tensor_wrapper]) - tile_op.AddOutputTensors([indices_tile_tensor_wrapper]) - tile_op.AddTensorParam( - OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(indices_multiples_shape), - indices_multiples_shape, - np.array(indices_multiples, dtype=np.uint32), - True, - ) - op_wrapper_list.append(tile_op) - indices_tensor_wrapper = indices_tile_tensor_wrapper + # Append one dimension to specify x-tuple + index_shape = target_index + [1] + # Reshape the index_node for tile op + reshape_shape = [ + shape if id == index_node_dim else 1 for id, shape in enumerate(index_shape) + ] + reshape_output_tensor = indice_tensor.reshape(reshape_shape) + reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_reshape", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=reshape_output_tensor.size(), + tensor=reshape_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + reshape_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + reshape_op.AddInputTensors([indices_tensor_wrapper]) + reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) + op_wrapper_list.append(reshape_op) + index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper - # Append one dimension to specify x-tuple - # Reshape the index_node for tile op - reshape_shape = list(indices_tensor.shape) + [1] - reshape_output_tensor = indices_tensor.reshape(reshape_shape) - reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + f"_reshape_{i}", + # Tile the index_node and concat the target index + if None in indicies_node: + tile_output_tensor = reshape_output_tensor.expand(index_shape) + # Tile the index_node to align with the shape of target_index + # Only need to tile the dim of None axis + # E.g., indicies_node: [None, None, aten__to_copy_default_1] + # Should tile the first two dimension. + multiples = [ + shape if id != index_node_dim else 1 + for id, shape in enumerate(index_shape) + ] + multiples_shape = [len(index_shape)] + tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_tile", tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], + dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, quant_configs={}, - dims=reshape_output_tensor.size(), - tensor=reshape_output_tensor, + dims=tile_output_tensor.size(), + tensor=tile_output_tensor, is_fake_tensor=True, nodes_to_wrappers=nodes_to_wrappers, ) - reshape_op = PyQnnWrapper.PyQnnOpWrapper( + tile_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, + OpTile.op_name, ) - reshape_op.AddInputTensors([indices_tensor_wrapper]) - reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) - op_wrapper_list.append(reshape_op) - index_tensor_wrapper = reshape_output_tensor_wrapper - index_tensor = reshape_output_tensor - - # Tile the index_node and concat the target index - if None in indices_nodes: - tile_output_tensor = reshape_output_tensor.repeat( - all_range_tensor.size(0), 1 - ) - # Tile the index_node to align with the shape of target_index - # Only need to tile the dim of None axis - # E.g., indices_node: [None, None, aten__to_copy_default_1] - # Should tile the number of indices combination of first two dimension - # times number of indices specified by aten__to_copy_default_1 - multiples = [all_range_tensor.size(0), 1] - multiples_shape = [len(multiples)] - tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + f"_tile_{i}", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, - quant_configs={}, - dims=tile_output_tensor.size(), - tensor=tile_output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - tile_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpTile.op_name, - ) - tile_op.AddInputTensors([reshape_output_tensor_wrapper]) - tile_op.AddOutputTensors([tile_output_tensor_wrapper]) - tile_op.AddTensorParam( - OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(multiples_shape), - multiples_shape, - np.array(multiples, dtype=np.uint32), - True, - ) - op_wrapper_list.append(tile_op) - index_tensor_wrapper = tile_output_tensor_wrapper - index_tensor = tile_output_tensor - - specified_index[i] = (index_tensor, index_tensor_wrapper) + tile_op.AddInputTensors([reshape_output_tensor_wrapper]) + tile_op.AddOutputTensors([tile_output_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(multiples_shape), + multiples_shape, + np.array(multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) - # Concat target_index and tile output to reconstruct index_node - # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype - index_tensors, index_tensor_wrappers = [], [] - for i, arg in enumerate(indices_nodes): - tensor, tensor_wrapper = ( - all_range_index[i] if arg is None else specified_index[i] + # Repeat index for "None" axis in indicies_node + ranges = [ + torch.arange(dim, dtype=indice_tensor.dtype) + for dim in target_index[:-1] + ] + target_index_shape = target_index + [len(ranges)] + target_index_tensor = torch.cartesian_prod(*ranges) + reshape_target_index_shape = [ + shape if id != index_node_dim else 1 + for id, shape in enumerate(target_index_shape) + ] + target_index_tensor = target_index_tensor.reshape( + reshape_target_index_shape + ) + target_index_tensor = target_index_tensor.expand( + target_index_shape + ).contiguous() + target_index_node = torch.fx.Node( + node.graph, + node.name + "_target_index", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + target_index_tensor_wrapper = self.define_tensor( + target_index_node, + node, + target_index_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, ) - index_tensors.append(tensor) - index_tensor_wrappers.append(tensor_wrapper) - if len(index_tensor_wrappers) > 1: - concat_output_tensor = torch.concat(index_tensors, dim=-1) + # Concat target_index and tile output to reconstruct index_node + # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype + concat_output_tensor = torch.concat( + (target_index_tensor, tile_output_tensor), dim=-1 + ) concat_output_tensor_wrapper = self.define_custom_tensor_wrapper( node_name=node.name + "_concat", tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, @@ -313,7 +227,9 @@ def define_node( # noqa: C901 QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) - concat_op.AddInputTensors(index_tensor_wrappers) + concat_op.AddInputTensors( + [target_index_tensor_wrapper, tile_output_tensor_wrapper] + ) concat_op.AddOutputTensors([concat_output_tensor_wrapper]) concat_op.AddScalarParam( OpConcat.param_axis, @@ -321,6 +237,7 @@ def define_node( # noqa: C901 {QCOM_DATA: np.uint32(concat_output_tensor.dim() - 1)}, ) op_wrapper_list.append(concat_op) + index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper value_node = self.get_node(node.args[2]) value_tensor = self.get_tensor(value_node, node) @@ -331,94 +248,6 @@ def define_node( # noqa: C901 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - # handle broadcast scenario - # e.g. input_tensor: (1, 12, 1024, 64), value_tensor: (1, 64) - # => value_reshape_tensor: (1, 1, 1, 64) - new_value_shape = ( - *([1] * (input_tensor.dim() - value_tensor.dim())), - *value_tensor.shape, - ) - # reshape the value_node for tile op - value_quant_encoding, value_quant_configs = self.get_quant_encoding_conf( - value_node, node - ) - value_dtype = ( - QNN_TENSOR_TYPE_MAP[value_tensor.dtype] - if value_quant_encoding - == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED - else QNN_QUANT_TYPE_MAP[ - ( - torch.uint16 - if value_quant_configs[QCOM_DTYPE] == torch.int32 - else value_quant_configs[QCOM_DTYPE] - ) - ] - ) - value_reshape_tensor = value_tensor.reshape(new_value_shape) - value_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_value_reshape", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=value_dtype, - quant_encoding=value_quant_encoding, - quant_configs=value_quant_configs, - dims=value_reshape_tensor.size(), - tensor=value_reshape_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - value_reshape_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, - ) - value_reshape_op.AddInputTensors([value_tensor_wrapper]) - value_reshape_op.AddOutputTensors([value_reshape_tensor_wrapper]) - op_wrapper_list.append(value_reshape_op) - - # e.g. input_tensor: (1, 12, 1024, 64), index_tensor: (None, None, 2), value_tensor: (1, 64) - # => multiples: [1, 12, 2, 1] - value_multiples = [] - for i in range(input_tensor.dim() - 1, -1, -1): - if i in specified_index: - # all user specified index node wil have the same dimension - multiplier = ( - indices_nodes[i].meta["val"].nelement() // new_value_shape[i] - if i == last_specified_index_node - else 1 - ) - else: - multiplier = input_tensor.shape[i] // new_value_shape[i] - value_multiples.insert(0, multiplier) - - value_tile_tensor = value_reshape_tensor.repeat(value_multiples) - value_multiples_shape = [len(value_multiples)] - value_tile_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_value_tile", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=value_dtype, - quant_encoding=value_quant_encoding, - quant_configs=value_quant_configs, - dims=value_tile_tensor.size(), - tensor=value_tile_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - value_tile_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpTile.op_name, - ) - value_tile_op.AddInputTensors([value_reshape_tensor_wrapper]) - value_tile_op.AddOutputTensors([value_tile_tensor_wrapper]) - value_tile_op.AddTensorParam( - OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(value_multiples_shape), - value_multiples_shape, - np.array(value_multiples, dtype=np.uint32), - True, - ) - op_wrapper_list.append(value_tile_op) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( @@ -434,46 +263,11 @@ def define_node( # noqa: C901 QNN_OP_PACKAGE_NAME_QTI_AISW, OpScatterNd.op_name, ) - # accumulation - if len(node.args) > 3 and node.args[3]: - index_put_op.AddScalarParam( - OpScatterNd.param_reduction, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: 1}, - ) - - # check final index_input tensor - index_input_tensor, index_input_tensor_wrapper = ( - (concat_output_tensor, concat_output_tensor_wrapper) - if len(index_tensor_wrappers) > 1 - else specified_index[last_specified_index_node] - ) - target_index_reshape_tensor = index_input_tensor.reshape((*target_index, -1)) - target_index_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_target_index_reshape", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[target_index_reshape_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, - quant_configs={}, - dims=target_index_reshape_tensor.size(), - tensor=target_index_reshape_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - target_index_reshape_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, - ) - target_index_reshape_op.AddInputTensors([index_input_tensor_wrapper]) - target_index_reshape_op.AddOutputTensors([target_index_reshape_tensor_wrapper]) - op_wrapper_list.append(target_index_reshape_op) - index_put_op.AddInputTensors( [ input_tensor_wrapper, - target_index_reshape_tensor_wrapper, - value_tile_tensor_wrapper, + index_put_index_input_tensor_wrapper, + value_tensor_wrapper, ] ) index_put_op.AddOutputTensors([output_tensor_wrapper]) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 10644e17c79..22cb47ee288 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -55,7 +55,7 @@ def define_node( mean_dims = [dim_arg] else: mean_dims = list(dim_arg) - + print("mean_dims: ", mean_dims, "rank: ", rank) mean_dims = [ mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 5ea6caf54ad..3240ad7a018 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1141,62 +1141,20 @@ def forward(self, input_pos, k_val): class IndexPut(torch.nn.Module): - def __init__(self, skip_mutable_buffer=False, mode=0): + def __init__(self, skip_mutable_buffer=False): super().__init__() self.skip_mutable_buffer = skip_mutable_buffer self.register_buffer( "k_cache", - torch.zeros((2, 1024, 12, 64), dtype=torch.float32), + torch.zeros((1, 1024, 12, 64), dtype=torch.float32), persistent=True, ) - self.mode = mode def forward(self, input_pos, k_val): - match self.mode: - case 0: - k_out = torch.ops.aten.index_put_(self.k_cache, [input_pos], k_val) - case 1: - k_out = torch.ops.aten.index_put_( - self.k_cache, [None, input_pos], k_val - ) - case 2: - k_out = torch.ops.aten.index_put_( - self.k_cache, [None, None, input_pos], k_val - ) - case 3: - k_out = torch.ops.aten.index_put_( - self.k_cache, [input_pos[0], input_pos[1]], k_val - ) - case 4: - k_out = torch.ops.aten.index_put_( - self.k_cache, [None, input_pos[0], input_pos[1]], k_val - ) - case 5: - k_out = torch.ops.aten.index_put_( - self.k_cache, [input_pos[0], None, input_pos[1]], k_val - ) - + k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) return k_out + 0 -class IndexPutSuite(torch.nn.Module): - def __init__(self, accumulate=False, in_place=False): - super().__init__() - self.accumulate = accumulate - self.in_place = in_place - - def forward(self, x, indices, values): - if self.in_place: - # Clone the input to avoid modifying it in-place - result = x.clone() - # Apply index_put_ and return the modified tensor - result.index_put_(indices, values, self.accumulate) - return result - else: - # Use the non-in-place variant which returns a new tensor - return torch.index_put(x, indices, values, self.accumulate) - - class IndexSelect(torch.nn.Module): def __init__(self, dim): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 2641acc5a2d..56983561e5f 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import io -import itertools import json import subprocess import sys @@ -888,191 +887,28 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - skip_mutable_buffer = [False, True] - total_test_combo = [] - # mode 0 - sample_inputs = [ - (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), - (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 1 - sample_inputs = [ - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 2 - sample_inputs = [ - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 3 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([2, 12, 64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([1, 64]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 4 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([2, 64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([1, 64]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 5 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), ), - torch.randn([64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), ), - torch.randn([1]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - - for i, test_combo in enumerate(total_test_combo): - for j, combo in enumerate(test_combo): - with self.subTest(f"mode_{i}-{j}"): - self.lower_module_and_test_output( - IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 - combo[1], - skip_mutable_buffer=combo[0], - ) - - def test_qnn_backend_index_put_suite(self): - accumulate = [False, True] - in_place = [False, True] - sample_inputs = [ - # basic - ( - torch.rand(5, 2) * 100, - (torch.tensor([0, 2]),), - torch.tensor([10.0, 20.0]), - ), - (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), - # shape - (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), - ( - torch.rand(5, 2), - (torch.tensor([0, 2]), torch.tensor([1, 1])), - torch.tensor([10.0, 20.0]), - ), - ( - torch.rand(5, 3, 2), - (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), - torch.tensor([10.0, 20.0]), - ), - # TODO: not supported by HTP - # ( - # torch.rand(5, 3, 2, 4), - # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), - # torch.tensor([10.0]), - # ), - # indices - (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), - ( - torch.rand(5, 3), - (torch.tensor([0, 2, 4]),), - torch.tensor([10.0, 20.0, 30.0]), - ), - ( - torch.rand(5), - (torch.tensor([1, 1, 3, 3]),), - torch.tensor([10.0, 20.0, 30.0, 40.0]), - ), - # broadcasting - (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), - ( - torch.rand(3, 4), - (torch.tensor([0, 1]), torch.tensor([1, 2])), - torch.tensor([10.0, 20.0]), - ), - (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), - ( - torch.rand(3, 2, 2), - (torch.tensor([0, 1]),), - torch.tensor([[1.0, 2.0], [3.0, 4.0]]), - ), - (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), - # two-index - ( - torch.rand(4, 3), - (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), - torch.tensor([10.0, 20.0, 30.0]), - ), - ( - torch.rand(3, 3), - (torch.tensor([0, 2]), torch.tensor([1, 1])), - torch.tensor([15.0, 25.0]), - ), - ( - torch.rand(3, 2), - (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), - torch.tensor([5.0, 10.0, 15.0]), - ), - ( - torch.rand(3, 2), - (torch.tensor([1]), torch.tensor([0, 0, 1])), - torch.tensor([5.0, 10.0, 15.0]), - ), + }, ] - test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) - for i, combo in enumerate(test_combo): + for i, test in enumerate(test_comb): with self.subTest(i=i): self.lower_module_and_test_output( - IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 - combo[2], + test[QCOM_MODULE], + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, ) def test_qnn_backend_index_select(self): @@ -2806,197 +2642,32 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - skip_mutable_buffer = [False, True] - total_test_combo = [] - # mode 0 - sample_inputs = [ - (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), - (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 1 - sample_inputs = [ - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 2 - sample_inputs = [ - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), - (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), - (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), - (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 3 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([2, 12, 64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([1, 64]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 4 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([2, 64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), - ), - torch.randn([1, 64]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - # mode 5 - sample_inputs = [ - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), ), - torch.randn([64]), - ), - ( - ( - torch.tensor([0, 1], dtype=torch.int32), - torch.tensor([2, 3], dtype=torch.int32), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), ), - torch.randn([1]), - ), - ] - total_test_combo.append( - list(itertools.product(skip_mutable_buffer, sample_inputs)) - ) - - for i, test_combo in enumerate(total_test_combo): - for j, combo in enumerate(test_combo): - with self.subTest(f"mode_{i}-{j}"): - module = self.get_qdq_module( - IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 - combo[1], - ) - self.lower_module_and_test_output( - module, - combo[1], - skip_mutable_buffer=combo[0], - ) - - def test_qnn_backend_index_put_suite(self): - accumulate = [False, True] - in_place = [False, True] - sample_inputs = [ - # basic - ( - torch.rand(5, 2) * 100, - (torch.tensor([0, 2]),), - torch.tensor([10.0, 20.0]), - ), - (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), - # shape - (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), - ( - torch.rand(5, 2), - (torch.tensor([0, 2]), torch.tensor([1, 1])), - torch.tensor([10.0, 20.0]), - ), - ( - torch.rand(5, 3, 2), - (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), - torch.tensor([10.0, 20.0]), - ), - # TODO: not supported by HTP - # ( - # torch.rand(5, 3, 2, 4), - # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), - # torch.tensor([10.0]), - # ), - # indices - (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), - ( - torch.rand(5, 3), - (torch.tensor([0, 2, 4]),), - torch.tensor([10.0, 20.0, 30.0]), - ), - ( - torch.rand(5), - (torch.tensor([1, 1, 3, 3]),), - torch.tensor([10.0, 20.0, 30.0, 40.0]), - ), - # broadcasting - (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), - ( - torch.rand(3, 4), - (torch.tensor([0, 1]), torch.tensor([1, 2])), - torch.tensor([10.0, 20.0]), - ), - (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), - ( - torch.rand(3, 2, 2), - (torch.tensor([0, 1]),), - torch.tensor([[1.0, 2.0], [3.0, 4.0]]), - ), - (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), - # two-index - ( - torch.rand(4, 3), - (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), - torch.tensor([10.0, 20.0, 30.0]), - ), - ( - torch.rand(3, 3), - (torch.tensor([0, 2]), torch.tensor([1, 1])), - torch.tensor([15.0, 25.0]), - ), - ( - torch.rand(3, 2), - (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), - torch.tensor([5.0, 10.0, 15.0]), - ), - ( - torch.rand(3, 2), - (torch.tensor([1]), torch.tensor([0, 0, 1])), - torch.tensor([5.0, 10.0, 15.0]), - ), + }, ] - test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) - for i, combo in enumerate(test_combo): + for i, test in enumerate(test_comb): with self.subTest(i=i): module = self.get_qdq_module( - IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 - combo[2], + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output( + module, + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, ) - self.lower_module_and_test_output(module, combo[2]) def test_qnn_backend_index_select(self): module = IndexSelect(dim=1) # noqa: F405 diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 036c5060b12..11b9ab88bfe 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -918,34 +918,24 @@ def generate_inputs(dest_path: str, file_name: str, inputs=None): input_list_file = None input_files = [] - def prepare_input_file(tensor, fd, index, sub_index): - # transform torch.Tensor to raw file - input_file_name = f"input_{index}_{sub_index}.raw" - input_file_path = f"{dest_path}/{input_file_name}" - if not isinstance(tensor, torch.Tensor): - tensor = torch.tensor(tensor) - tensor.detach().numpy().tofile(input_file_path) - input_files.append(input_file_path) - # prepare input_list - if sub_index > 0: - fd.write(" ") - fd.write(input_file_name) - # Prepare input data if inputs is not None: input_list_file = f"{dest_path}/{file_name}" with open(input_list_file, "w") as f: for idx, data in enumerate(inputs): - sub_index = 0 - for d in data: - if isinstance(d, (list, tuple)): - for sub_d in d: - prepare_input_file(sub_d, f, idx, sub_index) - sub_index += 1 - else: - prepare_input_file(d, f, idx, sub_index) - sub_index += 1 - + for i, d in enumerate(data): + # transform torch.Tensor to raw file + file_name = f"input_{idx}_{i}.raw" + file_path = f"{dest_path}/{file_name}" + if not isinstance(d, torch.Tensor): + d = torch.tensor(d) + d.detach().numpy().tofile(file_path) + input_files.append(file_path) + + # prepare input_list + if i > 0: + f.write(" ") + f.write(file_name) f.write("\n") return input_list_file, input_files