@@ -1799,34 +1799,43 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
17991799 else :
18001800 self ._parse_devices ()
18011801
1802- if isinstance (file , str ) and file .endswith (".safetensors" ):
1803- from safetensors .torch import load_file
1804- state_dict = load_file (file , device = device )
1805- self .model .load_state_dict (state_dict )
1802+ if isinstance (file , str ):
1803+ if file .endswith (".safetensors" ):
1804+ from safetensors .torch import load_file
1805+ state_dict = load_file (file , device = device )
1806+ self .model .load_state_dict (state_dict )
18061807
1807- self .model_to_device ()
1808+ self .model_to_device ()
18081809
1809- if make_infrastructure :
1810- self .make_infrastructure ()
1810+ if make_infrastructure :
1811+ self .make_infrastructure ()
18111812
1812- self .set_model_mode (mode )
1813+ self .set_model_mode (mode )
18131814
1814- return
1815- elif isinstance ( file , str ) and file .endswith (".onnx" ):
1816- try :
1817- from onnx2torch import convert
1818- except ImportError as e :
1819- raise ImportError ('Loading model, stored in ONNX format, requires `onnx2torch` library.' ) from e
1815+ return
1816+ elif file .endswith (".onnx" ):
1817+ try :
1818+ from onnx2torch import convert
1819+ except ImportError as e :
1820+ raise ImportError ('Loading model, stored in ONNX format, requires `onnx2torch` library.' ) from e
18201821
1821- model = convert (file ).eval ()
1822- self .model = model
1822+ model = convert (file ).eval ()
1823+ self .model = model
18231824
1824- self .model_to_device ()
1825+ self .model_to_device ()
18251826
1826- if make_infrastructure :
1827- self .make_infrastructure ()
1827+ if make_infrastructure :
1828+ self .make_infrastructure ()
18281829
1829- self .set_model_mode (mode )
1830+ self .set_model_mode (mode )
1831+ elif file .endswith (".openvino" ):
1832+ model = OVModel (model_path = file , ** model_load_kwargs )
1833+ self .model = model
1834+
1835+ self ._loaded_from_openvino = True
1836+ self .disable_training = True
1837+
1838+ return
18301839
18311840 kwargs ['map_location' ] = self .device
18321841
@@ -1863,6 +1872,11 @@ def load(self, file, make_infrastructure=False, mode='eval', pickle_module=dill,
18631872 self ._loaded_from_onnx = True
18641873 self .disable_training = True
18651874
1875+ if "safetensors" in checkpoint :
1876+ from safetensors .torch import load_file
1877+ state_dict = load_file (checkpoint ['path_safetensors' ], device = device )
1878+ self .model .load_state_dict (state_dict )
1879+
18661880 self .model_to_device ()
18671881
18681882 if make_infrastructure :
0 commit comments