Skip to content

Commit e9a66de

Browse files
Added BFGS_tjm.py (work in progress)
1 parent e29a9e4 commit e9a66de

File tree

4 files changed

+386
-148
lines changed

4 files changed

+386
-148
lines changed

src/yaqs/noise_char/BFGS_tjm.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import qutip as qt
4+
import matplotlib.ticker as ticker
5+
6+
from yaqs.core.data_structures.networks import MPO, MPS
7+
from yaqs.core.data_structures.noise_model import NoiseModel
8+
from yaqs.core.data_structures.simulation_parameters import Observable, PhysicsSimParams
9+
10+
from yaqs import Simulator
11+
12+
13+
import time
14+
15+
import importlib
16+
import yaqs
17+
18+
from yaqs.noise_char.optimization import *
19+
# from yaqs.noise_char.propagation import *
20+
from yaqs.noise_char.analytical_gradient_tjm import *
21+
22+
importlib.reload(yaqs.noise_char.optimization)
23+
importlib.reload(yaqs.noise_char.propagation)
24+
25+
26+
27+
28+
29+
def BFGS_char(state, H_0, sim_params, noise_model, ref_traj, traj_der, learning_rate=0.01, max_iterations=200, tolerance=1e-8):
30+
"""
31+
Parameters:
32+
sim_params (object): Simulation parameters containing gamma_rel and gamma_deph.
33+
ref_traj (array-like): Reference trajectory data.
34+
traj_der (function): Function that runs the simulation and returns the time,
35+
expected values trajectory, and derivatives of the observables
36+
with respect to the noise parameters.
37+
learning_rate (float, optional): Learning rate for the BFGS optimizer. Default is 0.01.
38+
max_iterations (int, optional): Maximum number of iterations for the optimization. Default is 200.
39+
tolerance (float, optional): Tolerance for the convergence criterion. Default is 1e-8.
40+
Returns:
41+
tuple: A tuple containing:
42+
- loss_history (list): History of loss values during optimization.
43+
- gr_history (list): History of gamma_rel values during optimization.
44+
- gd_history (list): History of gamma_deph values during optimization.
45+
- dJ_dgr_history (list): History of gradients with respect to gamma_rel.
46+
- dJ_dgd_history (list): History of gradients with respect to gamma_deph.
47+
48+
Performs BFGS optimization to minimize the loss function.
49+
"""
50+
loss_history = []
51+
gr_history = []
52+
gd_history = []
53+
dJ_dgr_history = []
54+
dJ_dgd_history = []
55+
56+
gr_history.append(noise_model.strengths[0])
57+
gd_history.append(noise_model.strengths[1])
58+
59+
# Initial parameters
60+
params_old = np.array([noise_model.strengths[0], noise_model.strengths[1]])
61+
n_params = len(params_old)
62+
63+
# Initial inverse Hessian approximation
64+
H_inv = np.eye(n_params)
65+
66+
I = np.eye(n_params)
67+
68+
69+
# Calculate first loss and gradients
70+
loss, exp_vals_traj, grad_old = loss_function_char(state, H_0, sim_params, noise_model, ref_traj, traj_der)
71+
loss_history.append(loss)
72+
73+
74+
75+
for iteration in range(max_iterations):
76+
77+
# Store current parameters and gradients
78+
# params_old = params.copy()
79+
# grad_old = dJ_dg.copy()
80+
81+
# Update parameters
82+
params_new = params_old - learning_rate * H_inv.dot(grad_old)
83+
84+
for i in range(n_params):
85+
if params_new[i] < 0:
86+
params_new[i] = 0
87+
88+
# Update simulation parameters
89+
sim_params.gamma_rel, sim_params.gamma_deph = params_new
90+
91+
# Calculate new loss and gradients
92+
loss, exp_vals_traj, grad_new = loss_function_char(state, H_0, sim_params, noise_model, ref_traj, traj_der)
93+
loss_history.append(loss)
94+
95+
if loss < tolerance:
96+
print(f"Converged after {iteration} iterations.")
97+
break
98+
99+
100+
# Compute differences
101+
s = params_new - params_old
102+
y = grad_new - grad_old
103+
104+
# Update inverse Hessian approximation using BFGS formula
105+
rho = 1.0 / (y.dot(s))
106+
107+
H_inv = (I - rho * np.outer(s, y)).dot(H_inv).dot(I - rho * np.outer(y, s)) + rho * np.outer(s, s)
108+
109+
# Log history
110+
dJ_dgr_history.append(grad_new[0])
111+
dJ_dgd_history.append(grad_new[1])
112+
gr_history.append(noise_model.strengths[0])
113+
gd_history.append(noise_model.strengths[1])
114+
115+
116+
params_old = params_new
117+
grad_old = grad_new
118+
119+
print(f"Iteration {iteration}: Loss = {loss}")
120+
121+
return loss_history, gr_history, gd_history, dJ_dgr_history, dJ_dgd_history
122+
123+
124+
def loss_function_char(state, H_0, sim_params, noise_model, ref_traj, traj_der):
125+
"""
126+
Compute the loss function and its gradients for the given simulation parameters.
127+
Parameters:
128+
sim_params (dict): Dictionary containing the simulation parameters.
129+
ref_traj (list): List of reference trajectories for comparison.
130+
traj_der (function): Function that runs the simulation and returns the time,
131+
expected values trajectory, and derivatives of the observables
132+
with respect to the noise parameters.
133+
Returns:
134+
tuple: A tuple containing:
135+
- loss (float): The computed loss value.
136+
- exp_vals_traj (list): The expected values trajectory from the TJM simulation.
137+
- gradients (numpy.ndarray): Array containing the gradients of the loss with respect
138+
to gamma_relaxation and gamma_dephasing.
139+
"""
140+
141+
142+
# Run the TJM simulation with the given noise parameters
143+
144+
start_time = time.time()
145+
146+
traj_der(state, H_0, sim_params, noise_model)
147+
148+
t = sim_params.times
149+
exp_vals_traj = []
150+
for observable in sim_params.observables:
151+
exp_vals_traj.append(observable.results)
152+
d_On_d_gk = sim_params.d_On_d_gk
153+
154+
155+
end_time = time.time()
156+
tjm_time = end_time - start_time
157+
# print(f"TJM time -> {tjm_time:.4f}")
158+
159+
# Initialize loss
160+
loss = 0.0
161+
162+
# Ensure both lists have the same structure
163+
if len(ref_traj) != len(exp_vals_traj):
164+
raise ValueError("Mismatch in the number of sites between qt_exp_vals and tjm_exp_vals.")
165+
166+
# Compute squared distance for each site
167+
for ref_vals, tjm_vals in zip(ref_traj, exp_vals_traj):
168+
loss += np.sum((np.array(ref_vals) - np.array(tjm_vals)) ** 2)
169+
170+
171+
n_jump = len(d_On_d_gk)
172+
n_obs = len(d_On_d_gk[0])
173+
n_t = len(d_On_d_gk[0][0])
174+
175+
n_gr = n_jump//2
176+
177+
178+
dJ_d_gr = 0
179+
dJ_d_gd = 0
180+
181+
182+
for i in range(n_obs):
183+
for j in range(n_t):
184+
# I have to add all the derivatives with respect to the same gamma_relaxation and gamma_dephasing
185+
for k in range(n_gr):
186+
# The initial half of the jump operators are relaxation operators
187+
dJ_d_gr += 2*(exp_vals_traj[i][j] - ref_traj[i][j]) * d_On_d_gk[k][i][j]
188+
# The second half of the jump operators are dephasing operators
189+
dJ_d_gd += 2*(exp_vals_traj[i][j] - ref_traj[i][j]) * d_On_d_gk[n_gr + k][i][j]
190+
191+
192+
193+
194+
return loss, exp_vals_traj, np.array([dJ_d_gr, dJ_d_gd])
195+
196+
if __name__ == '__main__':
197+
198+
# @dataclass
199+
# class SimulationParameters:
200+
# T: float = 1
201+
# dt: float = 0.1
202+
# L: int = 2
203+
# J: float = 1
204+
# g: float = 0.5
205+
# gamma_rel: float = 0.1
206+
# gamma_deph: float = 0.1
207+
208+
209+
L = 4
210+
d = 2
211+
J = 1
212+
g = 0.5
213+
H_0 = MPO()
214+
H_0.init_Ising(L, d, J, g)
215+
216+
# Define the initial state
217+
state = MPS(L, state='zeros')
218+
219+
# Define the noise model
220+
gamma = 0.1
221+
noise_model = NoiseModel(['relaxation', 'dephasing'], [gamma, gamma])
222+
223+
224+
# Define the simulation parameters
225+
T = 5
226+
dt = 0.1
227+
sample_timesteps = True
228+
N = 100
229+
max_bond_dim = 4
230+
threshold = 1e-6
231+
order = 1
232+
measurements = [Observable('x', site) for site in range(L)] + [Observable('y', site) for site in range(L)] + [Observable('z', site) for site in range(L)]
233+
initial_params = PhysicsSimParams(measurements, T, dt, sample_timesteps, N, max_bond_dim, threshold, order)
234+
235+
236+
'''QUTIP calculation'''
237+
238+
qt_params = SimulationParameters()
239+
240+
qt_params.T = T
241+
qt_params.dt = dt
242+
qt_params.L = L
243+
qt_params.J = J
244+
qt_params.g = g
245+
qt_params.gamma_rel = gamma
246+
qt_params.gamma_deph = gamma
247+
qt_params.observables = ['x','y', 'z']
248+
249+
250+
# Generate reference trajectory
251+
sim_params = SimulationParameters()
252+
253+
# t, qt_ref_traj,dO, A_kn_exp_vals=qutip_traj(sim_params)
254+
t, qt_ref_traj,dO = qutip_traj_char(qt_params)
255+
256+
257+
258+
loss_history, gr_history, gd_history, dJ_dgr_history, dJ_dgd_history = BFGS_char(state, H_0, initial_params, noise_model, qt_ref_traj, run_char, learning_rate=0.2, max_iterations=10,tolerance=1e-8)
259+
260+
261+
262+
plt.plot(np.log(loss_history), label='log(J)')
263+
plt.legend()
264+
265+
266+
def exp_formatter(x, pos):
267+
return f"{np.exp(x):.2e}"
268+
269+
ax = plt.gca()
270+
ax.yaxis.set_major_formatter(ticker.FuncFormatter(exp_formatter))
271+
plt.ylabel('Loss J (exponentiated)')
272+
273+
plt.show()
274+
275+

src/yaqs/noise_char/analytical_gradient_tjm.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
def qutip_traj_char(sim_params_class: SimulationParameters):
3939

40+
print('hello')
41+
4042
T = sim_params_class.T
4143
dt = sim_params_class.dt
4244
L = sim_params_class.L
@@ -179,6 +181,7 @@ def qutip_traj_char(sim_params_class: SimulationParameters):
179181
# d_On_d_gk = [ [trapezoidal(A_kn_exp_vals[i][j],t) for j in range(n_obs)] for i in range(n_jump) ]
180182

181183
# return t, original_exp_vals, d_On_d_gk, A_kn_exp_vals
184+
print('hello')
182185
return t, original_exp_vals, d_On_d_gk
183186

184187

@@ -404,7 +407,7 @@ def PhysicsTJM_1_analytical_gradient(args):
404407
T = 5
405408
dt = 0.1
406409
sample_timesteps = True
407-
N = 500
410+
N = 100
408411
max_bond_dim = 4
409412
threshold = 1e-6
410413
order = 1
@@ -428,10 +431,6 @@ def PhysicsTJM_1_analytical_gradient(args):
428431
t, qt_ref_traj, d_On_d_gk_qt =qutip_traj_char(qt_params)
429432

430433

431-
432-
433-
434-
435434
########## TJM Example #################
436435
run_char(state, H_0, sim_params, noise_model)
437436

@@ -441,47 +440,10 @@ def PhysicsTJM_1_analytical_gradient(args):
441440

442441

443442

444-
445-
# '''Restructure Qutip A_kn means into same structure as TJM A_kn means:'''
446-
447-
# n_sites = len(qt_A_kn_exp_vals)
448-
# n_types = len(qt_params.observables)
449-
# n_noise = len(noise_model.processes)
450-
# n_Akn_per_site = n_noise * n_types
451-
452-
# # Create a new dictionary to hold the Qutip data in the same structure as sim_params.avg_expvals.
453-
# qt_avg_dict = {}
454-
455-
# for site in range(n_sites):
456-
# for type_index in range(n_types):
457-
# key = (qt_params.observables[type_index], site)
458-
# qt_avg_dict[key] = {}
459-
# for noise_index in range(n_noise):
460-
# process = noise_model.processes[noise_index]
461-
# # Calculate the index within the sublist for the given noise process and observable type.
462-
# idx = noise_index * n_types + type_index
463-
# qt_avg_dict[key][process] = qt_A_kn_exp_vals[site][idx]
464-
465-
# # Print out the structure for inspection.
466-
# print("Structure of qt_avg_dict:")
467-
# for key, proc_dict in qt_avg_dict.items():
468-
# print(f"Key (Observable, site): {key}")
469-
# for process, arr in proc_dict.items():
470-
# print(f" Process: {process}, array shape: {np.shape(arr)}")
471-
472-
# '''Structure of TJM and Qutip A_kn means is equal now.'''
473-
474-
475443
# Convert both to numpy arrays
476444
array1 = np.array(sim_params.d_On_d_gk)
477445
array2 = np.array(d_On_d_gk_qt)
478446

479-
# Use np.allclose to compare with a tolerance for floating point differences
480-
if np.allclose(array1, array2, rtol=1e-2, atol=1e-2):
481-
print("sim_params.d_On_d_gk and d_On_d_gk_qt are the same!")
482-
else:
483-
print("They are different.")
484-
485447
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 10))
486448

487449
# Loop over sites (assuming L is the number of sites)
@@ -501,12 +463,12 @@ def PhysicsTJM_1_analytical_gradient(args):
501463
ax1.set_title("TJM d_On_d_gk")
502464
ax1.set_xlabel("Time index (0-50)")
503465
ax1.set_ylabel("Integrated Value")
504-
ax1.legend()
466+
# ax1.legend()
505467

506468
ax2.set_title("Qutip d_On_d_gk")
507469
ax2.set_xlabel("Time index (0-50)")
508470
ax2.set_ylabel("Integrated Value")
509-
ax2.legend()
471+
# ax2.legend()
510472

511473
plt.tight_layout()
512474
plt.show()

0 commit comments

Comments
 (0)