@@ -73,7 +73,7 @@ def create(
7373 for wf in weight_format_priority_order :
7474 if wf == "pytorch_state_dict" and weights .pytorch_state_dict is not None :
7575 try :
76- from ._pytorch_model_adapter import PytorchModelAdapter
76+ from .pytorch_backend import PytorchModelAdapter
7777
7878 return PytorchModelAdapter (
7979 outputs = model_description .outputs ,
@@ -87,7 +87,7 @@ def create(
8787 and weights .tensorflow_saved_model_bundle is not None
8888 ):
8989 try :
90- from ._tensorflow_model_adapter import TensorflowModelAdapter
90+ from .tensorflow_backend import TensorflowModelAdapter
9191
9292 return TensorflowModelAdapter (
9393 model_description = model_description , devices = devices
@@ -96,7 +96,7 @@ def create(
9696 errors .append ((wf , e ))
9797 elif wf == "onnx" and weights .onnx is not None :
9898 try :
99- from ._onnx_model_adapter import ONNXModelAdapter
99+ from .onnx_backend import ONNXModelAdapter
100100
101101 return ONNXModelAdapter (
102102 model_description = model_description , devices = devices
@@ -105,7 +105,7 @@ def create(
105105 errors .append ((wf , e ))
106106 elif wf == "torchscript" and weights .torchscript is not None :
107107 try :
108- from ._torchscript_model_adapter import TorchscriptModelAdapter
108+ from .torchscript_backend import TorchscriptModelAdapter
109109
110110 return TorchscriptModelAdapter (
111111 model_description = model_description , devices = devices
@@ -117,13 +117,10 @@ def create(
117117 # we try to first import the keras model adapter using the separate package and,
118118 # if it is not available, try to load the one using tf
119119 try :
120- from ._keras import (
121- KerasModelAdapter ,
122- keras , # type: ignore
123- )
124-
125- if keras is None :
126- from ._tensorflow_model_adapter import KerasModelAdapter
120+ try :
121+ from .keras_backend import KerasModelAdapter
122+ except Exception :
123+ from .tensorflow_backend import KerasModelAdapter
127124
128125 return KerasModelAdapter (
129126 model_description = model_description , devices = devices
@@ -134,10 +131,11 @@ def create(
134131 assert errors
135132 if len (weight_format_priority_order ) == 1 :
136133 assert len (errors ) == 1
134+ wf , e = errors [0 ]
137135 raise ValueError (
138- f"The '{ weight_format_priority_order [ 0 ] } ' model adapter could not be created"
139- + f" in this environment:\n { errors [ 0 ][ 1 ] .__class__ .__name__ } ({ errors [ 0 ][ 1 ] } ).\n \n "
140- ) from errors [ 0 ][ 1 ]
136+ f"The '{ wf } ' model adapter could not be created"
137+ + f" in this environment:\n { e .__class__ .__name__ } ({ e } ).\n \n "
138+ ) from e
141139
142140 else :
143141 error_list = "\n - " .join (
@@ -165,13 +163,3 @@ def unload(self):
165163 Unload model from any devices, freeing their memory.
166164 The moder adapter should be considered unusable afterwards.
167165 """
168-
169-
170- def get_weight_formats () -> List [str ]:
171- """
172- Return list of supported weight types
173- """
174- return list (DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER )
175-
176-
177- create_model_adapter = ModelAdapter .create
0 commit comments