Skip to content

Commit 5e3ab17

Browse files
committed
Started adding files and folders we need. Initial structure, example of numpydoc docstring style, and some other in-progress work
1 parent 7d50ba6 commit 5e3ab17

File tree

7 files changed

+91
-3
lines changed

7 files changed

+91
-3
lines changed

.gitignore

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
Data/
2-
Results/
3-
Experiments/
1+
Data/*
2+
Results/*
3+
Experiments/*
4+
env/*

main.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch as th
2+
import torch.nn as nn
3+
4+
import argparse
5+
from utils import load_metric
6+
7+
8+
9+
10+
11+
12+
13+
14+
15+
16+
17+
18+
def main():
19+
20+
parser = argparse.ArgumentParser(
21+
prog='',
22+
description='',
23+
epilog='',
24+
)
25+
26+
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do')
27+
28+
args = parser.parse_args()
29+
30+
31+
32+
33+
for epoch in range(args.epoch):
34+
35+
#Training loop start
36+
37+
#Eval loop start
38+
39+
40+
41+
42+
if __name__ == '__main__':
43+
main()

utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .load_metric import load_metric

utils/load_metric.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch.nn as nn
2+
from metrics import EntropyPrediction
3+
4+
def load_metric(metricname:str) -> nn.Module:
5+
'''
6+
This function returns an instance of a class inhereting from nn.Module.
7+
This class returns the given metric given a set of label - prediction pairs.
8+
9+
Parameters
10+
----------
11+
metricname: string
12+
string naming the metric to return.
13+
14+
Returns
15+
-------
16+
Class
17+
Returns an instance of a class inhereting from nn.Module.
18+
19+
Raises
20+
------
21+
ValueError
22+
When the metricname parameter does not match any implemented metric, raise this error along with a descriptive message.
23+
'''
24+
if metricname == 'EntropyPrediction':
25+
return EntropyPrediction()
26+
else:
27+
raise ValueError(f'Metric: {metricname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/load_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import torch.nn as nn
2+
3+
def load_model(modelname:str) -> nn.Module:
4+
5+
raise ValueError(f'Metric: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')

utils/metrics/EntropyPred.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch.nn as nn
2+
3+
4+
class EntropyPrediction(nn.Module):
5+
def __init__(self):
6+
super().__init__()
7+
8+
def __call__(self, y_true, y_false):
9+
10+
return

utils/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .EntropyPred import EntropyPrediction

0 commit comments

Comments
 (0)