Skip to content

Commit feff9a3

Browse files
committed
Scipy: working in Constraint functions, it was possible to load the function before call Py_Initilaize()
1 parent 23109f0 commit feff9a3

File tree

2 files changed

+102
-16
lines changed

2 files changed

+102
-16
lines changed

math/scipy/inc/Math/ScipyMinimizer.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ class ScipyMinimizer : public BasicMinimizer {
7474
PyObject *fJacobian;
7575
PyObject *fHessian;
7676
PyObject *fBoundsMod;
77+
PyObject *fConstraintsList; /// contraints functions
7778
GenAlgoOptions fExtraOpts;
7879
std::function<bool(const std::vector<double> &, double *)> fHessianFunc;
80+
unsigned int fConstN;
7981
unsigned int fCalls;
82+
8083
protected:
8184
PyObject *fGlobalNS;
8285
PyObject *fLocalNS;
@@ -118,7 +121,18 @@ class ScipyMinimizer : public BasicMinimizer {
118121
/*
119122
Number of function calls
120123
*/
121-
virtual unsigned int NCalls() const override;
124+
virtual unsigned int NCalls() const override;
125+
126+
/*
127+
Method to add Constraint function,
128+
multiples constraints functions can be added.
129+
type have to be a string "eq" or "ineq" where
130+
eq (means Equal, then fun() = 0)
131+
ineq (means that, it is to be non-negative. fun() >=0)
132+
https://kitchingroup.cheme.cmu.edu/f19-06623/13-constrained-optimization.html
133+
*/
134+
virtual void AddConstraintFunction(std::function<double(const std::vector<double> &)>, std::string type);
135+
122136
private:
123137
// usually copying is non trivial, so we make this unaccessible
124138

math/scipy/src/ScipyMinimizer.cxx

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ std::function<bool(const std::vector<double> &, double *)> gfHessianFunction;
2727
/// simple function for debugging
2828
#define PyPrint(pyo) PyObject_Print(pyo, stdout, Py_PRINT_RAW)
2929

30-
3130
/// function to wrap into Python the C/C++ target function to be minimized
3231
PyObject *target_function(PyObject * /*self*/, PyObject *args)
3332
{
@@ -78,6 +77,8 @@ ScipyMinimizer::ScipyMinimizer() : BasicMinimizer()
7877
fHessianFunc = [](const std::vector<double> &, double *) -> bool { return false; };
7978
// set extra options
8079
SetAlgoExtraOptions();
80+
fConstraintsList = PyList_New(0);
81+
fConstN = 0;
8182
}
8283

8384
//_______________________________________________________________________
@@ -90,6 +91,8 @@ ScipyMinimizer::ScipyMinimizer(const char *type)
9091
fHessianFunc = [](const std::vector<double> &, double *) -> bool { return false; };
9192
// set extra options
9293
SetAlgoExtraOptions();
94+
fConstraintsList = PyList_New(0);
95+
fConstN = 0;
9396
}
9497

9598
//_______________________________________________________________________
@@ -110,16 +113,18 @@ void ScipyMinimizer::PyInitialize()
110113
{NULL, NULL, 0, NULL} /* Sentinel */
111114
};
112115

113-
static struct PyModuleDef paramsmodule = {PyModuleDef_HEAD_INIT, "params", /* name of module */
114-
"ROOT Scipy parameters", /* module documentation, may be NULL */
115-
-1, /* size of per-interpreter state of the module,
116-
or -1 if the module keeps state in global variables. */
117-
ParamsMethods,
118-
NULL, /* m_slots */
119-
NULL, /* m_traverse */
120-
0, /* m_clear */
121-
NULL /* m_free */
122-
};
116+
static struct PyModuleDef paramsmodule = {
117+
PyModuleDef_HEAD_INIT,
118+
"params", /* name of module */
119+
"ROOT Scipy parameters", /* module documentation, may be NULL */
120+
-1, /* size of per-interpreter state of the module,
121+
or -1 if the module keeps state in global variables. */
122+
ParamsMethods,
123+
NULL, /* m_slots */
124+
NULL, /* m_traverse */
125+
0, /* m_clear */
126+
NULL /* m_free */
127+
};
123128

124129
auto PyInit_params = [](void) -> PyObject * {
125130
PyObject *m;
@@ -209,8 +214,7 @@ bool ScipyMinimizer::Minimize()
209214
}
210215
}
211216
PyDict_SetItemString(pyoptions, "maxiter", PyLong_FromLong(MaxIterations()));
212-
if(PrintLevel()>0)
213-
{
217+
if (PrintLevel() > 0) {
214218
PyDict_SetItemString(pyoptions, "disp", Py_True);
215219
}
216220
std::cout << "=== Scipy Minimization" << std::endl;
@@ -281,7 +285,6 @@ bool ScipyMinimizer::Minimize()
281285
bool success = PyLong_AsLong(psuccess);
282286
Py_DECREF(psuccess);
283287

284-
285288
// the x values for the minimum
286289
PyArrayObject *pyx = (PyArrayObject *)PyObject_GetAttrString(result, "x");
287290
const double *x = (const double *)PyArray_DATA(pyx);
@@ -299,7 +302,7 @@ bool ScipyMinimizer::Minimize()
299302
SetFinalValues(x);
300303
auto obj_value = (*gFunction)(x);
301304
SetMinValue(obj_value);
302-
fCalls = nfev; //number of function evaluations
305+
fCalls = nfev; // number of function evaluations
303306

304307
std::cout << "=== Success: " << success << std::endl;
305308
std::cout << "=== Status: " << status << std::endl;
@@ -331,3 +334,72 @@ unsigned int ScipyMinimizer::NCalls() const
331334
{
332335
return fCalls;
333336
}
337+
338+
//_______________________________________________________________________
339+
void ScipyMinimizer::AddConstraintFunction(std::function<double(const std::vector<double> &)> func, std::string type)
340+
{
341+
if (type != "eq" && type != "ineq") {
342+
MATH_ERROR_MSG("ScipyMinimizer::AddConstraintFunction",
343+
Form("Error in constraint type %s, it have to be \"eq\" or \"ineq\"", type.c_str()));
344+
exit(1);
345+
}
346+
static std::function<double(const std::vector<double> &)> cfunt = func;
347+
auto const_function = [](PyObject * /*self*/, PyObject *args) -> PyObject * {
348+
PyArrayObject *arr = (PyArrayObject *)PyTuple_GetItem(args, 0);
349+
350+
uint size = PyArray_SIZE(arr);
351+
auto params = (double *)PyArray_DATA(arr);
352+
std::vector<double> x(params, params + size);
353+
auto r = cfunt(x);
354+
return PyFloat_FromDouble(r);
355+
};
356+
357+
static const char *name = Form("const_function%d", fConstN);
358+
static const char *name_error = Form("const_function%d.error", fConstN);
359+
static PyObject *ConstError;
360+
static PyMethodDef ConstMethods[] = {
361+
{name, const_function, METH_VARARGS, "Constraint function to minimize."}, {NULL, NULL, 0, NULL} /* Sentinel */
362+
};
363+
static struct PyModuleDef constmodule = {
364+
PyModuleDef_HEAD_INIT,
365+
name, /* name of module */
366+
"ROOT Scipy parameters", /* module documentation, may be NULL */
367+
-1, /* size of per-interpreter state of the module,
368+
or -1 if the module keeps state in global variables. */
369+
ConstMethods,
370+
NULL, /* m_slots */
371+
NULL, /* m_traverse */
372+
0, /* m_clear */
373+
NULL /* m_free */
374+
};
375+
376+
auto PyInit_const = [](void) -> PyObject * {
377+
PyObject *m;
378+
379+
m = PyModule_Create(&constmodule);
380+
if (m == NULL)
381+
return NULL;
382+
ConstError = PyErr_NewException(name_error, NULL, NULL);
383+
Py_XINCREF(ConstError);
384+
if (PyModule_AddObject(m, "error", ConstError) < 0) {
385+
Py_XDECREF(ConstError);
386+
Py_CLEAR(ConstError);
387+
Py_DECREF(m);
388+
return NULL;
389+
}
390+
return m;
391+
};
392+
PyImport_AddModule(name);
393+
PyObject *module = PyInit_const();
394+
PyObject *sys_modules = PyImport_GetModuleDict();
395+
PyDict_SetItemString(sys_modules, name, module);
396+
397+
PyRunString(Form("from %s import %s", name, name));
398+
PyObject *pyconstfun = PyDict_GetItemString(fLocalNS, name);
399+
400+
PyObject *pyconst = PyDict_New();
401+
PyDict_SetItemString(pyconst, "type", PyUnicode_FromString(type.c_str()));
402+
PyDict_SetItemString(pyconst, "fun", pyconstfun);
403+
PyList_Append(fConstraintsList, pyconst);
404+
PyPrint(fConstraintsList);
405+
}

0 commit comments

Comments
 (0)