@@ -107,7 +107,9 @@ def __init__(
107107 @property
108108 def is_distributed (self ) -> bool : # pragma: no-cover
109109 """Legacy property kept for backwards compatibility."""
110- rank_zero_deprecation (f"`{ type (self ).__name__ } .is_distributed` is deprecated. Use is discouraged." , stacklevel = 6 )
110+ rank_zero_deprecation (
111+ f"`{ type (self ).__name__ } .is_distributed` is deprecated. Use is discouraged." , stacklevel = 6
112+ )
111113 return True
112114
113115 @property
@@ -227,7 +229,9 @@ def _register_ddp_hooks(self) -> None:
227229 def _enable_model_averaging (self ) -> None :
228230 log .debug (f"{ self .__class__ .__name__ } : reinitializing optimizers with post localSGD" )
229231 if self ._model_averaging_period is None :
230- raise ValueError ("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." )
232+ raise ValueError (
233+ "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
234+ )
231235 from torch .distributed .optim import DistributedOptimizer , PostLocalSGDOptimizer , ZeroRedundancyOptimizer
232236
233237 for optimizer in self .optimizers :
@@ -236,7 +240,10 @@ def _enable_model_averaging(self) -> None:
236240
237241 is_distributed_optimizer = isinstance (optimizer , DistributedOptimizer ) if not _IS_WINDOWS else False
238242 if isinstance (optimizer , (ZeroRedundancyOptimizer , PostLocalSGDOptimizer )) or is_distributed_optimizer :
239- raise ValueError (f"Currently model averaging cannot work with a distributed optimizer of type " f"{ optimizer .__class__ .__name__ } ." )
243+ raise ValueError (
244+ f"Currently model averaging cannot work with a distributed optimizer of type "
245+ f"{ optimizer .__class__ .__name__ } ."
246+ )
240247
241248 assert self ._ddp_comm_state is not None
242249 self ._model_averager = torch .distributed .algorithms .model_averaging .averagers .PeriodicModelAverager (
@@ -316,7 +323,9 @@ def model_to_device(self) -> None:
316323 self .model .to (self .root_device )
317324
318325 @override
319- def reduce (self , tensor : Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = "mean" ) -> Tensor :
326+ def reduce (
327+ self , tensor : Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = "mean"
328+ ) -> Tensor :
320329 """Reduces a tensor from several distributed processes to one aggregated tensor.
321330
322331 Args:
0 commit comments