114114 "ExportedProgramDeserializer" ,
115115]
116116
117- from .upgrade import GraphModuleOpUpgrader
118117
119118log = logging .getLogger (__name__ )
120119
@@ -2220,12 +2219,8 @@ def deserialize_module_call_graph(
22202219
22212220
22222221class ExportedProgramDeserializer :
2223- def __init__ (self , expected_opset_version : Optional [Dict [str , int ]] = None ):
2224- self .expected_opset_version : Dict [str , int ] = {}
2225- if expected_opset_version :
2226- self .expected_opset_version .update (expected_opset_version )
2227- if "aten" not in self .expected_opset_version :
2228- self .expected_opset_version ["aten" ] = torch ._C ._get_max_operator_version ()
2222+ def __init__ (self ):
2223+ pass
22292224
22302225 def deserialize_range_constraints (
22312226 self ,
@@ -2278,13 +2273,6 @@ def deserialize(
22782273 symbol_name_to_range ,
22792274 res .names_to_symbols ,
22802275 )
2281- model_opset_version : Optional [Dict [str , int ]] = exported_program .opset_version
2282- self ._validate_model_opset_version (model_opset_version )
2283-
2284- upgrader = GraphModuleOpUpgrader (
2285- self .expected_opset_version , model_opset_version
2286- )
2287-
22882276 exported_program = ep .ExportedProgram (
22892277 root = res .graph_module ,
22902278 graph = res .graph_module .graph ,
@@ -2296,56 +2284,7 @@ def deserialize(
22962284 verifier = load_verifier (exported_program .dialect ),
22972285 constants = res .constants ,
22982286 )
2299- return upgrader .upgrade (exported_program )
2300-
2301- def _validate_model_opset_version (
2302- self , model_opset_version : Optional [Dict [str , int ]]
2303- ):
2304- """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version
2305- difference.
2306- E.g., model_opset_version = {"aten": 3, "custom": 4}
2307- expected_opset_version = {"aten": 4, "custom": 4}
2308- This means we can use an upgrader for ATen to reconcile the deserialized model.
2309-
2310- The logic of this method:
2311-
2312- For common op namespaces:
2313- 1. if model version < expected version, this case can be handled by upgraders.
2314- 2. if model version > expected version, we need downgraders but not implemented yet.
2315- 3. if model version == expected version, we don't need extra handling.
2316-
2317- For op namespace only in model_opset_version, we should give a warning because it is missing from
2318- expected_opset_version.
2319- """
2320- if not model_opset_version :
2321- raise RuntimeError ("Serialized model should have opset version." )
2322- common_namespaces = {
2323- key for key in model_opset_version if key in self .expected_opset_version
2324- }
2325- for namespace in common_namespaces :
2326- model_version = model_opset_version [namespace ]
2327- assert isinstance (
2328- model_version , int
2329- ), f"model_opset_version value should be int, got { model_version } "
2330-
2331- compiler_version = self .expected_opset_version [namespace ]
2332- assert isinstance (
2333- compiler_version , int
2334- ), f"expected_opset_version value should be int, got { compiler_version } "
2335-
2336- # TODO(larryliu0820): Add support for upgrader & downgrader
2337- if model_version != compiler_version :
2338- raise NotImplementedError (
2339- f"Model opset version { model_opset_version } doesn't match to compiler opset version "
2340- f"{ self .expected_opset_version } ! Upgrader/downgrader is not implemented yet."
2341- )
2342- for namespace in model_opset_version :
2343- if namespace in common_namespaces :
2344- continue
2345- log .warning (
2346- "Compiler doesn't have a version table for op namespace: {ns}. " ,
2347- extra = {"ns" : namespace },
2348- )
2287+ return exported_program
23492288
23502289
23512290class EnumEncoder (json .JSONEncoder ):
@@ -2435,15 +2374,14 @@ def _dict_to_dataclass(cls, data):
24352374
24362375def deserialize (
24372376 artifact : SerializedArtifact ,
2438- expected_opset_version : Optional [Dict [str , int ]] = None ,
24392377) -> ep .ExportedProgram :
24402378 assert isinstance (artifact .exported_program , bytes )
24412379 exported_program_str = artifact .exported_program .decode ("utf-8" )
24422380 exported_program_dict = json .loads (exported_program_str )
24432381 serialized_exported_program = _dict_to_dataclass (
24442382 ExportedProgram , exported_program_dict
24452383 )
2446- return ExportedProgramDeserializer (expected_opset_version ).deserialize (
2384+ return ExportedProgramDeserializer ().deserialize (
24472385 serialized_exported_program ,
24482386 artifact .state_dict ,
24492387 artifact .constants ,
0 commit comments