Skip to content

Commit 0174765

Browse files
authored
refactor(autojac): Avoid reassignment in interface functions (#368)
1 parent 4b67d1d commit 0174765

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

src/torchjd/_autojac/_backward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,20 @@ def backward(
7777
"""
7878
check_optional_positive_chunk_size(parallel_chunk_size)
7979

80-
tensors = as_checked_ordered_set(tensors, "tensors")
80+
tensors_ = as_checked_ordered_set(tensors, "tensors")
8181

82-
if len(tensors) == 0:
82+
if len(tensors_) == 0:
8383
raise ValueError("`tensors` cannot be empty")
8484

8585
if inputs is None:
86-
inputs = get_leaf_tensors(tensors=tensors, excluded=set())
86+
inputs_ = get_leaf_tensors(tensors=tensors_, excluded=set())
8787
else:
88-
inputs = OrderedSet(inputs)
88+
inputs_ = OrderedSet(inputs)
8989

9090
backward_transform = _create_transform(
91-
tensors=tensors,
91+
tensors=tensors_,
9292
aggregator=aggregator,
93-
inputs=inputs,
93+
inputs=inputs_,
9494
retain_graph=retain_graph,
9595
parallel_chunk_size=parallel_chunk_size,
9696
)

src/torchjd/_autojac/_mtl_backward.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,35 +81,35 @@ def mtl_backward(
8181

8282
check_optional_positive_chunk_size(parallel_chunk_size)
8383

84-
losses = as_checked_ordered_set(losses, "losses")
85-
features = as_checked_ordered_set(features, "features")
84+
losses_ = as_checked_ordered_set(losses, "losses")
85+
features_ = as_checked_ordered_set(features, "features")
8686

8787
if shared_params is None:
88-
shared_params = get_leaf_tensors(tensors=features, excluded=[])
88+
shared_params_ = get_leaf_tensors(tensors=features_, excluded=[])
8989
else:
90-
shared_params = OrderedSet(shared_params)
90+
shared_params_ = OrderedSet(shared_params)
9191
if tasks_params is None:
92-
tasks_params = [get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
92+
tasks_params_ = [get_leaf_tensors(tensors=[loss], excluded=features_) for loss in losses_]
9393
else:
94-
tasks_params = [OrderedSet(task_params) for task_params in tasks_params]
94+
tasks_params_ = [OrderedSet(task_params) for task_params in tasks_params]
9595

96-
if len(features) == 0:
96+
if len(features_) == 0:
9797
raise ValueError("`features` cannot be empty.")
9898

99-
_check_no_overlap(shared_params, tasks_params)
100-
_check_losses_are_scalar(losses)
99+
_check_no_overlap(shared_params_, tasks_params_)
100+
_check_losses_are_scalar(losses_)
101101

102-
if len(losses) == 0:
102+
if len(losses_) == 0:
103103
raise ValueError("`losses` cannot be empty")
104-
if len(losses) != len(tasks_params):
104+
if len(losses_) != len(tasks_params_):
105105
raise ValueError("`losses` and `tasks_params` should have the same size.")
106106

107107
backward_transform = _create_transform(
108-
losses=losses,
109-
features=features,
108+
losses=losses_,
109+
features=features_,
110110
aggregator=aggregator,
111-
tasks_params=tasks_params,
112-
shared_params=shared_params,
111+
tasks_params=tasks_params_,
112+
shared_params=shared_params_,
113113
retain_graph=retain_graph,
114114
parallel_chunk_size=parallel_chunk_size,
115115
)

0 commit comments

Comments
 (0)