Skip to content

Commit 6543b56

Browse files
new optical table for 4f system training
1 parent dce73b5 commit 6543b56

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

experiments/four_f_optical_table.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from xlumina.__init__ import um, nm, cm
1111
from xlumina.wave_optics import *
1212
from xlumina.optical_elements import SLM
13-
from xlumina.loss_functions import mean_batch_MSE_Intensity
1413
from xlumina.toolbox import space
1514
from jax import vmap
1615
import jax.numpy as jnp
@@ -28,20 +27,20 @@
2827

2928

3029
# 2. Define the optical functions:
31-
def batch_dualSLM_4f(input_field, x, y, wavelength, parameters):
30+
def batch_dualSLM_4f(input_mask, x, y, wavelength, parameters):
3231
"""
3332
[4f system coded exclusively for batch optimization purposes].
3433
3534
Define an optical table with a 4f system composed by 2 SLMs (to be used with ScalarLight).
3635
3736
Illustrative scheme:
38-
U(x,y) --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect
37+
U(x,y) --> input_mask --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect
3938
4039
Parameters:
41-
input_field (jnp.array): Light to be modulated. Comes in the form of an array. Not ScalarLight.
40+
input_mask (jnp.array): Input mask, comes in the form of an array
4241
parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
4342
44-
Returns the field (jnp.array) after second propagation, and phase masks slm1 and slm2.
43+
Returns the intensity (jnp.array) after second propagation, and phase masks slm1 and slm2.
4544
4645
+ Parameters in the optimizer are (0,1). We need to convert them back [Offset is determined by .get_RS_minimum_z() for the corresponding pixel resolution].
4746
Convert (0,1) to distance in cm. Conversion factor (offset, 100) -> (offset/100, 1).
@@ -52,63 +51,85 @@ def batch_dualSLM_4f(input_field, x, y, wavelength, parameters):
5251
# From get_RS_minimum_z()
5352
offset = 1.2
5453

55-
# Restore input_field as ScalarLight object - comes from vmap (AbstractTracer).
56-
input_light = ScalarLight(x, y, wavelength)
57-
input_light.field = jnp.array(input_field)
54+
# Apply input mask (comes from vmap)
55+
input_light.field = input_light.field * input_mask
5856

5957
""" Stage 0: Propagation """
6058
# Propagate light from mask
61-
light_stage0, quality_0 = input_light.RS_propagation(z=(jnp.abs(parameters[0]) * 100 + offset)*cm)
59+
light_stage0, _ = input_light.RS_propagation(z=(jnp.abs(parameters[0]) * 100 + offset)*cm)
6260

6361
""" Stage 0: Modulation """
6462
# Feed SLM_1 with parameters[2] and apply the mask to the forward beam
6563
modulated_slm1, slm_1 = SLM(light_stage0, parameters[3] * (2*jnp.pi) - jnp.pi, shape)
6664

6765
""" Stage 1: Propagation """
6866
# Propagate the SLM_1 output beam to another distance z
69-
light_stage1, quality_1 = modulated_slm1.RS_propagation(z=(jnp.abs(parameters[1]) * 100 + offset)*cm)
67+
light_stage1, _ = modulated_slm1.RS_propagation(z=(jnp.abs(parameters[1]) * 100 + offset)*cm)
7068

7169
""" Stage 1: Modulation """
7270
# Apply the SLM_2 to the forward beam
7371
modulated_slm2, slm_2 = SLM(light_stage1, parameters[4] * (2*jnp.pi) - jnp.pi, shape)
7472

7573
""" Stage 2: Propagation """
7674
# Propagate the SLM_2 output beam to another distance z
77-
fw_to_detector, quality_2 = modulated_slm2.RS_propagation(z=(jnp.abs(parameters[2]) * 100 + offset)*cm)
75+
fw_to_detector, _ = modulated_slm2.RS_propagation(z=(jnp.abs(parameters[2]) * 100 + offset)*cm)
7876

79-
return fw_to_detector.field, slm_1, slm_2
77+
return jnp.abs(fw_to_detector.field)**2, slm_1, slm_2
8078

81-
def vector_dualSLM_4f_system(input_fields, x, y, wavelength, parameters):
79+
def vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters):
8280
"""
8381
[Coded exclusively for the batch optimization].
8482
8583
Vectorized (efficient) version of 4f system for batch optimization.
8684
8785
Parameters:
88-
input_fields (jnp.array): Array with input fields.
86+
input_masks (jnp.array): Array with input masks
8987
x, y, wavelength (jnp.arrays and float): Light specs to pass to batch_dualSLM_4f.
9088
parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
9189
92-
Returns vectorized version of detected light.
90+
Returns vectorized version of detected light (intensity).
9391
"""
94-
detected_light, _, _ = vmap(batch_dualSLM_4f, in_axes=(0, None, None, None, None))(input_fields, x, y, wavelength, parameters)
95-
return detected_light
92+
detected_intensity, _, _ = vmap(batch_dualSLM_4f, in_axes=(0, None, None, None, None))(input_masks, x, y, wavelength, parameters)
93+
return detected_intensity
9694

9795

9896
# 3. Define the loss function for batch optimization.
99-
def loss_dualSLM(parameters, input_fields, target_fields):
97+
def loss_dualSLM(parameters, input_masks, target_intensities):
10098
"""
10199
Loss function for 4f system batch optimization. It computes the MSE between the optimized light and the target field.
102100
103101
Parameters:
104102
parameters (list): Optimized parameters.
105-
input_fields (jnp.array): Array with input light fields.
106-
target_fields (jnp.array): Array with target light fields.
103+
input_masks (jnp.array): Array with input masks.
104+
target_intensities (jnp.array): Array with target intensities.
107105
108-
Returns the mean value of the loss computed for all the input fields.
106+
Returns the mean value of the loss computed for all the inputs.
109107
"""
110108
global x, y, wavelength
111109
# Input fields and target fields are arrays with synthetic data. Global variables defined in the optical table script.
112-
optimized_fields = vector_dualSLM_4f_system(input_fields, x, y, wavelength, parameters)
113-
mean_loss, loss_array = mean_batch_MSE_Intensity(optimized_fields, target_fields)
114-
return mean_loss
110+
optimized_intensities = vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters)
111+
mean_loss, loss_array = mean_batch_MSE_Intensity(optimized_intensities, target_intensities)
112+
return mean_loss
113+
114+
def mean_batch_MSE_Intensity(optimized, target):
115+
"""
116+
[Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity.
117+
118+
Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair.
119+
"""
120+
MSE = vmap(MSE_Intensity, in_axes=(0, 0))(optimized, target)
121+
return jnp.mean(MSE), MSE
122+
123+
@jit
124+
def MSE_Intensity(input_light, target_light):
125+
"""
126+
Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez.
127+
128+
Parameters:
129+
input_light (array): intensity: input_light = jnp.abs(input_light.field)**2
130+
target_light (array): Ground truth - intensity in the detector: target_light = jnp.abs(target_light.field)**2
131+
132+
Returns the MSE (jnp.array).
133+
"""
134+
num_pix = jnp.shape(input_light)[0] * jnp.shape(input_light)[1]
135+
return jnp.sum((input_light - target_light)** 2) / num_pix

0 commit comments

Comments
 (0)