1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- from typing import List
16-
1715import torch
1816
1917
@@ -43,7 +41,7 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.
4341
4442
4543@torch .compile # type: ignore[misc]
46- def apply_kronecker_factors (Q_list : List [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
44+ def apply_kronecker_factors (Q_list : list [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
4745 """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.
4846
4947 This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`.
@@ -67,7 +65,7 @@ def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torc
6765
6866
6967@torch .compile # type: ignore[misc]
70- def apply_preconditioner (Q_list : List [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
68+ def apply_preconditioner (Q_list : list [torch .Tensor ], X : torch .Tensor ) -> torch .Tensor :
7169 """Apply the full PSGD preconditioner to X.
7270
7371 This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X.
@@ -130,7 +128,7 @@ def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int)
130128
131129
132130@torch .compile # type: ignore[misc]
133- def _apply_single_kronecker_factor (Q_list : List [torch .Tensor ], X : torch .Tensor , axis : int ) -> torch .Tensor :
131+ def _apply_single_kronecker_factor (Q_list : list [torch .Tensor ], X : torch .Tensor , axis : int ) -> torch .Tensor :
134132 """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.
135133
136134 If Q is a vector, we multiply X by Q.
0 commit comments