Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 94b1ca0

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)
This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent c224df4 commit 94b1ca0

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ extern "C" {
2222
MLIR_CAPI_EXPORTED void
2323
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
2424

25+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
26+
27+
struct MlirLinalgContractionDimensions {
28+
MlirAttribute batch;
29+
MlirAttribute m;
30+
MlirAttribute n;
31+
MlirAttribute k;
32+
};
33+
34+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
35+
mlirLinalgInferContractionDimensions(MlirOperation op);
36+
2537
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
2638

2739
#ifdef __cplusplus

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,25 @@
88

99
#include "mlir-c/Dialect/Linalg.h"
1010
#include "mlir-c/IR.h"
11-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1211
#include "mlir/Bindings/Python/Nanobind.h"
12+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1313

1414
namespace nb = nanobind;
15+
using namespace mlir::python::nanobind_adaptors;
16+
17+
static std::optional<MlirLinalgContractionDimensions>
18+
InferContractionDimensions(MlirOperation op) {
19+
MlirLinalgContractionDimensions dims =
20+
mlirLinalgInferContractionDimensions(op);
21+
22+
// Detect "empty" result. This occurs when `op` is not a contraction op,
23+
// or when `linalg::inferContractionDims` fails.
24+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26+
return std::nullopt;
27+
}
28+
return dims;
29+
}
1530

1631
static void populateDialectLinalgSubmodule(nb::module_ m) {
1732
m.def(
@@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
2035
nb::arg("op"),
2136
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2237
"op.");
38+
39+
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
40+
"Checks if the given operation is a Linalg contraction operation.",
41+
nb::arg("op"));
42+
43+
nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
44+
.def_prop_ro("batch",
45+
[](const MlirLinalgContractionDimensions &self) {
46+
return self.batch;
47+
})
48+
.def_prop_ro(
49+
"m",
50+
[](const MlirLinalgContractionDimensions &self) { return self.m; })
51+
.def_prop_ro(
52+
"n",
53+
[](const MlirLinalgContractionDimensions &self) { return self.n; })
54+
.def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
55+
return self.k;
56+
});
57+
58+
m.def("infer_contraction_dimensions", &InferContractionDimensions,
59+
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
60+
"op.",
61+
nb::arg("op"));
2362
}
2463

2564
NB_MODULE(_mlirDialectsLinalg, m) {

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4141
fun(b, *body, op->getAttrs());
4242
}
4343

44+
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
45+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
46+
// isaContractionOpInterface handles null linalgOp internally.
47+
return linalg::isaContractionOpInterface(linalgOp);
48+
}
49+
50+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
51+
mlirLinalgInferContractionDimensions(MlirOperation op) {
52+
MlirLinalgContractionDimensions result{};
53+
auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
54+
if (!linalgOp)
55+
return result;
56+
57+
FailureOr<linalg::ContractionDimensions> maybeDims =
58+
linalg::inferContractionDims(linalgOp);
59+
if (failed(maybeDims))
60+
return result;
61+
62+
linalg::ContractionDimensions contractionDims = *maybeDims;
63+
MLIRContext *ctx = linalgOp.getContext();
64+
65+
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
66+
return wrap(
67+
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
68+
};
69+
70+
result.batch = toAttr(contractionDims.batch);
71+
result.m = toAttr(contractionDims.m);
72+
result.n = toAttr(contractionDims.n);
73+
result.k = toAttr(contractionDims.k);
74+
75+
return result;
76+
}
77+
4478
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

0 commit comments

Comments
 (0)