@@ -216,8 +216,8 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
216216
217217// Attempt to rebuild a normalized splat vector constant of the requested splat
218218// width, built up of potentially smaller scalar values.
219- static Constant *rebuildSplatableConstant (const Constant *C,
220- unsigned SplatBitWidth) {
219+ static Constant *rebuildSplatCst (const Constant *C, unsigned /* NumElts */ ,
220+ unsigned SplatBitWidth) {
221221 std::optional<APInt> Splat = getSplatableConstant (C, SplatBitWidth);
222222 if (!Splat)
223223 return nullptr ;
@@ -238,8 +238,8 @@ static Constant *rebuildSplatableConstant(const Constant *C,
238238 return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
239239}
240240
241- static Constant *rebuildZeroUpperConstant (const Constant *C,
242- unsigned ScalarBitWidth) {
241+ static Constant *rebuildZeroUpperCst (const Constant *C, unsigned /* NumElts */ ,
242+ unsigned ScalarBitWidth) {
243243 Type *Ty = C->getType ();
244244 Type *SclTy = Ty->getScalarType ();
245245 unsigned NumBits = Ty->getPrimitiveSizeInBits ();
@@ -265,8 +265,6 @@ static Constant *rebuildZeroUpperConstant(const Constant *C,
265265 return nullptr ;
266266}
267267
268- typedef std::function<Constant *(const Constant *, unsigned )> RebuildFn;
269-
270268bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
271269 MachineBasicBlock &MBB,
272270 MachineInstr &MI) {
@@ -277,43 +275,42 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
277275 bool HasBWI = ST->hasBWI ();
278276 bool HasVLX = ST->hasVLX ();
279277
280- auto FixupConstant =
281- [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282- unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283- unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284- assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
285- " Unexpected number of operands!" );
286-
287- if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
288- // Attempt to detect a suitable splat/vzload from increasing constant
289- // bitwidths.
290- // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291- std::tuple<unsigned , unsigned , RebuildFn> FixupLoad[] = {
292- {8 , OpBcst8, rebuildSplatableConstant},
293- {16 , OpBcst16, rebuildSplatableConstant},
294- {32 , OpUpper32, rebuildZeroUpperConstant},
295- {32 , OpBcst32, rebuildSplatableConstant},
296- {64 , OpUpper64, rebuildZeroUpperConstant},
297- {64 , OpBcst64, rebuildSplatableConstant},
298- {128 , OpBcst128, rebuildSplatableConstant},
299- {256 , OpBcst256, rebuildSplatableConstant},
300- };
301- for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302- if (Op) {
303- // Construct a suitable constant and adjust the MI to use the new
304- // constant pool entry.
305- if (Constant *NewCst = RebuildConstant (C, BitWidth)) {
306- unsigned NewCPI =
307- CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
308- MI.setDesc (TII->get (Op));
309- MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
310- return true ;
311- }
312- }
278+ struct FixupEntry {
279+ int Op;
280+ int NumCstElts;
281+ int BitWidth;
282+ std::function<Constant *(const Constant *, unsigned , unsigned )>
283+ RebuildConstant;
284+ };
285+ auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
286+ #ifdef EXPENSIVE_CHECKS
287+ assert (llvm::is_sorted (Fixups,
288+ [](const FixupEntry &A, const FixupEntry &B) {
289+ return (A.NumCstElts * A.BitWidth ) <
290+ (B.NumCstElts * B.BitWidth );
291+ }) &&
292+ " Constant fixup table not sorted in ascending constant size" );
293+ #endif
294+ assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
295+ " Unexpected number of operands!" );
296+ if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
297+ for (const FixupEntry &Fixup : Fixups) {
298+ if (Fixup.Op ) {
299+ // Construct a suitable constant and adjust the MI to use the new
300+ // constant pool entry.
301+ if (Constant *NewCst =
302+ Fixup.RebuildConstant (C, Fixup.NumCstElts , Fixup.BitWidth )) {
303+ unsigned NewCPI =
304+ CP->getConstantPoolIndex (NewCst, Align (Fixup.BitWidth / 8 ));
305+ MI.setDesc (TII->get (Fixup.Op ));
306+ MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
307+ return true ;
313308 }
314309 }
315- return false ;
316- };
310+ }
311+ }
312+ return false ;
313+ };
317314
318315 // Attempt to convert full width vector loads into broadcast/vzload loads.
319316 switch (Opc) {
@@ -323,82 +320,125 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
323320 case X86::MOVUPDrm:
324321 case X86::MOVUPSrm:
325322 // TODO: SSE3 MOVDDUP Handling
326- return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVSDrm, X86::MOVSSrm, 1 );
323+ return FixupConstant ({{X86::MOVSSrm, 1 , 32 , rebuildZeroUpperCst},
324+ {X86::MOVSDrm, 1 , 64 , rebuildZeroUpperCst}},
325+ 1 );
327326 case X86::VMOVAPDrm:
328327 case X86::VMOVAPSrm:
329328 case X86::VMOVUPDrm:
330329 case X86::VMOVUPSrm:
331- return FixupConstant (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
332- X86::VMOVSDrm, X86::VMOVSSrm, 1 );
330+ return FixupConstant ({{X86::VMOVSSrm, 1 , 32 , rebuildZeroUpperCst},
331+ {X86::VBROADCASTSSrm, 1 , 32 , rebuildSplatCst},
332+ {X86::VMOVSDrm, 1 , 64 , rebuildZeroUpperCst},
333+ {X86::VMOVDDUPrm, 1 , 64 , rebuildSplatCst}},
334+ 1 );
333335 case X86::VMOVAPDYrm:
334336 case X86::VMOVAPSYrm:
335337 case X86::VMOVUPDYrm:
336338 case X86::VMOVUPSYrm:
337- return FixupConstant (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338- X86::VBROADCASTSSYrm, 0 , 0 , 0 , 0 , 1 );
339+ return FixupConstant ({{X86::VBROADCASTSSYrm, 1 , 32 , rebuildSplatCst},
340+ {X86::VBROADCASTSDYrm, 1 , 64 , rebuildSplatCst},
341+ {X86::VBROADCASTF128rm, 1 , 128 , rebuildSplatCst}},
342+ 1 );
339343 case X86::VMOVAPDZ128rm:
340344 case X86::VMOVAPSZ128rm:
341345 case X86::VMOVUPDZ128rm:
342346 case X86::VMOVUPSZ128rm:
343- return FixupConstant (0 , 0 , X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0 ,
344- 0 , X86::VMOVSDZrm, X86::VMOVSSZrm, 1 );
347+ return FixupConstant ({{X86::VMOVSSZrm, 1 , 32 , rebuildZeroUpperCst},
348+ {X86::VBROADCASTSSZ128rm, 1 , 32 , rebuildSplatCst},
349+ {X86::VMOVSDZrm, 1 , 64 , rebuildZeroUpperCst},
350+ {X86::VMOVDDUPZ128rm, 1 , 64 , rebuildSplatCst}},
351+ 1 );
345352 case X86::VMOVAPDZ256rm:
346353 case X86::VMOVAPSZ256rm:
347354 case X86::VMOVUPDZ256rm:
348355 case X86::VMOVUPSZ256rm:
349- return FixupConstant (0 , X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350- X86::VBROADCASTSSZ256rm, 0 , 0 , 0 , 0 , 1 );
356+ return FixupConstant (
357+ {{X86::VBROADCASTSSZ256rm, 1 , 32 , rebuildSplatCst},
358+ {X86::VBROADCASTSDZ256rm, 1 , 64 , rebuildSplatCst},
359+ {X86::VBROADCASTF32X4Z256rm, 1 , 128 , rebuildSplatCst}},
360+ 1 );
351361 case X86::VMOVAPDZrm:
352362 case X86::VMOVAPSZrm:
353363 case X86::VMOVUPDZrm:
354364 case X86::VMOVUPSZrm:
355- return FixupConstant (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356- X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 , 0 , 0 ,
365+ return FixupConstant ({{X86::VBROADCASTSSZrm, 1 , 32 , rebuildSplatCst},
366+ {X86::VBROADCASTSDZrm, 1 , 64 , rebuildSplatCst},
367+ {X86::VBROADCASTF32X4rm, 1 , 128 , rebuildSplatCst},
368+ {X86::VBROADCASTF64X4rm, 1 , 256 , rebuildSplatCst}},
357369 1 );
358370 /* Integer Loads */
359371 case X86::MOVDQArm:
360- case X86::MOVDQUrm:
361- return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
372+ case X86::MOVDQUrm: {
373+ return FixupConstant ({{X86::MOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
374+ {X86::MOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst}},
362375 1 );
376+ }
363377 case X86::VMOVDQArm:
364- case X86::VMOVDQUrm:
365- return FixupConstant (0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
367- HasAVX2 ? X86::VPBROADCASTWrm : 0 ,
368- HasAVX2 ? X86::VPBROADCASTBrm : 0 , X86::VMOVQI2PQIrm,
369- X86::VMOVDI2PDIrm, 1 );
378+ case X86::VMOVDQUrm: {
379+ FixupEntry Fixups[] = {
380+ {HasAVX2 ? X86::VPBROADCASTBrm : 0 , 1 , 8 , rebuildSplatCst},
381+ {HasAVX2 ? X86::VPBROADCASTWrm : 0 , 1 , 16 , rebuildSplatCst},
382+ {X86::VMOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
383+ {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1 , 32 ,
384+ rebuildSplatCst},
385+ {X86::VMOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst},
386+ {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1 , 64 ,
387+ rebuildSplatCst},
388+ };
389+ return FixupConstant (Fixups, 1 );
390+ }
370391 case X86::VMOVDQAYrm:
371- case X86::VMOVDQUYrm:
372- return FixupConstant (
373- 0 , HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
374- HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
375- HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
376- HasAVX2 ? X86::VPBROADCASTWYrm : 0 , HasAVX2 ? X86::VPBROADCASTBYrm : 0 ,
377- 0 , 0 , 1 );
392+ case X86::VMOVDQUYrm: {
393+ FixupEntry Fixups[] = {
394+ {HasAVX2 ? X86::VPBROADCASTBYrm : 0 , 1 , 8 , rebuildSplatCst},
395+ {HasAVX2 ? X86::VPBROADCASTWYrm : 0 , 1 , 16 , rebuildSplatCst},
396+ {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1 , 32 ,
397+ rebuildSplatCst},
398+ {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1 , 64 ,
399+ rebuildSplatCst},
400+ {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1 , 128 ,
401+ rebuildSplatCst}};
402+ return FixupConstant (Fixups, 1 );
403+ }
378404 case X86::VMOVDQA32Z128rm:
379405 case X86::VMOVDQA64Z128rm:
380406 case X86::VMOVDQU32Z128rm:
381- case X86::VMOVDQU64Z128rm:
382- return FixupConstant (0 , 0 , X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
383- HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
384- HasBWI ? X86::VPBROADCASTBZ128rm : 0 ,
385- X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1 );
407+ case X86::VMOVDQU64Z128rm: {
408+ FixupEntry Fixups[] = {
409+ {HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 , 8 , rebuildSplatCst},
410+ {HasBWI ? X86::VPBROADCASTWZ128rm : 0 , 1 , 16 , rebuildSplatCst},
411+ {X86::VMOVDI2PDIZrm, 1 , 32 , rebuildZeroUpperCst},
412+ {X86::VPBROADCASTDZ128rm, 1 , 32 , rebuildSplatCst},
413+ {X86::VMOVQI2PQIZrm, 1 , 64 , rebuildZeroUpperCst},
414+ {X86::VPBROADCASTQZ128rm, 1 , 64 , rebuildSplatCst}};
415+ return FixupConstant (Fixups, 1 );
416+ }
386417 case X86::VMOVDQA32Z256rm:
387418 case X86::VMOVDQA64Z256rm:
388419 case X86::VMOVDQU32Z256rm:
389- case X86::VMOVDQU64Z256rm:
390- return FixupConstant (0 , X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
391- X86::VPBROADCASTDZ256rm,
392- HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
393- HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 0 , 0 , 1 );
420+ case X86::VMOVDQU64Z256rm: {
421+ FixupEntry Fixups[] = {
422+ {HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 1 , 8 , rebuildSplatCst},
423+ {HasBWI ? X86::VPBROADCASTWZ256rm : 0 , 1 , 16 , rebuildSplatCst},
424+ {X86::VPBROADCASTDZ256rm, 1 , 32 , rebuildSplatCst},
425+ {X86::VPBROADCASTQZ256rm, 1 , 64 , rebuildSplatCst},
426+ {X86::VBROADCASTI32X4Z256rm, 1 , 128 , rebuildSplatCst}};
427+ return FixupConstant (Fixups, 1 );
428+ }
394429 case X86::VMOVDQA32Zrm:
395430 case X86::VMOVDQA64Zrm:
396431 case X86::VMOVDQU32Zrm:
397- case X86::VMOVDQU64Zrm:
398- return FixupConstant (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399- X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400- HasBWI ? X86::VPBROADCASTWZrm : 0 ,
401- HasBWI ? X86::VPBROADCASTBZrm : 0 , 0 , 0 , 1 );
432+ case X86::VMOVDQU64Zrm: {
433+ FixupEntry Fixups[] = {
434+ {HasBWI ? X86::VPBROADCASTBZrm : 0 , 1 , 8 , rebuildSplatCst},
435+ {HasBWI ? X86::VPBROADCASTWZrm : 0 , 1 , 16 , rebuildSplatCst},
436+ {X86::VPBROADCASTDZrm, 1 , 32 , rebuildSplatCst},
437+ {X86::VPBROADCASTQZrm, 1 , 64 , rebuildSplatCst},
438+ {X86::VBROADCASTI32X4rm, 1 , 128 , rebuildSplatCst},
439+ {X86::VBROADCASTI64X4rm, 1 , 256 , rebuildSplatCst}};
440+ return FixupConstant (Fixups, 1 );
441+ }
402442 }
403443
404444 auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -423,7 +463,9 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
423463
424464 if (OpBcst32 || OpBcst64) {
425465 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
426- return FixupConstant (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , 0 , 0 , OpNo);
466+ FixupEntry Fixups[] = {{(int )OpBcst32, 32 , 32 , rebuildSplatCst},
467+ {(int )OpBcst64, 64 , 64 , rebuildSplatCst}};
468+ return FixupConstant (Fixups, OpNo);
427469 }
428470 return false ;
429471 };
0 commit comments