Skip to content

Commit a0f2d1f

Browse files
committed
Scipy: added support for extra options
1 parent e0527b9 commit a0f2d1f

File tree

2 files changed

+59
-27
lines changed

2 files changed

+59
-27
lines changed

math/scipy/inc/Math/ScipyMinimizer.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
#include "Math/BasicMinimizer.h"
2323

24+
#include "Math/GenAlgoOptions.h"
25+
2426
#include "Rtypes.h"
2527
#include "TString.h"
2628

@@ -37,6 +39,8 @@ namespace ROOT {
3739

3840
namespace Math {
3941

42+
class GenAlgoOptions;
43+
4044
namespace Experimental {
4145
/**
4246
enumeration specifying the types of Scipy solvers
@@ -60,6 +64,7 @@ class ScipyMinimizer : public BasicMinimizer {
6064
PyObject *fMinimize;
6165
PyObject *fTarget;
6266
PyObject *fJacobian;
67+
GenAlgoOptions fExtraOpts;
6368

6469
protected:
6570
PyObject *fGlobalNS;
@@ -111,6 +116,7 @@ class ScipyMinimizer : public BasicMinimizer {
111116
Copy constructor
112117
*/
113118
ScipyMinimizer(const ScipyMinimizer &) : BasicMinimizer() {}
119+
void SetAlgoExtraOptions();
114120

115121
public:
116122
/// set the function to minimize
@@ -120,22 +126,13 @@ class ScipyMinimizer : public BasicMinimizer {
120126
// virtual void SetFunction(const ROOT::Math::IMultiGradFunction &func) { BasicMinimizer::SetFunction(func); }
121127

122128
/// method to perform the minimization
123-
virtual bool Minimize();
124-
125-
/// return expected distance reached from the minimum
126-
virtual double Edm() const { return 0; } // not impl. }
129+
virtual bool Minimize() override;
127130

128-
/// minimizer provides error and error matrix
129-
virtual bool ProvidesError() const { return false; }
130-
131-
/// return errors at the minimum
132-
virtual const double *Errors() const { return 0; }
133-
134-
/** return covariance matrices elements
135-
if the variable is fixed the matrix is zero
136-
The ordering of the variables is the same as in errors
137-
*/
138-
virtual double CovMatrix(unsigned int, unsigned int) const { return 0; }
131+
template <class T>
132+
void SetExtraOption(const char *key, T value)
133+
{
134+
fExtraOpts.SetValue(key, value);
135+
}
139136

140137
protected:
141138
ClassDef(ScipyMinimizer, 0) //

math/scipy/src/ScipyMinimizer.cxx

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <Fit/ParameterSettings.h>
44
#include <Math/IFunction.h>
55
#include <Math/FitMethodFunction.h>
6+
#include "Math/GenAlgoOptions.h"
67
#include <TString.h>
78
#include <iostream>
89

@@ -15,9 +16,11 @@ using namespace ROOT::Math::Experimental;
1516
using namespace ROOT::Fit;
1617

1718
/// function wrapper for the function to be minimized
18-
const ROOT::Math::IMultiGenFunction *gFunction;
19+
const ROOT::Math::IMultiGenFunction *gFunction = nullptr;
1920
/// function wrapper for the gradient of the function to be minimized
20-
const ROOT::Math::IMultiGradFunction *gGradFunction;
21+
const ROOT::Math::IMultiGradFunction *gGradFunction = nullptr;
22+
23+
#define PyPrint(pyo) PyObject_Print(pyo, stdout, Py_PRINT_RAW)
2124

2225
PyObject *target_function(PyObject * /*self*/, PyObject *args)
2326
{
@@ -46,10 +49,12 @@ PyObject *jac_function(PyObject * /*self*/, PyObject *args)
4649
ScipyMinimizer::ScipyMinimizer() : BasicMinimizer()
4750
{
4851
fOptions.SetMinimizerType("Scipy");
49-
fOptions.SetMinimizerAlgorithm("lbfgsb");
52+
fOptions.SetMinimizerAlgorithm("L-BFGS-B");
5053
if (!PyIsInitialized()) {
5154
PyInitialize();
5255
}
56+
// set extra options
57+
SetAlgoExtraOptions();
5358
}
5459

5560
//_______________________________________________________________________
@@ -60,6 +65,19 @@ ScipyMinimizer::ScipyMinimizer(const char *type)
6065
if (!PyIsInitialized()) {
6166
PyInitialize();
6267
}
68+
// set extra options
69+
SetAlgoExtraOptions();
70+
}
71+
72+
//_______________________________________________________________________
73+
void ScipyMinimizer::SetAlgoExtraOptions()
74+
{
75+
std::string type = fOptions.MinimizerAlgorithm();
76+
if (type == "L-BFGS-B") {
77+
fExtraOpts.SetValue("gtol", 1e-10);
78+
fExtraOpts.SetValue("eps", 1.0);
79+
}
80+
SetExtraOptions(fExtraOpts);
6381
}
6482

6583
//_______________________________________________________________________
@@ -144,7 +162,19 @@ bool ScipyMinimizer::Minimize()
144162
{
145163
(gFunction) = ObjFunction();
146164
(gGradFunction) = GradObjFunction();
165+
if (gGradFunction == nullptr) {
166+
fJacobian = Py_None;
167+
}
147168
auto method = fOptions.MinimizerAlgorithm();
169+
PyObject *pyoptions = PyDict_New();
170+
if (method == "L-BFGS-B") {
171+
for (std::string key : fExtraOpts.GetAllRealKeys()) {
172+
double value = 0;
173+
fExtraOpts.GetRealValue(key.c_str(), value);
174+
PyDict_SetItemString(pyoptions, key.c_str(), PyFloat_FromDouble(value));
175+
}
176+
}
177+
148178
std::cout << "=== Scipy Minimization" << std::endl;
149179
std::cout << "=== Method: " << method << std::endl;
150180
std::cout << "=== Initial value: (";
@@ -157,15 +187,20 @@ bool ScipyMinimizer::Minimize()
157187

158188
double *values = const_cast<double *>(X());
159189
npy_intp dims[1] = {NDim()};
160-
PyObject *py_array = PyArray_SimpleNewFromData(1, dims, NPY_DOUBLE, values);
161-
162-
PyObject *pargs = PyTuple_New(0);
163-
164-
auto pyvalues = Py_BuildValue("(OOOsO)", fTarget, py_array, pargs, method.c_str(), fJacobian);
165-
166-
PyObject *result = PyObject_CallObject(fMinimize, pyvalues);
167-
Py_DECREF(pyvalues);
168-
Py_DECREF(py_array);
190+
PyObject *x0 = PyArray_SimpleNewFromData(1, dims, NPY_DOUBLE, values);
191+
192+
// minimize(fun, x0, args=(), method=None, jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None,
193+
// callback=None, options=None)
194+
auto args = Py_BuildValue("(OO)", fTarget, x0);
195+
auto kw = Py_BuildValue("{s:s,s:O,s:d,s:O}", "method", method.c_str(), "jac", fJacobian, "tol", Tolerance(),
196+
"options", pyoptions);
197+
198+
//PyPrint(kw);
199+
PyObject *result = PyObject_Call(fMinimize, args, kw);
200+
//PyPrint(result);
201+
Py_DECREF(args);
202+
Py_DECREF(kw);
203+
Py_DECREF(x0);
169204

170205
// if the minimization works
171206
PyObject *pstatus = PyObject_GetAttrString(result, "status");

0 commit comments

Comments
 (0)