Skip to content

Commit 3286b10

Browse files
authored
Smartlab mstcn fix (#3430)
* segment.py fix * segmentor fix
1 parent 01171fc commit 3286b10

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

demos/smartlab_demo/python/segmentor.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, core, device, i3d_path, mstcn_path):
5757
net.reshape({net.inputs[0]: PartialShape(
5858
[self.EmbedBatchSize, self.EmbedWindowLength, self.ImgSizeHeight, self.ImgSizeWidth, 3])})
5959
nodes = net.get_ops()
60-
net.add_outputs(nodes[13].output(0))
60+
net.add_outputs(nodes[11].output(0))
6161
self.i3d = core.compile_model(model=net, device_name=device)
6262

6363
self.mstcn_net = core.read_model(mstcn_path)
@@ -124,8 +124,7 @@ def feature_embedding(self, img_buffer, embedding_buffer, frame_index):
124124
[cv2.resize(img_buffer[start_index + i * self.EmbedWindowAtrous],
125125
(self.ImgSizeHeight, self.ImgSizeWidth)) for i in range(self.EmbedWindowLength)]
126126
for j in range(self.EmbedBatchSize)]
127-
input_data = np.asarray(input_data).transpose((0, 4, 1, 2, 3))
128-
input_data = input_data * 127.5 + 127.5
127+
input_data = np.asarray(input_data) * 127.5 + 127.5
129128

130129
input_dict = {self.i3d.inputs[0]: input_data}
131130
out_logits = infer_request.infer(input_dict)[self.i3d.outputs[1]]
@@ -139,14 +138,15 @@ def feature_embedding(self, img_buffer, embedding_buffer, frame_index):
139138

140139
def action_segmentation(self):
141140
# read buffer
141+
batch_size = self.SegBatchSize
142142
embed_buffer_top = self.EmbedBufferTop
143143
embed_buffer_front = self.EmbedBufferFront
144-
batch_size = self.SegBatchSize
145144
start_index = self.TemporalLogits.shape[0]
146145
end_index = min(embed_buffer_top.shape[-1], embed_buffer_front.shape[-1])
147146
num_batch = (end_index - start_index) // batch_size
148147

149148
infer_request = self.reshape_mstcn.create_infer_request()
149+
150150
if num_batch < 0:
151151
log.debug("Waiting for the next frame ...")
152152
elif num_batch == 0:
@@ -163,23 +163,28 @@ def action_segmentation(self):
163163
string = list(key.names)[0]
164164
feed_dict[string] = self.his_fea[string]
165165
feed_dict['input'] = input_mstcn
166-
if input_mstcn.shape == (1, 2048, 1):
167-
out = infer_request.infer(feed_dict)
168166

167+
# if the input shape is unsupported shape, add dummpy result of noise and skip inference
168+
if input_mstcn.shape != (1, 2048, 1):
169+
tranposed_logits = np.zeros((2048, 16))
170+
self.TemporalLogits = np.concatenate([self.TemporalLogits, tranposed_logits], axis=0)
171+
return None
172+
173+
out = infer_request.infer(feed_dict)
169174
predictions = out[list(out.keys())[-1]]
170175
for key in self.mstcn_output_key:
171176
if 'fhis_in_' in str(key.names):
172177
string = list(key.names)[0]
173178
self.his_fea[string] = out[string]
174179

175180
"""
176-
predictions --> 4x1x64x24
177-
his_fea --> [12*[1x64x2048], 11*[1x64x2048], 11*[1x64x2048], 11*[1x64x2048]]
181+
predictions --> 4x64x24
182+
his_fea --> [12*[64x2048], 11*[64x2048], 11*[64x2048], 11*[64x2048]]
178183
"""
179-
temporal_logits = predictions[:, :, :len(self.ActionTerms), :] # 4x1x16xN
180-
temporal_logits = softmax(temporal_logits[-1], 1) # 1x16xN
181-
temporal_logits = temporal_logits.transpose((0, 2, 1)).squeeze(axis=0)
182-
self.TemporalLogits = np.concatenate([self.TemporalLogits, temporal_logits], axis=0)
184+
temporal_logits = predictions[:, :len(self.ActionTerms), :] # Nx16x2048
185+
softmaxed_logits = softmax(temporal_logits[-1], 1) # 16x2048
186+
tranposed_logits = softmaxed_logits.transpose((1, 0)) # 2048x16
187+
self.TemporalLogits = np.concatenate([self.TemporalLogits, tranposed_logits], axis=0)
183188
else:
184189
for batch_idx in range(num_batch):
185190
unit1 = embed_buffer_top[:,
@@ -202,7 +207,7 @@ def action_segmentation(self):
202207
string = list(key.names)[0]
203208
self.his_fea[string] = out[string]
204209

205-
temporal_logits = predictions[:, :, :len(self.ActionTerms), :] # 4x1x16xN
206-
temporal_logits = softmax(temporal_logits[-1], 1) # 1x16xN
207-
temporal_logits = temporal_logits.transpose((0, 2, 1)).squeeze(axis=0)
210+
temporal_logits = predictions[:, :len(self.ActionTerms), :]
211+
softmaxed_logits = softmax(temporal_logits[-1], 1)
212+
tranposed_logits = softmaxed_logits.transpose((1, 0))
208213
self.TemporalLogits = np.concatenate([self.TemporalLogits, temporal_logits], axis=0)

0 commit comments

Comments
 (0)