In my forward pass I have:
y = mask * x
# mask.shape: [B, 1, H, W]
# x.shape: [B, 32, H, W]
When I try to prune, the dependency builder forces the output channels of mask (1 channel) to match the output channels of x (32 channels), even though broadcasting should allow different channel counts.
This causes unintended coupling between the two branches and prevents pruning only the x path. I believe the dependency analysis for element-wise multiplication should handle broadcasted channel dimensions without forcing channel equality.
Any idea or solution?
Thanks