-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
82 lines (61 loc) · 2.74 KB
/
main.py
File metadata and controls
82 lines (61 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import h5pickle as h5py
import matplotlib.pyplot as plt
from model import Modified3DUNet, SimpleModel
from UNetTrainer import PatchDataset, UNetTrainer
import torch
from loss import GeneralizedDiceLoss, WeightedCrossEntropyLoss
import torch.nn as nn
# Creating a main is necessary in windows for multiprocessing, which is used by the dataloader
def main():
# patches_file = "50_patches_dataset.h5"
# hf = h5py.File(patches_file, 'r')
# We obtain a list with all the IDs of the patches
# all_groups = list(hf)
# Dividing the dataset into train and validation. Shuffle has to be false otherwise the model might be trained
# on what was previously validation set and validated on what was previously train set.
# X_train, X_validation = train_test_split(all_groups, test_size=0.2, shuffle=False)
# print(X_train, X_validation)
# for testing
datapath = "Data/"
train_file = datapath + "patches_dataset_test.h5"
val_file = datapath + "val250.h5"
# Loader Parameters
params = {'batch_size': 2,
'shuffle': False,
'num_workers': 0}
train_dataset = PatchDataset(train_file, n_classes=3)
print(len(train_dataset))
val_dataset = PatchDataset(val_file, n_classes=3)
print(len(val_dataset))
train_loader = DataLoader(train_dataset, **params)
val_loader = DataLoader(val_dataset, **params)
loaders = {
'train': train_loader,
'val': val_loader
}
# Model and param
model = Modified3DUNet(in_channels=1, n_classes=3)
optimizer = optim.Adam(model.parameters())
max_epochs = 10
# Median foreground percentage = 0.2 (= class 1,2)
# Median cancer percentage = 0.01 (= class 2)
# Median pancreas percentage = 0.2 - 0.01 = 0.19 (= class 1)
# Median background percentage = 1-0.2 = 99.8 (=class 0)
# [99.8, 0.19, 0.01] => corresponding class weights = [1, 525, 9980]
# class_weights = torch.tensor([1., 525., 9980.])
# loss_criterion = GeneralizedDiceLoss(weight=class_weights)
# loss_criterion = WeightedCrossEntropyLoss(weight=class_weights)
weights = [1, 100, 500]
class_weights = torch.FloatTensor(weights)
loss_criterion = nn.CrossEntropyLoss(weight=class_weights)
# trainer = UNetTrainer(model, optimizer, loaders, max_epochs, loss_criterion=loss_criterion)
# trainer.train()
# Load from last epoch
checkpoint_trainer = UNetTrainer.load_checkpoint("WCEL_1_10_50_last_model", model, optimizer, loaders, max_epochs, loss_criterion=loss_criterion)
pred = checkpoint_trainer.single_image_forward(val_dataset[0][0])
# print(pred)
if __name__ == '__main__':
main()