Skip to content

Commit 684b5fd

Browse files
authored
Quantized Softmax Kernel
Differential Revision: D82596569 Pull Request resolved: #14518
1 parent dcc3978 commit 684b5fd

File tree

4 files changed

+168
-1
lines changed

4 files changed

+168
-1
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,19 @@
390390
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
391391
)
392392

393+
lib.define(
394+
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
395+
)
396+
lib.define(
397+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
398+
)
399+
lib.define(
400+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
401+
)
402+
lib.define(
403+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
404+
)
405+
393406
# Load/store with iDMA. These only exist before memory planning.
394407
# Post memory planning, we check that outputs/inputs for the load/store are in
395408
# DTCM and replace idma_load/idma_store with idma_copy.
@@ -2523,3 +2536,29 @@ def softmax_f32_f32_meta(
25232536
half_to_float: Optional[bool] = None,
25242537
) -> torch.Tensor:
25252538
return self.new_empty(self.size(), dtype=self.dtype)
2539+
2540+
2541+
@register_fake("cadence::quantized_softmax")
2542+
def quantized_softmax_meta(
2543+
input: torch.Tensor,
2544+
mask: torch.Tensor,
2545+
dim: int,
2546+
in_scale: torch.Tensor,
2547+
in_zero_point: torch.Tensor,
2548+
out_scale: torch.Tensor,
2549+
out_zero_point: torch.Tensor,
2550+
) -> torch.Tensor:
2551+
return input.new_empty(input.size(), dtype=input.dtype)
2552+
2553+
2554+
@register_fake("cadence::quantized_softmax.per_tensor")
2555+
def quantized_softmax_per_tensor_meta(
2556+
input: torch.Tensor,
2557+
mask: torch.Tensor,
2558+
dim: int,
2559+
in_scale: float,
2560+
in_zero_point: int,
2561+
out_scale: float,
2562+
out_zero_point: int,
2563+
) -> torch.Tensor:
2564+
return input.new_empty(input.size(), dtype=input.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import Any, Dict, List, Tuple
9+
from typing import Any, cast, Dict, List, Tuple
1010

1111
import torch
12+
from executorch.backends.cadence.aot.compiler_utils import get_shape
1213
from executorch.backends.cadence.aot.quantizer.patterns import (
1314
AddmmPattern,
1415
AddPattern,
@@ -25,6 +26,7 @@
2526
MatmulPattern,
2627
ReluPattern0,
2728
ReluPattern1,
29+
SoftmaxPattern,
2830
)
2931
from executorch.backends.cadence.aot.quantizer.utils import (
3032
check_out_zero_point_is_min_range,
@@ -388,6 +390,73 @@ def get_args_and_kwargs_relu(
388390
return args, kwargs
389391

390392

393+
def get_args_and_kwargs_softmax(
394+
graph_module: GraphModule,
395+
inputs_inputs: List[fx.Node],
396+
dequants_inputs: List[fx.Node],
397+
quant_node: fx.Node,
398+
op_node: fx.Node,
399+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
400+
# Make a dummy mask tensor
401+
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
402+
mask_shape = list(mask_shape) if mask_shape else []
403+
mask_shape[-1] = mask_shape[-1] // 16
404+
mask_tensor = graph_module.graph.call_function(
405+
torch.ops.aten.full.default,
406+
(
407+
mask_shape,
408+
0.0,
409+
),
410+
{"dtype": torch.int32},
411+
)
412+
# Make the scale and zero_point tensors
413+
in_scale_tensor = graph_module.graph.call_function(
414+
torch.ops.aten.full.default,
415+
(
416+
[1],
417+
dequants_inputs[0].args[1],
418+
),
419+
{"dtype": torch.float32},
420+
)
421+
in_zero_point_tensor = graph_module.graph.call_function(
422+
torch.ops.aten.full.default,
423+
(
424+
[1],
425+
dequants_inputs[0].args[2],
426+
),
427+
{"dtype": torch.int32},
428+
)
429+
out_scale_tensor = graph_module.graph.call_function(
430+
torch.ops.aten.full.default,
431+
(
432+
[1],
433+
quant_node.args[1],
434+
),
435+
{"dtype": torch.float32},
436+
)
437+
out_zero_point_tensor = graph_module.graph.call_function(
438+
torch.ops.aten.full.default,
439+
(
440+
[1],
441+
quant_node.args[2],
442+
),
443+
{"dtype": torch.int32},
444+
)
445+
446+
# Make the args and kwargs for the replacement op
447+
args = (
448+
inputs_inputs[0],
449+
mask_tensor,
450+
op_node.args[1],
451+
in_scale_tensor,
452+
in_zero_point_tensor,
453+
out_scale_tensor,
454+
out_zero_point_tensor,
455+
)
456+
kwargs = {}
457+
return args, kwargs
458+
459+
391460
class QuantFusion(ExportPass):
392461
# pyre-ignore[2]: Parameter `patterns` has no type specified
393462
def __init__(self, patterns) -> None:
@@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
543612
dequants_inputs,
544613
quant_node,
545614
)
615+
elif isinstance(pattern, SoftmaxPattern):
616+
args, kwargs = get_args_and_kwargs_softmax(
617+
graph_module,
618+
inputs_inputs,
619+
dequants_inputs,
620+
quant_node,
621+
anchor_output_node,
622+
)
546623
fused = graph_module.graph.call_function(
547624
pattern.replacement_op(),
548625
args,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,25 @@ def partition_types(self) -> List[OpOverload]:
485485
class Conv2dReluPattern1(ConvReluBasePattern):
486486
def partition_types(self) -> List[OpOverload]:
487487
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]
488+
489+
490+
class SoftmaxPattern(QuantizationPattern):
491+
492+
def partition_types(self) -> List[OpOverload]:
493+
return [torch.ops.aten._softmax.default]
494+
495+
def get_anchors(
496+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
497+
) -> PartitionAnchors:
498+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
499+
softmax_node = fused_partition[0].nodes[-1]
500+
501+
return PartitionAnchors(
502+
inputs=[(softmax_node, 0)],
503+
weights=[],
504+
biases=[],
505+
output=[(softmax_node,)],
506+
)
507+
508+
def replacement_op(self) -> OpOverload:
509+
return torch.ops.cadence.quantized_softmax.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
QuantizationPattern,
2828
ReluPattern0,
2929
ReluPattern1,
30+
SoftmaxPattern,
3031
)
3132
from executorch.backends.cadence.aot.quantizer.utils import (
3233
find_sequential_partitions_aten,
@@ -58,6 +59,15 @@
5859
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5960
)
6061

62+
act_qspec_asym16s = QuantizationSpec(
63+
dtype=torch.int16,
64+
quant_min=-32768,
65+
quant_max=32767,
66+
qscheme=torch.per_tensor_affine,
67+
is_dynamic=False,
68+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
69+
)
70+
6171
wgt_qspec_asym8s = QuantizationSpec(
6272
dtype=torch.int8,
6373
quant_min=-128,
@@ -92,6 +102,13 @@
92102
None,
93103
)
94104

105+
qconfig_A16 = QuantizationConfig(
106+
act_qspec_asym16s,
107+
act_qspec_asym16s,
108+
wgt_qspec_asym8s,
109+
None,
110+
)
111+
95112

96113
class CadenceAtenQuantizer(Quantizer):
97114
def __init__(
@@ -283,3 +300,15 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
283300
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
284301
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
285302
super().__init__(quantizers)
303+
304+
305+
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
306+
"""
307+
Quantizer including A16 softmax
308+
"""
309+
310+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
311+
if quantizers is None:
312+
quantizers = get_cadence_default_quantizers()
313+
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
314+
super().__init__(quantizers)

0 commit comments

Comments
 (0)