diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 5d0ac832237..a286bf8b1ae 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -23,6 +23,7 @@ from .decompose_minmaxdim import DecomposeMinMaxDim from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu +from .decompose_threshold import DecomposeThreshold from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim @@ -65,6 +66,7 @@ DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, + DecomposeThreshold, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, diff --git a/backends/qualcomm/_passes/decompose_threshold.py b/backends/qualcomm/_passes/decompose_threshold.py new file mode 100644 index 00000000000..0f0a1bc4ea8 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_threshold.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +class DecomposeModule(torch.nn.Module): + def __init__(self, threshold, value): + super().__init__() + self.threshold = threshold + self.value = value + + def forward(self, x): + return torch.where(x <= self.threshold, self.value, x) + + +class DecomposeThreshold(ExportPass): + """ + Decompose threshold to less_equal and where. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target in { + torch.ops.aten.threshold_.default, + torch.ops.aten.threshold.default, + }: + input_node = node.args[0] + threshold = node.args[1] + value = node.args[2] + + model = DecomposeModule(threshold, value) + decomposed_module = torch.export.export( + model, (input_node.meta["val"],), strict=True + ).module() + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": input_node} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index f5c5915cab2..52bdf7fa090 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -51,6 +51,7 @@ class TensorOpInfo: # The scalar number arg[1] is missing when using default. Result in a corner case to deal aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False), aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False), + aten.where.ScalarSelf: TensorOpInfo(aten.where.self, False, True), aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), aten.masked_fill.Scalar: TensorOpInfo(aten.masked_fill.Tensor, False, False), diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 6e1369326fa..a377f0f4eb4 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -28,6 +28,7 @@ DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, + DecomposeThreshold, DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -200,6 +201,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeThreshold()) self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) @@ -216,6 +218,7 @@ def transform_for_export_pipeline( self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) + self.add_pass(DecomposeThreshold()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) self.add_pass(DecomposeWrapWithAutocast()) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index bc2b62c8c0b..8cbf3a50e22 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -176,7 +176,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): user_0 = self.get_first_user(node) if "convolution" in user_0.target.__name__: # OIHW (pytorch) -> HWIO (QNN) - quant_config[QCOM_AXIS] = 3 + quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1 quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0) elif "linear" in user_0.target.__name__: # OI (pytorch) -> OI (QNN) @@ -218,7 +218,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): user_0 = self.get_first_user(node) # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO if "convolution" in user_0.target.__name__: - quant_config[QCOM_AXIS] = 3 + quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1 else: quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index dbed10ced46..e7fd84e8e79 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -42,6 +42,8 @@ def define_node( # permutation permute_order = cast(List[int], node.args[1]) + # to prevent negative values + permute_order = [x % len(permute_order) for x in permute_order] permute_order_shape = [len(permute_order)] output_tensor = input_tensor.permute(permute_order) diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 6f1ef47c2ee..cf403a1a76d 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -1358,7 +1358,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: ) -@register_annotator([torch.ops.aten.where.self]) +@register_annotator([torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf]) def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return @@ -1368,7 +1368,6 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: assert isinstance(input_node, Node) if _is_float_tensor(input_node): input_qspec_map[input_node] = quantization_config.input_activation - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=( diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 5943b54d968..44d129d5544 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -161,6 +161,7 @@ def __post_init__(self): { torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, torch.ops.aten.conv_transpose2d.input, } ) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index cf4b2f21aaa..7b1663d09f6 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -598,28 +598,6 @@ def forward(self, x): return self.second(self.first(x)) -class Conv3dSequential(torch.nn.Module): - def __init__(self, bias=True): - super().__init__() - self.first = torch.nn.Conv3d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3, 3), - padding=1, - bias=bias, - ) - self.second = torch.nn.Conv3d( - in_channels=3, - out_channels=2, - kernel_size=(3, 3, 3), - padding=1, - bias=bias, - ) - - def forward(self, x): - return self.second(self.first(x)) - - class Conv2dSingle(torch.nn.Module): def __init__( self, @@ -726,6 +704,28 @@ def forward(self, x): return topk_values +class Conv3dSequential(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.first = torch.nn.Conv3d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + self.second = torch.nn.Conv3d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + + def forward(self, x): + return self.second(self.first(x)) + + class ConvTranspose1dSingle(torch.nn.Module): def __init__(self, bias=True, dilation=1): super().__init__() @@ -1507,6 +1507,15 @@ def forward(self, x): ) +class Permute(torch.nn.Module): + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x): + return x.permute(self.dims) + + class PixelShuffle(torch.nn.Module): def __init__(self, scale): super().__init__() @@ -1540,11 +1549,12 @@ def forward(self, x): class PowTensorScalar(torch.nn.Module): - def __init__(self): + def __init__(self, exponent=2): super().__init__() + self.exponent = exponent def forward(self, x): - return torch.pow(x, 2) + return torch.pow(x, self.exponent) class PReLUDefault(torch.nn.Module): @@ -2001,6 +2011,19 @@ def forward(self, x): return torch.tanh(x) +class Threshold(torch.nn.Module): + def __init__(self, threshold=0.0, value=0.0, inplace=False): + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.threshold( + x, threshold=self.threshold, value=self.value, inplace=self.inplace + ) + + class TopKandIndex(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index e3cf52b9a6f..3d347137bf5 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1117,6 +1117,16 @@ def test_qnn_backend_pad(self): sample_input = (torch.randn([1, 8, 128]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_permute(self): + modules = [ + Permute([0, 2, 3, 1]), # noqa: F405 + Permute([-1, -3, -2, -4]), # noqa: F405 + ] + sample_input = (torch.randn([2, 3, 4, 5]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 sample_input = (torch.ones([2, 4, 3, 3]),) @@ -1128,9 +1138,28 @@ def test_qnn_backend_pixel_unshuffle(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): - module = PowTensorScalar() # noqa: F405 - sample_input = (torch.rand([2, 4, 3, 3]),) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [ + PowTensorScalar(), # noqa: F405 + PowTensorScalar(1), # noqa: F405 + PowTensorScalar(-1), # noqa: F405 + PowTensorScalar(0.5), # noqa: F405 + ], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)], + }, + { + QCOM_MODULE: [PowTensorScalar(10)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) * 0.5 + 0.5,)], + }, + ] + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_prelu(self): test_comb = [ @@ -1321,6 +1350,17 @@ def test_qnn_backend_tanh(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_threshold(self): + modules = [ + Threshold(), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=True), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=False), # noqa: F405 + ] + sample_input = (torch.randn(2, 5, 1, 3),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unflatten(self): module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 sample_input = (torch.randn([1, 24]),) @@ -2818,6 +2858,17 @@ def test_qnn_backend_pad(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_permute(self): + modules = [ + Permute([0, 2, 3, 1]), # noqa: F405 + Permute([-1, -3, -2, -4]), # noqa: F405 + ] + sample_input = (torch.randn([2, 3, 4, 5]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 sample_input = (torch.ones([2, 4, 3, 3]),) @@ -2831,10 +2882,29 @@ def test_qnn_backend_pixel_unshuffle(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): - module = PowTensorScalar() # noqa: F405 - sample_input = (torch.rand([2, 4, 3, 3]),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [ + PowTensorScalar(), # noqa: F405 + PowTensorScalar(1), # noqa: F405 + PowTensorScalar(-1), # noqa: F405 + PowTensorScalar(0.5), # noqa: F405 + ], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)], + }, + { + QCOM_MODULE: [PowTensorScalar(10)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) * 0.5 + 0.5,)], + }, + ] + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) def test_qnn_backend_prelu(self): test_comb = [ @@ -2853,8 +2923,8 @@ def test_qnn_backend_prelu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) index += 1 def test_qnn_backend_relu(self): @@ -3057,6 +3127,18 @@ def test_qnn_backend_tanh(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_threshold(self): + modules = [ + Threshold(), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=True), # noqa: F405 + Threshold(threshold=0.5, value=3.0, inplace=False), # noqa: F405 + ] + sample_input = (torch.randn(2, 5, 1, 3),) + for i, module in enumerate(modules): + with self.subTest(i=i): + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_unflatten(self): module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 sample_input = (torch.randn([1, 24]),) diff --git a/backends/test/suite/flows/qualcomm.py b/backends/test/suite/flows/qualcomm.py index 9998caa51b6..99deb3d4877 100644 --- a/backends/test/suite/flows/qualcomm.py +++ b/backends/test/suite/flows/qualcomm.py @@ -42,7 +42,7 @@ def create_quantize_stage() -> Quantize: QNN_TEST_FLOW = _create_qnn_flow("qnn") QNN_16A16W_TEST_FLOW = _create_qnn_flow( - "qnn_16a16w", quantize=True, quant_dtype=QuantDtype.use_8a8w, use_fp16=False + "qnn_16a16w", quantize=True, quant_dtype=QuantDtype.use_16a16w, use_fp16=False ) QNN_16A8W_TEST_FLOW = _create_qnn_flow( "qnn_16a8w", quantize=True, quant_dtype=QuantDtype.use_16a8w, use_fp16=False diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index fa5ec2566d4..7475af29e15 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -70,7 +70,9 @@ def __init__(self, test_runner): self._test_runner = test_runner def _test_op(self, model, args, flow, generate_random_test_inputs=True): - self._test_runner.lower_and_run_model(model, args) + self._test_runner.lower_and_run_model( + model, args, generate_random_test_inputs=generate_random_test_inputs + ) def wrap_test(original_func, test_type): diff --git a/backends/test/suite/operators/test_rsqrt.py b/backends/test/suite/operators/test_rsqrt.py index 705833194fb..bb51b213dd4 100644 --- a/backends/test/suite/operators/test_rsqrt.py +++ b/backends/test/suite/operators/test_rsqrt.py @@ -37,15 +37,28 @@ def test_rsqrt_dtype(self, flow: TestFlow, dtype) -> None: def test_rsqrt_shapes(self, flow: TestFlow) -> None: # Test with different tensor shapes - # 1D tensor - self._test_op(RsqrtModel(), (torch.rand(20) + 0.01,), flow) - + self._test_op( + RsqrtModel(), + (torch.rand(20) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 2D tensor - self._test_op(RsqrtModel(), (torch.rand(5, 10) + 0.01,), flow) + self._test_op( + RsqrtModel(), + (torch.rand(5, 10) + 0.01,), + flow, + generate_random_test_inputs=False, + ) # 3D tensor - self._test_op(RsqrtModel(), (torch.rand(3, 4, 5) + 0.01,), flow) + self._test_op( + RsqrtModel(), + (torch.rand(3, 4, 5) + 0.01,), + flow, + generate_random_test_inputs=False, + ) @unittest.skip("NaN and Inf are not enforced for backends.") def test_rsqrt_edge_cases(self, flow: TestFlow) -> None: diff --git a/backends/test/suite/operators/test_sqrt.py b/backends/test/suite/operators/test_sqrt.py index 3d327ade6a5..92fbc64878e 100644 --- a/backends/test/suite/operators/test_sqrt.py +++ b/backends/test/suite/operators/test_sqrt.py @@ -39,13 +39,19 @@ def test_sqrt_shapes(self, flow: TestFlow) -> None: # Test with different tensor shapes # 1D tensor - self._test_op(SqrtModel(), (torch.rand(20),), flow) + self._test_op( + SqrtModel(), (torch.rand(20),), flow, generate_random_test_inputs=False + ) # 2D tensor - self._test_op(SqrtModel(), (torch.rand(5, 10),), flow) + self._test_op( + SqrtModel(), (torch.rand(5, 10),), flow, generate_random_test_inputs=False + ) # 3D tensor - self._test_op(SqrtModel(), (torch.rand(3, 4, 5),), flow) + self._test_op( + SqrtModel(), (torch.rand(3, 4, 5),), flow, generate_random_test_inputs=False + ) @unittest.skip("NaN and Inf are not enforced for backends.") def test_sqrt_edge_cases(self, flow: TestFlow) -> None: