1717# Call the data loader and set batchsize
1818dataloader = MultiHDF5DataLoader ("training_data_4f" , batch_size = 10 )
1919
20- # Define the update for batch optimization:
21- @jit
22- def update (step_index , optimizer_state , input_fields , target_fields ):
23- parameters = get_params (optimizer_state )
24- # Call the loss function and compute the gradients
25- computed_loss = loss_fn (parameters , input_fields , target_fields )
26- computed_gradients = grad (loss_fn , allow_int = True )(parameters , input_fields , target_fields )
27- return opt_update (step_index , computed_gradients , optimizer_state ), computed_loss , computed_gradients
20+ # Define the loss function and compute its gradients:
21+ loss_function = jit (loss_dualSLM )
2822
29- # JIT the loss function:
30- loss_fn = jit (loss_dualSLM )
23+ # ----------------------------------------------------
3124
32- # Optimizer settings
33- STEP_SIZE = 0.1
34- num_iterations = 50000
35- n_best = 500
36- best_loss = 1e2
37- best_params = None
38- best_step = 0
39-
40- # Init random parameters
41- phase_mask_slm1 = jnp .array ([np .random .uniform (0 , 1 , (shape , shape ))], dtype = jnp .float64 )[0 ]
42- phase_mask_slm2 = jnp .array ([np .random .uniform (0 , 1 , (shape , shape ))], dtype = jnp .float64 )[0 ]
43- distance_0 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
44- distance_1 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
45- distance_2 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
46- init_params = [distance_0 , distance_1 , distance_2 , phase_mask_slm1 , phase_mask_slm2 ]
25+ def fit (params : optax .Params , optimizer : optax .GradientTransformation , num_iterations ) -> optax .Params :
4726
48- # Define the optimizer and initialize it
49- opt_init , opt_update , get_params = optimizers .adam (STEP_SIZE )
50- opt_state = opt_init (init_params )
51-
52- # Optimize in a loop:
53- print ('Starting Optimization' , flush = True )
54- tic = time .perf_counter ()
27+ opt_state = optimizer .init (params )
28+
29+ @jit
30+ def update (params , opt_state , input_fields , target_fields ):
31+ # Define single update step:
32+ # JIT the loss and compute
33+ loss_value , grads = jax .value_and_grad (loss_function , allow_int = True )(params , input_fields , target_fields )
34+ # Update the state of the optimizer
35+ updates , opt_state = optimizer .update (grads , opt_state , params )
36+ params = optax .apply_updates (params , updates )
37+ return params , opt_state , loss_value
5538
56- for step in range (num_iterations ):
39+ # Initialize some parameters
40+ iteration_steps = []
41+ loss_list = []
42+
43+ # Optimizer settings
44+ n_best = 500
45+ best_loss = 1e2
46+ best_params = None
47+ best_step = 0
5748
58- # Load data:
59- input_fields , target_fields = next (dataloader )
60- # Perform an update step:
61- opt_state , loss_value , gradients = update (step , opt_state , input_fields , target_fields )
49+ print ('Starting Optimization' , flush = True )
6250
63- # Update the `best_loss` value:
64- if loss_value < best_loss :
65- # Best loss value
66- best_loss = loss_value
67- # Best optimized parameters
68- best_params = get_params (opt_state )
69- best_step = step
70- print ('Best loss value is updated' )
51+ for step in range (num_iterations ):
52+ # Load data:
53+ input_fields , target_fields = next (dataloader )
54+ params , opt_state , loss_value = update (params , opt_state , input_fields , target_fields )
55+
56+ print (f"Step { step } " )
57+ print (f"Loss { loss_value } " )
58+
59+ iteration_steps .append (step )
60+ loss_list .append (loss_value )
61+
62+ # Update the `best_loss` value:
63+ if loss_value < best_loss :
64+ # Best loss value
65+ best_loss = loss_value
66+ # Best optimized parameters
67+ best_params = params
68+ best_step = step
69+ print ('Best loss value is updated' )
70+
71+ if step % 100 == 0 :
72+ # Stopping criteria: if best_loss has not changed every 500 steps, stop.
73+ if step - best_step > n_best :
74+ print (f'Stopping criterion: no improvement in loss value for { n_best } steps' )
75+ break
76+
77+ print (f'Best loss: { best_loss } at step { best_step } ' )
78+ print (f'Best parameters: { best_params } ' )
79+ return best_params , best_loss , iteration_steps , loss_list
80+
81+ # ----------------------------------------------------
7182
72- if step % 500 == 0 :
73- # Stopping criteria: if best_loss has not changed every 500 steps, stop.
74- if step - best_step > n_best :
75- print (f'Stopping criterion: no improvement in loss value for { n_best } steps' )
76- break
83+ # Optimizer settings
84+ num_iterations = 50000
85+ num_samples = 50
86+ # Step size engineering:
87+ STEP_SIZE = 0.01
88+ WEIGHT_DECAY = 0.0001
7789
78- print (f'Best loss: { best_loss } at step { best_step } ' )
79- print (f'Best parameters: { best_params } ' )
80- print ("Time taken to optimize - in seconds" , time .perf_counter () - tic )
90+ for i in range (num_samples ):
91+ tic = time .perf_counter ()
92+
93+ # Init random parameters
94+ phase_mask_slm1 = jnp .array ([np .random .uniform (0 , 1 , (shape , shape ))], dtype = jnp .float64 )[0 ]
95+ phase_mask_slm2 = jnp .array ([np .random .uniform (0 , 1 , (shape , shape ))], dtype = jnp .float64 )[0 ]
96+ distance_0 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
97+ distance_1 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
98+ distance_2 = jnp .array ([np .random .uniform (0.027 , 1 )], dtype = jnp .float64 )
99+ init_params = [distance_0 , distance_1 , distance_2 , phase_mask_slm1 , phase_mask_slm2 ]
100+
101+ # Init optimizer:
102+ optimizer = optax .adamw (STEP_SIZE , weight_decay = WEIGHT_DECAY )
103+
104+ # Apply fit function:
105+ best_params , best_loss , iteration_steps , loss_list = fit (init_params , optimizer , num_iterations )
106+
107+ print ("Time taken to optimize one sample - in seconds" , time .perf_counter () - tic )
0 commit comments