@@ -150,6 +150,11 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
150150 " with -Os/-Oz" ),
151151 cl::init(true ), cl::Hidden);
152152
153+ static cl::opt<bool > ForceMemsetPatternIntrinsic (
154+ " loop-idiom-force-memset-pattern-intrinsic" ,
155+ cl::desc (" Use memset.pattern intrinsic whenever possible" ), cl::init(false ),
156+ cl::Hidden);
157+
153158namespace {
154159
155160class LoopIdiomRecognize {
@@ -323,10 +328,15 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
323328 L->getHeader ()->getParent ()->hasOptSize () && UseLIRCodeSizeHeurs;
324329
325330 HasMemset = TLI->has (LibFunc_memset);
331+ // TODO: Unconditionally enable use of the memset pattern intrinsic (or at
332+ // least, opt-in via target hook) once we are confident it will never result
333+ // in worse codegen than without. For now, use it only when the target
334+ // supports memset_pattern16 libcall (or unless this is overridden by
335+ // command line option).
326336 HasMemsetPattern = TLI->has (LibFunc_memset_pattern16);
327337 HasMemcpy = TLI->has (LibFunc_memcpy);
328338
329- if (HasMemset || HasMemsetPattern || HasMemcpy)
339+ if (HasMemset || HasMemsetPattern || ForceMemsetPatternIntrinsic || HasMemcpy)
330340 if (SE->hasLoopInvariantBackedgeTakenCount (L))
331341 return runOnCountableLoop ();
332342
@@ -378,11 +388,13 @@ static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) {
378388}
379389
380390// / getMemSetPatternValue - If a strided store of the specified value is safe to
381- // / turn into a memset_pattern16 , return a ConstantArray of 16 bytes that should
382- // / be passed in. Otherwise, return null.
391+ // / turn into a memset.patternn intrinsic , return the Constant that should
392+ // / be passed in. Otherwise, return null.
383393// /
384- // / Note that we don't ever attempt to use memset_pattern8 or 4, because these
385- // / just replicate their input array and then pass on to memset_pattern16.
394+ // / TODO this function could allow more constants than it does today (e.g.
395+ // / those over 16 bytes) now it has transitioned to being used for the
396+ // / memset.pattern intrinsic rather than directly the memset_pattern16
397+ // / libcall.
386398static Constant *getMemSetPatternValue (Value *V, const DataLayout *DL) {
387399 // FIXME: This could check for UndefValue because it can be merged into any
388400 // other valid pattern.
@@ -411,14 +423,12 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) {
411423 if (Size > 16 )
412424 return nullptr ;
413425
414- // If the constant is exactly 16 bytes, just use it.
415- if (Size == 16 )
416- return C;
426+ // For now, don't handle types that aren't int, floats, or pointers.
427+ Type *CTy = C->getType ();
428+ if (!CTy->isIntOrPtrTy () && !CTy->isFloatingPointTy ())
429+ return nullptr ;
417430
418- // Otherwise, we'll use an array of the constants.
419- unsigned ArraySize = 16 / Size;
420- ArrayType *AT = ArrayType::get (V->getType (), ArraySize);
421- return ConstantArray::get (AT, std::vector<Constant *>(ArraySize, C));
431+ return C;
422432}
423433
424434LoopIdiomRecognize::LegalStoreKind
@@ -479,7 +489,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
479489 // It looks like we can use SplatValue.
480490 return LegalStoreKind::Memset;
481491 }
482- if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
492+ if (!UnorderedAtomic && (HasMemsetPattern || ForceMemsetPatternIntrinsic) &&
493+ !DisableLIRP::Memset &&
483494 // Don't create memset_pattern16s with address spaces.
484495 StorePtr->getType ()->getPointerAddressSpace () == 0 &&
485496 getMemSetPatternValue (StoredVal, DL)) {
@@ -1061,50 +1072,81 @@ bool LoopIdiomRecognize::processLoopStridedStore(
10611072 return Changed;
10621073
10631074 // Okay, everything looks good, insert the memset.
1075+ Value *SplatValue = isBytewiseValue (StoredVal, *DL);
1076+ Constant *PatternValue = nullptr ;
1077+ if (!SplatValue)
1078+ PatternValue = getMemSetPatternValue (StoredVal, DL);
1079+
1080+ // MemsetArg is the number of bytes for the memset libcall, and the number
1081+ // of pattern repetitions if the memset.pattern intrinsic is being used.
1082+ Value *MemsetArg;
1083+ std::optional<int64_t > BytesWritten;
1084+
1085+ if (PatternValue && (HasMemsetPattern || ForceMemsetPatternIntrinsic)) {
1086+ const SCEV *TripCountS =
1087+ SE->getTripCountFromExitCount (BECount, IntIdxTy, CurLoop);
1088+ if (!Expander.isSafeToExpand (TripCountS))
1089+ return Changed;
1090+ const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
1091+ if (!ConstStoreSize)
1092+ return Changed;
1093+ Value *TripCount = Expander.expandCodeFor (TripCountS, IntIdxTy,
1094+ Preheader->getTerminator ());
1095+ uint64_t PatternRepsPerTrip =
1096+ (ConstStoreSize->getValue ()->getZExtValue () * 8 ) /
1097+ DL->getTypeSizeInBits (PatternValue->getType ());
1098+ // If ConstStoreSize is not equal to the width of PatternValue, then
1099+ // MemsetArg is TripCount * (ConstStoreSize/PatternValueWidth). Else
1100+ // MemSetArg is just TripCount.
1101+ MemsetArg =
1102+ PatternRepsPerTrip == 1
1103+ ? TripCount
1104+ : Builder.CreateMul (TripCount,
1105+ Builder.getIntN (IntIdxTy->getIntegerBitWidth (),
1106+ PatternRepsPerTrip));
1107+ if (auto *CI = dyn_cast<ConstantInt>(TripCount))
1108+ BytesWritten =
1109+ CI->getZExtValue () * ConstStoreSize->getValue ()->getZExtValue ();
10641110
1065- const SCEV *NumBytesS =
1066- getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1067-
1068- // TODO: ideally we should still be able to generate memset if SCEV expander
1069- // is taught to generate the dependencies at the latest point.
1070- if (!Expander.isSafeToExpand (NumBytesS))
1071- return Changed;
1111+ } else {
1112+ const SCEV *NumBytesS =
1113+ getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
10721114
1073- Value *NumBytes =
1074- Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1115+ // TODO: ideally we should still be able to generate memset if SCEV expander
1116+ // is taught to generate the dependencies at the latest point.
1117+ if (!Expander.isSafeToExpand (NumBytesS))
1118+ return Changed;
1119+ MemsetArg =
1120+ Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1121+ if (auto *CI = dyn_cast<ConstantInt>(MemsetArg))
1122+ BytesWritten = CI->getZExtValue ();
1123+ }
1124+ assert (MemsetArg && " MemsetArg should have been set" );
10751125
10761126 AAMDNodes AATags = TheStore->getAAMetadata ();
10771127 for (Instruction *Store : Stores)
10781128 AATags = AATags.merge (Store->getAAMetadata ());
1079- if (auto CI = dyn_cast<ConstantInt>(NumBytes) )
1080- AATags = AATags.extendTo (CI-> getZExtValue ());
1129+ if (BytesWritten )
1130+ AATags = AATags.extendTo (BytesWritten. value ());
10811131 else
10821132 AATags = AATags.extendTo (-1 );
10831133
10841134 CallInst *NewCall;
1085- if (Value * SplatValue = isBytewiseValue (StoredVal, *DL) ) {
1086- NewCall = Builder.CreateMemSet (BasePtr, SplatValue, NumBytes ,
1135+ if (SplatValue) {
1136+ NewCall = Builder.CreateMemSet (BasePtr, SplatValue, MemsetArg ,
10871137 MaybeAlign (StoreAlignment),
10881138 /* isVolatile=*/ false , AATags);
1089- } else if (isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16)) {
1090- // Everything is emitted in default address space
1091- Type *Int8PtrTy = DestInt8PtrTy;
1092-
1093- StringRef FuncName = " memset_pattern16" ;
1094- FunctionCallee MSP = getOrInsertLibFunc (M, *TLI, LibFunc_memset_pattern16,
1095- Builder.getVoidTy (), Int8PtrTy, Int8PtrTy, IntIdxTy);
1096- inferNonMandatoryLibFuncAttrs (M, FuncName, *TLI);
1097-
1098- // Otherwise we should form a memset_pattern16. PatternValue is known to be
1099- // an constant array of 16-bytes. Plop the value into a mergable global.
1100- Constant *PatternValue = getMemSetPatternValue (StoredVal, DL);
1101- assert (PatternValue && " Expected pattern value." );
1102- GlobalVariable *GV = new GlobalVariable (*M, PatternValue->getType (), true ,
1103- GlobalValue::PrivateLinkage,
1104- PatternValue, " .memset_pattern" );
1105- GV->setUnnamedAddr (GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1106- GV->setAlignment (Align (16 ));
1107- NewCall = Builder.CreateCall (MSP, {BasePtr, GV, NumBytes});
1139+ } else if (ForceMemsetPatternIntrinsic ||
1140+ isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16)) {
1141+ assert (isa<SCEVConstant>(StoreSizeSCEV) && " Expected constant store size" );
1142+
1143+ NewCall = Builder.CreateIntrinsic (
1144+ Intrinsic::experimental_memset_pattern,
1145+ {DestInt8PtrTy, PatternValue->getType (), IntIdxTy},
1146+ {BasePtr, PatternValue, MemsetArg,
1147+ ConstantInt::getFalse (M->getContext ())});
1148+ if (StoreAlignment)
1149+ cast<MemSetPatternInst>(NewCall)->setDestAlignment (*StoreAlignment);
11081150 NewCall->setAAMetadata (AATags);
11091151 } else {
11101152 // Neither a memset, nor memset_pattern16
0 commit comments