Skip to content

Commit dcd52a4

Browse files
committed
MagnusModel fix
1 parent f0e9803 commit dcd52a4

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ wandb_api.py
1111
#Magnus specific
1212
docker/*
1313
job*
14-
14+
env2/*
1515
# Byte-compiled / optimized / DLL files
1616
__pycache__/
1717
*.py[cod]
@@ -150,6 +150,7 @@ ENV/
150150
env.bak/
151151
venv.bak/
152152

153+
153154
# Spyder project settings
154155
.spyderproject
155156
.spyproject

utils/models/magnus_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, image_shape, num_classes: int, nr_channels: int):
1717
MagnusModel (nn.Module): An instance of the MagnusModel neural network.
1818
"""
1919
super().__init__()
20-
_, H, W = image_shape
20+
*_, H, W = image_shape
2121

2222
self.layer1 = nn.Sequential(
2323
*(
@@ -55,12 +55,12 @@ def forward(self, x):
5555
if __name__ == "__main__":
5656
import torch as th
5757

58-
data_shape = [28, 28]
58+
image_shape = (3, 28, 28)
59+
n, c, h, w = 5, *image_shape
60+
model = MagnusModel([h, w], 10, c)
5961

60-
data_shape = (3, *data_shape)
61-
model = MagnusModel(data_shape, 10)
62-
63-
dummy_img = th.rand((5, *data_shape))
64-
print(dummy_img.size())
62+
x = th.rand((n, c, h, w))
6563
with th.no_grad():
66-
print(model(dummy_img).size())
64+
y = model(x)
65+
66+
assert y.shape == (n, 10), f"Shape: {y.shape}"

0 commit comments

Comments
 (0)