Skip to content

Commit 53d23e3

Browse files
Auto-format: Applied ruff format and isort
1 parent 8693d41 commit 53d23e3

File tree

10 files changed

+78
-58
lines changed

10 files changed

+78
-58
lines changed

utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
from .createfolders import createfolders
2+
from .load_data import load_data
13
from .load_metric import MetricWrapper
24
from .load_model import load_model
3-
from .load_data import load_data
4-
from .createfolders import createfolders

utils/createfolders.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,56 @@
1-
import os
2-
from tempfile import TemporaryDirectory
31
import argparse
2+
import os
3+
from tempfile import TemporaryDirectory
4+
5+
6+
def createfolders(args) -> None:
7+
"""
8+
Creates folders for storing data, results, model weights.
49
5-
def createfolders(args) -> None:
6-
'''
7-
Creates folders for storing data, results, model weights.
8-
910
Parameters
1011
----------
1112
args
1213
ArgParse object containing string paths to be created
13-
14-
'''
15-
14+
15+
"""
16+
1617
if not os.path.exists(args.datafolder):
1718
os.makedirs(args.datafolder)
18-
print(f'Created a folder at {args.datafolder}')
19-
19+
print(f"Created a folder at {args.datafolder}")
20+
2021
if not os.path.exists(args.resultfolder):
2122
os.makedirs(args.resultfolder)
22-
print(f'Created a folder at {args.resultfolder}')
23-
23+
print(f"Created a folder at {args.resultfolder}")
24+
2425
if not os.path.exists(args.modelfolder):
2526
os.makedirs(args.modelfolder)
26-
print(f'Created a folder at {args.modelfolder}')
27-
27+
print(f"Created a folder at {args.modelfolder}")
2828

2929

3030
def test_createfolders():
31-
with TemporaryDirectory(dir = 'tmp/') as temp_dir:
31+
with TemporaryDirectory(dir="tmp/") as temp_dir:
3232
parser = argparse.ArgumentParser()
33-
#Structuture related values
34-
parser.add_argument('--datafolder', type=str, default=os.path.join(temp_dir, 'Data/'), help='Path to where data will be saved during training.')
35-
parser.add_argument('--resultfolder', type=str, default=os.path.join(temp_dir, 'Results/'), help='Path to where results will be saved during evaluation.')
36-
parser.add_argument('--modelfolder', type=str, default=os.path.join(temp_dir, 'Experiments/'), help='Path to where model weights will be saved at the end of training.')
37-
33+
# Structuture related values
34+
parser.add_argument(
35+
"--datafolder",
36+
type=str,
37+
default=os.path.join(temp_dir, "Data/"),
38+
help="Path to where data will be saved during training.",
39+
)
40+
parser.add_argument(
41+
"--resultfolder",
42+
type=str,
43+
default=os.path.join(temp_dir, "Results/"),
44+
help="Path to where results will be saved during evaluation.",
45+
)
46+
parser.add_argument(
47+
"--modelfolder",
48+
type=str,
49+
default=os.path.join(temp_dir, "Experiments/"),
50+
help="Path to where model weights will be saved at the end of training.",
51+
)
52+
3853
args = parser.parse_args()
3954
createfolders(args)
40-
41-
return
55+
56+
return

utils/dataloaders/svhn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from torch.utils.data import Dataset
22

3+
34
class SVHN(Dataset):
45
def __init__(self):
56
super().__init__()
6-
7+
78
def __len__(self):
8-
return
9-
9+
return
10+
1011
def __getitem__(self, index):
11-
return
12+
return

utils/load_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from torch.utils.data import Dataset
22

3-
def load_data(dataset:str) -> Dataset:
4-
5-
raise ValueError(f'Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling')
3+
4+
def load_data(dataset: str) -> Dataset:
5+
raise ValueError(
6+
f"Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling"
7+
)

utils/load_metric.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import copy
2-
import numpy as np
3-
import torch.nn as nn
1+
import copy
2+
3+
import numpy as np
4+
import torch.nn as nn
45
from metrics import EntropyPrediction
56

67

@@ -11,12 +12,11 @@ def __init__(self, *metrics):
1112

1213
for metric in metrics:
1314
self.metrics[metric] = self._get_metric(metric)
14-
15+
1516
self.tmp_scores = copy.deepcopy(self.metrics)
1617
for key in self.tmp_scores:
1718
self.tmp_scores[key] = []
1819

19-
2020
def _get_metric(self, key):
2121
"""
2222
Get the metric function based on the key
@@ -31,28 +31,28 @@ def _get_metric(self, key):
3131
"""
3232

3333
match key.lower():
34-
case 'entropy':
34+
case "entropy":
3535
return EntropyPrediction()
36-
case 'f1':
36+
case "f1":
3737
raise NotImplementedError("F1 score not implemented yet")
38-
case 'recall':
38+
case "recall":
3939
raise NotImplementedError("Recall score not implemented yet")
40-
case 'precision':
40+
case "precision":
4141
raise NotImplementedError("Precision score not implemented yet")
42-
case 'accuracy':
42+
case "accuracy":
4343
raise NotImplementedError("Accuracy score not implemented yet")
4444
case _:
4545
raise ValueError(f"Metric {key} not supported")
4646

4747
def __call__(self, y_true, y_pred):
4848
for key in self.metrics:
4949
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
50-
50+
5151
def __getmetrics__(self):
5252
return_metrics = {}
5353
for key in self.metrics:
5454
return_metrics[key] = np.mean(self.tmp_scores[key])
55-
55+
5656
return return_metrics
5757

5858
def __resetvalues__(self):

utils/load_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import torch.nn as nn
1+
import torch.nn as nn
22
from models import MagnusModel
33

4-
def load_model(modelname:str) -> nn.Module:
5-
6-
if modelname == 'MagnusModel':
4+
5+
def load_model(modelname: str) -> nn.Module:
6+
if modelname == "MagnusModel":
77
return MagnusModel()
88
else:
9-
raise ValueError(f'Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')
9+
raise ValueError(
10+
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
11+
)

utils/metrics/EntropyPred.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import torch.nn as nn
1+
import torch.nn as nn
22

33

44
class EntropyPrediction(nn.Module):
55
def __init__(self):
66
super().__init__()
7-
7+
88
def __call__(self, y_true, y_false):
9-
10-
return
9+
return

utils/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .EntropyPred import EntropyPrediction
1+
from .EntropyPred import EntropyPrediction

utils/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .magnus_model import MagnusModel
1+
from .magnus_model import MagnusModel

utils/models/magnus_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import torch.nn as nn
1+
import torch.nn as nn
2+
23

34
class MagnusModel(nn.Module):
45
def __init__(self):
56
super().__init__()
6-
7+
78
def forward(self, x):
8-
return
9+
return

0 commit comments

Comments
 (0)