Skip to content

Commit aa92ce6

Browse files
committed
[XNNPACK Quantizer] Select between TConvs and Convs
Allow selection of Difference between transposed convs and regular convs. Previously, we grouped all conv targets together (transposed and regular convs), but now we enable better per-operator selection Differential Revision: [D76641838](https://our.internmc.facebook.com/intern/diff/D76641838/) [ghstack-poisoned]
1 parent 36d5429 commit aa92ce6

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ class QuantPattern:
251251
torch.ops.aten.convolution.default,
252252
}
253253

254+
CONV_TRANSPOSE_TARGETS = {
255+
torch.ops.aten.conv_transpose1d,
256+
torch.ops.aten.conv_transpose1d.default,
257+
torch.ops.aten.conv_transpose2d,
258+
torch.ops.aten.conv_transpose2d.input,
259+
torch.ops.aten.conv_transpose3d,
260+
torch.ops.aten.conv_transpose3d.input,
261+
}
262+
254263
LINEAR_TARGETS = {
255264
torch.ops.aten.linear.default,
256265
}
@@ -269,14 +278,14 @@ class XNNPACKQuantizer(Quantizer):
269278
SUPPORTED_PATTERNS = [
270279
QuantPattern("conv_bn_relu", False, True, CONV_TARGETS),
271280
QuantPattern("conv_bn", False, True, CONV_TARGETS),
272-
QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS),
273-
QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS),
281+
QuantPattern("conv_transpose_bn_relu", False, True, CONV_TRANSPOSE_TARGETS),
282+
QuantPattern("conv_transpose_bn", False, True, CONV_TRANSPOSE_TARGETS),
274283
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
275284
QuantPattern("linear", True, False, LINEAR_TARGETS),
276285
QuantPattern("conv", True, False, CONV_TARGETS),
277-
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
286+
QuantPattern("conv_transpose", True, False, CONV_TRANSPOSE_TARGETS),
278287
QuantPattern("conv_relu", False, False, CONV_TARGETS),
279-
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
288+
QuantPattern("conv_transpose_relu", False, False, CONV_TRANSPOSE_TARGETS),
280289
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
281290
QuantPattern("add_relu", False, False, ADD_TARGETS),
282291
QuantPattern("add", False, False, ADD_TARGETS),

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,116 @@ def test_conv1d_with_conv2d(self):
120120
node_list,
121121
)
122122

123+
def test_q_tconv_and_conv2d(self):
124+
class TConv2dConv2d(torch.nn.Module):
125+
def __init__(self):
126+
super().__init__()
127+
self.first = torch.nn.ConvTranspose2d(
128+
in_channels=1,
129+
out_channels=3,
130+
kernel_size=(3, 3),
131+
padding=1,
132+
bias=False,
133+
)
134+
self.second = torch.nn.Conv2d(
135+
in_channels=3,
136+
out_channels=2,
137+
kernel_size=(3, 3),
138+
padding=1,
139+
bias=False,
140+
)
141+
142+
def forward(self, x):
143+
y = self.first(x)
144+
return self.second(y)
145+
146+
def example_inputs(self):
147+
return (torch.randn(1, 1, 3, 3),)
148+
149+
quantizer = XNNPACKQuantizer()
150+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
151+
quantizer.set_operator_type(
152+
torch.ops.aten.conv_transpose2d.input, quantization_config
153+
)
154+
node_occurrence = {
155+
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
156+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
157+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
158+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
159+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
160+
}
161+
node_list = [
162+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
163+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
164+
torch.ops.aten.conv_transpose2d.input,
165+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
166+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
167+
torch.ops.aten.conv2d.default,
168+
]
169+
m = TConv2dConv2d()
170+
self._test_quantizer(
171+
m,
172+
m.example_inputs(),
173+
quantizer,
174+
node_occurrence,
175+
node_list,
176+
is_debug_mode=True,
177+
)
178+
179+
def test_q_conv2_and_tconv2d(self):
180+
class TConv2dConv2d(torch.nn.Module):
181+
def __init__(self):
182+
super().__init__()
183+
self.first = torch.nn.ConvTranspose2d(
184+
in_channels=1,
185+
out_channels=3,
186+
kernel_size=(3, 3),
187+
padding=1,
188+
bias=False,
189+
)
190+
self.second = torch.nn.Conv2d(
191+
in_channels=3,
192+
out_channels=2,
193+
kernel_size=(3, 3),
194+
padding=1,
195+
bias=False,
196+
)
197+
198+
def forward(self, x):
199+
y = self.first(x)
200+
return self.second(y)
201+
202+
def example_inputs(self):
203+
return (torch.randn(1, 1, 3, 3),)
204+
205+
quantizer = XNNPACKQuantizer()
206+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
207+
quantizer.set_operator_type(torch.ops.aten.conv2d.default, quantization_config)
208+
node_occurrence = {
209+
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
210+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
211+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
212+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
213+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
214+
}
215+
node_list = [
216+
torch.ops.aten.conv_transpose2d.input,
217+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
218+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
219+
torch.ops.aten.conv2d.default,
220+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
221+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
222+
]
223+
m = TConv2dConv2d()
224+
self._test_quantizer(
225+
m,
226+
m.example_inputs(),
227+
quantizer,
228+
node_occurrence,
229+
node_list,
230+
is_debug_mode=True,
231+
)
232+
123233
def test_linear(self):
124234
quantizer = XNNPACKQuantizer()
125235
quantization_config = get_symmetric_quantization_config(is_per_channel=True)

0 commit comments

Comments
 (0)