diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 4f2ee0d434222..339e63d667c5e 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions { MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index ce1102a3b3498..015502371c65b 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { m.def("infer_convolution_dimensions", &InferConvolutionDimensions, "Infers convolution dimensions", nb::arg("op")); + + m.def( + "get_indexing_maps", + [](MlirOperation op) -> std::optional { + MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op); + if (mlirAttributeIsNull(attr)) + return std::nullopt; + return attr; + }, + "Returns the indexing_maps attribute for a linalg op."); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 7c456102a2c0c..0c4f6e88e7078 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return MlirAttribute{nullptr}; + + ArrayAttr attr = linalgOp.getIndexingMaps(); + return wrap(attr); +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 98157b0e443cf..5f7cb6a6c83cb 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -159,3 +159,52 @@ def dyn_conv_fn(input, filter, output): assert list(dims.depth) == [] assert list(dims.strides) == [1, 1] assert list(dims.dilations) == [1, 1] + + +@run +def test_get_indexing_maps_attr(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_type = RankedTensorType.get((4, 8), f32) + b_type = RankedTensorType.get((8, 16), f32) + c_type = RankedTensorType.get((4, 16), f32) + + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + @func.FuncOp.from_py_func(a_type, b_type, c_type) + def matmul_func(a, b, c): + zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) + assert not linalg.get_indexing_maps( + zero.operation + ), "Expected no indexing_maps on non-linalg op" + + init = linalg.fill(zero, outs=[c]) + fill_op = init.owner + fill_maps = linalg.get_indexing_maps(fill_op) + assert fill_maps is not None + assert len(fill_maps) == 2 + + # The fill op should have maps like (d0, d1) -> () and (d0, d1). + fill_input_map = fill_maps[0].value + fill_output_map = fill_maps[1].value + assert fill_input_map == AffineMap.get(2, 0, []) + assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n]) + + result = linalg.matmul(a, b, outs=(init,)) + matmul_op = result.owner + matmul_maps = linalg.get_indexing_maps(matmul_op) + assert matmul_maps is not None + assert len(matmul_maps) == 3 + + maps = [map_attr.value for map_attr in matmul_maps] + assert maps[0] == a_map + assert maps[1] == b_map + assert maps[2] == c_map