Skip to content

Commit 9984adf

Browse files
committed
[mlir][linalg][python] Add Python bindings for inferring contraction dimensions from affine maps
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 47a3ea4 commit 9984adf

File tree

4 files changed

+104
-0
lines changed

4 files changed

+104
-0
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#ifndef MLIR_C_DIALECT_LINALG_H
1111
#define MLIR_C_DIALECT_LINALG_H
1212

13+
#include "mlir-c/AffineMap.h"
1314
#include "mlir-c/IR.h"
1415
#include "mlir-c/Support.h"
1516

@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
3435
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
3536
mlirLinalgInferContractionDimensions(MlirOperation op);
3637

38+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
39+
mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
40+
intptr_t numMaps);
41+
3742
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
3843

3944
typedef struct MlirLinalgConvolutionDimensions {

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
8080
"op.",
8181
nb::arg("op"));
8282

83+
m.def(
84+
"infer_contraction_dimensions_from_maps",
85+
[](std::vector<MlirAffineMap> indexingMaps)
86+
-> std::optional<MlirLinalgContractionDimensions> {
87+
if (indexingMaps.empty())
88+
return std::nullopt;
89+
90+
MlirLinalgContractionDimensions dims =
91+
mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
92+
indexingMaps.size());
93+
94+
// Detect "empty" result. This occurs when the input is invalid
95+
// or when `linalg::inferContractionDims` fails.
96+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
97+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
98+
return std::nullopt;
99+
}
100+
return dims;
101+
},
102+
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
103+
"maps.",
104+
nb::arg("indexing_maps"));
105+
83106
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84107
"Checks if the given operation is a Linalg convolution operation.",
85108
nb::arg("op"));

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Dialect/Linalg.h"
10+
#include "mlir/CAPI/AffineMap.h"
1011
#include "mlir/CAPI/Registration.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1213

@@ -75,6 +76,40 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
7576
return result;
7677
}
7778

79+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
80+
mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
81+
intptr_t numMaps) {
82+
MlirLinalgContractionDimensions result{};
83+
if (!indexingMaps || numMaps <= 0)
84+
return result;
85+
86+
SmallVector<AffineMap> maps;
87+
maps.reserve(numMaps);
88+
for (intptr_t i = 0; i < numMaps; ++i) {
89+
maps.push_back(unwrap(indexingMaps[i]));
90+
}
91+
92+
FailureOr<linalg::ContractionDimensions> maybeDims =
93+
linalg::inferContractionDims(maps);
94+
if (failed(maybeDims))
95+
return result;
96+
97+
const linalg::ContractionDimensions &contractionDims = *maybeDims;
98+
MLIRContext *ctx = maps[0].getContext();
99+
100+
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
101+
return wrap(
102+
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
103+
};
104+
105+
result.batch = toAttr(contractionDims.batch);
106+
result.m = toAttr(contractionDims.m);
107+
result.n = toAttr(contractionDims.n);
108+
result.k = toAttr(contractionDims.k);
109+
110+
return result;
111+
}
112+
78113
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
79114
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
80115
if (!linalgOp)

mlir/test/python/dialects/linalg/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,44 @@ def matmul_func(a, b, c):
208208
assert maps[0] == a_map
209209
assert maps[1] == b_map
210210
assert maps[2] == c_map
211+
212+
213+
@run
214+
def test_infer_contraction_dimensions_from_maps():
215+
with Context(), Location.unknown():
216+
module = Module.create()
217+
with InsertionPoint(module.body):
218+
# === Test valid contraction (matmul) ===
219+
dim_m = AffineDimExpr.get(0)
220+
dim_n = AffineDimExpr.get(1)
221+
dim_k = AffineDimExpr.get(2)
222+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
223+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
224+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
225+
226+
dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
227+
assert dims is not None
228+
229+
# Expect m=[0], n=[1], k=[2] as per standard matmul.
230+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
231+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
232+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
233+
assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"
234+
235+
# === Test invalid input (wrong number of maps) ===
236+
invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
237+
assert invalid_dims is None
238+
239+
# === Test element-wise operation ===
240+
# All dimensions appear in all operands, so they're batch dimensions.
241+
dim_i = AffineDimExpr.get(0)
242+
dim_j = AffineDimExpr.get(1)
243+
elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
244+
elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
245+
[elementwise_map, elementwise_map, elementwise_map]
246+
)
247+
assert elementwise_dims is not None
248+
assert list(elementwise_dims.m) == []
249+
assert list(elementwise_dims.n) == []
250+
assert list(elementwise_dims.k) == []
251+
assert list(elementwise_dims.batch) == [0, 1]

0 commit comments

Comments
 (0)