1010from xlumina .__init__ import um , nm , cm
1111from xlumina .wave_optics import *
1212from xlumina .optical_elements import SLM
13- from xlumina .loss_functions import mean_batch_MSE_Intensity
1413from xlumina .toolbox import space
1514from jax import vmap
1615import jax .numpy as jnp
2827
2928
3029# 2. Define the optical functions:
31- def batch_dualSLM_4f (input_field , x , y , wavelength , parameters ):
30+ def batch_dualSLM_4f (input_mask , x , y , wavelength , parameters ):
3231 """
3332 [4f system coded exclusively for batch optimization purposes].
3433
3534 Define an optical table with a 4f system composed by 2 SLMs (to be used with ScalarLight).
3635
3736 Illustrative scheme:
38- U(x,y) --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect
37+ U(x,y) --> input_mask --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect
3938
4039 Parameters:
41- input_field (jnp.array): Light to be modulated. Comes in the form of an array. Not ScalarLight.
40+ input_mask (jnp.array): Input mask, comes in the form of an array
4241 parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
4342
44- Returns the field (jnp.array) after second propagation, and phase masks slm1 and slm2.
43+ Returns the intensity (jnp.array) after second propagation, and phase masks slm1 and slm2.
4544
4645 + Parameters in the optimizer are (0,1). We need to convert them back [Offset is determined by .get_RS_minimum_z() for the corresponding pixel resolution].
4746 Convert (0,1) to distance in cm. Conversion factor (offset, 100) -> (offset/100, 1).
@@ -52,63 +51,85 @@ def batch_dualSLM_4f(input_field, x, y, wavelength, parameters):
5251 # From get_RS_minimum_z()
5352 offset = 1.2
5453
55- # Restore input_field as ScalarLight object - comes from vmap (AbstractTracer).
56- input_light = ScalarLight (x , y , wavelength )
57- input_light .field = jnp .array (input_field )
54+ # Apply input mask (comes from vmap)
55+ input_light .field = input_light .field * input_mask
5856
5957 """ Stage 0: Propagation """
6058 # Propagate light from mask
61- light_stage0 , quality_0 = input_light .RS_propagation (z = (jnp .abs (parameters [0 ]) * 100 + offset )* cm )
59+ light_stage0 , _ = input_light .RS_propagation (z = (jnp .abs (parameters [0 ]) * 100 + offset )* cm )
6260
6361 """ Stage 0: Modulation """
6462 # Feed SLM_1 with parameters[2] and apply the mask to the forward beam
6563 modulated_slm1 , slm_1 = SLM (light_stage0 , parameters [3 ] * (2 * jnp .pi ) - jnp .pi , shape )
6664
6765 """ Stage 1: Propagation """
6866 # Propagate the SLM_1 output beam to another distance z
69- light_stage1 , quality_1 = modulated_slm1 .RS_propagation (z = (jnp .abs (parameters [1 ]) * 100 + offset )* cm )
67+ light_stage1 , _ = modulated_slm1 .RS_propagation (z = (jnp .abs (parameters [1 ]) * 100 + offset )* cm )
7068
7169 """ Stage 1: Modulation """
7270 # Apply the SLM_2 to the forward beam
7371 modulated_slm2 , slm_2 = SLM (light_stage1 , parameters [4 ] * (2 * jnp .pi ) - jnp .pi , shape )
7472
7573 """ Stage 2: Propagation """
7674 # Propagate the SLM_2 output beam to another distance z
77- fw_to_detector , quality_2 = modulated_slm2 .RS_propagation (z = (jnp .abs (parameters [2 ]) * 100 + offset )* cm )
75+ fw_to_detector , _ = modulated_slm2 .RS_propagation (z = (jnp .abs (parameters [2 ]) * 100 + offset )* cm )
7876
79- return fw_to_detector .field , slm_1 , slm_2
77+ return jnp . abs ( fw_to_detector .field ) ** 2 , slm_1 , slm_2
8078
81- def vector_dualSLM_4f_system (input_fields , x , y , wavelength , parameters ):
79+ def vector_dualSLM_4f_system (input_masks , x , y , wavelength , parameters ):
8280 """
8381 [Coded exclusively for the batch optimization].
8482
8583 Vectorized (efficient) version of 4f system for batch optimization.
8684
8785 Parameters:
88- input_fields (jnp.array): Array with input fields.
86+ input_masks (jnp.array): Array with input masks
8987 x, y, wavelength (jnp.arrays and float): Light specs to pass to batch_dualSLM_4f.
9088 parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
9189
92- Returns vectorized version of detected light.
90+ Returns vectorized version of detected light (intensity) .
9391 """
94- detected_light , _ , _ = vmap (batch_dualSLM_4f , in_axes = (0 , None , None , None , None ))(input_fields , x , y , wavelength , parameters )
95- return detected_light
92+ detected_intensity , _ , _ = vmap (batch_dualSLM_4f , in_axes = (0 , None , None , None , None ))(input_masks , x , y , wavelength , parameters )
93+ return detected_intensity
9694
9795
9896# 3. Define the loss function for batch optimization.
99- def loss_dualSLM (parameters , input_fields , target_fields ):
97+ def loss_dualSLM (parameters , input_masks , target_intensities ):
10098 """
10199 Loss function for 4f system batch optimization. It computes the MSE between the optimized light and the target field.
102100
103101 Parameters:
104102 parameters (list): Optimized parameters.
105- input_fields (jnp.array): Array with input light fields .
106- target_fields (jnp.array): Array with target light fields .
103+ input_masks (jnp.array): Array with input masks .
104+ target_intensities (jnp.array): Array with target intensities .
107105
108- Returns the mean value of the loss computed for all the input fields .
106+ Returns the mean value of the loss computed for all the inputs .
109107 """
110108 global x , y , wavelength
111109 # Input fields and target fields are arrays with synthetic data. Global variables defined in the optical table script.
112- optimized_fields = vector_dualSLM_4f_system (input_fields , x , y , wavelength , parameters )
113- mean_loss , loss_array = mean_batch_MSE_Intensity (optimized_fields , target_fields )
114- return mean_loss
110+ optimized_intensities = vector_dualSLM_4f_system (input_masks , x , y , wavelength , parameters )
111+ mean_loss , loss_array = mean_batch_MSE_Intensity (optimized_intensities , target_intensities )
112+ return mean_loss
113+
114+ def mean_batch_MSE_Intensity (optimized , target ):
115+ """
116+ [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity.
117+
118+ Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair.
119+ """
120+ MSE = vmap (MSE_Intensity , in_axes = (0 , 0 ))(optimized , target )
121+ return jnp .mean (MSE ), MSE
122+
123+ @jit
124+ def MSE_Intensity (input_light , target_light ):
125+ """
126+ Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez.
127+
128+ Parameters:
129+ input_light (array): intensity: input_light = jnp.abs(input_light.field)**2
130+ target_light (array): Ground truth - intensity in the detector: target_light = jnp.abs(target_light.field)**2
131+
132+ Returns the MSE (jnp.array).
133+ """
134+ num_pix = jnp .shape (input_light )[0 ] * jnp .shape (input_light )[1 ]
135+ return jnp .sum ((input_light - target_light )** 2 ) / num_pix
0 commit comments