File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -1231,12 +1231,20 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
12311231 load_result = load_file (path )
12321232 load_result = _pack_loaded_dict (load_result )
12331233 else :
1234+ import safetensors
12341235 from safetensors .paddle import load_file
12351236
12361237 if isinstance (_current_expected_place (), core .CUDAPlace ):
1237- load_result = load_file (
1238- path , device = _current_expected_place ()
1239- )
1238+ if (
1239+ safetensors .__version__ > "0.6.2"
1240+ and paddle .__version__ >= "3.2.0"
1241+ ):
1242+ load_result = load_file (path , device = 'cuda' )
1243+ else :
1244+ load_result = load_file (
1245+ path , device = _current_expected_place ()
1246+ )
1247+
12401248 else :
12411249 load_result = load_file (path , device = 'cpu' )
12421250
You can’t perform that action at this time.
0 commit comments