Skip to content

Commit 3890d51

Browse files
authored
Merge pull request #10 from SteveBronder/fix/complex-omega-fun
Allow complex omega and gamma return values
2 parents 46068fb + be28071 commit 3890d51

26 files changed

+1468
-1057
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ project(
55
LANGUAGES C CXX)
66

77
include(FetchContent)
8-
set(CMAKE_CXX_STANDARD 20)
8+
set(CMAKE_CXX_STANDARD 17)
99
set(CMAKE_CXX_STANDARD_REQUIRED NO)
1010
cmake_policy(SET CMP0135 NEW)
1111
cmake_policy(SET CMP0077 NEW)
1212
set(CMAKE_CXX_EXTENSIONS NO)
13-
if (CMAKE_BUILD_TYPE MATCHES Debug)
13+
if (CMAKE_BUILD_TYPE MATCHES DEBUG)
1414
set(CMAKE_VERBOSE_MAKEFILE YES)
1515
endif()
1616
option(RICCATI_BUILD_TESTING "Build the test targets for the library" OFF)
@@ -69,7 +69,7 @@ endif()
6969

7070
if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
7171
set(CMAKE_CXX_FLAGS_RELEASE
72-
"-O3 -march=native -mtune=native -DRICCATI_DEBUG=false"
72+
"-O3 -march=native -mtune=native"
7373
CACHE STRING "Flags used by the C++ compiler during Release builds."
7474
FORCE)
7575
endif()

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ make -j4 riccati_test && ctest
9696
For the python tests and benchmarks it is recommended to setup a virtual environment
9797

9898
```bash
99-
git checkout feature/benchmarks
10099
# From the top level of this directory
101100
python -m venv ./.venv
102101
source ./.venv/bin/activate

benchmarks/schrodinger_eq.py

Lines changed: 7 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from typing import Any, Callable, Dict, List, Tuple
1616
from collections.abc import Iterable
1717
import scipy.optimize as sci_opt
18-
import matplotlib
19-
from matplotlib import pyplot as plt
2018

2119

2220
class Algo(Enum):
@@ -157,7 +155,8 @@ def potential(self, x):
157155
"""
158156
Potential function V(x) = x^2 + l*x^4
159157
"""
160-
return x**2 + self.l * x**4
158+
x_sq = x**2
159+
return x_sq + (self.l * x_sq * x_sq)
161160

162161
def analytic_energy(n):
163162
"""
@@ -166,16 +165,15 @@ def analytic_energy(n):
166165
return np.sqrt(2.0) * (n - 0.5)
167166

168167
def w_gen(self, energy):
169-
return lambda x: np.sqrt(2 * self.m * (complex(energy) - self.potential(x)))
168+
return lambda x: np.sqrt(2.0 * self.m * (complex(energy) - self.potential(x)))
170169

171170
def g_gen(self):
172171
return lambda x: np.zeros_like(x)
173172

174173
def f_gen(self, energy):
175174
def f(x, y):
176175
psi, dpsi = y
177-
return [dpsi, -2 * self.m * (complex(energy) - self.potential(x)) * psi]
178-
176+
return [dpsi, (self.potential(x) - energy) * psi]
179177
return f
180178

181179
def yi_init(self):
@@ -297,7 +295,7 @@ def flatten_tuple(x):
297295
base_output_path = "./benchmarks/output/"
298296
all_algo_pl_lst: List[pl.DataFrame] = []
299297
first_write = True
300-
with open(base_output_path + "schrodinger_times2.csv", mode="a") as time_file:
298+
with open(base_output_path + "schrodinger_times.csv", mode="a") as time_file:
301299
for algo, algo_params in algorithm_dict.items():
302300
algo_evals_pl_lst = []
303301
for benchmark_run in range(1):
@@ -339,7 +337,7 @@ def flatten_tuple(x):
339337
algo_evals_pl_lst.append(algo_pl_tmp)
340338
algo_pl = pl.concat(algo_evals_pl_lst)
341339
print(algo_pl)
342-
algo_pl.write_csv(base_output_path + f"schrod2_{str(algo)}.csv")
340+
algo_pl.write_csv(base_output_path + f"schrod_{str(algo)}.csv")
343341
all_algo_pl_lst.append(algo_pl)
344342
time_pl_lst = []
345343
for algo_key, time_st in global_timer.execs.items():
@@ -361,7 +359,7 @@ def flatten_tuple(x):
361359
time_pl.write_csv(time_file, include_header=False)
362360

363361
all_algo_pl = pl.concat(all_algo_pl_lst)
364-
all_algo_pl.write_csv(f"{base_output_path}schrod2.csv")
362+
all_algo_pl.write_csv(f"{base_output_path}schrod.csv")
365363
# %%
366364
# %%
367365
time_pl_lst = []
@@ -376,112 +374,3 @@ def flatten_tuple(x):
376374
time_pl.write_csv(base_output_path + "schrodinger_times2.csv")
377375

378376

379-
# %%
380-
if False:
381-
ns = [50, 100]
382-
energies = solution_lst[:2]
383-
384-
x_plot = np.linspace(-6, 6, 500)
385-
plt.figure(figsize=(10, 5))
386-
plt.plot(x_plot, V(x_plot), color="black", label="V(x)")
387-
388-
default_init_step = 1e-12
389-
390-
for j, (n, current_energy) in enumerate(zip(ns, energies)):
391-
# Boundaries of integration
392-
left_boundary = -((current_energy) ** 0.25) - 1.0
393-
right_boundary = -left_boundary
394-
midpoint = 0.0
395-
chebyshev_order = 32
396-
397-
# Initialize Riccati solver
398-
riccati_info = ric.Init(
399-
w_gen(current_energy),
400-
g,
401-
8,
402-
max(32, chebyshev_order),
403-
chebyshev_order,
404-
chebyshev_order,
405-
)
406-
# Tolerances
407-
eps = 1e-12
408-
eps_h = eps * 1e-1
409-
# First integration range
410-
first_range = (left_boundary, right_boundary / 2.0)
411-
init_step = ric.choose_nonosc_stepsize(riccati_info, *first_range, eps_h)
412-
if init_step == 0:
413-
init_step = default_init_step
414-
print("iteration:", j)
415-
print("quantum_number:", n)
416-
print("left_boundary:", left_boundary)
417-
print("right_boundary:", right_boundary)
418-
print("midpoint:", midpoint)
419-
print("current_energy:", current_energy)
420-
print("init_step:", init_step)
421-
# Solve from left_boundary up to right_boundary/2
422-
full_range = (left_boundary, right_boundary)
423-
x_values = np.linspace(*full_range, 50_000)
424-
first_slice = x_values[x_values <= (right_boundary / 2.0)]
425-
left_solution = ric.evolve(
426-
riccati_info,
427-
*first_range,
428-
complex(0),
429-
complex(1e-8),
430-
eps,
431-
eps_h,
432-
init_step,
433-
first_slice,
434-
True,
435-
)
436-
left_times = left_solution[0]
437-
left_wavefunction = left_solution[6]
438-
left_step_types = left_solution[5]
439-
# Print debug info
440-
for i_val in range(len(left_solution)):
441-
print("i:", i_val, "\t", left_solution[i_val])
442-
# Find first Riccati index
443-
first_riccati_index = len(left_step_types) - 1
444-
for idx, step_type in enumerate(left_step_types):
445-
if step_type == 1 and 0 not in left_step_types[idx:]:
446-
first_riccati_index = idx
447-
break
448-
print("first_riccati_index:", first_riccati_index)
449-
print("range:", (left_times[first_riccati_index], midpoint))
450-
# Solve from right_boundary back to right_boundary/2 (or full range, whichever you need)
451-
init_step = ric.choose_nonosc_stepsize(riccati_info, *full_range, eps_h)
452-
if init_step == 0:
453-
init_step = default_init_step
454-
if full_range[0] > full_range[1]:
455-
init_step = -init_step
456-
print("init_step:", init_step)
457-
second_slice = x_values[x_values >= (right_boundary / 2.0)]
458-
right_solution = ric.evolve(
459-
riccati_info,
460-
*full_range,
461-
complex(0),
462-
complex(1e-8),
463-
eps,
464-
eps_h,
465-
init_step,
466-
second_slice,
467-
True,
468-
)
469-
right_wavefunction = right_solution[6]
470-
# Combine left and right solutions for plotting
471-
combined_wavefunction = np.concatenate((left_wavefunction, right_wavefunction))
472-
max_val = np.max(np.real(combined_wavefunction))
473-
scaled_wavefunction = (
474-
combined_wavefunction / max_val * 4.0 * np.sqrt(current_energy)
475-
)
476-
plt.plot(
477-
x_values,
478-
scaled_wavefunction + current_energy,
479-
color=f"C{j}",
480-
label=f"$\\Psi_n(x)$, n={n}, $E_n$={current_energy:.4f}",
481-
)
482-
483-
plt.xlabel("x")
484-
plt.legend(loc="lower left")
485-
plt.show()
486-
487-
# %%

include/riccati/arena_matrix.hpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ class arena_matrix : public Eigen::Map<MatrixType> {
6060
* @param allocator The allocator to receive memory from
6161
* @param size number of elements
6262
*/
63-
template <typename T>
64-
arena_matrix(arena_allocator<T, arena_alloc>& allocator, Eigen::Index size)
63+
template <typename T, typename Int,
64+
std::enable_if_t<std::is_integral_v<Int>>* = nullptr>
65+
arena_matrix(arena_allocator<T, arena_alloc>& allocator, Int size)
6566
: Base::Map(allocator.template allocate<Scalar>(size), size),
6667
allocator_(allocator) {}
6768

@@ -70,19 +71,19 @@ class arena_matrix : public Eigen::Map<MatrixType> {
7071
* @param allocator The allocator to receive memory from
7172
* @param other expression
7273
*/
73-
template <typename T, typename Expr>
74+
template <typename T, typename Expr, require_eigen<Expr>* = nullptr>
7475
arena_matrix(arena_allocator<T, arena_alloc>& allocator,
7576
const Expr& other) // NOLINT
7677
: Base::Map(
77-
allocator.template allocate<Scalar>(other.size()),
78-
(RowsAtCompileTime == 1 && Expr::ColsAtCompileTime == 1)
79-
|| (ColsAtCompileTime == 1 && Expr::RowsAtCompileTime == 1)
80-
? other.cols()
81-
: other.rows(),
82-
(RowsAtCompileTime == 1 && Expr::ColsAtCompileTime == 1)
83-
|| (ColsAtCompileTime == 1 && Expr::RowsAtCompileTime == 1)
84-
? other.rows()
85-
: other.cols()),
78+
allocator.template allocate<Scalar>(other.size()),
79+
(RowsAtCompileTime == 1 && Expr::ColsAtCompileTime == 1)
80+
|| (ColsAtCompileTime == 1 && Expr::RowsAtCompileTime == 1)
81+
? other.cols()
82+
: other.rows(),
83+
(RowsAtCompileTime == 1 && Expr::ColsAtCompileTime == 1)
84+
|| (ColsAtCompileTime == 1 && Expr::RowsAtCompileTime == 1)
85+
? other.rows()
86+
: other.cols()),
8687
allocator_(allocator) {
8788
allocator_.owns_alloc_ = false;
8889
(*this).noalias() = other;
@@ -168,18 +169,33 @@ inline auto to_arena(dummy_allocator& arena, const Expr& expr) noexcept {
168169
return eval(expr);
169170
}
170171

171-
template <typename T, typename Expr>
172-
inline auto empty_arena_matrix(arena_allocator<T, arena_alloc>& alloc, Expr&& expr) {
172+
template <typename Expr, typename T>
173+
inline auto empty_arena_matrix(arena_allocator<T, arena_alloc>& alloc,
174+
Expr&& expr) {
173175
using plain_type_t = typename std::decay_t<Expr>::PlainObject;
174176
return arena_matrix<plain_type_t>(alloc, expr.rows(), expr.cols());
175177
}
176178

177-
template <typename T, typename Expr>
179+
template <typename Expr>
178180
inline auto empty_arena_matrix(dummy_allocator& arena, Expr&& expr) {
179181
using plain_type_t = typename std::decay_t<Expr>::PlainObject;
180182
return plain_type_t(expr.rows(), expr.cols());
181183
}
182184

185+
template <typename Expr, typename T>
186+
inline auto empty_arena_matrix(arena_allocator<T, arena_alloc>& alloc,
187+
Eigen::Index rows, Eigen::Index cols) {
188+
using plain_type_t = typename std::decay_t<Expr>::PlainObject;
189+
return arena_matrix<plain_type_t>(alloc, rows, cols);
190+
}
191+
192+
template <typename Expr>
193+
inline auto empty_arena_matrix(dummy_allocator& arena, Eigen::Index rows,
194+
Eigen::Index cols) {
195+
using plain_type_t = typename std::decay_t<Expr>::PlainObject;
196+
return plain_type_t(rows, cols);
197+
}
198+
183199
template <typename T>
184200
inline void print(const char* name, const arena_matrix<T>& x) {
185201
#ifdef RICCATI_DEBUG

include/riccati/chebyshev.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,18 @@ template <typename SolverInfo, typename Scalar, typename YScalar,
398398
RICCATI_ALWAYS_INLINE auto spectral_chebyshev(SolverInfo&& info, Scalar x0,
399399
Scalar h, YScalar y0, YScalar dy0,
400400
Integral niter) {
401-
using complex_t = std::complex<Scalar>;
401+
using complex_t = promote_complex_t<Scalar>;
402402
using vectorc_t = vector_t<complex_t>;
403403
auto x_scaled = eval(
404404
info.alloc_, riccati::scale(std::get<2>(info.chebyshev_[niter]), x0, h));
405405
auto&& D = info.Dn(niter);
406406
auto ws = omega(info, x_scaled);
407407
auto gs = gamma(info, x_scaled);
408-
auto D2 = eval(info.alloc_,
409-
((D * D) + h * (gs.asDiagonal() * D)));
410-
D2 += ((ws * h / 2.0).array().square()).matrix().asDiagonal();
408+
auto D2 = eval(info.alloc_, ((D * D) + h * (gs.asDiagonal() * D))
409+
+ ((ws * h / 2.0).array().square())
410+
.matrix()
411+
.asDiagonal()
412+
.toDenseMatrix());
411413
const auto n = std::round(std::get<0>(info.chebyshev_[niter]));
412414
auto D2ic = eval(info.alloc_, matrix_t<complex_t>::Zero(n + 3, n + 1));
413415
D2ic.topRows(n + 1) = D2;

0 commit comments

Comments
 (0)