Skip to content

Commit f60e6c3

Browse files
Fix gesture batch norm dimensions correctly
1 parent 54c7d01 commit f60e6c3

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

modalx_v2/models/gesture_stgcn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,14 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
243243
intensity: (batch, 1) movement intensity
244244
"""
245245
batch_size = x.size(0)
246-
247-
# Input batch norm - use reshape instead of view for non-contiguous tensors
248-
x = x.contiguous()
249-
x_flat = x.reshape(batch_size, -1, x.size(2))
250-
x_flat = self.input_bn(x_flat.permute(0, 2, 1)).permute(0, 2, 1)
251-
x = x_flat.reshape(batch_size, x.size(1), x.size(2), x.size(3))
246+
C, T, V = x.size(1), x.size(2), x.size(3)
247+
248+
# Input batch norm: (B, C, T, V) -> (B*T, C*V) -> BN -> reshape back
249+
x = x.permute(0, 2, 1, 3).contiguous() # (B, T, C, V)
250+
x = x.view(batch_size * T, C * V) # (B*T, C*V=99)
251+
x = self.input_bn(x) # BatchNorm1d(99)
252+
x = x.view(batch_size, T, C, V) # (B, T, C, V)
253+
x = x.permute(0, 2, 1, 3).contiguous() # Back to (B, C, T, V)
252254

253255
# ST-GCN layers
254256
for layer in self.layers:

0 commit comments

Comments
 (0)