1+ # Setting the path for XLuminA modules:
2+ import os
3+ import sys
4+ current_path = os .path .abspath (os .path .join ('..' ))
5+ dir_path = os .path .dirname (current_path )
6+ module_path = os .path .join (dir_path )
7+ if module_path not in sys .path :
8+ sys .path .append (module_path )
9+
10+ from __init__ import um , nm , cm , mm
11+ from xlumina .vectorized_optics import *
12+ from xlumina .optical_elements import six_times_six_ansatz
13+ from xlumina .loss_functions import vectorized_loss_hybrid
14+ from xlumina .toolbox import space , softmin
15+ import jax .numpy as jnp
16+
17+ """
18+ Pure topological discovery within 6x6 ansatz for Dorn, Quabis and Leuchs (2004)
19+ """
20+
21+ # 1. System specs:
22+ sensor_lateral_size = 824 # Resolution
23+ wavelength_1 = 635.0 * nm
24+ x_total = 2500 * um
25+ x , y = space (x_total , sensor_lateral_size )
26+ shape = jnp .shape (x )[0 ]
27+
28+ # 2. Define the optical functions: two orthogonally polarized beams:
29+ w0 = (1200 * um , 1200 * um )
30+ ls1 = PolarizedLightSource (x , y , wavelength_1 )
31+ ls1 .gaussian_beam (w0 = w0 , jones_vector = (1 , - 1 ))
32+
33+ # 3. Define the output (High Resolution) detection:
34+ x_out , y_out = jnp .array (space (10 * um , 400 ))
35+ X , Y = jnp .meshgrid (x ,y )
36+
37+ # 4. High NA objective lens specs:
38+ NA = 0.9
39+ radius_lens = 3.6 * mm / 2
40+ f_lens = radius_lens / NA
41+
42+ # 4.1 Fixed phase masks:
43+ # Polarization converter in Dorn, Quabis, Leuchs (2004):
44+ pi_half = (jnp .pi - jnp .pi / 2 ) * jnp .ones (shape = (sensor_lateral_size // 2 , sensor_lateral_size // 2 ))
45+ minus_pi_half = - jnp .pi / 2 * jnp .ones (shape = (sensor_lateral_size // 2 , sensor_lateral_size // 2 ))
46+ PM1_1 = jnp .concatenate ((jnp .concatenate ((minus_pi_half , pi_half ), axis = 1 ), jnp .concatenate ((minus_pi_half , pi_half ), axis = 1 )), axis = 0 )
47+ PM1_2 = jnp .concatenate ((jnp .concatenate ((minus_pi_half , minus_pi_half ), axis = 1 ), jnp .concatenate ((pi_half , pi_half ), axis = 1 )), axis = 0 )
48+
49+ # Linear grating
50+ PM2_1 = jnp .sin (2 * jnp .pi * Y / 1000 ) * jnp .pi
51+ PM2_2 = jnp .sin (2 * jnp .pi * X / 1000 ) * jnp .pi
52+
53+ # 5. Static parameters - don't change during optimization:
54+ fixed_params = [radius_lens , f_lens , x_out , y_out , PM1_1 , PM1_2 , PM2_1 , PM2_2 ]
55+
56+ # 6. Define the loss function:
57+ def loss_hybrid_fixed_PM (parameters ):
58+ # Output from hybrid_setup is jnp.array(12, N, N): for 12 detectors
59+ i_effective = six_times_six_ansatz (ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , ls1 , parameters , fixed_params , distance_offset = 9.5 )
60+ # Get the minimum value within loss value array of shape (12, 1, 1)
61+ loss_val = softmin (vectorized_loss_hybrid (i_effective ))
62+ return loss_val
0 commit comments