Skip to content

Commit de34197

Browse files
chunnienccopybara-github
authored andcommitted
add AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER to ai_edge_torch.config
PiperOrigin-RevId: 718950902
1 parent b215780 commit de34197

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

ai_edge_torch/_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,14 @@ def enable_group_norm_composite(self) -> bool:
6565
def enable_group_norm_composite(self, value: bool):
6666
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
6767

68+
@property
69+
def layout_optimize_partitioner(self) -> str:
70+
"""The algorithm to use for layout optimization."""
71+
return os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", "DEFAULT")
72+
73+
@layout_optimize_partitioner.setter
74+
def layout_optimize_partitioner(self, value: str):
75+
os.environ["AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()
76+
6877

6978
config = _Config()

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
from typing import Union
2020

21+
import ai_edge_torch
2122
from ai_edge_torch import fx_infra
2223
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
2324
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
@@ -261,10 +262,8 @@ def call(self, exported_program: torch.export.ExportedProgram):
261262
self.mark_const_nodes(exported_program)
262263

263264
graph_module = exported_program.graph_module
264-
partitioner = os.environ.get(
265-
"AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None
266-
)
267-
if partitioner == "MINCUT":
265+
partitioner = ai_edge_torch.config.layout_optimize_partitioner
266+
if partitioner in ("MINCUT", "OPTIMAL"):
268267
graph_module = layout_partitioners.min_cut.partition(graph_module)
269268
elif partitioner == "GREEDY":
270269
graph_module = layout_partitioners.greedy.partition(graph_module)

0 commit comments

Comments
 (0)