-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
95 lines (69 loc) · 2.58 KB
/
train.py
File metadata and controls
95 lines (69 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#train.py
import os
import torch
import torch.nn as nn
from model import Unet
from tqdm import tqdm
from dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
from torch import optim
from loss import DiceCELoss
from eval_net import eval_net
train_file_path = r'C:\Users\f1995\Desktop\train_ex\trainfile.txt'
val_file_path = r'C:\Users\f1995\Desktop\train_ex\validationfile.txt'
test_file_path = r'C:\Users\f1995\Desktop\train\testfile.txt'
dir_checkpoint=""
batch_size=4
learn_rate = 0.0001
epochs=10#for
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net=Unet(1 ,4)
net.to(device=device)
train_dataset = BasicDataset(train_file_path)
val_dataset = BasicDataset(val_file_path)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True)
global_step = 0
optimizer = optim.Adam(net.parameters(), lr=learn_rate)
# optimizer = optim.SGD(net.parameters(), lr=lr)
init_score=0.
for epoch in range(epochs):
net.train()
epoch_loss = 0
evalDice=[]
lossfunc = DiceCELoss()
with tqdm(total=len(train_dataset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
masks_pred = net(imgs)
loss = lossfunc(masks_pred, true_masks)
epoch_loss += loss.item()
# dice= dice_coeff(masks_pred, true_masks).item()
# evalDice.append(dice)
# pbar.set_postfix(**{'loss (batch)': loss.item()})
# pbar.set_postfix(**{'loss (batch)': loss.item(),'Fusdice (batch)': dice})
optimizer.zero_grad()
#loss.requires_grad_(True)
loss.backward()
# nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.set_postfix(**{'loss (batch)': loss.item()})
pbar.update(imgs.shape[0])
global_step += 1
val_score = eval_net(net, val_loader, device)
print(f'epoch{epoch } validation dice is ' ,val_score)
if val_score>= init_score:
init_score=val_score
save_cp=True
else:
save_cp=False
if save_cp:
try:
os.mkdir(dir_checkpoint)
except OSError:
pass
torch.save(net, dir_checkpoint + f'mmodelWeight.pth')
print("val dice is", init_score)