Skip to content

Commit 64abd74

Browse files
cccclaifacebook-github-bot
authored andcommitted
conv2d padding
Differential Revision: D84071939
1 parent 553be81 commit 64abd74

File tree

4 files changed

+98
-14
lines changed

4 files changed

+98
-14
lines changed

backends/qualcomm/builders/op_pad.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818

1919
@register_node_visitor
2020
class Pad(NodeVisitor):
21-
target = ["aten.constant_pad_nd.default"]
21+
target = [
22+
"aten.constant_pad_nd.default",
23+
"aten.pad.default", # handles reflect/replicate modes
24+
"aten.reflection_pad2d.default",
25+
"aten.replication_pad2d.default",
26+
]
2227

2328
def __init__(self, *args) -> None:
2429
super().__init__(*args)
@@ -28,6 +33,8 @@ def define_node(
2833
node: torch.fx.Node,
2934
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
3035
) -> PyQnnWrapper.PyQnnOpWrapper:
36+
37+
# ---- Input tensor ----
3138
input_node = self.get_node(node.args[0])
3239
input_tensor = self.get_tensor(input_node, node)
3340
pad_inp_tensor_wrapper = self.define_tensor(
@@ -39,6 +46,7 @@ def define_node(
3946
)
4047
pad_input_tensors = [pad_inp_tensor_wrapper]
4148

49+
# ---- Output tensor ----
4250
output_tensor = self.get_tensor(node, node)
4351
output_tensor_wrapper = self.define_tensor(
4452
node,
@@ -49,21 +57,43 @@ def define_node(
4957
)
5058
pad_output_tensors = [output_tensor_wrapper]
5159

60+
# ---- Pad amount handling ----
61+
pad_list = cast(List[int], node.args[1])
5262
pad_amount_shape = [input_tensor.dim(), 2]
53-
# pytorch padding start from the last index
54-
pad_amount = np.reshape(cast(List[int], node.args[1]), (-1, 2))[::-1].astype(
55-
np.uint32
56-
)
57-
# fulfill the pad amount for each idex of tensor
63+
64+
# PyTorch pad order: [last_dim, ..., first_dim]
65+
pad_amount = np.reshape(pad_list, (-1, 2))[::-1].astype(np.uint32)
66+
67+
# Expand to full rank if needed
5868
if zero_amounts := pad_amount_shape[0] - pad_amount.shape[0]:
5969
pad_amount = np.concatenate(
6070
(np.array([(0, 0)] * zero_amounts), pad_amount)
6171
).astype(np.uint32)
6272

73+
# Apply axis reordering if necessary
6374
if QCOM_AXIS_ORDER in node.meta:
6475
pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])]
65-
pad_amount_val = node.args[2]
6676

77+
# ---- Determine mode ----
78+
if len(node.args) >= 3 and isinstance(node.args[2], str):
79+
mode = node.args[2]
80+
elif "reflection" in node.target:
81+
mode = "reflect"
82+
elif "replication" in node.target:
83+
mode = "replicate"
84+
else:
85+
mode = "constant"
86+
87+
scheme_map = {
88+
"constant": OpPad.Scheme.CONSTANT,
89+
"reflect": OpPad.Scheme.MIRROR_REFLECT,
90+
"replicate": OpPad.Scheme.EDGE,
91+
}
92+
93+
if mode not in scheme_map:
94+
raise ValueError(f"[QNN][Pad] Unsupported pad mode: {mode}")
95+
96+
# ---- Create QNN op ----
6797
pad_op = PyQnnWrapper.PyQnnOpWrapper(
6898
node.name,
6999
QNN_OP_PACKAGE_NAME_QTI_AISW,
@@ -72,19 +102,29 @@ def define_node(
72102
pad_op.AddInputTensors(pad_input_tensors)
73103
pad_op.AddOutputTensors(pad_output_tensors)
74104

75-
# For now, we only support constant (0) padding due to torch implementation
105+
# scheme param
76106
pad_op.AddScalarParam(
77107
OpPad.param_scheme,
78108
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
79-
{QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)},
109+
{QCOM_DATA: np.uint32(scheme_map[mode])},
80110
)
81111

82-
pad_op.AddScalarParam(
83-
OpPad.param_pad_constant_value,
84-
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
85-
{QCOM_DATA: pad_amount_val},
86-
)
112+
# pad_constant_value param (only for constant mode)
113+
if mode == "constant":
114+
# torch.constant_pad_nd takes optional pad value, default = 0.0
115+
pad_value = node.kwargs.get("value", None)
116+
if pad_value is None and len(node.args) > 2 and not isinstance(node.args[2], str):
117+
pad_value = node.args[2]
118+
if pad_value is None:
119+
pad_value = 0.0
120+
121+
pad_op.AddScalarParam(
122+
OpPad.param_pad_constant_value,
123+
QNN_TENSOR_TYPE_MAP[type(pad_value)],
124+
{QCOM_DATA: pad_value},
125+
)
87126

127+
# pad_amount tensor param
88128
pad_op.AddTensorParam(
89129
OpPad.param_pad_amount,
90130
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,

backends/qualcomm/partition/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
5252
torch.ops.aten.leaky_relu.default,
5353
torch.ops.aten.linear.default,
5454
torch.ops.aten.matmul.default,
55+
torch.ops.aten.pad.default,
5556
torch.ops.aten.pixel_shuffle.default,
5657
torch.ops.aten.pixel_unshuffle.default,
5758
torch.ops.aten.prelu.default,

backends/qualcomm/tests/models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,35 @@ def forward(self, x):
490490
return x
491491

492492

493+
class Conv2d(torch.nn.Module):
494+
def __init__(
495+
self,
496+
in_channels=3,
497+
out_channels=6,
498+
kernel_size: Union[int, Tuple[int, int]] = 3,
499+
stride: Union[int, Tuple[int, int]] = 1,
500+
padding: Union[int, Tuple[int, int]] = 0,
501+
dilation: Union[int, Tuple[int, int]] = 1,
502+
groups=1,
503+
bias=True,
504+
padding_mode="zeros",
505+
):
506+
super().__init__()
507+
self.conv = torch.nn.Conv2d(
508+
in_channels=in_channels,
509+
out_channels=out_channels,
510+
kernel_size=kernel_size,
511+
stride=stride,
512+
padding=padding,
513+
dilation=dilation,
514+
groups=groups,
515+
bias=bias,
516+
padding_mode=padding_mode,
517+
)
518+
519+
def forward(self, x):
520+
return self.conv(x)
521+
493522
class Conv2dArgmin(torch.nn.Module):
494523
def __init__(self):
495524
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,13 @@ def test_qnn_backend_conv2d(self):
318318
with self.subTest(i=i):
319319
self.lower_module_and_test_output(module, sample_input)
320320

321+
322+
def test_qnn_backend_conv2d_mode(self):
323+
sample_input = (torch.randn(4, 3, 16, 16),)
324+
for mode in ["zeros", "reflect", "replicate", "circular"]:
325+
module = Conv2d(padding=1, padding_mode=mode) # noqa: F405
326+
self.lower_module_and_test_output(module, sample_input)
327+
321328
def test_qnn_backend_conv2d_channel_last(self):
322329
modules = [
323330
Conv2dSequential(channel_last=True), # noqa: F405
@@ -1996,6 +2003,13 @@ def test_qnn_backend_conv2d(self):
19962003
module = self.get_qdq_module(module, sample_input)
19972004
self.lower_module_and_test_output(module, sample_input)
19982005

2006+
def test_qnn_backend_conv2d_mode(self):
2007+
sample_input = (torch.randn(4, 3, 16, 16),)
2008+
for mode in ["zeros", "reflect", "replicate", "circular"]:
2009+
module = Conv2d(padding=1, padding_mode=mode) # noqa: F405
2010+
module = self.get_qdq_module(module, sample_input)
2011+
self.lower_module_and_test_output(module, sample_input)
2012+
19992013
def test_qnn_backend_conv2d_block(self):
20002014
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
20012015

0 commit comments

Comments
 (0)