Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ executorch_generated_lib(
"//executorch/backends/cadence/generic/operators:dequantize_per_tensor",
"//executorch/backends/cadence/generic/operators:quantize_per_tensor",
"//executorch/backends/cadence/generic/operators:quantized_add_out",
"//executorch/backends/cadence/generic/operators:quantized_conv1d_ncl_out",
"//executorch/backends/cadence/generic/operators:quantized_conv1d_nlc_out",
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nchw_out",
"//executorch/backends/cadence/generic/operators:quantized_conv2d_nhwc_out",
"//executorch/backends/cadence/generic/operators:quantized_fully_connected_out",
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,16 @@
- arg_meta: null
kernel_name: impl::generic::dequantize_per_tensor_asym32s_out

- func: cadence::quantized_conv1d_ncl.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_conv1d_ncl_out

- func: cadence::quantized_conv1d_nlc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_conv1d_nlc_out

- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down Expand Up @@ -419,6 +429,16 @@
- arg_meta: null
kernel_name: impl::generic::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_conv1d_ncl_per_tensor_out

- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_conv1d_nlc_per_tensor_out

- func: cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,16 @@
- arg_meta: null
kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out

- func: cadence::quantized_conv1d_ncl.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv1d_ncl_out

- func: cadence::quantized_conv1d_nlc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv1d_nlc_out

- func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down Expand Up @@ -430,6 +440,16 @@
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv1d_ncl_per_tensor_out

- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_conv1d_nlc_per_tensor_out

- func: cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
174 changes: 174 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@
lib.define(
"quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nlc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nlc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_ncl(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_ncl.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_ncl.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nlc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
Expand Down Expand Up @@ -934,6 +958,94 @@ def quantized_conv2d_nhwc_meta(
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_nlc")
def quantized_conv1d_nlc_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
out_channels, *kernel_size, _ = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
True,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, True
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_ncl")
def quantized_conv1d_ncl_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
out_channels, _, *kernel_size = weight.shape

in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = (
get_conv1d_output_size(
in_size,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size[0],
False,
)
if len(in_size) == 3
else get_conv2d_output_size(
in_size, out_channels, stride, padding, dilation, kernel_size, False
)
)

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv2d_nchw")
def quantized_conv2d_nchw_meta(
input: torch.Tensor,
Expand Down Expand Up @@ -2371,6 +2483,68 @@ def roi_align_box_processor_meta(
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)


@register_fake("cadence::quantized_conv1d_ncl.per_tensor")
def quantized_conv1d_ncl_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
out_channels, _, kernel_size = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
False,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_nlc.per_tensor")
def quantized_conv1d_nlc_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
out_channels, _, kernel_size = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
True,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_ncl_asym8sxsym8s_asym8s.per_tensor")
def quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_meta(
input: torch.Tensor,
Expand Down
14 changes: 10 additions & 4 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv1d_ncl.default


class Conv2dPattern(QuantizationPattern):
Expand Down Expand Up @@ -459,29 +459,35 @@ def get_anchors(
output=[(relu_node,)], # Output is from the relu node
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default


# Conv1d + regular relu op fusion
class Conv1dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.default

# Conv1d + alternate relu op fusion
class Conv1dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.default

# Conv2d + regular relu op fusion
class Conv2dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default

# Conv2d + alternate relu op fusion
class Conv2dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
Loading
Loading