Skip to content

Commit 71e9bb5

Browse files
new optimizer for 4f system
1 parent 6543b56 commit 71e9bb5

File tree

1 file changed

+80
-53
lines changed

1 file changed

+80
-53
lines changed

experiments/four_f_optimizer.py

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,64 +17,91 @@
1717
# Call the data loader and set batchsize
1818
dataloader = MultiHDF5DataLoader("training_data_4f", batch_size=10)
1919

20-
# Define the update for batch optimization:
21-
@jit
22-
def update(step_index, optimizer_state, input_fields, target_fields):
23-
parameters = get_params(optimizer_state)
24-
# Call the loss function and compute the gradients
25-
computed_loss = loss_fn(parameters, input_fields, target_fields)
26-
computed_gradients = grad(loss_fn, allow_int=True)(parameters, input_fields, target_fields)
27-
return opt_update(step_index, computed_gradients, optimizer_state), computed_loss, computed_gradients
20+
# Define the loss function and compute its gradients:
21+
loss_function = jit(loss_dualSLM)
2822

29-
# JIT the loss function:
30-
loss_fn = jit(loss_dualSLM)
23+
# ----------------------------------------------------
3124

32-
# Optimizer settings
33-
STEP_SIZE = 0.1
34-
num_iterations = 50000
35-
n_best = 500
36-
best_loss = 1e2
37-
best_params = None
38-
best_step = 0
39-
40-
# Init random parameters
41-
phase_mask_slm1 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
42-
phase_mask_slm2 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
43-
distance_0 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
44-
distance_1 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
45-
distance_2 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
46-
init_params = [distance_0, distance_1, distance_2, phase_mask_slm1, phase_mask_slm2]
25+
def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params:
4726

48-
# Define the optimizer and initialize it
49-
opt_init, opt_update, get_params = optimizers.adam(STEP_SIZE)
50-
opt_state = opt_init(init_params)
51-
52-
# Optimize in a loop:
53-
print('Starting Optimization', flush=True)
54-
tic = time.perf_counter()
27+
opt_state = optimizer.init(params)
28+
29+
@jit
30+
def update(params, opt_state, input_fields, target_fields):
31+
# Define single update step:
32+
# JIT the loss and compute
33+
loss_value, grads = jax.value_and_grad(loss_function, allow_int=True)(params, input_fields, target_fields)
34+
# Update the state of the optimizer
35+
updates, opt_state = optimizer.update(grads, opt_state, params)
36+
params = optax.apply_updates(params, updates)
37+
return params, opt_state, loss_value
5538

56-
for step in range(num_iterations):
39+
# Initialize some parameters
40+
iteration_steps=[]
41+
loss_list=[]
42+
43+
# Optimizer settings
44+
n_best = 500
45+
best_loss = 1e2
46+
best_params = None
47+
best_step = 0
5748

58-
# Load data:
59-
input_fields, target_fields = next(dataloader)
60-
# Perform an update step:
61-
opt_state, loss_value, gradients = update(step, opt_state, input_fields, target_fields)
49+
print('Starting Optimization', flush=True)
6250

63-
# Update the `best_loss` value:
64-
if loss_value < best_loss:
65-
# Best loss value
66-
best_loss = loss_value
67-
# Best optimized parameters
68-
best_params = get_params(opt_state)
69-
best_step = step
70-
print('Best loss value is updated')
51+
for step in range(num_iterations):
52+
# Load data:
53+
input_fields, target_fields = next(dataloader)
54+
params, opt_state, loss_value = update(params, opt_state, input_fields, target_fields)
55+
56+
print(f"Step {step}")
57+
print(f"Loss {loss_value}")
58+
59+
iteration_steps.append(step)
60+
loss_list.append(loss_value)
61+
62+
# Update the `best_loss` value:
63+
if loss_value < best_loss:
64+
# Best loss value
65+
best_loss = loss_value
66+
# Best optimized parameters
67+
best_params = params
68+
best_step = step
69+
print('Best loss value is updated')
70+
71+
if step % 100 == 0:
72+
# Stopping criteria: if best_loss has not changed every 500 steps, stop.
73+
if step - best_step > n_best:
74+
print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
75+
break
76+
77+
print(f'Best loss: {best_loss} at step {best_step}')
78+
print(f'Best parameters: {best_params}')
79+
return best_params, best_loss, iteration_steps, loss_list
80+
81+
# ----------------------------------------------------
7182

72-
if step % 500 == 0:
73-
# Stopping criteria: if best_loss has not changed every 500 steps, stop.
74-
if step - best_step > n_best:
75-
print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
76-
break
83+
# Optimizer settings
84+
num_iterations = 50000
85+
num_samples = 50
86+
# Step size engineering:
87+
STEP_SIZE = 0.01
88+
WEIGHT_DECAY = 0.0001
7789

78-
print(f'Best loss: {best_loss} at step {best_step}')
79-
print(f'Best parameters: {best_params}')
80-
print("Time taken to optimize - in seconds", time.perf_counter() - tic)
90+
for i in range(num_samples):
91+
tic = time.perf_counter()
92+
93+
# Init random parameters
94+
phase_mask_slm1 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
95+
phase_mask_slm2 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
96+
distance_0 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
97+
distance_1 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
98+
distance_2 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
99+
init_params = [distance_0, distance_1, distance_2, phase_mask_slm1, phase_mask_slm2]
100+
101+
# Init optimizer:
102+
optimizer = optax.adamw(STEP_SIZE, weight_decay=WEIGHT_DECAY)
103+
104+
# Apply fit function:
105+
best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations)
106+
107+
print("Time taken to optimize one sample - in seconds", time.perf_counter() - tic)

0 commit comments

Comments
 (0)