@@ -181,10 +181,17 @@ static void upcastI8AllocasAndUses(Instruction &I,
181181    if  (!Load)
182182      continue ;
183183    for  (User *LU : Load->users ()) {
184-       auto  *Cast = dyn_cast<CastInst>(LU);
185-       if  (!Cast)
184+       Type *Ty = nullptr ;
185+       if  (auto  *Cast = dyn_cast<CastInst>(LU))
186+         Ty = Cast->getType ();
187+       if  (CallInst *CI = dyn_cast<CallInst>(LU)) {
188+         if  (CI->getIntrinsicID () == Intrinsic::memset)
189+           Ty = Type::getInt32Ty (CI->getContext ());
190+       }
191+ 
192+       if  (!Ty)
186193        continue ;
187-       Type *Ty = Cast-> getType (); 
194+ 
188195      if  (!SmallestType ||
189196          Ty->getPrimitiveSizeInBits () < SmallestType->getPrimitiveSizeInBits ())
190197        SmallestType = Ty;
@@ -240,8 +247,9 @@ downcastI64toI32InsertExtractElements(Instruction &I,
240247  }
241248}
242249
243- void  emitMemset (IRBuilder<> &Builder, Value *Dst, Value *Val,
244-                 ConstantInt *SizeCI) {
250+ void  emitMemsetExpansion (IRBuilder<> &Builder, Value *Dst, Value *Val,
251+                          ConstantInt *SizeCI,
252+                          DenseMap<Value *, Value *> &ReplacedValues) {
245253  LLVMContext &Ctx = Builder.getContext ();
246254  [[maybe_unused]] DataLayout DL =
247255      Builder.GetInsertBlock ()->getModule ()->getDataLayout ();
@@ -266,9 +274,19 @@ void emitMemset(IRBuilder<> &Builder, Value *Dst, Value *Val,
266274  assert (OrigSize == ElemSize * Size && " Size in bytes must match"  );
267275
268276  Value *TypedVal = Val;
269-   if  (Val->getType () != ElemTy)
270-     TypedVal = Builder.CreateIntCast (Val, ElemTy,
271-                                      false ); //  Or use CreateBitCast for float
277+ 
278+   if  (Val->getType () != ElemTy) {
279+     //  Note for i8 replacements if we know them we should use them.
280+     //  Further if this is a constant ReplacedValues will return null
281+     //  so we will stick to TypedVal = Val
282+     if  (ReplacedValues[Val])
283+       TypedVal = ReplacedValues[Val];
284+     //  This case Val is a ConstantInt so the cast folds away.
285+     //  However if we don't do the cast the store below ends up being
286+     //  an i8.
287+     else 
288+       TypedVal = Builder.CreateIntCast (Val, ElemTy, false );
289+   }
272290
273291  for  (uint64_t  I = 0 ; I < Size; ++I) {
274292    Value *Offset = ConstantInt::get (Type::getInt32Ty (Ctx), I);
@@ -279,7 +297,7 @@ void emitMemset(IRBuilder<> &Builder, Value *Dst, Value *Val,
279297
280298static  void  removeMemSet (Instruction &I,
281299                         SmallVectorImpl<Instruction *> &ToRemove,
282-                          DenseMap<Value *, Value *>) {
300+                          DenseMap<Value *, Value *> &ReplacedValues ) {
283301  if  (CallInst *CI = dyn_cast<CallInst>(&I)) {
284302    Intrinsic::ID ID = CI->getIntrinsicID ();
285303    if  (ID == Intrinsic::memset) {
@@ -289,7 +307,7 @@ static void removeMemSet(Instruction &I,
289307      [[maybe_unused]] ConstantInt *Size =
290308          dyn_cast<ConstantInt>(CI->getArgOperand (2 ));
291309      assert (Size && " Expected Size to be a ConstantInt"  );
292-       emitMemset (Builder, Dst, Val, Size);
310+       emitMemsetExpansion (Builder, Dst, Val, Size, ReplacedValues );
293311      ToRemove.push_back (CI);
294312    }
295313  }
@@ -322,11 +340,11 @@ class DXILLegalizationPipeline {
322340      LegalizationPipeline;
323341
324342  void  initializeLegalizationPipeline () {
325-     LegalizationPipeline.push_back (removeMemSet);
326343    LegalizationPipeline.push_back (upcastI8AllocasAndUses);
327344    LegalizationPipeline.push_back (fixI8UseChain);
328345    LegalizationPipeline.push_back (downcastI64toI32InsertExtractElements);
329346    LegalizationPipeline.push_back (legalizeFreeze);
347+     LegalizationPipeline.push_back (removeMemSet);
330348  }
331349};
332350
0 commit comments