Skip to content

Commit 0bb079c

Browse files
author
Feiyu Chan
authored
avoid polluting logging's root logger (#32673) (#32706)
avoid polluting logging's root logger
1 parent a9d330a commit 0bb079c

File tree

4 files changed

+64
-52
lines changed

4 files changed

+64
-52
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@
2929
from paddle.fluid import layers
3030

3131
import logging
32-
logging.basicConfig(
33-
format='%(asctime)s %(levelname)-8s %(message)s',
34-
datefmt='%Y-%m-%d %H:%M:%S')
32+
logger = logging.getLogger(__name__)
33+
formatter = logging.Formatter(
34+
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
35+
ch = logging.StreamHandler()
36+
ch.setFormatter(formatter)
37+
logger.addHandler(ch)
3538
from functools import reduce
3639

3740
__all__ = ["ShardingOptimizer"]
@@ -136,7 +139,7 @@ def minimize_impl(self,
136139

137140
# FIXME (JZ-LIANG) deprecated hybrid_dp
138141
if self.user_defined_strategy.sharding_configs["hybrid_dp"]:
139-
logging.warning(
142+
logger.warning(
140143
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
141144
)
142145
assert self.dp_degree >= 1
@@ -174,7 +177,7 @@ def minimize_impl(self,
174177
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
175178
'accumulate_steps']
176179
if self._gradient_merge_acc_step > 1:
177-
logging.info("Gradient merge in [{}], acc step = [{}]".format(
180+
logger.info("Gradient merge in [{}], acc step = [{}]".format(
178181
self.gradient_merge_mode, self._gradient_merge_acc_step))
179182

180183
# optimize offload
@@ -338,7 +341,7 @@ def minimize_impl(self,
338341
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
339342
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
340343
if self.optimize_offload:
341-
logging.info("Sharding with optimize offload !")
344+
logger.info("Sharding with optimize offload !")
342345
offload_helper = OffloadHelper()
343346
offload_helper.offload(main_block, startup_block)
344347
offload_helper.offload_fp32param(main_block, startup_block)
@@ -641,15 +644,15 @@ def _split_program(self, block):
641644
for varname in sorted(
642645
var2broadcast_time, key=var2broadcast_time.get,
643646
reverse=True):
644-
logging.info("Sharding broadcast: [{}] times [{}]".format(
647+
logger.info("Sharding broadcast: [{}] times [{}]".format(
645648
var2broadcast_time[varname], varname))
646649
for idx_ in range(len(self._segments)):
647-
logging.info("segment [{}] :".format(idx_))
648-
logging.info("start op: [{}] [{}]".format(block.ops[
650+
logger.info("segment [{}] :".format(idx_))
651+
logger.info("start op: [{}] [{}]".format(block.ops[
649652
self._segments[idx_]._start_idx].desc.type(), block.ops[
650653
self._segments[idx_]._start_idx].desc.input_arg_names(
651654
)))
652-
logging.info("end op: [{}] [{}]".format(block.ops[
655+
logger.info("end op: [{}] [{}]".format(block.ops[
653656
self._segments[idx_]._end_idx].desc.type(), block.ops[
654657
self._segments[idx_]._end_idx].desc.input_arg_names()))
655658
return
@@ -1108,7 +1111,7 @@ def _build_groups(self):
11081111
self.dp_group_endpoints.append(self.global_endpoints[
11091112
dp_first_rank_idx + dp_offset * i])
11101113
assert self.current_endpoint in self.dp_group_endpoints
1111-
logging.info("Hybrid DP mode turn on !")
1114+
logger.info("Hybrid DP mode turn on !")
11121115
else:
11131116
self.dp_ring_id = -1
11141117
self.dp_rank = -1
@@ -1119,40 +1122,40 @@ def _build_groups(self):
11191122
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
11201123
self.global_ring_id = 3
11211124

1122-
logging.info("global word size: {}".format(self.global_word_size))
1123-
logging.info("global rank: {}".format(self.global_rank))
1124-
logging.info("global endpoints: {}".format(self.global_endpoints))
1125-
logging.info("global ring id: {}".format(self.global_ring_id))
1126-
logging.info("#####" * 6)
1127-
1128-
logging.info("mp group size: {}".format(self.mp_degree))
1129-
logging.info("mp rank: {}".format(self.mp_rank))
1130-
logging.info("mp group id: {}".format(self.mp_group_id))
1131-
logging.info("mp group endpoints: {}".format(self.mp_group_endpoints))
1132-
logging.info("mp ring id: {}".format(self.mp_ring_id))
1133-
logging.info("#####" * 6)
1134-
1135-
logging.info("sharding group size: {}".format(self.sharding_degree))
1136-
logging.info("sharding rank: {}".format(self.sharding_rank))
1137-
logging.info("sharding group id: {}".format(self.sharding_group_id))
1138-
logging.info("sharding group endpoints: {}".format(
1125+
logger.info("global word size: {}".format(self.global_word_size))
1126+
logger.info("global rank: {}".format(self.global_rank))
1127+
logger.info("global endpoints: {}".format(self.global_endpoints))
1128+
logger.info("global ring id: {}".format(self.global_ring_id))
1129+
logger.info("#####" * 6)
1130+
1131+
logger.info("mp group size: {}".format(self.mp_degree))
1132+
logger.info("mp rank: {}".format(self.mp_rank))
1133+
logger.info("mp group id: {}".format(self.mp_group_id))
1134+
logger.info("mp group endpoints: {}".format(self.mp_group_endpoints))
1135+
logger.info("mp ring id: {}".format(self.mp_ring_id))
1136+
logger.info("#####" * 6)
1137+
1138+
logger.info("sharding group size: {}".format(self.sharding_degree))
1139+
logger.info("sharding rank: {}".format(self.sharding_rank))
1140+
logger.info("sharding group id: {}".format(self.sharding_group_id))
1141+
logger.info("sharding group endpoints: {}".format(
11391142
self.sharding_group_endpoints))
1140-
logging.info("sharding ring id: {}".format(self.sharding_ring_id))
1141-
logging.info("#####" * 6)
1142-
1143-
logging.info("pp group size: {}".format(self.pp_degree))
1144-
logging.info("pp rank: {}".format(self.pp_rank))
1145-
logging.info("pp group id: {}".format(self.pp_group_id))
1146-
logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
1147-
logging.info("pp ring id: {}".format(self.pp_ring_id))
1148-
logging.info("#####" * 6)
1149-
1150-
logging.info("pure dp group size: {}".format(self.dp_degree))
1151-
logging.info("pure dp rank: {}".format(self.dp_rank))
1152-
logging.info("pure dp group endpoints: {}".format(
1143+
logger.info("sharding ring id: {}".format(self.sharding_ring_id))
1144+
logger.info("#####" * 6)
1145+
1146+
logger.info("pp group size: {}".format(self.pp_degree))
1147+
logger.info("pp rank: {}".format(self.pp_rank))
1148+
logger.info("pp group id: {}".format(self.pp_group_id))
1149+
logger.info("pp group endpoints: {}".format(self.pp_group_endpoints))
1150+
logger.info("pp ring id: {}".format(self.pp_ring_id))
1151+
logger.info("#####" * 6)
1152+
1153+
logger.info("pure dp group size: {}".format(self.dp_degree))
1154+
logger.info("pure dp rank: {}".format(self.dp_rank))
1155+
logger.info("pure dp group endpoints: {}".format(
11531156
self.dp_group_endpoints))
1154-
logging.info("pure dp ring id: {}".format(self.dp_ring_id))
1155-
logging.info("#####" * 6)
1157+
logger.info("pure dp ring id: {}".format(self.dp_ring_id))
1158+
logger.info("#####" * 6)
11561159

11571160
return
11581161

python/paddle/distributed/fleet/utils/recompute.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
import contextlib
2020

2121
import logging
22-
logging.basicConfig(
23-
format='%(asctime)s %(levelname)-8s %(message)s',
24-
datefmt='%Y-%m-%d %H:%M:%S')
22+
logger = logging.getLogger(__name__)
23+
formatter = logging.Formatter(
24+
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
25+
ch = logging.StreamHandler()
26+
ch.setFormatter(formatter)
27+
logger.addHandler(ch)
2528

2629

2730
def detach_variable(inputs):
@@ -40,7 +43,7 @@ def detach_variable(inputs):
4043
def check_recompute_necessary(inputs):
4144
if not any(input_.stop_gradient == False for input_ in inputs
4245
if isinstance(input_, paddle.Tensor)):
43-
logging.warn(
46+
logger.warn(
4447
"[Recompute]: None of the inputs to current recompute block need grad, "
4548
"therefore there is NO need to recompute this block in backward !")
4649

python/paddle/fluid/incubate/fleet/utils/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@
3434
"graphviz"
3535
]
3636

37-
logging.basicConfig(
38-
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
3937
logger = logging.getLogger(__name__)
38+
logger.setLevel(logging.INFO)
39+
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
40+
ch = logging.StreamHandler()
41+
ch.setFormatter(formatter)
42+
logger.addHandler(ch)
4043

4144
persistable_vars_out_fn = "vars_persistable.log"
4245
all_vars_out_fn = "vars_all.log"

python/paddle/utils/cpp_extension/extension_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@
3232
from ...fluid.framework import OpProtoHolder
3333
from ...sysconfig import get_include, get_lib
3434

35-
logging.basicConfig(
36-
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
3735
logger = logging.getLogger("utils.cpp_extension")
36+
logger.setLevel(logging.INFO)
37+
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
38+
ch = logging.StreamHandler()
39+
ch.setFormatter(formatter)
40+
logger.addHandler(ch)
3841

3942
OS_NAME = sys.platform
4043
IS_WINDOWS = OS_NAME.startswith('win')
@@ -1125,4 +1128,4 @@ def log_v(info, verbose=True):
11251128
Print log information on stdout.
11261129
"""
11271130
if verbose:
1128-
logging.info(info)
1131+
logger.info(info)

0 commit comments

Comments
 (0)