-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
185 lines (151 loc) · 6.86 KB
/
train.py
File metadata and controls
185 lines (151 loc) · 6.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import torch.nn as nn
from utils.utils import *
from utils.args import Args
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from utils.dataset import Dataset
from model.attention_module import Model
from utils.loss import *
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def main():
# Category to sentence
class_name = {0: ["baseball pitch", "throw a baseball", "baseball throw"],
1: ["basketball dunk", "dunk a basketball", "slam dunk basketball"],
2: ["billiards"],
3: ["clean and jerk", "weight lifting movement"],
4: ["cliff diving", "high diving", "diving"],
5: ["cricket shot"],
6: ["cricket bowling", "cricket movement", "bowl cricket"],
7: ["diving", "jumping into water", "falling into water"],
8: ["frisbee catch", "catch frisbee"],
9: ["golf swing", "golf stroke"],
10: ["hammer throw", "throw a hammer"],
11: ["high jump"],
12: ["javelin throw", "throw a spear"],
13: ["long jump", "jump contest"],
14: ["pole vault", "a person uses a long flexible pole to jump over a bar"],
15: ["shot put"],
16: ["soccer penalty"],
17: ["tennis swing"],
18: ["throw discus", "discus"],
19: ["volleyball spiking", "volleyball", ]}
# Initialize the arguments
args = Args()
checkpoint_model_name = args.model_name
# Specify GPU
device = torch.device(args.gpu)
args.device = device
# Initialize dataset
dataset = Dataset(args)
# Duo Model testing
lstm_input_size = 300
hidden_dim = 300
output_dim = 100
batch_size = 1
num_layers = 2
model = Model(lstm_input_size, hidden_dim, batch_size=batch_size,
time_steps=8, args=args, output_dim=output_dim, num_layers=num_layers)
# Load pre-trained classification network for bootstrapping
checkpoint = torch.load(args.t_cam)
pretrained_dict = checkpoint['state_dict']
fc_weight = pretrained_dict['fc3.weight']
# filter out unnecessary keys and load valid params
# model_dict = model.state_dict()
# checkpoint = torch.load(args.t_cam)
# pretrained_dict = checkpoint['state_dict']
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)
# filter out unnecessary keys and load valid params
model_dict = model.visual_model.state_dict()
checkpoint = torch.load(args.t_cam)
pretrained_dict = checkpoint['state_dict']
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.visual_model.load_state_dict(model_dict)
model.to(device)
model.textual_model.to(device)
model.visual_model.to(device)
print('model created')
# Store the fc weights for t-cam prediction
# args.weights = model.visual_model.fc3.weight.detach().cpu()
args.weights = fc_weight.cpu()
args.tao = 0.9
args.lr = 0.01
# Loss defined here
marginrankingloss = nn.MarginRankingLoss(0.1)
adaptive_margin_loss = Adaptive_Margin_Loss()
l1_regu = nn.L1Loss(size_average=False)
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
writer = SummaryWriter()
for epoch in range(args.max_iter):
# Randomly extract 10 video clips' I3D feature
features, labels = dataset.load_data()
# Features are aligned in 750 frames all the same, now trunk it into max length
seq_len = np.sum(np.max(np.abs(features), axis=2) > 0, axis=1)
features = features[:, :np.max(seq_len), :]
# Convert to CUDA tensor
features = torch.from_numpy(features).float().to(device)
labels = torch.from_numpy(labels).float().to(device)
# Generate texts from categories
text_list, labels = one_label_text(class_name, labels)
# Visualize T-CAM result for comparison
clses = [[idx for idx, cls in enumerate(label) if cls == 1.] for label in labels]
t_proposals = temporal_proposals(args.weights, features.detach().cpu(), clses)
# Predict
attention_weights, visual_feature, textual_feature, pos_feature, neg_feature, neg_features_mean \
= model(features, text_list, None)
# L1 regularization on attention weights
# l1_loss = 0
# for attention in attention_weights:
# target = torch.zeros_like(attention)
# l1_loss += l1_regu(attention, target)
# loss = adaptive_margin_loss(pos_feature, neg_feature, textual_feature, args.device)\
# +euclidean_distance(visual_feature, textual_feature)
# Squared Loss, Margin Ranking Loss
pos_distance = euclidean_distance(pos_feature, textual_feature, 1)
neg_distance = euclidean_distance(neg_features_mean, textual_feature, 1)
target = torch.zeros(visual_feature.shape[0]).cuda() - 1
margin_loss = marginrankingloss(pos_distance, neg_distance, target)
loss = adaptive_margin_loss(pos_feature, neg_feature, textual_feature, args.device) + 0.01 * margin_loss
# Back Propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save checkpoint
if epoch % 3000 == 0 and epoch is not 0:
# Reduce lr
args.tao /= 3
args.lr /= 1.5
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
# Checkpoint structure
model_state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
directory = args.checkpoint_path + checkpoint_model_name + "/"
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(model_state,
os.path.join(directory + 'epoch_{:03}.pth'.format(epoch)))
# Print out training loss
loss_value = loss.detach().cpu().tolist()
writer.add_scalar('runs/', loss_value, epoch)
if epoch % 20 == 0:
print('Epoch:{:03}, Loss: {:02}'.format(epoch, loss_value))
# Display attention weights and loss value
plt.plot(attention_weights[0].tolist(), c='b')
plt.plot(t_proposals[0], c='r')
directory = "./visualization/" + checkpoint_model_name + "/"
if not os.path.exists(directory):
os.makedirs(directory)
plt.savefig(directory + str(epoch) + ".png")
plt.clf()
plt.close("all")
writer.export_scalars_to_json(args.json)
writer.close()
if __name__ == "__main__":
main()