File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -1231,12 +1231,20 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
1231
1231
load_result = load_file (path )
1232
1232
load_result = _pack_loaded_dict (load_result )
1233
1233
else :
1234
+ import safetensors
1234
1235
from safetensors .paddle import load_file
1235
1236
1236
1237
if isinstance (_current_expected_place (), core .CUDAPlace ):
1237
- load_result = load_file (
1238
- path , device = _current_expected_place ()
1239
- )
1238
+ if (
1239
+ safetensors .__version__ > "0.6.2"
1240
+ and paddle .__version__ >= "3.2.0"
1241
+ ):
1242
+ load_result = load_file (path , device = 'cuda' )
1243
+ else :
1244
+ load_result = load_file (
1245
+ path , device = _current_expected_place ()
1246
+ )
1247
+
1240
1248
else :
1241
1249
load_result = load_file (path , device = 'cpu' )
1242
1250
You can’t perform that action at this time.
0 commit comments