Skip to content

Commit 97750d8

Browse files
committed
Fixed bug in JohanModel
1 parent 8922263 commit 97750d8

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

utils/models/johan_model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class JohanModel(nn.Module):
2626
Numer of input features.
2727
num_classes : int
2828
Number of classes in the dataset.
29-
3029
"""
3130

3231
def __init__(self, image_shape, num_classes):
@@ -44,18 +43,15 @@ def __init__(self, image_shape, num_classes):
4443
self.fc3 = nn.Linear(77, 77)
4544
self.fc4 = nn.Linear(77, num_classes)
4645
self.relu = nn.ReLU()
46+
self.flatten = nn.Flatten()
4747

4848
def forward(self, x):
49-
x = x.flatten()
49+
x = self.flatten(x)
5050
for layer in [self.fc1, self.fc2, self.fc3, self.fc4]:
5151
x = layer(x)
5252
x = self.relu(x)
5353
return x
5454

5555

56-
# TODO
57-
# Add your tests here
58-
59-
6056
if __name__ == "__main__":
61-
pass # Add your tests here
57+
print("This is JohanModel")

0 commit comments

Comments
 (0)