Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.exir.program._program import to_edge
from torch._inductor.decomposition import remove_decompositions

from torch.export.exported_program import ExportedProgram
Expand Down Expand Up @@ -219,9 +219,9 @@ def quantize_pt2(
torch.ops.aten.angle.default,
torch.ops.aten.rms_norm.default,
]
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [
torch.ops.aten.rms_norm.default,
)
]


def _lower_ep_to_edge(
Expand All @@ -233,18 +233,18 @@ def _lower_ep_to_edge(
"""
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
"""
# Call to_edge_with_preserved_ops to convert the graph to edge IR.
# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge_with_preserved_ops(
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
+ (core_aten_exceptions or []),
preserve_ops=TO_EDGE_PRESERVE_OPS,
),
constant_methods=constant_methods,
preserve_ops=TO_EDGE_PRESERVE_OPS,
)

if dump_graphs:
Expand Down
16 changes: 8 additions & 8 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.exir.program._program import to_edge
from executorch.extension.export_util.utils import save_pte_program


Expand Down Expand Up @@ -196,17 +196,17 @@ def main() -> None:
print("Exported program")
print(ep)

edge_manager = to_edge_with_preserved_ops(
edge_manager = to_edge(
ep,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.reciprocal.default,
],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.reciprocal.default,
],
),
)
print("Edge program")
Expand Down
6 changes: 3 additions & 3 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class EdgeCompileConfig:
_check_ir_validity: bool = True
# TODO(larryliu): remove this
_use_edge_ops: bool = True
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
# enabling verification on the rest of the program until decomposition coverage is improved.
Expand All @@ -47,9 +49,7 @@ class EdgeCompileConfig:
)
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
# These may be core or non-core ATen ops; custom ops should not be here.
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False
preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)


@compatibility(is_backward_compatible=False)
Expand Down
4 changes: 2 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,8 +1382,8 @@ def to_edge(
table = _default_decomposition_table()
preserve_ops = []
if compile_config:
preserve_ops = compile_config._preserve_ops
for op in compile_config._preserve_ops:
preserve_ops = compile_config.preserve_ops
for op in compile_config.preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
Expand Down
2 changes: 1 addition & 1 deletion exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def _test_to_edge_with_preserved_ops(
self, program, preserved_ops, expected_preserved_ops
):
edge = to_edge(
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
program, compile_config=EdgeCompileConfig(preserve_ops=preserved_ops)
)

def count_nodes(graph_module, target):
Expand Down
2 changes: 1 addition & 1 deletion exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, x):
return x.expand(2, 2, 2, 2)

model = TestExpand()
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
config = EdgeCompileConfig(preserve_ops=[torch.ops.aten.expand.default])
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
with self.assertRaises(RuntimeError):
to_edge(export_model, compile_config=config)
10 changes: 5 additions & 5 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def EXIRATenDialectVerifier( # noqa: C901
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)
if edge_compile_config.preserve_ops:
_preserve_ops.extend(edge_compile_config.preserve_ops)

class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
dialect = "OLD_EXIR_ATEN"
Expand Down Expand Up @@ -181,7 +181,7 @@ def get_aten_verifier(config: EdgeCompileConfig):
EXIRATenDialectVerifier(
class_only=True,
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
preserve_ops=config._preserve_ops,
preserve_ops=config.preserve_ops,
)
if config._check_ir_validity
else EXIRATenDialectVerifierBase
Expand Down Expand Up @@ -253,8 +253,8 @@ def EXIREdgeDialectVerifier( # noqa: C901
_core_aten_ops_exception_list.extend(
edge_compile_config._core_aten_ops_exception_list
)
if edge_compile_config._preserve_ops:
_preserve_ops.extend(edge_compile_config._preserve_ops)
if edge_compile_config.preserve_ops:
_preserve_ops.extend(edge_compile_config.preserve_ops)

class _EXIREdgeDialectVerifier(Verifier):
dialect = "EDGE"
Expand Down
Loading