File tree Expand file tree Collapse file tree 2 files changed +19
-3
lines changed
python/paddle/distributed/fleet/base Expand file tree Collapse file tree 2 files changed +19
-3
lines changed Original file line number Diff line number Diff line change @@ -141,9 +141,9 @@ message DistributedStrategy {
141
141
optional bool fuse_all_reduce_ops = 18 [ default = true ];
142
142
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
143
143
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
144
- optional bool cudnn_exhaustive_search = 21 [ default = true ];
145
- optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
146
- optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
144
+ optional bool cudnn_exhaustive_search = 21 [ default = false ];
145
+ optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
146
+ optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = false ];
147
147
optional bool adaptive_localsgd = 24 [ default = false ];
148
148
optional bool fp16_allreduce = 25 [ default = false ];
149
149
optional bool sharding = 26 [ default = false ];
Original file line number Diff line number Diff line change @@ -118,6 +118,22 @@ def __init__(self):
118
118
119
119
"""
120
120
self .strategy = distributed_strategy_pb2 .DistributedStrategy ()
121
+
122
+ # Set the default values of the following flags to the ones set by users
123
+ key = 'FLAGS_cudnn_batchnorm_spatial_persistent'
124
+ if core .globals ().is_public (key ):
125
+ self .strategy .cudnn_batchnorm_spatial_persistent = bool (
126
+ core .globals ()[key ])
127
+ key = 'FLAGS_conv_workspace_size_limit'
128
+ if core .globals ().is_public (key ):
129
+ self .strategy .conv_workspace_size_limit = int (core .globals ()[key ])
130
+ key = 'FLAGS_cudnn_exhaustive_search'
131
+ if core .globals ().is_public (key ):
132
+ self .strategy .cudnn_exhaustive_search = bool (core .globals ()[key ])
133
+ key = 'FLAGS_sync_nccl_allreduce'
134
+ if core .globals ().is_public (key ):
135
+ self .strategy .sync_nccl_allreduce = bool (core .globals ()[key ])
136
+
121
137
self .__lock_attr = True
122
138
123
139
def __setattr__ (self , key , value ):
You can’t perform that action at this time.
0 commit comments