Skip to content

Commit 777a470

Browse files
committed
add feature extractor
1 parent aa4ccf7 commit 777a470

File tree

4 files changed

+92
-9
lines changed

4 files changed

+92
-9
lines changed

SyncNetInstance.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SyncNetInstance(torch.nn.Module):
3636
def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
3737
super(SyncNetInstance, self).__init__();
3838

39-
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers);
39+
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
4040

4141
def evaluate(self, opt, videofile):
4242

@@ -139,6 +139,57 @@ def evaluate(self, opt, videofile):
139139
dists_npy = numpy.array([ dist.numpy() for dist in dists ])
140140
return offset.numpy(), conf.numpy(), dists_npy
141141

142+
def extract_feature(self, opt, videofile):
143+
144+
self.__S__.eval();
145+
146+
# ========== ==========
147+
# Load video
148+
# ========== ==========
149+
cap = cv2.VideoCapture(videofile)
150+
151+
frame_num = 1;
152+
images = []
153+
while frame_num:
154+
frame_num += 1
155+
ret, image = cap.read()
156+
if ret == 0:
157+
break
158+
159+
image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
160+
images.append(image_np)
161+
162+
im = numpy.stack(images,axis=3)
163+
im = numpy.expand_dims(im,axis=0)
164+
im = numpy.transpose(im,(0,3,4,1,2))
165+
166+
imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
167+
168+
# ========== ==========
169+
# Generate video feats
170+
# ========== ==========
171+
172+
lastframe = len(images)-4
173+
im_feat = []
174+
175+
tS = time.time()
176+
for i in range(0,lastframe,opt.batch_size):
177+
178+
im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
179+
im_in = torch.cat(im_batch,0)
180+
im_out = self.__S__.forward_lip(im_in.cuda());
181+
im_feat.append(im_out.data.cpu())
182+
183+
im_feat = torch.cat(im_feat,0)
184+
185+
# ========== ==========
186+
# Compute offset
187+
# ========== ==========
188+
189+
print('Compute time %.3f sec.' % (time.time()-tS))
190+
191+
return im_feat
192+
142193

143194
def loadParameters(self, path):
144195
loaded_state = torch.load(path, map_location=lambda storage, loc: storage);

SyncNetModel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,6 @@ def __init__(self, num_layers_in_fc_layers = 1024):
8989
nn.BatchNorm3d(2048),
9090
nn.ReLU(inplace=True),
9191
);
92-
93-
self.netcnnaud = self.netcnnaud.cuda();
94-
self.netcnnlip = self.netcnnlip.cuda();
95-
self.netfcaud = self.netfcaud.cuda();
96-
self.netfclip = self.netfclip.cuda();
97-
9892

9993
def forward_aud(self, x):
10094

@@ -110,4 +104,10 @@ def forward_lip(self, x):
110104
mid = mid.view((mid.size()[0], -1)); # N x (ch x 24)
111105
out = self.netfclip(mid);
112106

107+
return out;
108+
109+
def forward_lipfeat(self, x):
110+
111+
out = self.netcnnlip(x);
112+
113113
return out;

demo_feature.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/python
2+
#-*- coding: utf-8 -*-
3+
4+
import time, pdb, argparse, subprocess
5+
6+
from SyncNetInstance import *
7+
8+
# ==================== LOAD PARAMS ====================
9+
10+
11+
parser = argparse.ArgumentParser(description = "SyncNet");
12+
13+
parser.add_argument('--initial_model', type=str, default="data/syncnetl2.model", help='');
14+
parser.add_argument('--batch_size', type=int, default='20', help='');
15+
parser.add_argument('--vshift', type=int, default='15', help='');
16+
parser.add_argument('--videofile', type=str, default="data/example.avi", help='');
17+
parser.add_argument('--tmp_dir', type=str, default="data", help='');
18+
parser.add_argument('--save_as', type=str, default="data/features.pt", help='');
19+
20+
opt = parser.parse_args();
21+
22+
23+
# ==================== RUN EVALUATION ====================
24+
25+
s = SyncNetInstance();
26+
27+
s.loadParameters(opt.initial_model);
28+
print("Model %s loaded."%opt.initial_model);
29+
30+
feats = s.extract_feature(opt, videofile=opt.videofile)
31+
32+
torch.save(feats, opt.save_as)

demo_syncnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
parser.add_argument('--initial_model', type=str, default="data/syncnetl2.model", help='');
1414
parser.add_argument('--batch_size', type=int, default='20', help='');
1515
parser.add_argument('--vshift', type=int, default='15', help='');
16-
parser.add_argument('--videofile', type=str, default="", help='');
17-
parser.add_argument('--tmp_dir', type=str, default="~", help='');
16+
parser.add_argument('--videofile', type=str, default="data/example.avi", help='');
17+
parser.add_argument('--tmp_dir', type=str, default="data", help='');
1818

1919
opt = parser.parse_args();
2020

0 commit comments

Comments
 (0)