Skip to content

Commit 6a934dc

Browse files
committed
serialization of dense qp wrapper (cpp/py)
- save/load model, results, settings - ignore ruiz eq. and workspace
1 parent b64f522 commit 6a934dc

File tree

6 files changed

+77
-1
lines changed

6 files changed

+77
-1
lines changed

bindings/python/src/expose-qpobject.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include <pybind11/pybind11.h>
99
#include <pybind11/eigen.h>
1010
#include <pybind11/stl.h>
11+
#ifdef PROXSUITE_WITH_SERIALIZATION
12+
#include <proxsuite/proxqp/serialization/archive.hpp>
13+
#include <proxsuite/proxqp/serialization/wrapper.hpp>
14+
#endif
1115

1216
namespace proxsuite {
1317
namespace proxqp {
@@ -114,7 +118,22 @@ exposeQpObjectDense(pybind11::module_ m)
114118
.def("cleanup",
115119
&dense::QP<T>::cleanup,
116120
"function used for cleaning the workspace and result "
117-
"classes.");
121+
"classes.")
122+
.def(pybind11::self == pybind11::self)
123+
.def(pybind11::self != pybind11::self)
124+
#ifdef PROXSUITE_WITH_SERIALIZATION
125+
.def(pybind11::pickle(
126+
127+
[](const dense::QP<T>& qp) {
128+
return pybind11::bytes(proxsuite::serialization::saveToString(qp));
129+
},
130+
[](pybind11::bytes& s) {
131+
proxsuite::proxqp::dense::QP<T> qp(1, 1, 1);
132+
proxsuite::serialization::loadFromString(qp, s);
133+
return qp;
134+
}));
135+
#endif
136+
;
118137
}
119138
} // namespace python
120139
} // namespace dense

include/proxsuite/proxqp/dense/wrapper.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,23 @@ solve(
459459

460460
return Qp.results;
461461
}
462+
463+
template<typename T>
464+
bool
465+
operator==(const QP<T>& qp1, const QP<T>& qp2)
466+
{
467+
bool value = qp1.model == qp2.model && qp1.settings == qp2.settings &&
468+
qp1.results == qp2.results;
469+
return value;
470+
}
471+
472+
template<typename T>
473+
bool
474+
operator!=(const QP<T>& qp1, const QP<T>& qp2)
475+
{
476+
return !(qp1 == qp2);
477+
}
478+
462479
} // namespace dense
463480
} // namespace proxqp
464481
} // namespace proxsuite
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//
2+
// Copyright (c) 2022 INRIA
3+
//
4+
/**
5+
* @file wrapper.hpp
6+
*/
7+
8+
#ifndef PROXSUITE_SERIALIZATION_WRAPPER_HPP
9+
#define PROXSUITE_SERIALIZATION_WRAPPER_HPP
10+
11+
#include <cereal/cereal.hpp>
12+
#include <proxsuite/proxqp/dense/wrapper.hpp>
13+
14+
namespace cereal {
15+
16+
template<class Archive, typename T>
17+
void
18+
serialize(Archive& archive, proxsuite::proxqp::dense::QP<T>& qp)
19+
{
20+
archive(
21+
CEREAL_NVP(qp.model), CEREAL_NVP(qp.results), CEREAL_NVP(qp.settings));
22+
}
23+
} // namespace cereal
24+
#endif /* end of include guard PROXSUITE_SERIALIZATION_WRAPPER_HPP */

test/src/serialization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ DOCTEST_TEST_CASE("test serialization of qp model, results and settings")
3838
generic_test(qp.model, TEST_SERIALIZATION_FOLDER "/qp_model");
3939
generic_test(qp.settings, TEST_SERIALIZATION_FOLDER "/qp_settings");
4040
generic_test(qp.results, TEST_SERIALIZATION_FOLDER "/qp_results");
41+
42+
generic_test(qp, TEST_SERIALIZATION_FOLDER "/qp_wrapper");
4143
}
4244

4345
DOCTEST_TEST_CASE(

test/src/serialization.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <proxsuite/proxqp/serialization/model.hpp>
66
#include <proxsuite/proxqp/serialization/results.hpp>
77
#include <proxsuite/proxqp/serialization/settings.hpp>
8+
#include <proxsuite/proxqp/serialization/wrapper.hpp>
89

910
template<typename object>
1011
struct init;
@@ -45,6 +46,18 @@ struct init<proxsuite::proxqp::Settings<T>>
4546
}
4647
};
4748

49+
template<typename T>
50+
struct init<proxsuite::proxqp::dense::QP<T>>
51+
{
52+
typedef proxsuite::proxqp::dense::QP<T> QP;
53+
54+
static QP run()
55+
{
56+
QP qp(1, 0, 0);
57+
return qp;
58+
}
59+
};
60+
4861
template<typename T>
4962
void
5063
generic_test(const T& object, const std::string& filename)

test/src/serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_pickle(self):
9090
generic_test(qp.model, "qp_model")
9191
generic_test(qp.settings, "qp_settings")
9292
generic_test(qp.results, "qp_results")
93+
generic_test(qp, "qp_wrapper")
9394

9495

9596
if __name__ == "__main__":

0 commit comments

Comments
 (0)