@@ -214,8 +214,8 @@ def gen_metadata_and_prepare_source_state_dict(self):
214
214
assert self .model_meta is not None
215
215
global_model_state_shapes = []
216
216
sharding_metas_keys = []
217
- for i in range (self .pp_degree ):
218
- for j in range (self .tp_degree ):
217
+ for i in range (self .tp_degree ):
218
+ for j in range (self .pp_degree ):
219
219
sharding_metas_keys .append ("tp{:02d}_pp{:02d}" .format (i , j ))
220
220
for key in sharding_metas_keys :
221
221
param_meta = self .model_meta ["sharding_metas" ][key ]["param_meta" ]
@@ -247,24 +247,25 @@ def gen_metadata_and_prepare_source_state_dict(self):
247
247
# Generate the optimizer states corresponding to the model weights.
248
248
logger .info ("Requesting GPU memory space to concatenate tensors split by sharding1 v2." )
249
249
optimizer_state_dict = {}
250
- for key in cur_rank_need_load_model_state_keys :
251
- for tp_rank in range (self .tp_degree ):
252
- tp_rank_suffix = "_tp{:02d}" .format (tp_rank )
253
- optimizer_state_dict [key + ".moment1" + tp_rank_suffix ] = paddle .zeros (
254
- (param_flattened_shapes [key ],), "float32"
255
- )
256
- optimizer_state_dict [key + ".moment2" + tp_rank_suffix ] = paddle .zeros (
257
- (param_flattened_shapes [key ],), "float32"
258
- )
259
- if self .optimizer_state_with_master_weights :
260
- optimizer_state_dict [key + ".master_weight" + tp_rank_suffix ] = paddle .zeros (
250
+ with paddle .base .dygraph .guard (place = paddle .CPUPlace ()):
251
+ for key in cur_rank_need_load_model_state_keys :
252
+ for tp_rank in range (self .tp_degree ):
253
+ tp_rank_suffix = "_tp{:02d}" .format (tp_rank )
254
+ optimizer_state_dict [key + ".moment1" + tp_rank_suffix ] = paddle .zeros (
261
255
(param_flattened_shapes [key ],), "float32"
262
256
)
263
- # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
264
- # Later, when these are compared with the global shape, we realize that they are replicated.
257
+ optimizer_state_dict [key + ".moment2" + tp_rank_suffix ] = paddle .zeros (
258
+ (param_flattened_shapes [key ],), "float32"
259
+ )
260
+ if self .optimizer_state_with_master_weights :
261
+ optimizer_state_dict [key + ".master_weight" + tp_rank_suffix ] = paddle .zeros (
262
+ (param_flattened_shapes [key ],), "float32"
263
+ )
264
+ # When handling tensor parallelism (TP), if some tensors are replicated, we initially assume that they are partitioned.
265
+ # Later, when these are compared with the global shape, we realize that they are replicated.
265
266
266
- optimizer_state_dict [key + ".beta1_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
267
- optimizer_state_dict [key + ".beta2_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
267
+ optimizer_state_dict [key + ".beta1_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
268
+ optimizer_state_dict [key + ".beta2_pow_acc" + tp_rank_suffix ] = paddle .zeros ((1 ,), "float32" )
268
269
269
270
malloc_size = 0
270
271
for opt_state_name , opt_state_value in optimizer_state_dict .items ():
0 commit comments