Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir-c/Dialect/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
mlirLinalgInferConvolutionDimensions(MlirOperation op);

MLIR_CAPI_EXPORTED MlirAttribute
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

#ifdef __cplusplus
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Bindings/Python/DialectLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {

m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));

m.def(
"get_indexing_maps_attr",
[](MlirOperation op) -> std::optional<MlirAttribute> {
MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
if (mlirAttributeIsNull(attr))
return std::nullopt;
return attr;
},
"Returns the indexing_maps attribute for a linalg op.");
}

NB_MODULE(_mlirDialectsLinalg, m) {
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
return result;
}

MLIR_CAPI_EXPORTED MlirAttribute
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return MlirAttribute{nullptr};

ArrayAttr attr = linalgOp.getIndexingMaps();
return wrap(attr);
}

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
49 changes: 49 additions & 0 deletions mlir/test/python/dialects/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,52 @@ def dyn_conv_fn(input, filter, output):
assert list(dims.depth) == []
assert list(dims.strides) == [1, 1]
assert list(dims.dilations) == [1, 1]


@run
def test_get_indexing_maps_attr():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
a_type = RankedTensorType.get((4, 8), f32)
b_type = RankedTensorType.get((8, 16), f32)
c_type = RankedTensorType.get((4, 16), f32)

dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)

a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])

@func.FuncOp.from_py_func(a_type, b_type, c_type)
def matmul_func(a, b, c):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
assert not linalg.get_indexing_maps_attr(
zero.operation
), "Expected no indexing_maps on non-linalg op"

init = linalg.fill(zero, outs=[c])
fill_op = init.owner
fill_maps = linalg.get_indexing_maps_attr(fill_op)
assert fill_maps is not None
assert len(fill_maps) == 2

# The fill op should have maps like (d0, d1) -> () and (d0, d1).
fill_input_map = fill_maps[0].value
fill_output_map = fill_maps[1].value
assert fill_input_map == AffineMap.get(2, 0, [])
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])

result = linalg.matmul(a, b, outs=(init,))
matmul_op = result.owner
matmul_maps = linalg.get_indexing_maps_attr(matmul_op)
assert matmul_maps is not None
assert len(matmul_maps) == 3

maps = [map_attr.value for map_attr in matmul_maps]
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map
Loading