Skip to content

Commit 5c4c6ce

Browse files
committed
Update
[ghstack-poisoned]
2 parents 8230848 + b35e7b1 commit 5c4c6ce

File tree

17 files changed

+150
-68
lines changed

17 files changed

+150
-68
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,27 @@
2323
from torch.fx.passes.operator_support import OperatorSupportBase
2424

2525
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
26+
logger.setLevel(logging.INFO)
2727

2828

29-
class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
29+
class _OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
3030
def __init__(
3131
self,
3232
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
3333
lower_full_graph: bool = False,
34+
log: bool = False,
3435
) -> None:
3536
if skip_ops_for_coreml_delegation is None:
3637
skip_ops_for_coreml_delegation = []
3738
super().__init__()
3839
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
3940
self.lower_full_graph = lower_full_graph
4041
self._logged_msgs = set()
42+
self._log = log
4143

4244
def log_once(self, msg: str) -> None:
43-
if msg not in self._logged_msgs:
44-
logging.info(msg)
45+
if self._log and msg not in self._logged_msgs:
46+
logger.info(msg)
4547
self._logged_msgs.add(msg)
4648

4749
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
@@ -154,8 +156,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
154156

155157
capability_partitioner = CapabilityBasedPartitioner(
156158
exported_program.graph_module,
157-
OperatorsSupportedForCoreMLBackend(
158-
self.skip_ops_for_coreml_delegation, self.lower_full_graph
159+
_OperatorsSupportedForCoreMLBackend(
160+
self.skip_ops_for_coreml_delegation,
161+
self.lower_full_graph,
162+
log=True,
159163
),
160164
allows_single_node_partition=True,
161165
)
@@ -191,8 +195,10 @@ def ops_to_not_decompose(
191195
self, ep: ExportedProgram
192196
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
193197
do_not_decompose = []
194-
op_support = OperatorsSupportedForCoreMLBackend(
195-
self.skip_ops_for_coreml_delegation, self.lower_full_graph
198+
op_support = _OperatorsSupportedForCoreMLBackend(
199+
self.skip_ops_for_coreml_delegation,
200+
self.lower_full_graph,
201+
log=False,
196202
)
197203

198204
# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1717
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1818
from executorch.exir.backend.utils import format_delegated_graph
19-
from executorch.runtime import Runtime
2019

2120

2221
@torch.library.custom_op("unsupported::linear", mutates_args=())
@@ -37,7 +36,13 @@ def _(
3736
return torch.ops.aten.linear.default(x, w, b)
3837

3938

40-
_TEST_RUNTIME = sys.platform == "darwin"
39+
def is_fbcode():
40+
return not hasattr(torch.version, "git_version")
41+
42+
43+
_TEST_RUNTIME = (sys.platform == "darwin") and not is_fbcode()
44+
if _TEST_RUNTIME:
45+
from executorch.runtime import Runtime
4146

4247

4348
class TestCoreMLPartitioner(unittest.TestCase):

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@
1414

1515
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17-
from executorch.runtime import Runtime
1817
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1918

20-
_TEST_RUNTIME = sys.platform == "darwin" and tuple(
21-
map(int, platform.mac_ver()[0].split("."))
22-
) >= (15, 0)
19+
20+
def is_fbcode():
21+
return not hasattr(torch.version, "git_version")
22+
23+
24+
_TEST_RUNTIME = (
25+
(sys.platform == "darwin")
26+
and not is_fbcode()
27+
and tuple(map(int, platform.mac_ver()[0].split("."))) >= (15, 0)
28+
)
29+
if _TEST_RUNTIME:
30+
from executorch.runtime import Runtime
2331

2432

2533
class TestTorchOps(unittest.TestCase):

examples/models/llama/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ find_package(gflags REQUIRED)
7777
# llama_main: test binary to run llama, with tokenizer and sampler integrated
7878
#
7979

80-
# find `executorch` libraries Same as for gflags
81-
set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch)
82-
find_package(executorch CONFIG REQUIRED)
80+
# find `executorch` libraries. CMAKE_PREFIX_PATH would work for host
81+
# compilation, but CMAKE_FIND_ROOT_PATH appears to be necessary for
82+
# cross-compiling (e.g., to Android) to work as well.
83+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
84+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
8385
target_link_options_shared_lib(executorch)
8486

8587
# llama_runner library

examples/models/llava/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ find_package(gflags REQUIRED)
7676
#
7777

7878
# find `executorch` libraries Same as for gflags
79-
set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch)
80-
find_package(executorch CONFIG REQUIRED)
79+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
80+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
8181
target_link_options_shared_lib(executorch)
8282

8383
# llava_runner library

examples/models/phi-3-mini/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../..")
2424
set(_common_include_directories
2525
${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
2626
)
27-
set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../lib/cmake/ExecuTorch)
28-
find_package(executorch CONFIG REQUIRED)
27+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
28+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
2929

3030
target_link_options_shared_lib(executorch)
3131

exir/backend/test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def false_fn(x, y):
10331033

10341034
def f(x, y):
10351035
x = x + y
1036-
x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1036+
x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
10371037
x = x - y
10381038
return x
10391039

exir/program/_program.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,28 @@ def keep(op):
10761076
return list(filter(keep, preserve_ops))
10771077

10781078

1079+
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1080+
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1081+
) -> bool:
1082+
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1083+
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1084+
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1085+
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1086+
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1087+
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1088+
can_skip_using_EDGE_DO_NOT_DECOMP = True
1089+
for name, program in aten_programs.items():
1090+
if partitioner is not None:
1091+
for curr_partitioner in partitioner.get(name, []):
1092+
(
1093+
curr_ops_no_decomp,
1094+
check_op_support,
1095+
) = curr_partitioner.ops_to_not_decompose(program)
1096+
if check_op_support is not None:
1097+
can_skip_using_EDGE_DO_NOT_DECOMP = False
1098+
return can_skip_using_EDGE_DO_NOT_DECOMP
1099+
1100+
10791101
def _gen_edge_manager_for_partitioners(
10801102
partitioner: Dict[str, List[Partitioner]],
10811103
aten_programs: Dict[str, ExportedProgram],
@@ -1095,37 +1117,56 @@ def _gen_edge_manager_for_partitioners(
10951117
on nodes with preserved aten targets. They are then replaces with transformed ops to
10961118
keep them through the second pass of decompositions
10971119
"""
1120+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1121+
partitioner, aten_programs
1122+
)
10981123
ops_set_to_not_decompose_by_program = {}
10991124
edge_programs: Dict[str, ExportedProgram] = {}
11001125
for name, program in aten_programs.items():
1126+
# Functionalize program before asking partitioners to preserve ops
1127+
program = program.run_decompositions({})
1128+
11011129
if partitioner is not None:
11021130
# preserve all ops listed by all partitioners first
11031131
all_ops_no_decomp = set()
1132+
all_ops_no_decomp_needing_preservation = []
11041133
for curr_partitioner in partitioner.get(name, []):
11051134
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1106-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1107-
curr_ops_no_decomp
1108-
)
11091135
all_ops_no_decomp |= set(curr_ops_no_decomp)
11101136

1111-
table = _default_decomposition_table()
1137+
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1138+
# Otherwise there will be issues
1139+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1140+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1141+
list(all_ops_no_decomp)
1142+
)
1143+
all_ops_no_decomp = set(all_ops_no_decomp)
11121144

1145+
# Run default decompositions, except for those in all_ops_no_decomp
1146+
table = _default_decomposition_table()
11131147
for op in all_ops_no_decomp:
1114-
table.pop(op, None)
1115-
1148+
if table.pop(op, None) is not None:
1149+
all_ops_no_decomp_needing_preservation.append(op)
11161150
program = program.run_decompositions(table)
1151+
11171152
# Among all the preserved aten ops, use the check_op_fn to do an additional
11181153
# check on which ops need to be preserved and which ops need to be decomposed
11191154
# Those which are truly preserved will be replaced with transformed ops
1120-
ops_set_to_not_decompose_by_program[name] = (
1121-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1122-
)
1123-
program = program.run_decompositions(_default_decomposition_table())
1155+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1156+
ops_set_to_not_decompose_by_program[name] = (
1157+
all_ops_no_decomp_needing_preservation
1158+
)
1159+
else:
1160+
ops_set_to_not_decompose_by_program[name] = (
1161+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1162+
or []
1163+
)
11241164

1125-
_restore_transformed_ops_to_aten_ops(program)
1165+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1166+
program = program.run_decompositions(_default_decomposition_table())
1167+
_restore_transformed_ops_to_aten_ops(program)
11261168

11271169
edge_programs[name] = program
1128-
11291170
edge_programs[name] = _generate_edge_program(
11301171
config,
11311172
program,
@@ -1169,7 +1210,7 @@ def collect_named_data_store_outputs(
11691210

11701211

11711212
@et_logger("to_edge_transform_and_lower")
1172-
def to_edge_transform_and_lower(
1213+
def to_edge_transform_and_lower( # noqa: C901
11731214
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
11741215
transform_passes: Optional[
11751216
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
@@ -1234,6 +1275,9 @@ def to_edge_transform_and_lower(
12341275
elif partitioner is None:
12351276
partitioner = {name: [] for name in aten_programs.keys()}
12361277

1278+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1279+
partitioner, aten_programs
1280+
)
12371281
edge_manager = _gen_edge_manager_for_partitioners(
12381282
partitioner, aten_programs, config, constant_methods
12391283
)
@@ -1259,7 +1303,8 @@ def to_edge_transform_and_lower(
12591303
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
12601304
program
12611305
)
1262-
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
1306+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1307+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
12631308
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
12641309
_sanity_check_graph_for_non_decomp_ops(
12651310
name,

exir/tests/control_flow_models.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ def true_branch(x):
2020
def false_branch(x):
2121
return x * x
2222

23-
return torch.ops.higher_order.cond(
24-
inp.sum() > 4, true_branch, false_branch, [inp]
25-
)
23+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
2624

2725
def get_random_inputs(self):
2826
return (torch.rand(5),)
@@ -39,9 +37,7 @@ def true_branch(x):
3937
def false_branch(x):
4038
return x * x * x
4139

42-
return torch.ops.higher_order.cond(
43-
inp.sum() > 4, true_branch, false_branch, [inp]
44-
)
40+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
4541

4642
def get_upper_bound_inputs(self):
4743
return (torch.rand(8),)
@@ -72,9 +68,7 @@ def true_branch(x):
7268
def false_branch(x):
7369
return x * 2
7470

75-
return torch.ops.higher_order.cond(
76-
inp.sum() > 4, true_branch, false_branch, [inp]
77-
)
71+
return torch.cond(inp.sum() > 4, true_branch, false_branch, [inp])
7872

7973
def get_random_inputs(self):
8074
return (torch.eye(5) * 2,)

exir/tests/test_passes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,9 +1463,7 @@ def forward(self, pred, x):
14631463
out = torch.nn.functional.linear(
14641464
x, self.w.to(torch.float16).to(torch.float32)
14651465
)
1466-
return torch.ops.higher_order.cond(
1467-
pred, self.true_fn, self.false_fn, [out]
1468-
)
1466+
return torch.cond(pred, self.true_fn, self.false_fn, [out])
14691467

14701468
mod = Module()
14711469
x = torch.randn([3, 3])

0 commit comments

Comments
 (0)