Skip to content

Commit d15e73b

Browse files
author
lilong12
authored
[CP] align fleet param (#31220)
* update, test=develop (#30692) * align the default value of some configuration for fleet to that of single cards (#30740) * update, test=develop
1 parent 98c4c78 commit d15e73b

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ message DistributedStrategy {
141141
optional bool fuse_all_reduce_ops = 18 [ default = true ];
142142
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
143143
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
144-
optional bool cudnn_exhaustive_search = 21 [ default = true ];
145-
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
146-
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
144+
optional bool cudnn_exhaustive_search = 21 [ default = false ];
145+
optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
146+
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = false ];
147147
optional bool adaptive_localsgd = 24 [ default = false ];
148148
optional bool fp16_allreduce = 25 [ default = false ];
149149
optional bool sharding = 26 [ default = false ];

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,22 @@ def __init__(self):
118118
119119
"""
120120
self.strategy = distributed_strategy_pb2.DistributedStrategy()
121+
122+
# Set the default values of the following flags to the ones set by users
123+
key = 'FLAGS_cudnn_batchnorm_spatial_persistent'
124+
if core.globals().is_public(key):
125+
self.strategy.cudnn_batchnorm_spatial_persistent = bool(
126+
core.globals()[key])
127+
key = 'FLAGS_conv_workspace_size_limit'
128+
if core.globals().is_public(key):
129+
self.strategy.conv_workspace_size_limit = int(core.globals()[key])
130+
key = 'FLAGS_cudnn_exhaustive_search'
131+
if core.globals().is_public(key):
132+
self.strategy.cudnn_exhaustive_search = bool(core.globals()[key])
133+
key = 'FLAGS_sync_nccl_allreduce'
134+
if core.globals().is_public(key):
135+
self.strategy.sync_nccl_allreduce = bool(core.globals()[key])
136+
121137
self.__lock_attr = True
122138

123139
def __setattr__(self, key, value):

0 commit comments

Comments
 (0)