1515
1616import logging
1717import os
18- from functools import wraps
19- from typing import Callable , Optional , TypeVar , overload
18+ from typing import Optional
2019
2120import lightning_utilities .core .rank_zero as rank_zero_module
2221
2928 rank_zero_info ,
3029 rank_zero_warn ,
3130)
32- from typing_extensions import ParamSpec
33-
34- from lightning .fabric .utilities .imports import _UTILITIES_GREATER_EQUAL_0_10
3531
3632rank_zero_module .log = logging .getLogger (__name__ )
3733
@@ -48,33 +44,7 @@ def _get_rank() -> Optional[int]:
4844 return None
4945
5046
51- if not _UTILITIES_GREATER_EQUAL_0_10 :
52- T = TypeVar ("T" )
53- P = ParamSpec ("P" )
54-
55- @overload
56- def rank_zero_only (fn : Callable [P , T ]) -> Callable [P , Optional [T ]]:
57- """Rank zero only."""
58-
59- @overload
60- def rank_zero_only (fn : Callable [P , T ], default : T ) -> Callable [P , T ]:
61- """Rank zero only."""
62-
63- def rank_zero_only (fn : Callable [P , T ], default : Optional [T ] = None ) -> Callable [P , Optional [T ]]:
64- @wraps (fn )
65- def wrapped_fn (* args : P .args , ** kwargs : P .kwargs ) -> Optional [T ]:
66- rank = getattr (rank_zero_only , "rank" , None )
67- if rank is None :
68- raise RuntimeError ("The `rank_zero_only.rank` needs to be set before use" )
69- if rank == 0 :
70- return fn (* args , ** kwargs )
71- return default
72-
73- return wrapped_fn
74-
75- rank_zero_module .rank_zero_only .rank = getattr (rank_zero_module .rank_zero_only , "rank" , _get_rank () or 0 )
76- else :
77- rank_zero_only = rank_zero_module .rank_zero_only
47+ rank_zero_only = rank_zero_module .rank_zero_only
7848
7949# add the attribute to the function but don't overwrite in case Trainer has already set it
8050rank_zero_only .rank = getattr (rank_zero_only , "rank" , _get_rank () or 0 )
0 commit comments