Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit c53c3ce

Browse files
authored
[mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims (#135253)
This PR is mainly about exposing the python bindings for `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 5af8ed1 commit c53c3ce

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
2828
return dims;
2929
}
3030

31+
static std::optional<MlirLinalgConvolutionDimensions>
32+
InferConvolutionDimensions(MlirOperation op) {
33+
MlirLinalgConvolutionDimensions dims =
34+
mlirLinalgInferConvolutionDimensions(op);
35+
36+
// Detect "empty" result. This occurs when `op` is not a convolution op,
37+
// or when `linalg::inferConvolutionDims` fails.
38+
if (mlirAttributeIsNull(dims.batch) &&
39+
mlirAttributeIsNull(dims.outputImage) &&
40+
mlirAttributeIsNull(dims.outputChannel) &&
41+
mlirAttributeIsNull(dims.filterLoop) &&
42+
mlirAttributeIsNull(dims.inputChannel) &&
43+
mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
44+
mlirAttributeIsNull(dims.dilations)) {
45+
return std::nullopt;
46+
}
47+
48+
return dims;
49+
}
50+
3151
static void populateDialectLinalgSubmodule(nb::module_ m) {
3252
m.def(
3353
"fill_builtin_region",
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
3656
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
3757
"op.");
3858

39-
m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
59+
m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
4060
"Checks if the given operation is a Linalg contraction operation.",
4161
nb::arg("op"));
4262

@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5979
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
6080
"op.",
6181
nb::arg("op"));
82+
83+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84+
"Checks if the given operation is a Linalg convolution operation.",
85+
nb::arg("op"));
86+
87+
nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
88+
.def_prop_ro("batch",
89+
[](const MlirLinalgConvolutionDimensions &self) {
90+
return self.batch;
91+
})
92+
.def_prop_ro("output_image",
93+
[](const MlirLinalgConvolutionDimensions &self) {
94+
return self.outputImage;
95+
})
96+
.def_prop_ro("output_channel",
97+
[](const MlirLinalgConvolutionDimensions &self) {
98+
return self.outputChannel;
99+
})
100+
.def_prop_ro("filter_loop",
101+
[](const MlirLinalgConvolutionDimensions &self) {
102+
return self.filterLoop;
103+
})
104+
.def_prop_ro("input_channel",
105+
[](const MlirLinalgConvolutionDimensions &self) {
106+
return self.inputChannel;
107+
})
108+
.def_prop_ro("depth",
109+
[](const MlirLinalgConvolutionDimensions &self) {
110+
return self.depth;
111+
})
112+
.def_prop_ro("strides",
113+
[](const MlirLinalgConvolutionDimensions &self) {
114+
return self.strides;
115+
})
116+
.def_prop_ro("dilations",
117+
[](const MlirLinalgConvolutionDimensions &self) {
118+
return self.dilations;
119+
});
120+
121+
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122+
"Infers convolution dimensions", nb::arg("op"));
62123
}
63124

64125
NB_MODULE(_mlirDialectsLinalg, m) {

0 commit comments

Comments
 (0)