Skip to content

Commit a0add7f

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Including mixed quant Linear op in Jarvis (#14820)
Summary: # Summary This diff includes a general and HiFi4 optimized Linear operator. Specifically, it adds both a standard Linear implementation and a version optimized for HiFi4 DSPs, ensuring better performance on supported hardware. Reviewed By: mcremon-meta, skrtskrtfb Differential Revision: D81605171
1 parent 740fe14 commit a0add7f

File tree

5 files changed

+140
-1
lines changed

5 files changed

+140
-1
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,8 @@
548548
kernels:
549549
- arg_meta: null
550550
kernel_name: impl::HiFi::quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out
551+
552+
- func: cadence::quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
553+
kernels:
554+
- arg_meta: null
555+
kernel_name: impl::HiFi::quantized_w8a32_linear_out

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,14 @@
564564
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
565565
)
566566

567+
lib.define(
568+
"quantized_w8a32_linear(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
569+
)
570+
lib.define(
571+
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
572+
)
573+
574+
567575
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
568576
aten_lib = Library("aten", "FRAGMENT")
569577
aten_lib.define(
@@ -2562,3 +2570,22 @@ def quantized_softmax_per_tensor_meta(
25622570
out_zero_point: int,
25632571
) -> torch.Tensor:
25642572
return input.new_empty(input.size(), dtype=input.dtype)
2573+
2574+
2575+
@register_fake("cadence::quantized_w8a32_linear")
2576+
def quantized_w8a32_linear_meta(
2577+
src: torch.Tensor,
2578+
weight: torch.Tensor,
2579+
w_scale: float,
2580+
bias: torch.Tensor,
2581+
b_scale: float,
2582+
) -> torch.Tensor:
2583+
# src comes in shape [leading_dims, in_dim]
2584+
# weight comes in shape [in_dim, out_dim]
2585+
# output comes in empty with shape [leading_dims, out_dim]
2586+
src_shape = list(src.shape)
2587+
weight_shape = weight.shape
2588+
assert len(weight_shape) == 2
2589+
assert src_shape[-1] == weight_shape[-1]
2590+
src_shape[-1] = weight_shape[0]
2591+
return src.new_empty(src_shape, dtype=src.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32LinearPattern,
2728
ReluPattern0,
2829
ReluPattern1,
2930
SoftmaxPattern,
@@ -390,6 +391,29 @@ def get_args_and_kwargs_relu(
390391
return args, kwargs
391392

392393

394+
def get_args_and_kwargs_mixed_w8a32_linear(
395+
graph_module: GraphModule,
396+
other_inputs: List[fx.Node],
397+
weights_inputs: List[fx.Node],
398+
dequants_weights: List[fx.Node],
399+
bias_inputs: List[fx.Node],
400+
dequants_biases: List[fx.Node],
401+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
402+
w_scale_ = dequants_weights[0].args[1]
403+
b_scale_ = dequants_biases[0].args[1]
404+
405+
args = (
406+
other_inputs[0],
407+
weights_inputs[0],
408+
w_scale_,
409+
bias_inputs[0],
410+
b_scale_,
411+
)
412+
kwargs = {}
413+
414+
return args, kwargs
415+
416+
393417
def get_args_and_kwargs_softmax(
394418
graph_module: GraphModule,
395419
inputs_inputs: List[fx.Node],
@@ -617,6 +641,15 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
617641
quant_node,
618642
op_node,
619643
)
644+
elif isinstance(pattern, MixedW8A32LinearPattern):
645+
args, kwargs = get_args_and_kwargs_mixed_w8a32_linear(
646+
graph_module,
647+
other_inputs,
648+
weights_inputs,
649+
dequants_weights,
650+
bias_inputs,
651+
dequants_biases,
652+
)
620653

621654
fused = graph_module.graph.call_function(
622655
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:
524524

525525

526526
class SoftmaxPattern(QuantizationPattern):
527-
528527
def partition_types(self) -> List[OpOverload]:
529528
return [torch.ops.aten._softmax.default]
530529

@@ -546,3 +545,57 @@ def get_anchors(
546545

547546
def replacement_op(self) -> OpOverload:
548547
return torch.ops.cadence.quantized_softmax.default
548+
549+
550+
class MixedW8A32LinearPattern(QuantizationPattern):
551+
def partition_types(self) -> List[OpOverload]:
552+
return [torch.ops.aten.linear.default]
553+
554+
def get_anchors(
555+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
556+
) -> Tuple[PartitionAnchors, fx.Node]:
557+
# pyre-ignore[29]
558+
linear_layer = fused_partition[0].nodes[-1]
559+
560+
# Bail if the arguments have different shapes than expected
561+
if len(linear_layer.args) != 3 or len(linear_layer.kwargs) > 0:
562+
return (
563+
PartitionAnchors(
564+
empty=True,
565+
),
566+
linear_layer,
567+
)
568+
569+
input_node = linear_layer.args[0]
570+
input_shape = input_node.meta["tensor_meta"].shape
571+
572+
# Bail if the weights are not multiple of 4 (SIMD)
573+
if input_shape[-1] % 4 != 0:
574+
return (
575+
PartitionAnchors(
576+
empty=True,
577+
),
578+
linear_layer,
579+
)
580+
# Currenly only supporting vector-matrix multiplication
581+
if len(input_shape) > 0 and input_shape[-2] != 1:
582+
return (
583+
PartitionAnchors(
584+
empty=True,
585+
),
586+
linear_layer,
587+
)
588+
589+
return (
590+
PartitionAnchors(
591+
inputs=[],
592+
weights=[(linear_layer, 1)],
593+
biases=[(linear_layer, 2)],
594+
output=[],
595+
others=[(linear_layer, 0)],
596+
),
597+
linear_layer,
598+
)
599+
600+
def replacement_op(self) -> OpOverload:
601+
return torch.ops.cadence.quantized_w8a32_linear.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32LinearPattern,
2728
QuantizationPattern,
2829
ReluPattern0,
2930
ReluPattern1,
@@ -109,6 +110,13 @@
109110
None,
110111
)
111112

113+
qconfig_A32W8sym = QuantizationConfig(
114+
input_activation=None,
115+
output_activation=None,
116+
weight=wgt_qspec_sym8s,
117+
bias=wgt_qspec_sym8s,
118+
)
119+
112120

113121
class CadenceAtenQuantizer(Quantizer):
114122
def __init__(
@@ -302,6 +310,19 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
302310
super().__init__(quantizers)
303311

304312

313+
class CadenceW8A32MixedQuantizer(CadenceQuantizer):
314+
"""
315+
Quantizer for Conversational Focus
316+
"""
317+
318+
def __init__(self) -> None:
319+
quantizers = []
320+
quantizers.append(
321+
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
322+
)
323+
super().__init__(quantizers)
324+
325+
305326
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
306327
"""
307328
Quantizer including A16 softmax

0 commit comments

Comments
 (0)