Skip to content

Commit 7253fd6

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Quantized Softmax Kernel (#14096)
Summary: Pull Request resolved: #14096 Generic implementation of quantized softmax, dummy implementation of DLA_V130 implementation for now NOTE: Mask parameter is nop Reviewed By: mcremon-meta Differential Revision: D78716203
1 parent d05a793 commit 7253fd6

File tree

4 files changed

+168
-1
lines changed

4 files changed

+168
-1
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@
300300
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
301301
)
302302

303+
lib.define(
304+
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
305+
)
306+
lib.define(
307+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
308+
)
309+
lib.define(
310+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
311+
)
312+
lib.define(
313+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
314+
)
315+
303316
# Load/store with iDMA. These only exist before memory planning.
304317
# Post memory planning, we check that outputs/inputs for the load/store are in
305318
# DTCM and replace idma_load/idma_store with idma_copy.
@@ -2097,3 +2110,29 @@ def softmax_f32_f32_meta(
20972110
half_to_float: Optional[bool] = None,
20982111
) -> torch.Tensor:
20992112
return self.new_empty(self.size(), dtype=self.dtype)
2113+
2114+
2115+
@register_fake("cadence::quantized_softmax")
2116+
def quantized_softmax_meta(
2117+
input: torch.Tensor,
2118+
mask: torch.Tensor,
2119+
dim: int,
2120+
in_scale: torch.Tensor,
2121+
in_zero_point: torch.Tensor,
2122+
out_scale: torch.Tensor,
2123+
out_zero_point: torch.Tensor,
2124+
) -> torch.Tensor:
2125+
return input.new_empty(input.size(), dtype=input.dtype)
2126+
2127+
2128+
@register_fake("cadence::quantized_softmax.per_tensor")
2129+
def quantized_softmax_per_tensor_meta(
2130+
input: torch.Tensor,
2131+
mask: torch.Tensor,
2132+
dim: int,
2133+
in_scale: float,
2134+
in_zero_point: int,
2135+
out_scale: float,
2136+
out_zero_point: int,
2137+
) -> torch.Tensor:
2138+
return input.new_empty(input.size(), dtype=input.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import Any, Dict, List, Tuple
9+
from typing import Any, cast, Dict, List, Tuple
1010

1111
import torch
12+
from executorch.backends.cadence.aot.compiler_utils import get_shape
1213
from executorch.backends.cadence.aot.quantizer.patterns import (
1314
AddmmPattern,
1415
AddPattern,
@@ -21,6 +22,7 @@
2122
MatmulPattern,
2223
ReluPattern0,
2324
ReluPattern1,
25+
SoftmaxPattern,
2426
)
2527
from executorch.backends.cadence.aot.quantizer.utils import (
2628
create_zero_bias_int32,
@@ -376,6 +378,73 @@ def get_args_and_kwargs_relu(
376378
return args, kwargs
377379

378380

381+
def get_args_and_kwargs_softmax(
382+
graph_module: GraphModule,
383+
inputs_inputs: List[fx.Node],
384+
dequants_inputs: List[fx.Node],
385+
quant_node: fx.Node,
386+
op_node: fx.Node,
387+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
388+
# Make a dummy mask tensor
389+
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
390+
mask_shape = list(mask_shape) if mask_shape else []
391+
mask_shape[-1] = mask_shape[-1] // 16
392+
mask_tensor = graph_module.graph.call_function(
393+
torch.ops.aten.full.default,
394+
(
395+
mask_shape,
396+
0.0,
397+
),
398+
{"dtype": torch.int32},
399+
)
400+
# Make the scale and zero_point tensors
401+
in_scale_tensor = graph_module.graph.call_function(
402+
torch.ops.aten.full.default,
403+
(
404+
[1],
405+
dequants_inputs[0].args[1],
406+
),
407+
{"dtype": torch.float32},
408+
)
409+
in_zero_point_tensor = graph_module.graph.call_function(
410+
torch.ops.aten.full.default,
411+
(
412+
[1],
413+
dequants_inputs[0].args[2],
414+
),
415+
{"dtype": torch.int32},
416+
)
417+
out_scale_tensor = graph_module.graph.call_function(
418+
torch.ops.aten.full.default,
419+
(
420+
[1],
421+
quant_node.args[1],
422+
),
423+
{"dtype": torch.float32},
424+
)
425+
out_zero_point_tensor = graph_module.graph.call_function(
426+
torch.ops.aten.full.default,
427+
(
428+
[1],
429+
quant_node.args[2],
430+
),
431+
{"dtype": torch.int32},
432+
)
433+
434+
# Make the args and kwargs for the replacement op
435+
args = (
436+
inputs_inputs[0],
437+
mask_tensor,
438+
op_node.args[1],
439+
in_scale_tensor,
440+
in_zero_point_tensor,
441+
out_scale_tensor,
442+
out_zero_point_tensor,
443+
)
444+
kwargs = {}
445+
return args, kwargs
446+
447+
379448
class QuantFusion(ExportPass):
380449
# pyre-ignore[2]: Parameter `patterns` has no type specified
381450
def __init__(self, patterns) -> None:
@@ -511,6 +580,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
511580
dequants_inputs,
512581
quant_node,
513582
)
583+
elif isinstance(pattern, SoftmaxPattern):
584+
args, kwargs = get_args_and_kwargs_softmax(
585+
graph_module,
586+
inputs_inputs,
587+
dequants_inputs,
588+
quant_node,
589+
op_node,
590+
)
514591
fused = graph_module.graph.call_function(
515592
pattern.replacement_op(),
516593
args,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,25 @@ def partition_types(self) -> List[OpOverload]:
417417
class ReluPattern1(ReluBasePattern):
418418
def partition_types(self) -> List[OpOverload]:
419419
return [torch.ops.aten.relu_.default]
420+
421+
422+
class SoftmaxPattern(QuantizationPattern):
423+
424+
def partition_types(self) -> List[OpOverload]:
425+
return [torch.ops.aten._softmax.default]
426+
427+
def get_anchors(
428+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
429+
) -> PartitionAnchors:
430+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
431+
softmax_node = fused_partition[0].nodes[-1]
432+
433+
return PartitionAnchors(
434+
inputs=[(softmax_node, 0)],
435+
weights=[],
436+
biases=[],
437+
output=[(softmax_node,)],
438+
)
439+
440+
def replacement_op(self) -> OpOverload:
441+
return torch.ops.cadence.quantized_softmax.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
QuantizationPattern,
2424
ReluPattern0,
2525
ReluPattern1,
26+
SoftmaxPattern,
2627
)
2728
from executorch.backends.cadence.aot.quantizer.utils import (
2829
find_sequential_partitions_aten,
@@ -54,6 +55,15 @@
5455
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5556
)
5657

58+
act_qspec_asym16s = QuantizationSpec(
59+
dtype=torch.int16,
60+
quant_min=-32768,
61+
quant_max=32767,
62+
qscheme=torch.per_tensor_affine,
63+
is_dynamic=False,
64+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
65+
)
66+
5767
wgt_qspec_asym8s = QuantizationSpec(
5868
dtype=torch.int8,
5969
quant_min=-128,
@@ -88,6 +98,13 @@
8898
None,
8999
)
90100

101+
qconfig_A16 = QuantizationConfig(
102+
act_qspec_asym16s,
103+
act_qspec_asym16s,
104+
wgt_qspec_asym8s,
105+
None,
106+
)
107+
91108

92109
class CadenceAtenQuantizer(Quantizer):
93110
def __init__(
@@ -260,3 +277,15 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
260277
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
261278
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
262279
super().__init__(quantizers)
280+
281+
282+
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
283+
"""
284+
Quantizer including A16 softmax
285+
"""
286+
287+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
288+
if quantizers is None:
289+
quantizers = get_cadence_default_quantizers()
290+
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
291+
super().__init__(quantizers)

0 commit comments

Comments
 (0)