Skip to content

Commit 1ef5b94

Browse files
authored
[TensorParallel] Support naive split for lazy safetensors (#7018)
1 parent 6e6301e commit 1ef5b94

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

paddlenlp/transformers/conversion_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,32 @@ def naive_fuse_split_tp(
299299
300300
"""
301301
axis = -1 if is_column else 0
302+
if "PySafeSlice" in str(type(weight)):
303+
size = weight.get_shape()[axis]
304+
block_size = size // (fuse_tensor_parts * tensor_parallel_degree)
305+
306+
splited = []
307+
if tensor_parallel_rank is None:
308+
begin, end, step = 0, fuse_tensor_parts * tensor_parallel_degree, 1
309+
else:
310+
begin, end, step = tensor_parallel_rank, fuse_tensor_parts * tensor_parallel_degree, tensor_parallel_degree
311+
for rank in range(begin, end, step):
312+
start = rank * block_size
313+
stop = (rank + 1) * block_size
314+
if axis == 0 or len(weight.get_shape()) == 1:
315+
tensor = weight[start:stop]
316+
else:
317+
tensor = weight[:, start:stop]
318+
splited.append(tensor)
319+
320+
if tensor_parallel_rank is None:
321+
ret = []
322+
for tensor_parallel_rank in range(tensor_parallel_degree):
323+
ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
324+
return ret
325+
326+
return np.concatenate(splited, axis=axis)
327+
302328
splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
303329

304330
if tensor_parallel_rank is None:

paddlenlp/transformers/model_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,8 +1955,14 @@ def from_pretrained(
19551955
)
19561956
else:
19571957
# 4. loading non-sharded ckpt from the state dict
1958-
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
1959-
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
1958+
if config.tensor_parallel_degree > 1:
1959+
if resolved_archive_file.endswith("model_state.pdparams"):
1960+
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
1961+
elif resolved_archive_file.endswith("model.safetensors"):
1962+
with safe_open(resolved_archive_file, framework="np", device="cpu") as f:
1963+
loaded_keys = f.keys()
1964+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
1965+
state_dict = load_state_dict(resolved_archive_file, tp_actions)
19601966
else:
19611967
state_dict = load_state_dict(resolved_archive_file)
19621968

0 commit comments

Comments
 (0)