Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 339b852

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims (#135253)
This PR is mainly about exposing the python bindings for `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 94b1ca0 commit 339b852

File tree

3 files changed

+125
-3
lines changed

3 files changed

+125
-3
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ extern "C" {
2222
MLIR_CAPI_EXPORTED void
2323
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
2424

25-
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
25+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op);
2626

2727
struct MlirLinalgContractionDimensions {
2828
MlirAttribute batch;
@@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions {
3434
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
3535
mlirLinalgInferContractionDimensions(MlirOperation op);
3636

37+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
38+
39+
struct MlirLinalgConvolutionDimensions {
40+
MlirAttribute batch;
41+
MlirAttribute outputImage;
42+
MlirAttribute outputChannel;
43+
MlirAttribute filterLoop;
44+
MlirAttribute inputChannel;
45+
MlirAttribute depth;
46+
MlirAttribute strides;
47+
MlirAttribute dilations;
48+
};
49+
50+
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
51+
mlirLinalgInferConvolutionDimensions(MlirOperation op);
52+
3753
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
3854

3955
#ifdef __cplusplus

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
2828
return dims;
2929
}
3030

31+
static std::optional<MlirLinalgConvolutionDimensions>
32+
InferConvolutionDimensions(MlirOperation op) {
33+
MlirLinalgConvolutionDimensions dims =
34+
mlirLinalgInferConvolutionDimensions(op);
35+
36+
// Detect "empty" result. This occurs when `op` is not a convolution op,
37+
// or when `linalg::inferConvolutionDims` fails.
38+
if (mlirAttributeIsNull(dims.batch) &&
39+
mlirAttributeIsNull(dims.outputImage) &&
40+
mlirAttributeIsNull(dims.outputChannel) &&
41+
mlirAttributeIsNull(dims.filterLoop) &&
42+
mlirAttributeIsNull(dims.inputChannel) &&
43+
mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
44+
mlirAttributeIsNull(dims.dilations)) {
45+
return std::nullopt;
46+
}
47+
48+
return dims;
49+
}
50+
3151
static void populateDialectLinalgSubmodule(nb::module_ m) {
3252
m.def(
3353
"fill_builtin_region",
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
3656
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
3757
"op.");
3858

39-
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
59+
m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
4060
"Checks if the given operation is a Linalg contraction operation.",
4161
nb::arg("op"));
4262

@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5979
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
6080
"op.",
6181
nb::arg("op"));
82+
83+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84+
"Checks if the given operation is a Linalg convolution operation.",
85+
nb::arg("op"));
86+
87+
nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
88+
.def_prop_ro("batch",
89+
[](const MlirLinalgConvolutionDimensions &self) {
90+
return self.batch;
91+
})
92+
.def_prop_ro("output_image",
93+
[](const MlirLinalgConvolutionDimensions &self) {
94+
return self.outputImage;
95+
})
96+
.def_prop_ro("output_channel",
97+
[](const MlirLinalgConvolutionDimensions &self) {
98+
return self.outputChannel;
99+
})
100+
.def_prop_ro("filter_loop",
101+
[](const MlirLinalgConvolutionDimensions &self) {
102+
return self.filterLoop;
103+
})
104+
.def_prop_ro("input_channel",
105+
[](const MlirLinalgConvolutionDimensions &self) {
106+
return self.inputChannel;
107+
})
108+
.def_prop_ro("depth",
109+
[](const MlirLinalgConvolutionDimensions &self) {
110+
return self.depth;
111+
})
112+
.def_prop_ro("strides",
113+
[](const MlirLinalgConvolutionDimensions &self) {
114+
return self.strides;
115+
})
116+
.def_prop_ro("dilations",
117+
[](const MlirLinalgConvolutionDimensions &self) {
118+
return self.dilations;
119+
});
120+
121+
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122+
"Infers convolution dimensions", nb::arg("op"));
62123
}
63124

64125
NB_MODULE(_mlirDialectsLinalg, m) {

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
4141
fun(b, *body, op->getAttrs());
4242
}
4343

44-
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
44+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
4545
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
4646
// isaContractionOpInterface handles null linalgOp internally.
4747
return linalg::isaContractionOpInterface(linalgOp);
@@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
7575
return result;
7676
}
7777

78+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
79+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
80+
if (!linalgOp)
81+
return false;
82+
83+
return linalg::isaConvolutionOpInterface(linalgOp);
84+
}
85+
86+
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
87+
mlirLinalgInferConvolutionDimensions(MlirOperation op) {
88+
MlirLinalgConvolutionDimensions result{};
89+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
90+
if (!linalgOp)
91+
return result;
92+
93+
FailureOr<linalg::ConvolutionDimensions> maybeDims =
94+
linalg::inferConvolutionDims(linalgOp);
95+
if (failed(maybeDims))
96+
return result;
97+
98+
linalg::ConvolutionDimensions dims = *maybeDims;
99+
MLIRContext *ctx = linalgOp.getContext();
100+
101+
auto toI32Attr =
102+
[&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
103+
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
104+
};
105+
106+
auto toI64Attr =
107+
[&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute {
108+
return wrap(DenseI64ArrayAttr::get(ctx, vals));
109+
};
110+
111+
result.batch = toI32Attr(dims.batch);
112+
result.outputImage = toI32Attr(dims.outputImage);
113+
result.outputChannel = toI32Attr(dims.outputChannel);
114+
result.filterLoop = toI32Attr(dims.filterLoop);
115+
result.inputChannel = toI32Attr(dims.inputChannel);
116+
result.depth = toI32Attr(dims.depth);
117+
result.strides = toI64Attr(dims.strides);
118+
result.dilations = toI64Attr(dims.dilations);
119+
120+
return result;
121+
}
122+
78123
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

0 commit comments

Comments
 (0)