Skip to content

Commit 8e96966

Browse files
authored
MLIR variadic (#2227)
* WIP variadic * complete
1 parent 5a99e15 commit 8e96966

File tree

9 files changed

+126
-72
lines changed

9 files changed

+126
-72
lines changed

enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class AutoDiffCallFwd
7373
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
7474
freeMemory, width,
7575
/* addedType */ nullptr, type_args, volatile_args,
76-
/* augmented */ nullptr, gutils->postpasses);
76+
/* augmented */ nullptr, gutils->omp, gutils->postpasses);
7777

7878
SmallVector<Value> fwdArguments;
7979

@@ -173,7 +173,7 @@ class AutoDiffCallRev
173173
auto revFn = gutils->Logic.CreateReverseDiff(
174174
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow,
175175
mode, freeMemory, width, /*addedType*/ nullptr, type_args,
176-
volatile_args, /*augmented*/ nullptr, gutils->postpasses);
176+
volatile_args, /*augmented*/ nullptr, gutils->omp, gutils->postpasses);
177177

178178
SmallVector<Value> revArguments;
179179

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
8181
std::vector<DIFFE_TYPE> ArgActivity, MTypeAnalysis &TA,
8282
std::vector<bool> returnPrimals, DerivativeMode mode, bool freeMemory,
8383
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
84-
std::vector<bool> volatile_args, void *augmented,
84+
std::vector<bool> volatile_args, void *augmented, bool omp,
8585
llvm::StringRef postpasses) {
8686
if (fn.getFunctionBody().empty()) {
8787
llvm::errs() << fn << "\n";
@@ -95,7 +95,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
9595
fn, RetActivity, ArgActivity,
9696
// std::map<Argument *, bool>(_uncacheable_args.begin(),
9797
// _uncacheable_args.end()),
98-
returnPrimals, mode, static_cast<unsigned>(width), addedType, type_args};
98+
returnPrimals, mode, static_cast<unsigned>(width), addedType, type_args,
99+
omp};
99100

100101
if (ForwardCachedFunctions.find(tup) != ForwardCachedFunctions.end()) {
101102
return ForwardCachedFunctions.find(tup)->second;

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,18 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
192192
llvm_unreachable("Differentiating empty function");
193193
}
194194

195-
MReverseCacheKey tup = {
196-
fn,
197-
retType,
198-
constants,
199-
returnPrimals,
200-
returnShadows,
201-
mode,
202-
freeMemory,
203-
static_cast<unsigned>(width),
204-
addedType,
205-
type_args,
206-
volatile_args,
207-
};
195+
MReverseCacheKey tup = {fn,
196+
retType,
197+
constants,
198+
returnPrimals,
199+
returnShadows,
200+
mode,
201+
freeMemory,
202+
static_cast<unsigned>(width),
203+
addedType,
204+
type_args,
205+
volatile_args,
206+
omp};
208207

209208
{
210209
auto cachedFn = ReverseCachedFunctions.find(tup);

enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal(
5858
return originalToNewFn.lookupOrNull(originst);
5959
}
6060

61+
SmallVector<mlir::Value, 1>
62+
mlir::enzyme::MGradientUtils::getNewFromOriginal(ValueRange originst) const {
63+
SmallVector<mlir::Value, 1> results;
64+
for (auto op : originst) {
65+
results.push_back(getNewFromOriginal(op));
66+
}
67+
return results;
68+
}
69+
6170
Block *
6271
mlir::enzyme::MGradientUtils::getNewFromOriginal(mlir::Block *originst) const {
6372
if (!originalToNewFn.contains(originst)) {

enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class MGradientUtils {
4444
ArrayRef<DIFFE_TYPE> ArgDiffeTypes;
4545
ArrayRef<DIFFE_TYPE> RetDiffeTypes;
4646

47+
SmallVector<mlir::Value, 1> getNewFromOriginal(ValueRange originst) const;
4748
mlir::Value getNewFromOriginal(const mlir::Value originst) const;
4849
mlir::Block *getNewFromOriginal(mlir::Block *originst) const;
4950
Operation *getNewFromOriginal(Operation *originst) const;
@@ -88,6 +89,16 @@ class MGradientUtils {
8889
auto iface = cast<AutoDiffTypeInterface>(T);
8990
return iface.getShadowType(width);
9091
}
92+
93+
static llvm::SmallVector<mlir::Value, 1>
94+
reindex_arguments(llvm::ArrayRef<mlir::Value> vals,
95+
mlir::OperandRange range) {
96+
llvm::SmallVector<mlir::Value, 1> results;
97+
for (size_t i = 0; i < range.size(); i++) {
98+
results.push_back(vals[range.getBeginOperandIndex() + i]);
99+
}
100+
return results;
101+
}
91102
};
92103

93104
class MDiffeGradientUtils : public MGradientUtils {

enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
3636
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
3737
IRMapping &originalToNewFn_,
3838
std::map<Operation *, Operation *> &originalToNewFnOps_,
39-
DerivativeMode mode_, unsigned width,
39+
DerivativeMode mode_, unsigned width, bool omp,
4040
llvm::StringRef postpasses);
4141

4242
IRMapping mapReverseModeBlocks;
@@ -65,14 +65,13 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
6565

6666
void createReverseModeBlocks(Region &oldFunc, Region &newFunc);
6767

68-
static MGradientUtilsReverse *
69-
CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
70-
FunctionOpInterface todiff, MTypeAnalysis &TA,
71-
MFnTypeInfo &oldTypeInfo, const ArrayRef<bool> returnPrimals,
72-
const ArrayRef<bool> returnShadows,
73-
llvm::ArrayRef<DIFFE_TYPE> retType,
74-
llvm::ArrayRef<DIFFE_TYPE> constant_args,
75-
mlir::Type additionalArg, llvm::StringRef postpasses);
68+
static MGradientUtilsReverse *CreateFromClone(
69+
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
70+
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
71+
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
72+
llvm::ArrayRef<DIFFE_TYPE> retType,
73+
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg,
74+
bool omp, llvm::StringRef postpasses);
7675
};
7776

7877
} // namespace enzyme

enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
152152
MTypeAnalysis TA;
153153
auto type_args = TA.getAnalyzedTypeInfo(fn);
154154
bool freeMemory = true;
155+
bool omp = false;
155156
size_t width = CI.getWidth();
156157

157158
std::vector<bool> volatile_args;
@@ -163,7 +164,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
163164
FunctionOpInterface newFunc = Logic.CreateForwardDiff(
164165
fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
165166
/*addedType*/ nullptr, type_args, volatile_args,
166-
/*augmented*/ nullptr, postpasses);
167+
/*augmented*/ nullptr, omp, postpasses);
167168
if (!newFunc)
168169
return failure();
169170

@@ -229,7 +230,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
229230

230231
auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr());
231232
auto fn = cast<FunctionOpInterface>(symbolOp);
232-
233+
bool omp = false;
233234
auto mode = DerivativeMode::ReverseModeCombined;
234235
std::vector<DIFFE_TYPE> retType;
235236
std::vector<bool> returnPrimals;
@@ -289,7 +290,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
289290
Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals,
290291
returnShadows, mode, freeMemory, width,
291292
/*addedType*/ nullptr, type_args, volatile_args,
292-
/*augmented*/ nullptr, postpasses);
293+
/*augmented*/ nullptr, omp, postpasses);
293294
if (!newFunc)
294295
return failure();
295296

enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct DifferentiateWrapperPass
7474
}
7575
}
7676
auto fn = cast<FunctionOpInterface>(symbolOp);
77+
bool omp = false;
78+
std::string postpasses = "";
7779

7880
std::vector<DIFFE_TYPE> ArgActivity =
7981
parseActivityString(argTys.getValue());
@@ -121,13 +123,13 @@ struct DifferentiateWrapperPass
121123
returnPrimal, mode, freeMemory, width,
122124
/*addedType*/ nullptr, type_args,
123125
volatile_args,
124-
/*augmented*/ nullptr, "");
126+
/*augmented*/ nullptr, omp, postpasses);
125127
} else {
126128
newFunc = Logic.CreateReverseDiff(
127129
fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode,
128130
freeMemory, width,
129131
/*addedType*/ nullptr, type_args, volatile_args,
130-
/*augmented*/ nullptr, "");
132+
/*augmented*/ nullptr, omp, postpasses);
131133
}
132134
if (!newFunc) {
133135
signalPassFailure();

0 commit comments

Comments
 (0)