1+ # Test for optical elements
2+ import os
3+ import sys
4+
5+ # Setting the path for XLuminA modules:
6+ current_path = os .path .abspath (os .path .join ('..' ))
7+ module_path = os .path .join (current_path )
8+
9+ if module_path not in sys .path :
10+ sys .path .append (module_path )
11+
12+ import unittest
13+ import jax .numpy as jnp
14+ import numpy as np
15+ from xlumina .optical_elements import (
16+ SLM , jones_LP , jones_general_retarder , jones_sSLM , jones_LCD ,
17+ sSLM , LCD , linear_polarizer , BS_symmetric , high_NA_objective_lens ,
18+ VCZT_objective_lens , lens , cylindrical_lens , axicon_lens , building_block
19+ )
20+
21+ from xlumina .vectorized_optics import VectorizedLight , PolarizedLightSource
22+ from xlumina .wave_optics import LightSource
23+
24+ class TestOpticalElements (unittest .TestCase ):
25+ def setUp (self ):
26+ self .wavelength = 633e-3 #nm
27+ self .resolution = 512
28+ self .x = np .linspace (- 1500 , 1500 , self .resolution )
29+ self .y = np .linspace (- 1500 , 1500 , self .resolution )
30+ self .k = 2 * jnp .pi / self .wavelength
31+
32+ def test_slm (self ):
33+ light = LightSource (self .x , self .y , self .wavelength )
34+ light .gaussian_beam (w0 = (1200 , 1200 ), E0 = 1 )
35+ phase = jnp .zeros ((self .resolution , self .resolution ))
36+ slm_output , _ = SLM (light , phase , self .resolution )
37+ self .assertEqual (slm_output .field .shape , (self .resolution , self .resolution )) # Check output shape == input shape
38+ self .assertTrue (jnp .allclose (slm_output .field , light .field )) # Phase added by SLM is 0, field shouldn't change.
39+
40+ def test_shape_jones_matrices (self ):
41+ lp = jones_LP (jnp .pi / 4 )
42+ self .assertEqual (lp .shape , (2 , 2 ))
43+
44+ retarder = jones_general_retarder (jnp .pi / 2 , jnp .pi / 4 , 0 )
45+ self .assertEqual (retarder .shape , (2 , 2 ))
46+
47+ sslm = jones_sSLM (jnp .pi / 2 , jnp .pi / 4 )
48+ self .assertEqual (sslm .shape , (2 , 2 ))
49+
50+ lcd = jones_LCD (jnp .pi / 2 , jnp .pi / 4 )
51+ self .assertEqual (lcd .shape , (2 , 2 ))
52+
53+ def test_polarization_devices (self ):
54+ light = PolarizedLightSource (self .x , self .y , self .wavelength )
55+ light .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 1 ))
56+
57+ # super-SLM with zero phase -- input SoP is diagonal
58+ alpha = jnp .zeros ((self .resolution , self .resolution ))
59+ phi = jnp .zeros ((self .resolution , self .resolution ))
60+ sslm_output = sSLM (light , alpha , phi )
61+ self .assertEqual (sslm_output .Ex .shape , (self .resolution , self .resolution ))
62+ self .assertEqual (sslm_output .Ey .shape , (self .resolution , self .resolution ))
63+ self .assertEqual (sslm_output .Ez .shape , (self .resolution , self .resolution ))
64+ self .assertTrue (jnp .allclose (sslm_output .Ex , light .Ex ))
65+ self .assertTrue (jnp .allclose (sslm_output .Ey , light .Ey ))
66+ self .assertTrue (jnp .allclose (sslm_output .Ez , light .Ez ))
67+
68+ # super-SLM with pi phase in Ex and Ey -- input SoP is diagonal
69+ alpha = jnp .pi * jnp .ones ((self .resolution , self .resolution ))
70+ phi = jnp .pi * jnp .ones ((self .resolution , self .resolution ))
71+ sslm_output = sSLM (light , alpha , phi )
72+ self .assertTrue (jnp .allclose (sslm_output .Ex , light .Ex * jnp .exp (1j * jnp .pi )))
73+ self .assertTrue (jnp .allclose (sslm_output .Ey , light .Ey * jnp .exp (1j * jnp .pi )))
74+ self .assertTrue (jnp .allclose (sslm_output .Ez , light .Ez ))
75+
76+ # LCD
77+ light = PolarizedLightSource (self .x , self .y , self .wavelength )
78+ light .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
79+ lcd_output = LCD (light , 0 , 0 )
80+ self .assertEqual (lcd_output .Ex .shape , (self .resolution , self .resolution ))
81+ self .assertEqual (lcd_output .Ey .shape , (self .resolution , self .resolution ))
82+ self .assertEqual (lcd_output .Ez .shape , (self .resolution , self .resolution ))
83+ self .assertTrue (jnp .allclose (lcd_output .Ex , light .Ex ))
84+ self .assertTrue (jnp .allclose (lcd_output .Ey , light .Ey ))
85+ self .assertTrue (jnp .allclose (lcd_output .Ez , light .Ez ))
86+
87+ # LP aligned with incident SoP
88+ empty = jnp .zeros ((self .resolution , self .resolution ))
89+ lp_output = linear_polarizer (light , empty )
90+ self .assertEqual (lp_output .Ex .shape , (self .resolution , self .resolution ))
91+ self .assertEqual (lp_output .Ey .shape , (self .resolution , self .resolution ))
92+ self .assertEqual (lp_output .Ez .shape , (self .resolution , self .resolution ))
93+ self .assertTrue (jnp .allclose (lp_output .Ex , light .Ex ))
94+ self .assertTrue (jnp .allclose (lp_output .Ey , light .Ey ))
95+ self .assertTrue (jnp .allclose (lp_output .Ez , light .Ez ))
96+
97+ # LP crossed to input SoP
98+ pi_half = jnp .pi / 2 * jnp .ones_like (empty )
99+ lp_output = linear_polarizer (light , pi_half )
100+ self .assertTrue (jnp .allclose (lp_output .Ex , empty ))
101+ self .assertTrue (jnp .allclose (lp_output .Ey , empty ))
102+
103+ def test_beam_splitter (self ):
104+ light1 = PolarizedLightSource (self .x , self .y , self .wavelength )
105+ light1 .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
106+ light2 = PolarizedLightSource (self .x , self .y , self .wavelength )
107+ light2 .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
108+
109+ c , d = BS_symmetric (light1 , light2 , 0 ) # fully transmissive
110+ # Adds a pi phase: jnp.exp(1j * pi) = 1j
111+ # Noise = T*0.01
112+ T = jnp .abs (jnp .cos (0 ))
113+ R = jnp .abs (jnp .sin (0 ))
114+ noise = 0.01
115+ self .assertTrue (jnp .allclose (c .Ex , (T - noise ) * 1j * light2 .Ex + (R - noise ) * light1 .Ex ))
116+ self .assertTrue (jnp .allclose (c .Ey , (T - noise ) * 1j * light2 .Ey + (R - noise ) * light1 .Ey ))
117+ self .assertTrue (jnp .allclose (d .Ex , (T - noise ) * 1j * light1 .Ex + (R - noise ) * light2 .Ex ))
118+ self .assertTrue (jnp .allclose (d .Ey , (T - noise ) * 1j * light1 .Ey + (R - noise ) * light2 .Ey ))
119+
120+ def test_high_na_objective_lens (self ):
121+ radius_lens = 3.6 * 1e3 / 2 # mm
122+ f_lens = radius_lens / 0.9
123+ light = PolarizedLightSource (self .x , self .y , self .wavelength )
124+ light .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
125+ output , _ = high_NA_objective_lens (light , radius_lens , f_lens )
126+ self .assertEqual (output .shape , (3 , self .resolution , self .resolution ))
127+
128+ def test_vczt_objective_lens (self ):
129+ radius_lens = 3.6 * 1e3 / 2 # mm
130+ f_lens = radius_lens / 0.9
131+ light = PolarizedLightSource (self .x , self .y , self .wavelength )
132+ light .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
133+ output = VCZT_objective_lens (light , radius_lens , f_lens , self .x , self .y )
134+ self .assertEqual (output .Ex .shape , (self .resolution , self .resolution ))
135+ self .assertEqual (output .Ey .shape , (self .resolution , self .resolution ))
136+ self .assertEqual (output .Ez .shape , (self .resolution , self .resolution ))
137+
138+ def test_lenses_scalar (self ):
139+ light = LightSource (self .x , self .y , self .wavelength )
140+ light .gaussian_beam (w0 = (1200 , 1200 ), E0 = 1 )
141+ lens_output , _ = lens (light , (50 , 50 ), (1000 , 1000 ))
142+ self .assertEqual (lens_output .field .shape , (self .resolution , self .resolution ))
143+ cyl_lens_output , _ = cylindrical_lens (light , 1000 )
144+ self .assertEqual (cyl_lens_output .field .shape , (self .resolution , self .resolution ))
145+ axicon_output , _ = axicon_lens (light , 0.1 )
146+ self .assertEqual (axicon_output .field .shape , (self .resolution , self .resolution ))
147+
148+ def test_lenses_vectorial (self ):
149+ ls = PolarizedLightSource (self .x , self .y , self .wavelength )
150+ ls .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
151+ light = VectorizedLight (self .x , self .y , self .wavelength )
152+ light .Ex = ls .Ex
153+ light .Ey = ls .Ey
154+ light .Ez = ls .Ez
155+ lens_output , _ = lens (light , (50 , 50 ), (1000 , 1000 ))
156+ self .assertEqual (lens_output .Ex .shape , (self .resolution , self .resolution ))
157+ self .assertEqual (lens_output .Ey .shape , (self .resolution , self .resolution ))
158+ cyl_lens_output , _ = cylindrical_lens (light , 1000 )
159+ self .assertEqual (cyl_lens_output .Ex .shape , (self .resolution , self .resolution ))
160+ self .assertEqual (cyl_lens_output .Ey .shape , (self .resolution , self .resolution ))
161+ axicon_output , _ = axicon_lens (light , 0.1 )
162+ self .assertEqual (axicon_output .Ex .shape , (self .resolution , self .resolution ))
163+ self .assertEqual (axicon_output .Ey .shape , (self .resolution , self .resolution ))
164+
165+ def test_building_block (self ):
166+ light = PolarizedLightSource (self .x , self .y , self .wavelength )
167+ light .gaussian_beam (w0 = (1200 , 1200 ), jones_vector = (1 , 0 ))
168+ output = building_block (light , jnp .zeros ((self .resolution , self .resolution )), jnp .zeros ((self .resolution , self .resolution )), 1000 , jnp .pi / 2 , jnp .pi / 4 )
169+ self .assertEqual (output .Ex .shape , (self .resolution , self .resolution ))
170+ self .assertEqual (output .Ey .shape , (self .resolution , self .resolution ))
171+ self .assertEqual (output .Ez .shape , (self .resolution , self .resolution ))
172+
173+ if __name__ == '__main__' :
174+ unittest .main ()
0 commit comments