diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 2de057d1d0758..4063740a9acd1 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation( unsigned maxLoopDepth = loops.size(); if (maxLoopDepth == 1) return true; + + // We cannot guarantee the validity of the interchange if the loops have + // iter_args, since the dependence analysis does not take them into account. + // Conservatively return false in such cases. + if (llvm::any_of(loops, [](AffineForOp loop) { + return loop.getNumIterOperands() > 0; + })) + return false; + // Gather dependence components for dependences between all ops in loop nest // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. std::vector> depCompsVec; diff --git a/mlir/test/Dialect/Affine/loop-permute.mlir b/mlir/test/Dialect/Affine/loop-permute.mlir index 118165b2fb2a2..e38aeb543fceb 100644 --- a/mlir/test/Dialect/Affine/loop-permute.mlir +++ b/mlir/test/Dialect/Affine/loop-permute.mlir @@ -4,6 +4,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=0,2,1" | FileCheck %s --check-prefix=CHECK-021 // RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,0,1" | FileCheck %s --check-prefix=CHECK-201 // RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,1,0" | FileCheck %s --check-prefix=CHECK-210 +// RUN: mlir-opt -allow-unregistered-dialect %s -test-loop-permutation="permutation-map=2,1,0 check-validity=1" | FileCheck %s --check-prefix=CHECK-210-VALID // CHECK-120-LABEL: func @permute func.func @permute(%U0 : index, %U1 : index, %U2 : index) { @@ -45,3 +46,34 @@ func.func @permute(%U0 : index, %U1 : index, %U2 : index) { // CHECK-201: "foo"(%arg5, %arg3) // CHECK-201-NEXT: "bar"(%arg4) + +// ----- + +// Tests that the permutation validation check utility conservatively returns false when the +// for loop has an iter_arg. + +// CHECK-210-VALID-LABEL: func @check_validity_with_iter_args +// CHECK-210-VALID-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index +func.func @check_validity_with_iter_args(%U0 : index, %U1 : index, %U2 : index) { + %buf = memref.alloc() : memref<100x100xf32> + %cst = arith.constant 1.0 : f32 + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + // Check that the loops are not permuted. + // CHECK-210-VALID: affine.for %{{.*}} = 0 to %[[ARG0]] { + // CHECK-210-VALID-NEXT: affine.for %{{.*}} = 0 to %[[ARG1]] { + // CHECK-210-VALID-NEXT: affine.for %{{.*}} = 0 to %[[ARG2]] iter_args( + affine.for %arg0 = 0 to %U0 { + affine.for %arg1 = 0 to %U1 { + %res = affine.for %arg2 = 0 to %U2 iter_args(%iter1 = %cst) -> (f32) { + %val = affine.load %buf[%arg0 + 10, %arg1 + 20] : memref<100x100xf32> + %newVal = arith.addf %val, %cst : f32 + affine.store %newVal, %buf[%arg0 + 10, %arg1 + 20] : memref<100x100xf32> + %newVal2 = arith.addf %newVal, %iter1 : f32 + affine.yield %iter1 : f32 + } + } + } + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp index e708b7de690ec..8bab9a0ef55b8 100644 --- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp +++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp @@ -42,6 +42,12 @@ struct TestLoopPermutation ListOption permList{*this, "permutation-map", llvm::cl::desc("Specify the loop permutation"), llvm::cl::OneOrMore}; + + /// Specify whether to check validity of loop permutation. + Option checkValidity{ + *this, "check-validity", + llvm::cl::desc("Check validity of the loop permutation"), + llvm::cl::init(false)}; }; } // namespace @@ -60,6 +66,9 @@ void TestLoopPermutation::runOnOperation() { // Permute if the nest's size is consistent with the specified // permutation. if (nest.size() >= 2 && nest.size() == permMap.size()) { + if (checkValidity.getValue() && + !isValidLoopInterchangePermutation(nest, permMap)) + continue; permuteLoops(nest, permMap); } }