@@ -207,7 +207,7 @@ struct VariableSetting {
207207 StringMap<std::vector<int >> extractions;
208208
209209 std::tuple<std::string, bool , std::vector<int >>
210- lookup (StringRef name, const Record *pattern, const Init *resultRoot) {
210+ lookup (StringRef name, const Record *pattern, const Init *resultRoot) const {
211211 auto ord = nameToOrdinal.find (name);
212212 if (ord == nameToOrdinal.end ())
213213 PrintFatalError (pattern->getLoc (), Twine (" unknown named operand '" ) +
@@ -1192,14 +1192,16 @@ void handleUse(
11921192 const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
11931193 std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
11941194 const DagInit *tree,
1195- StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition);
1195+ StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition,
1196+ const VariableSetting &nameToOrdinal);
11961197
11971198void handleUseArgument (
11981199 StringRef name, const Init *arg, bool usesPrimal, bool usesShadow,
11991200 const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
12001201 std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
12011202 const DagInit *tree,
1202- StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition) {
1203+ StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition,
1204+ const VariableSetting &nameToOrdinal) {
12031205
12041206 auto arg2 = dyn_cast<DagInit>(arg);
12051207
@@ -1218,7 +1220,8 @@ void handleUseArgument(
12181220 handleUse (root, arg2, name.size () ? foundPrimalUse2 : foundPrimalUse,
12191221 name.size () ? foundShadowUse2 : foundShadowUse,
12201222 name.size () ? foundDiffRet2 : foundDiffRet,
1221- usesPrimal ? precondition : " " , tree, varNameToCondition);
1223+ usesPrimal ? precondition : " " , tree, varNameToCondition,
1224+ nameToOrdinal);
12221225
12231226 if (name.size ()) {
12241227 if (foundPrimalUse2.size () &&
@@ -1306,7 +1309,8 @@ void handleUse(
13061309 const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
13071310 std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
13081311 const DagInit *tree,
1309- StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition) {
1312+ StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition,
1313+ const VariableSetting &nameToOrdinal) {
13101314 auto opName = resultTree->getOperator ()->getAsString ();
13111315 auto Def = cast<DefInit>(resultTree->getOperator ())->getDef ();
13121316 if (opName == " DiffeRetIndex" || Def->isSubClassOf (" DiffeRetIndex" )) {
@@ -1339,7 +1343,9 @@ void handleUse(
13391343 if (numArgs == 3 ) {
13401344 if (isa<UnsetInit>(resultTree->getArg (0 )) && resultTree->getArgName (0 )) {
13411345 auto name = resultTree->getArgName (0 )->getAsUnquotedString ();
1342- conditionStr = ReplaceAll (conditionStr, " imVal" , name);
1346+ auto [ord, isVec, ext] = nameToOrdinal.lookup (name, nullptr , nullptr );
1347+ assert (!isVec);
1348+ conditionStr = ReplaceAll (conditionStr, " imVal" , ord);
13431349 } else
13441350 assert (" Requires name for arg" );
13451351 }
@@ -1362,7 +1368,7 @@ void handleUse(
13621368 auto arg = resultTree->getArg (i);
13631369 handleUseArgument (name, arg, true , false , root, resultTree,
13641370 foundPrimalUse, foundShadowUse, foundDiffRet,
1365- precondition2, tree, varNameToCondition);
1371+ precondition2, tree, varNameToCondition, nameToOrdinal );
13661372 }
13671373
13681374 return ;
@@ -1375,16 +1381,57 @@ void handleUse(
13751381 auto name = resultTree->getArgNameStr (argEn.index ());
13761382 handleUseArgument (name, argEn.value (), usesPrimal, usesShadow, root,
13771383 resultTree, foundPrimalUse, foundShadowUse, foundDiffRet,
1378- precondition, tree, varNameToCondition);
1384+ precondition, tree, varNameToCondition, nameToOrdinal );
13791385 }
13801386}
13811387
1388+ static VariableSetting parseVariables (const DagInit *tree, ActionType intrinsic,
1389+ StringRef origName) {
1390+ VariableSetting nameToOrdinal;
1391+ std::function<void (const DagInit *, ArrayRef<unsigned >)> insert =
1392+ [&](const DagInit *ptree, ArrayRef<unsigned > prev) {
1393+ unsigned i = 0 ;
1394+ for (auto tree : ptree->getArgs ()) {
1395+ SmallVector<unsigned , 2 > next (prev.begin (), prev.end ());
1396+ next.push_back (i);
1397+ if (auto dg = dyn_cast<DagInit>(tree))
1398+ insert (dg, next);
1399+
1400+ if (ptree->getArgNameStr (i).size ()) {
1401+ std::string op;
1402+ if (intrinsic != MLIRDerivatives)
1403+ op = (origName + " .getOperand(" + Twine (next[0 ]) + " )" ).str ();
1404+ else
1405+ op = (origName + " ->getOperand(" + Twine (next[0 ]) + " )" ).str ();
1406+ std::vector<int > extractions;
1407+ if (prev.size () > 0 ) {
1408+ for (unsigned i = 1 ; i < next.size (); i++) {
1409+ extractions.push_back (next[i]);
1410+ }
1411+ }
1412+ nameToOrdinal.insert (ptree->getArgNameStr (i), op, false ,
1413+ extractions);
1414+ }
1415+ i++;
1416+ }
1417+ };
1418+
1419+ insert (tree, {});
1420+
1421+ if (tree->getNameStr ().size ())
1422+ nameToOrdinal.insert (tree->getNameStr (),
1423+ (Twine (" (&" ) + origName + " )" ).str (), false , {});
1424+ return nameToOrdinal;
1425+ }
1426+
13821427void printDiffUse (
13831428 raw_ostream &os, Twine prefix, const ListInit *argOps, StringRef origName,
13841429 ActionType intrinsic, const DagInit *tree,
13851430 StringMap<std::tuple<std::string, std::string, bool >> &varNameToCondition) {
13861431 os << prefix << " // Rule " << *tree << " \n " ;
13871432
1433+ VariableSetting nameToOrdinal = parseVariables (tree, intrinsic, origName);
1434+
13881435 for (auto argOpEn : enumerate(*argOps)) {
13891436 size_t argIdx = argOpEn.index ();
13901437 if (auto resultRoot = dyn_cast<DagInit>(argOpEn.value ())) {
@@ -1417,7 +1464,8 @@ void printDiffUse(
14171464
14181465 // hasDiffeRet(resultTree)
14191466 handleUse (resultTree, resultTree, foundPrimalUse, foundShadowUse,
1420- foundDiffRet, /* precondition*/ " true" , tree, varNameToCondition);
1467+ foundDiffRet, /* precondition*/ " true" , tree, varNameToCondition,
1468+ nameToOrdinal);
14211469
14221470 os << prefix << " // Arg " << argIdx << " : " << *resultTree << " \n " ;
14231471
@@ -1587,45 +1635,6 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
15871635 os << " mlir::Value dif = nullptr;\n " ;
15881636}
15891637
1590- static VariableSetting parseVariables (const DagInit *tree, ActionType intrinsic,
1591- StringRef origName) {
1592- VariableSetting nameToOrdinal;
1593- std::function<void (const DagInit *, ArrayRef<unsigned >)> insert =
1594- [&](const DagInit *ptree, ArrayRef<unsigned > prev) {
1595- unsigned i = 0 ;
1596- for (auto tree : ptree->getArgs ()) {
1597- SmallVector<unsigned , 2 > next (prev.begin (), prev.end ());
1598- next.push_back (i);
1599- if (auto dg = dyn_cast<DagInit>(tree))
1600- insert (dg, next);
1601-
1602- if (ptree->getArgNameStr (i).size ()) {
1603- std::string op;
1604- if (intrinsic != MLIRDerivatives)
1605- op = (origName + " .getOperand(" + Twine (next[0 ]) + " )" ).str ();
1606- else
1607- op = (origName + " ->getOperand(" + Twine (next[0 ]) + " )" ).str ();
1608- std::vector<int > extractions;
1609- if (prev.size () > 0 ) {
1610- for (unsigned i = 1 ; i < next.size (); i++) {
1611- extractions.push_back (next[i]);
1612- }
1613- }
1614- nameToOrdinal.insert (ptree->getArgNameStr (i), op, false ,
1615- extractions);
1616- }
1617- i++;
1618- }
1619- };
1620-
1621- insert (tree, {});
1622-
1623- if (tree->getNameStr ().size ())
1624- nameToOrdinal.insert (tree->getNameStr (),
1625- (Twine (" (&" ) + origName + " )" ).str (), false , {});
1626- return nameToOrdinal;
1627- }
1628-
16291638static void emitReverseCommon (raw_ostream &os, const Record *pattern,
16301639 const DagInit *tree, ActionType intrinsic,
16311640 StringRef origName, const ListInit *argOps) {
0 commit comments