Skip to content

Commit 2d70b7c

Browse files
cccclaifacebook-github-bot
authored andcommitted
Fix circulate/reflection/replication padding and conv2d with different padding (#14860)
Summary: This PR make changes for the Pad and Conv2d with pad option 1. For Pad op, first of all, QNN can support it so we don't decompose it. `circular` mode is not supported by QNN so it has to be converted anyway. However, for `replicate`, ideally it should work with `OpPad.Scheme.EDGE` but I have been getting wrong result. In this PR, I convert pad with `circular` mode and `replicate` mode for the pass. We can follow up on why replicate didn't work with `OpPad.Scheme.EDGE` and remove the condition 2. For `Conv2d` with pad option, it just work out of box once we fix the pad op Add unit test for pad and conv2d with pad option, also add test for the new pass Differential Revision: D84071939
1 parent bba9d26 commit 2d70b7c

File tree

8 files changed

+554
-28
lines changed

8 files changed

+554
-28
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .annotate_unbind import AnnotateUnbind
1111
from .canonicalize_conv import CanonicalizeConv
1212
from .convert_bmm_to_matmul import ConvertBmmToMatmul
13+
from .convert_pad_to_slice_concat import ConvertPadToSliceConcat
1314
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1415
from .convert_square_to_pow import ConvertSquareToPow
1516
from .decompose_any import DecomposeAny
@@ -48,6 +49,7 @@
4849

4950
__all__ = [
5051
AnnotateAdaptiveAvgPool1D,
52+
ConvertPadToSliceConcat,
5153
AnnotateQuantAttrs,
5254
AnnotateStack,
5355
AnnotateUnbind,
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import operator
2+
import torch
3+
from torch.fx import GraphModule
4+
from executorch.exir.pass_base import ExportPass, PassResult
5+
6+
import operator
7+
import torch
8+
from torch.fx import GraphModule
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
class ConvertPadToSliceConcat(ExportPass):
12+
"""
13+
Replace aten.pad(..., mode in {'circular','replicate'}) with slice+cat (+expand for replicate).
14+
Supports 1D/2D (NCL / NCHW-like). SymInt-safe for torch.export graphs.
15+
"""
16+
17+
def __init__(self):
18+
super().__init__()
19+
20+
# ---------- small helpers ----------
21+
22+
def _copy_meta(self, src, dst, val_transform=None):
23+
dst.meta = dict(getattr(src, "meta", {}))
24+
if "val" in getattr(src, "meta", {}) and isinstance(src.meta["val"], torch.Tensor):
25+
v = src.meta["val"]
26+
if val_transform is not None:
27+
try:
28+
v = val_transform(v)
29+
except Exception:
30+
pass
31+
dst.meta["val"] = v
32+
33+
def _set_scalar_meta(self, node, dtype=torch.int64):
34+
node.meta = getattr(node, "meta", {})
35+
node.meta["val"] = torch.tensor(0, dtype=dtype)
36+
37+
def _sym_size(self, graph, x, dim):
38+
if hasattr(torch.ops.aten, "sym_size"):
39+
n = graph.create_node("call_function", torch.ops.aten.sym_size.int, (x, dim))
40+
else:
41+
n = graph.create_node("call_function", torch.ops.aten.size.int, (x, dim))
42+
self._set_scalar_meta(n)
43+
return n
44+
45+
def _sym_sub(self, graph, a, b):
46+
n = graph.create_node("call_function", operator.sub, (a, b))
47+
self._set_scalar_meta(n)
48+
return n
49+
50+
def _rank_from_meta(self, t):
51+
r = None
52+
if hasattr(t, "meta") and isinstance(t.meta.get("val", None), torch.Tensor):
53+
r = t.meta["val"].dim()
54+
return r
55+
56+
def _expand_along_dim(self, graph, t, dim, new_len, before):
57+
"""
58+
Build aten.expand(t, new_sizes) where only 'dim' changes to new_len.
59+
Works with SymInt sizes. new_len is a python int.
60+
"""
61+
with graph.inserting_before(before):
62+
rank = self._rank_from_meta(t)
63+
if rank is None:
64+
# Fallback: grab sizes with sym_size one-by-one assuming up to 8 dims
65+
# (most models are 4D here; if meta is missing, 4 is reasonable)
66+
rank = 4
67+
sizes = []
68+
# convert negative dim to pos
69+
pdim = dim % rank
70+
for d in range(rank):
71+
if d == pdim:
72+
sizes.append(int(new_len))
73+
else:
74+
sizes.append(self._sym_size(graph, t, d))
75+
n = graph.create_node("call_function", torch.ops.aten.expand.default, (t, sizes))
76+
# meta: broadcast view to the new shape if we have it
77+
def _vt(v):
78+
shape = list(v.shape)
79+
shape[pdim] = int(new_len)
80+
return v.expand(shape)
81+
self._copy_meta(t, n, _vt)
82+
return n
83+
84+
# ---------- main entry ----------
85+
86+
def call(self, gm: GraphModule) -> PassResult:
87+
g = gm.graph
88+
modified = False
89+
90+
for node in list(g.nodes):
91+
if node.op == "call_function" and node.target == torch.ops.aten.pad.default:
92+
# args: (x, pad, mode, [value])
93+
if len(node.args) < 3 or not isinstance(node.args[2], str):
94+
continue
95+
mode = node.args[2]
96+
if mode not in ("circular", "replicate"):
97+
continue
98+
99+
x = node.args[0]
100+
pad = list(node.args[1])
101+
ndim = len(pad) // 2 # 1D: (l,r) 2D: (l,r,t,b)
102+
103+
if mode == "circular":
104+
new_val = self._insert_circular(g, x, pad, ndim, before=node)
105+
else:
106+
new_val = self._insert_replicate(g, x, pad, ndim, before=node)
107+
108+
self._copy_meta(node, new_val)
109+
node.replace_all_uses_with(new_val)
110+
g.erase_node(node)
111+
modified = True
112+
113+
if modified:
114+
g.lint()
115+
gm.recompile()
116+
return PassResult(gm, modified)
117+
118+
# ---------- rewrites ----------
119+
def _insert_circular(self, graph, x, pad, ndim, before):
120+
with graph.inserting_before(before):
121+
if ndim == 1:
122+
left, right = pad
123+
w = self._sym_size(graph, x, -1)
124+
start = self._sym_sub(graph, w, left)
125+
left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start, w))
126+
right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, right))
127+
self._copy_meta(x, left_slice)
128+
self._copy_meta(x, right_slice)
129+
out = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1))
130+
self._copy_meta(x, out, lambda t: torch.cat([t[..., -left:], t, t[..., :right]], dim=-1))
131+
return out
132+
133+
if ndim == 2:
134+
l, r, t, b = pad
135+
# horiz
136+
W = self._sym_size(graph, x, -1)
137+
start_w = self._sym_sub(graph, W, l)
138+
left_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, start_w, W))
139+
right_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, r))
140+
self._copy_meta(x, left_slice)
141+
self._copy_meta(x, right_slice)
142+
x_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((left_slice, x, right_slice), -1))
143+
self._copy_meta(x, x_cat, lambda T: torch.cat([T[..., -l:], T, T[..., :r]], dim=-1))
144+
145+
# vert
146+
H = self._sym_size(graph, x_cat, -2)
147+
start_h = self._sym_sub(graph, H, t)
148+
top_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, start_h, H))
149+
bot_slice = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_cat, -2, 0, b))
150+
self._copy_meta(x_cat, top_slice)
151+
self._copy_meta(x_cat, bot_slice)
152+
y_cat = graph.create_node("call_function", torch.ops.aten.cat.default, ((top_slice, x_cat, bot_slice), -2))
153+
self._copy_meta(x_cat, y_cat, lambda T: torch.cat([T[..., -t:, :], T, T[..., :b, :]], dim=-2))
154+
return y_cat
155+
156+
raise NotImplementedError(f"circular pad only supports 1D/2D, got pad={pad}")
157+
158+
def _insert_replicate(self, graph, x, pad, ndim, before):
159+
"""
160+
Replicate: extend borders with edge values.
161+
Implemented via slice (edge 1-wide) + expand + cat.
162+
"""
163+
with graph.inserting_before(before):
164+
if ndim == 1:
165+
left, right = pad
166+
parts = []
167+
if left > 0:
168+
left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1))
169+
self._copy_meta(x, left_edge)
170+
left_pad = self._expand_along_dim(graph, left_edge, -1, left, before)
171+
parts.append(left_pad)
172+
parts.append(x)
173+
if right > 0:
174+
right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None))
175+
self._copy_meta(x, right_edge)
176+
right_pad = self._expand_along_dim(graph, right_edge, -1, right, before)
177+
parts.append(right_pad)
178+
179+
out = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1))
180+
# meta
181+
def _vt(t):
182+
L = left; R = right
183+
if L or R:
184+
lp = t[..., :1].expand(*t.shape[:-1], L) if L else t[..., :0]
185+
rp = t[..., -1:].expand(*t.shape[:-1], R) if R else t[..., :0]
186+
return torch.cat([lp, t, rp], dim=-1)
187+
return t
188+
self._copy_meta(x, out, _vt)
189+
return out
190+
191+
if ndim == 2:
192+
l, r, t, b = pad
193+
# horizontal replicate first
194+
parts = []
195+
if l > 0:
196+
left_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, 0, 1))
197+
self._copy_meta(x, left_edge)
198+
left_pad = self._expand_along_dim(graph, left_edge, -1, l, before)
199+
parts.append(left_pad)
200+
parts.append(x)
201+
if r > 0:
202+
right_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x, -1, -1, None))
203+
self._copy_meta(x, right_edge)
204+
right_pad = self._expand_along_dim(graph, right_edge, -1, r, before)
205+
parts.append(right_pad)
206+
207+
x_w = parts[0] if len(parts) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts), -1))
208+
self._copy_meta(x, x_w, lambda T: torch.cat([
209+
T[..., :1].expand(*T.shape[:-1], l) if l else T[..., :0],
210+
T,
211+
T[..., -1:].expand(*T.shape[:-1], r) if r else T[..., :0]
212+
], dim=-1) if (l or r) else T)
213+
214+
# then vertical replicate on the widened tensor
215+
parts2 = []
216+
if t > 0:
217+
top_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, 0, 1))
218+
self._copy_meta(x_w, top_edge)
219+
top_pad = self._expand_along_dim(graph, top_edge, -2, t, before)
220+
parts2.append(top_pad)
221+
parts2.append(x_w)
222+
if b > 0:
223+
bot_edge = graph.create_node("call_function", torch.ops.aten.slice.Tensor, (x_w, -2, -1, None))
224+
self._copy_meta(x_w, bot_edge)
225+
bot_pad = self._expand_along_dim(graph, bot_edge, -2, b, before)
226+
parts2.append(bot_pad)
227+
228+
out = parts2[0] if len(parts2) == 1 else graph.create_node("call_function", torch.ops.aten.cat.default, (tuple(parts2), -2))
229+
self._copy_meta(x_w, out, lambda T: torch.cat([
230+
T[..., :1, :].expand(*T.shape[:-2], t, T.shape[-1]) if t else T[..., :0, :],
231+
T,
232+
T[..., -1:, :].expand(*T.shape[:-2], b, T.shape[-1]) if b else T[..., :0, :]
233+
], dim=-2) if (t or b) else T)
234+
return out
235+
236+
raise NotImplementedError(f"replicate pad only supports 1D/2D, got pad={pad}")

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AnnotateUnbind,
1616
CanonicalizeConv,
1717
ConvertBmmToMatmul,
18+
ConvertPadToSliceConcat,
1819
ConvertLinearToConv2d,
1920
ConvertSquareToPow,
2021
DecomposeAny,
@@ -210,12 +211,14 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
210211
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
211212
self.add_pass(ReplaceInfValues())
212213
self.add_pass(LiftConstantScalarOperands())
214+
self.add_pass(ConvertPadToSliceConcat())
213215
self.add_pass(InsertReshapeForReduceOps())
214216
return self._transform(graph_module)
215217

216218
def transform_for_export_pipeline(
217219
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
218220
):
221+
self.add_pass(ConvertPadToSliceConcat())
219222
self.add_pass(DecomposeBinaryAlpha())
220223
self.add_pass(DecomposeCDist())
221224
self.add_pass(DecomposeScaledDotProductAttention())

backends/qualcomm/builders/op_pad.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
@register_node_visitor
2020
class Pad(NodeVisitor):
21-
target = ["aten.constant_pad_nd.default"]
21+
target = [
22+
"aten.constant_pad_nd.default",
23+
"aten.pad.default",
24+
# Add tests before adding these two to the list
25+
# "aten.reflection_pad2d.default",
26+
# "aten.replication_pad2d.default",
27+
]
2228

2329
def __init__(self, *args) -> None:
2430
super().__init__(*args)
@@ -49,48 +55,72 @@ def define_node(
4955
)
5056
pad_output_tensors = [output_tensor_wrapper]
5157

58+
# ---- Pad amount ([rank, 2], uint32) ----
5259
pad_amount_shape = [input_tensor.dim(), 2]
53-
# pytorch padding start from the last index
54-
pad_amount = np.reshape(cast(List[int], node.args[1]), (-1, 2))[::-1].astype(
55-
np.uint32
56-
)
57-
# fulfill the pad amount for each idex of tensor
58-
if zero_amounts := pad_amount_shape[0] - pad_amount.shape[0]:
59-
pad_amount = np.concatenate(
60-
(np.array([(0, 0)] * zero_amounts), pad_amount)
61-
).astype(np.uint32)
60+
# PyTorch pad order is from the *last* dim: e.g. 2D = [L, R, T, B]
61+
pad_amount = np.reshape(
62+
np.array(cast(List[int], node.args[1]), dtype=np.int64), (-1, 2)
63+
)[:: -1] # reverse to go from last->first to first->last
64+
65+
# expand to all ranks if needed
66+
if pad_amount_shape[0] - pad_amount.shape[0] > 0:
67+
zeros = np.zeros((pad_amount_shape[0] - pad_amount.shape[0], 2), dtype=np.int64)
68+
pad_amount = np.concatenate((zeros, pad_amount), axis=0)
6269

70+
# remap rows if backend axis order is provided (backend_pos -> pt_dim)
6371
if QCOM_AXIS_ORDER in node.meta:
64-
pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])]
65-
pad_amount_val = node.args[2]
72+
axis_order = list(node.meta[QCOM_AXIS_ORDER]) # e.g. (0,2,3,1)
73+
pad_amount = pad_amount[axis_order]
74+
75+
pad_amount = pad_amount.astype(np.uint32, copy=False)
76+
77+
# ---- Mode/scheme ----
78+
if len(node.args) >= 3 and isinstance(node.args[2], str):
79+
mode = node.args[2]
80+
else:
81+
# default to constant
82+
mode = "constant"
6683

84+
scheme_map = {
85+
"constant": OpPad.Scheme.CONSTANT,
86+
"reflect": OpPad.Scheme.MIRROR_REFLECT,
87+
"replicate": OpPad.Scheme.EDGE, # I think this is supposed to be correct, but the result is wrong
88+
}
89+
scheme_u32 = np.uint32(scheme_map[mode])
90+
91+
# ---- Build op ----
6792
pad_op = PyQnnWrapper.PyQnnOpWrapper(
68-
node.name,
69-
QNN_OP_PACKAGE_NAME_QTI_AISW,
70-
OpPad.op_name,
93+
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, OpPad.op_name
7194
)
7295
pad_op.AddInputTensors(pad_input_tensors)
7396
pad_op.AddOutputTensors(pad_output_tensors)
7497

75-
# For now, we only support constant (0) padding due to torch implementation
7698
pad_op.AddScalarParam(
7799
OpPad.param_scheme,
78100
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
79-
{QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)},
101+
{QCOM_DATA: scheme_u32}, # scheme (UINT32)
80102
)
81103

82-
pad_op.AddScalarParam(
83-
OpPad.param_pad_constant_value,
84-
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
85-
{QCOM_DATA: pad_amount_val},
86-
)
104+
# pad_constant_value only for constant mode
105+
if mode == "constant":
106+
pad_value = None
107+
if len(node.args) > 2 and not isinstance(node.args[2], str):
108+
pad_value = node.args[2]
109+
if pad_value is None:
110+
pad_value = 0.0
111+
pad_op.AddScalarParam(
112+
OpPad.param_pad_constant_value,
113+
QNN_TENSOR_TYPE_MAP[type(pad_value)],
114+
{QCOM_DATA: pad_value},
115+
)
87116

117+
# pad_amount tensor param (UINT32, shape [rank, 2])
88118
pad_op.AddTensorParam(
89119
OpPad.param_pad_amount,
90120
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
91-
len(pad_amount_shape),
92-
pad_amount_shape,
93-
pad_amount,
121+
len(pad_amount_shape),
122+
pad_amount_shape,
123+
pad_amount,
94124
True,
95125
)
96126

backends/qualcomm/partition/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
5252
torch.ops.aten.leaky_relu.default,
5353
torch.ops.aten.linear.default,
5454
torch.ops.aten.matmul.default,
55+
torch.ops.aten.pad.default,
5556
torch.ops.aten.pixel_shuffle.default,
5657
torch.ops.aten.pixel_unshuffle.default,
5758
torch.ops.aten.prelu.default,

0 commit comments

Comments
 (0)