diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index bcab980abd6..8c65e745c21 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -548,3 +548,8 @@ kernels: - arg_meta: null kernel_name: impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out + +- func: cadence::quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_w8a32_linear_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index f7d07018e59..9266cc72970 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -564,6 +564,14 @@ "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_w8a32_linear(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor" +) +lib.define( + "quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)" +) + + # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") aten_lib.define( @@ -2562,3 +2570,22 @@ def quantized_softmax_per_tensor_meta( out_zero_point: int, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=input.dtype) + + +@register_fake("cadence::quantized_w8a32_linear") +def quantized_w8a32_linear_meta( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [in_dim, out_dim] + # output comes in empty with shape [leading_dims, out_dim] + src_shape = list(src.shape) + weight_shape = weight.shape + assert len(weight_shape) == 2 + assert src_shape[-1] == weight_shape[-1] + src_shape[-1] = weight_shape[0] + return src.new_empty(src_shape, dtype=src.dtype) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 0461c03ccb7..9fc09cb31cd 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -27,6 +27,7 @@ ReluPattern0, ReluPattern1, SoftmaxPattern, + MixedW8A32LinearPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( check_out_zero_point_is_min_range, @@ -389,6 +390,27 @@ def get_args_and_kwargs_relu( } return args, kwargs +def get_args_and_kwargs_mixed_w8a32_linear( + graph_module: GraphModule, + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + dequants_biases: List[fx.Node], +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: + w_scale_ = dequants_weights[0].args[1] + b_scale_ = dequants_biases[0].args[1] + + args = ( + other_inputs[0], + weights_inputs[0], + w_scale_, + bias_inputs[0], + b_scale_, + ) + kwargs = {} + + return args, kwargs def get_args_and_kwargs_softmax( graph_module: GraphModule, @@ -617,6 +639,15 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 quant_node, op_node, ) + elif isinstance(pattern, MixedW8A32LinearPattern): + args, kwargs = get_args_and_kwargs_mixed_w8a32_linear( + graph_module, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + dequants_biases, + ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 4eae55502d7..df3ccbb033d 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]: class SoftmaxPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: return [torch.ops.aten._softmax.default] @@ -546,3 +545,45 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_softmax.default + + +class MixedW8A32LinearPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.linear.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # pyre-ignore[29] + linear_layer = fused_partition[0].nodes[-1] + + # Bail if the arguments have different shapes than expected + if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0: + return (PartitionAnchors( + empty=True, + ), linear_layer) + + input_node = linear_layer.args[0] + input_shape = input_node.meta['tensor_meta'].shape + + # Bail if the weights are not multiple of 4 (SIMD) + if input_shape[-1] % 4 != 0: + return (PartitionAnchors( + empty=True, + ), linear_layer) + # Currenly only supporting vector-matrix multiplication + if len(input_shape) > 0 and input_shape[-2] != 1: + return (PartitionAnchors( + empty=True, + ), linear_layer) + + return (PartitionAnchors( + inputs=[], + weights=[(linear_layer, 1)], + biases=[(linear_layer, 2)], + output=[], + others=[(linear_layer, 0)], + ), linear_layer) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_w8a32_linear.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 536b28f5cec..39503c7a0dc 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -28,6 +28,7 @@ ReluPattern0, ReluPattern1, SoftmaxPattern, + MixedW8A32LinearPattern, ) from executorch.backends.cadence.aot.quantizer.utils import ( find_sequential_partitions_aten, @@ -109,6 +110,13 @@ None, ) +qconfig_A32W8sym = QuantizationConfig( + input_activation= None, + output_activation= None, + weight= wgt_qspec_sym8s, + bias= wgt_qspec_sym8s, +) + class CadenceAtenQuantizer(Quantizer): def __init__( @@ -301,6 +309,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) super().__init__(quantizers) +class CadenceW8A32MixedQuantizer(CadenceQuantizer): + """ + Quantizer for Conversational Focus + """ + + def __init__(self) -> None: + quantizers = [] + quantizers.append(CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)) + super().__init__(quantizers) + class CadenceWithSoftmaxQuantizer(CadenceQuantizer): """