|  | 
| 6 | 6 | 
 | 
| 7 | 7 | # pyre-strict | 
| 8 | 8 | 
 | 
| 9 |  | -from typing import Any, Dict, List, Tuple | 
|  | 9 | +from typing import Any, cast, Dict, List, Tuple | 
| 10 | 10 | 
 | 
| 11 | 11 | import torch | 
|  | 12 | +from executorch.backends.cadence.aot.compiler_utils import get_shape | 
| 12 | 13 | from executorch.backends.cadence.aot.quantizer.patterns import ( | 
| 13 | 14 |     AddmmPattern, | 
| 14 | 15 |     AddPattern, | 
|  | 
| 25 | 26 |     MatmulPattern, | 
| 26 | 27 |     ReluPattern0, | 
| 27 | 28 |     ReluPattern1, | 
|  | 29 | +    SoftmaxPattern, | 
| 28 | 30 | ) | 
| 29 | 31 | from executorch.backends.cadence.aot.quantizer.utils import ( | 
| 30 | 32 |     check_out_zero_point_is_min_range, | 
| @@ -388,6 +390,73 @@ def get_args_and_kwargs_relu( | 
| 388 | 390 |     return args, kwargs | 
| 389 | 391 | 
 | 
| 390 | 392 | 
 | 
|  | 393 | +def get_args_and_kwargs_softmax( | 
|  | 394 | +    graph_module: GraphModule, | 
|  | 395 | +    inputs_inputs: List[fx.Node], | 
|  | 396 | +    dequants_inputs: List[fx.Node], | 
|  | 397 | +    quant_node: fx.Node, | 
|  | 398 | +    op_node: fx.Node, | 
|  | 399 | +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: | 
|  | 400 | +    # Make a dummy mask tensor | 
|  | 401 | +    mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0])) | 
|  | 402 | +    mask_shape = list(mask_shape) if mask_shape else [] | 
|  | 403 | +    mask_shape[-1] = mask_shape[-1] // 16 | 
|  | 404 | +    mask_tensor = graph_module.graph.call_function( | 
|  | 405 | +        torch.ops.aten.full.default, | 
|  | 406 | +        ( | 
|  | 407 | +            mask_shape, | 
|  | 408 | +            0.0, | 
|  | 409 | +        ), | 
|  | 410 | +        {"dtype": torch.int32}, | 
|  | 411 | +    ) | 
|  | 412 | +    # Make the scale and zero_point tensors | 
|  | 413 | +    in_scale_tensor = graph_module.graph.call_function( | 
|  | 414 | +        torch.ops.aten.full.default, | 
|  | 415 | +        ( | 
|  | 416 | +            [1], | 
|  | 417 | +            dequants_inputs[0].args[1], | 
|  | 418 | +        ), | 
|  | 419 | +        {"dtype": torch.float32}, | 
|  | 420 | +    ) | 
|  | 421 | +    in_zero_point_tensor = graph_module.graph.call_function( | 
|  | 422 | +        torch.ops.aten.full.default, | 
|  | 423 | +        ( | 
|  | 424 | +            [1], | 
|  | 425 | +            dequants_inputs[0].args[2], | 
|  | 426 | +        ), | 
|  | 427 | +        {"dtype": torch.int32}, | 
|  | 428 | +    ) | 
|  | 429 | +    out_scale_tensor = graph_module.graph.call_function( | 
|  | 430 | +        torch.ops.aten.full.default, | 
|  | 431 | +        ( | 
|  | 432 | +            [1], | 
|  | 433 | +            quant_node.args[1], | 
|  | 434 | +        ), | 
|  | 435 | +        {"dtype": torch.float32}, | 
|  | 436 | +    ) | 
|  | 437 | +    out_zero_point_tensor = graph_module.graph.call_function( | 
|  | 438 | +        torch.ops.aten.full.default, | 
|  | 439 | +        ( | 
|  | 440 | +            [1], | 
|  | 441 | +            quant_node.args[2], | 
|  | 442 | +        ), | 
|  | 443 | +        {"dtype": torch.int32}, | 
|  | 444 | +    ) | 
|  | 445 | + | 
|  | 446 | +    # Make the args and kwargs for the replacement op | 
|  | 447 | +    args = ( | 
|  | 448 | +        inputs_inputs[0], | 
|  | 449 | +        mask_tensor, | 
|  | 450 | +        op_node.args[1], | 
|  | 451 | +        in_scale_tensor, | 
|  | 452 | +        in_zero_point_tensor, | 
|  | 453 | +        out_scale_tensor, | 
|  | 454 | +        out_zero_point_tensor, | 
|  | 455 | +    ) | 
|  | 456 | +    kwargs = {} | 
|  | 457 | +    return args, kwargs | 
|  | 458 | + | 
|  | 459 | + | 
| 391 | 460 | class QuantFusion(ExportPass): | 
| 392 | 461 |     # pyre-ignore[2]: Parameter `patterns` has no type specified | 
| 393 | 462 |     def __init__(self, patterns) -> None: | 
| @@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult:  # noqa: C901 | 
| 543 | 612 |                             dequants_inputs, | 
| 544 | 613 |                             quant_node, | 
| 545 | 614 |                         ) | 
|  | 615 | +                    elif isinstance(pattern, SoftmaxPattern): | 
|  | 616 | +                        args, kwargs = get_args_and_kwargs_softmax( | 
|  | 617 | +                            graph_module, | 
|  | 618 | +                            inputs_inputs, | 
|  | 619 | +                            dequants_inputs, | 
|  | 620 | +                            quant_node, | 
|  | 621 | +                            anchor_output_node, | 
|  | 622 | +                        ) | 
| 546 | 623 |                     fused = graph_module.graph.call_function( | 
| 547 | 624 |                         pattern.replacement_op(), | 
| 548 | 625 |                         args, | 
|  | 
0 commit comments