|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | -from typing import Any, Dict, List, Tuple |
| 9 | +from typing import Any, cast, Dict, List, Tuple |
10 | 10 |
|
11 | 11 | import torch |
| 12 | +from executorch.backends.cadence.aot.compiler_utils import get_shape |
12 | 13 | from executorch.backends.cadence.aot.quantizer.patterns import ( |
13 | 14 | AddmmPattern, |
14 | 15 | AddPattern, |
|
25 | 26 | MatmulPattern, |
26 | 27 | ReluPattern0, |
27 | 28 | ReluPattern1, |
| 29 | + SoftmaxPattern, |
28 | 30 | ) |
29 | 31 | from executorch.backends.cadence.aot.quantizer.utils import ( |
30 | 32 | check_out_zero_point_is_min_range, |
@@ -388,6 +390,73 @@ def get_args_and_kwargs_relu( |
388 | 390 | return args, kwargs |
389 | 391 |
|
390 | 392 |
|
| 393 | +def get_args_and_kwargs_softmax( |
| 394 | + graph_module: GraphModule, |
| 395 | + inputs_inputs: List[fx.Node], |
| 396 | + dequants_inputs: List[fx.Node], |
| 397 | + quant_node: fx.Node, |
| 398 | + op_node: fx.Node, |
| 399 | +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: |
| 400 | + # Make a dummy mask tensor |
| 401 | + mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0])) |
| 402 | + mask_shape = list(mask_shape) if mask_shape else [] |
| 403 | + mask_shape[-1] = mask_shape[-1] // 16 |
| 404 | + mask_tensor = graph_module.graph.call_function( |
| 405 | + torch.ops.aten.full.default, |
| 406 | + ( |
| 407 | + mask_shape, |
| 408 | + 0.0, |
| 409 | + ), |
| 410 | + {"dtype": torch.int32}, |
| 411 | + ) |
| 412 | + # Make the scale and zero_point tensors |
| 413 | + in_scale_tensor = graph_module.graph.call_function( |
| 414 | + torch.ops.aten.full.default, |
| 415 | + ( |
| 416 | + [1], |
| 417 | + dequants_inputs[0].args[1], |
| 418 | + ), |
| 419 | + {"dtype": torch.float32}, |
| 420 | + ) |
| 421 | + in_zero_point_tensor = graph_module.graph.call_function( |
| 422 | + torch.ops.aten.full.default, |
| 423 | + ( |
| 424 | + [1], |
| 425 | + dequants_inputs[0].args[2], |
| 426 | + ), |
| 427 | + {"dtype": torch.int32}, |
| 428 | + ) |
| 429 | + out_scale_tensor = graph_module.graph.call_function( |
| 430 | + torch.ops.aten.full.default, |
| 431 | + ( |
| 432 | + [1], |
| 433 | + quant_node.args[1], |
| 434 | + ), |
| 435 | + {"dtype": torch.float32}, |
| 436 | + ) |
| 437 | + out_zero_point_tensor = graph_module.graph.call_function( |
| 438 | + torch.ops.aten.full.default, |
| 439 | + ( |
| 440 | + [1], |
| 441 | + quant_node.args[2], |
| 442 | + ), |
| 443 | + {"dtype": torch.int32}, |
| 444 | + ) |
| 445 | + |
| 446 | + # Make the args and kwargs for the replacement op |
| 447 | + args = ( |
| 448 | + inputs_inputs[0], |
| 449 | + mask_tensor, |
| 450 | + op_node.args[1], |
| 451 | + in_scale_tensor, |
| 452 | + in_zero_point_tensor, |
| 453 | + out_scale_tensor, |
| 454 | + out_zero_point_tensor, |
| 455 | + ) |
| 456 | + kwargs = {} |
| 457 | + return args, kwargs |
| 458 | + |
| 459 | + |
391 | 460 | class QuantFusion(ExportPass): |
392 | 461 | # pyre-ignore[2]: Parameter `patterns` has no type specified |
393 | 462 | def __init__(self, patterns) -> None: |
@@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 |
543 | 612 | dequants_inputs, |
544 | 613 | quant_node, |
545 | 614 | ) |
| 615 | + elif isinstance(pattern, SoftmaxPattern): |
| 616 | + args, kwargs = get_args_and_kwargs_softmax( |
| 617 | + graph_module, |
| 618 | + inputs_inputs, |
| 619 | + dequants_inputs, |
| 620 | + quant_node, |
| 621 | + anchor_output_node, |
| 622 | + ) |
546 | 623 | fused = graph_module.graph.call_function( |
547 | 624 | pattern.replacement_op(), |
548 | 625 | args, |
|
0 commit comments