File tree Expand file tree Collapse file tree 3 files changed +4
-4
lines changed Expand file tree Collapse file tree 3 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -179,10 +179,10 @@ def _create_task_transform(
179179
180180 # Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181181 # .grad fields.
182- accumulate = Accumulate () << Select (task_params )
182+ accumulate = Accumulate () << Select [ Gradients ] (task_params )
183183
184184 # Transform that backpropagates the gradients of the losses w.r.t. the features.
185- backpropagate = Select (features )
185+ backpropagate = Select [ Gradients ] (features )
186186
187187 # Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into
188188 # their .grad fields and backpropagates the gradient of the losses w.r.t. to the features.
Original file line number Diff line number Diff line change @@ -36,7 +36,7 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
3636 # It is important to first remove duplicate keys before computing their associated
3737 # stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
3838 # unique_keys, and only then, we compute the stacked tensors.
39- union = {}
39+ union : dict [ Tensor , Tensor ] = {}
4040 for d in gradient_dicts :
4141 union |= d
4242 unique_keys = union .keys ()
Original file line number Diff line number Diff line change @@ -67,7 +67,7 @@ def _get_descendant_accumulate_grads(
6767 """
6868
6969 excluded_nodes = set (excluded_nodes ) # Re-instantiate set to avoid modifying input
70- result = OrderedSet ([])
70+ result : OrderedSet [ Node ] = OrderedSet ([])
7171 roots .difference_update (excluded_nodes )
7272 nodes_to_traverse = deque (roots )
7373
You can’t perform that action at this time.
0 commit comments