|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from typing import List |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | + |
| 20 | +__all__ = [ |
| 21 | + "partial_contraction", |
| 22 | + "apply_kronecker_factors", |
| 23 | + "apply_preconditioner", |
| 24 | +] |
| 25 | + |
| 26 | + |
| 27 | +@torch.compile # type: ignore[misc] |
| 28 | +def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor: |
| 29 | + """Compute the partial contraction of G1 and G2 along axis `axis`. |
| 30 | + This is the contraction of the two tensors, but with all axes except `axis` contracted. |
| 31 | +
|
| 32 | + Args: |
| 33 | + G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) |
| 34 | + G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) |
| 35 | + axis: int, the axis to contract along |
| 36 | +
|
| 37 | + Returns: |
| 38 | + Tensor of shape (d_{axis}, d_{axis}) |
| 39 | + """ |
| 40 | + # dims_to_contract = all dims except `axis` |
| 41 | + dims = list(range(G1.dim())) |
| 42 | + dims.pop(axis) |
| 43 | + # contraction is symmetric and has shape (d_{axis}, d_{axis}) |
| 44 | + return torch.tensordot(G1, G2, dims=(dims, dims)) |
| 45 | + |
| 46 | + |
| 47 | +@torch.compile # type: ignore[misc] |
| 48 | +def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: |
| 49 | + """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension. |
| 50 | +
|
| 51 | + This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`. |
| 52 | +
|
| 53 | + Args: |
| 54 | + Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. |
| 55 | + X: Tensor of shape `(d_0, d_1, ..., d_N)`. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + Tensor of shape `(d_0, d_1, ..., d_N)`. |
| 59 | + """ |
| 60 | + if len(Q_list) != X.dim(): |
| 61 | + raise ValueError( |
| 62 | + f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}" |
| 63 | + ) |
| 64 | + |
| 65 | + Y = X |
| 66 | + for i in range(len(Q_list)): |
| 67 | + Y = _apply_single_kronecker_factor(Q_list, Y, i) |
| 68 | + return Y |
| 69 | + |
| 70 | + |
| 71 | +@torch.compile # type: ignore[misc] |
| 72 | +def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: |
| 73 | + """Apply the full PSGD preconditioner to X. |
| 74 | +
|
| 75 | + This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X. |
| 76 | +
|
| 77 | + :math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)` |
| 78 | +
|
| 79 | + This applies each factor followed by its transpose for the full preconditioner effect. |
| 80 | +
|
| 81 | + Args: |
| 82 | + Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. |
| 83 | + X: Tensor of shape `(d_0, d_1, ..., d_N)`. |
| 84 | +
|
| 85 | + Returns: |
| 86 | + Tensor of shape `(d_0, d_1, ..., d_N)`. |
| 87 | + """ |
| 88 | + # Apply Q first, then Q.T to get Q^T @ Q |
| 89 | + Px = apply_kronecker_factors(Q_list, X) |
| 90 | + Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px) |
| 91 | + return Px |
| 92 | + |
| 93 | + |
| 94 | +def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor: |
| 95 | + """Multiply tensor X along axis `contract_dim` by 2D matrix M. |
| 96 | +
|
| 97 | + Helper function for `_apply_single_kronecker_factor`. |
| 98 | + If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index. |
| 99 | + `torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`. |
| 100 | + Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out. |
| 101 | + Note that d_{contract_dim} == d_in. |
| 102 | +
|
| 103 | + Args: |
| 104 | + X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N) |
| 105 | + M: Tensor of shape (d_out, d_in) |
| 106 | + contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in |
| 107 | +
|
| 108 | + Returns: |
| 109 | + Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N) |
| 110 | +
|
| 111 | + Examples |
| 112 | + -------- |
| 113 | + >>> X = torch.randn(2, 3, 6) |
| 114 | + >>> M = torch.randn(5, 6) |
| 115 | + >>> contract_dim = 2 |
| 116 | + >>> result = _dim_n_mul_and_permute(X, M, contract_dim) |
| 117 | + >>> print(result.shape) |
| 118 | + torch.Size([2, 3, 5]) |
| 119 | +
|
| 120 | + """ |
| 121 | + if X.shape[contract_dim] != M.shape[1]: |
| 122 | + raise ValueError( |
| 123 | + f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}" |
| 124 | + ) |
| 125 | + # Contract M's 2nd dim (idx=1) with X's `contract_dim` dim |
| 126 | + Y = torch.tensordot(M, X, dims=([1], [contract_dim])) |
| 127 | + # Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …). |
| 128 | + # We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`. |
| 129 | + nd = X.dim() |
| 130 | + perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd)) |
| 131 | + return Y.permute(perm) |
| 132 | + |
| 133 | + |
| 134 | +@torch.compile # type: ignore[misc] |
| 135 | +def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor: |
| 136 | + """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors. |
| 137 | +
|
| 138 | + If Q is a vector, we multiply X by Q. |
| 139 | + If Q is a matrix, we contract Q's second index with X's `axis` index. |
| 140 | +
|
| 141 | + Args: |
| 142 | + Q_list: List of Q (e.g. the Kronecker factors). |
| 143 | + X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N) |
| 144 | + """ |
| 145 | + Q = Q_list[axis] |
| 146 | + if Q.dim() == 1: |
| 147 | + shape = [1] * X.dim() |
| 148 | + shape[axis] = Q.size(0) |
| 149 | + return X * Q.view(shape) |
| 150 | + |
| 151 | + return _dim_n_mul_and_permute(X, Q, contract_dim=axis) |
0 commit comments