Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit a145441

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)
This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 6457178 commit a145441

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

DialectLinalg.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,25 @@
88

99
#include "mlir-c/Dialect/Linalg.h"
1010
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1211
#include "mlir/Bindings/Python/Nanobind.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace nb = nanobind;
15+
using namespace mlir::python::nanobind_adaptors;
16+
17+
static std::optional<MlirLinalgContractionDimensions>
18+
InferContractionDimensions(MlirOperation op) {
19+
MlirLinalgContractionDimensions dims =
20+
mlirLinalgInferContractionDimensions(op);
21+
22+
// Detect "empty" result. This occurs when `op` is not a contraction op,
23+
// or when `linalg::inferContractionDims` fails.
24+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26+
return std::nullopt;
27+
}
28+
return dims;
29+
}
1530

1631
static void populateDialectLinalgSubmodule(nb::module_ m) {
1732
m.def(
@@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
2035
nb::arg("op"),
2136
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2237
"op.");
38+
39+
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
40+
"Checks if the given operation is a Linalg contraction operation.",
41+
nb::arg("op"));
42+
43+
nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
44+
.def_prop_ro("batch",
45+
[](const MlirLinalgContractionDimensions &self) {
46+
return self.batch;
47+
})
48+
.def_prop_ro(
49+
"m",
50+
[](const MlirLinalgContractionDimensions &self) { return self.m; })
51+
.def_prop_ro(
52+
"n",
53+
[](const MlirLinalgContractionDimensions &self) { return self.n; })
54+
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
55+
return self.k;
56+
});
57+
58+
m.def("infer_contraction_dimensions", &InferContractionDimensions,
59+
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
60+
"op.",
61+
nb::arg("op"));
2362
}
2463

2564
NB_MODULE(_mlirDialectsLinalg, m) {

0 commit comments

Comments
 (0)