Skip to content

Commit e8bc187

Browse files
authored
Fix nametoordinal (#2221)
1 parent 495fde3 commit e8bc187

File tree

1 file changed

+57
-48
lines changed

1 file changed

+57
-48
lines changed

enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ struct VariableSetting {
207207
StringMap<std::vector<int>> extractions;
208208

209209
std::tuple<std::string, bool, std::vector<int>>
210-
lookup(StringRef name, const Record *pattern, const Init *resultRoot) {
210+
lookup(StringRef name, const Record *pattern, const Init *resultRoot) const {
211211
auto ord = nameToOrdinal.find(name);
212212
if (ord == nameToOrdinal.end())
213213
PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") +
@@ -1192,14 +1192,16 @@ void handleUse(
11921192
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
11931193
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
11941194
const DagInit *tree,
1195-
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition);
1195+
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
1196+
const VariableSetting &nameToOrdinal);
11961197

11971198
void handleUseArgument(
11981199
StringRef name, const Init *arg, bool usesPrimal, bool usesShadow,
11991200
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
12001201
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
12011202
const DagInit *tree,
1202-
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
1203+
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
1204+
const VariableSetting &nameToOrdinal) {
12031205

12041206
auto arg2 = dyn_cast<DagInit>(arg);
12051207

@@ -1218,7 +1220,8 @@ void handleUseArgument(
12181220
handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse,
12191221
name.size() ? foundShadowUse2 : foundShadowUse,
12201222
name.size() ? foundDiffRet2 : foundDiffRet,
1221-
usesPrimal ? precondition : "", tree, varNameToCondition);
1223+
usesPrimal ? precondition : "", tree, varNameToCondition,
1224+
nameToOrdinal);
12221225

12231226
if (name.size()) {
12241227
if (foundPrimalUse2.size() &&
@@ -1306,7 +1309,8 @@ void handleUse(
13061309
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
13071310
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
13081311
const DagInit *tree,
1309-
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
1312+
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
1313+
const VariableSetting &nameToOrdinal) {
13101314
auto opName = resultTree->getOperator()->getAsString();
13111315
auto Def = cast<DefInit>(resultTree->getOperator())->getDef();
13121316
if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) {
@@ -1339,7 +1343,9 @@ void handleUse(
13391343
if (numArgs == 3) {
13401344
if (isa<UnsetInit>(resultTree->getArg(0)) && resultTree->getArgName(0)) {
13411345
auto name = resultTree->getArgName(0)->getAsUnquotedString();
1342-
conditionStr = ReplaceAll(conditionStr, "imVal", name);
1346+
auto [ord, isVec, ext] = nameToOrdinal.lookup(name, nullptr, nullptr);
1347+
assert(!isVec);
1348+
conditionStr = ReplaceAll(conditionStr, "imVal", ord);
13431349
} else
13441350
assert("Requires name for arg");
13451351
}
@@ -1362,7 +1368,7 @@ void handleUse(
13621368
auto arg = resultTree->getArg(i);
13631369
handleUseArgument(name, arg, true, false, root, resultTree,
13641370
foundPrimalUse, foundShadowUse, foundDiffRet,
1365-
precondition2, tree, varNameToCondition);
1371+
precondition2, tree, varNameToCondition, nameToOrdinal);
13661372
}
13671373

13681374
return;
@@ -1375,16 +1381,57 @@ void handleUse(
13751381
auto name = resultTree->getArgNameStr(argEn.index());
13761382
handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root,
13771383
resultTree, foundPrimalUse, foundShadowUse, foundDiffRet,
1378-
precondition, tree, varNameToCondition);
1384+
precondition, tree, varNameToCondition, nameToOrdinal);
13791385
}
13801386
}
13811387

1388+
static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
1389+
StringRef origName) {
1390+
VariableSetting nameToOrdinal;
1391+
std::function<void(const DagInit *, ArrayRef<unsigned>)> insert =
1392+
[&](const DagInit *ptree, ArrayRef<unsigned> prev) {
1393+
unsigned i = 0;
1394+
for (auto tree : ptree->getArgs()) {
1395+
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
1396+
next.push_back(i);
1397+
if (auto dg = dyn_cast<DagInit>(tree))
1398+
insert(dg, next);
1399+
1400+
if (ptree->getArgNameStr(i).size()) {
1401+
std::string op;
1402+
if (intrinsic != MLIRDerivatives)
1403+
op = (origName + ".getOperand(" + Twine(next[0]) + ")").str();
1404+
else
1405+
op = (origName + "->getOperand(" + Twine(next[0]) + ")").str();
1406+
std::vector<int> extractions;
1407+
if (prev.size() > 0) {
1408+
for (unsigned i = 1; i < next.size(); i++) {
1409+
extractions.push_back(next[i]);
1410+
}
1411+
}
1412+
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
1413+
extractions);
1414+
}
1415+
i++;
1416+
}
1417+
};
1418+
1419+
insert(tree, {});
1420+
1421+
if (tree->getNameStr().size())
1422+
nameToOrdinal.insert(tree->getNameStr(),
1423+
(Twine("(&") + origName + ")").str(), false, {});
1424+
return nameToOrdinal;
1425+
}
1426+
13821427
void printDiffUse(
13831428
raw_ostream &os, Twine prefix, const ListInit *argOps, StringRef origName,
13841429
ActionType intrinsic, const DagInit *tree,
13851430
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
13861431
os << prefix << " // Rule " << *tree << "\n";
13871432

1433+
VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName);
1434+
13881435
for (auto argOpEn : enumerate(*argOps)) {
13891436
size_t argIdx = argOpEn.index();
13901437
if (auto resultRoot = dyn_cast<DagInit>(argOpEn.value())) {
@@ -1417,7 +1464,8 @@ void printDiffUse(
14171464

14181465
// hasDiffeRet(resultTree)
14191466
handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse,
1420-
foundDiffRet, /*precondition*/ "true", tree, varNameToCondition);
1467+
foundDiffRet, /*precondition*/ "true", tree, varNameToCondition,
1468+
nameToOrdinal);
14211469

14221470
os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n";
14231471

@@ -1587,45 +1635,6 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
15871635
os << " mlir::Value dif = nullptr;\n";
15881636
}
15891637

1590-
static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
1591-
StringRef origName) {
1592-
VariableSetting nameToOrdinal;
1593-
std::function<void(const DagInit *, ArrayRef<unsigned>)> insert =
1594-
[&](const DagInit *ptree, ArrayRef<unsigned> prev) {
1595-
unsigned i = 0;
1596-
for (auto tree : ptree->getArgs()) {
1597-
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
1598-
next.push_back(i);
1599-
if (auto dg = dyn_cast<DagInit>(tree))
1600-
insert(dg, next);
1601-
1602-
if (ptree->getArgNameStr(i).size()) {
1603-
std::string op;
1604-
if (intrinsic != MLIRDerivatives)
1605-
op = (origName + ".getOperand(" + Twine(next[0]) + ")").str();
1606-
else
1607-
op = (origName + "->getOperand(" + Twine(next[0]) + ")").str();
1608-
std::vector<int> extractions;
1609-
if (prev.size() > 0) {
1610-
for (unsigned i = 1; i < next.size(); i++) {
1611-
extractions.push_back(next[i]);
1612-
}
1613-
}
1614-
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
1615-
extractions);
1616-
}
1617-
i++;
1618-
}
1619-
};
1620-
1621-
insert(tree, {});
1622-
1623-
if (tree->getNameStr().size())
1624-
nameToOrdinal.insert(tree->getNameStr(),
1625-
(Twine("(&") + origName + ")").str(), false, {});
1626-
return nameToOrdinal;
1627-
}
1628-
16291638
static void emitReverseCommon(raw_ostream &os, const Record *pattern,
16301639
const DagInit *tree, ActionType intrinsic,
16311640
StringRef origName, const ListInit *argOps) {

0 commit comments

Comments
 (0)