Skip to content

Commit 929a6a1

Browse files
Switch from Pybind11 to Nanobind (#2015)
Co-authored-by: Jeff Fifield <[email protected]>
1 parent df46f74 commit 929a6a1

File tree

4 files changed

+71
-55
lines changed

4 files changed

+71
-55
lines changed

python/AIEMLIRModule.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,22 @@
1515
#include "mlir-c/IR.h"
1616
#include "mlir-c/Support.h"
1717
#include "mlir/Bindings/Python/Diagnostics.h"
18-
#include "mlir/Bindings/Python/PybindAdaptors.h"
18+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
19+
#include "llvm/ADT/Twine.h"
1920

20-
#include <pybind11/cast.h>
21-
#include <pybind11/detail/common.h>
22-
#include <pybind11/pybind11.h>
21+
#include <nanobind/nanobind.h>
2322

2423
#include <cstdlib>
2524
#include <stdexcept>
2625
#include <string>
2726
#include <unicodeobject.h>
2827
#include <vector>
2928

30-
using namespace mlir::python::adaptors;
31-
namespace py = pybind11;
32-
using namespace py::literals;
29+
using namespace mlir::python;
30+
namespace nb = nanobind;
31+
using namespace nb::literals;
3332

34-
PYBIND11_MODULE(_aie, m) {
33+
NB_MODULE(_aie, m) {
3534

3635
aieRegisterAllPasses();
3736

@@ -48,33 +47,35 @@ PYBIND11_MODULE(_aie, m) {
4847
"registry"_a);
4948

5049
// AIE types bindings
51-
mlir_type_subclass(m, "ObjectFifoType", aieTypeIsObjectFifoType)
50+
nanobind_adaptors::mlir_type_subclass(m, "ObjectFifoType",
51+
aieTypeIsObjectFifoType)
5252
.def_classmethod(
5353
"get",
54-
[](const py::object &cls, const MlirType type) {
54+
[](const nb::object &cls, const MlirType type) {
5555
return cls(aieObjectFifoTypeGet(type));
5656
},
5757
"Get an instance of ObjectFifoType with given element type.",
58-
"self"_a, "type"_a = py::none());
58+
"self"_a, "type"_a = nb::none());
5959

60-
mlir_type_subclass(m, "ObjectFifoSubviewType", aieTypeIsObjectFifoSubviewType)
60+
nanobind_adaptors::mlir_type_subclass(m, "ObjectFifoSubviewType",
61+
aieTypeIsObjectFifoSubviewType)
6162
.def_classmethod(
6263
"get",
63-
[](const py::object &cls, const MlirType type) {
64+
[](const nb::object &cls, const MlirType type) {
6465
return cls(aieObjectFifoSubviewTypeGet(type));
6566
},
6667
"Get an instance of ObjectFifoSubviewType with given element type.",
67-
"self"_a, "type"_a = py::none());
68+
"self"_a, "type"_a = nb::none());
6869

6970
auto stealCStr = [](MlirStringRef mlirString) {
7071
if (!mlirString.data || mlirString.length == 0)
7172
throw std::runtime_error("couldn't translate");
7273
std::string cpp(mlirString.data, mlirString.length);
7374
free((void *)mlirString.data);
74-
py::handle pyS = PyUnicode_DecodeLatin1(cpp.data(), cpp.length(), nullptr);
75+
nb::handle pyS = PyUnicode_DecodeLatin1(cpp.data(), cpp.length(), nullptr);
7576
if (!pyS)
76-
throw py::error_already_set();
77-
return py::reinterpret_steal<py::str>(pyS);
77+
throw nb::python_error();
78+
return nb::steal<nb::str>(pyS);
7879
};
7980

8081
m.def(
@@ -101,28 +102,31 @@ PYBIND11_MODULE(_aie, m) {
101102
if (mlirLogicalResultIsFailure(aieTranslateToCDODirect(
102103
op, {workDirPath.data(), workDirPath.size()}, bigendian,
103104
emitUnified, cdoDebug, aieSim, xaieDebug, enableCores)))
104-
throw py::value_error("Failed to generate cdo because: " +
105-
scope.takeMessage());
105+
throw nb::value_error(
106+
(llvm::Twine("Failed to generate cdo because: ") +
107+
llvm::Twine(scope.takeMessage()))
108+
.str()
109+
.c_str());
106110
},
107111
"module"_a, "work_dir_path"_a, "bigendian"_a = false,
108112
"emit_unified"_a = false, "cdo_debug"_a = false, "aiesim"_a = false,
109113
"xaie_debug"_a = false, "enable_cores"_a = true);
110114

111115
m.def(
112116
"transaction_binary_to_mlir",
113-
[](MlirContext ctx, py::bytes bytes) {
114-
std::string s = bytes;
115-
MlirStringRef bin = {s.data(), s.size()};
117+
[](MlirContext ctx, nb::bytes bytes) {
118+
MlirStringRef bin = {static_cast<const char *>(bytes.data()),
119+
bytes.size()};
116120
return aieTranslateBinaryToTxn(ctx, bin);
117121
},
118122
"ctx"_a, "binary"_a);
119123

120124
m.def(
121125
"npu_instgen",
122126
[&stealCStr](MlirOperation op) {
123-
py::str npuInstructions = stealCStr(aieTranslateToNPU(op));
127+
nb::str npuInstructions = stealCStr(aieTranslateToNPU(op));
124128
auto individualInstructions =
125-
npuInstructions.attr("split")().cast<py::list>();
129+
nb::cast<nb::list>(npuInstructions.attr("split")());
126130
for (size_t i = 0; i < individualInstructions.size(); ++i)
127131
individualInstructions[i] = individualInstructions[i].attr("strip")();
128132
return individualInstructions;
@@ -132,10 +136,10 @@ PYBIND11_MODULE(_aie, m) {
132136
m.def(
133137
"generate_control_packets",
134138
[&stealCStr](MlirOperation op) {
135-
py::str ctrlPackets =
139+
nb::str ctrlPackets =
136140
stealCStr(aieTranslateControlPacketsToUI32Vec(op));
137141
auto individualInstructions =
138-
ctrlPackets.attr("split")().cast<py::list>();
142+
nb::cast<nb::list>(ctrlPackets.attr("split")());
139143
for (size_t i = 0; i < individualInstructions.size(); ++i)
140144
individualInstructions[i] = individualInstructions[i].attr("strip")();
141145
return individualInstructions;
@@ -171,7 +175,7 @@ PYBIND11_MODULE(_aie, m) {
171175
m.def("get_target_model",
172176
[](uint32_t d) -> PyAieTargetModel { return aieGetTargetModel(d); });
173177

174-
py::class_<PyAieTargetModel>(m, "AIETargetModel", py::module_local())
178+
nb::class_<PyAieTargetModel>(m, "AIETargetModel")
175179
.def(
176180
"columns",
177181
[](PyAieTargetModel &self) {

python/AIERTModule.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313

1414
#include "aie/Bindings/PyTypes.h"
1515

16-
#include <pybind11/pybind11.h>
17-
#include <pybind11/pytypes.h>
16+
#include <nanobind/nanobind.h>
1817

1918
#include <algorithm>
2019

21-
namespace py = pybind11;
22-
using namespace py::literals;
20+
namespace nb = nanobind;
21+
using namespace nb::literals;
2322

2423
class PyAIERTControl {
2524
public:
@@ -31,10 +30,10 @@ class PyAIERTControl {
3130
AieRtControl ctl;
3231
};
3332

34-
PYBIND11_MODULE(_aiert, m) {
33+
NB_MODULE(_aiert, m) {
3534

36-
py::class_<PyAIERTControl>(m, "AIERTControl", py::module_local())
37-
.def(py::init<PyAieTargetModel>(), "target_model"_a)
35+
nb::class_<PyAIERTControl>(m, "AIERTControl")
36+
.def(nb::init<PyAieTargetModel>(), "target_model"_a)
3837
.def("start_transaction",
3938
[](PyAIERTControl &self) { aieRtStartTransaction(self.ctl); })
4039
.def("export_serialized_transaction",

python/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ if (AIE_ENABLE_PYTHON_PASSES)
191191

192192
PRIVATE_LINK_LIBS
193193
${_py_libs}
194+
195+
PYTHON_BINDINGS_LIBRARY
196+
nanobind
194197
)
195198
target_include_directories(
196199
AIEPythonExtensions.MLIR
@@ -282,6 +285,8 @@ else ()
282285
AIECAPI
283286
PRIVATE_LINK_LIBS
284287
LLVMSupport
288+
PYTHON_BINDINGS_LIBRARY
289+
nanobind
285290
)
286291

287292
if(AIE_ENABLE_XRT_PYTHON_BINDINGS)
@@ -298,6 +303,8 @@ else ()
298303
LLVMSupport
299304
xrt_coreutil
300305
uuid
306+
PYTHON_BINDINGS_LIBRARY
307+
nanobind
301308
)
302309
target_include_directories(AIEPythonExtensions.XRT INTERFACE ${XRT_INCLUDE_DIR})
303310
target_link_directories(AIEPythonExtensions.XRT INTERFACE ${XRT_LIB_DIR})
@@ -314,6 +321,9 @@ else ()
314321

315322
PRIVATE_LINK_LIBS
316323
LLVMSupport
324+
325+
PYTHON_BINDINGS_LIBRARY
326+
nanobind
317327
)
318328

319329
add_mlir_python_common_capi_library(AIEAggregateCAPI

python/XRTModule.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
#include "xrt/xrt_device.h"
1414
#include "xrt/xrt_kernel.h"
1515

16-
#include <pybind11/numpy.h>
17-
#include <pybind11/pybind11.h>
18-
#include <pybind11/pytypes.h>
19-
#include <pybind11/stl.h>
16+
#include <nanobind/nanobind.h>
17+
#include <nanobind/ndarray.h>
18+
#include <nanobind/stl/string.h>
19+
#include <nanobind/stl/vector.h>
2020

2121
#include <algorithm>
22+
#include <numeric>
23+
#include <optional>
2224
#include <string>
2325
#include <vector>
2426

25-
namespace py = pybind11;
26-
using namespace py::literals;
27+
namespace nb = nanobind;
28+
using namespace nb::literals;
2729

2830
// group_id 0 is for npu instructions
2931
// group_id 1 is for number of npu instructions
@@ -55,16 +57,16 @@ class PyXCLBin {
5557
}
5658

5759
template <typename ElementT>
58-
std::vector<py::memoryview>
59-
mmapBuffers(std::vector<std::vector<int>> shapes) {
60+
std::vector<nb::ndarray<>>
61+
mmapBuffers(std::vector<std::vector<size_t>> shapes) {
6062
this->buffers.reserve(shapes.size());
61-
std::vector<py::memoryview> views;
63+
std::vector<nb::ndarray<>> views;
6264
views.reserve(shapes.size());
6365

6466
auto initAndViewBuffer = [this](
65-
std::vector<int> shape, int groupId,
67+
std::vector<size_t> shape, int groupId,
6668
std::vector<std::unique_ptr<xrt::bo>> &buffers,
67-
std::vector<py::memoryview> &views) {
69+
std::vector<nb::ndarray<>> &views) {
6870
int nElements =
6971
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
7072
int nBytes = nElements * sizeof(ElementT);
@@ -79,12 +81,13 @@ class PyXCLBin {
7981
std::vector strides_{1};
8082
for (int i = shape.size() - 1; i > 0; i--)
8183
strides_.push_back(strides_.back() * shape[i]);
82-
std::vector<int> strides;
84+
std::vector<int64_t> strides;
8385
// stride in bytes
8486
std::transform(strides_.rbegin(), strides_.rend(),
8587
std::back_inserter(strides),
8688
[](int s) { return s * sizeof(ElementT); });
87-
views.push_back(py::memoryview::from_buffer(buf, shape, strides));
89+
views.push_back(nb::ndarray(buf, shape.size(), shape.data(), nb::handle(),
90+
strides.data()));
8891
};
8992

9093
for (size_t i = 0; i < shapes.size(); ++i)
@@ -140,22 +143,22 @@ class PyXCLBin {
140143
std::unique_ptr<xrt::run> run_;
141144
};
142145

143-
PYBIND11_MODULE(_xrt, m) {
146+
NB_MODULE(_xrt, m) {
144147

145-
py::class_<PyXCLBin>(m, "XCLBin", py::module_local())
146-
.def(py::init<const std::string &, const std::string &, int>(),
148+
nb::class_<PyXCLBin>(m, "XCLBin")
149+
.def(nb::init<const std::string &, const std::string &, int>(),
147150
"xclbin_path"_a, "kernel_name"_a, "device_index"_a = 0)
148151
.def("load_npu_instructions", &PyXCLBin::loadNPUInstructions, "insts"_a)
149152
.def("sync_buffers_to_device", &PyXCLBin::syncBuffersToDevice)
150153
.def("sync_buffers_from_device", &PyXCLBin::syncBuffersFromDevice)
151154
.def("run", &PyXCLBin::run)
152155
.def("_run_only_npu_instructions", &PyXCLBin::_runOnlyNpuInstructions)
153-
.def("wait", &PyXCLBin::wait, "timeout"_a = py::none())
156+
.def("wait", &PyXCLBin::wait, "timeout"_a = nb::none())
154157
.def(
155158
"mmap_buffers",
156-
[](PyXCLBin &self, const std::vector<std::vector<int>> &shapes,
157-
const py::object &npFormat) {
158-
auto npy = py::module_::import("numpy");
159+
[](PyXCLBin &self, const std::vector<std::vector<size_t>> &shapes,
160+
const nb::object &npFormat) {
161+
auto npy = nb::module_::import_("numpy");
159162
if (npFormat.is(npy.attr("int16")))
160163
return self.mmapBuffers<int16_t>(shapes);
161164
if (npFormat.is(npy.attr("int32")))
@@ -167,7 +170,7 @@ PYBIND11_MODULE(_xrt, m) {
167170
if (npFormat.is(npy.attr("float64")))
168171
return self.mmapBuffers<double>(shapes);
169172
throw std::runtime_error("unsupported np format: " +
170-
py::repr(npFormat).cast<std::string>());
173+
nb::cast<std::string>(nb::repr(npFormat)));
171174
},
172175
"shapes"_a, "np_format"_a)
173176
.def("_get_buffer_host_address", [](PyXCLBin &self, size_t idx) {

0 commit comments

Comments
 (0)