Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir-c/AffineExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);

/// Replace dims[offset ... numDims)
/// by dims[offset + shift ... shift + numDims).
MLIR_CAPI_EXPORTED MlirAffineExpr
mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
uint32_t shift, uint32_t offset);

/// Replace symbols[offset ... numSymbols)
/// by symbols[offset + shift ... shift + numSymbols).
MLIR_CAPI_EXPORTED MlirAffineExpr
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
uint32_t shift, uint32_t offset);

//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Bindings/Python/IRAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
.def(
"shift_dims",
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
uint32_t offset) {
return PyAffineExpr(
self.getContext(),
mlirAffineExprShiftDims(self, numDims, shift, offset));
},
nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
.def(
"shift_symbols",
[](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
uint32_t offset) {
return PyAffineExpr(
self.getContext(),
mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
},
nb::arg("num_symbols"), nb::arg("shift"),
nb::arg("offset").none() = 0)
.def_static(
"get_add", &PyAffineAddExpr::get,
"Gets an affine expression containing a sum of two expressions.")
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/CAPI/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
}

MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
uint32_t numDims, uint32_t shift,
uint32_t offset) {
return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
}

MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
uint32_t numSymbols, uint32_t shift,
uint32_t offset) {
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
}

//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/python/ir/affine_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,14 @@ def testHash():
dictionary[s1] = 1
assert d0 in dictionary
assert s1 in dictionary


# CHECK-LABEL: TEST: testAffineExprShift
@run
def testAffineExprShift():
with Context() as ctx:
dims = [AffineExpr.get_dim(i) for i in range(4)]
syms = [AffineExpr.get_symbol(i) for i in range(4)]

assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)