@@ -81,15 +81,11 @@ def __init__(
8181 if eta <= 0 :
8282 raise ValueError (f"eta must be positive, got { eta } " )
8383 if rand_percent < 0 or rand_percent > 100 :
84- raise ValueError (
85- f"rand_percent must be in [0, 100], got { rand_percent } "
86- )
84+ raise ValueError (f"rand_percent must be in [0, 100], got { rand_percent } " )
8785 if threshold < 0 :
8886 raise ValueError (f"threshold must be non-negative, got { threshold } " )
8987 if num_pre_loss < 1 :
90- raise ValueError (
91- f"num_pre_loss must be >= 1, got { num_pre_loss } "
92- )
88+ raise ValueError (f"num_pre_loss must be >= 1, got { num_pre_loss } " )
9389 if max_ala_epochs is not None and max_ala_epochs < 1 :
9490 raise ValueError (
9591 f"max_ala_epochs must be >= 1 or None, got { max_ala_epochs } "
@@ -378,9 +374,7 @@ def _move_to_device(value: Any, device: torch.device) -> Any:
378374 if isinstance (value , torch .Tensor ):
379375 return value .to (device )
380376 if isinstance (value , tuple ):
381- return tuple (
382- FedALAUpdateStrategy ._move_to_device (v , device ) for v in value
383- )
377+ return tuple (FedALAUpdateStrategy ._move_to_device (v , device ) for v in value )
384378 if isinstance (value , list ):
385379 return [FedALAUpdateStrategy ._move_to_device (v , device ) for v in value ]
386380 if isinstance (value , dict ):
@@ -450,9 +444,7 @@ def _adaptive_local_aggregation(
450444 examples = self ._move_to_device (
451445 examples , next (model_t .parameters ()).device
452446 )
453- labels = self ._move_to_device (
454- labels , next (model_t .parameters ()).device
455- )
447+ labels = self ._move_to_device (labels , next (model_t .parameters ()).device )
456448
457449 optimizer .zero_grad ()
458450 output = model_t (examples )
@@ -526,9 +518,7 @@ def _ensure_weights(self, params: Iterable[torch.Tensor]) -> None:
526518
527519 for weight , param in zip (self .weights , params_list ):
528520 if weight .shape != param .data .shape :
529- self .weights = [
530- torch .ones_like (param .data ) for param in params_list
531- ]
521+ self .weights = [torch .ones_like (param .data ) for param in params_list ]
532522 return
533523
534524
0 commit comments