@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
2828 return dims;
2929}
3030
31+ static std::optional<MlirLinalgConvolutionDimensions>
32+ InferConvolutionDimensions (MlirOperation op) {
33+ MlirLinalgConvolutionDimensions dims =
34+ mlirLinalgInferConvolutionDimensions (op);
35+
36+ // Detect "empty" result. This occurs when `op` is not a convolution op,
37+ // or when `linalg::inferConvolutionDims` fails.
38+ if (mlirAttributeIsNull (dims.batch ) &&
39+ mlirAttributeIsNull (dims.outputImage ) &&
40+ mlirAttributeIsNull (dims.outputChannel ) &&
41+ mlirAttributeIsNull (dims.filterLoop ) &&
42+ mlirAttributeIsNull (dims.inputChannel ) &&
43+ mlirAttributeIsNull (dims.depth ) && mlirAttributeIsNull (dims.strides ) &&
44+ mlirAttributeIsNull (dims.dilations )) {
45+ return std::nullopt ;
46+ }
47+
48+ return dims;
49+ }
50+
3151static void populateDialectLinalgSubmodule (nb::module_ m) {
3252 m.def (
3353 " fill_builtin_region" ,
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
3656 " Fill the region for `op`, which is assumed to be a builtin named Linalg "
3757 " op." );
3858
39- m.def (" isa_contraction_op" , &mlirLinalgIsContractionOp ,
59+ m.def (" isa_contraction_op" , &mlirLinalgIsAContractionOp ,
4060 " Checks if the given operation is a Linalg contraction operation." ,
4161 nb::arg (" op" ));
4262
@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
5979 " Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
6080 " op." ,
6181 nb::arg (" op" ));
82+
83+ m.def (" isa_convolution_op" , &mlirLinalgIsAConvolutionOp,
84+ " Checks if the given operation is a Linalg convolution operation." ,
85+ nb::arg (" op" ));
86+
87+ nb::class_<MlirLinalgConvolutionDimensions>(m, " ConvolutionDimensions" )
88+ .def_prop_ro (" batch" ,
89+ [](const MlirLinalgConvolutionDimensions &self) {
90+ return self.batch ;
91+ })
92+ .def_prop_ro (" output_image" ,
93+ [](const MlirLinalgConvolutionDimensions &self) {
94+ return self.outputImage ;
95+ })
96+ .def_prop_ro (" output_channel" ,
97+ [](const MlirLinalgConvolutionDimensions &self) {
98+ return self.outputChannel ;
99+ })
100+ .def_prop_ro (" filter_loop" ,
101+ [](const MlirLinalgConvolutionDimensions &self) {
102+ return self.filterLoop ;
103+ })
104+ .def_prop_ro (" input_channel" ,
105+ [](const MlirLinalgConvolutionDimensions &self) {
106+ return self.inputChannel ;
107+ })
108+ .def_prop_ro (" depth" ,
109+ [](const MlirLinalgConvolutionDimensions &self) {
110+ return self.depth ;
111+ })
112+ .def_prop_ro (" strides" ,
113+ [](const MlirLinalgConvolutionDimensions &self) {
114+ return self.strides ;
115+ })
116+ .def_prop_ro (" dilations" ,
117+ [](const MlirLinalgConvolutionDimensions &self) {
118+ return self.dilations ;
119+ });
120+
121+ m.def (" infer_convolution_dimensions" , &InferConvolutionDimensions,
122+ " Infers convolution dimensions" , nb::arg (" op" ));
62123}
63124
64125NB_MODULE (_mlirDialectsLinalg, m) {
0 commit comments