Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 190b37f

Browse files
hawkinspjpienaar
andauthored
[mlir python] Port in-tree dialects to nanobind. (#119924)
This is a companion to #118583, although it can be landed independently because since #117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h. --------- Co-authored-by: Jacques Pienaar <[email protected]>
1 parent a054b3e commit 190b37f

26 files changed

+263
-299
lines changed

AsyncPasses.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88

99
#include "mlir-c/Dialect/Async.h"
1010

11-
#include <pybind11/detail/common.h>
12-
#include <pybind11/pybind11.h>
11+
#include "mlir/Bindings/Python/Nanobind.h"
1312

1413
// -----------------------------------------------------------------------------
1514
// Module initialization.
1615
// -----------------------------------------------------------------------------
1716

18-
PYBIND11_MODULE(_mlirAsyncPasses, m) {
17+
NB_MODULE(_mlirAsyncPasses, m) {
1918
m.doc() = "MLIR Async Dialect Passes";
2019

2120
// Register all Async passes on load.

DialectGPU.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@
99
#include "mlir-c/Dialect/GPU.h"
1010
#include "mlir-c/IR.h"
1111
#include "mlir-c/Support.h"
12-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
13+
#include "mlir/Bindings/Python/Nanobind.h"
1314

14-
#include <pybind11/detail/common.h>
15-
#include <pybind11/pybind11.h>
15+
namespace nb = nanobind;
16+
using namespace nanobind::literals;
1617

17-
namespace py = pybind11;
1818
using namespace mlir;
1919
using namespace mlir::python;
20-
using namespace mlir::python::adaptors;
20+
using namespace mlir::python::nanobind_adaptors;
2121

2222
// -----------------------------------------------------------------------------
2323
// Module initialization.
2424
// -----------------------------------------------------------------------------
2525

26-
PYBIND11_MODULE(_mlirDialectsGPU, m) {
26+
NB_MODULE(_mlirDialectsGPU, m) {
2727
m.doc() = "MLIR GPU Dialect";
2828
//===-------------------------------------------------------------------===//
2929
// AsyncTokenType
@@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
3434

3535
mlirGPUAsyncTokenType.def_classmethod(
3636
"get",
37-
[](py::object cls, MlirContext ctx) {
37+
[](nb::object cls, MlirContext ctx) {
3838
return cls(mlirGPUAsyncTokenTypeGet(ctx));
3939
},
40-
"Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
41-
py::arg("ctx") = py::none());
40+
"Gets an instance of AsyncTokenType in the same context", nb::arg("cls"),
41+
nb::arg("ctx").none() = nb::none());
4242

4343
//===-------------------------------------------------------------------===//
4444
// ObjectAttr
@@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
4747
mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
4848
.def_classmethod(
4949
"get",
50-
[](py::object cls, MlirAttribute target, uint32_t format,
51-
py::bytes object, std::optional<MlirAttribute> mlirObjectProps,
50+
[](nb::object cls, MlirAttribute target, uint32_t format,
51+
nb::bytes object, std::optional<MlirAttribute> mlirObjectProps,
5252
std::optional<MlirAttribute> mlirKernelsAttr) {
53-
py::buffer_info info(py::buffer(object).request());
54-
MlirStringRef objectStrRef =
55-
mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
53+
MlirStringRef objectStrRef = mlirStringRefCreate(
54+
static_cast<char *>(const_cast<void *>(object.data())),
55+
object.size());
5656
return cls(mlirGPUObjectAttrGetWithKernels(
5757
mlirAttributeGetContext(target), target, format, objectStrRef,
5858
mlirObjectProps.has_value() ? *mlirObjectProps
@@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
6161
: MlirAttribute{nullptr}));
6262
},
6363
"cls"_a, "target"_a, "format"_a, "object"_a,
64-
"properties"_a = py::none(), "kernels"_a = py::none(),
64+
"properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(),
6565
"Gets a gpu.object from parameters.")
6666
.def_property_readonly(
6767
"target",
@@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) {
7373
"object",
7474
[](MlirAttribute self) {
7575
MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
76-
return py::bytes(stringRef.data, stringRef.length);
76+
return nb::bytes(stringRef.data, stringRef.length);
7777
})
7878
.def_property_readonly("properties",
79-
[](MlirAttribute self) {
79+
[](MlirAttribute self) -> nb::object {
8080
if (mlirGPUObjectAttrHasProperties(self))
81-
return py::cast(
81+
return nb::cast(
8282
mlirGPUObjectAttrGetProperties(self));
83-
return py::none().cast<py::object>();
83+
return nb::none();
8484
})
85-
.def_property_readonly("kernels", [](MlirAttribute self) {
85+
.def_property_readonly("kernels", [](MlirAttribute self) -> nb::object {
8686
if (mlirGPUObjectAttrHasKernels(self))
87-
return py::cast(mlirGPUObjectAttrGetKernels(self));
88-
return py::none().cast<py::object>();
87+
return nb::cast(mlirGPUObjectAttrGetKernels(self));
88+
return nb::none();
8989
});
9090
}

DialectLLVM.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@
1212
#include "mlir-c/IR.h"
1313
#include "mlir-c/Support.h"
1414
#include "mlir/Bindings/Python/Diagnostics.h"
15-
#include "mlir/Bindings/Python/PybindAdaptors.h"
15+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
#include "mlir/Bindings/Python/Nanobind.h"
17+
18+
namespace nb = nanobind;
19+
20+
using namespace nanobind::literals;
1621

17-
namespace py = pybind11;
1822
using namespace llvm;
1923
using namespace mlir;
2024
using namespace mlir::python;
21-
using namespace mlir::python::adaptors;
25+
using namespace mlir::python::nanobind_adaptors;
2226

23-
void populateDialectLLVMSubmodule(const pybind11::module &m) {
27+
void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
2428

2529
//===--------------------------------------------------------------------===//
2630
// StructType
@@ -31,58 +35,58 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
3135

3236
llvmStructType.def_classmethod(
3337
"get_literal",
34-
[](py::object cls, const std::vector<MlirType> &elements, bool packed,
38+
[](nb::object cls, const std::vector<MlirType> &elements, bool packed,
3539
MlirLocation loc) {
3640
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
3741

3842
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
3943
loc, elements.size(), elements.data(), packed);
4044
if (mlirTypeIsNull(type)) {
41-
throw py::value_error(scope.takeMessage());
45+
throw nb::value_error(scope.takeMessage().c_str());
4246
}
4347
return cls(type);
4448
},
45-
"cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
46-
"loc"_a = py::none());
49+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
50+
"loc"_a.none() = nb::none());
4751

4852
llvmStructType.def_classmethod(
4953
"get_identified",
50-
[](py::object cls, const std::string &name, MlirContext context) {
54+
[](nb::object cls, const std::string &name, MlirContext context) {
5155
return cls(mlirLLVMStructTypeIdentifiedGet(
5256
context, mlirStringRefCreate(name.data(), name.size())));
5357
},
54-
"cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
58+
"cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none());
5559

5660
llvmStructType.def_classmethod(
5761
"get_opaque",
58-
[](py::object cls, const std::string &name, MlirContext context) {
62+
[](nb::object cls, const std::string &name, MlirContext context) {
5963
return cls(mlirLLVMStructTypeOpaqueGet(
6064
context, mlirStringRefCreate(name.data(), name.size())));
6165
},
62-
"cls"_a, "name"_a, "context"_a = py::none());
66+
"cls"_a, "name"_a, "context"_a.none() = nb::none());
6367

6468
llvmStructType.def(
6569
"set_body",
6670
[](MlirType self, const std::vector<MlirType> &elements, bool packed) {
6771
MlirLogicalResult result = mlirLLVMStructTypeSetBody(
6872
self, elements.size(), elements.data(), packed);
6973
if (!mlirLogicalResultIsSuccess(result)) {
70-
throw py::value_error(
74+
throw nb::value_error(
7175
"Struct body already set to different content.");
7276
}
7377
},
74-
"elements"_a, py::kw_only(), "packed"_a = false);
78+
"elements"_a, nb::kw_only(), "packed"_a = false);
7579

7680
llvmStructType.def_classmethod(
7781
"new_identified",
78-
[](py::object cls, const std::string &name,
82+
[](nb::object cls, const std::string &name,
7983
const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
8084
return cls(mlirLLVMStructTypeIdentifiedNewGet(
8185
ctx, mlirStringRefCreate(name.data(), name.length()),
8286
elements.size(), elements.data(), packed));
8387
},
84-
"cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
85-
"context"_a = py::none());
88+
"cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
89+
"context"_a.none() = nb::none());
8690

8791
llvmStructType.def_property_readonly(
8892
"name", [](MlirType type) -> std::optional<std::string> {
@@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
9397
return StringRef(stringRef.data, stringRef.length).str();
9498
});
9599

96-
llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
100+
llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object {
97101
// Don't crash in absence of a body.
98102
if (mlirLLVMStructTypeIsOpaque(type))
99-
return py::none();
103+
return nb::none();
100104

101-
py::list body;
105+
nb::list body;
102106
for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
103107
++i) {
104108
body.append(mlirLLVMStructTypeGetElementType(type, i));
@@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
119123
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
120124
.def_classmethod(
121125
"get",
122-
[](py::object cls, std::optional<unsigned> addressSpace,
126+
[](nb::object cls, std::optional<unsigned> addressSpace,
123127
MlirContext context) {
124128
CollectDiagnosticsToStringScope scope(context);
125129
MlirType type = mlirLLVMPointerTypeGet(
126130
context, addressSpace.has_value() ? *addressSpace : 0);
127131
if (mlirTypeIsNull(type)) {
128-
throw py::value_error(scope.takeMessage());
132+
throw nb::value_error(scope.takeMessage().c_str());
129133
}
130134
return cls(type);
131135
},
132-
"cls"_a, "address_space"_a = py::none(), py::kw_only(),
133-
"context"_a = py::none())
136+
"cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(),
137+
"context"_a.none() = nb::none())
134138
.def_property_readonly("address_space", [](MlirType type) {
135139
return mlirLLVMPointerTypeGetAddressSpace(type);
136140
});
137141
}
138142

139-
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
143+
NB_MODULE(_mlirDialectsLLVM, m) {
140144
m.doc() = "MLIR LLVM Dialect";
141145

142146
populateDialectLLVMSubmodule(m);

DialectLinalg.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88

99
#include "mlir-c/Dialect/Linalg.h"
1010
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/PybindAdaptors.h"
11+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
12+
#include "mlir/Bindings/Python/Nanobind.h"
1213

13-
namespace py = pybind11;
14+
namespace nb = nanobind;
1415

15-
static void populateDialectLinalgSubmodule(py::module m) {
16+
static void populateDialectLinalgSubmodule(nb::module_ m) {
1617
m.def(
1718
"fill_builtin_region",
1819
[](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
19-
py::arg("op"),
20+
nb::arg("op"),
2021
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2122
"op.");
2223
}
2324

24-
PYBIND11_MODULE(_mlirDialectsLinalg, m) {
25+
NB_MODULE(_mlirDialectsLinalg, m) {
2526
m.doc() = "MLIR Linalg dialect.";
2627

2728
populateDialectLinalgSubmodule(m);

DialectNVGPU.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,33 @@
88

99
#include "mlir-c/Dialect/NVGPU.h"
1010
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/PybindAdaptors.h"
12-
#include <pybind11/pybind11.h>
11+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
12+
#include "mlir/Bindings/Python/Nanobind.h"
1313

14-
namespace py = pybind11;
14+
namespace nb = nanobind;
1515
using namespace llvm;
1616
using namespace mlir;
1717
using namespace mlir::python;
18-
using namespace mlir::python::adaptors;
18+
using namespace mlir::python::nanobind_adaptors;
1919

20-
static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
20+
static void populateDialectNVGPUSubmodule(const nb::module_ &m) {
2121
auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
2222
m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
2323

2424
nvgpuTensorMapDescriptorType.def_classmethod(
2525
"get",
26-
[](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
26+
[](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
2727
int oobFill, int interleave, MlirContext ctx) {
2828
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
2929
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
3030
},
3131
"Gets an instance of TensorMapDescriptorType in the same context",
32-
py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"),
33-
py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"),
34-
py::arg("ctx") = py::none());
32+
nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"),
33+
nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"),
34+
nb::arg("ctx").none() = nb::none());
3535
}
3636

37-
PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
37+
NB_MODULE(_mlirDialectsNVGPU, m) {
3838
m.doc() = "MLIR NVGPU dialect.";
3939

4040
populateDialectNVGPUSubmodule(m);

0 commit comments

Comments
 (0)