Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir-c/Dialect/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);

struct MlirLinalgContractionDimensions {
MlirAttribute batch;
MlirAttribute m;
MlirAttribute n;
MlirAttribute k;
};

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

#ifdef __cplusplus
Expand Down
62 changes: 61 additions & 1 deletion mlir/lib/Bindings/Python/DialectLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,45 @@
//
//===----------------------------------------------------------------------===//

#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;

struct PyContractionDimensions {
MlirLinalgContractionDimensions value;

PyContractionDimensions() = default;
PyContractionDimensions(const MlirLinalgContractionDimensions &v)
: value(v) {}
};

static std::optional<PyContractionDimensions>
mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
MlirLinalgContractionDimensions dims =
mlirLinalgInferContractionDimensions(op);

// Detect "empty" result.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
return PyContractionDimensions{dims};
}

static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
std::vector<int32_t> result;
int64_t size = mlirDenseArrayGetNumElements(attr);
result.reserve(size);
for (int64_t i = 0; i < size; ++i) {
result.push_back(mlirDenseI32ArrayGetElement(attr, i));
}
return result;
}

static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
Expand All @@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");

m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
"Checks if the given operation is a Linalg contraction operation.",
nb::arg("op"));

nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
.def_prop_ro("batch",
[](const PyContractionDimensions &self) {
return convertDenseI32AttrToList(self.value.batch);
})
.def_prop_ro("m",
[](const PyContractionDimensions &self) {
return convertDenseI32AttrToList(self.value.m);
})
.def_prop_ro("n",
[](const PyContractionDimensions &self) {
return convertDenseI32AttrToList(self.value.n);
})
.def_prop_ro("k", [](const PyContractionDimensions &self) {
return convertDenseI32AttrToList(self.value.k);
});

m.def("infer_contraction_dimensions",
&mlirLinalgInferContractionDimensionsBinding,
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
"op.",
nb::arg("op"));
}

NB_MODULE(_mlirDialectsLinalg, m) {
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
return linalg::isaContractionOpInterface(linalgOp);
}

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op) {
MlirLinalgContractionDimensions result{};
auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return result;

auto maybeDims = linalg::inferContractionDims(linalgOp);
if (failed(maybeDims))
return result;

linalg::ContractionDimensions contractionDims = maybeDims.value();
MLIRContext *ctx = linalgOp.getContext();

auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
SmallVector<int32_t> intVals(vals.begin(), vals.end());
return wrap(DenseI32ArrayAttr::get(ctx, intVals));
};

result.batch = toAttr(contractionDims.batch);
result.m = toAttr(contractionDims.m);
result.n = toAttr(contractionDims.n);
result.k = toAttr(contractionDims.k);

return result;
}

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
35 changes: 35 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,38 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)


@run
def test_infer_contraction_dimensions():
with Context(), Location.unknown():
module = ir.Module.parse(
r"""
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
-> tensor<4x4xf32> {
%cst = arith.constant 0.0 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
}
"""
)
func_op = module.body.operations[0]
body_block = func_op.regions[0].blocks[0]
fill_op = body_block.operations[1]
matmul_op = body_block.operations[2]

assert not linalg.isa_contraction_op(fill_op)
assert linalg.isa_contraction_op(matmul_op)

dims = linalg.infer_contraction_dimensions(fill_op)
assert dims is None
dims = linalg.infer_contraction_dimensions(matmul_op)
assert dims

assert dims.m == [0], f"Expected m=[0], got {dims.m}"
assert dims.n == [1], f"Expected n=[1], got {dims.n}"
assert dims.k == [2], f"Expected k=[2], got {dims.k}"
Loading