1010#define _USE_MATH_DEFINES
1111#include " cvector.hpp" // for CVector
1212#include < cmath> // for M_PI
13- #include < stdexcept> // for runtime_error
14- #include < utility> // for move
15- #include < vector> // for vector
13+ #include < pybind11/pybind11.h>
14+ #include < stdexcept> // for runtime_error
15+ #include < utility> // for move
16+ #include < vector> // for vector
1617
1718enum UpdateType {
1819 constant,
@@ -23,7 +24,8 @@ enum UpdateType {
2324 halfsine,
2425 trapezoid,
2526 gaussimpulse,
26- gaussstep
27+ gaussstep,
28+ custom = 100
2729};
2830
2931template <typename T> class Driver {
@@ -65,24 +67,23 @@ template <typename T> class ScalarDriver : public Driver<T> {
6567private:
6668 T edgeTime = 0 ;
6769 T steadyTime = 0 ;
70+ pybind11::function m_callback;
6871
6972protected:
7073 T stepUpdate (T amplitude, T time, T timeStart, T timeStop) {
7174 if (time >= timeStart && time <= timeStop) {
7275 return amplitude;
73- } else {
74- return 0.0 ;
7576 }
77+ return 0.0 ;
7678 }
7779 T pulseTrain (T amplitude, T time, T period, T cycle) {
7880 const int n = static_cast <int >(time / period);
7981 const T dT = cycle * period;
8082 const T nT = n * period;
8183 if (nT <= time && time <= (nT + dT)) {
8284 return amplitude;
83- } else {
84- return 0 ;
8585 }
86+ return 0.0 ;
8687 }
8788
8889 T trapezoidalUpdate (T amplitude, T time, T timeStart, T edgeTime,
@@ -110,7 +111,8 @@ template <typename T> class ScalarDriver : public Driver<T> {
110111 explicit ScalarDriver (UpdateType update = constant, T constantValue = 0 ,
111112 T amplitude = 0 , T frequency = -1 , T phase = 0 ,
112113 T period = -1 , T cycle = -1 , T timeStart = -1 ,
113- T timeStop = -1 , T edgeTime = -1 , T steadyTime = -1 )
114+ T timeStop = -1 , T edgeTime = -1 , T steadyTime = -1 ,
115+ pybind11::function m_callback = pybind11::function())
114116 : Driver<T>(update, constantValue, amplitude, frequency, phase, period,
115117 cycle, timeStart, timeStop) {
116118 this ->edgeTime = edgeTime;
@@ -122,6 +124,7 @@ template <typename T> class ScalarDriver : public Driver<T> {
122124 throw std::runtime_error (
123125 " Selected sine driver type but frequency was not set" );
124126 }
127+ this ->m_callback = m_callback;
125128 }
126129
127130 /* *
@@ -241,6 +244,21 @@ template <typename T> class ScalarDriver : public Driver<T> {
241244 t0, -1 , sigma);
242245 }
243246
247+ static ScalarDriver getCustomDriver (pybind11::function callback) {
248+ if (!callback) {
249+ throw std::runtime_error (" Callback function is not set" );
250+ }
251+ // check if the callback is callable has one argument
252+ // Check if the callback function has exactly one argument
253+ // Using cast to int to avoid type mismatch error
254+ if (pybind11::cast<int >(callback.attr (" __code__" ).attr (" co_argcount" )) !=
255+ 1 ) {
256+ throw std::runtime_error (" Callback function must have one argument" );
257+ }
258+
259+ return ScalarDriver (custom, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , callback);
260+ }
261+
244262 T getCurrentScalarValue (T &time) override {
245263 T returnValue = this ->constantValue ;
246264 if (this ->update == pulse) {
@@ -273,6 +291,15 @@ template <typename T> class ScalarDriver : public Driver<T> {
273291 0.5 * this ->amplitude *
274292 (1 + std::erf ((time - this ->timeStart ) / (sqrt (2 ) * this ->edgeTime )));
275293 returnValue += gaussStep;
294+ } else if (this ->update == custom) {
295+ // If it is, call the Python function
296+ pybind11::gil_scoped_acquire gil;
297+ try {
298+ return pybind11::cast<double >(m_callback (time));
299+ } catch (pybind11::error_already_set &e) {
300+ std::cerr << " Error in Python callback: " << e.what () << std::endl;
301+ throw std::runtime_error (" Error in Python callback" );
302+ }
276303 }
277304 return returnValue;
278305 }
0 commit comments