Skip to content

Commit aecef92

Browse files
cccclaifacebook-github-bot
authored andcommitted
passes on edge dialect will be done via tranform (#236)
Summary: Pull Request resolved: #236 For graph transform in edge dialect, let's use `transform`. Also we don't expect any transform from aten to edge as discussed {F1081902441} Reviewed By: angelayi Differential Revision: D48911694 fbshipit-source-id: 3f895b6e8266511b38c0521066bea5c681080a30
1 parent 29bbab8 commit aecef92

File tree

14 files changed

+62
-40
lines changed

14 files changed

+62
-40
lines changed

backends/xnnpack/test/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ python_unittest(
104104
"//executorch/backends/xnnpack/passes:xnnpack_passes",
105105
"//executorch/backends/xnnpack/utils:xnnpack_utils",
106106
"//executorch/exir:lib",
107+
"//executorch/exir:pass_base",
107108
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
108109
"//executorch/exir/dialects:lib",
109110
],
@@ -130,7 +131,7 @@ python_unittest(
130131
"//caffe2:torch",
131132
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
132133
"//executorch/backends/xnnpack/test/tester:tester",
133-
"//executorch/backends/xnnpack/utils:xnnpack_utils",
134+
"//executorch/exir:lib",
134135
"//pytorch/vision:torchvision",
135136
],
136137
)

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
XnnpackQuantizedPartitioner,
1919
)
2020
from executorch.backends.xnnpack.utils.configs import (
21+
get_transform_passes,
2122
get_xnnpack_edge_compile_config,
2223
get_xnnpack_executorch_backend_config,
2324
)
@@ -323,7 +324,9 @@ def quantize_and_test_model_with_quantizer(
323324
config=exir.CaptureConfig(enable_aot=True, _unlift=True),
324325
)
325326

326-
edge_program = captured_program.to_edge(get_xnnpack_edge_compile_config())
327+
edge_program = captured_program.to_edge(
328+
get_xnnpack_edge_compile_config()
329+
).transform(*get_transform_passes())
327330
delegated_module = self.lower_module_and_test_output(
328331
module=edge_program,
329332
sample_inputs=example_inputs,

backends/xnnpack/utils/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ python_library(
66
deps = [
77
"//caffe2:torch",
88
"//executorch/exir:lib",
9+
"//executorch/exir:pass_manager",
910
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
1011
"//executorch/exir/dialects:lib",
1112
],

backends/xnnpack/utils/configs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,28 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import List, Optional
88

99
import executorch.exir as exir
1010
from executorch.exir import CaptureConfig
1111
from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import (
1212
DuplicateDequantNodePass,
1313
)
14+
from executorch.exir.pass_manager import PassType
1415

1516
### XNNPACK Configs ###
16-
def get_xnnpack_edge_compile_config(additional_passes=None) -> exir.EdgeCompileConfig:
17-
additional_passes = additional_passes if additional_passes else []
18-
passes = additional_passes + [DuplicateDequantNodePass()]
17+
def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig:
1918
return exir.EdgeCompileConfig(
20-
passes=passes,
2119
_check_ir_validity=False,
2220
)
2321

2422

23+
def get_transform_passes(additional_passes=None) -> List[PassType]:
24+
additional_passes = additional_passes if additional_passes else []
25+
passes = additional_passes + [DuplicateDequantNodePass()]
26+
return passes
27+
28+
2529
def get_xnnpack_executorch_backend_config(
2630
additional_passes=None,
2731
) -> exir.ExecutorchBackendConfig:

backends/xnnpack/utils/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111

1212
from executorch.backends.xnnpack.utils.configs import (
13+
get_transform_passes,
1314
get_xnnpack_capture_config,
1415
get_xnnpack_edge_compile_config,
1516
)
@@ -25,11 +26,15 @@ def capture_graph_for_xnnpack(
2526
enable_aot: Optional[bool] = None,
2627
unlift: Optional[bool] = None,
2728
) -> exir.ExirExportedProgram:
28-
return exir.capture(
29-
module,
30-
inputs,
31-
get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift),
32-
).to_edge(get_xnnpack_edge_compile_config())
29+
return (
30+
exir.capture(
31+
module,
32+
inputs,
33+
get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift),
34+
)
35+
.to_edge(get_xnnpack_edge_compile_config())
36+
.transform(*get_transform_passes())
37+
)
3338

3439

3540
### XNNPACK Utils ###

examples/backend/xnnpack_examples.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,16 @@
9191
capture_config=CaptureConfig(enable_aot=True),
9292
edge_compile_config=EdgeCompileConfig(
9393
# TODO(T162080278): Duplicated Dequant nodes will be in quantizer spec
94-
_check_ir_validity=False if args.quantize else True,
95-
passes=[DuplicateDequantNodePass()],
94+
_check_ir_validity=False
95+
if args.quantize
96+
else True,
9697
),
9798
)
9899
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
99100

100-
edge.exported_program = to_backend(edge.exported_program, partitioner)
101+
edge.exported_program = to_backend(
102+
edge.transform(DuplicateDequantNodePass()).exported_program, partitioner
103+
)
101104
logging.info(f"Lowered graph:\n{edge.exported_program.graph}")
102105

103106
exec_prog = edge.to_executorch()

exir/capture/_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class CaptureConfig:
2929
@compatibility(is_backward_compatible=False)
3030
@dataclass
3131
class EdgeCompileConfig:
32-
passes: List[PassType] = field(default_factory=list)
3332
# TODO(qihan): remove ability to opt out
3433
_check_ir_validity: bool = True
3534
# TODO(larryliu): remove this

exir/emit/test/test_emit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def f(x: torch.Tensor) -> torch.Tensor:
179179

180180
program = (
181181
exir.capture(f, (torch.randn(100),), exir.CaptureConfig())
182-
.to_edge(exir.EdgeCompileConfig(passes=[ConstPropPass()]))
182+
.to_edge()
183+
.transform(ConstPropPass())
183184
.to_executorch()
184185
.program
185186
)
@@ -641,7 +642,8 @@ def forward(self, x):
641642
x = (torch.randn(1, 1, 2, 2),)
642643
program = (
643644
exir.capture(M(), x, exir.CaptureConfig())
644-
.to_edge(exir.EdgeCompileConfig(passes=[ConstPropPass()]))
645+
.to_edge()
646+
.transform(ConstPropPass())
645647
.to_executorch()
646648
.program
647649
)

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
283283
aten_to_edge_passes.passes[:-2]
284284
+ op_replace_pass
285285
+ aten_to_edge_passes.passes[-2:]
286-
) + config.passes
286+
)
287287
new_ep = copy.deepcopy(ep).transform(*passes)
288288
if config._check_ir_validity:
289289
EXIREdgeDialectVerifier(check_edge_ops=config._use_edge_ops)(

exir/tests/test_memory_planning.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,20 @@ def quantize(self, eager_model: nn.Module) -> nn.Module:
462462
def test_asr_joiner(self) -> None:
463463
eager_model = self.quantize(ASRJoiner())
464464
inputs = eager_model.get_random_inputs()
465-
edge_program = exir.capture(
466-
eager_model,
467-
inputs,
468-
exir.CaptureConfig(
469-
enable_dynamic_shape=True,
470-
),
471-
).to_edge(
472-
exir.EdgeCompileConfig(
473-
passes=[
474-
ConstPropPass(),
475-
],
476-
_check_ir_validity=False,
465+
edge_program = (
466+
exir.capture(
467+
eager_model,
468+
inputs,
469+
exir.CaptureConfig(
470+
enable_dynamic_shape=True,
471+
),
472+
)
473+
.to_edge(
474+
exir.EdgeCompileConfig(
475+
_check_ir_validity=False,
476+
)
477477
)
478+
.transform(ConstPropPass())
478479
)
479480
with validation_disabled():
480481
backend_module = to_backend(

0 commit comments

Comments
 (0)