Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,8 @@
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_w8a32_linear_out
27 changes: 27 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,14 @@
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_w8a32_linear(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
)
lib.define(
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
)


# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
aten_lib.define(
Expand Down Expand Up @@ -2562,3 +2570,22 @@ def quantized_softmax_per_tensor_meta(
out_zero_point: int,
) -> torch.Tensor:
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_w8a32_linear")
def quantized_w8a32_linear_meta(
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [in_dim, out_dim]
# output comes in empty with shape [leading_dims, out_dim]
src_shape = list(src.shape)
weight_shape = weight.shape
assert len(weight_shape) == 2
assert src_shape[-1] == weight_shape[-1]
src_shape[-1] = weight_shape[0]
return src.new_empty(src_shape, dtype=src.dtype)
31 changes: 31 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ReluPattern0,
ReluPattern1,
SoftmaxPattern,
MixedW8A32LinearPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
Expand Down Expand Up @@ -389,6 +390,27 @@ def get_args_and_kwargs_relu(
}
return args, kwargs

def get_args_and_kwargs_mixed_w8a32_linear(
graph_module: GraphModule,
other_inputs: List[fx.Node],
weights_inputs: List[fx.Node],
dequants_weights: List[fx.Node],
bias_inputs: List[fx.Node],
dequants_biases: List[fx.Node],
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
w_scale_ = dequants_weights[0].args[1]
b_scale_ = dequants_biases[0].args[1]

args = (
other_inputs[0],
weights_inputs[0],
w_scale_,
bias_inputs[0],
b_scale_,
)
kwargs = {}

return args, kwargs

def get_args_and_kwargs_softmax(
graph_module: GraphModule,
Expand Down Expand Up @@ -617,6 +639,15 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
op_node,
)
elif isinstance(pattern, MixedW8A32LinearPattern):
args, kwargs = get_args_and_kwargs_mixed_w8a32_linear(
graph_module,
other_inputs,
weights_inputs,
dequants_weights,
bias_inputs,
dequants_biases,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
43 changes: 42 additions & 1 deletion backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:


class SoftmaxPattern(QuantizationPattern):

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten._softmax.default]

Expand All @@ -546,3 +545,45 @@ def get_anchors(

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_softmax.default


class MixedW8A32LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-ignore[29]
linear_layer = fused_partition[0].nodes[-1]

# Bail if the arguments have different shapes than expected
if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0:
return (PartitionAnchors(
empty=True,
), linear_layer)

input_node = linear_layer.args[0]
input_shape = input_node.meta['tensor_meta'].shape

# Bail if the weights are not multiple of 4 (SIMD)
if input_shape[-1] % 4 != 0:
return (PartitionAnchors(
empty=True,
), linear_layer)
# Currenly only supporting vector-matrix multiplication
if len(input_shape) > 0 and input_shape[-2] != 1:
return (PartitionAnchors(
empty=True,
), linear_layer)

return (PartitionAnchors(
inputs=[],
weights=[(linear_layer, 1)],
biases=[(linear_layer, 2)],
output=[],
others=[(linear_layer, 0)],
), linear_layer)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_linear.default
18 changes: 18 additions & 0 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ReluPattern0,
ReluPattern1,
SoftmaxPattern,
MixedW8A32LinearPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
Expand Down Expand Up @@ -109,6 +110,13 @@
None,
)

qconfig_A32W8sym = QuantizationConfig(
input_activation= None,
output_activation= None,
weight= wgt_qspec_sym8s,
bias= wgt_qspec_sym8s,
)


class CadenceAtenQuantizer(Quantizer):
def __init__(
Expand Down Expand Up @@ -301,6 +309,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
super().__init__(quantizers)

class CadenceW8A32MixedQuantizer(CadenceQuantizer):
"""
Quantizer for Conversational Focus
"""

def __init__(self) -> None:
quantizers = []
quantizers.append(CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym))
super().__init__(quantizers)


class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
"""
Expand Down
Loading