Skip to content

Commit 26cd168

Browse files
committed
Apply suggestions
1 parent db8a015 commit 26cd168

File tree

10 files changed

+395
-577
lines changed

10 files changed

+395
-577
lines changed

backends/xnnpack/_passes/fuse_activation_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def call(self, graph_module: torch.fx.GraphModule):
6868
preceding_op.op == "call_function"
6969
and preceding_op.target in self.FUSEABLE_OPS
7070
):
71+
# Check that current activation is the only user of the preceding op
72+
# so that we can fuse the activation into the preceding op
7173
if len(preceding_op.users) > 1:
7274
continue
7375
# Delete activation, and embed metadata into preceding op

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def can_fuse(
140140
Determine whether a batch norm node can be fused with a preceding conv node.
141141
"""
142142

143+
# Only fuse transposed convolutions if the kernel size matches the stride,
144+
# Weights are not distributed equally across the spatial dimensions otherwise
143145
is_transpose = conv.args[6]
144146
kernel_node = get_input_node(conv, 1)
145147
kernel_shape = get_shape(kernel_node)

backends/xnnpack/_passes/tag_implicit_q_dq_pass.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,16 @@ def is_dynamically_quantized(self, node: torch.fx.Node) -> bool:
8383
return is_dynamic_qdq(node)
8484

8585
def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
86-
return (
87-
node.op == "call_function"
88-
and cast(torch._ops.OpOverload, node.target).name()
89-
in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
90-
)
86+
if node.op != "call_function":
87+
return False
88+
89+
op_name = cast(torch._ops.OpOverload, node.target).name()
90+
91+
# Weight and Input should both be quantized
92+
if op_name == exir_ops.edge.aten.convolution.default.name():
93+
return is_dequant(node.args[1])
94+
95+
return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
9196

9297
def is_supported_quant_module(self, node: torch.fx.Node) -> bool:
9398
is_supported = (

backends/xnnpack/operators/node_visitor.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,9 @@ def define_tensor( # noqa: C901
343343
xnn_graph: XNNGraph,
344344
vals_to_ids: Dict[torch.fx.Node, int],
345345
convert_to_nhwc: bool = False,
346-
swap_nc_for_depthwise_weights: bool = False,
346+
swap_in_out_for_weights: bool = False,
347347
quant_params: Optional[QuantParams] = None,
348348
fp32_static_weights: bool = False,
349-
swap_in_out_for_transpose_weights: bool = False,
350349
groups: int = 1,
351350
) -> None:
352351
"""
@@ -359,19 +358,21 @@ def define_tensor( # noqa: C901
359358
their corresponding ids in XNNGraph
360359
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
361360
reflect the nhwc memory format.
362-
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
363-
should be permuted such that the N and C dimensions are
364-
swapped, which should be used for depthwise convolution
361+
swap_in_out_for_weights: bool to indicate whether tensor shape should be
362+
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
363+
, which should be used for depthwise/transpose convolution
365364
weights. This is only valid for tensors which hold
366365
constant data. If used along with convert_to_nhwc, this
367366
swap will happen before converting to nhwc.
368367
quant_params: Quantization meta data for this tensor, None if it is not quantized
369368
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
370-
swap_in_out_for_transpose_weights: bool to indicate whether tensor shape should be
371-
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
372-
groups: number of groups for swap_in_out_for_transpose_weights
369+
groups: number of groups for swap_in_out_for_weights
373370
"""
374371

372+
assert (
373+
swap_in_out_for_weights or groups == 1
374+
), "groups is option for swap_in_out_for_weights"
375+
375376
if tensor in vals_to_ids:
376377
return
377378

@@ -399,18 +400,15 @@ def define_tensor( # noqa: C901
399400
xnn_graph,
400401
vals_to_ids,
401402
convert_to_nhwc,
402-
swap_nc_for_depthwise_weights,
403+
swap_in_out_for_weights,
403404
quant_params,
404405
fp32_static_weights,
405-
swap_in_out_for_transpose_weights,
406406
groups,
407407
)
408408

409409
# convert tensor shape must reflect memory format, default is contiguous, so
410410
# only permute shape if we are converting the tensor to nhwc format
411-
if swap_nc_for_depthwise_weights:
412-
dims = [dims[1], dims[0]] + dims[2:]
413-
if swap_in_out_for_transpose_weights:
411+
if swap_in_out_for_weights:
414412
dims = [dims[1] * groups, dims[0] // groups] + dims[2:]
415413
if convert_to_nhwc:
416414
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
@@ -431,24 +429,16 @@ def define_tensor( # noqa: C901
431429
)
432430

433431
# Override the quant params axis since we have
434-
# updated the weights for depthwise, with that the out_channels dim
432+
# updated the weights for depthwise/ transposed_conv2d, with that the out_channels dim
435433
# will be dims[3] instead of dims[0]. Let's update the per_channel
436434
# quant axis to match the new weight tensor before serializing
437-
if swap_nc_for_depthwise_weights and (
438-
quant_params and quant_params.per_channel
439-
):
440-
if quant_params.axis == 0:
441-
quant_params.axis = len(dims) - 1
442-
else:
443-
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
444-
445-
if swap_in_out_for_transpose_weights and (
446-
quant_params and quant_params.per_channel
447-
):
435+
if swap_in_out_for_weights and (quant_params and quant_params.per_channel):
448436
if quant_params.axis == 0:
449437
quant_params.axis = len(dims) - 1
438+
elif quant_params.axis == 1:
439+
quant_params.axis = 0
450440
else:
451-
assert f"Unsupported weight per channel quantization axis for conv_transpose2d: {quant_params.axis}, expecting 0."
441+
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."
452442

453443
# Serialize tensor value
454444
ser_val = (
@@ -509,10 +499,9 @@ def get_serialized_buffer_index(
509499
xnn_graph: XNNGraph,
510500
vals_to_ids: Dict[torch.fx.Node, int],
511501
convert_to_nhwc: bool,
512-
swap_nc_for_depthwise_weights: bool,
502+
swap_in_out_for_weights: bool,
513503
quant_params: Optional[QuantParams],
514504
fp32_static_weights: bool = False,
515-
swap_in_out_for_transpose_weights: bool = False,
516505
groups: int = 1,
517506
) -> int:
518507
"""
@@ -526,24 +515,30 @@ def get_serialized_buffer_index(
526515
their corresponding ids in XNNGraph
527516
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
528517
reflect the nhwc memory format.
529-
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
530-
should be permuted such that the N and C dimensions are
531-
swapped, which should be used for depthwise convolution
518+
swap_in_out_for_weights: bool to indicate whether tensor shape should be
519+
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
520+
, which should be used for depthwise/transpose convolution
532521
weights. This is only valid for tensors which hold
533522
constant data. If used along with convert_to_nhwc, this
534523
swap will happen before converting to nhwc.
535524
quant_params: Quantization meta data for this tensor, None if it is not quantize
536525
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+
groups: groups for swap_in_out_for_weights
537527
538528
Returns:
539529
buffer_idx: idx of the serialized data. 0 If not associated constant
540530
data
541531
"""
532+
533+
assert (
534+
swap_in_out_for_weights or groups == 1
535+
), "groups is option for swap_in_out_for_weights"
536+
542537
# The get_attr node is the input to quant_params.
543538
get_attr_node = tensor if quant_params is None else quant_params.q_input
544539
if not is_param_node(self.exported_program, get_attr_node):
545540
check_or_raise(
546-
not swap_nc_for_depthwise_weights,
541+
not swap_in_out_for_weights,
547542
"Swapping N and C dimensions is only valid for constant data tensors",
548543
)
549544
return 0
@@ -560,12 +555,9 @@ def get_serialized_buffer_index(
560555
# ensure that the const is fp32
561556
const_val = const_val.to(dtype=torch.float32).contiguous()
562557

563-
if swap_nc_for_depthwise_weights:
564-
const_val = const_val.permute(
565-
dims=((1, 0) + tuple(range(2, const_val.dim())))
566-
).contiguous()
567-
568-
if swap_in_out_for_transpose_weights:
558+
if swap_in_out_for_weights:
559+
# Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
560+
# which should be used for depthwise/transpose convolution weights for XNNPACK
569561
shape = const_val.shape
570562
const_val = const_val.reshape(
571563
(groups, const_val.shape[0] // groups) + const_val.shape[1:]

backends/xnnpack/operators/op_conv2d.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,26 @@ def define_node(
8484
)
8585
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
8686

87+
if weight_quant_params is not None and weight_quant_params.per_channel:
88+
if is_transpose:
89+
check_or_raise(
90+
weight_quant_params.axis == 1 and groups == 1,
91+
"XNNPACK currently only supports per output channel quantization with groups == 1 for transpose convolutions",
92+
)
93+
elif is_depthwise_conv:
94+
check_or_raise(
95+
weight_quant_params.axis == 0,
96+
"XNNPACK currently only supports per input channel quantization for depthwise convolutions",
97+
)
8798
self.define_tensor(
8899
kernel_node,
89100
xnn_graph,
90101
vals_to_ids,
91102
convert_to_nhwc=True,
92-
swap_nc_for_depthwise_weights=is_depthwise_conv,
103+
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
93104
quant_params=weight_quant_params,
94105
fp32_static_weights=fp32_static_weights,
95-
swap_in_out_for_transpose_weights=is_transpose,
96-
groups=groups,
106+
groups=groups if is_transpose else 1,
97107
)
98108
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
99109

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def __init__(self, **kwargs):
318318

319319
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
320320
"""
321-
Currently we have no support for convolution 3d and transposed convolution
321+
Currently we have no support for convolution 3d
322322
"""
323323
if not super().check_constraints(node, ep):
324324
return False
@@ -333,11 +333,12 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
333333

334334
is_transpose = node.args[6]
335335
groups = cast(int, node.args[8])
336+
336337
if (
337338
is_transpose
338339
and weight_quant_params is not None
339340
and weight_quant_params.per_channel
340-
and groups > 1
341+
and (groups > 1 or weight_quant_params.axis != 1)
341342
):
342343
why(
343344
node,

backends/xnnpack/partition/configs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,8 @@
131131
torch.nn.functional.conv1d,
132132
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
133133
torch.nn.Conv2d,
134-
torch.nn.ConvTranspose2d,
135134
torch.nn.functional.conv2d,
136-
torch.nn.functional.conv_transpose2d,
137135
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
138-
torch.ao.nn.quantized.reference.modules.conv.ConvTranspose2d,
139136
torch.nn.BatchNorm1d,
140137
torch.nn.BatchNorm2d,
141138
]

0 commit comments

Comments
 (0)