Skip to content

Commit 56659e4

Browse files
authored
Revert "Quantized Softmax Kernel" (#14364)
This reverts commit 94f62b7. Not landed internally and failing internal tests here: [D82596569](https://www.internalfb.com/diff/D82596569), causing fix-up patch
1 parent facf35d commit 56659e4

File tree

4 files changed

+1
-168
lines changed

4 files changed

+1
-168
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -324,19 +324,6 @@
324324
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
325325
)
326326

327-
lib.define(
328-
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
329-
)
330-
lib.define(
331-
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
332-
)
333-
lib.define(
334-
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
335-
)
336-
lib.define(
337-
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
338-
)
339-
340327
# Load/store with iDMA. These only exist before memory planning.
341328
# Post memory planning, we check that outputs/inputs for the load/store are in
342329
# DTCM and replace idma_load/idma_store with idma_copy.
@@ -2342,29 +2329,3 @@ def softmax_f32_f32_meta(
23422329
half_to_float: Optional[bool] = None,
23432330
) -> torch.Tensor:
23442331
return self.new_empty(self.size(), dtype=self.dtype)
2345-
2346-
2347-
@register_fake("cadence::quantized_softmax")
2348-
def quantized_softmax_meta(
2349-
input: torch.Tensor,
2350-
mask: torch.Tensor,
2351-
dim: int,
2352-
in_scale: torch.Tensor,
2353-
in_zero_point: torch.Tensor,
2354-
out_scale: torch.Tensor,
2355-
out_zero_point: torch.Tensor,
2356-
) -> torch.Tensor:
2357-
return input.new_empty(input.size(), dtype=input.dtype)
2358-
2359-
2360-
@register_fake("cadence::quantized_softmax.per_tensor")
2361-
def quantized_softmax_per_tensor_meta(
2362-
input: torch.Tensor,
2363-
mask: torch.Tensor,
2364-
dim: int,
2365-
in_scale: float,
2366-
in_zero_point: int,
2367-
out_scale: float,
2368-
out_zero_point: int,
2369-
) -> torch.Tensor:
2370-
return input.new_empty(input.size(), dtype=input.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
# pyre-strict
88

9-
from typing import Any, cast, Dict, List, Tuple
9+
from typing import Any, Dict, List, Tuple
1010

1111
import torch
12-
from executorch.backends.cadence.aot.compiler_utils import get_shape
1312
from executorch.backends.cadence.aot.quantizer.patterns import (
1413
AddmmPattern,
1514
AddPattern,
@@ -26,7 +25,6 @@
2625
MatmulPattern,
2726
ReluPattern0,
2827
ReluPattern1,
29-
SoftmaxPattern,
3028
)
3129
from executorch.backends.cadence.aot.quantizer.utils import (
3230
check_out_zero_point_is_min_range,
@@ -390,73 +388,6 @@ def get_args_and_kwargs_relu(
390388
return args, kwargs
391389

392390

393-
def get_args_and_kwargs_softmax(
394-
graph_module: GraphModule,
395-
inputs_inputs: List[fx.Node],
396-
dequants_inputs: List[fx.Node],
397-
quant_node: fx.Node,
398-
op_node: fx.Node,
399-
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
400-
# Make a dummy mask tensor
401-
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
402-
mask_shape = list(mask_shape) if mask_shape else []
403-
mask_shape[-1] = mask_shape[-1] // 16
404-
mask_tensor = graph_module.graph.call_function(
405-
torch.ops.aten.full.default,
406-
(
407-
mask_shape,
408-
0.0,
409-
),
410-
{"dtype": torch.int32},
411-
)
412-
# Make the scale and zero_point tensors
413-
in_scale_tensor = graph_module.graph.call_function(
414-
torch.ops.aten.full.default,
415-
(
416-
[1],
417-
dequants_inputs[0].args[1],
418-
),
419-
{"dtype": torch.float32},
420-
)
421-
in_zero_point_tensor = graph_module.graph.call_function(
422-
torch.ops.aten.full.default,
423-
(
424-
[1],
425-
dequants_inputs[0].args[2],
426-
),
427-
{"dtype": torch.int32},
428-
)
429-
out_scale_tensor = graph_module.graph.call_function(
430-
torch.ops.aten.full.default,
431-
(
432-
[1],
433-
quant_node.args[1],
434-
),
435-
{"dtype": torch.float32},
436-
)
437-
out_zero_point_tensor = graph_module.graph.call_function(
438-
torch.ops.aten.full.default,
439-
(
440-
[1],
441-
quant_node.args[2],
442-
),
443-
{"dtype": torch.int32},
444-
)
445-
446-
# Make the args and kwargs for the replacement op
447-
args = (
448-
inputs_inputs[0],
449-
mask_tensor,
450-
op_node.args[1],
451-
in_scale_tensor,
452-
in_zero_point_tensor,
453-
out_scale_tensor,
454-
out_zero_point_tensor,
455-
)
456-
kwargs = {}
457-
return args, kwargs
458-
459-
460391
class QuantFusion(ExportPass):
461392
# pyre-ignore[2]: Parameter `patterns` has no type specified
462393
def __init__(self, patterns) -> None:
@@ -612,14 +543,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
612543
dequants_inputs,
613544
quant_node,
614545
)
615-
elif isinstance(pattern, SoftmaxPattern):
616-
args, kwargs = get_args_and_kwargs_softmax(
617-
graph_module,
618-
inputs_inputs,
619-
dequants_inputs,
620-
quant_node,
621-
anchor_output_node,
622-
)
623546
fused = graph_module.graph.call_function(
624547
pattern.replacement_op(),
625548
args,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -485,25 +485,3 @@ def partition_types(self) -> List[OpOverload]:
485485
class Conv2dReluPattern1(ConvReluBasePattern):
486486
def partition_types(self) -> List[OpOverload]:
487487
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]
488-
489-
490-
class SoftmaxPattern(QuantizationPattern):
491-
492-
def partition_types(self) -> List[OpOverload]:
493-
return [torch.ops.aten._softmax.default]
494-
495-
def get_anchors(
496-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
497-
) -> PartitionAnchors:
498-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
499-
softmax_node = fused_partition[0].nodes[-1]
500-
501-
return PartitionAnchors(
502-
inputs=[(softmax_node, 0)],
503-
weights=[],
504-
biases=[],
505-
output=[(softmax_node,)],
506-
)
507-
508-
def replacement_op(self) -> OpOverload:
509-
return torch.ops.cadence.quantized_softmax.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
QuantizationPattern,
2828
ReluPattern0,
2929
ReluPattern1,
30-
SoftmaxPattern,
3130
)
3231
from executorch.backends.cadence.aot.quantizer.utils import (
3332
find_sequential_partitions_aten,
@@ -59,15 +58,6 @@
5958
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
6059
)
6160

62-
act_qspec_asym16s = QuantizationSpec(
63-
dtype=torch.int16,
64-
quant_min=-32768,
65-
quant_max=32767,
66-
qscheme=torch.per_tensor_affine,
67-
is_dynamic=False,
68-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
69-
)
70-
7161
wgt_qspec_asym8s = QuantizationSpec(
7262
dtype=torch.int8,
7363
quant_min=-128,
@@ -102,13 +92,6 @@
10292
None,
10393
)
10494

105-
qconfig_A16 = QuantizationConfig(
106-
act_qspec_asym16s,
107-
act_qspec_asym16s,
108-
wgt_qspec_asym8s,
109-
None,
110-
)
111-
11295

11396
class CadenceAtenQuantizer(Quantizer):
11497
def __init__(
@@ -300,15 +283,3 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
300283
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
301284
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
302285
super().__init__(quantizers)
303-
304-
305-
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
306-
"""
307-
Quantizer including A16 softmax
308-
"""
309-
310-
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
311-
if quantizers is None:
312-
quantizers = get_cadence_default_quantizers()
313-
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
314-
super().__init__(quantizers)

0 commit comments

Comments
 (0)