|
39 | 39 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
40 | 40 | #include "llvm/Transforms/Utils/LowerVectorIntrinsics.h" |
41 | 41 |
|
42 | | -#include <set> |
43 | | - |
44 | 42 | using namespace llvm; |
45 | 43 |
|
46 | 44 | /// Threshold to leave statically sized memory intrinsic calls. Calls of known |
@@ -476,8 +474,8 @@ enum class PointerEncoding { |
476 | 474 | bool expandProtectedFieldPtr(Function &Intr) { |
477 | 475 | Module &M = *Intr.getParent(); |
478 | 476 |
|
479 | | - std::set<GlobalValue *> DSsToDeactivate; |
480 | | - std::set<Instruction *> LoadsStores; |
| 477 | + SmallPtrSet<GlobalValue *, 2> DSsToDeactivate; |
| 478 | + SmallPtrSet<Instruction *, 2> LoadsStores; |
481 | 479 |
|
482 | 480 | Type *Int8Ty = Type::getInt8Ty(M.getContext()); |
483 | 481 | Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
@@ -520,111 +518,75 @@ bool expandProtectedFieldPtr(Function &Intr) { |
520 | 518 | for (User *U : Intr.users()) { |
521 | 519 | auto *Call = cast<CallInst>(U); |
522 | 520 | auto *DS = GetDeactivationSymbol(Call); |
523 | | - std::set<PHINode *> VisitedPhis; |
524 | | - |
525 | | - std::function<void(Instruction *)> FindLoadsStores; |
526 | | - FindLoadsStores = [&](Instruction *I) { |
527 | | - for (Use &U : I->uses()) { |
528 | | - if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { |
529 | | - if (isa<PointerType>(LI->getType())) { |
530 | | - LoadsStores.insert(LI); |
531 | | - continue; |
532 | | - } |
533 | | - } |
534 | | - if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { |
535 | | - if (U.getOperandNo() == 1 && |
536 | | - isa<PointerType>(SI->getValueOperand()->getType())) { |
537 | | - LoadsStores.insert(SI); |
538 | | - continue; |
539 | | - } |
540 | | - } |
541 | | - if (auto *P = dyn_cast<PHINode>(U.getUser())) { |
542 | | - if (VisitedPhis.insert(P).second) |
543 | | - FindLoadsStores(P); |
| 521 | + |
| 522 | + for (Use &U : Call->uses()) { |
| 523 | + if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { |
| 524 | + if (isa<PointerType>(LI->getType())) { |
| 525 | + LoadsStores.insert(LI); |
544 | 526 | continue; |
545 | 527 | } |
546 | | - // Comparisons against null cannot be used to recover the original |
547 | | - // pointer so we allow them. |
548 | | - if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) { |
549 | | - if (auto *Op = dyn_cast<Constant>(CI->getOperand(0))) |
550 | | - if (Op->isNullValue()) |
551 | | - continue; |
552 | | - if (auto *Op = dyn_cast<Constant>(CI->getOperand(1))) |
553 | | - if (Op->isNullValue()) |
554 | | - continue; |
555 | | - } |
556 | | - if (DS) |
557 | | - DSsToDeactivate.insert(DS); |
558 | 528 | } |
559 | | - }; |
560 | | - |
561 | | - FindLoadsStores(Call); |
562 | | - } |
563 | | - |
564 | | - for (Instruction *I : LoadsStores) { |
565 | | - std::set<Value *> Pointers; |
566 | | - std::set<Value *> Discs; |
567 | | - std::set<GlobalValue *> DSs; |
568 | | - std::set<PHINode *> VisitedPhis; |
569 | | - bool UseHWEncoding = false; |
570 | | - |
571 | | - std::function<void(Value *)> FindFields; |
572 | | - FindFields = [&](Value *V) { |
573 | | - if (auto *Call = dyn_cast<CallInst>(V)) { |
574 | | - if (Call->getCalledOperand() == &Intr) { |
575 | | - Pointers.insert(Call->getArgOperand(0)); |
576 | | - Discs.insert(Call->getArgOperand(1)); |
577 | | - if (cast<ConstantInt>(Call->getArgOperand(2))->getZExtValue()) |
578 | | - UseHWEncoding = true; |
579 | | - DSs.insert(GetDeactivationSymbol(Call)); |
580 | | - return; |
| 529 | + if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { |
| 530 | + if (U.getOperandNo() == 1 && |
| 531 | + isa<PointerType>(SI->getValueOperand()->getType())) { |
| 532 | + LoadsStores.insert(SI); |
| 533 | + continue; |
581 | 534 | } |
582 | 535 | } |
583 | | - if (auto *P = dyn_cast<PHINode>(V)) { |
584 | | - if (VisitedPhis.insert(P).second) |
585 | | - for (Value *V : P->incoming_values()) |
586 | | - FindFields(V); |
587 | | - return; |
| 536 | + // Comparisons against null cannot be used to recover the original |
| 537 | + // pointer so we allow them. |
| 538 | + if (auto *CI = dyn_cast<ICmpInst>(U.getUser())) { |
| 539 | + if (auto *Op = dyn_cast<Constant>(CI->getOperand(0))) |
| 540 | + if (Op->isNullValue()) |
| 541 | + continue; |
| 542 | + if (auto *Op = dyn_cast<Constant>(CI->getOperand(1))) |
| 543 | + if (Op->isNullValue()) |
| 544 | + continue; |
588 | 545 | } |
589 | | - Pointers.insert(nullptr); |
590 | | - }; |
591 | | - FindFields(isa<StoreInst>(I) ? cast<StoreInst>(I)->getPointerOperand() |
592 | | - : cast<LoadInst>(I)->getPointerOperand()); |
593 | | - if (Pointers.size() != 1 || Discs.size() != 1 || DSs.size() != 1) { |
594 | | - for (GlobalValue *DS : DSs) |
595 | | - if (DS) |
596 | | - DSsToDeactivate.insert(DS); |
597 | | - continue; |
| 546 | + if (DS) |
| 547 | + DSsToDeactivate.insert(DS); |
598 | 548 | } |
| 549 | + } |
| 550 | + |
| 551 | + for (Instruction *I : LoadsStores) { |
| 552 | + auto *PointerOperand = isa<StoreInst>(I) |
| 553 | + ? cast<StoreInst>(I)->getPointerOperand() |
| 554 | + : cast<LoadInst>(I)->getPointerOperand(); |
| 555 | + auto *Call = cast<CallInst>(PointerOperand); |
| 556 | + |
| 557 | + auto *Disc = Call->getArgOperand(1); |
| 558 | + bool UseHWEncoding = cast<ConstantInt>(Call->getArgOperand(2))->getZExtValue(); |
599 | 559 |
|
600 | | - GlobalValue *DS = *DSs.begin(); |
| 560 | + GlobalValue *DS = GetDeactivationSymbol(Call); |
601 | 561 | OperandBundleDef DSBundle("deactivation-symbol", DS); |
602 | 562 |
|
603 | 563 | if (auto *LI = dyn_cast<LoadInst>(I)) { |
604 | 564 | IRBuilder<> B(LI->getNextNode()); |
605 | 565 | auto *LIInt = cast<Instruction>(B.CreatePtrToInt(LI, B.getInt64Ty())); |
606 | 566 | Value *Auth; |
607 | 567 | if (UseHWEncoding) { |
608 | | - Auth = CreateAuth(B, LIInt, *Discs.begin(), DSBundle); |
| 568 | + Auth = CreateAuth(B, LIInt, Disc, DSBundle); |
609 | 569 | } else { |
610 | | - Auth = B.CreateAdd(LIInt, *Discs.begin()); |
| 570 | + Auth = B.CreateAdd(LIInt, Disc); |
611 | 571 | Auth = B.CreateIntrinsic( |
612 | 572 | Auth->getType(), Intrinsic::fshr, |
613 | 573 | {Auth, Auth, ConstantInt::get(Auth->getType(), 16)}); |
614 | 574 | } |
615 | 575 | LI->replaceAllUsesWith(B.CreateIntToPtr(Auth, B.getPtrTy())); |
616 | 576 | LIInt->setOperand(0, LI); |
617 | | - } else if (auto *SI = dyn_cast<StoreInst>(I)) { |
| 577 | + } else { |
| 578 | + auto *SI = cast<StoreInst>(I); |
618 | 579 | IRBuilder<> B(SI); |
619 | 580 | auto *SIValInt = |
620 | 581 | B.CreatePtrToInt(SI->getValueOperand(), B.getInt64Ty()); |
621 | 582 | Value *Sign; |
622 | 583 | if (UseHWEncoding) { |
623 | | - Sign = CreateSign(B, SIValInt, *Discs.begin(), DSBundle); |
| 584 | + Sign = CreateSign(B, SIValInt, Disc, DSBundle); |
624 | 585 | } else { |
625 | 586 | Sign = B.CreateIntrinsic( |
626 | 587 | SIValInt->getType(), Intrinsic::fshl, |
627 | 588 | {SIValInt, SIValInt, ConstantInt::get(SIValInt->getType(), 16)}); |
| 589 | + Sign = B.CreateSub(Sign, Disc); |
628 | 590 | } |
629 | 591 | SI->setOperand(0, B.CreateIntToPtr(Sign, B.getPtrTy())); |
630 | 592 | } |
|
0 commit comments