Skip to content

Commit f4e5591

Browse files
committed
Update model to input image_shape rather than in_channels
1 parent 9ad01e4 commit f4e5591

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

utils/models/christian_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class ChristianModel(nn.Module):
2727
2828
Args
2929
----
30-
in_channels : int
31-
Number of input channels.
30+
image_shape : tuple(int, int, int)
31+
Shape of the input image (C, H, W).
3232
num_classes : int
3333
Number of classes in the dataset.
3434
@@ -49,10 +49,12 @@ class ChristianModel(nn.Module):
4949
FC Output Shape: (5, num_classes)
5050
"""
5151

52-
def __init__(self, in_channels, num_classes):
52+
def __init__(self, image_shape, num_classes):
5353
super().__init__()
5454

55-
self.cnn1 = CNNBlock(in_channels, 50)
55+
C, *_ = image_shape
56+
57+
self.cnn1 = CNNBlock(C, 50)
5658
self.cnn2 = CNNBlock(50, 100)
5759

5860
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)

0 commit comments

Comments
 (0)