Skip to content

Commit b3649a7

Browse files
authored
Merge branch 'magnus-branch' into main
2 parents b429559 + f68a7dd commit b3649a7

File tree

13 files changed

+199
-4
lines changed

13 files changed

+199
-4
lines changed

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Data/
2-
Results/
3-
Experiments/
4-
_build/
1+
Data/*
2+
Results/*
3+
Experiments/*
4+
env/*

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ dependencies:
99
- sphinx-autobuild
1010
- sphinx-rtd-theme
1111
- pip
12+
- pytest
1213
prefix: /opt/miniconda3/envs/cc-exam
14+

main.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch as th
2+
import torch.nn as nn
3+
4+
import argparse
5+
from utils import load_metric, load_model, createfolders
6+
7+
8+
9+
10+
11+
12+
13+
14+
15+
16+
17+
18+
def main():
19+
'''
20+
21+
Parameters
22+
----------
23+
24+
Returns
25+
-------
26+
27+
Raises
28+
------
29+
30+
'''
31+
parser = argparse.ArgumentParser(
32+
prog='',
33+
description='',
34+
epilog='',
35+
)
36+
#Structuture related values
37+
parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')
38+
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
39+
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
40+
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
41+
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
42+
43+
#Data/Model specific values
44+
parser.add_argument('--modelname', type=str, default='MagnusModel',
45+
choices = ['MagnusModel'], help="Model which to be trained on")
46+
47+
#Training specific values
48+
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
49+
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
50+
args = parser.parse_args()
51+
52+
createfolders(args)
53+
54+
device = 'cuda' if th.cuda.is_available() else 'cpu'
55+
56+
#load model
57+
model = load_model()
58+
model.to(device)
59+
60+
61+
criterion = nn.CrossEntropyLoss()
62+
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
63+
64+
65+
66+
for epoch in range(args.epoch):
67+
68+
#Training loop start
69+
70+
#Eval loop start
71+
72+
pass
73+
74+
75+
if __name__ == '__main__':
76+
main()

utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .load_metric import load_metric
2+
from .load_model import load_model
3+
from .load_data import load_data
4+
from .createfolders import createfolders

utils/createfolders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
from tempfile import TemporaryDirectory
3+
import argparse
4+
5+
def createfolders(args) -> None:
6+
'''
7+
Creates folders for storing data, results, model weights.
8+
9+
Parameters
10+
----------
11+
args
12+
ArgParse object containing string paths to be created
13+
14+
'''
15+
16+
if not os.path.exists(args.datafolder):
17+
os.makedirs(args.datafolder)
18+
print(f'Created a folder at {args.datafolder}')
19+
20+
if not os.path.exists(args.resultfolder):
21+
os.makedirs(args.resultfolder)
22+
print(f'Created a folder at {args.resultfolder}')
23+
24+
if not os.path.exists(args.modelfolder):
25+
os.makedirs(args.modelfolder)
26+
print(f'Created a folder at {args.modelfolder}')
27+
28+
29+
30+
def test_createfolders():
31+
with TemporaryDirectory(dir = 'tmp/') as temp_dir:
32+
parser = argparse.ArgumentParser()
33+
#Structuture related values
34+
parser.add_argument('--datafolder', type=str, default=os.path.join(temp_dir, 'Data/'), help='Path to where data will be saved during training.')
35+
parser.add_argument('--resultfolder', type=str, default=os.path.join(temp_dir, 'Results/'), help='Path to where results will be saved during evaluation.')
36+
parser.add_argument('--modelfolder', type=str, default=os.path.join(temp_dir, 'Experiments/'), help='Path to where model weights will be saved at the end of training.')
37+
38+
args = parser.parse_args()
39+
createfolders(args)
40+
41+
return

utils/dataloaders/svhn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torch.utils.data import Dataset
2+
3+
class SVHN(Dataset):
4+
def __init__(self):
5+
super().__init__()
6+
7+
def __len__(self):
8+
return
9+
10+
def __getitem__(self, index):
11+
return

utils/load_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from torch.utils.data import Dataset
2+
3+
def load_data(dataset:str) -> Dataset:
4+
5+
raise ValueError(f'Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/load_metric.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch.nn as nn
2+
from metrics import EntropyPrediction
3+
4+
def load_metric(metricname:str) -> nn.Module:
5+
'''
6+
This function returns an instance of a class inhereting from nn.Module.
7+
This class returns the given metric given a set of label - prediction pairs.
8+
9+
Parameters
10+
----------
11+
metricname: string
12+
string naming the metric to return.
13+
14+
Returns
15+
-------
16+
Class
17+
Returns an instance of a class inhereting from nn.Module.
18+
19+
Raises
20+
------
21+
ValueError
22+
When the metricname parameter does not match any implemented metric, raise this error along with a descriptive message.
23+
'''
24+
if metricname == 'EntropyPrediction':
25+
return EntropyPrediction()
26+
else:
27+
raise ValueError(f'Metric: {metricname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/load_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch.nn as nn
2+
from models import MagnusModel
3+
4+
def load_model(modelname:str) -> nn.Module:
5+
6+
if modelname == 'MagnusModel':
7+
return MagnusModel()
8+
else:
9+
raise ValueError(f'Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/metrics/EntropyPred.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch.nn as nn
2+
3+
4+
class EntropyPrediction(nn.Module):
5+
def __init__(self):
6+
super().__init__()
7+
8+
def __call__(self, y_true, y_false):
9+
10+
return

0 commit comments

Comments
 (0)