Skip to content

Commit 36d5429

Browse files
committed
[Quantized DeConv Support] Dynamically Quantized Deconvolutions with groups ==1
Here we support dynamically quantized Deconvolutions. There is some refactoring of the previous diff, but in general, we just remove the constraint in the Dynamism check that the convolution isn't transposed. For the same reasons as before, this only supports channel_axis = 1 and groups = 1. Differential Revision: [D76638904](https://our.internmc.facebook.com/intern/diff/D76638904/) [ghstack-poisoned]
1 parent 5bdcb24 commit 36d5429

File tree

4 files changed

+133
-63
lines changed

4 files changed

+133
-63
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer):
274274
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
275275
QuantPattern("linear", True, False, LINEAR_TARGETS),
276276
QuantPattern("conv", True, False, CONV_TARGETS),
277-
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
277+
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
278278
QuantPattern("conv_relu", False, False, CONV_TARGETS),
279279
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
280280
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import torch
66
import torch.nn.functional as F
7-
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
7+
from executorch.backends.xnnpack.utils.utils import (
8+
get_groups_from_conv,
9+
is_depthwise_conv,
10+
)
811
from torch._subclasses import FakeTensor
912
from torch.fx import Node
1013
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
6568
return decorator
6669

6770

71+
def change_quantization_config(
72+
original_qspec,
73+
dtype=None,
74+
quant_min=None,
75+
quant_max=None,
76+
qscheme=None,
77+
ch_axis=None,
78+
is_dynamic=None,
79+
observer_or_fake_quant_ctr=None,
80+
):
81+
return QuantizationSpec(
82+
dtype=dtype or original_qspec.dtype,
83+
quant_min=quant_min or original_qspec.quant_min,
84+
quant_max=quant_max or original_qspec.quant_max,
85+
qscheme=qscheme or original_qspec.qscheme,
86+
ch_axis=ch_axis or original_qspec.ch_axis,
87+
is_dynamic=is_dynamic or original_qspec.is_dynamic,
88+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr
89+
or original_qspec.observer_or_fake_quant_ctr,
90+
)
91+
92+
6893
def is_relu_node(node: Node) -> bool:
6994
"""
7095
Check if a given node is a relu node
@@ -231,6 +256,9 @@ def _do_annotate_conv(
231256
if is_relu_node(user):
232257
continue
233258

259+
# Tracks conditions for whether or not to skip
260+
skip = False
261+
234262
input_qspec_map = {}
235263
input_act = conv_node.args[0]
236264
assert isinstance(input_act, Node)
@@ -239,35 +267,33 @@ def _do_annotate_conv(
239267
weight = conv_node.args[1]
240268
assert isinstance(weight, Node)
241269
weight_qspec = get_weight_qspec(quantization_config)
270+
num_groups = get_groups_from_conv(conv_node)
271+
272+
# skip if transposed conv has more than 1 group
273+
skip = skip or (is_conv_transpose and num_groups != 1)
274+
print(f"{skip} conv transpose and num_groups")
275+
242276
if is_conv_transpose:
243277
# transposed convs per output channel quantization
244-
weight_qspec = QuantizationSpec(
245-
dtype=weight_qspec.dtype,
246-
quant_min=weight_qspec.quant_min,
247-
quant_max=weight_qspec.quant_max,
248-
qscheme=weight_qspec.qscheme,
249-
ch_axis=1,
250-
is_dynamic=False,
251-
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
252-
)
253-
input_qspec_map[weight] = weight_qspec
278+
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
254279

255-
# Only annotate dynamically quantized conv if it's 2D and not depthwise
256-
if (
280+
input_qspec_map[weight] = weight_qspec
281+
is_dynamic = (
257282
quantization_config
258283
and quantization_config.input_activation
259284
and quantization_config.input_activation.is_dynamic
260-
):
285+
)
286+
287+
# Only annotate dynamically quantized conv if it's 2D and not depthwise
288+
if is_dynamic:
261289
weight_val = weight.meta.get("val", None)
262290
weight_shape = getattr(weight_val, "shape", None)
263-
264291
# Skip if not a 4D weight tensor (i.e. not conv2d)
265-
if weight_shape is not None and len(weight_shape) != 4:
266-
continue
267-
292+
skip = skip or (weight_shape is not None and len(weight_shape) != 4)
268293
# Skip if depthwise (default to groups=1 since it's not an arg)
269-
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
270-
continue
294+
skip = skip or (
295+
not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False)
296+
)
271297

272298
# adding weight node to the partition as well
273299
partition = [conv_node, conv_node.args[1]]
@@ -277,7 +303,7 @@ def _do_annotate_conv(
277303
input_qspec_map[bias] = get_bias_qspec(quantization_config)
278304
partition.append(bias)
279305

280-
if _is_annotated(partition):
306+
if _is_annotated(partition) or skip:
281307
continue
282308

283309
if filter_fn and any(not filter_fn(n) for n in partition):
@@ -324,17 +350,10 @@ def _do_annotate_conv_relu(
324350
weight = conv_node.args[1]
325351
assert isinstance(weight, Node)
326352
weight_qspec = get_weight_qspec(quantization_config)
353+
groups = get_groups_from_conv(conv_node)
327354
if is_conv_transpose:
328355
# transposed convs per output channel quantization
329-
weight_qspec = QuantizationSpec(
330-
dtype=weight_qspec.dtype,
331-
quant_min=weight_qspec.quant_min,
332-
quant_max=weight_qspec.quant_max,
333-
qscheme=weight_qspec.qscheme,
334-
ch_axis=1,
335-
is_dynamic=False,
336-
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
337-
)
356+
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
338357
input_qspec_map[weight] = weight_qspec
339358

340359
# adding weight node to the partition as well
@@ -347,6 +366,9 @@ def _do_annotate_conv_relu(
347366
if _is_annotated(partition):
348367
continue
349368

369+
if is_conv_transpose and groups != 1:
370+
continue
371+
350372
if filter_fn and any(not filter_fn(n) for n in partition):
351373
continue
352374

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,11 @@ def get_inputs(self):
174174

175175

176176
class Conv2dDQSeq(torch.nn.Module):
177-
def __init__(self):
177+
def __init__(self, transpose=False):
178178
super().__init__()
179-
self.first = torch.nn.Conv2d(
180-
in_channels=3, out_channels=8, kernel_size=3, padding=1
181-
)
182-
self.second = torch.nn.Conv2d(
183-
in_channels=8, out_channels=10, kernel_size=3, padding=1
184-
)
179+
op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d
180+
self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1)
181+
self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1)
185182

186183
def forward(self, x):
187184
y = self.first(x)
@@ -192,14 +189,11 @@ def get_inputs(self):
192189

193190

194191
class Conv2dDQParallel(torch.nn.Module):
195-
def __init__(self):
192+
def __init__(self, transpose=False):
196193
super().__init__()
197-
self.first = torch.nn.Conv2d(
198-
in_channels=3, out_channels=8, kernel_size=3, padding=1
199-
)
200-
self.second = torch.nn.Conv2d(
201-
in_channels=3, out_channels=8, kernel_size=3, padding=1
202-
)
194+
op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d
195+
self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1)
196+
self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1)
203197

204198
def forward(self, x):
205199
first = self.first(x)
@@ -266,8 +260,7 @@ def _test_dq(
266260
)
267261

268262
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
269-
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
270-
per_op_mode=True,
263+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True
271264
)
272265

273266
tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
@@ -349,11 +342,10 @@ def test_fp32_conv2d_depthwise(self):
349342
)
350343

351344
def test_qs8_conv2d_depthwise(self):
352-
for transpose in (True, False):
353-
self._test(
354-
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose),
355-
quant_config=get_symmetric_quantization_config(),
356-
)
345+
self._test(
346+
Conv2d(groups=2, in_channels=2, out_channels=6),
347+
quant_config=get_symmetric_quantization_config(),
348+
)
357349

358350
def test_fp32_conv2d_bn(self):
359351
class Conv2dBatchNorm(torch.nn.Module):
@@ -515,17 +507,14 @@ def forward(self, x):
515507
def get_inputs(self):
516508
return (torch.randn(batches, in_channels, height, width) * 11,)
517509

518-
for transpose in (True, False):
519-
for per_channel_quant in (False, True):
520-
if transpose and per_channel_quant:
521-
continue
522-
model = ModelConvReLU(transpose=transpose)
523-
self._test(
524-
model,
525-
quant_config=get_symmetric_quantization_config(
526-
is_per_channel=per_channel_quant
527-
),
528-
)
510+
for per_channel_quant in (False, True):
511+
model = ModelConvReLU()
512+
self._test(
513+
model,
514+
quant_config=get_symmetric_quantization_config(
515+
is_per_channel=per_channel_quant
516+
),
517+
)
529518

530519
def test_qs8_conv2d_relu_seq(self):
531520
class ConvReLUSeq(torch.nn.Module):
@@ -728,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None:
728717
model = Conv2dDQParallel()
729718
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
730719
self._test_dq(model, conv_count)
720+
721+
def test_dq_conv2d_transpose(self) -> None:
722+
model = Conv2d(
723+
in_channels=3,
724+
out_channels=10,
725+
kernel_size=(3, 3),
726+
stride=(1, 1),
727+
padding=(0, 0),
728+
batches=1,
729+
width=8,
730+
height=8,
731+
transpose=True,
732+
)
733+
self._test_dq(model)
734+
735+
def test_dq_conv2d_transpose_seq(self) -> None:
736+
model = Conv2dDQSeq(transpose=True)
737+
conv_count = sum(
738+
1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d
739+
)
740+
self._test_dq(model, conv_count)
741+
742+
def test_dq_conv2d_transpose_parallel(self) -> None:
743+
model = Conv2dDQParallel(transpose=True)
744+
conv_count = sum(
745+
1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d
746+
)
747+
self._test_dq(model, conv_count)

backends/xnnpack/utils/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
is_lifted_tensor_constant,
2626
is_param,
2727
)
28+
from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node
2829

2930

3031
### XNNPACK Capture ###
@@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
160161
return source_fn[1]
161162

162163

164+
def get_groups_from_conv(conv_node: torch.fx.Node) -> int:
165+
if _is_conv_node(conv_node):
166+
in_node = cast(torch.fx.Node, conv_node.args[0])
167+
weight_node = cast(torch.fx.Node, conv_node.args[1])
168+
# groups isn't given to us in the training graph so we deduce it from the weight shape
169+
# and the input shape
170+
171+
# input shape is (N, C_in, H_in, W_in)
172+
in_channels = in_node.meta["val"].shape[1]
173+
174+
# weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1])
175+
in_groups = weight_node.meta["val"].shape[1]
176+
177+
return in_channels // in_groups
178+
elif _is_conv_transpose_node(conv_node):
179+
weight_node = cast(torch.fx.Node, conv_node.args[1])
180+
# groups isn't given to us in the training graph so we deduce it from the weight shape
181+
# and the output shape
182+
183+
# weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1])
184+
out_groups = weight_node.meta["val"].shape[1]
185+
186+
# output shape is (N, C_out, H_out, W_out)
187+
out_channels = conv_node.meta["val"].shape[1]
188+
189+
return out_channels // out_groups
190+
191+
raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node")
192+
193+
163194
def is_depthwise_conv(
164195
kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False
165196
) -> bool:

0 commit comments

Comments
 (0)