1515from typing import Any , Callable , Dict , List , Tuple
1616from collections .abc import Iterable
1717import scipy .optimize as sci_opt
18- import matplotlib
19- from matplotlib import pyplot as plt
2018
2119
2220class Algo (Enum ):
@@ -157,7 +155,8 @@ def potential(self, x):
157155 """
158156 Potential function V(x) = x^2 + l*x^4
159157 """
160- return x ** 2 + self .l * x ** 4
158+ x_sq = x ** 2
159+ return x_sq + (self .l * x_sq * x_sq )
161160
162161 def analytic_energy (n ):
163162 """
@@ -166,16 +165,15 @@ def analytic_energy(n):
166165 return np .sqrt (2.0 ) * (n - 0.5 )
167166
168167 def w_gen (self , energy ):
169- return lambda x : np .sqrt (2 * self .m * (complex (energy ) - self .potential (x )))
168+ return lambda x : np .sqrt (2.0 * self .m * (complex (energy ) - self .potential (x )))
170169
171170 def g_gen (self ):
172171 return lambda x : np .zeros_like (x )
173172
174173 def f_gen (self , energy ):
175174 def f (x , y ):
176175 psi , dpsi = y
177- return [dpsi , - 2 * self .m * (complex (energy ) - self .potential (x )) * psi ]
178-
176+ return [dpsi , (self .potential (x ) - energy ) * psi ]
179177 return f
180178
181179 def yi_init (self ):
@@ -297,7 +295,7 @@ def flatten_tuple(x):
297295 base_output_path = "./benchmarks/output/"
298296all_algo_pl_lst : List [pl .DataFrame ] = []
299297first_write = True
300- with open (base_output_path + "schrodinger_times2 .csv" , mode = "a" ) as time_file :
298+ with open (base_output_path + "schrodinger_times .csv" , mode = "a" ) as time_file :
301299 for algo , algo_params in algorithm_dict .items ():
302300 algo_evals_pl_lst = []
303301 for benchmark_run in range (1 ):
@@ -339,7 +337,7 @@ def flatten_tuple(x):
339337 algo_evals_pl_lst .append (algo_pl_tmp )
340338 algo_pl = pl .concat (algo_evals_pl_lst )
341339 print (algo_pl )
342- algo_pl .write_csv (base_output_path + f"schrod2_ { str (algo )} .csv" )
340+ algo_pl .write_csv (base_output_path + f"schrod_ { str (algo )} .csv" )
343341 all_algo_pl_lst .append (algo_pl )
344342 time_pl_lst = []
345343 for algo_key , time_st in global_timer .execs .items ():
@@ -361,7 +359,7 @@ def flatten_tuple(x):
361359 time_pl .write_csv (time_file , include_header = False )
362360
363361all_algo_pl = pl .concat (all_algo_pl_lst )
364- all_algo_pl .write_csv (f"{ base_output_path } schrod2 .csv" )
362+ all_algo_pl .write_csv (f"{ base_output_path } schrod .csv" )
365363# %%
366364# %%
367365time_pl_lst = []
@@ -376,112 +374,3 @@ def flatten_tuple(x):
376374time_pl .write_csv (base_output_path + "schrodinger_times2.csv" )
377375
378376
379- # %%
380- if False :
381- ns = [50 , 100 ]
382- energies = solution_lst [:2 ]
383-
384- x_plot = np .linspace (- 6 , 6 , 500 )
385- plt .figure (figsize = (10 , 5 ))
386- plt .plot (x_plot , V (x_plot ), color = "black" , label = "V(x)" )
387-
388- default_init_step = 1e-12
389-
390- for j , (n , current_energy ) in enumerate (zip (ns , energies )):
391- # Boundaries of integration
392- left_boundary = - ((current_energy ) ** 0.25 ) - 1.0
393- right_boundary = - left_boundary
394- midpoint = 0.0
395- chebyshev_order = 32
396-
397- # Initialize Riccati solver
398- riccati_info = ric .Init (
399- w_gen (current_energy ),
400- g ,
401- 8 ,
402- max (32 , chebyshev_order ),
403- chebyshev_order ,
404- chebyshev_order ,
405- )
406- # Tolerances
407- eps = 1e-12
408- eps_h = eps * 1e-1
409- # First integration range
410- first_range = (left_boundary , right_boundary / 2.0 )
411- init_step = ric .choose_nonosc_stepsize (riccati_info , * first_range , eps_h )
412- if init_step == 0 :
413- init_step = default_init_step
414- print ("iteration:" , j )
415- print ("quantum_number:" , n )
416- print ("left_boundary:" , left_boundary )
417- print ("right_boundary:" , right_boundary )
418- print ("midpoint:" , midpoint )
419- print ("current_energy:" , current_energy )
420- print ("init_step:" , init_step )
421- # Solve from left_boundary up to right_boundary/2
422- full_range = (left_boundary , right_boundary )
423- x_values = np .linspace (* full_range , 50_000 )
424- first_slice = x_values [x_values <= (right_boundary / 2.0 )]
425- left_solution = ric .evolve (
426- riccati_info ,
427- * first_range ,
428- complex (0 ),
429- complex (1e-8 ),
430- eps ,
431- eps_h ,
432- init_step ,
433- first_slice ,
434- True ,
435- )
436- left_times = left_solution [0 ]
437- left_wavefunction = left_solution [6 ]
438- left_step_types = left_solution [5 ]
439- # Print debug info
440- for i_val in range (len (left_solution )):
441- print ("i:" , i_val , "\t " , left_solution [i_val ])
442- # Find first Riccati index
443- first_riccati_index = len (left_step_types ) - 1
444- for idx , step_type in enumerate (left_step_types ):
445- if step_type == 1 and 0 not in left_step_types [idx :]:
446- first_riccati_index = idx
447- break
448- print ("first_riccati_index:" , first_riccati_index )
449- print ("range:" , (left_times [first_riccati_index ], midpoint ))
450- # Solve from right_boundary back to right_boundary/2 (or full range, whichever you need)
451- init_step = ric .choose_nonosc_stepsize (riccati_info , * full_range , eps_h )
452- if init_step == 0 :
453- init_step = default_init_step
454- if full_range [0 ] > full_range [1 ]:
455- init_step = - init_step
456- print ("init_step:" , init_step )
457- second_slice = x_values [x_values >= (right_boundary / 2.0 )]
458- right_solution = ric .evolve (
459- riccati_info ,
460- * full_range ,
461- complex (0 ),
462- complex (1e-8 ),
463- eps ,
464- eps_h ,
465- init_step ,
466- second_slice ,
467- True ,
468- )
469- right_wavefunction = right_solution [6 ]
470- # Combine left and right solutions for plotting
471- combined_wavefunction = np .concatenate ((left_wavefunction , right_wavefunction ))
472- max_val = np .max (np .real (combined_wavefunction ))
473- scaled_wavefunction = (
474- combined_wavefunction / max_val * 4.0 * np .sqrt (current_energy )
475- )
476- plt .plot (
477- x_values ,
478- scaled_wavefunction + current_energy ,
479- color = f"C{ j } " ,
480- label = f"$\\ Psi_n(x)$, n={ n } , $E_n$={ current_energy :.4f} " ,
481- )
482-
483- plt .xlabel ("x" )
484- plt .legend (loc = "lower left" )
485- plt .show ()
486-
487- # %%
0 commit comments