1+ import os
2+ import sys
3+ current_path = os .path .abspath (os .path .join ('..' ))
4+ dir_path = os .path .dirname (current_path )
5+ module_path = os .path .join (dir_path )
6+ if module_path not in sys .path :
7+ sys .path .append (module_path )
8+
9+ import time
10+ import jax
11+ from jax import grad , jit
12+ import optax
13+ import numpy as np
14+ import jax .numpy as jnp
15+ import gc # Garbage collector
16+
17+ # Use this for pure topological discovery:
18+ from xlumina .six_times_six_ansatz_with_fixed_PM import * #<--- use this for 6x6 ansatz
19+ # from xlumina.hybrid_with_fixed_PM import * # <--- use this for 3x3 with fixed masks
20+
21+ # Use this for hybrid optimization:
22+ # from xlumina.hybrid_sharp_optical_table import * # <--- use this for sharp focus
23+ # from xlumina.hybrid_sted_optical_table import * # <--- use this for sted
24+
25+ """
26+ OPTIMIZER - LARGE-SCALE SETUPS
27+ """
28+
29+ # Print device info (GPU or CPU)
30+ print (jax .devices (), flush = True )
31+
32+ # Global variable
33+ shape = jnp .array ([sensor_lateral_size , sensor_lateral_size ])
34+
35+ # Define the loss function and compute its gradients:
36+ # loss_function = jit(loss_hybrid_sharp_focus) # <--- use this for sharp focus
37+ # loss_function = jit(loss_hybrid_sted) # <--- use this for sted
38+ loss_function = jit (loss_hybrid_fixed_PM ) # <--- use this for sharp focus with fixed phase masks
39+
40+ # ----------------------------------------------------
41+
42+ def clip_adamw (learning_rate , weight_decay ) -> optax .GradientTransformation :
43+ """
44+ Custom optimizer - adamw: applies several transformations in sequence
45+ 1) Apply ADAM
46+ 2) Apply weight decay
47+ """
48+ return optax .adamw (learning_rate = learning_rate , weight_decay = weight_decay )
49+
50+ def fit (params : optax .Params , optimizer : optax .GradientTransformation , num_iterations ) -> optax .Params :
51+
52+ # Init the optimizer with initial parameters
53+ opt_state = optimizer .init (params )
54+
55+ @jit
56+ def update (parameters , opt_state ):
57+ # Define single update step:
58+ loss_value , grads = jax .value_and_grad (loss_function )(parameters )
59+
60+ # Update the state of the optimizer
61+ updates , state = optimizer .update (grads , opt_state , parameters )
62+
63+ # Update the parameters
64+ new_params = optax .apply_updates (parameters , updates )
65+
66+ return new_params , parameters , state , loss_value , updates
67+
68+
69+ # Initialize some parameters
70+ iteration_steps = []
71+ loss_list = []
72+
73+ n_best = 500
74+ best_loss = 3 * 1e2
75+ best_params = None
76+ best_step = 0
77+
78+ print ('Starting Optimization' , flush = True )
79+
80+ for step in range (num_iterations ):
81+
82+ params , old_params , opt_state , loss_value , grads = update (params , opt_state )
83+
84+ print (f"Step { step } " )
85+ print (f"Loss { loss_value } " )
86+ iteration_steps .append (step )
87+ loss_list .append (loss_value )
88+
89+ # Update the `best_loss` value:
90+ if loss_value < best_loss :
91+ # Best loss value
92+ best_loss = loss_value
93+ # Best optimized parameters
94+ best_params = old_params
95+ best_step = step
96+ print ('Best loss value is updated' )
97+
98+ if step % 100 == 0 :
99+ # Stopping criteria: if best_loss has not changed every 500 steps, stop.
100+ if step - best_step > n_best :
101+ print (f'Stopping criterion: no improvement in loss value for { n_best } steps' )
102+ break
103+
104+ print (f'Best loss: { best_loss } at step { best_step } ' )
105+ print (f'Best parameters: { best_params } ' )
106+ return best_params , best_loss , iteration_steps , loss_list
107+
108+ # ----------------------------------------------------
109+
110+ # Optimizer settings
111+ num_iterations = 100000
112+ num_samples = 20
113+
114+ for i in range (num_samples ):
115+
116+ STEP_SIZE = 0.05
117+ WEIGHT_DECAY = 0.0001
118+
119+ gc .collect ()
120+ tic = time .perf_counter ()
121+
122+ # Parameters -- know which ones to comment based on the setup you want to optimize:
123+ # super-SLM phase masks:
124+ phase1_1 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
125+ phase1_2 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
126+ phase2_1 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
127+ phase2_2 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
128+ phase3_1 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
129+ phase3_2 = jnp .array ([np .random .uniform (0 , 1 , shape )], dtype = jnp .float64 )[0 ]
130+
131+ # Wave plate variables:
132+ eta1 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
133+ theta1 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
134+ eta2 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
135+ theta2 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
136+ eta3 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
137+ theta3 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
138+ eta4 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
139+ theta4 = jnp .array ([np .random .uniform (0 , 1 , 1 )], dtype = jnp .float64 )[0 ]
140+
141+ # Distances:
142+ z1_1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
143+ z1_2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
144+ z2_1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
145+ z2_2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
146+ z3_1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
147+ z3_2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
148+ z4_1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
149+ z4_2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
150+ z4 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
151+ z5 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
152+ z1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
153+ z2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
154+ z3 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
155+ z4 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
156+ z5 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
157+ z6 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
158+ z7 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
159+ z8 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
160+ z9 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
161+ z10 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
162+ z11 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
163+ z12 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
164+
165+ # Beam splitter ratios
166+ bs1 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
167+ bs2 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
168+ bs3 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
169+ bs4 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
170+ bs5 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
171+ bs6 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
172+ bs7 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
173+ bs8 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
174+ bs9 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
175+ bs10 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
176+ bs11 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
177+ bs12 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
178+ bs13 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
179+ bs14 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
180+ bs15 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
181+ bs16 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
182+ bs17 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
183+ bs18 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
184+ bs19 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
185+ bs20 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
186+ bs21 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
187+ bs22 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
188+ bs23 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
189+ bs24 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
190+ bs25 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
191+ bs26 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
192+ bs27 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
193+ bs28 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
194+ bs29 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
195+ bs30 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
196+ bs31 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
197+ bs32 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
198+ bs33 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
199+ bs34 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
200+ bs35 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
201+ bs36 = jnp .array ([np .random .uniform (0 , 1 )], dtype = jnp .float64 )
202+
203+ # Set which set of init parameters to use:
204+ # REMEMBER TO COMMENT (#) THE VARIABLES YOU DON'T USE!
205+
206+ # 1. For 3x3 hybrid optimization (topology + optical parameters):
207+ # init_params = [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2, phase2_1, phase2_2, eta2, theta2, z2_1, z2_2, phase3_1, phase3_2, eta3, theta3, z3_1, z3_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, z4, z5]
208+
209+ # 2. Parameters for pure topological optimization on 3x3 systems with fixed phase masks at random positions:
210+ # init_params = [z1_1, z1_2, z2_1, z2_2, z3_1, z3_2, z4_1, z4_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, eta1, theta1, eta2, theta2, eta3, theta3, eta4, theta4]
211+
212+ # 3. Parameters for pure topological optimization on the 6x6 system with fixed phase masks:
213+ init_params = [z1 , z2 , z3 , z4 , z5 , z6 , z7 , z8 , z9 , z10 , z11 , z12 ,
214+ bs1 , bs2 , bs3 , bs4 , bs5 , bs6 ,
215+ bs7 , bs8 , bs9 , bs10 , bs11 , bs12 ,
216+ bs13 , bs14 , bs15 , bs16 , bs17 , bs18 ,
217+ bs19 , bs20 , bs21 , bs22 , bs23 , bs24 ,
218+ bs25 , bs26 , bs27 , bs28 , bs29 , bs30 ,
219+ bs31 , bs32 , bs33 , bs34 , bs35 , bs36 ,
220+ eta1 , theta1 , eta2 , theta2 ]
221+
222+ # Init optimizer:
223+ optimizer = clip_adamw (STEP_SIZE , WEIGHT_DECAY )
224+
225+ # Apply fit function:
226+ best_params , best_loss , iteration_steps , loss_list = fit (init_params , optimizer , num_iterations )
0 commit comments