Skip to content

Commit 5709f17

Browse files
committed
Added a simple test to make sure github actions work
1 parent 5e3ab17 commit 5709f17

File tree

9 files changed

+113
-6
lines changed

9 files changed

+113
-6
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
Data/*
22
Results/*
33
Experiments/*
4-
env/*
4+
env/*
5+
test.ipynb

main.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33

44
import argparse
5-
from utils import load_metric
5+
from utils import load_metric, load_model, createfolders
66

77

88

@@ -16,17 +16,50 @@
1616

1717

1818
def main():
19+
'''
1920
21+
Parameters
22+
----------
23+
24+
Returns
25+
-------
26+
27+
Raises
28+
------
29+
30+
'''
2031
parser = argparse.ArgumentParser(
2132
prog='',
2233
description='',
2334
epilog='',
2435
)
36+
#Structuture related values
37+
parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')
38+
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
39+
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
40+
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
41+
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
2542

26-
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do')
43+
#Data/Model specific values
44+
parser.add_argument('--modelname', type=str, default='MagnusModel',
45+
choices = ['MagnusModel'], help="Model which to be trained on")
2746

47+
#Training specific values
48+
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
49+
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
2850
args = parser.parse_args()
2951

52+
createfolders(args)
53+
54+
device = 'cuda' if th.cuda.is_available() else 'cpu'
55+
56+
#load model
57+
model = load_model()
58+
model.to(device)
59+
60+
61+
criterion = nn.CrossEntropyLoss()
62+
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
3063

3164

3265

@@ -36,7 +69,7 @@ def main():
3669

3770
#Eval loop start
3871

39-
72+
pass
4073

4174

4275
if __name__ == '__main__':

utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from .load_metric import load_metric
1+
from .load_metric import load_metric
2+
from .load_model import load_model
3+
from .load_data import load_data
4+
from .createfolders import createfolders

utils/createfolders.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
from tempfile import TemporaryDirectory
3+
import argparse
4+
5+
def createfolders(args) -> None:
6+
'''
7+
Creates folders for storing data, results, model weights.
8+
9+
Parameters
10+
----------
11+
args
12+
ArgParse object containing string paths to be created
13+
14+
'''
15+
16+
if not os.path.exists(args.datafolder):
17+
os.makedirs(args.datafolder)
18+
print(f'Created a folder at {args.datafolder}')
19+
20+
if not os.path.exists(args.resultfolder):
21+
os.makedirs(args.resultfolder)
22+
print(f'Created a folder at {args.resultfolder}')
23+
24+
if not os.path.exists(args.modelfolder):
25+
os.makedirs(args.modelfolder)
26+
print(f'Created a folder at {args.modelfolder}')
27+
28+
29+
30+
def test_createfolders():
31+
with TemporaryDirectory(dir = 'tmp/') as temp_dir:
32+
parser = argparse.ArgumentParser()
33+
#Structuture related values
34+
parser.add_argument('--datafolder', type=str, default=os.path.join(temp_dir, 'Data/'), help='Path to where data will be saved during training.')
35+
parser.add_argument('--resultfolder', type=str, default=os.path.join(temp_dir, 'Results/'), help='Path to where results will be saved during evaluation.')
36+
parser.add_argument('--modelfolder', type=str, default=os.path.join(temp_dir, 'Experiments/'), help='Path to where model weights will be saved at the end of training.')
37+
38+
args = parser.parse_args()
39+
createfolders(args)
40+
41+
return

utils/dataloaders/svhn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torch.utils.data import Dataset
2+
3+
class SVHN(Dataset):
4+
def __init__(self):
5+
super().__init__()
6+
7+
def __len__(self):
8+
return
9+
10+
def __getitem__(self, index):
11+
return

utils/load_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from torch.utils.data import Dataset
2+
3+
def load_data(dataset:str) -> Dataset:
4+
5+
raise ValueError(f'Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/load_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import torch.nn as nn
2+
from models import MagnusModel
23

34
def load_model(modelname:str) -> nn.Module:
45

5-
raise ValueError(f'Metric: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')
6+
if modelname == 'MagnusModel':
7+
return MagnusModel()
8+
else:
9+
raise ValueError(f'Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .magnus_model import MagnusModel

utils/models/magnus_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch.nn as nn
2+
3+
class MagnusModel(nn.Module):
4+
def __init__(self):
5+
super().__init__()
6+
7+
def forward(self, x):
8+
return

0 commit comments

Comments
 (0)