Skip to content

Commit 2cbfca7

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Including mixed quant Linear op in Jarvis (#14820)
Summary: # Summary This diff includes a general and HiFi4 optimized Linear operator. Specifically, it adds both a standard Linear implementation and a version optimized for HiFi4 DSPs, ensuring better performance on supported hardware. Reviewed By: mcremon-meta, skrtskrtfb Differential Revision: D81605171
1 parent 270873f commit 2cbfca7

File tree

5 files changed

+123
-1
lines changed

5 files changed

+123
-1
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,8 @@
548548
kernels:
549549
- arg_meta: null
550550
kernel_name: impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out
551+
552+
- func: cadence::quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
553+
kernels:
554+
- arg_meta: null
555+
kernel_name: impl::HiFi::quantized_w8a32_linear_out

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,14 @@
564564
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
565565
)
566566

567+
lib.define(
568+
"quantized_w8a32_linear(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
569+
)
570+
lib.define(
571+
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
572+
)
573+
574+
567575
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
568576
aten_lib = Library("aten", "FRAGMENT")
569577
aten_lib.define(
@@ -2562,3 +2570,22 @@ def quantized_softmax_per_tensor_meta(
25622570
out_zero_point: int,
25632571
) -> torch.Tensor:
25642572
return input.new_empty(input.size(), dtype=input.dtype)
2573+
2574+
2575+
@register_fake("cadence::quantized_w8a32_linear")
2576+
def quantized_w8a32_linear_meta(
2577+
src: torch.Tensor,
2578+
weight: torch.Tensor,
2579+
w_scale: float,
2580+
bias: torch.Tensor,
2581+
b_scale: float,
2582+
) -> torch.Tensor:
2583+
# src comes in shape [leading_dims, in_dim]
2584+
# weight comes in shape [in_dim, out_dim]
2585+
# output comes in empty with shape [leading_dims, out_dim]
2586+
src_shape = list(src.shape)
2587+
weight_shape = weight.shape
2588+
assert len(weight_shape) == 2
2589+
assert src_shape[-1] == weight_shape[-1]
2590+
src_shape[-1] = weight_shape[0]
2591+
return src.new_empty(src_shape, dtype=src.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ReluPattern0,
2828
ReluPattern1,
2929
SoftmaxPattern,
30+
MixedW8A32LinearPattern,
3031
)
3132
from executorch.backends.cadence.aot.quantizer.utils import (
3233
check_out_zero_point_is_min_range,
@@ -389,6 +390,27 @@ def get_args_and_kwargs_relu(
389390
}
390391
return args, kwargs
391392

393+
def get_args_and_kwargs_mixed_w8a32_linear(
394+
graph_module: GraphModule,
395+
other_inputs: List[fx.Node],
396+
weights_inputs: List[fx.Node],
397+
dequants_weights: List[fx.Node],
398+
bias_inputs: List[fx.Node],
399+
dequants_biases: List[fx.Node],
400+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
401+
w_scale_ = dequants_weights[0].args[1]
402+
b_scale_ = dequants_biases[0].args[1]
403+
404+
args = (
405+
other_inputs[0],
406+
weights_inputs[0],
407+
w_scale_,
408+
bias_inputs[0],
409+
b_scale_,
410+
)
411+
kwargs = {}
412+
413+
return args, kwargs
392414

393415
def get_args_and_kwargs_softmax(
394416
graph_module: GraphModule,
@@ -617,6 +639,15 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
617639
quant_node,
618640
op_node,
619641
)
642+
elif isinstance(pattern, MixedW8A32LinearPattern):
643+
args, kwargs = get_args_and_kwargs_mixed_w8a32_linear(
644+
graph_module,
645+
other_inputs,
646+
weights_inputs,
647+
dequants_weights,
648+
bias_inputs,
649+
dequants_biases,
650+
)
620651

621652
fused = graph_module.graph.call_function(
622653
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:
524524

525525

526526
class SoftmaxPattern(QuantizationPattern):
527-
528527
def partition_types(self) -> List[OpOverload]:
529528
return [torch.ops.aten._softmax.default]
530529

@@ -546,3 +545,45 @@ def get_anchors(
546545

547546
def replacement_op(self) -> OpOverload:
548547
return torch.ops.cadence.quantized_softmax.default
548+
549+
550+
class MixedW8A32LinearPattern(QuantizationPattern):
551+
def partition_types(self) -> List[OpOverload]:
552+
return [torch.ops.aten.linear.default]
553+
554+
def get_anchors(
555+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
556+
) -> Tuple[PartitionAnchors, fx.Node]:
557+
# pyre-ignore[29]
558+
linear_layer = fused_partition[0].nodes[-1]
559+
560+
# Bail if the arguments have different shapes than expected
561+
if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0:
562+
return (PartitionAnchors(
563+
empty=True,
564+
), linear_layer)
565+
566+
input_node = linear_layer.args[0]
567+
input_shape = input_node.meta['tensor_meta'].shape
568+
569+
# Bail if the weights are not multiple of 4 (SIMD)
570+
if input_shape[-1] % 4 != 0:
571+
return (PartitionAnchors(
572+
empty=True,
573+
), linear_layer)
574+
# Currenly only supporting vector-matrix multiplication
575+
if len(input_shape) > 0 and input_shape[-2] != 1:
576+
return (PartitionAnchors(
577+
empty=True,
578+
), linear_layer)
579+
580+
return (PartitionAnchors(
581+
inputs=[],
582+
weights=[(linear_layer, 1)],
583+
biases=[(linear_layer, 2)],
584+
output=[],
585+
others=[(linear_layer, 0)],
586+
), linear_layer)
587+
588+
def replacement_op(self) -> OpOverload:
589+
return torch.ops.cadence.quantized_w8a32_linear.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ReluPattern0,
2929
ReluPattern1,
3030
SoftmaxPattern,
31+
MixedW8A32LinearPattern,
3132
)
3233
from executorch.backends.cadence.aot.quantizer.utils import (
3334
find_sequential_partitions_aten,
@@ -109,6 +110,13 @@
109110
None,
110111
)
111112

113+
qconfig_A32W8sym = QuantizationConfig(
114+
input_activation= None,
115+
output_activation= None,
116+
weight= wgt_qspec_sym8s,
117+
bias= wgt_qspec_sym8s,
118+
)
119+
112120

113121
class CadenceAtenQuantizer(Quantizer):
114122
def __init__(
@@ -301,6 +309,16 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
301309
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
302310
super().__init__(quantizers)
303311

312+
class CadenceW8A32MixedQuantizer(CadenceQuantizer):
313+
"""
314+
Quantizer for Conversational Focus
315+
"""
316+
317+
def __init__(self) -> None:
318+
quantizers = []
319+
quantizers.append(CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym))
320+
super().__init__(quantizers)
321+
304322

305323
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
306324
"""

0 commit comments

Comments
 (0)