Skip to content

Commit e6fb0f1

Browse files
authored
Merge pull request #96 from LemurPwned/feat/custom-drivers
driver updates
2 parents b9c4706 + f830f09 commit e6fb0f1

File tree

4 files changed

+106
-18
lines changed

4 files changed

+106
-18
lines changed

core/drivers.hpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
#define _USE_MATH_DEFINES
1111
#include "cvector.hpp" // for CVector
1212
#include <cmath> // for M_PI
13-
#include <stdexcept> // for runtime_error
14-
#include <utility> // for move
15-
#include <vector> // for vector
13+
#include <pybind11/pybind11.h>
14+
#include <stdexcept> // for runtime_error
15+
#include <utility> // for move
16+
#include <vector> // for vector
1617

1718
enum UpdateType {
1819
constant,
@@ -23,7 +24,8 @@ enum UpdateType {
2324
halfsine,
2425
trapezoid,
2526
gaussimpulse,
26-
gaussstep
27+
gaussstep,
28+
custom = 100
2729
};
2830

2931
template <typename T> class Driver {
@@ -65,24 +67,23 @@ template <typename T> class ScalarDriver : public Driver<T> {
6567
private:
6668
T edgeTime = 0;
6769
T steadyTime = 0;
70+
pybind11::function m_callback;
6871

6972
protected:
7073
T stepUpdate(T amplitude, T time, T timeStart, T timeStop) {
7174
if (time >= timeStart && time <= timeStop) {
7275
return amplitude;
73-
} else {
74-
return 0.0;
7576
}
77+
return 0.0;
7678
}
7779
T pulseTrain(T amplitude, T time, T period, T cycle) {
7880
const int n = static_cast<int>(time / period);
7981
const T dT = cycle * period;
8082
const T nT = n * period;
8183
if (nT <= time && time <= (nT + dT)) {
8284
return amplitude;
83-
} else {
84-
return 0;
8585
}
86+
return 0.0;
8687
}
8788

8889
T trapezoidalUpdate(T amplitude, T time, T timeStart, T edgeTime,
@@ -110,7 +111,8 @@ template <typename T> class ScalarDriver : public Driver<T> {
110111
explicit ScalarDriver(UpdateType update = constant, T constantValue = 0,
111112
T amplitude = 0, T frequency = -1, T phase = 0,
112113
T period = -1, T cycle = -1, T timeStart = -1,
113-
T timeStop = -1, T edgeTime = -1, T steadyTime = -1)
114+
T timeStop = -1, T edgeTime = -1, T steadyTime = -1,
115+
pybind11::function m_callback = pybind11::function())
114116
: Driver<T>(update, constantValue, amplitude, frequency, phase, period,
115117
cycle, timeStart, timeStop) {
116118
this->edgeTime = edgeTime;
@@ -122,6 +124,7 @@ template <typename T> class ScalarDriver : public Driver<T> {
122124
throw std::runtime_error(
123125
"Selected sine driver type but frequency was not set");
124126
}
127+
this->m_callback = m_callback;
125128
}
126129

127130
/**
@@ -241,6 +244,21 @@ template <typename T> class ScalarDriver : public Driver<T> {
241244
t0, -1, sigma);
242245
}
243246

247+
static ScalarDriver getCustomDriver(pybind11::function callback) {
248+
if (!callback) {
249+
throw std::runtime_error("Callback function is not set");
250+
}
251+
// check if the callback is callable has one argument
252+
// Check if the callback function has exactly one argument
253+
// Using cast to int to avoid type mismatch error
254+
if (pybind11::cast<int>(callback.attr("__code__").attr("co_argcount")) !=
255+
1) {
256+
throw std::runtime_error("Callback function must have one argument");
257+
}
258+
259+
return ScalarDriver(custom, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, callback);
260+
}
261+
244262
T getCurrentScalarValue(T &time) override {
245263
T returnValue = this->constantValue;
246264
if (this->update == pulse) {
@@ -273,6 +291,15 @@ template <typename T> class ScalarDriver : public Driver<T> {
273291
0.5 * this->amplitude *
274292
(1 + std::erf((time - this->timeStart) / (sqrt(2) * this->edgeTime)));
275293
returnValue += gaussStep;
294+
} else if (this->update == custom) {
295+
// If it is, call the Python function
296+
pybind11::gil_scoped_acquire gil;
297+
try {
298+
return pybind11::cast<double>(m_callback(time));
299+
} catch (pybind11::error_already_set &e) {
300+
std::cerr << "Error in Python callback: " << e.what() << std::endl;
301+
throw std::runtime_error("Error in Python callback");
302+
}
276303
}
277304
return returnValue;
278305
}

docs/api/drivers.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,11 @@ from cmtj import (
4545
import numpy as np
4646

4747

48-
class MyDriver(ScalarDriver):
49-
def getCurrentScalarValue(self, time: float) -> float:
50-
return time * np.random.choice([-1, 1])
48+
def my_custom_function(time: float) -> float:
49+
return time * np.random.choice([-1, 1])
5150

52-
53-
driver = MyDriver()
54-
for i in range(10):
55-
print(driver.getCurrentScalarValue(i * 1e-9))
51+
# Create a driver with this function
52+
driver = ScalarDriver.getCustomDriver(my_custom_function)
5653

5754
demag = [CVector(0, 0, 0), CVector(0, 0, 0), CVector(0, 0, 1)]
5855
layer = Layer(

python/cmtj.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ PYBIND11_MODULE(cmtj, m) {
147147
.value("DormandPrice", DORMAND_PRICE)
148148
.export_values();
149149

150+
py::enum_<UpdateType>(m, "UpdateType")
151+
.value("constant", constant)
152+
.value("pulse", pulse)
153+
.value("sine", sine)
154+
.value("step", step)
155+
.value("posine", posine)
156+
.value("halfsine", halfsine)
157+
.value("trapezoid", trapezoid)
158+
.value("gaussimpulse", gaussimpulse)
159+
.value("gaussstep", gaussstep)
160+
.value("custom", custom)
161+
.export_values();
150162
// Driver Class
151163
py::class_<DScalarDriver>(m, "ScalarDriver")
152164
.def(py::init<>())
@@ -174,7 +186,9 @@ PYBIND11_MODULE(cmtj, m) {
174186
"amplitude"_a, "t0"_a, "sigma"_a)
175187
.def_static("getGaussianStepDriver",
176188
&DScalarDriver::getGaussianStepDriver, "constantValue"_a,
177-
"amplitude"_a, "t0"_a, "sigma"_a);
189+
"amplitude"_a, "t0"_a, "sigma"_a)
190+
.def_static("getCustomDriver", &DScalarDriver::getCustomDriver,
191+
"callback"_a);
178192

179193
py::class_<DNullDriver, DScalarDriver>(m, "NullDriver")
180194
.def(py::init<>())

tests/test_drivers.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from cmtj import AxialDriver, CVector, Junction, Layer
1+
from cmtj import AxialDriver, CVector, Junction, Layer, ScalarDriver
22
from cmtj import constantDriver, sineDriver
33
import pytest
4+
import math
5+
46

57
def test_cvector_operators():
68
vec1 = (1.0, 2.0, 3.0)
@@ -84,3 +86,51 @@ def test_junction_with_driver():
8486
AxialDriver(constantDriver(0), constantDriver(0), sineDriver(0, 1e3, 7e9, 0)),
8587
)
8688
junction.runSimulation(10e-9, 1e-13, 1e-13)
89+
90+
91+
def test_custom_driver():
92+
def my_custom_function(time: float) -> float:
93+
return math.sqrt(time)
94+
95+
driver = ScalarDriver.getCustomDriver(my_custom_function)
96+
assert driver.getCurrentScalarValue(1e-9) == math.sqrt(1e-9)
97+
assert driver.getCurrentScalarValue(3e-9) == math.sqrt(3e-9)
98+
99+
# Test with zero time value
100+
assert my_custom_function(0) == 0, "Zero time value should return 0"
101+
102+
# Test error condition: Passing None should raise a TypeError (since None cannot be multiplied by an int)
103+
with pytest.raises(TypeError):
104+
my_custom_function(None)
105+
106+
107+
def test_custom_driver_exception():
108+
# Test error condition: A callback that raises an exception
109+
def faulty_callback(time: float) -> float:
110+
raise ValueError("Intentional exception")
111+
112+
driver = ScalarDriver.getCustomDriver(faulty_callback)
113+
with pytest.raises(RuntimeError):
114+
driver.getCurrentScalarValue(1e-9)
115+
116+
117+
def test_no_argument_callback():
118+
# Test error condition: A callback that raises an exception
119+
def faulty_callback():
120+
raise ValueError("Intentional exception")
121+
122+
with pytest.raises(RuntimeError):
123+
driver = ScalarDriver.getCustomDriver(faulty_callback)
124+
125+
126+
def test_custom_on_junction(single_layer_mtj):
127+
junction, _ = single_layer_mtj
128+
129+
def my_custom_function(time: float) -> float:
130+
return 12345
131+
132+
driver = ScalarDriver.getCustomDriver(my_custom_function)
133+
junction.setLayerAnisotropyDriver("all", driver)
134+
junction.runSimulation(10e-9, 1e-13, 1e-13, calculateEnergies=True)
135+
log = junction.getLog()
136+
assert all(x == 12345 for x in log["free_K"])

0 commit comments

Comments
 (0)