Skip to content

Training code implementation  #14

@HassanAbbas92

Description

@HassanAbbas92

`
import matplotlib.pyplot as plt
import cv2
import sys
import os
from PIL import Image, ImageDraw
from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
import numpy as np
from skimage import io
import shutil
from torch.autograd import Variable
import time
import copy
from torch import nn
import torch
import math
import matplotlib
matplotlib.use('Agg')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, weight_map, target):
        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(
            delta_y1 / self.omega, self.alpha - y1)) * weight_map[delta_y < self.theta]
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * \
            torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = (A * delta_y2 - C) * weight_map[delta_y >= self.theta]
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))


def train_model(model, dataloaders, dataset_sizes, use_gpu=True, epoches=5,
                save_path='./', num_landmarks=68, start_epoch=0):
    best_acc = 100
    optimizer = torch.optim.RMSprop(
        model.parameters(), lr=0.0000001, weight_decay=0)
    loss_AW = AdaptiveWingLoss()
    for epoch in range(start_epoch, epoches + start_epoch):
        running_loss = 0
        step = 0
        total_nme = 0
        total_count = 0
        fail_count = 0
        nmes = []
        # running_corrects = 0
        step_start = time.time()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            # Iterate over data.
            # with torch.set_grad_enabled(True):
            for data in dataloaders[phase]:
                optimizer.zero_grad()
                total_runtime = 0
                run_count = 0

                step += 1
                # get the inputs
                inputs = data['image'].type(torch.FloatTensor)
                labels_heatmap = data['heatmap'].type(torch.FloatTensor)
                labels_boundary = data['boundary'].type(torch.FloatTensor)
                gt_landmarks = data['landmarks'].type(torch.FloatTensor)
                loss_weight_map = data['weight_map'].type(torch.FloatTensor)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.to(device)
                    labels_heatmap = labels_heatmap.to(device)
                    labels_boundary = labels_boundary.to(device)
                    loss_weight_map = loss_weight_map.to(device)
                else:
                    inputs, labels_heatmap = Variable(
                        inputs), Variable(labels_heatmap)
                    labels_boundary = Variable(labels_boundary)
                labels = torch.cat((labels_heatmap, labels_boundary), 1)
                single_start = time.time()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs, boundary_channels = model(inputs)
                    pred_labels = torch.cat(
                        (outputs[-1][:, :-1, :, :], boundary_channels[-1][:, :-1, :, :]), 1)
                    ###
                    loss_total = loss_AW(
                        pred_labels, loss_weight_map * 10 + 1, labels)
                    ###
                    #print("Batch Loss: {:.6f}".format(loss.item()))
                    if phase == 'train':
                        loss_total.backward()
                        optimizer.step()
                batch_nme = fan_NME(
                    outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
                #print("Batch NME: {:.6f}".format(batch_nme))
                # batch_nme = 0
                total_nme += batch_nme
            epoch_nme = total_nme / dataset_sizes[phase]
            step_end = time.time()
            print(phase + ' NME: {:.6f}'.format(epoch_nme))
            if phase == 'val' and epoch_nme < best_acc:
                state = {
                    'next_epoch': epoch+1,
                    'epoch_total_nme': epoch_nme,
                    'state_dict': model.state_dict(),
                    # 'scheduler' : scheduler.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(state, save_path+'{:02d}'.format(epoch)+'.pth')
        #nme_save_path = os.path.join(save_path, 'nme_log.npy')
        #np.save(nme_save_path, np.array(nmes))
        #print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))

    #print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
    return model

`
@protossw512 code you please check if my training implementation is correct

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions