Skip to content

Commit e50192f

Browse files
committed
added torch compile
Signed-off-by: mikail <[email protected]>
1 parent dab19bc commit e50192f

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
]
2424

2525

26+
@torch.compile # type: ignore[misc]
2627
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125) -> torch.Tensor:
2728
r"""One step of an online solver for the orthogonal Procrustes problem.
2829

emerging_optimizers/psgd/psgd_kron_contractions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
]
2525

2626

27+
@torch.compile # type: ignore[misc]
2728
def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor:
2829
"""Compute the partial contraction of G1 and G2 along axis `axis`.
2930
This is the contraction of the two tensors, but with all axes except `axis` contracted.
@@ -43,7 +44,7 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.
4344
return torch.tensordot(G1, G2, dims=(dims, dims))
4445

4546

46-
# @torch.compile # type: ignore[misc]
47+
@torch.compile # type: ignore[misc]
4748
def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
4849
"""Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.
4950
@@ -67,7 +68,7 @@ def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torc
6768
return Y
6869

6970

70-
# @torch.compile # type: ignore[misc]
71+
@torch.compile # type: ignore[misc]
7172
def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
7273
"""Apply the full PSGD preconditioner to X.
7374
@@ -130,6 +131,7 @@ def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int)
130131
return Y.permute(perm)
131132

132133

134+
@torch.compile # type: ignore[misc]
133135
def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor:
134136
"""Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.
135137

0 commit comments

Comments
 (0)