Skip to content

Commit 5a98b10

Browse files
committed
Scipy: working in the hessian implementation
1 parent db5c3d5 commit 5a98b10

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

math/scipy/inc/Math/ScipyMinimizer.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "Rtypes.h"
2727
#include "TString.h"
2828

29+
#include <functional>
2930
#include <vector>
3031
#include <map>
3132

@@ -64,7 +65,10 @@ class ScipyMinimizer : public BasicMinimizer {
6465
PyObject *fMinimize;
6566
PyObject *fTarget;
6667
PyObject *fJacobian;
68+
PyObject *fHessian;
69+
6770
GenAlgoOptions fExtraOpts;
71+
std::function<bool(const std::vector<double> &, double *)> fHessianFunc;
6872

6973
protected:
7074
PyObject *fGlobalNS;
@@ -85,11 +89,6 @@ class ScipyMinimizer : public BasicMinimizer {
8589
*/
8690
virtual ~ScipyMinimizer();
8791

88-
/**
89-
Python eval function
90-
*/
91-
PyObject *Eval(TString code);
92-
9392
/**
9493
Python initialization
9594
*/
@@ -119,10 +118,10 @@ class ScipyMinimizer : public BasicMinimizer {
119118
void SetAlgoExtraOptions();
120119

121120
public:
122-
123121
/// method to perform the minimization
124122
virtual bool Minimize() override;
125123

124+
virtual void SetHessianFunction(std::function<bool(const std::vector<double> &, double *)>) override;
126125
template <class T>
127126
void SetExtraOption(const char *key, T value)
128127
{

math/scipy/src/ScipyMinimizer.cxx

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ const ROOT::Math::IMultiGenFunction *gFunction = nullptr;
2020
/// function wrapper for the gradient of the function to be minimized
2121
const ROOT::Math::IMultiGradFunction *gGradFunction = nullptr;
2222

23+
std::function<bool(const std::vector<double> &, double *)> gfHessianFunction;
24+
2325
#define PyPrint(pyo) PyObject_Print(pyo, stdout, Py_PRINT_RAW)
2426

2527
PyObject *target_function(PyObject * /*self*/, PyObject *args)
@@ -38,13 +40,30 @@ PyObject *jac_function(PyObject * /*self*/, PyObject *args)
3840

3941
uint size = PyArray_SIZE(arr);
4042
auto params = (double *)PyArray_DATA(arr);
41-
double *values=new double[size];
43+
double *values = new double[size];
4244
gGradFunction->Gradient(params, values);
4345
npy_intp dims[1] = {size};
4446
PyObject *py_array = PyArray_SimpleNewFromData(1, dims, NPY_DOUBLE, values);
4547
return py_array;
4648
};
4749

50+
PyObject *hessian_function(PyObject * /*self*/, PyObject *args)
51+
{
52+
PyArrayObject *arr = (PyArrayObject *)PyTuple_GetItem(args, 0);
53+
54+
uint size = PyArray_SIZE(arr);
55+
auto params = (double *)PyArray_DATA(arr);
56+
double *values = new double[size * size];
57+
std::vector<double> x(params, params + size);
58+
gfHessianFunction(x, values);
59+
npy_intp dims[2] = {size, size};
60+
PyObject *py_array = PyArray_SimpleNewFromData(2, dims, NPY_DOUBLE, values);
61+
// std::cout<<"---------------"<<std::endl;
62+
// PyPrint(py_array);
63+
// std::cout<<"---------------"<<std::endl;
64+
return py_array;
65+
};
66+
4867
//_______________________________________________________________________
4968
ScipyMinimizer::ScipyMinimizer() : BasicMinimizer()
5069
{
@@ -53,6 +72,7 @@ ScipyMinimizer::ScipyMinimizer() : BasicMinimizer()
5372
if (!PyIsInitialized()) {
5473
PyInitialize();
5574
}
75+
fHessianFunc = [](const std::vector<double> &, double *) -> bool { return false; };
5676
// set extra options
5777
SetAlgoExtraOptions();
5878
}
@@ -65,6 +85,7 @@ ScipyMinimizer::ScipyMinimizer(const char *type)
6585
if (!PyIsInitialized()) {
6686
PyInitialize();
6787
}
88+
fHessianFunc = [](const std::vector<double> &, double *) -> bool { return false; };
6889
// set extra options
6990
SetAlgoExtraOptions();
7091
}
@@ -83,6 +104,7 @@ void ScipyMinimizer::PyInitialize()
83104
static PyMethodDef ParamsMethods[] = {
84105
{"target_function", target_function, METH_VARARGS, "Target function to minimize."},
85106
{"jac_function", jac_function, METH_VARARGS, "Jacobian function."},
107+
{"hessian_function", hessian_function, METH_VARARGS, "Hessianfunction."},
86108
{NULL, NULL, 0, NULL} /* Sentinel */
87109
};
88110

@@ -128,9 +150,10 @@ void ScipyMinimizer::PyInitialize()
128150
// Scipy initialization
129151
PyRunString("from scipy.optimize import minimize");
130152
fMinimize = PyDict_GetItemString(fLocalNS, "minimize");
131-
PyRunString("from params import target_function, jac_function");
153+
PyRunString("from params import target_function, jac_function, hessian_function");
132154
fTarget = PyDict_GetItemString(fLocalNS, "target_function");
133155
fJacobian = PyDict_GetItemString(fLocalNS, "jac_function");
156+
fHessian = PyDict_GetItemString(fLocalNS, "hessian_function");
134157
}
135158

136159
//_______________________________________________________________________
@@ -158,9 +181,13 @@ bool ScipyMinimizer::Minimize()
158181
{
159182
(gFunction) = ObjFunction();
160183
(gGradFunction) = GradObjFunction();
184+
gfHessianFunction = fHessianFunc;
161185
if (gGradFunction == nullptr) {
162186
fJacobian = Py_None;
163187
}
188+
if (!gfHessianFunction) {
189+
fHessian = Py_None;
190+
}
164191
auto method = fOptions.MinimizerAlgorithm();
165192
PyObject *pyoptions = PyDict_New();
166193
if (method == "L-BFGS-B") {
@@ -188,17 +215,16 @@ bool ScipyMinimizer::Minimize()
188215
// minimize(fun, x0, args=(), method=None, jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None,
189216
// callback=None, options=None)
190217
auto args = Py_BuildValue("(OO)", fTarget, x0);
191-
auto kw = Py_BuildValue("{s:s,s:O,s:d,s:O}", "method", method.c_str(), "jac", fJacobian, "tol", Tolerance(),
192-
"options", pyoptions);
218+
auto kw = Py_BuildValue("{s:s,s:O,,s:O,s:d,s:O}", "method", method.c_str(), "jac", fJacobian, "hess", fHessian,
219+
"tol", Tolerance(), "options", pyoptions);
193220

194-
//PyPrint(kw);
221+
// PyPrint(kw);
195222
PyObject *result = PyObject_Call(fMinimize, args, kw);
196-
if(result == NULL)
197-
{
223+
if (result == NULL) {
198224
PyErr_Print();
199225
return false;
200226
}
201-
//PyPrint(result);
227+
// PyPrint(result);
202228
Py_DECREF(args);
203229
Py_DECREF(kw);
204230
Py_DECREF(x0);
@@ -243,3 +269,9 @@ void ScipyMinimizer::PyRunString(TString code, TString errorMessage, int start)
243269
exit(1);
244270
}
245271
}
272+
273+
//_______________________________________________________________________
274+
void ScipyMinimizer::SetHessianFunction(std::function<bool(const std::vector<double> &, double *)> func)
275+
{
276+
fHessianFunc = func;
277+
}

0 commit comments

Comments
 (0)