@@ -460,7 +460,7 @@ class OpLowerer {
460460 });
461461 }
462462
463- [[nodiscard]] bool lowerCtpopToCBits (Function &F) {
463+ [[nodiscard]] bool lowerCtpopToCountBits (Function &F) {
464464 IRBuilder<> &IRB = OpBuilder.getIRB ();
465465 Type *Int32Ty = IRB.getInt32Ty ();
466466
@@ -471,13 +471,11 @@ class OpLowerer {
471471
472472 Type *RetTy = Int32Ty;
473473 Type *FRT = F.getReturnType ();
474- if (FRT->isVectorTy ()) {
475- VectorType *VT = cast<VectorType>(FRT);
474+ if (const auto *VT = dyn_cast<VectorType>(FRT))
476475 RetTy = VectorType::get (RetTy, VT);
477- }
478476
479477 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp (
480- dxil::OpCode::CBits , Args, CI->getName (), RetTy);
478+ dxil::OpCode::CountBits , Args, CI->getName (), RetTy);
481479 if (Error E = OpCall.takeError ())
482480 return E;
483481
@@ -491,31 +489,36 @@ class OpLowerer {
491489 unsigned CastOp;
492490 if (FRT->isIntOrIntVectorTy (16 ))
493491 CastOp = Instruction::ZExt;
494- else // must be 64 bits
492+ else { // must be 64 bits
493+ assert (FRT->isIntOrIntVectorTy (64 ) &&
494+ " Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
495+ is supported." );
495496 CastOp = Instruction::Trunc;
497+ }
496498
497499 // It is correct to replace the ctpop with the dxil op and
498- // remove an existing cast iff the cast is the only usage of
499- // the ctpop
500- // can use hasOneUse instead of hasOneUser, because the user
501- // we care about should have one operand
502- if (CI->hasOneUse ()) {
503- User *U = CI->user_back ();
500+ // remove all casts to i32
501+ bool nonCastInstr = false ;
502+ for (User *User : make_early_inc_range (CI->users ())) {
504503 Instruction *I;
505- if (isa<Instruction>(U) && ( I = cast <Instruction>(U)) &&
504+ if (( I = dyn_cast <Instruction>(User)) != NULL &&
506505 I->getOpcode () == CastOp && I->getType () == RetTy) {
507506 I->replaceAllUsesWith (*OpCall);
508507 I->eraseFromParent ();
509- CI->eraseFromParent ();
510- return Error::success ();
511- }
508+ } else
509+ nonCastInstr = true ;
510+ }
511+
512+ // It is correct to replace a ctpop with the dxil op and
513+ // a cast from i32 to the return type of the ctpop
514+ // the cast is emitted here if there is a non-cast to i32
515+ // instr which uses the ctpop
516+ if (nonCastInstr) {
517+ Value *Cast =
518+ IRB.CreateZExtOrTrunc (*OpCall, F.getReturnType (), " ctpop.cast" );
519+ CI->replaceAllUsesWith (Cast);
512520 }
513521
514- // It is always correct to replace a ctpop with the dxil op and
515- // a cast
516- Value *Cast =
517- IRB.CreateZExtOrTrunc (*OpCall, F.getReturnType (), " ctpop.cast" );
518- CI->replaceAllUsesWith (Cast);
519522 CI->eraseFromParent ();
520523 return Error::success ();
521524 });
@@ -550,7 +553,7 @@ class OpLowerer {
550553 HasErrors |= lowerTypedBufferStore (F);
551554 break ;
552555 case Intrinsic::ctpop:
553- HasErrors |= lowerCtpopToCBits (F);
556+ HasErrors |= lowerCtpopToCountBits (F);
554557 break ;
555558 }
556559 Updated = true ;
0 commit comments