Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dwave/plugins/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
#

from dwave.plugins.torch.nn.modules.linear import *
from dwave.plugins.torch.nn.modules.orthogonal import *
from dwave.plugins.torch.nn.modules.utils import *
204 changes: 204 additions & 0 deletions dwave/plugins/torch/nn/modules/orthogonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque

import torch
import torch.nn as nn
from einops import einsum

from dwave.plugins.torch.nn.modules.utils import store_config

__all__ = ["GivensRotation"]


class _RoundRobinGivens(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to _get_blocks_edges, should this be in class instead? Should this function only ever be used in the context of GivensRotations (?)
cc @thisac

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a class, though.

"""Implements custom forward and backward passes to implement the parallel algorithms in
https://arxiv.org/abs/2106.00003

.. note::
We adopt the notation from the paper, but instead of using the rows of U to compute
rotations, we follow the standard convention of using the columns of U. Since U is
orthogonal, this does not affect the result.
"""

@staticmethod
def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor:
"""Creates a rotation matrix in n dimensions using parallel Givens transformations by
blocks.

Args:
ctx (context): Stores information for backward propagation.
angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices
that specify rotations between pairs of dimensions. Each of the ``n - 1`` blocks
contains ``n // 2`` pairs of independent rotations.
n (int): Dimension of the space.

Returns:
torch.Tensor: The nxn rotation matrix.
"""
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
U = torch.eye(n, device=angles.device, dtype=angles.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight preference to keep variables lower-case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this in the main GivensRotationLayer class. In the other code, I kept the capital letters just so that if someone is reading the algorithm in the paper alongside the code, each part of the algorithm is more easily understood.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also favour variable names in lower case with upper case reserved for constants, primarily because it is a widely adopted convention. I agree having a 1-1 correspondence between paper notation and implementation is important for readability, but I think making exceptions paper-by-paper can become messy. I suggest noting the correspondence between variable names and paper notation in the docstring.
--
I think I snuck in some upper case variable names in the codebase... should track those down at some point 😅

Copy link
Collaborator

@kevinchern kevinchern Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Playing devil's advocate against myself here: sometimes descriptive variable names are unnecessarily verbose and unhelpful in describing the algorithm.
🤷‍♀️

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for some algorithms, readers are more used to certain notations, for example, that U is orthogonal (actually orthonormal). I would make a vote for picking a convention 😆

block_size = n // 2
idx_block = torch.arange(block_size, device=angles.device)
B = blocks # to keep the same notation as in the paper
for b, block in enumerate(B):
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
c = torch.cos(angles_in_block).unsqueeze(0)
s = torch.sin(angles_in_block).unsqueeze(0)
i_idx = block[:, 0]
j_idx = block[:, 1]
r_i = c * U[:, i_idx] + s * U[:, j_idx]
r_j = -s * U[:, i_idx] + c * U[:, j_idx]
U[:, i_idx] = r_i
U[:, j_idx] = r_j
ctx.save_for_backward(angles, B, U)
return U

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
"""Computes the vector-Jacobian product needed for backward propagation.

Args:
ctx (context): Contains information for backward propagation.
grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss
with respect to the output of the forward pass, i.e., dL/dU.

Returns:
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
angles. No calculation of gradients with respect to blocks or n is needed (cf.
forward method), so None is returned for these.
"""
angles, B, Ufwd_saved = ctx.saved_tensors
Ufwd = Ufwd_saved.clone()
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
n = M.size(1)
block_size = n // 2
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
grad_theta = torch.zeros_like(angles, dtype=angles.dtype)
idx_block = torch.arange(block_size, device=angles.device)
for b, block in enumerate(B):
i_idx = block[:, 0]
j_idx = block[:, 1]
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
c = torch.cos(angles_in_block)
s = torch.sin(angles_in_block)
r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx]
r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx]
Ufwd[i_idx] = r_i
Ufwd[j_idx] = r_j
r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx]
r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx]
M[:, i_idx] = r_i
M[:, j_idx] = r_j
A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx]
grad_theta[idx_block + b * block_size] = A.sum(dim=1)
return grad_theta, None, None


class GivensRotation(nn.Module):
"""An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in
a round-robin fashion.

Angles are arranged into blocks, where each block references rotations that can be applied in
parallel because these rotations commute.

Args:
n (int): Dimension of the input and output space. Must be at least 2.
bias (bool): If True, adds a learnable bias to the output. Default: True.
"""

@store_config
def __init__(self, n: int, bias: bool = True):
super().__init__()
if not isinstance(n, int) or n <= 1:
raise ValueError(f"n must be an integer greater than 1, {n} was passed")
if not isinstance(bias, bool):
raise ValueError(f"bias must be a boolean, {bias} was passed")
self.n = n
self.n_angles = n * (n - 1) // 2
self.angles = nn.Parameter(torch.randn(self.n_angles))
blocks_edges = self._get_blocks_edges(n)
self.register_buffer("blocks", blocks_edges)
if bias:
self.bias = nn.Parameter(torch.zeros(n))
else:
self.register_parameter("bias", None)

@staticmethod
def _get_blocks_edges(n: int) -> torch.Tensor:
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel
Givens rotations.

A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs
in the same block can be rotated in parallel since they commute.

Args:
n (int): Dimension of the vector space onto which an orthogonal layer will be built.

Returns:
torch.Tensor: Blocks of edges for parallel Givens rotations stored in a tensor of shape
``(n - 1, n // 2, 2)``.

.. note::
If n is odd, a dummy dimension is added to make it even. When using the resulting blocks
to build an orthogonal transformation, rotations involving the dummy dimension should be
ignored.
"""
is_odd = bool(n % 2 != 0)
if is_odd:
# The circle method requires an even number of nodes, so we add a dummy dimension, the
# additional rotations involving this dimension will be ignored later.
n += 1

def circle_method(sequence):
seq_first_half = sequence[: len(sequence) // 2]
seq_second_half = sequence[len(sequence) // 2 :][::-1]
return list(zip(seq_first_half, seq_second_half))

blocks = []
sequence = list(range(n))
sequence_deque = deque(sequence[1:])
for _ in range(n - 1):
pairs = circle_method(sequence)
if is_odd:
# Remove pairs involving the dummy dimension:
pairs = [pair for pair in pairs if n - 1 not in pair]
blocks.append(pairs)
sequence_deque.rotate(1)
sequence[1:] = list(sequence_deque)
return torch.tensor(blocks, dtype=torch.long)

def _create_rotation_matrix(self) -> torch.Tensor:
"""Computes the Givens rotation matrix."""
return _RoundRobinGivens.apply(self.angles, self.blocks, self.n)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the Givens rotation to the input tensor ``x``.

Args:
x (torch.Tensor): Input tensor of shape ``(..., n)``.

Returns:
torch.Tensor: Rotated tensor of shape ``(..., n)``.
"""
unitary = self._create_rotation_matrix()
rotated_x = einsum(x, unitary, "... i, o i -> ... o")
if self.bias is not None:
rotated_x += self.bias
return rotated_x
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"dimod",
"dwave-system",
"dwave-hybrid",
"einops",
]

[project.readme]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add orthogonal rotation layer using Givens rotations.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ torch==2.9.1
dimod==0.12.21
dwave-system==1.34.0
dwave-hybrid==0.6.14
einops==0.8.1

# Development requirements
reno==4.1.0
87 changes: 87 additions & 0 deletions tests/helper_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn as nn
from einops import einsum


class NaiveGivensRotationLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very keen on having a full on separate implementation here just to compare with/test the GivensRotationLayer. If this NaiveGivensRotationLayer is useful, should it be part of the package instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this in our one on one but, just for the record, there is no difference between the NaiveGivensRotationLayer and the GivensRotationLayer in the forward or backward passes. The naïve implementation is there to make sure that the forward and backward passes indeed match. The GivensRotationLayer should always be used because it has a substantially better runtime complexity. Thus, the naïve implementation is not useful—other than for a sanity check.

"""Naive implementation of a Givens rotation layer.

Sequentially applies all Givens rotations to implement an orthogonal transformation in an order
provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains
pairs of indices such that no index appears more than once in a block. However, this
implementation does not rely on that assumption, so that indeces can appear multiple times in a
block; however, all pairs of indices must appear exactly once in the entire blocks tensor.

Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): If True, adds a learnable bias to the output. Default: True.

Note:
This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to
out_features.
"""

def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
assert in_features == out_features, (
"This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to "
"out_features."
)
self.n = in_features
self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)

def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None):
"""Creates the rotation matrix from the Givens angles by applying the Givens rotations in
order and sequentially, as specified by blocks.

Args:
angles (torch.Tensor): Givens rotation angles.
blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If
None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create
the blocks. Defaults to None.

Returns:
torch.Tensor: Rotation matrix.
"""
block_size = self.n // 2
if blocks is None:
# Create dummy blocks from triu indices:
triu_indices = torch.triu_indices(self.n, self.n, offset=1)
blocks = triu_indices.t().view(-1, block_size, 2)
U = torch.eye(self.n, dtype=angles.dtype)
for b, block in enumerate(blocks):
for k in range(block_size):
i = block[k, 0].item()
j = block[k, 1].item()
angle = angles[b * block_size + k]
c = torch.cos(angle)
s = torch.sin(angle)
Ge = torch.eye(self.n, dtype=angles.dtype)
Ge[i, i] = c
Ge[j, j] = c
Ge[i, j] = -s
Ge[j, i] = s
# Explicit Givens rotation
U = U @ Ge
return U

def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor:
"""Applies the Givens rotation to the input tensor ``x``.

Args:
x (torch.Tensor): Input tensor of shape (..., n).
blocks (torch.Tensor): Blocks specifying the order of rotations.

Returns:
torch.Tensor: Rotated tensor of shape (..., n).
"""
W = self._create_rotation_matrix(self.angles, blocks)
x = einsum(x, W, "... i, o i -> ... o")
if self.bias is not None:
x = x + self.bias
return x
Loading