Skip to content

Commit 5a99e15

Browse files
authored
WIP variadic (#2226)
1 parent 5c632cc commit 5a99e15

File tree

5 files changed

+63
-10
lines changed

5 files changed

+63
-10
lines changed

enzyme/Enzyme/MLIR/Implementations/Common.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def AssertingInactiveArg : InactiveArgSpec {
2727
bit asserting = 1;
2828
}
2929

30+
class Variadic<string getter_> {
31+
string getter = getter_;
32+
}
3033

3134
def Unimplemented {
3235

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class MEnzymeLogic {
5555
unsigned width;
5656
mlir::Type additionalType;
5757
const MFnTypeInfo typeInfo;
58+
bool omp;
5859

5960
inline bool operator<(const MForwardCacheKey &rhs) const {
6061
if (todiff < rhs.todiff)
@@ -100,6 +101,12 @@ class MEnzymeLogic {
100101
return true;
101102
if (rhs.typeInfo < typeInfo)
102103
return false;
104+
105+
if (omp < rhs.omp)
106+
return true;
107+
if (rhs.omp < omp)
108+
return false;
109+
103110
// equal
104111
return false;
105112
}
@@ -117,6 +124,7 @@ class MEnzymeLogic {
117124
mlir::Type additionalType;
118125
const MFnTypeInfo typeInfo;
119126
const std::vector<bool> volatileArgs;
127+
bool omp;
120128

121129
inline bool operator<(const MReverseCacheKey &rhs) const {
122130
if (todiff < rhs.todiff)
@@ -182,6 +190,11 @@ class MEnzymeLogic {
182190
if (rhs.volatileArgs < volatileArgs)
183191
return false;
184192

193+
if (omp < rhs.omp)
194+
return true;
195+
if (rhs.omp < omp)
196+
return false;
197+
185198
// equal
186199
return false;
187200
}
@@ -196,7 +209,7 @@ class MEnzymeLogic {
196209
std::vector<bool> returnPrimals, DerivativeMode mode,
197210
bool freeMemory, size_t width, mlir::Type addedType,
198211
MFnTypeInfo type_args, std::vector<bool> volatile_args,
199-
void *augmented, llvm::StringRef postpasses);
212+
void *augmented, bool omp, llvm::StringRef postpasses);
200213

201214
FunctionOpInterface
202215
CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
@@ -205,7 +218,7 @@ class MEnzymeLogic {
205218
std::vector<bool> returnShadows, DerivativeMode mode,
206219
bool freeMemory, size_t width, mlir::Type addedType,
207220
MFnTypeInfo type_args, std::vector<bool> volatile_args,
208-
void *augmented, llvm::StringRef postpasses);
221+
void *augmented, bool omp, llvm::StringRef postpasses);
209222

210223
void
211224
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
185185
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
186186
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
187187
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
188-
llvm::StringRef postpasses) {
188+
bool omp, llvm::StringRef postpasses) {
189189

190190
if (fn.getFunctionBody().empty()) {
191191
llvm::errs() << fn << "\n";
@@ -217,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
217217

218218
MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
219219
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
220-
retType, constants, addedType, postpasses);
220+
retType, constants, addedType, omp, postpasses);
221221

222222
ReverseCachedFunctions[tup] = gutils->newFunc;
223223

enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
3737
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
3838
IRMapping &originalToNewFn_,
3939
std::map<Operation *, Operation *> &originalToNewFnOps_,
40-
DerivativeMode mode_, unsigned width, StringRef postpasses)
40+
DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses)
4141
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
4242
invertedPointers_, returnPrimals, returnShadows,
4343
constantvalues_, activevals_, ReturnActivity,
4444
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
45-
mode_, width, /*omp*/ false, postpasses) {}
45+
mode_, width, omp, postpasses) {}
4646

4747
Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
4848
Type indexType = getIndexType();
@@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
138138
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
139139
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
140140
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
141-
mlir::Type additionalArg, llvm::StringRef postpasses) {
141+
mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
142142
std::string prefix;
143143

144144
switch (mode_) {
@@ -174,5 +174,6 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
174174
return new MGradientUtilsReverse(
175175
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
176176
returnShadows, constant_values, nonconstant_values, retType,
177-
constant_args, originalToNew, originalToNewOps, mode_, width, postpasses);
177+
constant_args, originalToNew, originalToNewOps, mode_, width, omp,
178+
postpasses);
178179
}

enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,30 @@ static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
13941394
for (auto tree : ptree->getArgs()) {
13951395
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
13961396
next.push_back(i);
1397-
if (auto dg = dyn_cast<DagInit>(tree))
1397+
if (auto dg = dyn_cast<DagInit>(tree)) {
1398+
if (ptree->getArgNameStr(i).size()) {
1399+
auto opName = dg->getOperator()->getAsString();
1400+
auto Def = cast<DefInit>(dg->getOperator())->getDef();
1401+
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
1402+
auto expr = Def->getValueAsString("getter");
1403+
std::string op;
1404+
if (intrinsic != MLIRDerivatives)
1405+
op = (origName + "." + expr + "()").str();
1406+
else
1407+
op = (origName + "->" + expr + "()").str();
1408+
std::vector<int> extractions;
1409+
if (prev.size() > 0) {
1410+
for (unsigned i = 1; i < next.size(); i++) {
1411+
extractions.push_back(next[i]);
1412+
}
1413+
}
1414+
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
1415+
extractions);
1416+
continue;
1417+
}
1418+
}
13981419
insert(dg, next);
1420+
}
13991421

14001422
if (ptree->getArgNameStr(i).size()) {
14011423
std::string op;
@@ -1580,8 +1602,22 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
15801602
auto name = ptree->getArgNameStr(treeEn.index());
15811603
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
15821604
next.push_back(treeEn.index());
1583-
if (auto dg = dyn_cast<DagInit>(tree))
1605+
if (auto dg = dyn_cast<DagInit>(tree)) {
1606+
if (name.size()) {
1607+
auto opName = dg->getOperator()->getAsString();
1608+
auto Def = cast<DefInit>(dg->getOperator())->getDef();
1609+
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
1610+
auto expr = Def->getValueAsString("getter");
1611+
varNameToCondition[name] = std::make_tuple(
1612+
("llvm::is_contained(op->getOperand(idx), op." + expr +
1613+
"())")
1614+
.str(),
1615+
"", false);
1616+
continue;
1617+
}
1618+
}
15841619
insert(dg, next);
1620+
}
15851621

15861622
if (name.size()) {
15871623
varNameToCondition[name] = std::make_tuple(

0 commit comments

Comments
 (0)