-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_left_foot.py
More file actions
107 lines (85 loc) · 4.39 KB
/
train_left_foot.py
File metadata and controls
107 lines (85 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# the dataset we created in Notebook 1 is copied in the helper file `data_load.py`
from torchvision.transforms import CenterCrop
from data_load import FacialKeypointsDataset
# the transforms we defined in Notebook 1 are in the helper file `data_load.py`
from data_load import Rescale, RandomCrop, Normalize, ToTensor
#from models import Net
from model_leftfoot import NetLF
#from model_429_vgg import Net
import torch.nn as nn
import torch.nn.functional as F
net = NetLF()
## TODO: define the data_transform using transforms.Compose([all tx's, . , .])
# order matters! i.e. rescaling should come before a smaller crop
# testing that you've defined a transform
data_transform = transforms.Compose([#Rescale(250),
#RandomCrop(224),
#Rescale(225),
RandomCrop(429),
#CenterCrop([244,244]),
Normalize(),
ToTensor()])
assert(data_transform is not None), 'Define a data_transform'
transformed_dataset = FacialKeypointsDataset(csv_file=r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\frontalpoints.csv',
root_dir=r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\ceasar_mat_first100',
transform=data_transform)
# load training data in batches
batch_size = 2
train_loader = DataLoader(transformed_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0)
#criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(params = net.parameters(), lr = 0.0001)
def train_net(n_epochs):
# prepare the net for training
net.train()
for epoch in range(n_epochs): # loop over the dataset multiple times
running_loss = 0.0
# train on batches of data, assumes you already have train_loader
for batch_i, data in enumerate(train_loader):
# get the input images and their corresponding labels
images = data['image']
key_pts = data['keypoints']
key_pts = key_pts[:,20:26]
#print(key_pts)
# flatten pts
key_pts = key_pts.view(key_pts.size(0), -1)
#print(key_pts[:,16:42])
#key_pts = key_pts[:, 16:42]
# convert variables to floats for regression loss
key_pts = key_pts.type(torch.FloatTensor)
images = images.type(torch.FloatTensor)
# forward pass to get outputs
output_pts = net(images)
# output_pts = output_pts.type(torch.FloatTensor)
# print(output_pts.type)
# print(key_pts.type)
# calculate the loss between predicted and target keypoints
loss = criterion(output_pts, key_pts)
# zero the parameter (weight) gradients
optimizer.zero_grad()
# backward pass to calculate the weight gradients
loss.backward()
# update the weights
optimizer.step()
# print loss statistics
running_loss += loss.item()
if batch_i % 10 == 9: # print every 10 batches
print('Epoch: {}, Batch: {}, Avg. Loss: {}'.format(epoch + 1, batch_i + 1, running_loss / 1000))
running_loss = 0.0
print('Finished Training')
# train your network
n_epochs = 100 # start small, and increase when you've decided on your model structure and hyperparams
# this is a Workspaces-specific context manager to keep the connection
# alive while training your model, not part of pytorch
#with active_session():
train_net(n_epochs)
PATH = r'C:\Users\Semanti Basu\Documents\OneDrive_2020-02-19\3D Ceaser dataset\Image and point generation\Image and point generation\frontaltrainedmodel_leftfoot_multifc_100epochs_first100_fc4.pth'
torch.save(net.state_dict(), PATH)