Skip to content

Conversation

@Hardcode84
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/131521.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/AffineExpr.h (+12)
  • (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+19)
  • (modified) mlir/lib/CAPI/IR/AffineExpr.cpp (+12)
  • (modified) mlir/test/python/ir/affine_expr.py (+11)
diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 14e951ddee9ad..ab768eb2ec870 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
 MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
     MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);
 
+/// Replace dims[offset ... numDims)
+/// by dims[offset + shift ... shift + numDims).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
+                        uint32_t shift, uint32_t offset);
+
+/// Replace symbols[offset ... numSymbols)
+/// by symbols[offset + shift ... shift + numSymbols).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
+                           uint32_t shift, uint32_t offset);
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index a2df824f59a53..3c95d29c4bcca 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
              return PyAffineExpr(self.getContext(),
                                  mlirAffineExprCompose(self, other));
            })
+      .def(
+          "shift_dims",
+          [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
+             uint32_t offset) {
+            return PyAffineExpr(
+                self.getContext(),
+                mlirAffineExprShiftDims(self, numDims, shift, offset));
+          },
+          nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
+      .def(
+          "shift_symbols",
+          [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
+             uint32_t offset) {
+            return PyAffineExpr(
+                self.getContext(),
+                mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
+          },
+          nb::arg("num_symbols"), nb::arg("shift"),
+          nb::arg("offset").none() = 0)
       .def_static(
           "get_add", &PyAffineAddExpr::get,
           "Gets an affine expression containing a sum of two expressions.")
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 6e3328b65cb08..bc3dcd4174736 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
   return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
 }
 
+MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
+                                       uint32_t numDims, uint32_t shift,
+                                       uint32_t offset) {
+  return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
+}
+
+MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
+                                          uint32_t numSymbols, uint32_t shift,
+                                          uint32_t offset) {
+  return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index c7861c1acfe12..2f64aff143420 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -405,3 +405,14 @@ def testHash():
         dictionary[s1] = 1
         assert d0 in dictionary
         assert s1 in dictionary
+
+
+# CHECK-LABEL: TEST: testAffineExprShift
+@run
+def testAffineExprShift():
+    with Context() as ctx:
+        dims = [AffineExpr.get_dim(i) for i in range(4)]
+        syms = [AffineExpr.get_symbol(i) for i in range(4)]
+
+        assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
+        assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)

@Hardcode84 Hardcode84 merged commit 7c98cdd into llvm:main Mar 16, 2025
13 checks passed
@Hardcode84 Hardcode84 deleted the pyaffine-shift branch March 16, 2025 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants