Skip to content

Commit 28405e9

Browse files
authored
style(autojac): Add missing type hints (#367)
* Add type hint to union in _stack * Add type hint to result in _get_descendant_accumulate_grads * Be more specific about the type of Select when instantiating it
1 parent b381581 commit 28405e9

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/torchjd/_autojac/_mtl_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ def _create_task_transform(
179179

180180
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181181
# .grad fields.
182-
accumulate = Accumulate() << Select(task_params)
182+
accumulate = Accumulate() << Select[Gradients](task_params)
183183

184184
# Transform that backpropagates the gradients of the losses w.r.t. the features.
185-
backpropagate = Select(features)
185+
backpropagate = Select[Gradients](features)
186186

187187
# Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into
188188
# their .grad fields and backpropagates the gradient of the losses w.r.t. to the features.

src/torchjd/_autojac/_transform/_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
3636
# It is important to first remove duplicate keys before computing their associated
3737
# stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
3838
# unique_keys, and only then, we compute the stacked tensors.
39-
union = {}
39+
union: dict[Tensor, Tensor] = {}
4040
for d in gradient_dicts:
4141
union |= d
4242
unique_keys = union.keys()

src/torchjd/_autojac/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _get_descendant_accumulate_grads(
6767
"""
6868

6969
excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input
70-
result = OrderedSet([])
70+
result: OrderedSet[Node] = OrderedSet([])
7171
roots.difference_update(excluded_nodes)
7272
nodes_to_traverse = deque(roots)
7373

0 commit comments

Comments
 (0)