29
29
from paddle .fluid import layers
30
30
31
31
import logging
32
- logging .basicConfig (
33
- format = '%(asctime)s %(levelname)-8s %(message)s' ,
34
- datefmt = '%Y-%m-%d %H:%M:%S' )
32
+ logger = logging .getLogger (__name__ )
33
+ formatter = logging .Formatter (
34
+ fmt = '%(asctime)s %(levelname)-8s %(message)s' , datefmt = '%Y-%m-%d %H:%M:%S' )
35
+ ch = logging .StreamHandler ()
36
+ ch .setFormatter (formatter )
37
+ logger .addHandler (ch )
35
38
from functools import reduce
36
39
37
40
__all__ = ["ShardingOptimizer" ]
@@ -136,7 +139,7 @@ def minimize_impl(self,
136
139
137
140
# FIXME (JZ-LIANG) deprecated hybrid_dp
138
141
if self .user_defined_strategy .sharding_configs ["hybrid_dp" ]:
139
- logging .warning (
142
+ logger .warning (
140
143
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
141
144
)
142
145
assert self .dp_degree >= 1
@@ -174,7 +177,7 @@ def minimize_impl(self,
174
177
self ._gradient_merge_acc_step = self .user_defined_strategy .pipeline_configs [
175
178
'accumulate_steps' ]
176
179
if self ._gradient_merge_acc_step > 1 :
177
- logging .info ("Gradient merge in [{}], acc step = [{}]" .format (
180
+ logger .info ("Gradient merge in [{}], acc step = [{}]" .format (
178
181
self .gradient_merge_mode , self ._gradient_merge_acc_step ))
179
182
180
183
# optimize offload
@@ -338,7 +341,7 @@ def minimize_impl(self,
338
341
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
339
342
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
340
343
if self .optimize_offload :
341
- logging .info ("Sharding with optimize offload !" )
344
+ logger .info ("Sharding with optimize offload !" )
342
345
offload_helper = OffloadHelper ()
343
346
offload_helper .offload (main_block , startup_block )
344
347
offload_helper .offload_fp32param (main_block , startup_block )
@@ -641,15 +644,15 @@ def _split_program(self, block):
641
644
for varname in sorted (
642
645
var2broadcast_time , key = var2broadcast_time .get ,
643
646
reverse = True ):
644
- logging .info ("Sharding broadcast: [{}] times [{}]" .format (
647
+ logger .info ("Sharding broadcast: [{}] times [{}]" .format (
645
648
var2broadcast_time [varname ], varname ))
646
649
for idx_ in range (len (self ._segments )):
647
- logging .info ("segment [{}] :" .format (idx_ ))
648
- logging .info ("start op: [{}] [{}]" .format (block .ops [
650
+ logger .info ("segment [{}] :" .format (idx_ ))
651
+ logger .info ("start op: [{}] [{}]" .format (block .ops [
649
652
self ._segments [idx_ ]._start_idx ].desc .type (), block .ops [
650
653
self ._segments [idx_ ]._start_idx ].desc .input_arg_names (
651
654
)))
652
- logging .info ("end op: [{}] [{}]" .format (block .ops [
655
+ logger .info ("end op: [{}] [{}]" .format (block .ops [
653
656
self ._segments [idx_ ]._end_idx ].desc .type (), block .ops [
654
657
self ._segments [idx_ ]._end_idx ].desc .input_arg_names ()))
655
658
return
@@ -1108,7 +1111,7 @@ def _build_groups(self):
1108
1111
self .dp_group_endpoints .append (self .global_endpoints [
1109
1112
dp_first_rank_idx + dp_offset * i ])
1110
1113
assert self .current_endpoint in self .dp_group_endpoints
1111
- logging .info ("Hybrid DP mode turn on !" )
1114
+ logger .info ("Hybrid DP mode turn on !" )
1112
1115
else :
1113
1116
self .dp_ring_id = - 1
1114
1117
self .dp_rank = - 1
@@ -1119,40 +1122,40 @@ def _build_groups(self):
1119
1122
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
1120
1123
self .global_ring_id = 3
1121
1124
1122
- logging .info ("global word size: {}" .format (self .global_word_size ))
1123
- logging .info ("global rank: {}" .format (self .global_rank ))
1124
- logging .info ("global endpoints: {}" .format (self .global_endpoints ))
1125
- logging .info ("global ring id: {}" .format (self .global_ring_id ))
1126
- logging .info ("#####" * 6 )
1127
-
1128
- logging .info ("mp group size: {}" .format (self .mp_degree ))
1129
- logging .info ("mp rank: {}" .format (self .mp_rank ))
1130
- logging .info ("mp group id: {}" .format (self .mp_group_id ))
1131
- logging .info ("mp group endpoints: {}" .format (self .mp_group_endpoints ))
1132
- logging .info ("mp ring id: {}" .format (self .mp_ring_id ))
1133
- logging .info ("#####" * 6 )
1134
-
1135
- logging .info ("sharding group size: {}" .format (self .sharding_degree ))
1136
- logging .info ("sharding rank: {}" .format (self .sharding_rank ))
1137
- logging .info ("sharding group id: {}" .format (self .sharding_group_id ))
1138
- logging .info ("sharding group endpoints: {}" .format (
1125
+ logger .info ("global word size: {}" .format (self .global_word_size ))
1126
+ logger .info ("global rank: {}" .format (self .global_rank ))
1127
+ logger .info ("global endpoints: {}" .format (self .global_endpoints ))
1128
+ logger .info ("global ring id: {}" .format (self .global_ring_id ))
1129
+ logger .info ("#####" * 6 )
1130
+
1131
+ logger .info ("mp group size: {}" .format (self .mp_degree ))
1132
+ logger .info ("mp rank: {}" .format (self .mp_rank ))
1133
+ logger .info ("mp group id: {}" .format (self .mp_group_id ))
1134
+ logger .info ("mp group endpoints: {}" .format (self .mp_group_endpoints ))
1135
+ logger .info ("mp ring id: {}" .format (self .mp_ring_id ))
1136
+ logger .info ("#####" * 6 )
1137
+
1138
+ logger .info ("sharding group size: {}" .format (self .sharding_degree ))
1139
+ logger .info ("sharding rank: {}" .format (self .sharding_rank ))
1140
+ logger .info ("sharding group id: {}" .format (self .sharding_group_id ))
1141
+ logger .info ("sharding group endpoints: {}" .format (
1139
1142
self .sharding_group_endpoints ))
1140
- logging .info ("sharding ring id: {}" .format (self .sharding_ring_id ))
1141
- logging .info ("#####" * 6 )
1142
-
1143
- logging .info ("pp group size: {}" .format (self .pp_degree ))
1144
- logging .info ("pp rank: {}" .format (self .pp_rank ))
1145
- logging .info ("pp group id: {}" .format (self .pp_group_id ))
1146
- logging .info ("pp group endpoints: {}" .format (self .pp_group_endpoints ))
1147
- logging .info ("pp ring id: {}" .format (self .pp_ring_id ))
1148
- logging .info ("#####" * 6 )
1149
-
1150
- logging .info ("pure dp group size: {}" .format (self .dp_degree ))
1151
- logging .info ("pure dp rank: {}" .format (self .dp_rank ))
1152
- logging .info ("pure dp group endpoints: {}" .format (
1143
+ logger .info ("sharding ring id: {}" .format (self .sharding_ring_id ))
1144
+ logger .info ("#####" * 6 )
1145
+
1146
+ logger .info ("pp group size: {}" .format (self .pp_degree ))
1147
+ logger .info ("pp rank: {}" .format (self .pp_rank ))
1148
+ logger .info ("pp group id: {}" .format (self .pp_group_id ))
1149
+ logger .info ("pp group endpoints: {}" .format (self .pp_group_endpoints ))
1150
+ logger .info ("pp ring id: {}" .format (self .pp_ring_id ))
1151
+ logger .info ("#####" * 6 )
1152
+
1153
+ logger .info ("pure dp group size: {}" .format (self .dp_degree ))
1154
+ logger .info ("pure dp rank: {}" .format (self .dp_rank ))
1155
+ logger .info ("pure dp group endpoints: {}" .format (
1153
1156
self .dp_group_endpoints ))
1154
- logging .info ("pure dp ring id: {}" .format (self .dp_ring_id ))
1155
- logging .info ("#####" * 6 )
1157
+ logger .info ("pure dp ring id: {}" .format (self .dp_ring_id ))
1158
+ logger .info ("#####" * 6 )
1156
1159
1157
1160
return
1158
1161
0 commit comments