diff --git a/aion/codecs/modules/subsampler.py b/aion/codecs/modules/subsampler.py index af66f7a..ae9b615 100644 --- a/aion/codecs/modules/subsampler.py +++ b/aion/codecs/modules/subsampler.py @@ -32,7 +32,7 @@ def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]): # Normalize label_sizes = labels.sum(dim=1, keepdim=True) - scales = ((self.dim_in / label_sizes) ** 0.5).squeeze() + scales = ((self.dim_in / label_sizes) ** 0.5).squeeze(-1) # Apply linear layer return scales[:, None, None, None] * F.linear(x, self.weight, self.bias) diff --git a/tests/codecs/test_image_codec.py b/tests/codecs/test_image_codec.py index f60d2d1..76fdb36 100644 --- a/tests/codecs/test_image_codec.py +++ b/tests/codecs/test_image_codec.py @@ -81,3 +81,26 @@ def test_hf_previous_predictions(data_dir): rtol=1e-3, atol=1e-4, ) + + +def test_batch_size_one(): + """Test ImageCodec with batch_size=1 to ensure subsampler works correctly.""" + codec = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) + + # Test with batch_size=1 + batch_size = 1 + flux_tensor = torch.randn(batch_size, 4, 96, 96) + input_image_obj = Image( + flux=flux_tensor, + bands=["DES-G", "DES-R", "DES-I", "DES-Z"], + ) + + # This should not raise an error (previously failed due to squeeze() issue) + with torch.no_grad(): + encoded = codec.encode(input_image_obj) + decoded_image_obj = codec.decode( + encoded, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] + ) + + assert isinstance(decoded_image_obj, Image) + assert decoded_image_obj.flux.shape == flux_tensor.shape