Skip to content

Commit d424ac4

Browse files
new optimizer for large-scale hybrid discovery -- it is common for all the optical tables that contain hybrid and 6x6
1 parent 0a9a332 commit d424ac4

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

experiments/hybrid_optimizer.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import os
2+
import sys
3+
current_path = os.path.abspath(os.path.join('..'))
4+
dir_path = os.path.dirname(current_path)
5+
module_path = os.path.join(dir_path)
6+
if module_path not in sys.path:
7+
sys.path.append(module_path)
8+
9+
import time
10+
import jax
11+
from jax import grad, jit
12+
import optax
13+
import numpy as np
14+
import jax.numpy as jnp
15+
import gc # Garbage collector
16+
17+
# Use this for pure topological discovery:
18+
from xlumina.six_times_six_ansatz_with_fixed_PM import * #<--- use this for 6x6 ansatz
19+
# from xlumina.hybrid_with_fixed_PM import * # <--- use this for 3x3 with fixed masks
20+
21+
# Use this for hybrid optimization:
22+
# from xlumina.hybrid_sharp_optical_table import * # <--- use this for sharp focus
23+
# from xlumina.hybrid_sted_optical_table import * # <--- use this for sted
24+
25+
"""
26+
OPTIMIZER - LARGE-SCALE SETUPS
27+
"""
28+
29+
# Print device info (GPU or CPU)
30+
print(jax.devices(), flush=True)
31+
32+
# Global variable
33+
shape = jnp.array([sensor_lateral_size, sensor_lateral_size])
34+
35+
# Define the loss function and compute its gradients:
36+
# loss_function = jit(loss_hybrid_sharp_focus) # <--- use this for sharp focus
37+
# loss_function = jit(loss_hybrid_sted) # <--- use this for sted
38+
loss_function = jit(loss_hybrid_fixed_PM) # <--- use this for sharp focus with fixed phase masks
39+
40+
# ----------------------------------------------------
41+
42+
def clip_adamw(learning_rate, weight_decay) -> optax.GradientTransformation:
43+
"""
44+
Custom optimizer - adamw: applies several transformations in sequence
45+
1) Apply ADAM
46+
2) Apply weight decay
47+
"""
48+
return optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
49+
50+
def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params:
51+
52+
# Init the optimizer with initial parameters
53+
opt_state = optimizer.init(params)
54+
55+
@jit
56+
def update(parameters, opt_state):
57+
# Define single update step:
58+
loss_value, grads = jax.value_and_grad(loss_function)(parameters)
59+
60+
# Update the state of the optimizer
61+
updates, state = optimizer.update(grads, opt_state, parameters)
62+
63+
# Update the parameters
64+
new_params = optax.apply_updates(parameters, updates)
65+
66+
return new_params, parameters, state, loss_value, updates
67+
68+
69+
# Initialize some parameters
70+
iteration_steps=[]
71+
loss_list=[]
72+
73+
n_best = 500
74+
best_loss = 3*1e2
75+
best_params = None
76+
best_step = 0
77+
78+
print('Starting Optimization', flush=True)
79+
80+
for step in range(num_iterations):
81+
82+
params, old_params, opt_state, loss_value, grads = update(params, opt_state)
83+
84+
print(f"Step {step}")
85+
print(f"Loss {loss_value}")
86+
iteration_steps.append(step)
87+
loss_list.append(loss_value)
88+
89+
# Update the `best_loss` value:
90+
if loss_value < best_loss:
91+
# Best loss value
92+
best_loss = loss_value
93+
# Best optimized parameters
94+
best_params = old_params
95+
best_step = step
96+
print('Best loss value is updated')
97+
98+
if step % 100 == 0:
99+
# Stopping criteria: if best_loss has not changed every 500 steps, stop.
100+
if step - best_step > n_best:
101+
print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
102+
break
103+
104+
print(f'Best loss: {best_loss} at step {best_step}')
105+
print(f'Best parameters: {best_params}')
106+
return best_params, best_loss, iteration_steps, loss_list
107+
108+
# ----------------------------------------------------
109+
110+
# Optimizer settings
111+
num_iterations = 100000
112+
num_samples = 20
113+
114+
for i in range(num_samples):
115+
116+
STEP_SIZE = 0.05
117+
WEIGHT_DECAY = 0.0001
118+
119+
gc.collect()
120+
tic = time.perf_counter()
121+
122+
# Parameters -- know which ones to comment based on the setup you want to optimize:
123+
# super-SLM phase masks:
124+
phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
125+
phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
126+
phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
127+
phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
128+
phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
129+
phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
130+
131+
# Wave plate variables:
132+
eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
133+
theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
134+
eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
135+
theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
136+
eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
137+
theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
138+
eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
139+
theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
140+
141+
# Distances:
142+
z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
143+
z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
144+
z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
145+
z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
146+
z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
147+
z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
148+
z4_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
149+
z4_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
150+
z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
151+
z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
152+
z1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
153+
z2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
154+
z3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
155+
z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
156+
z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
157+
z6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
158+
z7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
159+
z8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
160+
z9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
161+
z10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
162+
z11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
163+
z12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
164+
165+
# Beam splitter ratios
166+
bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
167+
bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
168+
bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
169+
bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
170+
bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
171+
bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
172+
bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
173+
bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
174+
bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
175+
bs10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
176+
bs11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
177+
bs12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
178+
bs13 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
179+
bs14 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
180+
bs15 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
181+
bs16 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
182+
bs17 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
183+
bs18 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
184+
bs19 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
185+
bs20 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
186+
bs21 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
187+
bs22 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
188+
bs23 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
189+
bs24 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
190+
bs25 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
191+
bs26 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
192+
bs27 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
193+
bs28 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
194+
bs29 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
195+
bs30 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
196+
bs31 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
197+
bs32 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
198+
bs33 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
199+
bs34 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
200+
bs35 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
201+
bs36 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
202+
203+
# Set which set of init parameters to use:
204+
# REMEMBER TO COMMENT (#) THE VARIABLES YOU DON'T USE!
205+
206+
# 1. For 3x3 hybrid optimization (topology + optical parameters):
207+
# init_params = [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2, phase2_1, phase2_2, eta2, theta2, z2_1, z2_2, phase3_1, phase3_2, eta3, theta3, z3_1, z3_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, z4, z5]
208+
209+
# 2. Parameters for pure topological optimization on 3x3 systems with fixed phase masks at random positions:
210+
# init_params = [z1_1, z1_2, z2_1, z2_2, z3_1, z3_2, z4_1, z4_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, eta1, theta1, eta2, theta2, eta3, theta3, eta4, theta4]
211+
212+
# 3. Parameters for pure topological optimization on the 6x6 system with fixed phase masks:
213+
init_params = [z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12,
214+
bs1, bs2, bs3, bs4, bs5, bs6,
215+
bs7, bs8, bs9, bs10, bs11, bs12,
216+
bs13, bs14, bs15, bs16, bs17, bs18,
217+
bs19, bs20, bs21, bs22, bs23, bs24,
218+
bs25, bs26, bs27, bs28, bs29, bs30,
219+
bs31, bs32, bs33, bs34, bs35, bs36,
220+
eta1, theta1, eta2, theta2]
221+
222+
# Init optimizer:
223+
optimizer = clip_adamw(STEP_SIZE, WEIGHT_DECAY)
224+
225+
# Apply fit function:
226+
best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations)

0 commit comments

Comments
 (0)