Skip to content

Commit 2ff9ad3

Browse files
authored
Merge pull request #38 from SFI-Visual-Intelligence/Jan/softmax
Removing softmax, closes #37
2 parents ed0eaf2 + c0e8d9d commit 2ff9ad3

File tree

4 files changed

+1
-8
lines changed

4 files changed

+1
-8
lines changed

tests/test_models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ def test_christian_model(image_shape, num_classes):
1717
y = model(x)
1818

1919
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
20-
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
21-
f"Softmax output should sum to 1, but got: {y.sum()}"
22-
)
2320

2421

2522
@pytest.mark.parametrize(
@@ -35,3 +32,4 @@ def test_jan_model(image_shape, num_classes):
3532
y = model(x)
3633

3734
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35+

utils/models/christian_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@ def __init__(self, image_shape, num_classes):
5858
self.cnn2 = CNNBlock(50, 100)
5959

6060
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
61-
self.softmax = nn.Softmax(dim=1)
6261

6362
def forward(self, x):
6463
x = self.cnn1(x)
6564
x = self.cnn2(x)
6665

6766
x = x.view(x.size(0), -1)
6867
x = self.fc1(x)
69-
x = self.softmax(x)
7068

7169
return x
7270

utils/models/johan_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,12 @@ def __init__(self, image_shape, num_classes):
4343
self.fc2 = nn.Linear(77, 77)
4444
self.fc3 = nn.Linear(77, 77)
4545
self.fc4 = nn.Linear(77, num_classes)
46-
self.softmax = nn.Softmax(dim=1)
4746
self.relu = nn.ReLU()
4847

4948
def forward(self, x):
5049
for layer in [self.fc1, self.fc2, self.fc3, self.fc4]:
5150
x = layer(x)
5251
x = self.relu(x)
53-
x = self.softmax(x)
5452
return x
5553

5654

utils/models/solveig_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def forward(self, x):
5858
x = torch.flatten(x, 1)
5959

6060
x = self.fc1(x)
61-
x = nn.Softmax(x)
6261

6362
return x
6463

0 commit comments

Comments
 (0)