Skip to content

Commit a5a78d0

Browse files
authored
[mlir][linalg][python] Add Python Bindings for Inferring Contraction Dimensions from Affine Maps (#167587)
This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent a22834a commit a5a78d0

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

mlir/include/mlir-c/Dialect/Linalg.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#ifndef MLIR_C_DIALECT_LINALG_H
1111
#define MLIR_C_DIALECT_LINALG_H
1212

13+
#include "mlir-c/AffineMap.h"
1314
#include "mlir-c/IR.h"
1415
#include "mlir-c/Support.h"
1516

@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
3435
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
3536
mlirLinalgInferContractionDimensions(MlirOperation op);
3637

38+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
39+
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
40+
size_t numMaps);
41+
3742
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
3843

3944
typedef struct MlirLinalgConvolutionDimensions {

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
8080
"op.",
8181
nb::arg("op"));
8282

83+
m.def(
84+
"infer_contraction_dimensions_from_maps",
85+
[](std::vector<MlirAffineMap> indexingMaps)
86+
-> std::optional<MlirLinalgContractionDimensions> {
87+
if (indexingMaps.empty())
88+
return std::nullopt;
89+
90+
MlirLinalgContractionDimensions dims =
91+
mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
92+
indexingMaps.size());
93+
94+
// Detect "empty" result from invalid input or failed inference.
95+
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
96+
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
97+
return std::nullopt;
98+
}
99+
return dims;
100+
},
101+
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
102+
"maps.",
103+
nb::arg("indexing_maps"));
104+
83105
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84106
"Checks if the given operation is a Linalg convolution operation.",
85107
nb::arg("op"));

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Dialect/Linalg.h"
10+
#include "mlir/CAPI/AffineMap.h"
1011
#include "mlir/CAPI/Registration.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1213

@@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
6263
const linalg::ContractionDimensions &contractionDims = *maybeDims;
6364
MLIRContext *ctx = linalgOp.getContext();
6465

65-
auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
66-
return wrap(
67-
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
66+
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
67+
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
6868
};
6969

7070
result.batch = toAttr(contractionDims.batch);
@@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
7575
return result;
7676
}
7777

78+
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
79+
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
80+
size_t numMaps) {
81+
MlirLinalgContractionDimensions result{};
82+
if (!indexingMaps || numMaps == 0)
83+
return result;
84+
85+
SmallVector<AffineMap, 3> maps;
86+
maps.reserve(numMaps);
87+
for (size_t i = 0; i < numMaps; ++i) {
88+
maps.push_back(unwrap(indexingMaps[i]));
89+
}
90+
91+
FailureOr<linalg::ContractionDimensions> maybeDims =
92+
linalg::inferContractionDims(maps);
93+
if (failed(maybeDims))
94+
return result;
95+
96+
MLIRContext *ctx = maps[0].getContext();
97+
98+
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
99+
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
100+
};
101+
102+
result.batch = toAttr(maybeDims->batch);
103+
result.m = toAttr(maybeDims->m);
104+
result.n = toAttr(maybeDims->n);
105+
result.k = toAttr(maybeDims->k);
106+
107+
return result;
108+
}
109+
78110
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
79111
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
80112
if (!linalgOp)

mlir/test/python/dialects/linalg/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,43 @@ def matmul_func(a, b, c):
208208
assert maps[0] == a_map
209209
assert maps[1] == b_map
210210
assert maps[2] == c_map
211+
212+
213+
@run
214+
def test_infer_contraction_dimensions_from_maps():
215+
with Context(), Location.unknown():
216+
module = Module.create()
217+
with InsertionPoint(module.body):
218+
# === Test valid contraction (matmul) ===
219+
dim_m = AffineDimExpr.get(0)
220+
dim_n = AffineDimExpr.get(1)
221+
dim_k = AffineDimExpr.get(2)
222+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
223+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
224+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
225+
226+
dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
227+
assert dims is not None
228+
229+
# Expect m=[0], n=[1], k=[2] as per standard matmul.
230+
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
231+
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
232+
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
233+
assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"
234+
235+
# === Test invalid input (wrong number of maps) ===
236+
invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
237+
assert invalid_dims is None
238+
239+
# === Test element-wise operation ===
240+
dim_i = AffineDimExpr.get(0)
241+
dim_j = AffineDimExpr.get(1)
242+
elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
243+
elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
244+
[elementwise_map, elementwise_map, elementwise_map]
245+
)
246+
assert elementwise_dims is not None
247+
assert len(elementwise_dims.m) == 0
248+
assert len(elementwise_dims.n) == 0
249+
assert len(elementwise_dims.k) == 0
250+
assert list(elementwise_dims.batch) == [0, 1]

0 commit comments

Comments
 (0)