Skip to content

Commit 8644408

Browse files
committed
updated model according to #31
1 parent fb7353e commit 8644408

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

utils/models/johan_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,17 @@ class JohanModel(nn.Module):
2929
3030
"""
3131

32-
def __init__(self, in_features, num_classes):
32+
def __init__(self, image_shape, num_classes):
3333
super().__init__()
3434

35-
self.fc1 = nn.Linear(in_features, 77)
35+
# Extract features from image shape
36+
self.in_channels = image_shape[0]
37+
self.height = image_shape[1]
38+
self.width = image_shape[2]
39+
self.num_classes = num_classes
40+
self.in_features = self.in_channels * self.height * self.width
41+
42+
self.fc1 = nn.Linear(self.in_features, 77)
3643
self.fc2 = nn.Linear(77, 77)
3744
self.fc3 = nn.Linear(77, 77)
3845
self.fc4 = nn.Linear(77, num_classes)

0 commit comments

Comments
 (0)