Skip to content

Commit ffcbcd3

Browse files
committed
refactor: pre_conditioners_for_grad type
1 parent fb493ae commit ffcbcd3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def compute_pre_conditioners(self) -> None:
333333
def precondition_block(
334334
partitioned_grad: torch.Tensor,
335335
should_preconditioned_dims: List[bool],
336-
pre_conditioners_for_grad: List[torch.Tensor],
336+
pre_conditioners_for_grad: Union[List[torch.Tensor], torch.Tensor],
337337
) -> torch.Tensor:
338338
r"""Perform a preconditioning operation on a single gradient block.
339339

0 commit comments

Comments
 (0)