Skip to content

Commit 166177d

Browse files
tests: add tests for fileio, image sim, & variables
1 parent 96f44a2 commit 166177d

File tree

3 files changed

+325
-1
lines changed

3 files changed

+325
-1
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
2+
import pytest
3+
import numpy as np
4+
import optiland.backend as be
5+
from optiland.analysis.image_simulation import ImageSimulationEngine, DistortionWarper, PSFBasisGenerator
6+
from optiland.samples.objectives import CookeTriplet
7+
8+
class TestImageSimulation:
9+
@pytest.fixture
10+
def optic(self):
11+
return CookeTriplet()
12+
13+
@pytest.fixture
14+
def source_image(self):
15+
# Create a small dummy image (green square in black background)
16+
img = np.zeros((32, 32, 3), dtype=np.float32)
17+
img[10:22, 10:22, 1] = 1.0
18+
return img
19+
20+
def test_engine_init(self, optic, source_image):
21+
engine = ImageSimulationEngine(optic, source_image)
22+
assert engine.source_image.shape == (3, 32, 32) # Transposed to (C, H, W)
23+
assert engine.config is not None
24+
25+
def test_engine_run(self, optic, source_image):
26+
# Use very loose config for speed
27+
config = {
28+
"psf_grid_shape": (3, 3),
29+
"psf_size": 32,
30+
"num_rays": 32,
31+
"n_components": 2,
32+
"oversample": 1,
33+
"wavelengths": [0.55] # Mono
34+
}
35+
engine = ImageSimulationEngine(optic, source_image, config=config)
36+
result = engine.run()
37+
38+
# Let's check typical RGB case
39+
config["wavelengths"] = [0.65, 0.55, 0.45]
40+
engine = ImageSimulationEngine(optic, source_image, config=config)
41+
result = engine.run()
42+
43+
assert result.shape == (32, 32, 3)
44+
assert be.max(result) > 0 # Should have some signal
45+
46+
def test_distortion_warper(self, optic):
47+
warper = DistortionWarper(optic)
48+
H, W = 32, 32
49+
dist_map = warper.generate_distortion_map(wavelength=0.55, image_shape=(H, W))
50+
assert dist_map.shape == (1, H, W, 2)
51+
52+
# Warp a dummy image
53+
img = be.ones((1, H, W))
54+
warped = warper.warp_image(img, dist_map)
55+
assert warped.shape == (1, H, W)
56+
57+
def test_psf_basis_generator(self, optic):
58+
gen = PSFBasisGenerator(
59+
optic,
60+
wavelength=0.55,
61+
grid_shape=(3, 3),
62+
num_rays=32,
63+
psf_grid_size=32
64+
)
65+
eigen_psfs, coeffs, mean_psf = gen.generate_basis(n_components=2)
66+
67+
assert eigen_psfs.shape == (2, 32, 32)
68+
assert coeffs.shape == (2, 3, 3)
69+
assert mean_psf.shape == (32, 32)
70+
71+
resized_coeffs = gen.resize_coefficient_map(coeffs, (64, 64))
72+
assert resized_coeffs.shape == (2, 64, 64)
73+

tests/test_fileio.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
ZemaxFileSourceHandler,
1313
load_zemax_file,
1414
)
15+
from optiland.fileio.converters import ZemaxToOpticConverter
1516
from optiland.materials import Material
1617
from optiland.optic import Optic
1718
from optiland.samples.objectives import HeliarLens
19+
from optiland.geometries import ToroidalGeometry
1820
import optiland.backend as be
1921
from .utils import assert_allclose
2022

@@ -288,3 +290,148 @@ def test_remove_surface_after_load(set_test_backend, tmp_path):
288290
expected_positions_after_object = be.array([0.0, 10.0, 30.0])
289291

290292
assert_allclose(positions[1:], expected_positions_after_object)
293+
294+
295+
class TestZemaxToOpticConverterExtended:
296+
def test_configure_aperture_floating_stop_no_diameter(self):
297+
zemax_data = {
298+
"surfaces": {
299+
0: {"type": "standard", "is_stop": True, "radius": 0.0, "conic": 0.0, "thickness": 0.0, "material": "Air"},
300+
},
301+
"aperture": {"floating_stop": True},
302+
"fields": {"type": "angle", "x": [0], "y": [0]},
303+
"wavelengths": {"primary_index": 0, "data": [0.55]},
304+
}
305+
converter = ZemaxToOpticConverter(zemax_data)
306+
converter.optic = Optic()
307+
converter._configure_surfaces()
308+
309+
with pytest.raises(ValueError, match="Floating stop aperture specified but no stop diameter found"):
310+
converter._configure_aperture()
311+
312+
def test_configure_aperture_no_valid_type(self):
313+
zemax_data = {
314+
"surfaces": {},
315+
"aperture": {"floating_stop": False},
316+
"fields": {"type": "angle", "x": [0], "y": [0]},
317+
"wavelengths": {"primary_index": 0, "data": [0.55]},
318+
}
319+
converter = ZemaxToOpticConverter(zemax_data)
320+
converter.optic = Optic()
321+
322+
with pytest.raises(ValueError, match="No valid aperture type found"):
323+
converter._configure_aperture()
324+
325+
def test_configure_surface_coefficients_unsupported_type(self):
326+
zemax_data = {
327+
"surfaces": {
328+
0: {"type": "unsupported_surface_type", "radius": 10},
329+
},
330+
"aperture": {"EPD": 10},
331+
"fields": {"type": "angle", "x": [0], "y": [0]},
332+
"wavelengths": {"primary_index": 0, "data": [0.55]},
333+
}
334+
converter = ZemaxToOpticConverter(zemax_data)
335+
336+
with pytest.raises(ValueError, match="Unsupported Zemax surface type"):
337+
converter._configure_surface_coefficients({"type": "unsupported_surface_type"})
338+
339+
def test_configure_fields_vignette_warning(self, capsys):
340+
zemax_data = {
341+
"surfaces": {},
342+
"aperture": {"EPD": 10},
343+
"fields": {
344+
"type": "angle",
345+
"x": [0],
346+
"y": [0],
347+
"vignette_decenter_x": [0.1],
348+
"vignette_decenter_y": [0.0]
349+
},
350+
"wavelengths": {"primary_index": 0, "data": [0.55]},
351+
}
352+
converter = ZemaxToOpticConverter(zemax_data)
353+
converter.optic = Optic()
354+
converter._configure_fields()
355+
356+
captured = capsys.readouterr()
357+
assert "Warning: Vignette decentering is not supported." in captured.out
358+
359+
def test_configure_surfaces_coordinate_break(self):
360+
zemax_data = {
361+
"surfaces": {
362+
0: {
363+
"type": "coordinate_break",
364+
"param_0": 1.0, # dx
365+
"param_1": 2.0, # dy
366+
"thickness": 5.0, # dz (thickness)
367+
"param_2": 10.0, # rx deg
368+
"param_3": 20.0, # ry deg
369+
"param_4": 30.0, # rz deg
370+
"conic": 0.0,
371+
},
372+
1: {
373+
"type": "standard",
374+
"radius": 100.0,
375+
"thickness": 10.0,
376+
"conic": 0.0,
377+
"material": "N-BK7"
378+
}
379+
},
380+
"aperture": {"EPD": 10},
381+
"fields": {"type": "angle", "x": [0], "y": [0]},
382+
"wavelengths": {"primary_index": 0, "data": [0.55]},
383+
}
384+
converter = ZemaxToOpticConverter(zemax_data)
385+
optic = converter.convert()
386+
387+
surf = optic.surface_group.surfaces[0]
388+
assert surf.geometry.radius == 100.0
389+
390+
cs = surf.geometry.cs
391+
assert cs.x != 0 or cs.y != 0 or cs.z != 0 or cs.rx != 0 or cs.ry != 0 or cs.rz != 0
392+
393+
def test_configure_surfaces_toroidal(self):
394+
zemax_data = {
395+
"surfaces": {
396+
0: {
397+
"type": "toroidal",
398+
"radius": 50.0, # radius_y
399+
"param_1": 60.0, # radius_x
400+
"param_2": 0.1, # coeff start
401+
"thickness": 5.0,
402+
"conic": 0.0,
403+
"material": "Air"
404+
}
405+
},
406+
"aperture": {"EPD": 10},
407+
"fields": {"type": "angle", "x": [0], "y": [0]},
408+
"wavelengths": {"primary_index": 0, "data": [0.55]},
409+
}
410+
converter = ZemaxToOpticConverter(zemax_data)
411+
optic = converter.convert()
412+
413+
surf = optic.surface_group.surfaces[0]
414+
assert isinstance(surf.geometry, ToroidalGeometry)
415+
assert surf.geometry.R_yz == 50.0
416+
assert surf.geometry.R_rot == 60.0
417+
418+
def test_configure_surfaces_infinity_thickness(self):
419+
zemax_data = {
420+
"surfaces": {
421+
0: {
422+
"type": "standard",
423+
"radius": be.inf,
424+
"thickness": be.inf,
425+
"conic": 0.0,
426+
"material": "Air"
427+
}
428+
},
429+
"aperture": {"EPD": 10},
430+
"fields": {"type": "angle", "x": [0], "y": [0]},
431+
"wavelengths": {"primary_index": 0, "data": [0.55]},
432+
}
433+
converter = ZemaxToOpticConverter(zemax_data)
434+
optic = converter.convert()
435+
436+
surf = optic.surface_group.surfaces[0]
437+
assert be.isinf(surf.thickness)

tests/test_variable.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import optiland.backend as be
22
import pytest
33
from unittest.mock import patch
4+
import numpy as np
45

6+
from optiland.optic import Optic
57
from optiland.coordinate_system import CoordinateSystem
68
from optiland.geometries import (
79
ChebyshevPolynomialGeometry,
@@ -17,6 +19,13 @@
1719
from optiland.materials.abbe import AbbeMaterial
1820
from optiland.optimization.variable.material import MaterialVariable
1921
from optiland.optimization.scaling.identity import IdentityScaler
22+
from optiland.optimization.variable import (
23+
DecenterVariable,
24+
TiltVariable,
25+
AsphereCoeffVariable,
26+
ZernikeCoeffVariable,
27+
Variable
28+
)
2029
from .utils import assert_allclose
2130

2231

@@ -703,4 +712,99 @@ def test_delitem(self, set_test_backend):
703712
var_manager = variable.VariableManager()
704713
var_manager.add(optic, "radius", surface_number=1)
705714
del var_manager[0]
706-
assert len(var_manager) == 0
715+
assert len(var_manager) == 0
716+
717+
718+
class TestVariablesExtended:
719+
@pytest.fixture
720+
def optic(self):
721+
optic = Optic()
722+
optic.add_surface(index=0, surface_type="plane")
723+
optic.add_surface(index=1, surface_type="standard", radius=100)
724+
optic.add_surface(index=2, surface_type="even_asphere", radius=50, coefficients=[0.0, 0.1])
725+
optic.add_surface(index=3, surface_type="zernike", radius=200, coefficients=[0.0]*5)
726+
return optic
727+
728+
def test_decenter_variable(self, optic):
729+
# Surface 1 is standard
730+
var = DecenterVariable(optic, surface_number=1, axis="x")
731+
assert var.get_value() == 0.0
732+
733+
var.update_value(0.5)
734+
assert optic.surface_group.surfaces[1].geometry.cs.x == 0.5
735+
assert var.get_value() == 0.5
736+
737+
# Test y and z
738+
var_y = DecenterVariable(optic, surface_number=1, axis="y")
739+
var_y.update_value(-0.2)
740+
assert optic.surface_group.surfaces[1].geometry.cs.y == -0.2
741+
742+
var_z = DecenterVariable(optic, surface_number=1, axis="z")
743+
var_z.update_value(1.0)
744+
assert optic.surface_group.surfaces[1].geometry.cs.z == 1.0
745+
746+
def test_decenter_variable_invalid_axis(self, optic):
747+
with pytest.raises(ValueError, match='Invalid axis "r"'):
748+
DecenterVariable(optic, surface_number=1, axis="r")
749+
750+
var = DecenterVariable(optic, surface_number=1, axis="x")
751+
var.axis = "r"
752+
with pytest.raises(ValueError, match='Invalid axis "r"'):
753+
var.get_value()
754+
with pytest.raises(ValueError, match='Invalid axis "r"'):
755+
var.update_value(1.0)
756+
757+
def test_tilt_variable(self, optic):
758+
var = TiltVariable(optic, surface_number=1, axis="x")
759+
var.update_value(0.1) # radians
760+
assert optic.surface_group.surfaces[1].geometry.cs.rx == 0.1
761+
762+
var_y = TiltVariable(optic, surface_number=1, axis="y")
763+
var_y.update_value(0.2)
764+
assert optic.surface_group.surfaces[1].geometry.cs.ry == 0.2
765+
766+
767+
def test_tilt_variable_invalid_axis(self, optic):
768+
with pytest.raises(ValueError, match='Invalid axis "r"'):
769+
TiltVariable(optic, surface_number=1, axis="r")
770+
771+
def test_asphere_coeff_variable_get_padding(self, optic):
772+
# Surface 2 is even_asphere with 2 coeffs
773+
var = AsphereCoeffVariable(optic, surface_number=2, coeff_number=5)
774+
# Should return 0 and pad
775+
val = var.get_value()
776+
assert val == 0.0
777+
# Check padding happened
778+
assert len(optic.surface_group.surfaces[2].geometry.coefficients) >= 6
779+
780+
def test_asphere_coeff_variable_update(self, optic):
781+
var = AsphereCoeffVariable(optic, surface_number=2, coeff_number=1)
782+
var.update_value(0.5)
783+
assert optic.surface_group.surfaces[2].geometry.coefficients[1] == 0.5
784+
785+
def test_zernike_coeff_variable_padding(self, optic):
786+
# Surface 3 is zernike with 5 coeffs (indices 0-4)
787+
var = ZernikeCoeffVariable(optic, surface_number=3, coeff_index=10)
788+
789+
# Test update padding
790+
var.update_value(1.5)
791+
coeffs = optic.surface_group.surfaces[3].geometry.coefficients
792+
assert len(coeffs) >= 11
793+
assert coeffs[10] == 1.5
794+
795+
# Test get padding
796+
var2 = ZernikeCoeffVariable(optic, surface_number=3, coeff_index=15)
797+
val = var2.get_value()
798+
assert val == 0.0
799+
assert len(optic.surface_group.surfaces[3].geometry.coefficients) >= 16
800+
801+
def test_variable_wrapper(self, optic):
802+
# Test generic Variable wrapper
803+
v = Variable(optic, type_name="decenter", surface_number=1, axis="x")
804+
v.update(1.0)
805+
assert v.value == 1.0
806+
assert optic.surface_group.surfaces[1].geometry.cs.x == 1.0
807+
808+
def test_variable_wrapper_invalid_type(self, optic):
809+
with pytest.raises(ValueError, match='Invalid variable type "invalid"'):
810+
Variable(optic, type_name="invalid")

0 commit comments

Comments
 (0)