diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 0ab201e158033..c57d193e62d25 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -22,6 +22,18 @@ extern "C" { MLIR_CAPI_EXPORTED void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); +MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op); + +struct MlirLinalgContractionDimensions { + MlirAttribute batch; + MlirAttribute m; + MlirAttribute n; + MlirAttribute k; +}; + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 548df4ee100aa..978ea8664b6b9 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -8,10 +8,25 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; +using namespace mlir::python::nanobind_adaptors; + +static std::optional +InferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensions(op); + + // Detect "empty" result. This occurs when `op` is not a contraction op, + // or when `linalg::inferContractionDims` fails. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; +} static void populateDialectLinalgSubmodule(nb::module_ m) { m.def( @@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { nb::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); + + m.def("isa_contraction_op", &mlirLinalgIsContractionOp, + "Checks if the given operation is a Linalg contraction operation.", + nb::arg("op")); + + nb::class_(m, "ContractionDimensions") + .def_prop_ro("batch", + [](const MlirLinalgContractionDimensions &self) { + return self.batch; + }) + .def_prop_ro( + "m", + [](const MlirLinalgContractionDimensions &self) { return self.m; }) + .def_prop_ro( + "n", + [](const MlirLinalgContractionDimensions &self) { return self.n; }) + .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) { + return self.k; + }); + + m.def("infer_contraction_dimensions", &InferContractionDimensions, + "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction " + "op.", + nb::arg("op")); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 2fb5bc651de07..362b89bdef6c9 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { fun(b, *body, op->getAttrs()); } +MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + // isaContractionOpInterface handles null linalgOp internally. + return linalg::isaContractionOpInterface(linalgOp); +} + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions result{}; + auto linalgOp = dyn_cast(unwrap(op)); + if (!linalgOp) + return result; + + FailureOr maybeDims = + linalg::inferContractionDims(linalgOp); + if (failed(maybeDims)) + return result; + + linalg::ContractionDimensions contractionDims = *maybeDims; + MLIRContext *ctx = linalgOp.getContext(); + + auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap( + DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + result.batch = toAttr(contractionDims.batch); + result.m = toAttr(contractionDims.m); + result.n = toAttr(contractionDims.n); + result.k = toAttr(contractionDims.k); + + return result; +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py new file mode 100644 index 0000000000000..a48aa90fa5836 --- /dev/null +++ b/mlir/test/python/dialects/linalg/utils.py @@ -0,0 +1,97 @@ +# RUN: %PYTHON %s + +from mlir.dialects import arith, func, linalg +from mlir.dialects.linalg.opdsl.lang import * +from mlir.ir import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +@run +def test_infer_contraction_dimensions_from_ops(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + # === Static shapes === + m, n, k = 4, 4, 4 + a_type = RankedTensorType.get((m, k), f32) + b_type = RankedTensorType.get((k, n), f32) + c_type = RankedTensorType.get((m, n), f32) + + @func.FuncOp.from_py_func(a_type, b_type, c_type) + def contraction_fn(a, b, c): + zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) + filled = linalg.fill(zero, outs=[c]) + fill_op = filled.owner + + assert not linalg.isa_contraction_op(zero.operation) + assert not linalg.isa_contraction_op(fill_op) + assert linalg.infer_contraction_dimensions(fill_op) is None + + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + result = linalg.contract( + a, + b, + outs=(filled,), + indexing_maps=[a_map, b_map, c_map], + ) + contraction_op = result.owner + + assert linalg.isa_contraction_op(contraction_op) + dims = linalg.infer_contraction_dimensions(contraction_op) + assert dims is not None + + # Expect m=[0], n=[1], k=[2] as per standard matmul + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert ( + list(dims.batch) == [] + ), f"Expected batch=[], got {list(dims.batch)}" + + # === Dynamic shape case === + dyn = ShapedType.get_dynamic_size() + a_dyn_type = RankedTensorType.get((4, dyn), f32) + b_dyn_type = RankedTensorType.get((dyn, 4), f32) + c_type = RankedTensorType.get((4, 4), f32) + + @func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type) + def dynamic_contraction_fn(a, b, c): + zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) + filled = linalg.fill(zero, outs=[c]) + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + result = linalg.contract( + a, + b, + outs=(filled,), + indexing_maps=[a_map, b_map, c_map], + ) + contraction_op = result.owner + + assert linalg.isa_contraction_op(contraction_op) + dims = linalg.infer_contraction_dimensions(contraction_op) + assert dims is not None + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert ( + list(dims.batch) == [] + ), f"Expected batch=[], got {list(dims.batch)}"