@@ -132,6 +132,11 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
132132 " with -Os/-Oz" ),
133133 cl::init(true ), cl::Hidden);
134134
135+ static cl::opt<bool > EnableMemsetPatternIntrinsic (
136+ " loop-idiom-enable-memset-pattern-intrinsic" ,
137+ cl::desc (" Enable use of the memset_pattern intrinsic." ), cl::init(false ),
138+ cl::Hidden);
139+
135140namespace {
136141
137142class LoopIdiomRecognize {
@@ -306,7 +311,8 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
306311 HasMemsetPattern = TLI->has (LibFunc_memset_pattern16);
307312 HasMemcpy = TLI->has (LibFunc_memcpy);
308313
309- if (HasMemset || HasMemsetPattern || HasMemcpy)
314+ if (HasMemset || HasMemsetPattern || EnableMemsetPatternIntrinsic ||
315+ HasMemcpy)
310316 if (SE->hasLoopInvariantBackedgeTakenCount (L))
311317 return runOnCountableLoop ();
312318
@@ -463,8 +469,10 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
463469 // It looks like we can use SplatValue.
464470 return LegalStoreKind::Memset;
465471 }
466- if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
467- // Don't create memset_pattern16s with address spaces.
472+ if (!UnorderedAtomic && (HasMemsetPattern || EnableMemsetPatternIntrinsic) &&
473+ !DisableLIRP::Memset &&
474+ // Don't create memset_pattern16s / memset.pattern intrinsics with
475+ // address spaces.
468476 StorePtr->getType ()->getPointerAddressSpace () == 0 &&
469477 getMemSetPatternValue (StoredVal, DL)) {
470478 // It looks like we can use PatternValue!
@@ -1064,53 +1072,101 @@ bool LoopIdiomRecognize::processLoopStridedStore(
10641072 return Changed;
10651073
10661074 // Okay, everything looks good, insert the memset.
1075+ // MemsetArg is the number of bytes for the memset and memset_pattern16
1076+ // libcalls, and the number of pattern repetitions if the memset.pattern
1077+ // intrinsic is being used.
1078+ Value *MemsetArg;
1079+ std::optional<int64_t > BytesWritten = std::nullopt ;
1080+
1081+ if (PatternValue && EnableMemsetPatternIntrinsic) {
1082+ const SCEV *TripCountS =
1083+ SE->getTripCountFromExitCount (BECount, IntIdxTy, CurLoop);
1084+ if (!Expander.isSafeToExpand (TripCountS))
1085+ return Changed;
1086+ const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
1087+ if (!ConstStoreSize)
1088+ return Changed;
10671089
1068- const SCEV *NumBytesS =
1069- getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1070-
1071- // TODO: ideally we should still be able to generate memset if SCEV expander
1072- // is taught to generate the dependencies at the latest point.
1073- if (!Expander.isSafeToExpand (NumBytesS))
1074- return Changed;
1090+ MemsetArg = Expander.expandCodeFor (TripCountS, IntIdxTy,
1091+ Preheader->getTerminator ());
1092+ if (auto CI = dyn_cast<ConstantInt>(MemsetArg))
1093+ BytesWritten =
1094+ CI->getZExtValue () * ConstStoreSize->getValue ()->getZExtValue ();
1095+ } else {
1096+ const SCEV *NumBytesS =
1097+ getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
10751098
1076- Value *NumBytes =
1077- Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1099+ // TODO: ideally we should still be able to generate memset if SCEV expander
1100+ // is taught to generate the dependencies at the latest point.
1101+ if (!Expander.isSafeToExpand (NumBytesS))
1102+ return Changed;
1103+ MemsetArg =
1104+ Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1105+ if (auto CI = dyn_cast<ConstantInt>(MemsetArg))
1106+ BytesWritten = CI->getZExtValue ();
1107+ }
1108+ assert (MemsetArg && " MemsetArg should have been set" );
10781109
1079- if (!SplatValue && !isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16))
1110+ if (!SplatValue && !(isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16) ||
1111+ EnableMemsetPatternIntrinsic))
10801112 return Changed;
10811113
10821114 AAMDNodes AATags = TheStore->getAAMetadata ();
10831115 for (Instruction *Store : Stores)
10841116 AATags = AATags.merge (Store->getAAMetadata ());
1085- if (auto CI = dyn_cast<ConstantInt>(NumBytes) )
1086- AATags = AATags.extendTo (CI-> getZExtValue ());
1117+ if (BytesWritten )
1118+ AATags = AATags.extendTo (BytesWritten. value ());
10871119 else
10881120 AATags = AATags.extendTo (-1 );
10891121
10901122 CallInst *NewCall;
10911123 if (SplatValue) {
10921124 NewCall = Builder.CreateMemSet (
1093- BasePtr, SplatValue, NumBytes , MaybeAlign (StoreAlignment),
1125+ BasePtr, SplatValue, MemsetArg , MaybeAlign (StoreAlignment),
10941126 /* isVolatile=*/ false , AATags.TBAA , AATags.Scope , AATags.NoAlias );
10951127 } else {
1096- assert (isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16));
1097- // Everything is emitted in default address space
1098- Type *Int8PtrTy = DestInt8PtrTy;
1099-
1100- StringRef FuncName = " memset_pattern16" ;
1101- FunctionCallee MSP = getOrInsertLibFunc (M, *TLI, LibFunc_memset_pattern16,
1102- Builder.getVoidTy (), Int8PtrTy, Int8PtrTy, IntIdxTy);
1103- inferNonMandatoryLibFuncAttrs (M, FuncName, *TLI);
1104-
1105- // Otherwise we should form a memset_pattern16. PatternValue is known to be
1106- // an constant array of 16-bytes. Plop the value into a mergable global.
1107- GlobalVariable *GV = new GlobalVariable (*M, PatternValue->getType (), true ,
1108- GlobalValue::PrivateLinkage,
1109- PatternValue, " .memset_pattern" );
1110- GV->setUnnamedAddr (GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1111- GV->setAlignment (Align (16 ));
1112- Value *PatternPtr = GV;
1113- NewCall = Builder.CreateCall (MSP, {BasePtr, PatternPtr, NumBytes});
1128+ assert (isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16) ||
1129+ EnableMemsetPatternIntrinsic);
1130+ if (EnableMemsetPatternIntrinsic) {
1131+ // Everything is emitted in default address space
1132+
1133+ assert (isa<SCEVConstant>(StoreSizeSCEV) &&
1134+ " Expected constant store size" );
1135+ llvm::Type *IntType = Builder.getIntNTy (
1136+ cast<SCEVConstant>(StoreSizeSCEV)->getValue ()->getZExtValue () * 8 );
1137+
1138+ llvm::Value *BitcastedValue = Builder.CreateBitCast (StoredVal, IntType);
1139+
1140+ // (Optional) Use the bitcasted value for further operations
1141+
1142+ // Create the call to the intrinsic
1143+ NewCall =
1144+ Builder.CreateIntrinsic (Intrinsic::experimental_memset_pattern,
1145+ {DestInt8PtrTy, IntType, IntIdxTy},
1146+ {BasePtr, BitcastedValue, MemsetArg,
1147+ ConstantInt::getFalse (M->getContext ())});
1148+ } else {
1149+ // Everything is emitted in default address space
1150+ Type *Int8PtrTy = DestInt8PtrTy;
1151+
1152+ StringRef FuncName = " memset_pattern16" ;
1153+ FunctionCallee MSP = getOrInsertLibFunc (M, *TLI, LibFunc_memset_pattern16,
1154+ Builder.getVoidTy (), Int8PtrTy,
1155+ Int8PtrTy, IntIdxTy);
1156+ inferNonMandatoryLibFuncAttrs (M, FuncName, *TLI);
1157+
1158+ // Otherwise we should form a memset_pattern16. PatternValue is known to
1159+ // be an constant array of 16-bytes. Plop the value into a mergable
1160+ // global.
1161+ GlobalVariable *GV = new GlobalVariable (*M, PatternValue->getType (), true ,
1162+ GlobalValue::PrivateLinkage,
1163+ PatternValue, " .memset_pattern" );
1164+ GV->setUnnamedAddr (
1165+ GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1166+ GV->setAlignment (Align (16 ));
1167+ Value *PatternPtr = GV;
1168+ NewCall = Builder.CreateCall (MSP, {BasePtr, PatternPtr, MemsetArg});
1169+ }
11141170
11151171 // Set the TBAA info if present.
11161172 if (AATags.TBAA )
0 commit comments