Skip to content

Commit bab9a50

Browse files
authored
refactor: Remove unnecessary casts to list (#425)
1 parent 7a71d50 commit bab9a50

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def forward(
163163
module: nn.Module,
164164
*xs: Tensor,
165165
) -> tuple[Tensor, ...]:
166-
return tuple([x.detach() for x in xs])
166+
return tuple(x.detach() for x in xs)
167167

168168
# For Python version > 3.10, the type of `inputs` should become
169169
# tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]

src/torchjd/autojac/_transform/_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
5151
return tuple()
5252

5353
if len(self.outputs) == 0:
54-
return tuple([torch.zeros_like(input) for input in self.inputs])
54+
return tuple(torch.zeros_like(input) for input in self.inputs)
5555

5656
grads = self._get_vjp(grad_outputs, self.retain_graph)
5757
return grads

0 commit comments

Comments
 (0)