diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 154a360689e..56aa1063674 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -10,6 +10,7 @@ from .annotate_unbind import AnnotateUnbind from .canonicalize_conv import CanonicalizeConv from .convert_bmm_to_matmul import ConvertBmmToMatmul +from .convert_pad_to_slice_concat import ConvertPadToSliceConcat from .convert_linear_to_conv2d import ConvertLinearToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny @@ -49,6 +50,7 @@ __all__ = [ AnnotateAdaptiveAvgPool1D, + ConvertPadToSliceConcat, AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, diff --git a/backends/qualcomm/_passes/convert_pad_to_slice_concat.py b/backends/qualcomm/_passes/convert_pad_to_slice_concat.py new file mode 100644 index 00000000000..5bb62ae8a4c --- /dev/null +++ b/backends/qualcomm/_passes/convert_pad_to_slice_concat.py @@ -0,0 +1,236 @@ +import operator +import torch +from torch.fx import GraphModule +from executorch.exir.pass_base import ExportPass, PassResult + +import operator +import torch +from torch.fx import GraphModule +from executorch.exir.pass_base import ExportPass, PassResult + +class ConvertPadToSliceConcat(ExportPass): + """ + Replace aten.pad(..., mode in {'circular','replicate'}) with slice+cat (+expand for replicate). + Supports 1D/2D (NCL / NCHW-like). SymInt-safe for torch.export graphs. + """ + + def __init__(self): + super().__init__() + + # ---------- small helpers ---------- + + def _copy_meta(self, src, dst, val_transform=None): + dst.meta = dict(getattr(src, "meta", {})) + if "val" in getattr(src, "meta", {}) and isinstance(src.meta["val"], torch.Tensor): + v = src.meta["val"] + if val_transform is not None: + try: + v = val_transform(v) + except Exception: + pass + dst.meta["val"] = v + + def _set_scalar_meta(self, node, dtype=torch.int64): + node.meta = getattr(node, "meta", {}) + node.meta["val"] = torch.tensor(0, dtype=dtype) + + def _sym_size(self, graph, x, dim): + if hasattr(torch.ops.aten, "sym_size"): + n = graph.create_node("call_function", torch.ops.aten.sym_size.int, (x, dim)) + else: + n = graph.create_node("call_function", torch.ops.aten.size.int, (x, dim)) + self._set_scalar_meta(n) + return n + + def _sym_sub(self, graph, a, b): + n = graph.create_node("call_function", operator.sub, (a, b)) + self._set_scalar_meta(n) + return n + + def _rank_from_meta(self, t): + r = None + if hasattr(t, "meta") and isinstance(t.meta.get("val", None), torch.Tensor): + r = t.meta["val"].dim() + return r + + def _expand_along_dim(self, graph, t, dim, new_len, before): + """ + Build aten.expand(t, new_sizes) where only 'dim' changes to new_len. + Works with SymInt sizes. new_len is a python int. + """ + with graph.inserting_before(before): + rank = self._rank_from_meta(t) + if rank is None: + # Fallback: grab sizes with sym_size one-by-one assuming up to 8 dims + # (most models are 4D here; if meta is missing, 4 is reasonable) + rank = 4 + sizes = [] + # convert negative dim to pos + pdim = dim % rank + for d in range(rank): + if d == pdim: + sizes.append(int(new_len)) + else: + sizes.append(self._sym_size(graph, t, d)) + n = graph.create_node("call_function", torch.ops.aten.expand.default, (t, sizes)) + # meta: broadcast view to the new shape if we have it + def _vt(v): + shape = list(v.shape) + shape[pdim] = int(new_len) + return v.expand(shape) + self._copy_meta(t, n, _vt) + return n + + # ---------- main entry ---------- + + def call(self, gm: GraphModule) -> PassResult: + g = gm.graph + modified = False + + for node in list(g.nodes): + if node.op == "call_function" and node.target == torch.ops.aten.pad.default: + # args: (x, pad, mode, [value]) + if len(node.args) < 3 or not isinstance(node.args[2], str): + continue + mode = node.args[2] + if mode not in ("circular", "replicate"): + continue + + x = node.args[0] + pad = list(node.args[1]) + ndim = len(pad) // 2 # 1D: (l,r) 2D: (l,r,t,b) + + if mode == "circular": + new_val = self._insert_circular(g, x, pad, ndim, before=node) + else: + new_val = self._insert_replicate(g, x, pad, ndim, before=node) + + self._copy_meta(node, new_val) + node.replace_all_uses_with(new_val) + g.erase_node(node) + modified = True + + if modified: + g.lint() + gm.recompile() + return PassResult(gm, modified) + + # ---------- rewrites ---------- + def _insert_circular(self, graph, x, pad, ndim, before): + with graph.inserting_before(before): + if ndim == 1: + left, right = pad + w = self._sym_size(graph, x, -1) + start = self._sym_sub(graph, w, left) + left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start, w)) + right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, right)) + self._copy_meta(x, left_slice) + self._copy_meta(x, right_slice) + out = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1)) + self._copy_meta(x, out, lambda t: torch.cat([t[..., -left:], t, t[..., :right]], dim=-1)) + return out + + if ndim == 2: + l, r, t, b = pad + # horiz + W = self._sym_size(graph, x, -1) + start_w = self._sym_sub(graph, W, l) + left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start_w, W)) + right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, r)) + self._copy_meta(x, left_slice) + self._copy_meta(x, right_slice) + x_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1)) + self._copy_meta(x, x_cat, lambda T: torch.cat([T[..., -l:], T, T[..., :r]], dim=-1)) + + # vert + H = self._sym_size(graph, x_cat, -2) + start_h = self._sym_sub(graph, H, t) + top_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, start_h, H)) + bot_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, 0, b)) + self._copy_meta(x_cat, top_slice) + self._copy_meta(x_cat, bot_slice) + y_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((top_slice, x_cat, bot_slice), -2)) + self._copy_meta(x_cat, y_cat, lambda T: torch.cat([T[..., -t:, :], T, T[..., :b, :]], dim=-2)) + return y_cat + + raise NotImplementedError(f"circular pad only supports 1D/2D, got pad={pad}") + + def _insert_replicate(self, graph, x, pad, ndim, before): + """ + Replicate: extend borders with edge values. + Implemented via slice (edge 1-wide) + expand + cat. + """ + with graph.inserting_before(before): + if ndim == 1: + left, right = pad + parts = [] + if left > 0: + left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1)) + self._copy_meta(x, left_edge) + left_pad = self._expand_along_dim(graph, left_edge, -1, left, before) + parts.append(left_pad) + parts.append(x) + if right > 0: + right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None)) + self._copy_meta(x, right_edge) + right_pad = self._expand_along_dim(graph, right_edge, -1, right, before) + parts.append(right_pad) + + out = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1)) + # meta + def _vt(t): + L = left; R = right + if L or R: + lp = t[..., :1].expand(*t.shape[:-1], L) if L else t[..., :0] + rp = t[..., -1:].expand(*t.shape[:-1], R) if R else t[..., :0] + return torch.cat([lp, t, rp], dim=-1) + return t + self._copy_meta(x, out, _vt) + return out + + if ndim == 2: + l, r, t, b = pad + # horizontal replicate first + parts = [] + if l > 0: + left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1)) + self._copy_meta(x, left_edge) + left_pad = self._expand_along_dim(graph, left_edge, -1, l, before) + parts.append(left_pad) + parts.append(x) + if r > 0: + right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None)) + self._copy_meta(x, right_edge) + right_pad = self._expand_along_dim(graph, right_edge, -1, r, before) + parts.append(right_pad) + + x_w = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1)) + self._copy_meta(x, x_w, lambda T: torch.cat([ + T[..., :1].expand(*T.shape[:-1], l) if l else T[..., :0], + T, + T[..., -1:].expand(*T.shape[:-1], r) if r else T[..., :0] + ], dim=-1) if (l or r) else T) + + # then vertical replicate on the widened tensor + parts2 = [] + if t > 0: + top_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, 0, 1)) + self._copy_meta(x_w, top_edge) + top_pad = self._expand_along_dim(graph, top_edge, -2, t, before) + parts2.append(top_pad) + parts2.append(x_w) + if b > 0: + bot_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, -1, None)) + self._copy_meta(x_w, bot_edge) + bot_pad = self._expand_along_dim(graph, bot_edge, -2, b, before) + parts2.append(bot_pad) + + out = parts2[0] if len(parts2) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts2), -2)) + self._copy_meta(x_w, out, lambda T: torch.cat([ + T[..., :1, :].expand(*T.shape[:-2], t, T.shape[-1]) if t else T[..., :0, :], + T, + T[..., -1:, :].expand(*T.shape[:-2], b, T.shape[-1]) if b else T[..., :0, :] + ], dim=-2) if (t or b) else T) + return out + + raise NotImplementedError(f"replicate pad only supports 1D/2D, got pad={pad}") diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 360581a2929..34caea1a27c 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -15,6 +15,7 @@ AnnotateUnbind, CanonicalizeConv, ConvertBmmToMatmul, + ConvertPadToSliceConcat, ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, @@ -211,12 +212,14 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) + self.add_pass(ConvertPadToSliceConcat()) self.add_pass(InsertReshapeForReduceOps()) return self._transform(graph_module) def transform_for_export_pipeline( self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False ): + self.add_pass(ConvertPadToSliceConcat()) self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 7832e180ebb..71c1b309af1 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -18,7 +18,13 @@ @register_node_visitor class Pad(NodeVisitor): - target = ["aten.constant_pad_nd.default"] + target = [ + "aten.constant_pad_nd.default", + "aten.pad.default", + # Add tests before adding these two to the list + # "aten.reflection_pad2d.default", + # "aten.replication_pad2d.default", + ] def __init__(self, *args) -> None: super().__init__(*args) @@ -49,48 +55,72 @@ def define_node( ) pad_output_tensors = [output_tensor_wrapper] + # ---- Pad amount ([rank, 2], uint32) ---- pad_amount_shape = [input_tensor.dim(), 2] - # pytorch padding start from the last index - pad_amount = np.reshape(cast(List[int], node.args[1]), (-1, 2))[::-1].astype( - np.uint32 - ) - # fulfill the pad amount for each idex of tensor - if zero_amounts := pad_amount_shape[0] - pad_amount.shape[0]: - pad_amount = np.concatenate( - (np.array([(0, 0)] * zero_amounts), pad_amount) - ).astype(np.uint32) + # PyTorch pad order is from the *last* dim: e.g. 2D = [L, R, T, B] + pad_amount = np.reshape( + np.array(cast(List[int], node.args[1]), dtype=np.int64), (-1, 2) + )[:: -1] # reverse to go from last->first to first->last + + # expand to all ranks if needed + if pad_amount_shape[0] - pad_amount.shape[0] > 0: + zeros = np.zeros((pad_amount_shape[0] - pad_amount.shape[0], 2), dtype=np.int64) + pad_amount = np.concatenate((zeros, pad_amount), axis=0) + # remap rows if backend axis order is provided (backend_pos -> pt_dim) if QCOM_AXIS_ORDER in node.meta: - pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])] - pad_amount_val = node.args[2] + axis_order = list(node.meta[QCOM_AXIS_ORDER]) # e.g. (0,2,3,1) + pad_amount = pad_amount[axis_order] + + pad_amount = pad_amount.astype(np.uint32, copy=False) + + # ---- Mode/scheme ---- + if len(node.args) >= 3 and isinstance(node.args[2], str): + mode = node.args[2] + else: + # default to constant + mode = "constant" + scheme_map = { + "constant": OpPad.Scheme.CONSTANT, + "reflect": OpPad.Scheme.MIRROR_REFLECT, + "replicate": OpPad.Scheme.EDGE, # I think this is supposed to be correct, but the result is wrong + } + scheme_u32 = np.uint32(scheme_map[mode]) + + # ---- Build op ---- pad_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpPad.op_name, + node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPad.op_name ) pad_op.AddInputTensors(pad_input_tensors) pad_op.AddOutputTensors(pad_output_tensors) - # For now, we only support constant (0) padding due to torch implementation pad_op.AddScalarParam( OpPad.param_scheme, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)}, + {QCOM_DATA: scheme_u32}, # scheme (UINT32) ) - pad_op.AddScalarParam( - OpPad.param_pad_constant_value, - QNN_TENSOR_TYPE_MAP[type(pad_amount_val)], - {QCOM_DATA: pad_amount_val}, - ) + # pad_constant_value only for constant mode + if mode == "constant": + pad_value = None + if len(node.args) > 2 and not isinstance(node.args[2], str): + pad_value = node.args[2] + if pad_value is None: + pad_value = 0.0 + pad_op.AddScalarParam( + OpPad.param_pad_constant_value, + QNN_TENSOR_TYPE_MAP[type(pad_value)], + {QCOM_DATA: pad_value}, + ) + # pad_amount tensor param (UINT32, shape [rank, 2]) pad_op.AddTensorParam( OpPad.param_pad_amount, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(pad_amount_shape), - pad_amount_shape, - pad_amount, + len(pad_amount_shape), + pad_amount_shape, + pad_amount, True, ) diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 5923d438252..b5425ce40be 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -48,32 +48,61 @@ def define_node( nodes_to_wrappers, ) dim = cast(int, node.args[1]) - if QCOM_AXIS_ORDER in node.meta: - dim = node.meta[QCOM_AXIS_ORDER].index(dim) if dim < 0: dim = dim % len(input_tensor.shape) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER][dim] + + # --- parse & normalize pytorch dim --- + pytorch_dim = int(node.args[1]) + rank = len(input_tensor.shape) + if pytorch_dim < 0: + pytorch_dim = pytorch_dim % rank + + # --- map pytorch dim -> QNN dim --- + qnn_dim = pytorch_dim + if QCOM_AXIS_ORDER in node.meta: + axis_order = node.meta[QCOM_AXIS_ORDER] + qnn_dim = axis_order.index(pytorch_dim) - start = 0 if node.args[2] is None else cast(int, node.args[2]) - if start < 0: - start = start % input_tensor.shape[dim] + # --- size on the QNN axis --- + size = int(input_tensor.shape[qnn_dim]) + # --- get start/end/step --- + start = 0 if len(node.args) <= 2 or node.args[2] is None else int(node.args[2]) + end = size if len(node.args) > 3 and node.args[3] is not None: - end = min(cast(int, node.args[3]), input_tensor.shape[dim]) - if end < 0: - end = end % input_tensor.shape[dim] - else: - end = input_tensor.shape[dim] - input_tensor_rank = len(input_tensor.shape) + end = int(node.args[3]) + step = 1 if len(node.args) <= 4 or node.args[4] is None else int(node.args[4]) + + # --- normalize negatives --- + if start < 0: + start = start % size + if end < 0: + end = end % size + + # --- clamp into valid range --- + start = max(0, min(start, size)) + end = max(0, min(end, size)) + + # --- canonicalize for positive step --- + if step == 0: + step = 1 + if step > 0 and start > end: + # empty slice (like Python []): make it start=end + start = end + elif step < 0: + raise NotImplementedError("Negative step not supported in QNN StridedSlice") + + # --- build ranges in QNN axes --- ranges = [] - for i in range(input_tensor_rank): - if i == dim: - # find step - step = node.args[4] if len(node.args) > 4 else 1 + for q in range(rank): + if q == qnn_dim: ranges.extend([start, end, step]) else: - ranges.extend([0, input_tensor.shape[i], 1]) + ranges.extend([0, int(input_tensor.shape[q]), 1]) - range_shape = [input_tensor_rank, 3] + range_shape = [rank, 3] stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( node.name, diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 97e0b4bd109..bd1ca33bd13 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -52,6 +52,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.leaky_relu.default, torch.ops.aten.linear.default, torch.ops.aten.matmul.default, + torch.ops.aten.pad.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.prelu.default, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3240ad7a018..bfa41dcec0b 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -450,7 +450,34 @@ def example_inputs(self): "x": torch.randn((1, 3, 3, 3)), "y": torch.randn((2, 1, 5, 5)), } +class Conv1d(torch.nn.Module): + def __init__( + self, + in_channels=3, + out_channels=6, + kernel_size=3, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super().__init__() + self.conv = torch.nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + def forward(self, x): + return self.conv(x) class Conv1dSequential(torch.nn.Module): def __init__(self, bias=True): @@ -474,7 +501,6 @@ def __init__(self, bias=True): def forward(self, x): return self.second(self.first(x)) - # small models class Conv1dReluLogSoftmax(torch.nn.Module): def __init__(self, dim): @@ -490,6 +516,35 @@ def forward(self, x): return x +class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels=3, + out_channels=6, + kernel_size: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups=1, + bias=True, + padding_mode="zeros", + ): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + def forward(self, x): + return self.conv(x) + class Conv2dArgmin(torch.nn.Module): def __init__(self): super().__init__() @@ -1508,6 +1563,14 @@ def forward(self, x): x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant" ) +class PadGeneric(torch.nn.Module): + def __init__(self, mode, pad = (1, 1, 1, 1)): + super().__init__() + self.mode = mode + self.pad = pad + + def forward(self, x): + return torch.ops.aten.pad.default(x, self.pad, mode=self.mode) class Permute(torch.nn.Module): def __init__(self, dims: List[int]): @@ -1879,6 +1942,15 @@ def forward(self, x, y): ) +class SliceConv2d(torch.nn.Module): + def __init__(self, in_ch=1, out_ch=2, k=1): + super().__init__() + self.conv = torch.nn.Conv2d(in_ch, out_ch, kernel_size=k, bias=True) + + def forward(self, x): + y = torch.ops.aten.slice.Tensor(x, -2, 0, 1) + return torch.ops.aten.conv2d.default(y, self.conv.weight, self.conv.bias) + class Softmax(torch.nn.Module): def __init__(self, dim): super().__init__() diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 94a5d08acc1..ea1cfe95c0d 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -1,11 +1,11 @@ import unittest import torch -from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps +from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps, ConvertPadToSliceConcat class TestPasses(unittest.TestCase): - def test_insert_reshape_for_argmax(self): + def test_insert_reshape_for_reduced_ops(self): class ArgmaxModule(torch.nn.Module): def forward(self, x): return torch.argmax(x, dim=None) @@ -49,6 +49,49 @@ def forward(self, x): torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" ) + def test_convert_pad_to_slice_concat(self): + # Test with circular and replicate modes, the pass should remove the pad node and insert slice and concat nodes + class Pad(torch.nn.Module): + def __init__(self, mode): + super().__init__() + self.mode = mode + + def forward(self, x): + # pad order = [left, right, top, bottom] + return torch.ops.aten.pad.default(x, (1, 1, 1, 1), mode=self.mode) + + modes = ["circular", "replicate"] + for mode in modes: + mod = Pad(mode) + x = torch.arange(1.0, 17.0).reshape(1, 1, 4, 4) + ep = torch.export.export(mod, (x,)) + # Run original module for reference + ref = mod(x) + + circular_pad_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.pad.default and mode in n.args + ] + self.assertTrue(len(circular_pad_nodes) == 1, "Circular pad node missing") + + ConvertPadToSliceConcat()(ep.graph_module) + + out = ep.graph_module(x) + # Check graph structure: argmax should take a reshape as input + slice_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.slice.Tensor + ] + circular_pad_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.pad.default and mode in n.args + ] + self.assertTrue(len(slice_nodes) >= 1, "Slice node should be inserted") + self.assertTrue(len(circular_pad_nodes) == 0, "Pad node should be removed") + + # Execute new graph and compare with reference + out = ep.graph_module(x) + self.assertTrue( + torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" + ) + if __name__ == "__main__": unittest.main() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 56983561e5f..908cb7fc85e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -315,6 +315,14 @@ def test_qnn_backend_conv1d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv1d_mode(self): + # N=4 batches, C=3 channels, length L=32 (L>=2 so "reflect" is valid) + L = 32 + sample_input = (torch.arange(4 * 3 * L, dtype=torch.float32).reshape(4, 3, L) / 1000.0,) + for mode in ["zeros", "reflect", "replicate", "circular"]: + module = Conv1d(padding=1, padding_mode=mode) # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -322,6 +330,13 @@ def test_qnn_backend_conv2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_conv2d_mode(self): + sample_input = (torch.randn(4, 3, 16, 16),) + for mode in ["zeros", "reflect", "replicate", "circular"]: + module = Conv2d(padding=1, padding_mode=mode) # noqa: F405 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_channel_last(self): modules = [ Conv2dSequential(channel_last=True), # noqa: F405 @@ -1182,6 +1197,61 @@ def test_qnn_backend_pad(self): sample_input = (torch.randn([1, 8, 128]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad_generic(self): + test_comb = [ + # --- replicate --- + { + QCOM_MODULE: [PadGeneric("replicate", (1, 1, 1, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 17.0).reshape(1, 1, 4, 4),)], + }, + { + # Small 2x3 input with large padding (edge replication stress) + QCOM_MODULE: [PadGeneric("replicate", (2, 2, 2, 2))], # (L,R,T,B) + QCOM_SAMPLE_INPUTS: [(torch.tensor([[[[1., 2., 3.], + [4., 5., 6.]]]]),)], + }, + { + # Batch>1, Channels>1, asymmetric pads + QCOM_MODULE: [PadGeneric("replicate", (1, 0, 0, 2))], + QCOM_SAMPLE_INPUTS: [(torch.arange(2*3*4*5, dtype=torch.float32) + .reshape(2, 3, 4, 5),)], + }, + + # --- circular --- + { + QCOM_MODULE: [PadGeneric("circular", (1, 1, 1, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 17.0).reshape(1, 1, 4, 4),)], + }, + { + # Asymmetric circular pad + QCOM_MODULE: [PadGeneric("circular", (2, 0, 1, 0))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 1.0 + 1*1*5*4) + .reshape(1, 1, 5, 4),)], + }, + + # --- reflect --- + # For reflect, each pad must be <= size-1 along that dim. + { + QCOM_MODULE: [PadGeneric("reflect", (1, 1, 1, 0))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 1.0 + 1*1*3*4) + .reshape(1, 1, 3, 4),)], + }, + + # --- constant (value defaults to 0.0 in aten.pad.default) --- + { + QCOM_MODULE: [PadGeneric("constant", (2, 0, 0, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 10.0).reshape(1, 1, 3, 3),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index, mode=module.mode, pad=module.pad): + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_permute(self): modules = [ Permute([0, 2, 3, 1]), # noqa: F405 @@ -1331,11 +1401,13 @@ def test_qnn_backend_slice_copy(self): SliceCopyDefaultParameter(), # noqa: F405 SliceCopy(), # noqa: F405 SliceCopyWithStep(), # noqa: F405 + SliceConv2d() ] sample_inputs = [ (torch.randn([2, 1, 320, 512]),), (torch.randn([1, 512]), torch.randn([1, 8])), (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn(1, 1, 2, 6), ) ] for module, sample_input in zip(modules, sample_inputs): self.lower_module_and_test_output(module, sample_input) @@ -2011,6 +2083,22 @@ def test_qnn_backend_conv1d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv1d_mode(self): + sample_input = (torch.arange(4 * 3 * 16 * 16).reshape(4, 3, 16, 16).float() / 1000,) + for mode in ["zeros", "reflect", "replicate", "circular"]: + module = Conv1d(padding=1, padding_mode=mode) # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_conv1d_mode(self): + # N=4 batches, C=3 channels, length L=32 (L>=2 so "reflect" is valid) + L = 32 + sample_input = (torch.arange(4 * 3 * L, dtype=torch.float32).reshape(4, 3, L) / 1000.0,) + for mode in ["zeros", "reflect", "replicate", "circular"]: + module = Conv1d(padding=1, padding_mode=mode) # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d(self): modules = [Conv2dSequential(), Conv2dSequential(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) @@ -2019,6 +2107,13 @@ def test_qnn_backend_conv2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_mode(self): + sample_input = (torch.randn(4, 3, 16, 16),) + for mode in ["zeros", "reflect", "replicate", "circular"]: + module = Conv2d(padding=1, padding_mode=mode) # noqa: F405 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_block(self): o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0 @@ -2991,6 +3086,61 @@ def test_qnn_backend_pad(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pad_generic(self): + test_comb = [ + # --- replicate --- + { + QCOM_MODULE: [PadGeneric("replicate", (1, 1, 1, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 17.0).reshape(1, 1, 4, 4),)], + }, + { + # Small 2x3 input with large padding (edge replication stress) + QCOM_MODULE: [PadGeneric("replicate", (2, 2, 2, 2))], # (L,R,T,B) + QCOM_SAMPLE_INPUTS: [(torch.tensor([[[[1., 2., 3.], + [4., 5., 6.]]]]),)], + }, + { + # Batch>1, Channels>1, asymmetric pads + QCOM_MODULE: [PadGeneric("replicate", (1, 0, 0, 2))], + QCOM_SAMPLE_INPUTS: [(torch.arange(2*3*4*5, dtype=torch.float32) + .reshape(2, 3, 4, 5),)], + }, + + # --- circular --- + { + QCOM_MODULE: [PadGeneric("circular", (1, 1, 1, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 17.0).reshape(1, 1, 4, 4),)], + }, + { + # Asymmetric circular pad + QCOM_MODULE: [PadGeneric("circular", (2, 0, 1, 0))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 1.0 + 1*1*5*4) + .reshape(1, 1, 5, 4),)], + }, + + # --- reflect --- + # For reflect, each pad must be <= size-1 along that dim. + { + QCOM_MODULE: [PadGeneric("reflect", (1, 1, 1, 0))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 1.0 + 1*1*3*4) + .reshape(1, 1, 3, 4),)], + }, + + # --- constant (value defaults to 0.0 in aten.pad.default) --- + { + QCOM_MODULE: [PadGeneric("constant", (2, 0, 0, 1))], + QCOM_SAMPLE_INPUTS: [(torch.arange(1.0, 10.0).reshape(1, 1, 3, 3),)], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index, mode=module.mode, pad=module.pad): + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_permute(self): modules = [ Permute([0, 2, 3, 1]), # noqa: F405 @@ -3167,11 +3317,13 @@ def test_qnn_backend_slice_copy(self): SliceCopyDefaultParameter(), # noqa: F405 SliceCopy(), # noqa: F405 SliceCopyWithStep(), # noqa: F405 + SliceConv2d() ] sample_inputs = [ (torch.randn([2, 1, 320, 512]),), (torch.randn([1, 512]), torch.randn([1, 8])), (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn(1, 1, 2, 6), ) ] for module, sample_input in zip(modules, sample_inputs): module = self.get_qdq_module(module, sample_input)