Skip to content

Commit a490152

Browse files
support RankZeroFirst (#1163)
1 parent 65a8798 commit a490152

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

docs/zh/api/utils/misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- PrettyOrderedDict
99
- Prettydefaultdict
1010
- RankZeroOnly
11+
- RankZeroFirst
1112
- Timer
1213
- all_gather
1314
- concat_dict_list

ppsci/utils/misc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"PrettyOrderedDict",
4141
"Prettydefaultdict",
4242
"RankZeroOnly",
43+
"RankZeroFirst",
4344
"Timer",
4445
"all_gather",
4546
"concat_dict_list",
@@ -189,6 +190,40 @@ def __exit__(self, exc_type, exc_value, traceback):
189190
dist.barrier()
190191

191192

193+
class RankZeroFirst(ContextDecorator):
194+
"""
195+
A context manager that ensures the code inside it is only executed by the process
196+
with rank zero first. All ranks will be synchronized by `dist.barrier()`.
197+
198+
Args:
199+
rank (Optional[int]): The rank of the current process. If not provided,
200+
it will be obtained from `dist.get_rank()`.
201+
202+
Examples:
203+
>>> import paddle.distributed as dist
204+
>>> with RankZeroFirst(dist.get_rank()):
205+
... # code here which should be executed first in the master(rank-0) process
206+
... pass
207+
"""
208+
209+
def __init__(self, rank: Optional[int] = None):
210+
if dist.is_initialized():
211+
self.rank = rank if rank is not None else dist.get_rank()
212+
self.world_size = dist.get_world_size()
213+
else:
214+
self.rank = 0
215+
self.world_size = 1
216+
self.is_master = self.rank == 0
217+
218+
def __enter__(self):
219+
if self.world_size > 1 and not self.is_master:
220+
dist.barrier() # Non-master processs wait for master to finish
221+
222+
def __exit__(self, type, value, traceback):
223+
if self.world_size > 1 and self.is_master:
224+
dist.barrier() # Allow others to proceed
225+
226+
192227
class Timer(ContextDecorator):
193228
"""Count time cost for code block within context.
194229

0 commit comments

Comments
 (0)