We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 54c7d01 commit f60e6c3Copy full SHA for f60e6c3
modalx_v2/models/gesture_stgcn.py
@@ -243,12 +243,14 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
243
intensity: (batch, 1) movement intensity
244
"""
245
batch_size = x.size(0)
246
-
247
- # Input batch norm - use reshape instead of view for non-contiguous tensors
248
- x = x.contiguous()
249
- x_flat = x.reshape(batch_size, -1, x.size(2))
250
- x_flat = self.input_bn(x_flat.permute(0, 2, 1)).permute(0, 2, 1)
251
- x = x_flat.reshape(batch_size, x.size(1), x.size(2), x.size(3))
+ C, T, V = x.size(1), x.size(2), x.size(3)
+
+ # Input batch norm: (B, C, T, V) -> (B*T, C*V) -> BN -> reshape back
+ x = x.permute(0, 2, 1, 3).contiguous() # (B, T, C, V)
+ x = x.view(batch_size * T, C * V) # (B*T, C*V=99)
+ x = self.input_bn(x) # BatchNorm1d(99)
252
+ x = x.view(batch_size, T, C, V) # (B, T, C, V)
253
+ x = x.permute(0, 2, 1, 3).contiguous() # Back to (B, C, T, V)
254
255
# ST-GCN layers
256
for layer in self.layers:
0 commit comments