Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,11 @@ class IntegerRelation {
IntMatrix inequalities;
};

inline raw_ostream &operator<<(raw_ostream &os, const IntegerRelation &rel) {
rel.print(os);
return os;
}

/// An IntegerPolyhedron represents the set of points from a PresburgerSpace
/// that satisfy a list of affine constraints. Affine constraints can be
/// inequalities or equalities in the form:
Expand Down
78 changes: 37 additions & 41 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>

Expand Down Expand Up @@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
}

bool MemRefDependenceGraph::init() {
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
LDBG() << "--- Initializing MDG ---";
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
Expand Down Expand Up @@ -288,16 +289,15 @@ bool MemRefDependenceGraph::init() {
// Return false if non-handled/unknown region-holding ops are found. We
// won't know what such ops do or what its regions mean; for e.g., it may
// not be an imperative op.
LLVM_DEBUG(llvm::dbgs()
<< "MDG init failed; unknown region-holding op found!\n");
LDBG() << "MDG init failed; unknown region-holding op found!";
return false;
}
// We aren't creating nodes for memory-effect free ops either with no
// regions (unless it has results being used) or those with branch op
// interface.
}

LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
LDBG() << "Created " << nodes.size() << " nodes";

// Add dependence edges between nodes which produce SSA values and their
// users. Load ops can be considered as the ones producing SSA values.
Expand Down Expand Up @@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
gatherDefiningNodes(dstId, definingNodes);
if (llvm::any_of(definingNodes,
[&](unsigned id) { return hasDependencePath(srcId, id); })) {
LLVM_DEBUG(llvm::dbgs()
<< "Can't fuse: a defining op with a user in the dst "
"loop has dependence from the src loop\n");
LDBG() << "Can't fuse: a defining op with a user in the dst "
<< "loop has dependence from the src loop";
return nullptr;
}

Expand Down Expand Up @@ -957,28 +956,28 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
FlatAffineValueConstraints srcConstraints;
// TODO: Store the source's domain to avoid computation at each depth.
if (failed(getSourceAsConstraints(srcConstraints))) {
LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
LDBG() << "Unable to compute source's domain";
return std::nullopt;
}
// As the set difference utility currently cannot handle symbols in its
// operands, validity of the slice cannot be determined.
if (srcConstraints.getNumSymbolVars() > 0) {
LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
LDBG() << "Cannot handle symbols in source domain";
return std::nullopt;
}
// TODO: Handle local vars in the source domains while using the 'projectOut'
// utility below. Currently, aligning is not done assuming that there will be
// no local vars in the source domain.
if (srcConstraints.getNumLocalVars() != 0) {
LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
LDBG() << "Cannot handle locals in source domain";
return std::nullopt;
}

// Create constraints for the slice loop nest that would be created if the
// fusion succeeds.
FlatAffineValueConstraints sliceConstraints;
if (failed(getAsConstraints(&sliceConstraints))) {
LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
LDBG() << "Unable to compute slice's domain";
return std::nullopt;
}

Expand All @@ -987,19 +986,19 @@ std::optional<bool> ComputationSliceState::isSliceValid() const {
sliceConstraints.projectOut(ivs.size(),
sliceConstraints.getNumVars() - ivs.size());

LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
LLVM_DEBUG(srcConstraints.dump());
LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
"(expressed in terms of its source's IVs):\n");
LLVM_DEBUG(sliceConstraints.dump());
LDBG() << "Domain of the source of the slice:\n"
<< "Source constraints:" << srcConstraints
<< "\nDomain of the slice if this fusion succeeds "
<< "(expressed in terms of its source's IVs):\n"
<< "Slice constraints:" << sliceConstraints;

// TODO: Store 'srcSet' to avoid recalculating for each depth.
PresburgerSet srcSet(srcConstraints);
PresburgerSet sliceSet(sliceConstraints);
PresburgerSet diffSet = sliceSet.subtract(srcSet);

if (!diffSet.isIntegerEmpty()) {
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
LDBG() << "Incorrect slice";
return false;
}
return true;
Expand Down Expand Up @@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,

unsigned rank = access.getRank();

LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
<< "\ndepth: " << loopDepth << "\n";);
LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth;

// 0-d memrefs.
if (rank == 0) {
Expand Down Expand Up @@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
if (auto constVal = getConstantIntValue(symbol))
cst.addBound(BoundType::EQ, symbol, constVal.value());
} else {
LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
LDBG() << "unknown affine dimensional value";
return failure();
}
}
Expand All @@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// Add access function equalities to connect loop IVs to data dimensions.
if (failed(cst.composeMap(&accessValueMap))) {
op->emitError("getMemRefRegion: compose affine map failed");
LLVM_DEBUG(accessValueMap.getAffineMap().dump());
LDBG() << "Access map: " << accessValueMap.getAffineMap();
return failure();
}

Expand Down Expand Up @@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
}
cst.removeTrivialRedundancy();

LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
LLVM_DEBUG(cst.dump());
LDBG() << "Memory region: " << cst;
return success();
}

Expand Down Expand Up @@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() {
auto memRefType = cast<MemRefType>(memref.getType());

if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
LDBG() << "Non-identity layout map not yet supported";
return false;
}

// Compute the extents of the buffer.
std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
if (!numElements) {
LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
LDBG() << "Dynamic shapes not yet supported";
return std::nullopt;
}
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
Expand Down Expand Up @@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
/*addMemRefDimBounds=*/false)))
return success();

LLVM_DEBUG(llvm::dbgs() << "Memory region");
LLVM_DEBUG(region.getConstraints()->dump());
LDBG() << "Memory region: " << region.getConstraints();

bool outOfBounds = false;
unsigned rank = loadOrStoreOp.getMemRefType().getRank();
Expand Down Expand Up @@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
(isBackwardSlice && loopDepth > getNestingDepth(b))) {
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
LDBG() << "Invalid loop depth";
return SliceComputationResult::GenericFailure;
}

Expand All @@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
&dependenceConstraints, /*dependenceComponents=*/nullptr,
/*allowRAR=*/readReadAccesses);
if (result.value == DependenceResult::Failure) {
LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
LDBG() << "Dependence check failed";
return SliceComputationResult::GenericFailure;
}
if (result.value == DependenceResult::NoDependence)
Expand All @@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute slice bound constraints\n");
LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}
assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
Expand All @@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
FlatAffineValueConstraints tmpSliceCst;
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute slice bound constraints\n");
LDBG() << "Unable to compute slice bound constraints";
return SliceComputationResult::GenericFailure;
}

Expand Down Expand Up @@ -1630,16 +1624,15 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
if (sliceUnionCst.getNumLocalVars() > 0 ||
tmpSliceCst.getNumLocalVars() > 0 ||
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute union bounding box of slice bounds\n");
LDBG() << "Unable to compute union bounding box of slice bounds";
return SliceComputationResult::GenericFailure;
}
}
}

// Empty union.
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
LDBG() << "empty slice union - unexpected";
return SliceComputationResult::GenericFailure;
}

Expand All @@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
unsigned innermostCommonLoopDepth =
getInnermostCommonLoopDepth(ops, &surroundingLoops);
if (loopDepth > innermostCommonLoopDepth) {
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
LDBG() << "Exceeds max loop depth";
return SliceComputationResult::GenericFailure;
}

Expand Down Expand Up @@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
// that the slice is valid, otherwise return appropriate failure status.
std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
if (!isSliceValid) {
LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
LDBG() << "Cannot determine if the slice is valid";
return SliceComputationResult::GenericFailure;
}
if (!*isSliceValid)
Expand Down Expand Up @@ -2050,17 +2043,20 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
if (failed(
region->compute(opInst,
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
LDBG() << "Error obtaining memory region";
opInst->emitError("error obtaining memory region");
return failure();
}

auto [it, inserted] = regions.try_emplace(region->memref);
if (inserted) {
it->second = std::move(region);
} else if (failed(it->second->unionBoundingBox(*region))) {
LLVM_DEBUG(opInst->emitWarning(
LDBG() << "getMemoryFootprintBytes: unable to perform a union on a "
"memory region";
opInst->emitWarning(
"getMemoryFootprintBytes: unable to perform a union on a memory "
"region"));
"region");
return failure();
}
return WalkResult::advance();
Expand Down