Skip to content

Commit c5bb7b4

Browse files
test for toolbox module
1 parent 443bc3d commit c5bb7b4

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

tests/test_toolbox.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Test for optical elements
2+
import os
3+
import sys
4+
5+
# Setting the path for XLuminA modules:
6+
current_path = os.path.abspath(os.path.join('..'))
7+
module_path = os.path.join(current_path)
8+
9+
if module_path not in sys.path:
10+
sys.path.append(module_path)
11+
12+
import unittest
13+
import jax.numpy as jnp
14+
import numpy as np
15+
from jax import random
16+
from xlumina.toolbox import (
17+
space, wrap_phase, is_conserving_energy, softmin, delta_kronecker,
18+
build_LCD_cell, rotate_mask, nearest,
19+
extract_profile, gaussian, lorentzian, fwhm_1d_fit, spot_size,
20+
compute_fwhm, find_max_min, gaussian_2d
21+
)
22+
from xlumina.vectorized_optics import VectorizedLight, PolarizedLightSource
23+
24+
class TestToolbox(unittest.TestCase):
25+
def setUp(self):
26+
seed = 9999
27+
self.key = random.PRNGKey(seed)
28+
self.resolution = 512
29+
self.x = np.linspace(-1500, 1500, self.resolution)
30+
self.y = np.linspace(-1500, 1500, self.resolution)
31+
self.wavelength = 633e-3
32+
33+
def test_space(self):
34+
x, y = space(1500, self.resolution)
35+
self.assertTrue(jnp.allclose(x, self.x))
36+
self.assertTrue(jnp.allclose(y, self.y))
37+
38+
def test_wrap_phase(self):
39+
phase = jnp.array([0, jnp.pi, 2*jnp.pi, 3*jnp.pi, -3*jnp.pi])
40+
wrapped = wrap_phase(phase)
41+
self.assertTrue(jnp.allclose(wrapped, jnp.array([0, jnp.pi, 0, jnp.pi, -jnp.pi])))
42+
43+
def test_is_conserving_energy(self):
44+
light1 = VectorizedLight(self.x, self.y, self.wavelength)
45+
light2 = VectorizedLight(self.x, self.y, self.wavelength)
46+
light1.Ex = jnp.ones((self.resolution, self.resolution))
47+
light2.Ex = jnp.ones((self.resolution, self.resolution))
48+
conservation = is_conserving_energy(light1, light2)
49+
self.assertAlmostEqual(conservation, 1.0, places=6)
50+
light2.Ex = 0*light2.Ex
51+
conservation = is_conserving_energy(light1, light2)
52+
self.assertEqual(conservation, 0)
53+
54+
def test_softmin(self):
55+
result = softmin(jnp.array([1.0, 2.0, 3.0]))
56+
self.assertTrue(result == 1.0)
57+
58+
def test_delta_kronecker(self):
59+
self.assertEqual(delta_kronecker(1, 1), 1)
60+
self.assertEqual(delta_kronecker(1, 2), 0)
61+
62+
def test_build_LCD_cell(self):
63+
eta, theta = build_LCD_cell(jnp.pi/2, jnp.pi/4, self.resolution)
64+
self.assertTrue(jnp.allclose(eta, jnp.pi/2 * jnp.ones((self.resolution, self.resolution))))
65+
self.assertTrue(jnp.allclose(theta, jnp.pi/4 * jnp.ones((self.resolution, self.resolution))))
66+
67+
def test_rotate_mask(self):
68+
X, Y = jnp.meshgrid(self.x, self.y)
69+
Xrot, Yrot = rotate_mask(X, Y, jnp.pi/4)
70+
self.assertEqual(Xrot.shape, (self.resolution, self.resolution))
71+
self.assertEqual(Yrot.shape, (self.resolution, self.resolution))
72+
73+
def test_nearest(self):
74+
array = jnp.array([1, 2, 3, 4, 5])
75+
idx, value, distance = nearest(array, 3.7)
76+
self.assertEqual(idx, 3)
77+
self.assertEqual(value, 4)
78+
self.assertAlmostEqual(distance, 0.3, places=6)
79+
80+
def test_extract_profile(self):
81+
data_2d = jnp.ones((10, 10))
82+
x_points = jnp.array([0, 1, 2])
83+
y_points = jnp.array([0, 1, 2])
84+
profile = extract_profile(data_2d, x_points, y_points)
85+
self.assertEqual(profile.shape, x_points.shape)
86+
87+
def test_gaussian(self):
88+
y = gaussian(self.x, 1, 0, 1)
89+
self.assertEqual(y.shape, self.x.shape)
90+
91+
def test_lorentzian(self):
92+
y = lorentzian(self.x, 0, 1)
93+
self.assertEqual(y.shape, self.x.shape)
94+
95+
def test_fwhm_1d_fit(self):
96+
sigma = 120
97+
amplitude = 1000
98+
mean = 0
99+
y = gaussian(self.x, amplitude, mean, sigma)
100+
_, fwhm, _ = fwhm_1d_fit(self.x, y, fit='gaussian')
101+
fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2)) # 2*sigma*sqrt(2*ln2) is the theoretical FWHM for a gaussian
102+
self.assertAlmostEqual(fwhm, fwhm_theoretical, places=2)
103+
104+
def test_spot_size(self):
105+
size = spot_size(1, 1, self.wavelength)
106+
self.assertGreater(size, 0)
107+
108+
def test_compute_fwhm(self):
109+
sigma = 120
110+
light_1d = gaussian(self.x, 1000, 0, sigma)
111+
XY = jnp.meshgrid(self.x, self.y)
112+
light_2d = gaussian_2d(XY, 1000, 0, 0, sigma, sigma)
113+
114+
popt, fwhm, r_squared = compute_fwhm(light_1d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='1D')
115+
fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2))
116+
self.assertAlmostEqual(fwhm, fwhm_theoretical, places=4)
117+
118+
popt, fwhm, r_squared = compute_fwhm(light_2d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='2D')
119+
fwhm_x, fwhm_y = fwhm
120+
fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2))
121+
self.assertAlmostEqual(fwhm_x, fwhm_theoretical, places=4)
122+
self.assertAlmostEqual(fwhm_y, fwhm_theoretical, places=4)
123+
124+
def test_find_max_min(self):
125+
value = jnp.array([[1, 2], [3, 4]])
126+
idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='max')
127+
self.assertEqual(idx.shape, (1, 2))
128+
self.assertEqual(xy.shape, (1, 2))
129+
self.assertEqual(ext_value, 4)
130+
idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='min')
131+
self.assertEqual(ext_value, 1)
132+
133+
if __name__ == '__main__':
134+
unittest.main()

0 commit comments

Comments
 (0)