diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h index 14e951ddee9ad..ab768eb2ec870 100644 --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose( MlirAffineExpr affineExpr, struct MlirAffineMap affineMap); +/// Replace dims[offset ... numDims) +/// by dims[offset + shift ... shift + numDims). +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims, + uint32_t shift, uint32_t offset); + +/// Replace symbols[offset ... numSymbols) +/// by symbols[offset + shift ... shift + numSymbols). +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols, + uint32_t shift, uint32_t offset); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index a2df824f59a53..3c95d29c4bcca 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyAffineExpr(self.getContext(), mlirAffineExprCompose(self, other)); }) + .def( + "shift_dims", + [](PyAffineExpr &self, uint32_t numDims, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftDims(self, numDims, shift, offset)); + }, + nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0) + .def( + "shift_symbols", + [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftSymbols(self, numSymbols, shift, offset)); + }, + nb::arg("num_symbols"), nb::arg("shift"), + nb::arg("offset").none() = 0) .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 6e3328b65cb08..bc3dcd4174736 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); } +MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr, + uint32_t numDims, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset)); +} + +MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, + uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset)); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py index c7861c1acfe12..2f64aff143420 100644 --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -405,3 +405,14 @@ def testHash(): dictionary[s1] = 1 assert d0 in dictionary assert s1 in dictionary + + +# CHECK-LABEL: TEST: testAffineExprShift +@run +def testAffineExprShift(): + with Context() as ctx: + dims = [AffineExpr.get_dim(i) for i in range(4)] + syms = [AffineExpr.get_symbol(i) for i in range(4)] + + assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2) + assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)