1+ # Test for optical elements
2+ import os
3+ import sys
4+
5+ # Setting the path for XLuminA modules:
6+ current_path = os .path .abspath (os .path .join ('..' ))
7+ module_path = os .path .join (current_path )
8+
9+ if module_path not in sys .path :
10+ sys .path .append (module_path )
11+
12+ import unittest
13+ import jax .numpy as jnp
14+ import numpy as np
15+ from jax import random
16+ from xlumina .toolbox import (
17+ space , wrap_phase , is_conserving_energy , softmin , delta_kronecker ,
18+ build_LCD_cell , rotate_mask , nearest ,
19+ extract_profile , gaussian , lorentzian , fwhm_1d_fit , spot_size ,
20+ compute_fwhm , find_max_min , gaussian_2d
21+ )
22+ from xlumina .vectorized_optics import VectorizedLight , PolarizedLightSource
23+
24+ class TestToolbox (unittest .TestCase ):
25+ def setUp (self ):
26+ seed = 9999
27+ self .key = random .PRNGKey (seed )
28+ self .resolution = 512
29+ self .x = np .linspace (- 1500 , 1500 , self .resolution )
30+ self .y = np .linspace (- 1500 , 1500 , self .resolution )
31+ self .wavelength = 633e-3
32+
33+ def test_space (self ):
34+ x , y = space (1500 , self .resolution )
35+ self .assertTrue (jnp .allclose (x , self .x ))
36+ self .assertTrue (jnp .allclose (y , self .y ))
37+
38+ def test_wrap_phase (self ):
39+ phase = jnp .array ([0 , jnp .pi , 2 * jnp .pi , 3 * jnp .pi , - 3 * jnp .pi ])
40+ wrapped = wrap_phase (phase )
41+ self .assertTrue (jnp .allclose (wrapped , jnp .array ([0 , jnp .pi , 0 , jnp .pi , - jnp .pi ])))
42+
43+ def test_is_conserving_energy (self ):
44+ light1 = VectorizedLight (self .x , self .y , self .wavelength )
45+ light2 = VectorizedLight (self .x , self .y , self .wavelength )
46+ light1 .Ex = jnp .ones ((self .resolution , self .resolution ))
47+ light2 .Ex = jnp .ones ((self .resolution , self .resolution ))
48+ conservation = is_conserving_energy (light1 , light2 )
49+ self .assertAlmostEqual (conservation , 1.0 , places = 6 )
50+ light2 .Ex = 0 * light2 .Ex
51+ conservation = is_conserving_energy (light1 , light2 )
52+ self .assertEqual (conservation , 0 )
53+
54+ def test_softmin (self ):
55+ result = softmin (jnp .array ([1.0 , 2.0 , 3.0 ]))
56+ self .assertTrue (result == 1.0 )
57+
58+ def test_delta_kronecker (self ):
59+ self .assertEqual (delta_kronecker (1 , 1 ), 1 )
60+ self .assertEqual (delta_kronecker (1 , 2 ), 0 )
61+
62+ def test_build_LCD_cell (self ):
63+ eta , theta = build_LCD_cell (jnp .pi / 2 , jnp .pi / 4 , self .resolution )
64+ self .assertTrue (jnp .allclose (eta , jnp .pi / 2 * jnp .ones ((self .resolution , self .resolution ))))
65+ self .assertTrue (jnp .allclose (theta , jnp .pi / 4 * jnp .ones ((self .resolution , self .resolution ))))
66+
67+ def test_rotate_mask (self ):
68+ X , Y = jnp .meshgrid (self .x , self .y )
69+ Xrot , Yrot = rotate_mask (X , Y , jnp .pi / 4 )
70+ self .assertEqual (Xrot .shape , (self .resolution , self .resolution ))
71+ self .assertEqual (Yrot .shape , (self .resolution , self .resolution ))
72+
73+ def test_nearest (self ):
74+ array = jnp .array ([1 , 2 , 3 , 4 , 5 ])
75+ idx , value , distance = nearest (array , 3.7 )
76+ self .assertEqual (idx , 3 )
77+ self .assertEqual (value , 4 )
78+ self .assertAlmostEqual (distance , 0.3 , places = 6 )
79+
80+ def test_extract_profile (self ):
81+ data_2d = jnp .ones ((10 , 10 ))
82+ x_points = jnp .array ([0 , 1 , 2 ])
83+ y_points = jnp .array ([0 , 1 , 2 ])
84+ profile = extract_profile (data_2d , x_points , y_points )
85+ self .assertEqual (profile .shape , x_points .shape )
86+
87+ def test_gaussian (self ):
88+ y = gaussian (self .x , 1 , 0 , 1 )
89+ self .assertEqual (y .shape , self .x .shape )
90+
91+ def test_lorentzian (self ):
92+ y = lorentzian (self .x , 0 , 1 )
93+ self .assertEqual (y .shape , self .x .shape )
94+
95+ def test_fwhm_1d_fit (self ):
96+ sigma = 120
97+ amplitude = 1000
98+ mean = 0
99+ y = gaussian (self .x , amplitude , mean , sigma )
100+ _ , fwhm , _ = fwhm_1d_fit (self .x , y , fit = 'gaussian' )
101+ fwhm_theoretical = 2 * sigma * jnp .sqrt (2 * jnp .log (2 )) # 2*sigma*sqrt(2*ln2) is the theoretical FWHM for a gaussian
102+ self .assertAlmostEqual (fwhm , fwhm_theoretical , places = 2 )
103+
104+ def test_spot_size (self ):
105+ size = spot_size (1 , 1 , self .wavelength )
106+ self .assertGreater (size , 0 )
107+
108+ def test_compute_fwhm (self ):
109+ sigma = 120
110+ light_1d = gaussian (self .x , 1000 , 0 , sigma )
111+ XY = jnp .meshgrid (self .x , self .y )
112+ light_2d = gaussian_2d (XY , 1000 , 0 , 0 , sigma , sigma )
113+
114+ popt , fwhm , r_squared = compute_fwhm (light_1d , [self .x , self .y ], field = 'Intensity' , fit = 'gaussian' , dimension = '1D' )
115+ fwhm_theoretical = 2 * sigma * jnp .sqrt (2 * jnp .log (2 ))
116+ self .assertAlmostEqual (fwhm , fwhm_theoretical , places = 4 )
117+
118+ popt , fwhm , r_squared = compute_fwhm (light_2d , [self .x , self .y ], field = 'Intensity' , fit = 'gaussian' , dimension = '2D' )
119+ fwhm_x , fwhm_y = fwhm
120+ fwhm_theoretical = 2 * sigma * jnp .sqrt (2 * jnp .log (2 ))
121+ self .assertAlmostEqual (fwhm_x , fwhm_theoretical , places = 4 )
122+ self .assertAlmostEqual (fwhm_y , fwhm_theoretical , places = 4 )
123+
124+ def test_find_max_min (self ):
125+ value = jnp .array ([[1 , 2 ], [3 , 4 ]])
126+ idx , xy , ext_value = find_max_min (value , self .x [:2 ], self .y [:2 ], kind = 'max' )
127+ self .assertEqual (idx .shape , (1 , 2 ))
128+ self .assertEqual (xy .shape , (1 , 2 ))
129+ self .assertEqual (ext_value , 4 )
130+ idx , xy , ext_value = find_max_min (value , self .x [:2 ], self .y [:2 ], kind = 'min' )
131+ self .assertEqual (ext_value , 1 )
132+
133+ if __name__ == '__main__' :
134+ unittest .main ()
0 commit comments