Skip to content

Commit 8eda577

Browse files
authored
Adapt to upstream (#2237)
* Adapt to upstream * fix * fix
1 parent df197be commit 8eda577

File tree

11 files changed

+142
-84
lines changed

11 files changed

+142
-84
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
161161
AttributeList AL;
162162
AL = AL.addParamAttribute(DT->getContext(), 0,
163163
Attribute::AttrKind::ReadOnly);
164-
AL = AL.addParamAttribute(DT->getContext(), 0,
165-
Attribute::AttrKind::NoCapture);
164+
AL = addFunctionNoCapture(DT->getContext(), AL, 0);
166165
AL =
167166
AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NoAlias);
168167
AL =
169168
AL.addParamAttribute(DT->getContext(), 0, Attribute::AttrKind::NonNull);
170169
AL = AL.addParamAttribute(DT->getContext(), 1,
171170
Attribute::AttrKind::WriteOnly);
172-
AL = AL.addParamAttribute(DT->getContext(), 1,
173-
Attribute::AttrKind::NoCapture);
171+
AL = addFunctionNoCapture(DT->getContext(), AL, 1);
174172
AL =
175173
AL.addParamAttribute(DT->getContext(), 1, Attribute::AttrKind::NoAlias);
176174
AL =
@@ -208,11 +206,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
208206
auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
209207
AttributeList AL;
210208
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
211-
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoCapture);
209+
AL = addFunctionNoCapture(context, AL, 0);
212210
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
213211
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
214212
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
215-
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoCapture);
213+
AL = addFunctionNoCapture(context, AL, 1);
216214
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
217215
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
218216
AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,
@@ -241,11 +239,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
241239
auto alloc = IRBuilder<>(gutils->inversionAllocs).CreateAlloca(rankTy);
242240
AttributeList AL;
243241
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::ReadOnly);
244-
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoCapture);
242+
AL = addFunctionNoCapture(context, AL, 0);
245243
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NoAlias);
246244
AL = AL.addParamAttribute(context, 0, Attribute::AttrKind::NonNull);
247245
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::WriteOnly);
248-
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoCapture);
246+
AL = addFunctionNoCapture(context, AL, 1);
249247
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NoAlias);
250248
AL = AL.addParamAttribute(context, 1, Attribute::AttrKind::NonNull);
251249
AL = AL.addAttributeAtIndex(context, AttributeList::FunctionIndex,

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ std::pair<PHINode *, Instruction *> FindCanonicalIV(Loop *L, Type *Ty) {
194194
continue;
195195
if (!Inc)
196196
continue;
197-
if (Inc != Header->getFirstNonPHIOrDbg())
198-
Inc->moveBefore(Header->getFirstNonPHIOrDbg());
197+
if (Inc != getFirstNonPHIOrDbg(Header))
198+
Inc->moveBefore(getFirstNonPHIOrDbg(Header));
199199
return std::make_pair(PN, Inc);
200200
}
201201
llvm::errs() << *Header << "\n";

enzyme/Enzyme/Enzyme.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
125125
if (F.getName() == "fprintf") {
126126
for (auto &arg : F.args()) {
127127
if (arg.getType()->isPointerTy()) {
128-
arg.addAttr(Attribute::NoCapture);
128+
addFunctionNoCapture(&F, arg.getArgNo());
129129
changed = true;
130130
}
131131
}
@@ -148,7 +148,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
148148
for (auto &arg : F.args()) {
149149
if (arg.getType()->isPointerTy()) {
150150
arg.addAttr(Attribute::ReadNone);
151-
arg.addAttr(Attribute::NoCapture);
151+
addFunctionNoCapture(&F, arg.getArgNo());
152152
}
153153
}
154154
}
@@ -168,7 +168,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
168168
F.addFnAttr(Attribute::NoSync);
169169
for (int i = 0; i < 2; i++)
170170
if (F.getFunctionType()->getParamType(i)->isPointerTy()) {
171-
F.addParamAttr(i, Attribute::NoCapture);
171+
addFunctionNoCapture(&F, i);
172172
F.addParamAttr(i, Attribute::WriteOnly);
173173
}
174174
}
@@ -192,7 +192,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
192192
F.addFnAttr(Attribute::NoSync);
193193
F.addParamAttr(0, Attribute::WriteOnly);
194194
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
195-
F.addParamAttr(2, Attribute::NoCapture);
195+
addFunctionNoCapture(&F, 2);
196196
F.addParamAttr(2, Attribute::WriteOnly);
197197
}
198198
F.addParamAttr(6, Attribute::WriteOnly);
@@ -211,7 +211,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
211211
F.addFnAttr(Attribute::NoSync);
212212
F.addParamAttr(0, Attribute::ReadOnly);
213213
if (F.getFunctionType()->getParamType(2)->isPointerTy()) {
214-
F.addParamAttr(2, Attribute::NoCapture);
214+
addFunctionNoCapture(&F, 2);
215215
F.addParamAttr(2, Attribute::ReadOnly);
216216
}
217217
F.addParamAttr(6, Attribute::WriteOnly);
@@ -231,12 +231,12 @@ bool attributeKnownFunctions(llvm::Function &F) {
231231
F.addFnAttr(Attribute::NoSync);
232232

233233
if (F.getFunctionType()->getParamType(0)->isPointerTy()) {
234-
F.addParamAttr(0, Attribute::NoCapture);
234+
addFunctionNoCapture(&F, 0);
235235
F.addParamAttr(0, Attribute::ReadOnly);
236236
}
237237
if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
238238
F.addParamAttr(1, Attribute::WriteOnly);
239-
F.addParamAttr(1, Attribute::NoCapture);
239+
addFunctionNoCapture(&F, 1);
240240
}
241241
}
242242
if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") {
@@ -246,9 +246,9 @@ bool attributeKnownFunctions(llvm::Function &F) {
246246
F.addFnAttr(Attribute::WillReturn);
247247
F.addFnAttr(Attribute::NoFree);
248248
F.addFnAttr(Attribute::NoSync);
249-
F.addParamAttr(0, Attribute::NoCapture);
249+
addFunctionNoCapture(&F, 0);
250250
F.addParamAttr(1, Attribute::WriteOnly);
251-
F.addParamAttr(1, Attribute::NoCapture);
251+
addFunctionNoCapture(&F, 1);
252252
}
253253
if (F.getName() == "MPI_Waitall" || F.getName() == "PMPI_Waitall") {
254254
changed = true;
@@ -257,9 +257,9 @@ bool attributeKnownFunctions(llvm::Function &F) {
257257
F.addFnAttr(Attribute::WillReturn);
258258
F.addFnAttr(Attribute::NoFree);
259259
F.addFnAttr(Attribute::NoSync);
260-
F.addParamAttr(1, Attribute::NoCapture);
260+
addFunctionNoCapture(&F, 1);
261261
F.addParamAttr(2, Attribute::WriteOnly);
262-
F.addParamAttr(2, Attribute::NoCapture);
262+
addFunctionNoCapture(&F, 2);
263263
}
264264
// Map of MPI function name to the arg index of its type argument
265265
std::map<std::string, int> MPI_TYPE_ARGS = {
@@ -2347,7 +2347,7 @@ class EnzymeBase {
23472347
for (size_t i = 0; i < num_args; ++i) {
23482348
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
23492349
CI->addParamAttr(i, Attribute::ReadNone);
2350-
CI->addParamAttr(i, Attribute::NoCapture);
2350+
addCallSiteNoCapture(CI, i);
23512351
}
23522352
}
23532353
}
@@ -2361,7 +2361,7 @@ class EnzymeBase {
23612361
for (size_t i = 0; i < num_args; ++i) {
23622362
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
23632363
CI->addParamAttr(i, Attribute::ReadNone);
2364-
CI->addParamAttr(i, Attribute::NoCapture);
2364+
addCallSiteNoCapture(CI, i);
23652365
}
23662366
}
23672367
}
@@ -2375,7 +2375,7 @@ class EnzymeBase {
23752375
for (size_t i = 0; i < num_args; ++i) {
23762376
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
23772377
CI->addParamAttr(i, Attribute::ReadNone);
2378-
CI->addParamAttr(i, Attribute::NoCapture);
2378+
addCallSiteNoCapture(CI, i);
23792379
}
23802380
}
23812381
}
@@ -2389,7 +2389,7 @@ class EnzymeBase {
23892389
for (size_t i = 0; i < num_args; ++i) {
23902390
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
23912391
CI->addParamAttr(i, Attribute::ReadNone);
2392-
CI->addParamAttr(i, Attribute::NoCapture);
2392+
addCallSiteNoCapture(CI, i);
23932393
}
23942394
}
23952395
}
@@ -2439,9 +2439,9 @@ class EnzymeBase {
24392439
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
24402440
#endif
24412441
CI->addParamAttr(1, Attribute::ReadOnly);
2442-
CI->addParamAttr(1, Attribute::NoCapture);
2442+
addCallSiteNoCapture(CI, 1);
24432443
CI->addParamAttr(3, Attribute::ReadOnly);
2444-
CI->addParamAttr(3, Attribute::NoCapture);
2444+
addCallSiteNoCapture(CI, 3);
24452445
}
24462446
if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
24472447
Fn->getName() == "frexpl") {
@@ -2502,7 +2502,7 @@ class EnzymeBase {
25022502
for (size_t i : {0, 1}) {
25032503
if (i < num_args &&
25042504
CI->getArgOperand(i)->getType()->isPointerTy()) {
2505-
CI->addParamAttr(i, Attribute::NoCapture);
2505+
addCallSiteNoCapture(CI, i);
25062506
}
25072507
}
25082508
}
@@ -2527,7 +2527,7 @@ class EnzymeBase {
25272527
for (size_t i : {0, 2}) {
25282528
if (i < num_args &&
25292529
CI->getArgOperand(i)->getType()->isPointerTy()) {
2530-
CI->addParamAttr(i, Attribute::NoCapture);
2530+
addCallSiteNoCapture(CI, i);
25312531
}
25322532
}
25332533
}
@@ -2553,7 +2553,7 @@ class EnzymeBase {
25532553
for (size_t i : {0, 1, 2, 3}) {
25542554
if (i < num_args &&
25552555
CI->getArgOperand(i)->getType()->isPointerTy()) {
2556-
CI->addParamAttr(i, Attribute::NoCapture);
2556+
addCallSiteNoCapture(CI, i);
25572557
}
25582558
}
25592559
}
@@ -2579,7 +2579,7 @@ class EnzymeBase {
25792579
for (size_t i : {0}) {
25802580
if (i < num_args &&
25812581
CI->getArgOperand(i)->getType()->isPointerTy()) {
2582-
CI->addParamAttr(i, Attribute::NoCapture);
2582+
addCallSiteNoCapture(CI, i);
25832583
}
25842584
}
25852585
}
@@ -2601,7 +2601,7 @@ class EnzymeBase {
26012601
for (size_t i = 0; i < num_args; ++i) {
26022602
if (CI->getArgOperand(i)->getType()->isPointerTy()) {
26032603
CI->addParamAttr(i, Attribute::ReadOnly);
2604-
CI->addParamAttr(i, Attribute::NoCapture);
2604+
addCallSiteNoCapture(CI, i);
26052605
}
26062606
}
26072607
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2713,9 +2713,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
27132713
auto i = nf->arg_begin(), j = NewF->arg_begin();
27142714
while (i != nf->arg_end()) {
27152715
VMap[i] = j;
2716+
#if LLVM_VERSION_MAJOR > 20
2717+
if (nf->hasParamAttribute(attrIndex, Attribute::Captures)) {
2718+
NewF->addParamAttr(attrIndex,
2719+
nf->getParamAttribute(attrIndex, Attribute::Captures));
2720+
}
2721+
#else
27162722
if (nf->hasParamAttribute(attrIndex, Attribute::NoCapture)) {
2717-
NewF->addParamAttr(attrIndex, Attribute::NoCapture);
2723+
NewF->addParamAttr(
2724+
attrIndex, nf->getParamAttribute(attrIndex, Attribute::NoCapture));
27182725
}
2726+
#endif
27192727
if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) {
27202728
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
27212729
}

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
698698
AttributeList list;
699699
list = list.addFnAttribute(NewF->getContext(), Attribute::ReadOnly);
700700
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone);
701-
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::NoCapture);
701+
list = addFunctionNoCapture(NewF->getContext(), list, 0);
702702
auto allocSize = NewF->getParent()->getOrInsertFunction(
703703
allocName,
704704
FunctionType::get(
@@ -1109,7 +1109,7 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
11091109
B.SetInsertPoint(Bound->getNextNode());
11101110
}
11111111
B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI);
1112-
Bound->addParamAttr(i, Attribute::NoCapture);
1112+
addCallSiteNoCapture(Bound, i);
11131113
}
11141114
}
11151115
PreservedAnalyses PA;

enzyme/Enzyme/TraceGenerator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void TraceGenerator::visitFunction(Function &F) {
6363
return;
6464

6565
auto fn = tutils->newFunc;
66-
auto entry = fn->getEntryBlock().getFirstNonPHIOrDbgOrLifetime();
66+
auto entry = getFirstNonPHIOrDbgOrLifetime(&fn->getEntryBlock());
6767

6868
while (isa<AllocaInst>(entry) && entry->getNextNode()) {
6969
entry = entry->getNextNode();

enzyme/Enzyme/TraceInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ DynamicTraceInterface::DynamicTraceInterface(Value *dynamicInterface,
311311
assert(dynamicInterface);
312312

313313
auto &M = *F->getParent();
314-
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbg());
314+
IRBuilder<> Builder(getFirstNonPHIOrDbg(&F->getEntryBlock()));
315315

316316
getTraceFunction = MaterializeInterfaceFunction(
317317
Builder, dynamicInterface, getTraceTy(), 0, M, "get_trace");

enzyme/Enzyme/TraceUtils.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,8 @@ TraceUtils::ValueToVoidPtrAndSize(IRBuilder<> &Builder, Value *val,
196196
Builder.CreateIntToPtr(cast, getInt8PtrTy(cast->getContext()));
197197
return {retval, ConstantInt::get(size_type, valsize / 8)};
198198
} else {
199-
auto insertPoint = Builder.GetInsertBlock()
200-
->getParent()
201-
->getEntryBlock()
202-
.getFirstNonPHIOrDbgOrLifetime();
203-
IRBuilder<> AllocaBuilder(insertPoint);
199+
IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
200+
&Builder.GetInsertBlock()->getParent()->getEntryBlock()));
204201
auto alloca = AllocaBuilder.CreateAlloca(val->getType(), nullptr,
205202
val->getName() + ".ptr");
206203
Builder.CreateStore(val, alloca);
@@ -248,7 +245,8 @@ CallInst *TraceUtils::InsertChoice(IRBuilder<> &Builder, Value *address,
248245
auto call = Builder.CreateCall(interface->insertChoiceTy(),
249246
interface->insertChoice(Builder), args);
250247
call->addParamAttr(1, Attribute::ReadOnly);
251-
call->addParamAttr(1, Attribute::NoCapture);
248+
249+
addCallSiteNoCapture(call, 1);
252250
return call;
253251
}
254252

@@ -259,7 +257,7 @@ CallInst *TraceUtils::InsertCall(IRBuilder<> &Builder, Value *address,
259257
auto call = Builder.CreateCall(interface->insertCallTy(),
260258
interface->insertCall(Builder), args);
261259
call->addParamAttr(1, Attribute::ReadOnly);
262-
call->addParamAttr(1, Attribute::NoCapture);
260+
addCallSiteNoCapture(call, 1);
263261
#if LLVM_VERSION_MAJOR >= 14
264262
call->addAttributeAtIndex(
265263
AttributeList::FunctionIndex,
@@ -283,7 +281,7 @@ CallInst *TraceUtils::InsertArgument(IRBuilder<> &Builder, Value *name,
283281
auto call = Builder.CreateCall(interface->insertArgumentTy(),
284282
interface->insertArgument(Builder), args);
285283
call->addParamAttr(1, Attribute::ReadOnly);
286-
call->addParamAttr(1, Attribute::NoCapture);
284+
addCallSiteNoCapture(call, 1);
287285
return call;
288286
}
289287

@@ -322,7 +320,7 @@ CallInst *TraceUtils::InsertChoiceGradient(IRBuilder<> &Builder,
322320

323321
auto call = Builder.CreateCall(interface_type, interface_function, args);
324322
call->addParamAttr(1, Attribute::ReadOnly);
325-
call->addParamAttr(1, Attribute::NoCapture);
323+
addCallSiteNoCapture(call, 1);
326324
return call;
327325
}
328326

@@ -339,7 +337,7 @@ CallInst *TraceUtils::InsertArgumentGradient(IRBuilder<> &Builder,
339337

340338
auto call = Builder.CreateCall(interface_type, interface_function, args);
341339
call->addParamAttr(1, Attribute::ReadOnly);
342-
call->addParamAttr(1, Attribute::NoCapture);
340+
addCallSiteNoCapture(call, 1);
343341
return call;
344342
}
345343

@@ -352,16 +350,14 @@ CallInst *TraceUtils::GetTrace(IRBuilder<> &Builder, Value *address,
352350
auto call = Builder.CreateCall(interface->getTraceTy(),
353351
interface->getTrace(Builder), args, Name);
354352
call->addParamAttr(1, Attribute::ReadOnly);
355-
call->addParamAttr(1, Attribute::NoCapture);
353+
addCallSiteNoCapture(call, 1);
356354
return call;
357355
}
358356

359357
Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address,
360358
Type *choiceType, const Twine &Name) {
361-
IRBuilder<> AllocaBuilder(Builder.GetInsertBlock()
362-
->getParent()
363-
->getEntryBlock()
364-
.getFirstNonPHIOrDbgOrLifetime());
359+
IRBuilder<> AllocaBuilder(getFirstNonPHIOrDbgOrLifetime(
360+
&Builder.GetInsertBlock()->getParent()->getEntryBlock()));
365361
AllocaInst *store_dest =
366362
AllocaBuilder.CreateAlloca(choiceType, nullptr, Name + ".ptr");
367363
auto preallocated_size = choiceType->getPrimitiveSizeInBits() / 8;
@@ -385,7 +381,7 @@ Instruction *TraceUtils::GetChoice(IRBuilder<> &Builder, Value *address,
385381
Attribute::get(call->getContext(), "enzyme_inactive"));
386382
#endif
387383
call->addParamAttr(1, Attribute::ReadOnly);
388-
call->addParamAttr(1, Attribute::NoCapture);
384+
addCallSiteNoCapture(call, 1);
389385
return Builder.CreateLoad(choiceType, store_dest, "from.trace." + Name);
390386
}
391387

@@ -396,7 +392,7 @@ Instruction *TraceUtils::HasChoice(IRBuilder<> &Builder, Value *address,
396392
auto call = Builder.CreateCall(interface->hasChoiceTy(),
397393
interface->hasChoice(Builder), args, Name);
398394
call->addParamAttr(1, Attribute::ReadOnly);
399-
call->addParamAttr(1, Attribute::NoCapture);
395+
addCallSiteNoCapture(call, 1);
400396
return call;
401397
}
402398

@@ -407,7 +403,7 @@ Instruction *TraceUtils::HasCall(IRBuilder<> &Builder, Value *address,
407403
auto call = Builder.CreateCall(interface->hasCallTy(),
408404
interface->hasCall(Builder), args, Name);
409405
call->addParamAttr(1, Attribute::ReadOnly);
410-
call->addParamAttr(1, Attribute::NoCapture);
406+
addCallSiteNoCapture(call, 1);
411407
return call;
412408
}
413409

0 commit comments

Comments
 (0)