Skip to content

Commit 5ae4d22

Browse files
committed
Format after conflicts resolved
1 parent b42b139 commit 5ae4d22

File tree

6 files changed

+48
-48
lines changed

6 files changed

+48
-48
lines changed

main.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,16 @@ def main():
152152
)
153153
model.to(device)
154154

155-
trainloader = DataLoader(traindata,
156-
batch_size=args.batchsize,
157-
shuffle=True,
158-
pin_memory=True,
159-
drop_last=True)
160-
valiloader = DataLoader(validata,
161-
batch_size=args.batchsize,
162-
shuffle=False,
163-
pin_memory=True)
155+
trainloader = DataLoader(
156+
traindata,
157+
batch_size=args.batchsize,
158+
shuffle=True,
159+
pin_memory=True,
160+
drop_last=True,
161+
)
162+
valiloader = DataLoader(
163+
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
164+
)
164165

165166
criterion = nn.CrossEntropyLoss()
166167
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -170,12 +171,10 @@ def main():
170171
print("Dry run completed")
171172
exit(0)
172173

173-
wandb.init(project='',
174-
tags=[])
174+
wandb.init(project="", tags=[])
175175
wandb.watch(model)
176176

177177
for epoch in range(args.epoch):
178-
179178
# Training loop start
180179
trainingloss = []
181180
model.train()
@@ -200,12 +199,14 @@ def main():
200199
loss = criterion(y, pred)
201200
evalloss.append(loss.item())
202201

203-
wandb.log({
204-
'Epoch': epoch,
205-
'Train loss': np.mean(trainingloss),
206-
'Evaluation Loss': np.mean(evalloss)
207-
})
202+
wandb.log(
203+
{
204+
"Epoch": epoch,
205+
"Train loss": np.mean(trainingloss),
206+
"Evaluation Loss": np.mean(evalloss),
207+
}
208+
)
208209

209210

210-
if __name__ == '__main__':
211+
if __name__ == "__main__":
211212
main()

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Recall, F1Score
1+
from utils.metrics import F1Score, Recall
22

33

44
def test_recall():

utils/dataloaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .mnist_0_3 import MNISTDataset0_3
44
from .usps_0_6 import USPSDataset0_6
5-
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
5+
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/load_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import (MNISTDataset0_3, USPSDataset0_6,
4+
USPSH5_Digit_7_9_Dataset)
45

56

67
def load_data(dataset: str, *args, **kwargs) -> Dataset:
@@ -10,6 +11,6 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
1011
case "mnist_0-3":
1112
return MNISTDataset0_3(*args, **kwargs)
1213
case "usps_7-9":
13-
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
14+
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1415
case _:
1516
raise ValueError(f"Dataset: {dataset} not implemented.")

utils/metrics/F1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,3 @@ def compute(self):
8484
)
8585

8686
return f1_score
87-

utils/models/solveig_model.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@
44

55
class SolveigModel(nn.Module):
66
"""
7-
A Convolutional Neural Network model for classification.
8-
9-
Args
10-
----
11-
image_shape : tuple(int, int, int)
12-
Shape of the input image (C, H, W).
13-
num_classes : int
14-
Number of classes in the dataset.
15-
16-
Attributes:
17-
-----------
18-
conv_block1 : nn.Sequential
19-
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
20-
conv_block2 : nn.Sequential
21-
Second convolutional block containing a convolutional layer and ReLU activation.
22-
conv_block3 : nn.Sequential
23-
Third convolutional block containing a convolutional layer and ReLU activation.
24-
fc1 : nn.Linear
25-
Fully connected layer that outputs the final classification scores.
26-
"""
7+
A Convolutional Neural Network model for classification.
8+
9+
Args
10+
----
11+
image_shape : tuple(int, int, int)
12+
Shape of the input image (C, H, W).
13+
num_classes : int
14+
Number of classes in the dataset.
15+
16+
Attributes:
17+
-----------
18+
conv_block1 : nn.Sequential
19+
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
20+
conv_block2 : nn.Sequential
21+
Second convolutional block containing a convolutional layer and ReLU activation.
22+
conv_block3 : nn.Sequential
23+
Third convolutional block containing a convolutional layer and ReLU activation.
24+
fc1 : nn.Linear
25+
Fully connected layer that outputs the final classification scores.
26+
"""
2727

2828
def __init__(self, image_shape, num_classes):
2929
super().__init__()
@@ -34,19 +34,19 @@ def __init__(self, image_shape, num_classes):
3434
self.conv_block1 = nn.Sequential(
3535
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1),
3636
nn.ReLU(),
37-
nn.MaxPool2d(kernel_size=2, stride=2)
37+
nn.MaxPool2d(kernel_size=2, stride=2),
3838
)
3939

4040
# Define the second convolutional block (conv + relu)
4141
self.conv_block2 = nn.Sequential(
4242
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1),
43-
nn.ReLU()
43+
nn.ReLU(),
4444
)
4545

4646
# Define the third convolutional block (conv + relu)
4747
self.conv_block3 = nn.Sequential(
4848
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1),
49-
nn.ReLU()
49+
nn.ReLU(),
5050
)
5151

5252
self.fc1 = nn.Linear(100 * 8 * 8, num_classes)
@@ -64,8 +64,7 @@ def forward(self, x):
6464

6565

6666
if __name__ == "__main__":
67-
68-
x = torch.randn(1,3, 16, 16)
67+
x = torch.randn(1, 3, 16, 16)
6968

7069
model = SolveigModel(x.shape[1:], 3)
7170

0 commit comments

Comments
 (0)