diff --git a/exir/dialects/_ops.py b/exir/dialects/_ops.py index ec4d71395bf..fc25158e07c 100644 --- a/exir/dialects/_ops.py +++ b/exir/dialects/_ops.py @@ -100,7 +100,7 @@ def __getattr__(self, op_name): parent_packet = getattr(self._op_namespace, op_name) except AttributeError as e: # Turn this into AttributeError so getattr(obj, key, default) - # works (this is called by TorchScript with __origin__) + # works raise AttributeError( f"'_OpNamespace' '{self._dialect}.{self._name}' object has no attribute '{op_name}'" ) from e diff --git a/exir/serde/TARGETS b/exir/serde/TARGETS index 7bede435359..ec3db22aac6 100644 --- a/exir/serde/TARGETS +++ b/exir/serde/TARGETS @@ -12,7 +12,6 @@ python_library( "schema_check.py", "serialize.py", "union.py", - "upgrade.py", ], deps = [ "fbsource//third-party/pypi/sympy:sympy", diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 08cd03adcea..7a1d35c432e 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -114,7 +114,6 @@ "ExportedProgramDeserializer", ] -from .upgrade import GraphModuleOpUpgrader log = logging.getLogger(__name__) @@ -2220,12 +2219,8 @@ def deserialize_module_call_graph( class ExportedProgramDeserializer: - def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): - self.expected_opset_version: Dict[str, int] = {} - if expected_opset_version: - self.expected_opset_version.update(expected_opset_version) - if "aten" not in self.expected_opset_version: - self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + def __init__(self): + pass def deserialize_range_constraints( self, @@ -2278,13 +2273,6 @@ def deserialize( symbol_name_to_range, res.names_to_symbols, ) - model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version - self._validate_model_opset_version(model_opset_version) - - upgrader = GraphModuleOpUpgrader( - self.expected_opset_version, model_opset_version - ) - exported_program = ep.ExportedProgram( root=res.graph_module, graph=res.graph_module.graph, @@ -2296,56 +2284,7 @@ def deserialize( verifier=load_verifier(exported_program.dialect), constants=res.constants, ) - return upgrader.upgrade(exported_program) - - def _validate_model_opset_version( - self, model_opset_version: Optional[Dict[str, int]] - ): - """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version - difference. - E.g., model_opset_version = {"aten": 3, "custom": 4} - expected_opset_version = {"aten": 4, "custom": 4} - This means we can use an upgrader for ATen to reconcile the deserialized model. - - The logic of this method: - - For common op namespaces: - 1. if model version < expected version, this case can be handled by upgraders. - 2. if model version > expected version, we need downgraders but not implemented yet. - 3. if model version == expected version, we don't need extra handling. - - For op namespace only in model_opset_version, we should give a warning because it is missing from - expected_opset_version. - """ - if not model_opset_version: - raise RuntimeError("Serialized model should have opset version.") - common_namespaces = { - key for key in model_opset_version if key in self.expected_opset_version - } - for namespace in common_namespaces: - model_version = model_opset_version[namespace] - assert isinstance( - model_version, int - ), f"model_opset_version value should be int, got {model_version}" - - compiler_version = self.expected_opset_version[namespace] - assert isinstance( - compiler_version, int - ), f"expected_opset_version value should be int, got {compiler_version}" - - # TODO(larryliu0820): Add support for upgrader & downgrader - if model_version != compiler_version: - raise NotImplementedError( - f"Model opset version {model_opset_version} doesn't match to compiler opset version " - f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet." - ) - for namespace in model_opset_version: - if namespace in common_namespaces: - continue - log.warning( - "Compiler doesn't have a version table for op namespace: {ns}. ", - extra={"ns": namespace}, - ) + return exported_program class EnumEncoder(json.JSONEncoder): @@ -2435,7 +2374,6 @@ def _dict_to_dataclass(cls, data): def deserialize( artifact: SerializedArtifact, - expected_opset_version: Optional[Dict[str, int]] = None, ) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) exported_program_str = artifact.exported_program.decode("utf-8") @@ -2443,7 +2381,7 @@ def deserialize( serialized_exported_program = _dict_to_dataclass( ExportedProgram, exported_program_dict ) - return ExportedProgramDeserializer(expected_opset_version).deserialize( + return ExportedProgramDeserializer().deserialize( serialized_exported_program, artifact.state_dict, artifact.constants, diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index c9605018c4a..b587813c72c 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -32,7 +32,7 @@ from executorch.exir.lowered_backend_module import ( LoweredBackendModule as ExirLoweredBackendModule, ) -from executorch.exir.serde.export_serialize import GraphModuleOpUpgrader, SerializeError +from executorch.exir.serde.export_serialize import SerializeError from executorch.exir.serde.schema import ( CompileSpec, LoweredBackendModule as SerdeLoweredBackendModule, @@ -617,12 +617,6 @@ def deserialize( symbol_name_to_range, res.names_to_symbols, ) - model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version - self._validate_model_opset_version(model_opset_version) - - upgrader = GraphModuleOpUpgrader( - self.expected_opset_version, model_opset_version - ) dummy_g = torch.fx.Graph() dummy_g.output(()) @@ -656,7 +650,7 @@ def deserialize( node.target, getattr(res.graph_module, node.target), ) - return upgrader.upgrade(exported_program) + return exported_program def serialize( @@ -683,7 +677,6 @@ def serialize( def deserialize( artifact: export_serialize.SerializedArtifact, - expected_opset_version: Optional[Dict[str, int]] = None, ) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) exported_program_str = artifact.exported_program.decode("utf-8") @@ -691,7 +684,7 @@ def deserialize( serialized_exported_program = export_serialize._dict_to_dataclass( schema.ExportedProgram, exported_program_dict ) - return ExportedProgramDeserializer(expected_opset_version).deserialize( + return ExportedProgramDeserializer().deserialize( serialized_exported_program, artifact.state_dict, artifact.constants, @@ -735,7 +728,6 @@ def load( f: Union[str, os.PathLike[str], io.BytesIO], *, extra_files: Optional[Dict[str, Any]] = None, - expected_opset_version: Optional[Dict[str, int]] = None, ) -> ep.ExportedProgram: if isinstance(f, (str, os.PathLike)): f = os.fspath(str(f)) @@ -796,6 +788,6 @@ def load( ) # Deserialize ExportedProgram - ep = deserialize(artifact, expected_opset_version) + ep = deserialize(artifact) return ep diff --git a/exir/serde/upgrade.py b/exir/serde/upgrade.py deleted file mode 100644 index 5f0ffbf6818..00000000000 --- a/exir/serde/upgrade.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import re -from collections import defaultdict -from typing import Dict, List, Optional, Tuple - -import torch -from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse -from torch._export.pass_infra.node_metadata import NodeMetadata -from torch._export.pass_infra.proxy_value import ProxyValue -from torch.fx.node import Argument, Target -from torch.library import Library - -lib = Library("aten", "FRAGMENT") -impl_lib = Library("aten", "IMPL") - -log = logging.getLogger(__name__) - - -def get_target_version(versioned_upgrader_name: str) -> int: - """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is - upgrading to version 4.""" - if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): - raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") - - return int(versioned_upgrader_name.split("_")[-1]) + 1 - - -def get_upgraders() -> Dict[str, Tuple[str, str]]: - """Getting upgraders entry map and operator version map and merge them into one dict.""" - upgraders = torch._C._get_upgraders_entry_map() - op_version_map = torch._C._get_operator_version_map() - output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] - for opname, entry_list in op_version_map.items(): - if not entry_list: - raise RuntimeError(f"Op version map has an empty entry for opname {opname}") - entry = entry_list[0] - old_schema = entry.old_schema - upgrader_name = entry.upgrader_name - upgrader_str = upgraders.get(upgrader_name, None) - if not upgrader_str: - raise RuntimeError( - f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}" - ) - output[upgrader_name] = (old_schema, upgrader_str) - return output - - -class GraphModuleOpUpgrader: - """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. - To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In - __init__() it does the following: - 1. parse the upgrader list and reorder for upgrading purpose. - 2. register old versions of operators as custom ops. - 3. prepare upgrader passes. - - In `upgrade()` API run these upgrader passes. - - An example of op_upgraders input: - { - "aten::div__Scalar_0_3": ( # versioned op name - "div._Scalar(self: Tensor, other: Scalar)", # old schema - ''' - def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string - if (self.is_floating_point() or isinstance(other, float)): - return self.true_divide_(other) - return self.divide_(other, rounding_mode='trunc') - ''', - ), - }, - - Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the - original TorchScript upgrader). - """ - - class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): - def __init__(self, old_target: Target, new_target: Target): - super().__init__() - self.old_target = old_target - self.new_target = new_target - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op == self.old_target: - return super().call_operator(self.new_target, args, kwargs, meta) - return super().call_operator(op, args, kwargs, meta) - - def __init__( - self, - compiler_opset_version: Optional[Dict[str, int]] = None, - model_opset_version: Optional[Dict[str, int]] = None, - op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, - ): - self.op_upgraders: Dict[str, Tuple[str, str]] = ( - get_upgraders() if not op_upgraders else op_upgraders - ) - self.compiler_opset_version = ( - compiler_opset_version if compiler_opset_version else {} - ) - self.model_opset_version = model_opset_version if model_opset_version else {} - self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = ( - GraphModuleOpUpgrader._populate_passes( - self._parse_upgraders(self.op_upgraders) - ) - ) - - def _parse_upgraders( - self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None - ) -> List[Tuple[str, str]]: - """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well - as the upgrader function string literal.""" - # TODO(larryliu0820): Add support for custom ops - op_namespace = "aten" - if ( - not op_upgraders - or op_namespace not in self.model_opset_version - or op_namespace not in self.compiler_opset_version - ): - return [] - model_ver = self.model_opset_version[op_namespace] - curr_ver = self.compiler_opset_version[op_namespace] - - # key is the target version. div__Scalar_0_3 should have a key of 4. - versioned_upgraders: Dict[int, Tuple[str, str]] = { - get_target_version(name): v for name, v in op_upgraders.items() - } - target_upgraders: List[Tuple[str, str]] = [] - # we need all upgraders from model_ver + 1 to curr_ver, inclusively - for ver in range(model_ver + 1, curr_ver + 1): - if ver in versioned_upgraders: - target_upgraders.append(versioned_upgraders[ver]) - else: - # we may be able to get away with missing upgraders, if that operator is missing from given graph - # module. - log.warning( - "Missing an upgrader to upgrade to version {ver}.", - extra={"ver": ver}, - ) - - return target_upgraders - - @staticmethod - def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: - """Given a list of upgraders, loop through it from lower version to higher version and create passes for all - upgraders. se torch.Library API to register old ops. Op name will be - __. Register upgraders as CompositeImplicitAutograd kernels. For example: - - lib = Library("aten", "FRAGMENT") - lib.define(old_schema) - - impl_lib = Library("aten", "IMPL") - impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") - - @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the - upgrader function literal text. - @:return upgrader passes, order matters - """ - - upgrader_passes = [] - - def register_old_op(name: str, schema: str, impl_str: str): - """Registers an old version operator using impl_name as old op name.""" - lib.define(schema) - try: - exec(impl_str) - except Exception as e: - raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e - impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") - - for schema, upgrader_str in upgraders: - upgrader_name = upgrader_str.split("(")[0].split(" ")[-1] - op_name = schema.split("(")[0].split("::")[-1] - schema = schema.replace(op_name, upgrader_name) - try: - register_old_op( - name=upgrader_name, schema=schema, impl_str=upgrader_str - ) - except RuntimeError as e: - if "with the same name and overload name multiple times" in str(e): - print(f"Registering {upgrader_name} multiple times") - else: - raise RuntimeError from e - old_op_target = getattr(torch.ops.aten, upgrader_name).default - # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the - # "default" at the end. - op_name, overload_name = ( - (op_name, "default") - if "." not in op_name - else tuple(op_name.split(".")[:2]) - ) - new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) - # Note that the graph will have op names in the graph, but actually they are of old versions. - upgrader_passes.append( - GraphModuleOpUpgrader.UpgraderPass( - old_target=new_op_target, new_target=old_op_target - ) - ) - - return upgrader_passes - - def upgrade(self, exported_program): - return exported_program