-
Notifications
You must be signed in to change notification settings - Fork 92
Open
Description
`
import matplotlib.pyplot as plt
import cv2
import sys
import os
from PIL import Image, ImageDraw
from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
import numpy as np
from skimage import io
import shutil
from torch.autograd import Variable
import time
import copy
from torch import nn
import torch
import math
import matplotlib
matplotlib.use('Agg')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class AdaptiveWingLoss(nn.Module):
def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
super(AdaptiveWingLoss, self).__init__()
self.omega = omega
self.theta = theta
self.epsilon = epsilon
self.alpha = alpha
def forward(self, pred, weight_map, target):
y = target
y_hat = pred
delta_y = (y - y_hat).abs()
delta_y1 = delta_y[delta_y < self.theta]
delta_y2 = delta_y[delta_y >= self.theta]
y1 = y[delta_y < self.theta]
y2 = y[delta_y >= self.theta]
loss1 = self.omega * torch.log(1 + torch.pow(
delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * \
torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))
def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
save_path='./', num_landmarks=68, start_epoch=0):
best_acc = 100
optimizer = torch.optim.RMSprop(
model.parameters(), lr=0.0000001, weight_decay=0)
loss_AW = AdaptiveWingLoss()
for epoch in range(start_epoch, epoches + start_epoch):
running_loss = 0
step = 0
total_nme = 0
total_count = 0
fail_count = 0
nmes = []
# running_corrects = 0
step_start = time.time()
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# Iterate over data.
# with torch.set_grad_enabled(True):
for data in dataloaders[phase]:
optimizer.zero_grad()
total_runtime = 0
run_count = 0
step += 1
# get the inputs
inputs = data['image'].type(torch.FloatTensor)
labels_heatmap = data['heatmap'].type(torch.FloatTensor)
labels_boundary = data['boundary'].type(torch.FloatTensor)
gt_landmarks = data['landmarks'].type(torch.FloatTensor)
loss_weight_map = data['weight_map'].type(torch.FloatTensor)
# wrap them in Variable
if use_gpu:
inputs = inputs.to(device)
labels_heatmap = labels_heatmap.to(device)
labels_boundary = labels_boundary.to(device)
loss_weight_map = loss_weight_map.to(device)
else:
inputs, labels_heatmap = Variable(
inputs), Variable(labels_heatmap)
labels_boundary = Variable(labels_boundary)
labels = torch.cat((labels_heatmap, labels_boundary), 1)
single_start = time.time()
with torch.set_grad_enabled(phase == 'train'):
outputs, boundary_channels = model(inputs)
pred_labels = torch.cat(
(outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
###
loss_total = loss_AW(
pred_labels, loss_weight_map * 10 + 1, labels)
###
#print("Batch Loss: {:.6f}".format(loss.item()))
if phase == 'train':
loss_total.backward()
optimizer.step()
batch_nme = fan_NME(
outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
#print("Batch NME: {:.6f}".format(batch_nme))
# batch_nme = 0
total_nme += batch_nme
epoch_nme = total_nme / dataset_sizes[phase]
step_end = time.time()
print(phase + ' NME: {:.6f}'.format(epoch_nme))
if phase == 'val' and epoch_nme < best_acc:
state = {
'next_epoch': epoch+1,
'epoch_total_nme': epoch_nme,
'state_dict': model.state_dict(),
# 'scheduler' : scheduler.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
#nme_save_path = os.path.join(save_path, 'nme_log.npy')
#np.save(nme_save_path, np.array(nmes))
#print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))
#print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
return model
`
@protossw512 code you please check if my training implementation is correct
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels