@@ -349,58 +349,82 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
349349}
350350
351351bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore (
352- StoreInst *SI , ArrayRef<Value *> InterleaveValues) const {
352+ Instruction *Store, Value *Mask , ArrayRef<Value *> InterleaveValues) const {
353353 unsigned Factor = InterleaveValues.size ();
354354 if (Factor > 8 )
355355 return false ;
356356
357- assert (SI->isSimple ());
358- IRBuilder<> Builder (SI);
357+ IRBuilder<> Builder (Store);
359358
360359 auto *InVTy = cast<VectorType>(InterleaveValues[0 ]->getType ());
361- auto *PtrTy = SI-> getPointerOperandType ();
362- const DataLayout &DL = SI-> getDataLayout ( );
360+ const DataLayout &DL = Store-> getDataLayout ();
361+ Type *XLenTy = Type::getIntNTy (Store-> getContext (), Subtarget. getXLen () );
363362
364- if (!isLegalInterleavedAccessType (InVTy, Factor, SI->getAlign (),
365- SI->getPointerAddressSpace (), DL))
366- return false ;
363+ Value *Ptr, *VL;
364+ Align Alignment;
365+ if (auto *SI = dyn_cast<StoreInst>(Store)) {
366+ assert (SI->isSimple ());
367+ Ptr = SI->getPointerOperand ();
368+ Alignment = SI->getAlign ();
369+ assert (!Mask && " Unexpected mask on a store" );
370+ Mask = Builder.getAllOnesMask (InVTy->getElementCount ());
371+ VL = isa<FixedVectorType>(InVTy)
372+ ? Builder.CreateElementCount (XLenTy, InVTy->getElementCount ())
373+ : Constant::getAllOnesValue (XLenTy);
374+ } else {
375+ auto *VPStore = cast<VPIntrinsic>(Store);
376+ assert (VPStore->getIntrinsicID () == Intrinsic::vp_store &&
377+ " Unexpected intrinsic" );
378+ Ptr = VPStore->getMemoryPointerParam ();
379+ Alignment = VPStore->getPointerAlignment ().value_or (
380+ DL.getABITypeAlign (InVTy->getElementType ()));
381+
382+ assert (Mask && " vp.store needs a mask!" );
383+
384+ Value *WideEVL = VPStore->getVectorLengthParam ();
385+ // Conservatively check if EVL is a multiple of factor, otherwise some
386+ // (trailing) elements might be lost after the transformation.
387+ if (!isMultipleOfN (WideEVL, DL, Factor))
388+ return false ;
367389
368- Type *XLenTy = Type::getIntNTy (SI->getContext (), Subtarget.getXLen ());
390+ VL = Builder.CreateZExt (
391+ Builder.CreateUDiv (WideEVL,
392+ ConstantInt::get (WideEVL->getType (), Factor)),
393+ XLenTy);
394+ }
395+ Type *PtrTy = Ptr->getType ();
396+ unsigned AS = Ptr->getType ()->getPointerAddressSpace ();
397+ if (!isLegalInterleavedAccessType (InVTy, Factor, Alignment, AS, DL))
398+ return false ;
369399
370400 if (isa<FixedVectorType>(InVTy)) {
371401 Function *VssegNFunc = Intrinsic::getOrInsertDeclaration (
372- SI ->getModule (), FixedVssegIntrIds[Factor - 2 ], {InVTy, PtrTy, XLenTy});
373-
402+ Store ->getModule (), FixedVssegIntrIds[Factor - 2 ],
403+ {InVTy, PtrTy, XLenTy});
374404 SmallVector<Value *, 10 > Ops (InterleaveValues);
375- Value *VL = Builder.CreateElementCount (XLenTy, InVTy->getElementCount ());
376- Value *Mask = Builder.getAllOnesMask (InVTy->getElementCount ());
377- Ops.append ({SI->getPointerOperand (), Mask, VL});
378-
405+ Ops.append ({Ptr, Mask, VL});
379406 Builder.CreateCall (VssegNFunc, Ops);
380407 return true ;
381408 }
382409 unsigned SEW = DL.getTypeSizeInBits (InVTy->getElementType ());
383410 unsigned NumElts = InVTy->getElementCount ().getKnownMinValue ();
384411 Type *VecTupTy = TargetExtType::get (
385- SI ->getContext (), " riscv.vector.tuple" ,
386- ScalableVectorType::get (Type::getInt8Ty (SI ->getContext ()),
412+ Store ->getContext (), " riscv.vector.tuple" ,
413+ ScalableVectorType::get (Type::getInt8Ty (Store ->getContext ()),
387414 NumElts * SEW / 8 ),
388415 Factor);
389416
390- Value *VL = Constant::getAllOnesValue (XLenTy);
391- Value *Mask = Builder.getAllOnesMask (InVTy->getElementCount ());
392-
393417 Value *StoredVal = PoisonValue::get (VecTupTy);
394418 for (unsigned i = 0 ; i < Factor; ++i)
395419 StoredVal = Builder.CreateIntrinsic (
396420 Intrinsic::riscv_tuple_insert, {VecTupTy, InVTy},
397421 {StoredVal, InterleaveValues[i], Builder.getInt32 (i)});
398422
399423 Function *VssegNFunc = Intrinsic::getOrInsertDeclaration (
400- SI ->getModule (), ScalableVssegIntrIds[Factor - 2 ],
424+ Store ->getModule (), ScalableVssegIntrIds[Factor - 2 ],
401425 {VecTupTy, PtrTy, Mask->getType (), VL->getType ()});
402426
403- Value *Operands[] = {StoredVal, SI-> getPointerOperand () , Mask, VL,
427+ Value *Operands[] = {StoredVal, Ptr , Mask, VL,
404428 ConstantInt::get (XLenTy, Log2_64 (SEW))};
405429 Builder.CreateCall (VssegNFunc, Operands);
406430 return true ;
0 commit comments