Skip to content

Commit 9033d3c

Browse files
committed
Implement custom Pybind11 type casters to improve the API readability.
Improve conversion implementation by removing intermediate conversions.
1 parent ac1c886 commit 9033d3c

17 files changed

+973
-707
lines changed

bindings/Python/conversion_from_python.hpp

Lines changed: 0 additions & 453 deletions
This file was deleted.

bindings/Python/conversion_to_python.hpp

Lines changed: 0 additions & 95 deletions
This file was deleted.

bindings/Python/data_set/classification_data_set.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
#include "plssvm/detail/type_traits.hpp" // plssvm::detail::remove_cvref_t
1313
#include "plssvm/file_format_types.hpp" // plssvm::file_format_type
1414

15-
#include "bindings/Python/conversion_from_python.hpp" // plssvm::bindings::python::util::{pyobject_to_matrix, pyobject_to_vector}
16-
#include "bindings/Python/conversion_to_python.hpp" // plssvm::bindings::python::util::{matrix_to_pyarray, vector_to_pyarray}
17-
#include "bindings/Python/data_set/variant_wrapper.hpp" // plssvm::bindings::python::util::classification_data_set_wrapper
18-
#include "bindings/Python/utility.hpp" // plssvm::bindings::python::util::{create_instance, python_type_name_mapping}
15+
#include "bindings/Python/data_set/variant_wrapper.hpp" // plssvm::bindings::python::util::classification_data_set_wrapper
16+
#include "bindings/Python/type_caster/label_vector_wrapper_caster.hpp" // a custom Pybind11 type caster for a plssvm::bindings::python::util::label_vector_wrapper
17+
#include "bindings/Python/type_caster/matrix_type_caster.hpp" // a custom Pybind11 type caster for a plssvm::matrix
18+
#include "bindings/Python/utility.hpp" // plssvm::bindings::python::util::{create_instance, python_type_name_mapping, vector_to_pyarray}
1919

2020
#include "fmt/format.h" // fmt::format
2121
#include "fmt/ranges.h" // fmt::join
2222
#include "pybind11/numpy.h" // py::array_t, py::array
23-
#include "pybind11/pybind11.h" // py::module_, py::class_, py::init, py::arg, py::pos_only, py::object, py::attribute_error
23+
#include "pybind11/pybind11.h" // py::module_, py::class_, py::init, py::arg, py::pos_only, py::attribute_error
2424
#include "pybind11/pytypes.h" // py::type
2525
#include "pybind11/stl.h" // support for STL types
2626

@@ -57,21 +57,18 @@ void init_classification_data_set(py::module_ &m) {
5757
py::arg("type") = std::nullopt,
5858
py::arg("format") = plssvm::file_format_type::libsvm,
5959
py::arg("scaler") = std::nullopt)
60-
.def(py::init([](py::object data, const std::optional<py::type> type, const std::optional<plssvm::min_max_scaler> scaler) {
61-
// convert the data py::object to a plssvm::aos_matrix
62-
auto [data_matrix, opt_feature_names] = plssvm::bindings::python::util::pyobject_to_matrix(data);
63-
60+
.def(py::init([](plssvm::soa_matrix<plssvm::real_type> data, const std::optional<py::type> type, const std::optional<plssvm::min_max_scaler> scaler) {
6461
if (type.has_value()) {
6562
if (scaler.has_value()) {
66-
return std::make_unique<classification_data_set_wrapper>(plssvm::bindings::python::util::create_instance<plssvm::classification_data_set, typename classification_data_set_wrapper::possible_data_set_types>(type.value(), std::move(data_matrix), scaler.value()));
63+
return std::make_unique<classification_data_set_wrapper>(plssvm::bindings::python::util::create_instance<plssvm::classification_data_set, typename classification_data_set_wrapper::possible_data_set_types>(type.value(), std::move(data), scaler.value()));
6764
} else {
68-
return std::make_unique<classification_data_set_wrapper>(plssvm::bindings::python::util::create_instance<plssvm::classification_data_set, typename classification_data_set_wrapper::possible_data_set_types>(type.value(), std::move(data_matrix)));
65+
return std::make_unique<classification_data_set_wrapper>(plssvm::bindings::python::util::create_instance<plssvm::classification_data_set, typename classification_data_set_wrapper::possible_data_set_types>(type.value(), std::move(data)));
6966
}
7067
} else {
7168
if (scaler.has_value()) {
72-
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<std::string>{ std::move(data_matrix), scaler.value() });
69+
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<std::string>{ std::move(data), scaler.value() });
7370
} else {
74-
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<std::string>{ std::move(data_matrix) });
71+
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<std::string>{ std::move(data) });
7572
}
7673
}
7774
}),
@@ -80,29 +77,24 @@ void init_classification_data_set(py::module_ &m) {
8077
py::pos_only(),
8178
py::arg("type") = std::nullopt,
8279
py::arg("scaler") = std::nullopt)
83-
.def(py::init([](py::object data, py::object labels, const std::optional<plssvm::min_max_scaler> scaler) {
84-
// convert the data py::object to a plssvm::aos_matrix
85-
auto [data_matrix, opt_feature_names] = plssvm::bindings::python::util::pyobject_to_matrix(data);
86-
// convert the labels to a std::vector
87-
auto [labels_vector_variant, dtype] = plssvm::bindings::python::util::pyobject_to_vector<typename classification_data_set_wrapper::possible_vector_types>(labels);
88-
89-
return std::visit([&data_matrix = data_matrix, &dtype = dtype, &scaler](auto &&labels_vector) {
80+
.def(py::init([](plssvm::soa_matrix<plssvm::real_type> data, plssvm::bindings::python::util::label_vector_wrapper<typename classification_data_set_wrapper::possible_vector_types> labels, const std::optional<plssvm::min_max_scaler> scaler) {
81+
return std::visit([&](auto &&labels_vector) {
9082
using label_type = typename plssvm::detail::remove_cvref_t<decltype(labels_vector)>::value_type;
9183
if (scaler.has_value()) {
92-
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<label_type>(std::move(data_matrix), std::move(labels_vector), scaler.value()));
84+
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<label_type>(std::move(data), std::move(labels_vector), scaler.value()));
9385
} else {
94-
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<label_type>(std::move(data_matrix), std::move(labels_vector)));
86+
return std::make_unique<classification_data_set_wrapper>(plssvm::classification_data_set<label_type>(std::move(data), std::move(labels_vector)));
9587
}
9688
},
97-
labels_vector_variant);
89+
labels.labels);
9890
}),
9991
"create a new data set from the provided data and labels and additional optional parameters",
10092
py::arg("X"),
10193
py::arg("y"),
10294
py::pos_only(),
10395
py::arg("scaler") = std::nullopt)
10496
.def("save", [](const classification_data_set_wrapper &self, const std::string &filename, const plssvm::file_format_type format) { std::visit([&filename, format](auto &&data) { data.save(filename, format); }, self.data_set); }, "save the data set to a file using the provided file format type", py::arg("filename"), py::pos_only(), py::arg("format") = plssvm::file_format_type::libsvm)
105-
.def("data", [](const classification_data_set_wrapper &self) { return std::visit([](auto &&data) { return plssvm::bindings::python::util::matrix_to_pyarray(data.data()); }, self.data_set); }, "the data saved as 2D vector")
97+
.def("data", [](const classification_data_set_wrapper &self) { return std::visit([](auto &&data) { return py::cast(data.data()); }, self.data_set); }, "the data saved as 2D vector")
10698
.def("has_labels", [](const classification_data_set_wrapper &self) { return std::visit([](auto &&data) { return data.has_labels(); }, self.data_set); }, "check whether the data set has labels")
10799
// clang-format off
108100
.def("labels", [](const classification_data_set_wrapper &self) {

bindings/Python/data_set/min_max_scaler.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include "plssvm/constants.hpp" // plssvm::real_type
1212

13-
#include "bindings/Python/conversion_to_python.hpp" // plssvm::bindings::python::util::vector_to_pyarray
13+
#include "bindings/Python/utility.hpp" // plssvm::bindings::python::util::vector_to_pyarray
1414

1515
#include "fmt/format.h" // fmt::format
1616
#include "pybind11/numpy.h" // py::array
@@ -54,7 +54,8 @@ void init_min_max_scaler(py::module_ &m) {
5454
throw py::value_error{ fmt::format("MinMaxScaler can only be created from two interval values (lower, upper), but {} were provided!", interval.size()) };
5555
}
5656
return plssvm::min_max_scaler{ interval[0].cast<plssvm::real_type>(), interval[1].cast<plssvm::real_type>() };
57-
}), "create new scaling factors for the range [lower, upper]")
57+
}),
58+
"create new scaling factors for the range [lower, upper]")
5859
.def("save", &plssvm::min_max_scaler::save, "save the scaling factors to a file")
5960
.def("scaling_interval", &plssvm::min_max_scaler::scaling_interval, "the interval to which the data points are scaled")
6061
.def(

0 commit comments

Comments
 (0)