Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
}

// The (partially complete) expression should have no recursion in it. After
// all, we're trying to avoid recursion using this codepath.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document return value

static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
SmallVectorImpl<const MCExpr *> &Exprs) {
SmallVectorImpl<const MCExpr *> &Exprs,
SmallPtrSetImpl<const MCExpr *> &Visited) {
// Assert if any of the expressions is already visited (i.e., there is
// existing recursion).
assert(!Visited.contains(Expr) &&
"Expr should not exist in Visited as we're avoiding recursion");
Visited.insert(Expr);

switch (Expr->getKind()) {
default:
return false;
Expand All @@ -107,17 +116,17 @@ static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
}
case MCExpr::ExprKind::Binary: {
const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
return findSymbolInExpr(Sym, BExpr->getLHS(), Exprs) ||
findSymbolInExpr(Sym, BExpr->getRHS(), Exprs);
return findSymbolInExpr(Sym, BExpr->getLHS(), Exprs, Visited) ||
findSymbolInExpr(Sym, BExpr->getRHS(), Exprs, Visited);
}
case MCExpr::ExprKind::Unary: {
const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
return findSymbolInExpr(Sym, UExpr->getSubExpr(), Exprs);
return findSymbolInExpr(Sym, UExpr->getSubExpr(), Exprs, Visited);
}
case MCExpr::ExprKind::Target: {
const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
for (const MCExpr *E : AGVK->getArgs()) {
if (findSymbolInExpr(Sym, E, Exprs))
if (findSymbolInExpr(Sym, E, Exprs, Visited))
return true;
}
return false;
Expand All @@ -132,11 +141,12 @@ static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
// contains the symbol Expr is associated with.
static bool foundRecursiveSymbolDef(MCSymbol *Sym, const MCExpr *Expr) {
SmallVector<const MCExpr *, 8> WorkList;
SmallPtrSet<const MCExpr *, 8> Visited;
WorkList.push_back(Expr);

while (!WorkList.empty()) {
const MCExpr *CurExpr = WorkList.pop_back_val();
if (findSymbolInExpr(Sym, CurExpr, WorkList))
if (findSymbolInExpr(Sym, CurExpr, WorkList, Visited))
return true;
}

Expand Down
Loading