Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ bool isTopLevelValue(Value value, Region *region);
/// trait `AffineScope`; `nullptr` if there is no such region.
Region *getAffineScope(Operation *op);

/// Returns the closest region enclosing `op` that is held by a non-affine
/// operation; `nullptr` if there is no such region. This method is meant to
/// be used by affine analysis methods (e.g. dependence analysis) which are
/// only meaningful when performed among/between operations from the same
/// analysis scope.
Region *getAffineAnalysisScope(Operation *op);

/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
/// from a source memref to a destination memref. The source and destination
/// memref need not be of the same dimensionality, but need to have the same
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,8 @@ DependenceResult mlir::affine::checkMemrefAccessDependence(

// We can't analyze further if the ops lie in different affine scopes or have
// no common block in an affine scope.
if (getAffineScope(srcAccess.opInst) != getAffineScope(dstAccess.opInst))
if (getAffineAnalysisScope(srcAccess.opInst) !=
getAffineAnalysisScope(dstAccess.opInst))
return DependenceResult::Failure;
if (!getCommonBlockInAffineScope(srcAccess.opInst, dstAccess.opInst))
return DependenceResult::Failure;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2307,8 +2307,8 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(

Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
Operation *b) {
Region *aScope = mlir::affine::getAffineScope(a);
Region *bScope = mlir::affine::getAffineScope(b);
Region *aScope = getAffineAnalysisScope(a);
Region *bScope = getAffineAnalysisScope(b);
if (aScope != bScope)
return nullptr;

Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ Region *mlir::affine::getAffineScope(Operation *op) {
return nullptr;
}

Region *mlir::affine::getAffineAnalysisScope(Operation *op) {
Operation *curOp = op;
while (auto *parentOp = curOp->getParentOp()) {
if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
return curOp->getParentRegion();
curOp = parentOp;
}
return nullptr;
}

// A Value can be used as a dimension id iff it meets one of the following
// conditions:
// *) It is valid as a symbol.
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ static bool mustReachAtInnermost(const MemRefAccess &srcAccess,
const MemRefAccess &destAccess) {
// Affine dependence analysis is possible only if both ops in the same
// AffineScope.
if (getAffineScope(srcAccess.opInst) != getAffineScope(destAccess.opInst))
if (getAffineAnalysisScope(srcAccess.opInst) !=
getAffineAnalysisScope(destAccess.opInst))
return false;

unsigned nsLoops =
Expand All @@ -659,9 +660,9 @@ static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp,
// AffineScope. Also, we can only check if our affine scope is isolated from
// above; otherwise, values can from outside of the affine scope that the
// check below cannot analyze.
Region *srcScope = getAffineScope(srcMemOp);
Region *srcScope = getAffineAnalysisScope(srcMemOp);
if (srcAccess.memref == destAccess.memref &&
srcScope == getAffineScope(destMemOp)) {
srcScope == getAffineAnalysisScope(destMemOp)) {
unsigned nsLoops = getNumCommonSurroundingLoops(*srcMemOp, *destMemOp);
FlatAffineValueConstraints dependenceConstraints;
for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
Expand Down
41 changes: 36 additions & 5 deletions mlir/test/Dialect/Affine/scalrep.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,11 @@ func.func @with_inner_ops(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: i1)
return
}

// CHECK: %[[pi:.+]] = arith.constant 3.140000e+00 : f64
// CHECK: %{{.*}} = scf.if %arg2 -> (f64) {
// CHECK: scf.yield %{{.*}} : f64
// Semantics of non-affine region ops would be unknown.

// CHECK: } else {
// CHECK: scf.yield %[[pi]] : f64
// CHECK: }
// CHECK-NEXT: %[[Y:.*]] = affine.load
// CHECK-NEXT: scf.yield %[[Y]] : f64

// Check if scalar replacement works correctly when affine memory ops are in the
// body of an scf.for.
Expand Down Expand Up @@ -952,3 +951,35 @@ func.func @consecutive_store() {
}
return
}

// CHECK-LABEL: func @scf_for_if
func.func @scf_for_if(%arg0: memref<?xi32>, %arg1: i32) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%c5_i32 = arith.constant 5 : i32
%c10_i32 = arith.constant 10 : i32
%0 = memref.alloca() : memref<1xi32>
%1 = llvm.mlir.undef : i32
affine.store %1, %0[0] : memref<1xi32>
affine.store %c0_i32, %0[0] : memref<1xi32>
%2 = arith.index_cast %arg1 : i32 to index
scf.for %arg2 = %c0 to %2 step %c1 {
%4 = memref.load %arg0[%arg2] : memref<?xi32>
%5 = arith.muli %4, %c5_i32 : i32
%6 = arith.cmpi sgt, %5, %c10_i32 : i32
// CHECK: scf.if
scf.if %6 {
// No forwarding should happen here since we have an scf.for around and we
// can't analyze the flow of values.
// CHECK: affine.load
%7 = affine.load %0[0] : memref<1xi32>
%8 = arith.addi %5, %7 : i32
// CHECK: affine.store
affine.store %8, %0[0] : memref<1xi32>
}
}
// CHECK: affine.load
%3 = affine.load %0[0] : memref<1xi32>
return %3 : i32
}