-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
100 lines (83 loc) · 2.82 KB
/
train.py
File metadata and controls
100 lines (83 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import json
import os, sys
import datetime
from pprint import pprint
import torch
import random
import torch.optim as optim
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint
from ignite.utils import setup_logger
from ignite.metrics import Loss
from mapping import mapping
from utils import *
from data import get_ds
from metrics import DiceMetric
def train(args):
# Seed setup
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", index=args.idx)
# Folder setup and save setting
args.exp_dir = folder_setup(args)
save_cfg(args, args.exp_dir)
# Data setup
data, args = get_ds(args)
_, _, _, train_dl, valid_dl, _ = data
# Logging setup
logger = Logging(args)
# Mapping
try:
model_class = mapping[args.ds]["model"][args.model]
# metric_class = mapping[args.ds]["metrics"]["dsc"]
loss_class = mapping[args.ds]["loss"][args.loss]
except KeyError as e:
raise ValueError(f"Invalid key in mapping: {e}")
# Model setup
if args.model == "unet":
model = model_class(device=device)
else:
model = model_class()
model.to(device)
#
criterion = loss_class()
metrics = {
'dsc': DiceMetric(device=device),
'loss': Loss(criterion)
}
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Trainer and Evaluator
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
trainer.logger = setup_logger('Trainer')
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
best_dsc = 0
# Event
@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics(engine):
nonlocal best_dsc
train_evaluator.run(train_dl)
validation_evaluator.run(valid_dl)
curr_dsc = validation_evaluator.state.metrics["dsc"]
if curr_dsc > best_dsc:
best_dsc = curr_dsc
logger.step(engine.state.epoch)
def score_function(engine):
return engine.state.metrics["dsc"]
model_checkpoint = ModelCheckpoint(
args.exp_dir,
n_saved=2,
filename_prefix='best',
score_function=score_function,
score_name='dsc',
global_step_transform=None,
require_empty=False
)
validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
trainer.run(train_dl, max_epochs=args.epochs)