2424]
2525
2626
27+ @torch .compile # type: ignore[misc]
2728def partial_contraction (G1 : torch .Tensor , G2 : torch .Tensor , axis : int ) -> torch .Tensor :
2829 """Compute the partial contraction of G1 and G2 along axis `axis`.
2930 This is the contraction of the two tensors, but with all axes except `axis` contracted.
@@ -43,7 +44,7 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.
4344 return torch .tensordot (G1 , G2 , dims = (dims , dims ))
4445
4546
46- # @torch.compile # type: ignore[misc]
47+ @torch .compile # type: ignore[misc]
4748def apply_kronecker_factors (Q_list : List [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
4849 """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.
4950
@@ -67,7 +68,7 @@ def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torc
6768 return Y
6869
6970
70- # @torch.compile # type: ignore[misc]
71+ @torch .compile # type: ignore[misc]
7172def apply_preconditioner (Q_list : List [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
7273 """Apply the full PSGD preconditioner to X.
7374
@@ -130,6 +131,7 @@ def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int)
130131 return Y .permute (perm )
131132
132133
134+ @torch .compile # type: ignore[misc]
133135def _apply_single_kronecker_factor (Q_list : List [torch .Tensor ], X : torch .Tensor , axis : int ) -> torch .Tensor :
134136 """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.
135137
0 commit comments