Skip to content

Commit d367112

Browse files
committed
Update base for Update on "[ET-VK] Adding push constant and ubo verison of select and slice ops to improve memory and performance."
Adding push constant and ubo verison of select and slice ops to improve memory and performance. * Updated `transfer_buffer.yaml` and `transfer_texture.yaml` to include `UBO_PARAMS` parameter and generate variants for `select` and `slice` ops with UBO parameters. * Updated `transfer.glsl` to generate ubo and push constant versions of `select` and `slice` ops with UBO parameters. Differential Revision: [D78095262](https://our.internmc.facebook.com/intern/diff/D78095262/) [ghstack-poisoned]
2 parents 4cd537b + 97a61f4 commit d367112

File tree

25 files changed

+1166
-833
lines changed

25 files changed

+1166
-833
lines changed

backends/cadence/aot/compiler.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -210,6 +209,21 @@ def quantize_pt2(
210209
return program
211210

212211

212+
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
213+
torch.ops.aten._linalg_det.default,
214+
torch.ops.aten._linalg_svd.default,
215+
torch.ops.aten._native_batch_norm_legit_functional.default,
216+
torch.ops.aten.linear.default,
217+
torch.ops.aten.linalg_vector_norm.default,
218+
torch.ops.aten.unfold.default,
219+
torch.ops.aten.angle.default,
220+
torch.ops.aten.rms_norm.default,
221+
]
222+
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
223+
torch.ops.aten.rms_norm.default,
224+
)
225+
226+
213227
def _lower_ep_to_edge(
214228
expo_program: ExportedProgram,
215229
dump_graphs: bool = False,
@@ -226,20 +240,11 @@ def _lower_ep_to_edge(
226240
compile_config=EdgeCompileConfig(
227241
_skip_dim_order=True,
228242
# Allow specific non-core aten ops in the IR.
229-
_core_aten_ops_exception_list=[
230-
torch.ops.aten._linalg_det.default,
231-
torch.ops.aten._linalg_svd.default,
232-
torch.ops.aten._native_batch_norm_legit_functional.default,
233-
torch.ops.aten.linear.default,
234-
torch.ops.aten.linalg_vector_norm.default,
235-
torch.ops.aten.unfold.default,
236-
torch.ops.aten.angle.default,
237-
torch.ops.aten.rms_norm.default,
238-
]
243+
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
239244
+ (core_aten_exceptions or []),
240245
),
241246
constant_methods=constant_methods,
242-
preserve_ops=(torch.ops.aten.rms_norm.default,),
247+
preserve_ops=TO_EDGE_PRESERVE_OPS,
243248
)
244249

245250
if dump_graphs:
@@ -256,14 +261,20 @@ def export_to_edge(
256261
inputs: tuple[object, ...],
257262
dump_graphs: bool = False,
258263
constant_methods: Optional[dict[str, object]] = None,
264+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
259265
) -> EdgeProgramManager:
260266
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
261267

262268
# Export the model into an ExportedProgram.
263269
expo_program = trace(model, inputs)
264270

271+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+
expo_program = apply_torch_ops_passes(expo_program)
273+
265274
# Lower the model to edge IR.
266-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
275+
edge_prog_manager = _lower_ep_to_edge(
276+
expo_program, dump_graphs, constant_methods, core_aten_exceptions
277+
)
267278

268279
return edge_prog_manager
269280

@@ -305,14 +316,7 @@ def _lower_ep_to_cadence(
305316
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
306317
"""
307318
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
308-
cadence_passes = get_cadence_passes(opt_level)
309-
310-
# Run a couple required passes for quant/dequant ops
311-
cadence_prog_manager = edge_prog_manager.transform(
312-
cast(
313-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
314-
)
315-
)
319+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
316320
return cadence_prog_manager
317321

318322

@@ -323,14 +327,7 @@ def export_to_cadence(
323327
opt_level: int = 1,
324328
) -> EdgeProgramManager:
325329
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
326-
cadence_passes = get_cadence_passes(opt_level)
327-
328-
# Run a couple required passes for quant/dequant ops
329-
cadence_prog_manager = edge_prog_manager.transform(
330-
cast(
331-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
332-
)
333-
)
330+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
334331
return cadence_prog_manager
335332

336333

@@ -367,15 +364,8 @@ def export_to_executorch_gen_etrecord(
367364
memory_config: Optional[MemoryConfig] = None,
368365
dump_graphs: bool = False,
369366
) -> ExecutorchProgramManager:
370-
cadence_passes = get_cadence_passes(opt_level)
371367
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
372-
373-
# Run a couple required passes for quant/dequant ops
374-
cadence_prog_manager = edge_prog_manager.transform(
375-
cast(
376-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
377-
)
378-
)
368+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
379369

380370
# Print some information to terminal
381371
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
# TODO(T230417247): Use PassResult which is currently ignored.
129+
PassManager(aten_passes)(expo_program.graph_module)
130+
return expo_program

backends/nxp/backend/edge_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch.fx import Node
8+
from torch.nn import Parameter
89

910

1011
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
@@ -38,3 +39,35 @@ def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None:
3839
return None
3940

4041
return input_tensor(node, input_index)
42+
43+
44+
def node_is_static_tensor(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
45+
"""Return `True` if the given `node` has static data in the `parameters_mapping` dict.
46+
:param node: Tensor node to check for data.
47+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
48+
`state_dict` attribute of an edge program.
49+
"""
50+
return node.name in parameters_mapping.keys()
51+
52+
53+
def node_is_effectively_static_tensor(
54+
node: Node, parameters_mapping: dict[str, Parameter]
55+
) -> bool:
56+
"""Return `True` if the given `node` has static data, or follows after a `Dequantize` node with a static input.
57+
In the IR, the `node` will be turned into a static quantized tensor.
58+
:param node: Tensor node to check for data.
59+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
60+
`state_dict` attribute of an edge program.
61+
"""
62+
if node_is_static_tensor(node, parameters_mapping):
63+
return True
64+
65+
def _is_dequantize(node_: Node) -> bool:
66+
return node_.target.__name__ in {
67+
"quantized_decomposed.dequantize_per_tensor.default",
68+
"quantized_decomposed.dequantize_per_channel.default",
69+
}
70+
71+
return _is_dequantize(node) and node_is_static_tensor(
72+
node.args[0], parameters_mapping
73+
)

0 commit comments

Comments
 (0)