Skip to content

Commit b7b8245

Browse files
committed
Scipy: added support for bounds
1 parent 1c4bbc9 commit b7b8245

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

math/scipy/inc/Math/ScipyMinimizer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// @(#)root/math/scipy:$Id$
2-
// Author: [email protected] 2022
2+
// Author: [email protected] 2023
33

44
/*************************************************************************
55
* Copyright (C) 1995-2022, Rene Brun and Fons Rademakers. *
@@ -66,7 +66,7 @@ class ScipyMinimizer : public BasicMinimizer {
6666
PyObject *fTarget;
6767
PyObject *fJacobian;
6868
PyObject *fHessian;
69-
69+
PyObject *fBoundsMod;
7070
GenAlgoOptions fExtraOpts;
7171
std::function<bool(const std::vector<double> &, double *)> fHessianFunc;
7272

math/scipy/src/ScipyMinimizer.cxx

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
12
#include <Python.h> // Needs to be included first to avoid redefinition of _POSIX_C_SOURCE
23
#include <Math/ScipyMinimizer.h>
34
#include <Fit/ParameterSettings.h>
45
#include <Math/IFunction.h>
56
#include <Math/FitMethodFunction.h>
6-
#include "Math/GenAlgoOptions.h"
7+
#include <Math/GenAlgoOptions.h>
78
#include <TString.h>
89
#include <iostream>
910

1011
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
1112
#include <numpy/arrayobject.h>
13+
#include <numpy/npy_math.h>
1214

1315
using namespace ROOT;
1416
using namespace ROOT::Math;
@@ -148,8 +150,10 @@ void ScipyMinimizer::PyInitialize()
148150
_import_array(); // Numpy initialization
149151
}
150152
// Scipy initialization
151-
PyRunString("from scipy.optimize import minimize");
153+
154+
PyRunString("from scipy.optimize import minimize, Bounds");
152155
fMinimize = PyDict_GetItemString(fLocalNS, "minimize");
156+
fBoundsMod = PyDict_GetItemString(fLocalNS, "Bounds");
153157
PyRunString("from params import target_function, jac_function, hessian_function");
154158
fTarget = PyDict_GetItemString(fLocalNS, "target_function");
155159
fJacobian = PyDict_GetItemString(fLocalNS, "jac_function");
@@ -162,6 +166,8 @@ void ScipyMinimizer::PyFinalize()
162166
{
163167
if (fMinimize)
164168
Py_DECREF(fMinimize);
169+
if (fBoundsMod)
170+
Py_DECREF(fBoundsMod);
165171
Py_Finalize();
166172
}
167173

@@ -212,19 +218,47 @@ bool ScipyMinimizer::Minimize()
212218
npy_intp dims[1] = {NDim()};
213219
PyObject *x0 = PyArray_SimpleNewFromData(1, dims, NPY_DOUBLE, values);
214220

221+
PyObject *pybounds_args = PyTuple_New(2);
222+
PyObject *pylimits_lower = PyList_New(NDim());
223+
PyObject *pylimits_upper = PyList_New(NDim());
224+
for (unsigned int i = 0; i < NDim(); i++) {
225+
ParameterSettings varsettings;
226+
227+
if (GetVariableSettings(i, varsettings)) {
228+
if (varsettings.HasLowerLimit()) {
229+
PyList_SetItem(pylimits_lower, i, PyFloat_FromDouble(varsettings.LowerLimit()));
230+
} else {
231+
PyList_SetItem(pylimits_lower, i, PyFloat_FromDouble(-NPY_INFINITY));
232+
}
233+
if (varsettings.HasUpperLimit()) {
234+
PyList_SetItem(pylimits_upper, i, PyFloat_FromDouble(varsettings.UpperLimit()));
235+
} else {
236+
PyList_SetItem(pylimits_upper, i, PyFloat_FromDouble(NPY_INFINITY));
237+
}
238+
} else {
239+
MATH_ERROR_MSG("ScipyMinimizer::Minimize", Form("Variable index = %d not found", i));
240+
}
241+
}
242+
PyTuple_SetItem(pybounds_args, 0, pylimits_lower);
243+
PyTuple_SetItem(pybounds_args, 1, pylimits_upper);
244+
245+
PyObject *pybounds = PyObject_CallObject(fBoundsMod, pybounds_args);
246+
215247
// minimize(fun, x0, args=(), method=None, jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None,
216248
// callback=None, options=None)
217-
auto args = Py_BuildValue("(OO)", fTarget, x0);
218-
auto kw = Py_BuildValue("{s:s,s:O,,s:O,s:d,s:O}", "method", method.c_str(), "jac", fJacobian, "hess", fHessian,
219-
"tol", Tolerance(), "options", pyoptions);
249+
PyObject *args = Py_BuildValue("(OO)", fTarget, x0);
250+
PyObject *kw = Py_BuildValue("{s:s,s:O,,s:O,s:O,s:d,s:O}", "method", method.c_str(), "jac", fJacobian, "hess",
251+
fHessian, "bounds", pybounds, "tol", Tolerance(), "options", pyoptions);
220252

221-
// PyPrint(kw);
222253
PyObject *result = PyObject_Call(fMinimize, args, kw);
223254
if (result == NULL) {
224255
PyErr_Print();
225256
return false;
226257
}
227258
// PyPrint(result);
259+
Py_DECREF(pylimits_lower);
260+
Py_DECREF(pylimits_upper);
261+
Py_DECREF(pybounds);
228262
Py_DECREF(args);
229263
Py_DECREF(kw);
230264
Py_DECREF(x0);

0 commit comments

Comments
 (0)