-
Notifications
You must be signed in to change notification settings - Fork 11
Givens orthogonal layer #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from collections import deque | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from einops import einsum | ||
|
|
||
| from dwave.plugins.torch.nn.modules.utils import store_config | ||
|
|
||
| __all__ = ["GivensRotation"] | ||
|
|
||
|
|
||
| class _RoundRobinGivens(torch.autograd.Function): | ||
| """Implements custom forward and backward passes to implement the parallel algorithms in | ||
| https://arxiv.org/abs/2106.00003 | ||
|
|
||
| .. note:: | ||
| We adopt the notation from the paper, but instead of using the rows of U to compute | ||
| rotations, we follow the standard convention of using the columns of U. Since U is | ||
| orthogonal, this does not affect the result. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor: | ||
| """Creates a rotation matrix in n dimensions using parallel Givens transformations by | ||
| blocks. | ||
|
|
||
| Args: | ||
| ctx (context): Stores information for backward propagation. | ||
| angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations | ||
| between pairs of dimensions. | ||
| blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices | ||
| that specify rotations between pairs of dimensions. Each of the ``n - 1`` blocks | ||
| contains ``n // 2`` pairs of independent rotations. | ||
| n (int): Dimension of the space. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The nxn rotation matrix. | ||
| """ | ||
| # Blocks is of shape (n_blocks, n/2, 2) containing indices for angles | ||
| # Within each block, each Givens rotation is commuting, so we can apply them in parallel | ||
| U = torch.eye(n, device=angles.device, dtype=angles.dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slight preference to keep variables lower-case.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed this in the main
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also favour variable names in lower case with upper case reserved for constants, primarily because it is a widely adopted convention. I agree having a 1-1 correspondence between paper notation and implementation is important for readability, but I think making exceptions paper-by-paper can become messy. I suggest noting the correspondence between variable names and paper notation in the docstring.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Playing devil's advocate against myself here: sometimes descriptive variable names are unnecessarily verbose and unhelpful in describing the algorithm.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that for some algorithms, readers are more used to certain notations, for example, that U is orthogonal (actually orthonormal). I would make a vote for picking a convention 😆 |
||
| block_size = n // 2 | ||
| idx_block = torch.arange(block_size, device=angles.device) | ||
| B = blocks # to keep the same notation as in the paper | ||
| for b, block in enumerate(B): | ||
| # angles is of shape (n_angles,) containing all angles for contiguous blocks. | ||
| angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) | ||
| c = torch.cos(angles_in_block).unsqueeze(0) | ||
| s = torch.sin(angles_in_block).unsqueeze(0) | ||
| i_idx = block[:, 0] | ||
| j_idx = block[:, 1] | ||
| r_i = c * U[:, i_idx] + s * U[:, j_idx] | ||
| r_j = -s * U[:, i_idx] + c * U[:, j_idx] | ||
| U[:, i_idx] = r_i | ||
| U[:, j_idx] = r_j | ||
| ctx.save_for_backward(angles, B, U) | ||
| return U | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: | ||
| """Computes the vector-Jacobian product needed for backward propagation. | ||
|
|
||
| Args: | ||
| ctx (context): Contains information for backward propagation. | ||
| grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss | ||
| with respect to the output of the forward pass, i.e., dL/dU. | ||
|
|
||
| Returns: | ||
| tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input | ||
| angles. No calculation of gradients with respect to blocks or n is needed (cf. | ||
| forward method), so None is returned for these. | ||
| """ | ||
| angles, B, Ufwd_saved = ctx.saved_tensors | ||
| Ufwd = Ufwd_saved.clone() | ||
| M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n) | ||
| n = M.size(1) | ||
| block_size = n // 2 | ||
| A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype) | ||
| grad_theta = torch.zeros_like(angles, dtype=angles.dtype) | ||
| idx_block = torch.arange(block_size, device=angles.device) | ||
| for b, block in enumerate(B): | ||
| i_idx = block[:, 0] | ||
| j_idx = block[:, 1] | ||
| angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) | ||
| c = torch.cos(angles_in_block) | ||
| s = torch.sin(angles_in_block) | ||
| r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx] | ||
| r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx] | ||
| Ufwd[i_idx] = r_i | ||
| Ufwd[j_idx] = r_j | ||
| r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx] | ||
| r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx] | ||
| M[:, i_idx] = r_i | ||
| M[:, j_idx] = r_j | ||
| A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx] | ||
| grad_theta[idx_block + b * block_size] = A.sum(dim=1) | ||
| return grad_theta, None, None | ||
|
|
||
|
|
||
| class GivensRotation(nn.Module): | ||
| """An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in | ||
| a round-robin fashion. | ||
|
|
||
| Angles are arranged into blocks, where each block references rotations that can be applied in | ||
| parallel because these rotations commute. | ||
|
|
||
| Args: | ||
| n (int): Dimension of the input and output space. Must be at least 2. | ||
| bias (bool): If True, adds a learnable bias to the output. Default: True. | ||
| """ | ||
|
|
||
| @store_config | ||
| def __init__(self, n: int, bias: bool = True): | ||
| super().__init__() | ||
| if not isinstance(n, int) or n <= 1: | ||
| raise ValueError(f"n must be an integer greater than 1, {n} was passed") | ||
| if not isinstance(bias, bool): | ||
| raise ValueError(f"bias must be a boolean, {bias} was passed") | ||
| self.n = n | ||
| self.n_angles = n * (n - 1) // 2 | ||
| self.angles = nn.Parameter(torch.randn(self.n_angles)) | ||
| blocks_edges = self._get_blocks_edges(n) | ||
| self.register_buffer("blocks", blocks_edges) | ||
| if bias: | ||
| self.bias = nn.Parameter(torch.zeros(n)) | ||
| else: | ||
| self.register_parameter("bias", None) | ||
|
|
||
| @staticmethod | ||
| def _get_blocks_edges(n: int) -> torch.Tensor: | ||
| """Uses the circle method for Round Robin pairing to create blocks of edges for parallel | ||
| Givens rotations. | ||
|
|
||
| A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs | ||
| in the same block can be rotated in parallel since they commute. | ||
|
|
||
| Args: | ||
| n (int): Dimension of the vector space onto which an orthogonal layer will be built. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Blocks of edges for parallel Givens rotations stored in a tensor of shape | ||
| ``(n - 1, n // 2, 2)``. | ||
|
|
||
| .. note:: | ||
| If n is odd, a dummy dimension is added to make it even. When using the resulting blocks | ||
| to build an orthogonal transformation, rotations involving the dummy dimension should be | ||
| ignored. | ||
| """ | ||
| is_odd = bool(n % 2 != 0) | ||
| if is_odd: | ||
| # The circle method requires an even number of nodes, so we add a dummy dimension, the | ||
| # additional rotations involving this dimension will be ignored later. | ||
| n += 1 | ||
|
|
||
| def circle_method(sequence): | ||
| seq_first_half = sequence[: len(sequence) // 2] | ||
| seq_second_half = sequence[len(sequence) // 2 :][::-1] | ||
| return list(zip(seq_first_half, seq_second_half)) | ||
|
|
||
| blocks = [] | ||
| sequence = list(range(n)) | ||
| sequence_deque = deque(sequence[1:]) | ||
| for _ in range(n - 1): | ||
| pairs = circle_method(sequence) | ||
| if is_odd: | ||
| # Remove pairs involving the dummy dimension: | ||
| pairs = [pair for pair in pairs if n - 1 not in pair] | ||
| blocks.append(pairs) | ||
| sequence_deque.rotate(1) | ||
| sequence[1:] = list(sequence_deque) | ||
| return torch.tensor(blocks, dtype=torch.long) | ||
|
|
||
| def _create_rotation_matrix(self) -> torch.Tensor: | ||
| """Computes the Givens rotation matrix.""" | ||
| return _RoundRobinGivens.apply(self.angles, self.blocks, self.n) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Applies the Givens rotation to the input tensor ``x``. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape ``(..., n)``. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Rotated tensor of shape ``(..., n)``. | ||
| """ | ||
| unitary = self._create_rotation_matrix() | ||
| rotated_x = einsum(x, unitary, "... i, o i -> ... o") | ||
| if self.bias is not None: | ||
| rotated_x += self.bias | ||
| return rotated_x | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ dependencies = [ | |
| "dimod", | ||
| "dwave-system", | ||
| "dwave-hybrid", | ||
| "einops", | ||
| ] | ||
|
|
||
| [project.readme] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| --- | ||
| features: | ||
| - | | ||
| Add orthogonal rotation layer using Givens rotations. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| from einops import einsum | ||
|
|
||
|
|
||
| class NaiveGivensRotationLayer(nn.Module): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not very keen on having a full on separate implementation here just to compare with/test the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed this in our one on one but, just for the record, there is no difference between the |
||
| """Naive implementation of a Givens rotation layer. | ||
|
|
||
| Sequentially applies all Givens rotations to implement an orthogonal transformation in an order | ||
| provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains | ||
| pairs of indices such that no index appears more than once in a block. However, this | ||
| implementation does not rely on that assumption, so that indeces can appear multiple times in a | ||
| block; however, all pairs of indices must appear exactly once in the entire blocks tensor. | ||
|
|
||
| Args: | ||
| in_features (int): Number of input features. | ||
| out_features (int): Number of output features. | ||
| bias (bool): If True, adds a learnable bias to the output. Default: True. | ||
|
|
||
| Note: | ||
| This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to | ||
| out_features. | ||
| """ | ||
|
|
||
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | ||
| super().__init__() | ||
| assert in_features == out_features, ( | ||
| "This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to " | ||
| "out_features." | ||
| ) | ||
| self.n = in_features | ||
| self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2)) | ||
| if bias: | ||
| self.bias = nn.Parameter(torch.zeros(out_features)) | ||
| else: | ||
| self.register_parameter("bias", None) | ||
|
|
||
| def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None): | ||
| """Creates the rotation matrix from the Givens angles by applying the Givens rotations in | ||
| order and sequentially, as specified by blocks. | ||
|
|
||
| Args: | ||
| angles (torch.Tensor): Givens rotation angles. | ||
| blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If | ||
| None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create | ||
| the blocks. Defaults to None. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Rotation matrix. | ||
| """ | ||
| block_size = self.n // 2 | ||
| if blocks is None: | ||
| # Create dummy blocks from triu indices: | ||
| triu_indices = torch.triu_indices(self.n, self.n, offset=1) | ||
| blocks = triu_indices.t().view(-1, block_size, 2) | ||
| U = torch.eye(self.n, dtype=angles.dtype) | ||
| for b, block in enumerate(blocks): | ||
| for k in range(block_size): | ||
| i = block[k, 0].item() | ||
| j = block[k, 1].item() | ||
| angle = angles[b * block_size + k] | ||
| c = torch.cos(angle) | ||
| s = torch.sin(angle) | ||
| Ge = torch.eye(self.n, dtype=angles.dtype) | ||
| Ge[i, i] = c | ||
| Ge[j, j] = c | ||
| Ge[i, j] = -s | ||
| Ge[j, i] = s | ||
| # Explicit Givens rotation | ||
| U = U @ Ge | ||
| return U | ||
|
|
||
| def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor: | ||
| """Applies the Givens rotation to the input tensor ``x``. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape (..., n). | ||
| blocks (torch.Tensor): Blocks specifying the order of rotations. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Rotated tensor of shape (..., n). | ||
| """ | ||
| W = self._create_rotation_matrix(self.angles, blocks) | ||
| x = einsum(x, W, "... i, o i -> ... o") | ||
| if self.bias is not None: | ||
| x = x + self.bias | ||
| return x | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to
_get_blocks_edges, should this be in class instead? Should this function only ever be used in the context of GivensRotations (?)cc @thisac
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a class, though.