Skip to content

Commit 1a2394a

Browse files
committed
update model input to image_shape
1 parent 38c4139 commit 1a2394a

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

utils/models/solveig_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ class SolveigModel(nn.Module):
66
"""
77
A Convolutional Neural Network model for classification.
88
9-
Args:
9+
Args
1010
----
11-
in_channels : int
12-
Number of input channels (e.g., 3 for RGB images, 1 for grayscale).
11+
image_shape : tuple(int, int, int)
12+
Shape of the input image (C, H, W).
1313
num_classes : int
14-
The number of output classes (e.g., 2 for binary classification).
14+
Number of classes in the dataset.
1515
1616
Attributes:
1717
-----------
@@ -25,12 +25,14 @@ class SolveigModel(nn.Module):
2525
Fully connected layer that outputs the final classification scores.
2626
"""
2727

28-
def __init__(self, in_channels, num_classes):
28+
def __init__(self, image_shape, num_classes):
2929
super().__init__()
3030

31+
C, *_ = image_shape
32+
3133
# Define the first convolutional block (conv + relu + maxpool)
3234
self.conv_block1 = nn.Sequential(
33-
nn.Conv2d(in_channels=in_channels, out_channels=25, kernel_size=3, padding=1),
35+
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1),
3436
nn.ReLU(),
3537
nn.MaxPool2d(kernel_size=2, stride=2)
3638
)
@@ -62,9 +64,11 @@ def forward(self, x):
6264

6365

6466
if __name__ == "__main__":
65-
model = SolveigModel(3, 3)
6667

67-
x = torch.randn(1, 3, 16, 16)
68+
x = torch.randn(1,3, 16, 16)
69+
70+
model = SolveigModel(x.shape[1:], 3)
71+
6872
y = model(x)
6973

7074
print(y)

0 commit comments

Comments
 (0)