22
22
)
23
23
from paddle .distributed .fleet .utils .log_util import logger
24
24
25
+ from paddleformers .utils .tools import get_env_device
26
+
25
27
try :
26
28
from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer .dygraph_sharding_optimizer import (
27
29
DygraphShardingOptimizerV2 ,
@@ -61,6 +63,49 @@ def get_sharding_strategy(optimizer):
61
63
return SHARDING_STRATEGY_V1
62
64
63
65
66
+ def convert_opt_name_to_tname (tensor_names , opt_names ):
67
+ tensor_names = set (tensor_names )
68
+ all_names = []
69
+ all_names .extend (list (tensor_names ))
70
+ all_names .extend (opt_names )
71
+ all_names .sort ()
72
+ pre_t_name = ""
73
+ suffix = [
74
+ "_fp32_master_0_beta1_pow_acc_0" ,
75
+ "_fp32_master_0_beta2_pow_acc_0" ,
76
+ "_fp32_master_0_moment1_0" ,
77
+ "_fp32_master_0_moment2_0" ,
78
+ "_beta1_pow_acc_0" ,
79
+ "_beta2_pow_acc_0" ,
80
+ "_moment1_0" ,
81
+ "_moment2_0" ,
82
+ ]
83
+ opt_to_t = {}
84
+ for n in all_names :
85
+ if n in tensor_names :
86
+ # we get a param
87
+ pre_t_name = n
88
+ else :
89
+ assert pre_t_name
90
+ opt_to_t [n ] = pre_t_name
91
+
92
+ for t in opt_names :
93
+ _find = False
94
+ for s in suffix :
95
+ if get_env_device () == "xpu" and t .endswith (s + ".SCALE_VALUE" ):
96
+ # NOTE: for xpu adamw, all optimizer state will have an extra attribute end with SCALE_VALUE.
97
+ # This extra attribute won't be used, just skip it.
98
+ _find = True
99
+ break
100
+ if t .endswith (s ):
101
+ logger .info (f"{ t } -{ t [:- len (s )]} --{ t [:- len (s )] in tensor_names } " )
102
+ opt_to_t [t ] = t [: - len (s )]
103
+ _find = True
104
+ break
105
+ assert _find
106
+ return opt_to_t
107
+
108
+
64
109
class NodeModelState :
65
110
def __init__ (self , group ):
66
111
self ._model_weights = OrderedDict ()
@@ -259,43 +304,6 @@ def pack_keys(self, structure_name_mapping=None):
259
304
change the key of master weights dict from param_name to (structure_name, param_name)
260
305
"""
261
306
# pack key for pp convert
262
- def _opt_name_to_tname (tensor_names , opt_names ):
263
- tensor_names = set (tensor_names )
264
- all_names = []
265
- all_names .extend (list (tensor_names ))
266
- all_names .extend (opt_names )
267
- all_names .sort ()
268
- pre_t_name = ""
269
- suffix = [
270
- "_fp32_master_0_beta1_pow_acc_0" ,
271
- "_fp32_master_0_beta2_pow_acc_0" ,
272
- "_fp32_master_0_moment1_0" ,
273
- "_fp32_master_0_moment2_0" ,
274
- "_beta1_pow_acc_0" ,
275
- "_beta2_pow_acc_0" ,
276
- "_moment1_0" ,
277
- "_moment2_0" ,
278
- ]
279
- opt_to_t = {}
280
- for n in all_names :
281
- if n in tensor_names :
282
- # we get a param
283
- pre_t_name = n
284
- else :
285
- assert pre_t_name
286
- opt_to_t [n ] = pre_t_name
287
-
288
- for t in opt_names :
289
- _find = False
290
- for s in suffix :
291
- if t .endswith (s ):
292
- logger .info (f"{ t } -{ t [:- len (s )]} --{ t [:- len (s )] in tensor_names } " )
293
- opt_to_t [t ] = t [: - len (s )]
294
- _find = True
295
- break
296
- assert _find
297
- return opt_to_t
298
-
299
307
if structure_name_mapping is not None :
300
308
tname_to_structure_name = {v : k for (k , v ) in structure_name_mapping .items ()}
301
309
else :
@@ -304,7 +312,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
304
312
305
313
tensor_names = list (tname_to_structure_name .keys ())
306
314
opt_names = list (self ._opt_state .keys ())
307
- opt_name_to_tname = _opt_name_to_tname (tensor_names , opt_names )
315
+ opt_name_to_tname = convert_opt_name_to_tname (tensor_names , opt_names )
308
316
309
317
# model state
310
318
model_weights_tmp = OrderedDict ()
0 commit comments