Skip to content

Commit 3b3dd93

Browse files
committed
add unit tests
1 parent d3eed69 commit 3b3dd93

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""Tests for image_mapper module."""
2+
import pytest
3+
import numpy as np
4+
from ctapipe.instrument import CameraGeometry
5+
6+
from dl1_data_handler.image_mapper import (
7+
BilinearMapper,
8+
BicubicMapper,
9+
NearestNeighborMapper,
10+
RebinMapper,
11+
AxialMapper,
12+
OversamplingMapper,
13+
ShiftingMapper,
14+
SquareMapper,
15+
)
16+
17+
18+
@pytest.fixture
19+
def lstcam_geometry():
20+
"""Fixture to provide LSTCam geometry."""
21+
return CameraGeometry.from_name("LSTCam")
22+
23+
24+
@pytest.fixture
25+
def sample_image(lstcam_geometry):
26+
"""Fixture to provide a sample image for testing."""
27+
return np.random.rand(lstcam_geometry.n_pixels, 1).astype(np.float32)
28+
29+
30+
class TestInterpolationImageShape:
31+
"""Test that interpolation_image_shape parameter works correctly (issue #171)."""
32+
33+
@pytest.mark.parametrize(
34+
"mapper_class",
35+
[BilinearMapper, BicubicMapper, NearestNeighborMapper, RebinMapper],
36+
)
37+
def test_interpolation_image_shape_kwarg(self, lstcam_geometry, mapper_class):
38+
"""Test that interpolation_image_shape can be set via kwarg.
39+
40+
This is a regression test for issue #171 where passing
41+
interpolation_image_shape directly to mapper constructors
42+
was silently ignored.
43+
"""
44+
# Request a custom interpolation grid size
45+
custom_size = 138
46+
mapper = mapper_class(
47+
geometry=lstcam_geometry, interpolation_image_shape=custom_size
48+
)
49+
50+
# Verify the trait is set correctly
51+
assert (
52+
mapper.interpolation_image_shape == custom_size
53+
), f"{mapper_class.__name__}: interpolation_image_shape trait not set correctly"
54+
55+
# Verify the image_shape is updated
56+
assert (
57+
mapper.image_shape == custom_size
58+
), f"{mapper_class.__name__}: image_shape not updated to custom size"
59+
60+
# Verify the mapping table has the correct shape
61+
expected_mapping_cols = custom_size * custom_size
62+
assert (
63+
mapper.mapping_table.shape[1] == expected_mapping_cols
64+
), f"{mapper_class.__name__}: mapping_table shape incorrect"
65+
66+
@pytest.mark.parametrize(
67+
"mapper_class",
68+
[BilinearMapper, BicubicMapper, NearestNeighborMapper, RebinMapper],
69+
)
70+
def test_interpolation_image_shape_output(
71+
self, lstcam_geometry, sample_image, mapper_class
72+
):
73+
"""Test that the output image has the correct shape when interpolation_image_shape is set."""
74+
custom_size = 138
75+
mapper = mapper_class(
76+
geometry=lstcam_geometry, interpolation_image_shape=custom_size
77+
)
78+
79+
# Map the image
80+
mapped_image = mapper.map_image(sample_image)
81+
82+
# Verify output shape
83+
expected_shape = (custom_size, custom_size, 1)
84+
assert (
85+
mapped_image.shape == expected_shape
86+
), f"{mapper_class.__name__}: output shape incorrect. Expected {expected_shape}, got {mapped_image.shape}"
87+
88+
@pytest.mark.parametrize(
89+
"mapper_class",
90+
[BilinearMapper, BicubicMapper, NearestNeighborMapper, RebinMapper],
91+
)
92+
def test_default_image_shape(self, lstcam_geometry, mapper_class):
93+
"""Test that mappers use default image_shape when interpolation_image_shape is not set."""
94+
mapper = mapper_class(geometry=lstcam_geometry)
95+
96+
# Default for LSTCam should be 110
97+
default_size = 110
98+
assert (
99+
mapper.image_shape == default_size
100+
), f"{mapper_class.__name__}: default image_shape incorrect"
101+
assert (
102+
mapper.interpolation_image_shape is None
103+
), f"{mapper_class.__name__}: interpolation_image_shape should be None by default"
104+
105+
106+
class TestMapperBasicFunctionality:
107+
"""Test basic functionality of all mapper classes."""
108+
109+
@pytest.mark.parametrize(
110+
"mapper_class",
111+
[
112+
BilinearMapper,
113+
BicubicMapper,
114+
NearestNeighborMapper,
115+
RebinMapper,
116+
AxialMapper,
117+
OversamplingMapper,
118+
ShiftingMapper,
119+
],
120+
)
121+
def test_hexagonal_mapper_instantiation(self, lstcam_geometry, mapper_class):
122+
"""Test that hexagonal mappers can be instantiated."""
123+
mapper = mapper_class(geometry=lstcam_geometry)
124+
assert mapper is not None
125+
assert mapper.mapping_table is not None
126+
127+
def test_square_mapper_requires_square_pixels(self, lstcam_geometry):
128+
"""Test that SquareMapper raises error for non-square pixel cameras."""
129+
# LSTCam has hexagonal pixels, should raise ValueError
130+
with pytest.raises(ValueError, match="only available for square pixel cameras"):
131+
SquareMapper(geometry=lstcam_geometry)
132+
133+
@pytest.mark.parametrize(
134+
"mapper_class",
135+
[
136+
BilinearMapper,
137+
BicubicMapper,
138+
NearestNeighborMapper,
139+
RebinMapper,
140+
AxialMapper,
141+
OversamplingMapper,
142+
ShiftingMapper,
143+
],
144+
)
145+
def test_mapper_output_shape(self, lstcam_geometry, sample_image, mapper_class):
146+
"""Test that mappers produce correctly shaped output."""
147+
mapper = mapper_class(geometry=lstcam_geometry)
148+
mapped_image = mapper.map_image(sample_image)
149+
150+
# Output should be square image with 1 channel
151+
assert len(mapped_image.shape) == 3
152+
assert mapped_image.shape[0] == mapped_image.shape[1]
153+
assert mapped_image.shape[2] == 1
154+
155+
@pytest.mark.parametrize(
156+
"mapper_class",
157+
[
158+
BilinearMapper,
159+
BicubicMapper,
160+
NearestNeighborMapper,
161+
RebinMapper,
162+
AxialMapper,
163+
OversamplingMapper,
164+
ShiftingMapper,
165+
],
166+
)
167+
def test_mapper_multichannel(self, lstcam_geometry, mapper_class):
168+
"""Test that mappers work with multi-channel input."""
169+
# Create a 2-channel image
170+
multichannel_image = np.random.rand(lstcam_geometry.n_pixels, 2).astype(
171+
np.float32
172+
)
173+
mapper = mapper_class(geometry=lstcam_geometry)
174+
mapped_image = mapper.map_image(multichannel_image)
175+
176+
# Output should preserve the number of channels
177+
assert mapped_image.shape[2] == 2
178+
179+
180+
class TestAxialMapperSpecific:
181+
"""Test AxialMapper specific functionality."""
182+
183+
def test_set_index_matrix_false(self, lstcam_geometry):
184+
"""Test AxialMapper with set_index_matrix=False (default)."""
185+
mapper = AxialMapper(geometry=lstcam_geometry, set_index_matrix=False)
186+
assert mapper.index_matrix is None
187+
188+
def test_set_index_matrix_true(self, lstcam_geometry):
189+
"""Test AxialMapper with set_index_matrix=True."""
190+
mapper = AxialMapper(geometry=lstcam_geometry, set_index_matrix=True)
191+
assert mapper.index_matrix is not None
192+
# Index matrix should have the same shape as the output image
193+
assert mapper.index_matrix.shape == (mapper.image_shape, mapper.image_shape)

0 commit comments

Comments
 (0)