@@ -441,98 +441,74 @@ bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
441441 // __attribute__((vector_size(64))) int d;
442442 // }
443443 //
444- // Struct of 1 fixed-length vector is passed as a scalable vector.
445- // Struct of >1 fixed-length vectors are passed as vector tuple.
446- // Struct of 1 array of fixed-length vectors is passed as a scalable vector.
447- // Otherwise, pass the struct indirectly.
448-
449- if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType (Ty))) {
450- unsigned NumElts = STy->getStructNumElements ();
451- if (NumElts > 8 )
452- return false ;
444+ // 1. Struct of 1 fixed-length vector is passed as a scalable vector.
445+ // 2. Struct of >1 fixed-length vectors are passed as vector tuple.
446+ // 3. Struct of an array with 1 element of fixed-length vectors is passed as a
447+ // scalable vector.
448+ // 4. Struct of an array with >1 elements of fixed-length vectors is passed as
449+ // vector tuple.
450+ // 5. Otherwise, pass the struct indirectly.
451+
452+ llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType (Ty));
453+ if (!STy)
454+ return false ;
453455
454- auto *FirstEltTy = STy->getElementType ( 0 );
455- if (!STy-> containsHomogeneousTypes () )
456- return false ;
456+ unsigned NumElts = STy->getStructNumElements ( );
457+ if (NumElts > 8 )
458+ return false ;
457459
458- // Check structure of fixed-length vectors and turn them into vector tuple
459- // type if legal.
460- if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
461- if (NumElts == 1 ) {
462- // Handle single fixed-length vector.
463- VLSType = llvm::ScalableVectorType::get (
464- FixedVecTy->getElementType (),
465- llvm::divideCeil (FixedVecTy->getNumElements () *
466- llvm::RISCV::RVVBitsPerBlock,
467- ABIVLen));
468- // Check registers needed <= 8.
469- return llvm::divideCeil (
470- FixedVecTy->getNumElements () *
471- FixedVecTy->getElementType ()->getScalarSizeInBits (),
472- ABIVLen) <= 8 ;
473- }
474- // LMUL
475- // = fixed-length vector size / ABIVLen
476- // = 8 * I8EltCount / RVVBitsPerBlock
477- // =>
478- // I8EltCount
479- // = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
480- unsigned I8EltCount = llvm::divideCeil (
481- FixedVecTy->getNumElements () *
482- FixedVecTy->getElementType ()->getScalarSizeInBits () *
483- llvm::RISCV::RVVBitsPerBlock,
484- ABIVLen * 8 );
485- VLSType = llvm::TargetExtType::get (
486- getVMContext (), " riscv.vector.tuple" ,
487- llvm::ScalableVectorType::get (llvm::Type::getInt8Ty (getVMContext ()),
488- I8EltCount),
489- NumElts);
490- // Check registers needed <= 8.
491- return NumElts *
492- llvm::divideCeil (
493- FixedVecTy->getNumElements () *
494- FixedVecTy->getElementType ()->getScalarSizeInBits (),
495- ABIVLen) <=
496- 8 ;
497- }
460+ auto *FirstEltTy = STy->getElementType (0 );
461+ if (!STy->containsHomogeneousTypes ())
462+ return false ;
498463
499- // If elements are not fixed-length vectors, it should be an array.
464+ if (auto *ArrayTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
465+ // Only struct of single array is accepted
500466 if (NumElts != 1 )
501467 return false ;
468+ FirstEltTy = ArrayTy->getArrayElementType ();
469+ NumElts = ArrayTy->getNumElements ();
470+ }
502471
503- // Check array of fixed-length vector and turn it into scalable vector type
504- // if legal.
505- if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
506- unsigned NumArrElt = ArrTy->getNumElements ();
507- if (NumArrElt > 8 )
508- return false ;
472+ auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy);
473+ if (!FixedVecTy)
474+ return false ;
509475
510- auto *ArrEltTy = dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType ());
511- if (!ArrEltTy)
512- return false ;
476+ // Check registers needed <= 8.
477+ if (NumElts * llvm::divideCeil (
478+ FixedVecTy->getNumElements () *
479+ FixedVecTy->getElementType ()->getScalarSizeInBits (),
480+ ABIVLen) >
481+ 8 )
482+ return false ;
513483
514- // LMUL
515- // = NumArrElt * fixed-length vector size / ABIVLen
516- // = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
517- // =>
518- // ScalVecNumElts
519- // = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
520- // (ABIVLen * fixed-length vector elt size)
521- // = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
522- // ABIVLen
523- unsigned ScalVecNumElts = llvm::divideCeil (
524- NumArrElt * ArrEltTy->getNumElements () * llvm::RISCV::RVVBitsPerBlock,
525- ABIVLen);
526- VLSType = llvm::ScalableVectorType::get (ArrEltTy->getElementType (),
527- ScalVecNumElts);
528- // Check registers needed <= 8.
529- return llvm::divideCeil (
530- ScalVecNumElts *
531- ArrEltTy->getElementType ()->getScalarSizeInBits (),
532- llvm::RISCV::RVVBitsPerBlock) <= 8 ;
533- }
484+ // Turn them into scalable vector type or vector tuple type if legal.
485+ if (NumElts == 1 ) {
486+ // Handle single fixed-length vector.
487+ VLSType = llvm::ScalableVectorType::get (
488+ FixedVecTy->getElementType (),
489+ llvm::divideCeil (FixedVecTy->getNumElements () *
490+ llvm::RISCV::RVVBitsPerBlock,
491+ ABIVLen));
492+ return true ;
534493 }
535- return false ;
494+
495+ // LMUL
496+ // = fixed-length vector size / ABIVLen
497+ // = 8 * I8EltCount / RVVBitsPerBlock
498+ // =>
499+ // I8EltCount
500+ // = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
501+ unsigned I8EltCount =
502+ llvm::divideCeil (FixedVecTy->getNumElements () *
503+ FixedVecTy->getElementType ()->getScalarSizeInBits () *
504+ llvm::RISCV::RVVBitsPerBlock,
505+ ABIVLen * 8 );
506+ VLSType = llvm::TargetExtType::get (
507+ getVMContext (), " riscv.vector.tuple" ,
508+ llvm::ScalableVectorType::get (llvm::Type::getInt8Ty (getVMContext ()),
509+ I8EltCount),
510+ NumElts);
511+ return true ;
536512}
537513
538514// Fixed-length RVV vectors are represented as scalable vectors in function
0 commit comments