diff --git a/mlir/unittests/IR/AffineMapTest.cpp b/mlir/unittests/IR/AffineMapTest.cpp index eaeb18d128ca5..166692f731d1c 100644 --- a/mlir/unittests/IR/AffineMapTest.cpp +++ b/mlir/unittests/IR/AffineMapTest.cpp @@ -76,3 +76,57 @@ TEST(AffineMapTest, isProjectedPermutation) { AffineMap map10 = AffineMap::get(6, 0, {d5, d3, d2, d4}, &ctx); EXPECT_TRUE(map10.isProjectedPermutation()); } + +TEST(AffineMapTest, getInversePermutation) { + MLIRContext ctx; + OpBuilder b(&ctx); + + // 0. Empty map + AffineMap map0 = AffineMap::get(0, 0, {}, &ctx); + AffineMap inverseMap0 = inversePermutation(map0); + EXPECT_TRUE(inverseMap0.isEmpty()); + + auto d0 = b.getAffineDimExpr(0); + auto d1 = b.getAffineDimExpr(1); + auto d2 = b.getAffineDimExpr(2); + + // 1. (d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0) + AffineMap map1 = AffineMap::get(3, 0, {d1, d1, d0, d2, d1, d2, d1, d0}, &ctx); + // (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3) + AffineMap inverseMap1 = inversePermutation(map1); + auto resultsInv1 = inverseMap1.getResults(); + EXPECT_EQ(resultsInv1.size(), 3UL); + + // 1.1 Expect d2 + AffineDimExpr expr = llvm::dyn_cast(resultsInv1[0]); + EXPECT_TRUE(expr && expr.getPosition() == 2); + + // 1.2 Expect d0 + expr = llvm::dyn_cast(resultsInv1[1]); + EXPECT_TRUE(expr && expr.getPosition() == 0); + + // 1.3 Expect d3 + expr = llvm::dyn_cast(resultsInv1[2]); + EXPECT_TRUE(expr && expr.getPosition() == 3); + + // 2. (d0, d1, d2) -> (d1, d0 + d1, d0, d2, d1, d2, d1, d0) + auto sum = d0 + d1; + AffineMap map2 = + AffineMap::get(3, 0, {d1, sum, d0, d2, d1, d2, d1, d0}, &ctx); + // (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3) + AffineMap inverseMap2 = inversePermutation(map2); + auto resultsInv2 = inverseMap2.getResults(); + EXPECT_EQ(resultsInv2.size(), 3UL); + + // 2.1 Expect d2 + expr = llvm::dyn_cast(resultsInv2[0]); + EXPECT_TRUE(expr && expr.getPosition() == 2); + + // 2.2 Expect d0 + expr = llvm::dyn_cast(resultsInv2[1]); + EXPECT_TRUE(expr && expr.getPosition() == 0); + + // 2.3 Expect d3 + expr = llvm::dyn_cast(resultsInv2[2]); + EXPECT_TRUE(expr && expr.getPosition() == 3); +}