We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fb7353e commit 8644408Copy full SHA for 8644408
utils/models/johan_model.py
@@ -29,10 +29,17 @@ class JohanModel(nn.Module):
29
30
"""
31
32
- def __init__(self, in_features, num_classes):
+ def __init__(self, image_shape, num_classes):
33
super().__init__()
34
35
- self.fc1 = nn.Linear(in_features, 77)
+ # Extract features from image shape
36
+ self.in_channels = image_shape[0]
37
+ self.height = image_shape[1]
38
+ self.width = image_shape[2]
39
+ self.num_classes = num_classes
40
+ self.in_features = self.in_channels * self.height * self.width
41
+
42
+ self.fc1 = nn.Linear(self.in_features, 77)
43
self.fc2 = nn.Linear(77, 77)
44
self.fc3 = nn.Linear(77, 77)
45
self.fc4 = nn.Linear(77, num_classes)
0 commit comments