diff --git a/dwave/plugins/torch/nn/modules/__init__.py b/dwave/plugins/torch/nn/modules/__init__.py index 598a19c..b628264 100755 --- a/dwave/plugins/torch/nn/modules/__init__.py +++ b/dwave/plugins/torch/nn/modules/__init__.py @@ -15,3 +15,4 @@ from dwave.plugins.torch.nn.modules.linear import * from dwave.plugins.torch.nn.modules.utils import * +from dwave.plugins.torch.nn.modules.conv import * diff --git a/dwave/plugins/torch/nn/modules/conv.py b/dwave/plugins/torch/nn/modules/conv.py new file mode 100755 index 0000000..c1f907a --- /dev/null +++ b/dwave/plugins/torch/nn/modules/conv.py @@ -0,0 +1,101 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from torch import nn +from dwave.plugins.torch.nn.modules.utils import store_config + +__all__ = ["SkipConv2d", "ConvolutionBlock"] + +class SkipConv2d(nn.Module): + """A 2D convolution or the identity depending on whether input/output channels match. + + This module is identity when ``cin == cout``, otherwise it applies a 1×1 convolution + without bias to match channel dimensions. This is used for residual (skip) connections + as described in the ResNet architecture. + + Args: + cin (int): Number of input channels. + cout (int): Number of output channels. + """ + + @store_config + def __init__(self, cin: int, cout: int) -> None: + super().__init__() + if cin == cout: + self.conv = nn.Identity() + else: + self.conv = nn.Conv2d(cin, cout, kernel_size=1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the skip connection transformation. + + Args: + x (torch.Tensor): Input tensor of shape ``(N, cin, H, W)``. + + Returns: + torch.Tensor: Output tensor of shape ``(N, cout, H, W)``. + """ + return self.conv(x) + +class ConvolutionBlock(nn.Module): + """A residual convolutional block with normalization, convolutions, and a skip connection. + + The block consists of: + + 1. Layer normalization over the input, + 2. a 3×3 convolution, + 3. a ReLU activation, + 4. a second layer normalization, + 5. a second 3×3 convolution, and + 6. a skip connection from input to output. + + This block preserves spatial resolution and follows the residual learning + principle introduced in the ResNet paper. + + Args: + input_shape (tuple[int, int, int]): Input shape ``(channels, height, width)``. + cout (int): Number of output channels. + + Raises: + NotImplementedError: If input height and width are not equal. + """ + + @store_config + def __init__(self, input_shape: tuple[int, int, int], cout: int) -> None: + super().__init__() + + cin, hx, wx = tuple(input_shape) + if hx != wx: + raise NotImplementedError("Only square inputs are currently supported.") + + self._block = nn.Sequential( + nn.LayerNorm((cin, hx, wx)), + nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.LayerNorm((cout, hx, wx)), + nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1), + ) + self._skip = SkipConv2d(cin, cout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the convolutional block and skip connection. + + Args: + x (torch.Tensor): Input tensor of shape ``(N, cin, H, W)``. + + Returns: + torch.Tensor: Output tensor of shape ``(N, cout, H, W)``. + """ + return self._block(x) + self._skip(x) \ No newline at end of file diff --git a/releasenotes/notes/add-conv-modules-3c01964dc6205555.yaml b/releasenotes/notes/add-conv-modules-3c01964dc6205555.yaml new file mode 100644 index 0000000..28ae02f --- /dev/null +++ b/releasenotes/notes/add-conv-modules-3c01964dc6205555.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add ``ConvolutionBlock`` and ``SkipConv2d`` Modules. diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..aef3e42 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -4,6 +4,8 @@ from parameterized import parameterized from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config +from dwave.plugins.torch.nn.modules.utils import store_config +from dwave.plugins.torch.nn.modules.conv import ConvolutionBlock, SkipConv2d from tests.helper_functions import model_probably_good @@ -95,6 +97,65 @@ def test_SkipLinear_identity(self): self.assertTrue((x == y).all()) self.assertTrue(model_probably_good(model, (dim,), (dim, ))) +class TestConv(unittest.TestCase): + """Tests for convolutional residual modules. + + The tests focus on: + 1. Output shape correctness, and + 2. Identity behavior of skip connections when possible. + """ + + @parameterized.expand([ + ((3, 16, 16), 8), + ((8, 32, 32), 8), + ((16, 64, 64), 32), + ]) + def test_ConvolutionBlock(self, input_shape, cout): + model = ConvolutionBlock(input_shape, cout) + self.assertTrue( + model_probably_good( + model, + input_shape, + (cout, input_shape[1], input_shape[2]), + ) + ) + + def test_SkipConv2d_different_channels(self): + cin = 5 + cout = 13 + h = w = 17 + model = SkipConv2d(cin, cout) + self.assertTrue( + model_probably_good( + model, + (cin, h, w), + (cout, h, w), + ) + ) + + def test_SkipConv2d_identity(self): + # SkipConv2d behaves as identity when cin == cout + c = 11 + h = w = 23 + model = SkipConv2d(c, c) + + x = torch.randn((c, h, w)) + y = model(x) + + self.assertTrue(torch.equal(x, y)) + self.assertTrue( + model_probably_good( + model, + (c, h, w), + (c, h, w), + ) + ) + + def test_ConvolutionBlock_non_square_raises(self): + # The block explicitly does not support non-square inputs + with self.assertRaises(NotImplementedError): + ConvolutionBlock((3, 16, 32), 8) + if __name__ == "__main__": unittest.main()