Skip to content

Commit 653883e

Browse files
add safetensor version (#75049)
1 parent 2ba7e79 commit 653883e

File tree

1 file changed

+11
-3
lines changed
  • python/paddle/framework

1 file changed

+11
-3
lines changed

python/paddle/framework/io.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,12 +1231,20 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
12311231
load_result = load_file(path)
12321232
load_result = _pack_loaded_dict(load_result)
12331233
else:
1234+
import safetensors
12341235
from safetensors.paddle import load_file
12351236

12361237
if isinstance(_current_expected_place(), core.CUDAPlace):
1237-
load_result = load_file(
1238-
path, device=_current_expected_place()
1239-
)
1238+
if (
1239+
safetensors.__version__ > "0.6.2"
1240+
and paddle.__version__ >= "3.2.0"
1241+
):
1242+
load_result = load_file(path, device='cuda')
1243+
else:
1244+
load_result = load_file(
1245+
path, device=_current_expected_place()
1246+
)
1247+
12401248
else:
12411249
load_result = load_file(path, device='cpu')
12421250

0 commit comments

Comments
 (0)