Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/fuse_activation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def call(self, graph_module: torch.fx.GraphModule):
preceding_op.op == "call_function"
and preceding_op.target in self.FUSEABLE_OPS
):
# Check that current activation is the only user of the preceding op
# so that we can fuse the activation into the preceding op
if len(preceding_op.users) > 1:
continue
# Delete activation, and embed metadata into preceding op
output_min_max = self.get_output_min_max_from_activation(
activation_node
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/_passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# as an arg)
eps = bn.args[-1]

is_transpose = conv.args[6]
# Compute the updated weight and bias after fusing conv op
# with batchnorm op.
fused_weight, fused_bias = fuse_conv_bn_weights(
Expand All @@ -92,6 +93,7 @@ def call(self, graph_module: torch.fx.GraphModule):
eps,
bn_weight,
bn_bias,
is_transpose,
)

# Modify the graph by updating the weight and bias of conv op
Expand Down
15 changes: 10 additions & 5 deletions backends/xnnpack/_passes/tag_implicit_q_dq_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,16 @@ def is_dynamically_quantized(self, node: torch.fx.Node) -> bool:
return is_dynamic_qdq(node)

def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and cast(torch._ops.OpOverload, node.target).name()
in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
)
if node.op != "call_function":
return False

op_name = cast(torch._ops.OpOverload, node.target).name()

# Weight and Input should both be quantized
if op_name == exir_ops.edge.aten.convolution.default.name():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the reasoning for this? I imagine it should've returned true in the previous implementation as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue arises when a non-quantized operation is interleaved between two quantized operations. That operation also match the dequantize-op-quantize pattern. However, an operation with quantized input and float weight is not supported by XNNPACK.

return is_dequant(node.args[1])

return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET

def is_supported_quant_module(self, node: torch.fx.Node) -> bool:
is_supported = (
Expand Down
63 changes: 42 additions & 21 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,16 @@ def _check_per_channel_group_params(
# For now group quantization is only supported for 4b weights
assert quant_params.is_qc4w, "Only 4b group quantization is supported"

def define_tensor(
def define_tensor( # noqa: C901
self,
tensor: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
convert_to_nhwc: bool = False,
swap_nc_for_depthwise_weights: bool = False,
swap_in_out_for_weights: bool = False,
quant_params: Optional[QuantParams] = None,
fp32_static_weights: bool = False,
groups: int = 1,
) -> None:
"""
Defines an tensor value into the XNNGraph
Expand All @@ -357,16 +358,21 @@ def define_tensor(
their corresponding ids in XNNGraph
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
reflect the nhwc memory format.
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
should be permuted such that the N and C dimensions are
swapped, which should be used for depthwise convolution
swap_in_out_for_weights: bool to indicate whether tensor shape should be
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
, which should be used for depthwise/transpose convolution
weights. This is only valid for tensors which hold
constant data. If used along with convert_to_nhwc, this
swap will happen before converting to nhwc.
quant_params: Quantization meta data for this tensor, None if it is not quantized
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
groups: number of groups for swap_in_out_for_weights
"""

assert (
swap_in_out_for_weights or groups == 1
), "groups is option for swap_in_out_for_weights"

if tensor in vals_to_ids:
return

Expand Down Expand Up @@ -394,15 +400,16 @@ def define_tensor(
xnn_graph,
vals_to_ids,
convert_to_nhwc,
swap_nc_for_depthwise_weights,
swap_in_out_for_weights,
quant_params,
fp32_static_weights,
groups,
)

# convert tensor shape must reflect memory format, default is contiguous, so
# only permute shape if we are converting the tensor to nhwc format
if swap_nc_for_depthwise_weights:
dims = [dims[1], dims[0]] + dims[2:]
if swap_in_out_for_weights:
dims = [dims[1] * groups, dims[0] // groups] + dims[2:]
if convert_to_nhwc:
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
Expand All @@ -422,16 +429,16 @@ def define_tensor(
)

# Override the quant params axis since we have
# updated the weights for depthwise, with that the out_channels dim
# updated the weights for depthwise/ transposed_conv2d, with that the out_channels dim
# will be dims[3] instead of dims[0]. Let's update the per_channel
# quant axis to match the new weight tensor before serializing
if swap_nc_for_depthwise_weights and (
quant_params and quant_params.per_channel
):
if swap_in_out_for_weights and (quant_params and quant_params.per_channel):
if quant_params.axis == 0:
quant_params.axis = len(dims) - 1
elif quant_params.axis == 1:
quant_params.axis = 0
else:
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."

# Serialize tensor value
ser_val = (
Expand Down Expand Up @@ -492,9 +499,10 @@ def get_serialized_buffer_index(
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
convert_to_nhwc: bool,
swap_nc_for_depthwise_weights: bool,
swap_in_out_for_weights: bool,
quant_params: Optional[QuantParams],
fp32_static_weights: bool = False,
groups: int = 1,
) -> int:
"""
If tensor holds some constant data, serialize it and return the
Expand All @@ -507,24 +515,30 @@ def get_serialized_buffer_index(
their corresponding ids in XNNGraph
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
reflect the nhwc memory format.
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
should be permuted such that the N and C dimensions are
swapped, which should be used for depthwise convolution
swap_in_out_for_weights: bool to indicate whether tensor shape should be
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
, which should be used for depthwise/transpose convolution
weights. This is only valid for tensors which hold
constant data. If used along with convert_to_nhwc, this
swap will happen before converting to nhwc.
quant_params: Quantization meta data for this tensor, None if it is not quantize
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
groups: groups for swap_in_out_for_weights

Returns:
buffer_idx: idx of the serialized data. 0 If not associated constant
data
"""

assert (
swap_in_out_for_weights or groups == 1
), "groups is option for swap_in_out_for_weights"

# The get_attr node is the input to quant_params.
get_attr_node = tensor if quant_params is None else quant_params.q_input
if not is_param_node(self.exported_program, get_attr_node):
check_or_raise(
not swap_nc_for_depthwise_weights,
not swap_in_out_for_weights,
"Swapping N and C dimensions is only valid for constant data tensors",
)
return 0
Expand All @@ -541,9 +555,16 @@ def get_serialized_buffer_index(
# ensure that the const is fp32
const_val = const_val.to(dtype=torch.float32).contiguous()

if swap_nc_for_depthwise_weights:
const_val = const_val.permute(
dims=((1, 0) + tuple(range(2, const_val.dim())))
if swap_in_out_for_weights:
# Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
# which should be used for depthwise/transpose convolution weights for XNNPACK
shape = const_val.shape
const_val = const_val.reshape(
(groups, const_val.shape[0] // groups) + const_val.shape[1:]
)
const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim())))
const_val = const_val.reshape(
(shape[1] * groups, shape[0] // groups) + shape[2:]
).contiguous()

if convert_to_nhwc:
Expand Down
45 changes: 33 additions & 12 deletions backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.backends.xnnpack.operators.quant_params import QuantParams
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNConv2d,
XNNConvTranspose2d,
XNNDepthwiseConv2d,
XNNGraph,
XNode,
Expand Down Expand Up @@ -52,35 +53,57 @@ def define_node(
) # NHWC input
kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)]

# filter shape for pytorch convolution is (oc, inc/groups, height, width)
# shape for xnnpack convolution is (oc, height, width, inc/groups), to convert
# to the proper shape, this is essentially a NCHW to NHWC conversion
# filter shape for pytorch convolution is (oc, inc/groups, height, width),
# filter shape for pytorch transpose convolution is (inc, oc/groups, height, width),
# shape for xnnpack convolution is (oc, height, width, inc/groups),
# shape for xnnpack transpose convolution is (oc, height, width, inc/groups),
# to convert to the proper shape, this is essentially a NCHW to NHWC conversion
kernel_node = get_input_node(node, 1)
kernel_shape = get_shape(kernel_node)
groups = cast(int, node.args[8])
group_input_channels = kernel_shape[1]
group_output_channels = int(kernel_shape[0] / groups)
is_transpose = node.args[6]

if is_transpose:
group_input_channels = int(kernel_shape[0] / groups)
group_output_channels = kernel_shape[1]
else:
group_input_channels = kernel_shape[1]
group_output_channels = int(kernel_shape[0] / groups)

# XNNPACK expects the kernel's N and C dimensions to be swapped for
# Depthwise Convolution, which occurs under the following conditions:
# 1) groups = input_channels (i.e. group_input_channels = 1)
# 2) output_channels is a positive integer multiple of input channels
is_depthwise_conv = (group_input_channels == 1) and (
group_output_channels % group_input_channels == 0
is_depthwise_conv = (
(group_input_channels == 1)
and (group_output_channels % group_input_channels == 0)
and not is_transpose
)
weight_quant_params = QuantParams.from_weights(
kernel_node, self._exported_program
)
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16

if weight_quant_params is not None and weight_quant_params.per_channel:
if is_transpose:
check_or_raise(
weight_quant_params.axis == 1 and groups == 1,
"XNNPACK currently only supports per output channel quantization with groups == 1 for transpose convolutions",
)
elif is_depthwise_conv:
check_or_raise(
weight_quant_params.axis == 0,
"XNNPACK currently only supports per input channel quantization for depthwise convolutions",
)
self.define_tensor(
kernel_node,
xnn_graph,
vals_to_ids,
convert_to_nhwc=True,
swap_nc_for_depthwise_weights=is_depthwise_conv,
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
quant_params=weight_quant_params,
fp32_static_weights=fp32_static_weights,
groups=groups if is_transpose else 1,
)
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]

Expand Down Expand Up @@ -120,10 +143,6 @@ def define_node(
if len(padding) == 1:
padding = padding + padding

# args[6] = transposed
check_or_raise(
not cast(bool, node.args[6]), "No support for transposed convolution"
)
# args[7] = output padding
check_or_raise(
all(out_pad == 0 for out_pad in cast(List[int], node.args[7])),
Expand Down Expand Up @@ -152,6 +171,8 @@ def define_node(

if is_depthwise_conv:
conv_node_type = XNNDepthwiseConv2d
elif is_transpose:
conv_node_type = XNNConvTranspose2d
else:
conv_node_type = XNNConv2d

Expand Down
24 changes: 19 additions & 5 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast, List, Optional, Tuple

import torch
from executorch.backends.xnnpack.operators.quant_params import QuantParams
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
XNNPartitionerConfig,
Expand Down Expand Up @@ -317,7 +318,7 @@ def __init__(self, **kwargs):

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Currently we have no support for convolution 3d and transposed convolution
Currently we have no support for convolution 3d
"""
if not super().check_constraints(node, ep):
return False
Expand All @@ -327,11 +328,24 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
why(node, "Only support 1D + 2D Conv")
return False # Only support 1D + 2D Conv

transposed = cast(bool, node.args[6])
if transposed:
why(node, "Transposed Conv is not supported")
return False # Currently don't support transposed conv
kernel_node = get_input_node(node, 1)
weight_quant_params = QuantParams.from_weights(kernel_node, ep)

is_transpose = node.args[6]
groups = cast(int, node.args[8])

if (
is_transpose
and weight_quant_params is not None
and weight_quant_params.per_channel
and (groups > 1 or weight_quant_params.axis != 1)
):
why(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty neat way of checking this constraint

node,
"XNNPACK does not support per input channel quantization"
"for transpose convolutions with groups > 1",
)
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test which triggers this constraint?

return True

def supported_precision_types(self):
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
torch.nn.BatchNorm2d,
torch.nn.BatchNorm1d,
torch.nn.Conv2d,
torch.nn.ConvTranspose2d,
torch.nn.Linear,
torch.nn.functional.linear,
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
Expand Down
Loading
Loading