Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dwave/plugins/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

from dwave.plugins.torch.nn.modules.linear import *
from dwave.plugins.torch.nn.modules.utils import *
from dwave.plugins.torch.nn.modules.conv import *
101 changes: 101 additions & 0 deletions dwave/plugins/torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn
from dwave.plugins.torch.nn.modules.utils import store_config

__all__ = ["SkipConv2d", "ConvolutionBlock"]

class SkipConv2d(nn.Module):
"""A 2D convolution or the identity depending on whether input/output channels match.

This module is identity when ``cin == cout``, otherwise it applies a 1×1 convolution
without bias to match channel dimensions. This is used for residual (skip) connections
as described in the ResNet architecture.

Args:
cin (int): Number of input channels.
cout (int): Number of output channels.
"""

@store_config
def __init__(self, cin: int, cout: int) -> None:
super().__init__()
if cin == cout:
self.conv = nn.Identity()
else:
self.conv = nn.Conv2d(cin, cout, kernel_size=1, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the skip connection transformation.

Args:
x (torch.Tensor): Input tensor of shape ``(N, cin, H, W)``.

Returns:
torch.Tensor: Output tensor of shape ``(N, cout, H, W)``.
"""
return self.conv(x)

class ConvolutionBlock(nn.Module):
"""A residual convolutional block with normalization, convolutions, and a skip connection.

The block consists of:

1. Layer normalization over the input,
2. a 3×3 convolution,
3. a ReLU activation,
4. a second layer normalization,
5. a second 3×3 convolution, and
6. a skip connection from input to output.

This block preserves spatial resolution and follows the residual learning
principle introduced in the ResNet paper.

Args:
input_shape (tuple[int, int, int]): Input shape ``(channels, height, width)``.
cout (int): Number of output channels.

Raises:
NotImplementedError: If input height and width are not equal.
"""

@store_config
def __init__(self, input_shape: tuple[int, int, int], cout: int) -> None:
super().__init__()

cin, hx, wx = tuple(input_shape)
if hx != wx:
raise NotImplementedError("Only square inputs are currently supported.")

self._block = nn.Sequential(
nn.LayerNorm((cin, hx, wx)),
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.LayerNorm((cout, hx, wx)),
nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1),
)
self._skip = SkipConv2d(cin, cout)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the convolutional block and skip connection.

Args:
x (torch.Tensor): Input tensor of shape ``(N, cin, H, W)``.

Returns:
torch.Tensor: Output tensor of shape ``(N, cout, H, W)``.
"""
return self._block(x) + self._skip(x)
4 changes: 4 additions & 0 deletions releasenotes/notes/add-conv-modules-3c01964dc6205555.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add ``ConvolutionBlock`` and ``SkipConv2d`` Modules.
61 changes: 61 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from parameterized import parameterized

from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config
from dwave.plugins.torch.nn.modules.utils import store_config
from dwave.plugins.torch.nn.modules.conv import ConvolutionBlock, SkipConv2d
from tests.helper_functions import model_probably_good


Expand Down Expand Up @@ -95,6 +97,65 @@ def test_SkipLinear_identity(self):
self.assertTrue((x == y).all())
self.assertTrue(model_probably_good(model, (dim,), (dim, )))

class TestConv(unittest.TestCase):
"""Tests for convolutional residual modules.

The tests focus on:
1. Output shape correctness, and
2. Identity behavior of skip connections when possible.
"""

@parameterized.expand([
((3, 16, 16), 8),
((8, 32, 32), 8),
((16, 64, 64), 32),
])
def test_ConvolutionBlock(self, input_shape, cout):
model = ConvolutionBlock(input_shape, cout)
self.assertTrue(
model_probably_good(
model,
input_shape,
(cout, input_shape[1], input_shape[2]),
)
)

def test_SkipConv2d_different_channels(self):
cin = 5
cout = 13
h = w = 17
model = SkipConv2d(cin, cout)
self.assertTrue(
model_probably_good(
model,
(cin, h, w),
(cout, h, w),
)
)

def test_SkipConv2d_identity(self):
# SkipConv2d behaves as identity when cin == cout
c = 11
h = w = 23
model = SkipConv2d(c, c)

x = torch.randn((c, h, w))
y = model(x)

self.assertTrue(torch.equal(x, y))
self.assertTrue(
model_probably_good(
model,
(c, h, w),
(c, h, w),
)
)

def test_ConvolutionBlock_non_square_raises(self):
# The block explicitly does not support non-square inputs
with self.assertRaises(NotImplementedError):
ConvolutionBlock((3, 16, 32), 8)


if __name__ == "__main__":
unittest.main()