This repository was archived by the owner on Jan 9, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
89 lines (70 loc) · 2.3 KB
/
main.py
File metadata and controls
89 lines (70 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from args import args
from init import init
from data import unpacking_s2s
from learn import train, evaluate
def training(model, optim, train_set, valid_set, unzip, args):
"""Training procedure.
Parameters
----------
model : :class:`torch.nn.Model`
The initialized model.
optim : :class:`torch.Optimizer`
The optimizer.
train_set : :class:`torch.DataLoader`
Training dataset.
valid_set : :class:`torch.DataLoader`
Validation dataset.
unzip :
Function that unpacks the minibatch.
args : :class:`argparse.Namespace`
Configurations.
Returns
-------
model : :class:`torch.nn.Model`
The trained model.
"""
model = train(model, optim, train_set, valid_set, unzip, args)
torch.save({'model': model.state_dict()}, args.checkpoint)
return model
def evaluation(model, train_set, test_set, unzip, args):
"""Evaluation procedure.
Parameters
----------
model : :class:`torch.nn.Model`
The initialized model.
train_set : :class:`torch.DataLoader`
Training dataset.
test_set : :class:`torch.DataLoader`
Test dataset.
unzip :
Function that unpacks the minibatch.
args : :class:`argparse.Namespace`
Configurations.
"""
saved_model = torch.load(args.checkpoint)
return evaluation_on_dataset(model, saved_model, test_set, unzip, args)
def evaluation_on_dataset(model, saved_model, dataloader, unzip, args):
"""Evaluation on a dataset.
Parameters
----------
model : :class:`torch.nn.Model`
The initialized model.
models : dict
Saved models.
dataloader : :class:`troch.DataLoader`
Dataset.
unzip : function
Function that unpacks the minibatch.
args : :class:`argparse.Namespace`
Configurations.
"""
res = {'mae': 0, 'rmse': 0, 'acc': 0}
model.load_state_dict(saved_model['model'])
res = {key: val
for key, val in evaluate(model, dataloader, unzip, args).items()}
print('MAE: {:.4f}, RMSE: {:.4f}, ACC: {:.4f}'.format(
res['mae'], res['rmse'], res['acc']))
tr_set, val_set, te_set, model, optim = init(args)
model = training(model, optim, tr_set, val_set, unpacking_s2s, args)
evaluation(model, tr_set, te_set, unpacking_s2s, args)