Skip to content

Commit 0fdd859

Browse files
authored
Merge pull request #8 from SteveBronder/feature/logger
Adds a logger for the c++ and python code
2 parents a946154 + 4be3081 commit 0fdd859

File tree

13 files changed

+625
-77
lines changed

13 files changed

+625
-77
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ if (RICCATI_BUILD_PYTHON)
158158
target_link_libraries(pyriccaticpp PRIVATE pybind11::headers riccati Eigen3::Eigen)
159159

160160
# This is passing in the version as a define just as an example
161-
target_compile_definitions(pyriccaticpp PRIVATE VERSION_INFO=${PROJECT_VERSION})
161+
target_compile_definitions(pyriccaticpp PRIVATE VERSION_INFO=${PROJECT_VERSION} RICCATI_PYTHON)
162162
if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
163163
target_compile_options(pyriccaticpp PRIVATE -Wno-deprecated-declarations)
164164
endif()

include/riccati/chebyshev.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define INCLUDE_RICCATI_CHEBYSHEV_HPP
33

44
#include <riccati/arena_matrix.hpp>
5+
#include <riccati/logger.hpp>
56
#include <riccati/memory.hpp>
67
#include <riccati/utils.hpp>
78
#include <unsupported/Eigen/FFT>

include/riccati/evolve.hpp

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ inline auto nonosc_evolve(SolverInfo &&info, Scalar xi, Scalar xf,
310310
* equations.
311311
* 6. @ref Eigen::Matrix<std::complex<Scalar>, -1, 1>: A vector containing the
312312
* interpolated solution at the specified `x_eval` The function returns these
313+
* 7. An array of logging information for the solver process. This can be added with the
314+
* info() from the `SolverInfo` object to get the total log information.
313315
* vectors encapsulated in a standard tuple, providing comprehensive information
314316
* about the solution process, including where the solution was evaluated, the
315317
* values and derivatives of the solution at those points, success status of
@@ -319,15 +321,16 @@ template <typename SolverInfo, typename Scalar, typename Vec>
319321
inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
320322
std::complex<Scalar> yi, std::complex<Scalar> dyi,
321323
Scalar eps, Scalar epsilon_h, Scalar init_stepsize,
322-
Vec &&x_eval, bool hard_stop = false) {
324+
Vec &&x_eval, bool hard_stop = false,
325+
LogLevel log_level = riccati::LogLevel::ERROR) {
323326
static_assert(std::is_floating_point<Scalar>::value,
324327
"Scalar type must be a floating-point type.");
325328
using vectord_t = vector_t<Scalar>;
326329
Scalar direction = init_stepsize > 0 ? 1 : -1;
327330
if (init_stepsize * (xf - xi) < 0) {
328331
throw std::domain_error(
329332
"Direction of integration does not match stepsize sign,"
330-
" adjusting it so that integration happens from xi to xf.");
333+
" adjust the direction or stepsize so that integration happens from xi to xf.");
331334
}
332335
// Check that yeval and x_eval are right size
333336
constexpr bool dense_output = compile_size_v<Vec> != 0;
@@ -436,9 +439,9 @@ inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
436439
matrixc_t y_eval;
437440
matrixc_t dy_eval;
438441
std::pair<complex_t, complex_t> a_pair;
442+
std::array<std::pair<LogInfo, std::size_t>, 5> solver_counts = info.info();
439443
while (std::abs(xcurrent - xf) > Scalar(1e-8)
440444
&& direction * xcurrent < direction * xf) {
441-
// std::cout << "t = " << xcurrent << std::endl;
442445
Scalar phase{0.0};
443446
bool success = false;
444447
bool steptype = true;
@@ -447,7 +450,6 @@ inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
447450
arena_matrix<vectorc_t> un(info.alloc_, omega_n.size(), 1);
448451
arena_matrix<vectorc_t> d_un(info.alloc_, omega_n.size(), 1);
449452
if (direction * hosc > direction * hslo){
450-
// && (direction * hosc * wnext / (2.0 * pi<Scalar>()) > 1.0)) {
451453
if (hard_stop) {
452454
if (direction * (xcurrent + hosc) > direction * xf) {
453455
hosc = xf - xcurrent;
@@ -463,13 +465,27 @@ inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
463465
std::tie(success, y, dy, err, phase, un, d_un, a_pair)
464466
= osc_step<dense_output>(info, omega_n, gamma_n, xcurrent, hosc,
465467
yprev, dyprev, eps);
466-
// std::cout << "Attempted osc step with hosc = " << hosc << ", successful? " << success << std::endl;
467468
steptype = 1;
469+
solver_counts[get_idx(LogInfo::RICCSTEP)].second++;
470+
}
471+
if (unlikely(log_level == LogLevel::INFO && !success)) {
472+
info.logger().template log<LogLevel::INFO>(
473+
std::string("[Non-oscillatory step][x = ") +
474+
std::to_string(xcurrent) + std::string("][stepsize =") +
475+
std::to_string(hslo) + std::string("]")
476+
);
477+
} else if (unlikely(log_level == LogLevel::INFO)) {
478+
info.logger().template log<LogLevel::INFO>(
479+
std::string("[Oscillatory step][x = ") +
480+
std::to_string(xcurrent) + std::string("][stepsize =") +
481+
std::to_string(hslo) + std::string("]")
482+
);
468483
}
469484
while (!success) {
470485
std::tie(success, y, dy, err, y_eval, dy_eval, cheb_N)
471486
= nonosc_step(info, xcurrent, hslo, yprev, dyprev, eps);
472-
// std::cout << "Attempted nonosc step with hslo = " << hslo << ", successful? " << success << std::endl;
487+
solver_counts[get_idx(LogInfo::CHEBSTEP)].second++;
488+
solver_counts[get_idx(LogInfo::LS)].second += cheb_N + 1;
473489
steptype = 0;
474490
if (!success) {
475491
hslo *= Scalar{0.5};
@@ -487,6 +503,13 @@ inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
487503
std::tie(dense_start, dense_size)
488504
= get_slice(x_eval, xcurrent, (xcurrent + h));
489505
if (dense_size != 0) {
506+
if (unlikely(log_level == LogLevel::INFO)) {
507+
info.logger().template log<LogLevel::INFO>(
508+
std::string("[Dense output][x_start = ") +
509+
std::to_string(xcurrent) + std::string("][x_end =") +
510+
std::to_string(xcurrent + h) + std::string("]")
511+
);
512+
}
490513
auto x_eval_map
491514
= Eigen::Map<vectord_t>(x_eval.data() + dense_start, dense_size);
492515
auto y_eval_map
@@ -564,11 +587,29 @@ inline auto evolve(SolverInfo &info, Scalar xi, Scalar xf,
564587
}
565588
info.alloc_.recover_memory();
566589
}
567-
#ifdef RICCATI_DEBUG
568-
std::cout << "Total riccati steps: " << successes.size() << std::endl;
569-
#endif
590+
if constexpr (!std::is_same_v<std::decay_t<decltype(info.logger())>, EmptyLogger>) {
591+
if (unlikely(log_level == LogLevel::INFO)) {
592+
std::size_t riccati_steps = 0;
593+
std::size_t rk_steps = 0;
594+
for (auto& success : successes) {
595+
if (success) {
596+
riccati_steps++;
597+
} else {
598+
rk_steps++;
599+
}
600+
}
601+
info.logger().template log<LogLevel::INFO>(
602+
std::string("Total Steps = ") + std::to_string(successes.size())
603+
);
604+
for (auto&& info_pair : solver_counts) {
605+
info.logger().template log<LogLevel::INFO>(
606+
std::string(to_string(info_pair.first)) + " = " + std::to_string(info_pair.second)
607+
);
608+
}
609+
}
610+
}
570611
return std::make_tuple(xs, ys, dys, successes, phases, steptypes, yeval,
571-
dyeval);
612+
dyeval, 1.0);
572613
}
573614

574615
} // namespace riccati

include/riccati/logger.hpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#ifndef INCLUDE_RICCATI_LOGGER_HPP
2+
#define INCLUDE_RICCATI_LOGGER_HPP
3+
4+
#include <riccati/macros.hpp>
5+
#include <riccati/utils.hpp>
6+
#include <ctime>
7+
#include <fstream>
8+
#include <iomanip>
9+
#include <iostream>
10+
#include <memory>
11+
#include <sstream>
12+
#include <array>
13+
#include <utility>
14+
#include <stdexcept>
15+
16+
namespace riccati {
17+
18+
/**
19+
* @brief Enumeration of different log levels for logging.
20+
*/
21+
enum class LogLevel {
22+
ERROR, // Error messages.
23+
WARNING, // Warning messages.
24+
INFO, // General information messages.
25+
DEBUG, // Detailed debug information.
26+
};
27+
28+
/**
29+
* @brief Enumeration of different logging information keys.
30+
*/
31+
enum class LogInfo {
32+
CHEBNODES, // Chebyshev nodes.
33+
CHEBSTEP, // Chebyshev steps.
34+
CHEBITS, // Chebyshev iterates.
35+
LS, // Linear system solves.
36+
RICCSTEP // Riccati steps.
37+
};
38+
39+
/**
40+
* @brief Converts a LogInfo enum value to its string representation.
41+
*
42+
* @param log_info The LogInfo enum value to convert.
43+
* @return The string representation of the LogInfo value.
44+
*/
45+
inline constexpr auto to_string(const LogInfo& log_info) noexcept {
46+
if (log_info == LogInfo::CHEBNODES) {
47+
return "chebyshev_nodes";
48+
} else if (log_info == LogInfo::CHEBSTEP) {
49+
return "chebyshev_steps";
50+
} else if (log_info == LogInfo::CHEBITS) {
51+
return "chebyshev_iterates";
52+
} else if (log_info == LogInfo::LS) {
53+
return "linear_system_solves";
54+
} else if (log_info == LogInfo::RICCSTEP) {
55+
return "riccati_steps";
56+
} else {
57+
return "UNKNOWN_INFO";
58+
}
59+
}
60+
/**
61+
* @brief Get the index value for the log info in the default logger
62+
*/
63+
inline constexpr auto get_idx(const LogInfo& log_info) {
64+
if (log_info == LogInfo::CHEBNODES) {
65+
return 0;
66+
} else if (log_info == LogInfo::CHEBSTEP) {
67+
return 1;
68+
} else if (log_info == LogInfo::CHEBITS) {
69+
return 2;
70+
} else if (log_info == LogInfo::LS) {
71+
return 3;
72+
} else if (log_info == LogInfo::RICCSTEP) {
73+
return 4;
74+
} else {
75+
throw std::invalid_argument("Invalid LogInfo key!");
76+
return 0;
77+
}
78+
}
79+
80+
/**
81+
* @brief Retrieves the string representation of a LogLevel value.
82+
*
83+
* @tparam Level The LogLevel value.
84+
* @return The string representation of the LogLevel.
85+
*/
86+
template <LogLevel Level>
87+
inline constexpr auto log_level() noexcept {
88+
if constexpr (Level == LogLevel::DEBUG) {
89+
return "[DEBUG]";
90+
} else if constexpr (Level == LogLevel::INFO) {
91+
return "[INFO]";
92+
} else if constexpr (Level == LogLevel::WARNING) {
93+
return "[WARNING]";
94+
} else if constexpr (Level == LogLevel::ERROR) {
95+
return "[ERROR]";
96+
} else {
97+
static_assert(1, "Invalid LogLevel!");
98+
}
99+
}
100+
101+
/**
102+
* @brief Base class template for loggers.
103+
*
104+
* This class provides a common interface for logging and updating log
105+
* information. All loggers in Riccati should inherit from this logger
106+
* using the CRTP form `class Logger : public LoggerBase<Derived>`.
107+
* See \ref `riccati::DefaultLogger` for an example.
108+
*
109+
* @tparam Derived The derived logger class.
110+
*/
111+
template <typename Derived>
112+
class LoggerBase {
113+
inline Derived& underlying() noexcept { return static_cast<Derived&>(*this); }
114+
inline Derived const& underlying() const noexcept {
115+
return static_cast<Derived const&>(*this);
116+
}
117+
118+
public:
119+
/**
120+
* @brief Logs a message with a specified log level.
121+
*
122+
* @tparam Level The log level.
123+
* @tparam Types Variadic template parameter pack for message arguments.
124+
* @param arg Message arguments to log.
125+
*/
126+
template <LogLevel Level, typename... Types>
127+
inline void log(Types&&... args) {
128+
this->underlying().template log<Level>(std::forward<Types>(args)...);
129+
}
130+
};
131+
132+
/**
133+
* @brief A deleter that performs no operation.
134+
*
135+
* This is used with smart pointers where no deletion is required.
136+
*/
137+
struct deleter_noop {
138+
/**
139+
* @brief No-op function call operator.
140+
*
141+
* @tparam T The type of the pointer.
142+
* @param arg The pointer to which the deleter is applied.
143+
*/
144+
template <typename T>
145+
constexpr void operator()(T* arg) const {}
146+
};
147+
148+
/**
149+
* A logger class where all the operations are noops
150+
*/
151+
class EmptyLogger : public LoggerBase<EmptyLogger> {
152+
public:
153+
template <LogLevel Level, typename... Types>
154+
inline constexpr void log(Types&&... args) const noexcept {}
155+
};
156+
157+
/**
158+
* @brief A simple logger class template.
159+
*
160+
* This class provides basic logging functionality with customizable output
161+
* streams.
162+
*
163+
* @tparam Stream The type of the output stream.
164+
* @tparam StreamDeleter The deleter type for the output stream.
165+
*/
166+
template <typename Ptr>
167+
class PtrLogger : public LoggerBase<PtrLogger<Ptr>> {
168+
public:
169+
/**
170+
* @brief The output stream for logging messages.
171+
*/
172+
Ptr output_{};
173+
174+
/**
175+
* @brief Default constructor.
176+
*/
177+
PtrLogger() = default;
178+
179+
/**
180+
* @brief Constructor with a custom output stream.
181+
*
182+
* @param output The output stream to use for logging.
183+
*/
184+
template <typename Stream, typename StreamDeleter>
185+
RICCATI_NO_INLINE explicit PtrLogger(
186+
std::unique_ptr<Stream, StreamDeleter>&& output)
187+
: output_(std::move(output)) {}
188+
template <typename Stream>
189+
RICCATI_NO_INLINE explicit PtrLogger(const std::shared_ptr<Stream>& output)
190+
: output_(output) {}
191+
192+
/**
193+
* @brief Logs a message with a specified log level.
194+
*
195+
* @tparam Level The log level.
196+
* @param msg The message to log.
197+
*/
198+
template <LogLevel Level>
199+
inline void log(std::string_view msg) {
200+
#ifdef RICCATI_DEBUG
201+
#define RICCATI_DEBUG_VAL true
202+
#else
203+
#define RICCATI_DEBUG_VAL false
204+
#endif
205+
if constexpr (!RICCATI_DEBUG_VAL && Level == LogLevel::DEBUG) {
206+
return;
207+
}
208+
std::string full_msg = log_level<Level>() + time_mi() + "[";
209+
full_msg += msg;
210+
full_msg += std::string("]");
211+
*output_ << full_msg + "\n";
212+
}
213+
};
214+
215+
template <typename Stream, typename StreamDeleter = std::default_delete<Stream>>
216+
using DefaultLogger = PtrLogger<std::unique_ptr<Stream, StreamDeleter>>;
217+
template <typename Stream>
218+
using SharedLogger = PtrLogger<std::shared_ptr<Stream>>;
219+
220+
} // namespace riccati
221+
222+
#endif

include/riccati/memory.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,12 @@ struct arena_allocator {
322322
template <typename U, typename UArena>
323323
RICCATI_NO_INLINE arena_allocator(const arena_allocator<U, UArena>& rhs)
324324
: alloc_(rhs.alloc_), owns_alloc_(false) {}
325+
template <typename U, typename UArena>
326+
RICCATI_NO_INLINE arena_allocator(arena_allocator&& rhs)
327+
: alloc_(rhs.alloc_), owns_alloc_(rhs.owns_alloc_ ? true : false) {
328+
rhs.alloc_ = nullptr;
329+
rhs.owns_alloc_ = false;
330+
}
325331

326332
~arena_allocator() {
327333
if (owns_alloc_) {

0 commit comments

Comments
 (0)