Skip to content

Commit 1b8f37f

Browse files
authored
[AutoParallel] Fix ckpt convert bug for sharding v2 (#9179)
1 parent 1ce2642 commit 1b8f37f

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def gen_metadata_and_prepare_source_state_dict(self):
214214
assert self.model_meta is not None
215215
global_model_state_shapes = []
216216
sharding_metas_keys = []
217-
for i in range(self.pp_degree):
218-
for j in range(self.tp_degree):
217+
for i in range(self.tp_degree):
218+
for j in range(self.pp_degree):
219219
sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j))
220220
for key in sharding_metas_keys:
221221
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
@@ -247,24 +247,25 @@ def gen_metadata_and_prepare_source_state_dict(self):
247247
# Generate the optimizer states corresponding to the model weights.
248248
logger.info("Requesting GPU memory space to concatenate tensors split by sharding1 v2.")
249249
optimizer_state_dict = {}
250-
for key in cur_rank_need_load_model_state_keys:
251-
for tp_rank in range(self.tp_degree):
252-
tp_rank_suffix = "_tp{:02d}".format(tp_rank)
253-
optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros(
254-
(param_flattened_shapes[key],), "float32"
255-
)
256-
optimizer_state_dict[key + ".moment2" + tp_rank_suffix] = paddle.zeros(
257-
(param_flattened_shapes[key],), "float32"
258-
)
259-
if self.optimizer_state_with_master_weights:
260-
optimizer_state_dict[key + ".master_weight" + tp_rank_suffix] = paddle.zeros(
250+
with paddle.base.dygraph.guard(place=paddle.CPUPlace()):
251+
for key in cur_rank_need_load_model_state_keys:
252+
for tp_rank in range(self.tp_degree):
253+
tp_rank_suffix = "_tp{:02d}".format(tp_rank)
254+
optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros(
261255
(param_flattened_shapes[key],), "float32"
262256
)
263-
# When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
264-
# Later, when these are compared with the global shape, we realize that they are replicated.
257+
optimizer_state_dict[key + ".moment2" + tp_rank_suffix] = paddle.zeros(
258+
(param_flattened_shapes[key],), "float32"
259+
)
260+
if self.optimizer_state_with_master_weights:
261+
optimizer_state_dict[key + ".master_weight" + tp_rank_suffix] = paddle.zeros(
262+
(param_flattened_shapes[key],), "float32"
263+
)
264+
# When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
265+
# Later, when these are compared with the global shape, we realize that they are replicated.
265266

266-
optimizer_state_dict[key + ".beta1_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32")
267-
optimizer_state_dict[key + ".beta2_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32")
267+
optimizer_state_dict[key + ".beta1_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32")
268+
optimizer_state_dict[key + ".beta2_pow_acc" + tp_rank_suffix] = paddle.zeros((1,), "float32")
268269

269270
malloc_size = 0
270271
for opt_state_name, opt_state_value in optimizer_state_dict.items():

0 commit comments

Comments
 (0)