Skip to content

Commit 46062e0

Browse files
author
Vladimir Vargas Calderón
committed
Add Givens orthogonal layer
1 parent eb65b87 commit 46062e0

File tree

6 files changed

+425
-20
lines changed

6 files changed

+425
-20
lines changed

dwave/plugins/torch/nn/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
#
1515

1616
from dwave.plugins.torch.nn.modules.linear import *
17+
from dwave.plugins.torch.nn.modules.orthogonal import *
1718
from dwave.plugins.torch.nn.modules.utils import *
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import deque
16+
17+
import torch
18+
import torch.nn as nn
19+
from einops import einsum
20+
21+
from dwave.plugins.torch.nn.modules.utils import store_config
22+
23+
__all__ = ["GivensRotationLayer"]
24+
25+
26+
def _get_blocks_edges(n: int) -> list[list[tuple[int, int]]]:
27+
"""Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens
28+
rotations.
29+
30+
A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs
31+
in the same block can be rotated in parallel since they commute.
32+
33+
Args:
34+
n (int): Dimension of the vector space onto which an orthogonal layer will be built.
35+
36+
Returns:
37+
list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations.
38+
39+
Note:
40+
If n is odd, a dummy dimension is added to make it even. When using the resulting blocks to
41+
build an orthogonal transformation, rotations involving the dummy dimension should be
42+
ignored.
43+
"""
44+
if n % 2 != 0:
45+
n += 1 # Add a dummy dimension for odd n
46+
is_odd = True
47+
else:
48+
is_odd = False
49+
50+
def circle_method(sequence):
51+
seq_first_half = sequence[: len(sequence) // 2]
52+
seq_second_half = sequence[len(sequence) // 2 :][::-1]
53+
return list(zip(seq_first_half, seq_second_half))
54+
55+
blocks = []
56+
sequence = list(range(n))
57+
seqdeque = deque(sequence[1:])
58+
for _ in range(n - 1):
59+
pairs = circle_method(sequence)
60+
if is_odd:
61+
# Remove pairs involving the dummy dimension:
62+
pairs = [pair for pair in pairs if n - 1 not in pair]
63+
blocks.append(pairs)
64+
seqdeque.rotate(1)
65+
sequence[1:] = list(seqdeque)
66+
return blocks
67+
68+
69+
class _RoundRobinGivens(torch.autograd.Function):
70+
"""Implements custom forward and backward passes to implement the parallel algorithms in
71+
https://arxiv.org/abs/2106.00003
72+
"""
73+
74+
@staticmethod
75+
def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor:
76+
"""Creates a rotation matrix in n dimensions using parallel Givens transformations by
77+
blocks.
78+
79+
Args:
80+
ctx (context): Stores information for backward propagation.
81+
angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations
82+
between pairs of dimensions.
83+
blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices
84+
that specify rotations between pairs of dimensions. Each of the ``n - 1`` blocks
85+
contains ``n // 2`` pairs of independent rotations.
86+
n (int): Dimension of the space.
87+
88+
Returns:
89+
torch.Tensor: The nxn rotation matrix.
90+
"""
91+
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
92+
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
93+
U = torch.eye(n, device=angles.device, dtype=angles.dtype)
94+
block_size = n // 2
95+
idx_block = torch.arange(block_size, device=angles.device)
96+
for b, block in enumerate(blocks):
97+
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
98+
angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,)
99+
c = torch.cos(angles_in_block)
100+
s = torch.sin(angles_in_block)
101+
i_idx = block[:, 0]
102+
j_idx = block[:, 1]
103+
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
104+
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
105+
U[:, i_idx] = r_i
106+
U[:, j_idx] = r_j
107+
ctx.save_for_backward(angles, blocks, U)
108+
return U
109+
110+
@staticmethod
111+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
112+
"""Computes the VJP needed for backward propagation.
113+
114+
Args:
115+
ctx (context): Contains information for backward propagation.
116+
grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss
117+
with respect to the output of the forward pass, i.e., dL/dU.
118+
119+
Returns:
120+
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the input
121+
angles. No calculation of gradients with respect to blocks or n is needed (cf.
122+
forward method), so None is returned for these.
123+
"""
124+
angles, blocks, Ufwd_saved = ctx.saved_tensors
125+
Ufwd = Ufwd_saved.clone()
126+
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
127+
n = M.size(1)
128+
block_size = n // 2
129+
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
130+
grad_theta = torch.zeros_like(angles, dtype=angles.dtype)
131+
idx_block = torch.arange(block_size, device=angles.device)
132+
for b, block in enumerate(blocks):
133+
i_idx = block[:, 0]
134+
j_idx = block[:, 1]
135+
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
136+
c = torch.cos(angles_in_block)
137+
s = torch.sin(angles_in_block)
138+
r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx]
139+
r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx]
140+
Ufwd[i_idx] = r_i
141+
Ufwd[j_idx] = r_j
142+
r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx]
143+
r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx]
144+
M[:, i_idx] = r_i
145+
M[:, j_idx] = r_j
146+
A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx]
147+
grad_theta[idx_block + b * block_size] = A.sum(dim=1)
148+
return grad_theta, None, None
149+
150+
151+
class GivensRotationLayer(nn.Module):
152+
"""An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in
153+
a round-robin fashion.
154+
155+
Angles are arranged into blocks, where each block references rotations that can be applied in
156+
parallel because these rotations commute.
157+
158+
Args:
159+
n (int): Dimension of the input and output space. Must be at least 2.
160+
bias (bool): If True, adds a learnable bias to the output. Default: True.
161+
"""
162+
163+
@store_config
164+
def __init__(self, n: int, bias: bool = True):
165+
super().__init__()
166+
if not isinstance(n, int) or n <= 1:
167+
raise ValueError(f"n must be an integer greater than 1, {n} was passed")
168+
if not isinstance(bias, bool):
169+
raise ValueError(f"bias must be a boolean, {bias} was passed")
170+
self.n = n
171+
self.n_angles = n * (n - 1) // 2
172+
self.angles = nn.Parameter(torch.randn(self.n_angles))
173+
blocks_edges = _get_blocks_edges(n)
174+
self.register_buffer(
175+
"blocks",
176+
torch.tensor(blocks_edges, dtype=torch.long),
177+
)
178+
if bias:
179+
self.bias = nn.Parameter(torch.zeros(n))
180+
else:
181+
self.register_parameter("bias", None)
182+
183+
def _create_rotation_matrix(self) -> torch.Tensor:
184+
"""Computes the Givens rotation matrix."""
185+
return _RoundRobinGivens.apply(self.angles, self.blocks, self.n)
186+
187+
def forward(self, x: torch.Tensor) -> torch.Tensor:
188+
"""Applies the Givens rotation to the input tensor ``x``.
189+
190+
Args:
191+
x (torch.Tensor): Input tensor of shape ``(..., n)``.
192+
193+
Returns:
194+
torch.Tensor: Rotated tensor of shape ``(..., n)``.
195+
"""
196+
unitary = self._create_rotation_matrix()
197+
rotated_x = einsum(x, unitary, "... i, o i -> ... o")
198+
if self.bias is not None:
199+
rotated_x += self.bias
200+
return rotated_x

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"dimod",
3333
"dwave-system",
3434
"dwave-hybrid",
35+
"einops",
3536
]
3637

3738
[project.readme]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torch==2.9.1
22
dimod==0.12.18
33
dwave-system==1.28.0
44
dwave-hybrid==0.6.13
5+
einops==0.8.1
56

67
# Development requirements
78
reno==4.1.0

tests/helper_models.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import torch.nn as nn
3+
from einops import einsum
4+
5+
6+
class NaiveGivensRotationLayer(nn.Module):
7+
"""Naive implementation of a Givens rotation layer.
8+
9+
Sequentially applies all Givens rotations to implement an orthogonal transformation in an order
10+
provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains
11+
pairs of indices such that no index appears more than once in a block. However, this
12+
implementation does not rely on that assumption, so that indeces can appear multiple times in a
13+
block; however, all pairs of indices must appear exactly once in the entire blocks tensor.
14+
15+
Args:
16+
in_features (int): Number of input features.
17+
out_features (int): Number of output features.
18+
bias (bool): If True, adds a learnable bias to the output. Default: True.
19+
20+
Note:
21+
This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to
22+
out_features.
23+
"""
24+
25+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
26+
super().__init__()
27+
assert in_features == out_features, (
28+
"This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to "
29+
"out_features."
30+
)
31+
self.n = in_features
32+
self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2))
33+
if bias:
34+
self.bias = nn.Parameter(torch.zeros(out_features))
35+
else:
36+
self.register_parameter("bias", None)
37+
38+
def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None):
39+
"""Creates the rotation matrix from the Givens angles by applying the Givens rotations in
40+
order and sequentially, as specified by blocks.
41+
42+
Args:
43+
angles (torch.Tensor): Givens rotation angles.
44+
blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If
45+
None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create
46+
the blocks. Defaults to None.
47+
48+
Returns:
49+
torch.Tensor: Rotation matrix.
50+
"""
51+
block_size = self.n // 2
52+
if blocks is None:
53+
# Create dummy blocks from triu indices:
54+
triu_indices = torch.triu_indices(self.n, self.n, offset=1)
55+
blocks = triu_indices.t().view(-1, block_size, 2)
56+
U = torch.eye(self.n, dtype=angles.dtype)
57+
for b, block in enumerate(blocks):
58+
for k in range(block_size):
59+
i = block[k, 0].item()
60+
j = block[k, 1].item()
61+
angle = angles[b * block_size + k]
62+
c = torch.cos(angle)
63+
s = torch.sin(angle)
64+
Ge = torch.eye(self.n, dtype=angles.dtype)
65+
Ge[i, i] = c
66+
Ge[j, j] = c
67+
Ge[i, j] = -s
68+
Ge[j, i] = s
69+
# Explicit Givens rotation
70+
U = U @ Ge
71+
return U
72+
73+
def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor:
74+
"""Applies the Givens rotation to the input tensor ``x``.
75+
76+
Args:
77+
x (torch.Tensor): Input tensor of shape (..., n).
78+
blocks (torch.Tensor): Blocks specifying the order of rotations.
79+
80+
Returns:
81+
torch.Tensor: Rotated tensor of shape (..., n).
82+
"""
83+
W = self._create_rotation_matrix(self.angles, blocks)
84+
x = einsum(x, W, "... i, o i -> ... o")
85+
if self.bias is not None:
86+
x = x + self.bias
87+
return x

0 commit comments

Comments
 (0)