Skip to content

Commit 54d7d9b

Browse files
committed
uvx ruff format.
1 parent bfe5fe9 commit 54d7d9b

File tree

9 files changed

+30
-32
lines changed

9 files changed

+30
-32
lines changed

examples/gradient_leakage_attacks/defense/Outpost/perturb.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def noise(dy_dx: list, risk: list):
5252
noise_mask, device=grad.device, dtype=grad.dtype
5353
)
5454
gauss_noise = noise_base * noise_mask_tensor
55-
grad = torch.as_tensor(
56-
grad_tensor, device=grad.device, dtype=torch.float32
57-
) + gauss_noise
55+
grad = (
56+
torch.as_tensor(grad_tensor, device=grad.device, dtype=torch.float32)
57+
+ gauss_noise
58+
)
5859

5960
return dy_dx

examples/gradient_leakage_attacks/dlg_trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,7 @@ def on_train_step_end(self, trainer, config, batch, loss, **kwargs) -> None:
332332
flattened_weights, Config().algorithm.prune_pct
333333
)
334334
pruned = np.where(np.abs(grad_tensor) < thresh, 0, grad_tensor)
335-
gradient_list[index] = torch.from_numpy(pruned).to(
336-
trainer.device
337-
)
335+
gradient_list[index] = torch.from_numpy(pruned).to(trainer.device)
338336

339337
elif (
340338
defense_name == "DP"

examples/model_search/fedtp/fedtp_algorithm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def __init__(self, trainer: Trainer):
3030
super().__init__(trainer)
3131
self.current_weights: OrderedDict[str, Tensor] | None = None
3232

33-
def generate_attention(self, hnet: Module, client_id: int) -> OrderedDict[str, Tensor]:
33+
def generate_attention(
34+
self, hnet: Module, client_id: int
35+
) -> OrderedDict[str, Tensor]:
3436
"""Generated the customized attention of each client."""
3537
weights = hnet(
3638
torch.tensor([client_id - 1], dtype=torch.long).to(Config().device())

examples/model_search/fedtp/hypernetworks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def forward(self, idx) -> OrderedDictType[str, Tensor]:
9393
else:
9494
layer_list = cast(nn.ModuleList, layer_d_qkv_value_hyper)
9595
layer_d_qkv_value = [
96-
layer(features).view(self.inner_dim, self.dim) for layer in layer_list
96+
layer(features).view(self.inner_dim, self.dim)
97+
for layer in layer_list
9798
]
9899
name = Config().parameters.hypernet.attention % (dep, dep, dep)
99100
names = name.split(",")

examples/server_aggregation/moon/moon_algorithm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class Algorithm(fedavg.Algorithm):
2020
"""Algorithm providing MOON aggregation utilities."""
2121

2222
@staticmethod
23-
def _cast_tensor_like(tensor: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
23+
def _cast_tensor_like(
24+
tensor: torch.Tensor, reference: torch.Tensor
25+
) -> torch.Tensor:
2426
"""Cast a tensor to match a reference dtype (handles bool/int safely)."""
2527
if tensor.dtype == reference.dtype:
2628
return tensor

plato/servers/strategies/aggregation/pfedgraph.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def setup(self, context: ServerContext) -> None:
9595
if algorithm is not None and hasattr(algorithm, "extract_weights"):
9696
baseline = algorithm.extract_weights()
9797
self.initial_weights = {
98-
name: tensor.detach().cpu().clone()
99-
for name, tensor in baseline.items()
98+
name: tensor.detach().cpu().clone() for name, tensor in baseline.items()
10099
}
101100

102101
if self.total_clients > 0:
@@ -257,8 +256,7 @@ def _aggregate_client_weights(
257256
for client_id in selected_ids:
258257
row = graph[client_id, selected_ids]
259258
aggregated = {
260-
name: torch.zeros_like(value)
261-
for name, value in weights_cpu[0].items()
259+
name: torch.zeros_like(value) for name, value in weights_cpu[0].items()
262260
}
263261

264262
for neighbor_idx, weight in enumerate(row):
@@ -283,10 +281,13 @@ def _aggregate_global_weights(
283281
for idx, client_weights in enumerate(aggregated_weights):
284282
weight = float(sample_weights[idx])
285283
for name, tensor in client_weights.items():
286-
global_weights[name] += tensor.to(
287-
dtype=baseline_weights[name].dtype,
288-
device=baseline_weights[name].device,
289-
) * weight
284+
global_weights[name] += (
285+
tensor.to(
286+
dtype=baseline_weights[name].dtype,
287+
device=baseline_weights[name].device,
288+
)
289+
* weight
290+
)
290291

291292
return global_weights
292293

plato/trainers/huggingface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def on_train_step_end(self, trainer, config, batch, loss, **kwargs):
415415

416416
class Trainer(ComposableTrainer):
417417
"""Composable HuggingFace trainer built on Plato's strategy API."""
418+
418419
training_args: TrainingArguments
419420

420421
def __init__(self, model=None, callbacks=None):

plato/trainers/strategies/algorithms/fedala_strategy.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,11 @@ def __init__(
8181
if eta <= 0:
8282
raise ValueError(f"eta must be positive, got {eta}")
8383
if rand_percent < 0 or rand_percent > 100:
84-
raise ValueError(
85-
f"rand_percent must be in [0, 100], got {rand_percent}"
86-
)
84+
raise ValueError(f"rand_percent must be in [0, 100], got {rand_percent}")
8785
if threshold < 0:
8886
raise ValueError(f"threshold must be non-negative, got {threshold}")
8987
if num_pre_loss < 1:
90-
raise ValueError(
91-
f"num_pre_loss must be >= 1, got {num_pre_loss}"
92-
)
88+
raise ValueError(f"num_pre_loss must be >= 1, got {num_pre_loss}")
9389
if max_ala_epochs is not None and max_ala_epochs < 1:
9490
raise ValueError(
9591
f"max_ala_epochs must be >= 1 or None, got {max_ala_epochs}"
@@ -378,9 +374,7 @@ def _move_to_device(value: Any, device: torch.device) -> Any:
378374
if isinstance(value, torch.Tensor):
379375
return value.to(device)
380376
if isinstance(value, tuple):
381-
return tuple(
382-
FedALAUpdateStrategy._move_to_device(v, device) for v in value
383-
)
377+
return tuple(FedALAUpdateStrategy._move_to_device(v, device) for v in value)
384378
if isinstance(value, list):
385379
return [FedALAUpdateStrategy._move_to_device(v, device) for v in value]
386380
if isinstance(value, dict):
@@ -450,9 +444,7 @@ def _adaptive_local_aggregation(
450444
examples = self._move_to_device(
451445
examples, next(model_t.parameters()).device
452446
)
453-
labels = self._move_to_device(
454-
labels, next(model_t.parameters()).device
455-
)
447+
labels = self._move_to_device(labels, next(model_t.parameters()).device)
456448

457449
optimizer.zero_grad()
458450
output = model_t(examples)
@@ -526,9 +518,7 @@ def _ensure_weights(self, params: Iterable[torch.Tensor]) -> None:
526518

527519
for weight, param in zip(self.weights, params_list):
528520
if weight.shape != param.data.shape:
529-
self.weights = [
530-
torch.ones_like(param.data) for param in params_list
531-
]
521+
self.weights = [torch.ones_like(param.data) for param in params_list]
532522
return
533523

534524

plato/utils/reinforcement_learning/policies/sac.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def forward(self, state, action):
107107
class GaussianPolicy(nn.Module):
108108
action_scale: torch.Tensor
109109
action_bias: torch.Tensor
110+
110111
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
111112
super().__init__()
112113

@@ -160,6 +161,7 @@ def sample(self, state) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
160161
class DeterministicPolicy(nn.Module):
161162
action_scale: torch.Tensor
162163
action_bias: torch.Tensor
164+
163165
def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
164166
super().__init__()
165167
self.linear1 = nn.Linear(num_inputs, hidden_dim)

0 commit comments

Comments
 (0)