Skip to content

Commit 3fbdf77

Browse files
author
Vladimir Vargas Calderón
committed
Add Givens orthogonal layer
Add Givens orthogonal layer
1 parent eb65b87 commit 3fbdf77

File tree

6 files changed

+357
-20
lines changed

6 files changed

+357
-20
lines changed

dwave/plugins/torch/nn/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
#
1515

1616
from dwave.plugins.torch.nn.modules.linear import *
17+
from dwave.plugins.torch.nn.modules.orthogonal import *
1718
from dwave.plugins.torch.nn.modules.utils import *
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import deque
16+
17+
import torch
18+
import torch.nn as nn
19+
from einops import einsum
20+
21+
__all__ = ["get_blocks_edges", "GivensRotationLayer"]
22+
23+
24+
def get_blocks_edges(n: int) -> list[list[tuple[int, int]]]:
25+
"""
26+
Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens
27+
rotations.
28+
29+
A block is a list of pairs of indices indicating which coordinates to rotate together. Pairs
30+
in the same block can be rotated in parallel since they commute.
31+
32+
Args:
33+
n (int): Dimension of the vector space onto which an orthogonal layer will be built.
34+
35+
Returns:
36+
list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations.
37+
"""
38+
39+
assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas
40+
41+
def circle_method(sequence):
42+
seq_first_half = sequence[: len(sequence) // 2]
43+
seq_second_half = sequence[len(sequence) // 2 :][::-1]
44+
return list(zip(seq_first_half, seq_second_half))
45+
46+
blocks = []
47+
sequence = list(range(n))
48+
seqdeque = deque(sequence[1:])
49+
for _ in range(n - 1):
50+
blocks.append(circle_method(sequence))
51+
seqdeque.rotate(1)
52+
sequence[1:] = list(seqdeque)
53+
return blocks
54+
55+
56+
class RoundRobinGivens(torch.autograd.Function):
57+
"""
58+
Implements custom forward and backward passes to implement the parallel algorithms in
59+
https://arxiv.org/abs/2106.00003
60+
"""
61+
62+
@staticmethod
63+
def forward(ctx, angles: torch.Tensor, blocks: torch.Tensor, n: int) -> torch.Tensor:
64+
"""
65+
Creates a rotation matrix in n dimensions using parallel Givens transformations by blocks.
66+
67+
Args:
68+
ctx (context): Stores information for backward propagation.
69+
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
70+
between pairs of dimensions.
71+
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
72+
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
73+
pairs of independent rotations.
74+
n (int): Dimension of the space.
75+
76+
Returns:
77+
torch.Tensor: The nxn rotation matrix.
78+
"""
79+
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
80+
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
81+
U = torch.eye(n, device=angles.device)
82+
block_size = n // 2
83+
idx_block = torch.arange(block_size, device=angles.device)
84+
for b, block in enumerate(blocks):
85+
# angles is of shape (n_angles,) containing all angles for contiguous blocks.
86+
angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,)
87+
c = torch.cos(angles_in_block)
88+
s = torch.sin(angles_in_block)
89+
i_idx = block[:, 0]
90+
j_idx = block[:, 1]
91+
r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx]
92+
r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx]
93+
U[:, i_idx] = r_i
94+
U[:, j_idx] = r_j
95+
ctx.save_for_backward(angles, blocks, U)
96+
return U
97+
98+
@staticmethod
99+
def backward(ctx, grad_output: torch.Tensor):
100+
"""
101+
Computes the VJP needed for backward propagation.
102+
103+
Args:
104+
ctx (context): Contains information for backward propagation.
105+
grad_output (torch.Tensor): A tensor containing the partial derivatives for the loss
106+
with respect to the output of the forward pass, i.e., dL/dU.
107+
108+
Returns:
109+
torch.Tensor: The gradient of the loss with respect to the input angles.
110+
"""
111+
angles, blocks, Ufwd_saved = ctx.saved_tensors
112+
Ufwd = Ufwd_saved.clone()
113+
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
114+
n = M.size(1)
115+
block_size = n // 2
116+
A = torch.zeros((block_size, n), device=grad_output.device)
117+
grad_theta = torch.zeros_like(angles)
118+
idx_block = torch.arange(block_size, device=grad_output.device)
119+
for b, block in enumerate(blocks):
120+
i_idx = block[:, 0]
121+
j_idx = block[:, 1]
122+
angles_in_block = angles[idx_block + b * block_size] # shape (n/2,)
123+
c = torch.cos(angles_in_block)
124+
s = torch.sin(angles_in_block)
125+
r_i = c.unsqueeze(1) * Ufwd[i_idx] + s.unsqueeze(1) * Ufwd[j_idx]
126+
r_j = -s.unsqueeze(1) * Ufwd[i_idx] + c.unsqueeze(1) * Ufwd[j_idx]
127+
Ufwd[i_idx] = r_i
128+
Ufwd[j_idx] = r_j
129+
r_i = c.unsqueeze(0) * M[:, i_idx] + s.unsqueeze(0) * M[:, j_idx]
130+
r_j = -s.unsqueeze(0) * M[:, i_idx] + c.unsqueeze(0) * M[:, j_idx]
131+
M[:, i_idx] = r_i
132+
M[:, j_idx] = r_j
133+
A[:] = M[:, j_idx].T * Ufwd[i_idx] - M[:, i_idx].T * Ufwd[j_idx]
134+
grad_theta[idx_block + b * block_size] = A.sum(dim=1)
135+
return grad_theta, None, None
136+
137+
138+
class GivensRotationLayer(nn.Module):
139+
"""
140+
An orthogonal layer implementing a rotation using a sequence of Givens rotations arranged in a
141+
round-robin fashion.
142+
143+
Angles are arranged into blocks, where each block references rotations that can be applied in
144+
parallel because these rotations commute.
145+
146+
Args:
147+
n (int): Dimension of the input and output space.
148+
bias (bool): If True, adds a learnable bias to the output. Default: True.
149+
"""
150+
151+
def __init__(self, n: int, bias: bool = True):
152+
super().__init__()
153+
assert n % 2 == 0, "n must be even" # TODO: discuss odd case with Firas
154+
self.n = n
155+
self.n_angles = n * (n - 1) // 2
156+
self.angles = nn.Parameter(torch.randn(self.n_angles))
157+
blocks_edges = get_blocks_edges(n)
158+
self.register_buffer(
159+
"blocks",
160+
torch.tensor(blocks_edges, dtype=torch.long),
161+
)
162+
if bias:
163+
self.bias = nn.Parameter(torch.zeros(n))
164+
else:
165+
self.register_parameter("bias", None)
166+
167+
def _create_rotation_matrix(self) -> torch.Tensor:
168+
"""
169+
Computes the Givens rotation matrix.
170+
"""
171+
U = RoundRobinGivens.apply(self.angles, self.blocks, self.n)
172+
return U
173+
174+
def forward(self, x: torch.Tensor) -> torch.Tensor:
175+
"""
176+
Applies the Givens rotation to the input tensor ``x``.
177+
178+
Args:
179+
x (torch.Tensor): Input tensor of shape (..., n).
180+
181+
Returns:
182+
torch.Tensor: Rotated tensor of shape (..., n).
183+
"""
184+
U = self._create_rotation_matrix()
185+
rotated_x = einsum(x, U, "... i, o i -> ... o")
186+
if self.bias is not None:
187+
rotated_x = rotated_x + self.bias
188+
return rotated_x

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"dimod",
3333
"dwave-system",
3434
"dwave-hybrid",
35+
"einops",
3536
]
3637

3738
[project.readme]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torch==2.9.1
22
dimod==0.12.18
33
dwave-system==1.28.0
44
dwave-hybrid==0.6.13
5+
einops==0.8.1
56

67
# Development requirements
78
reno==4.1.0

tests/helper_models.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import torch.nn as nn
3+
from einops import einsum
4+
5+
6+
class NaiveGivensRotationLayer(nn.Module):
7+
"""
8+
Naive implementation of a Givens rotation layer.
9+
10+
Sequentially applies all Givens rotations to implement an orthogonal transformation in an order
11+
provided by blocks, which are of shape (n_blocks, n/2, 2), and where usually each block contains
12+
pairs of indices such that no index appears more than once in a block. However, this
13+
implementation does not rely on that assumption, so that indeces can appear multiple times in a
14+
block; however, all pairs of indices must appear exactly once in the entire blocks tensor.
15+
16+
Args:
17+
in_features (int): Number of input features.
18+
out_features (int): Number of output features.
19+
bias (bool): If True, adds a learnable bias to the output. Default: True.
20+
21+
Note:
22+
This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to
23+
out_features.
24+
"""
25+
26+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
27+
super().__init__()
28+
assert in_features == out_features, (
29+
"This layer defines an nxn SO(n) rotation matrix, so in_features must be equal to "
30+
"out_features."
31+
)
32+
self.n = in_features
33+
self.angles = nn.Parameter(torch.randn(in_features * (in_features - 1) // 2))
34+
if bias:
35+
self.bias = nn.Parameter(torch.zeros(out_features))
36+
else:
37+
self.register_parameter("bias", None)
38+
39+
def _create_rotation_matrix(self, angles, blocks: torch.Tensor | None = None):
40+
"""
41+
Creates the rotation matrix from the Givens angles by applying the Givens rotations in order
42+
and sequentially, as specified by blocks.
43+
44+
Args:
45+
angles (torch.Tensor): Givens rotation angles.
46+
blocks (torch.Tensor | None, optional): Blocks specifying the order of rotations. If
47+
None, all possible pairs of dimensions will be shaped into (n-1, n/2, 2) to create
48+
the blocks. Defaults to None.
49+
50+
Returns:
51+
torch.Tensor: Rotation matrix.
52+
"""
53+
block_size = self.n // 2
54+
if blocks is None:
55+
# Create dummy blocks from triu indices:
56+
triu_indices = torch.triu_indices(self.n, self.n, offset=1)
57+
blocks = triu_indices.t().view(-1, block_size, 2)
58+
U = torch.eye(self.n)
59+
for b, block in enumerate(blocks):
60+
for k in range(block_size):
61+
i = block[k, 0].item()
62+
j = block[k, 1].item()
63+
angle = angles[b * block_size + k]
64+
c = torch.cos(angle)
65+
s = torch.sin(angle)
66+
# Need to clone because of pytorch. (This wouldn't happen in JAX)
67+
r_i = c * U[:, i].clone() + s * U[:, j].clone()
68+
r_j = -s * U[:, i].clone() + c * U[:, j].clone()
69+
U[:, i] = r_i
70+
U[:, j] = r_j
71+
return U
72+
73+
def forward(self, x: torch.Tensor, blocks: torch.Tensor) -> torch.Tensor:
74+
"""
75+
Applies the Givens rotation to the input tensor ``x``.
76+
77+
Args:
78+
x (torch.Tensor): Input tensor of shape (..., n).
79+
blocks (torch.Tensor): Blocks specifying the order of rotations.
80+
81+
Returns:
82+
torch.Tensor: Rotated tensor of shape (..., n).
83+
"""
84+
W = self._create_rotation_matrix(self.angles, blocks)
85+
x = einsum(x, W, "... i, o i -> ... o")
86+
if self.bias is not None:
87+
x = x + self.bias
88+
return x

0 commit comments

Comments
 (0)