@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
383383 return Vectors.size ();
384384 else {
385385 assert (Vectors.size () > 0 && " Cannot call getNumRows without columns" );
386- return cast<FixedVectorType>(Vectors[ 0 ]-> getType () )->getNumElements ();
386+ return getVectorTy ( )->getNumElements ();
387387 }
388388 }
389389 unsigned getNumRows () const {
390390 if (isColumnMajor ()) {
391391 assert (Vectors.size () > 0 && " Cannot call getNumRows without columns" );
392- return cast<FixedVectorType>(Vectors[ 0 ]-> getType () )->getNumElements ();
392+ return getVectorTy ( )->getNumElements ();
393393 } else
394394 return Vectors.size ();
395395 }
396396
397397 void addVector (Value *V) { Vectors.push_back (V); }
398- VectorType *getColumnTy () {
398+ FixedVectorType *getColumnTy () {
399399 assert (isColumnMajor () && " only supported for column-major matrixes" );
400400 return getVectorTy ();
401401 }
402402
403- VectorType *getVectorTy () const {
404- return cast<VectorType >(Vectors[0 ]->getType ());
403+ FixedVectorType *getVectorTy () const {
404+ return cast<FixedVectorType >(Vectors[0 ]->getType ());
405405 }
406406
407407 iterator_range<SmallVector<Value *, 8 >::iterator> columns () {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
514514 : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
515515
516516 unsigned getNumOps (Type *VT) {
517- assert (isa<VectorType >(VT) && " Expected vector type" );
517+ assert (isa<FixedVectorType >(VT) && " Expected vector type" );
518518 return getNumOps (VT->getScalarType (),
519519 cast<FixedVectorType>(VT)->getNumElements ());
520520 }
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
540540 // / into vectors.
541541 MatrixTy getMatrix (Value *MatrixVal, const ShapeInfo &SI,
542542 IRBuilder<> &Builder) {
543- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType ());
544- assert (VType && " MatrixVal must be a vector type" );
545- assert (cast<FixedVectorType>(VType)->getNumElements () ==
546- SI.NumRows * SI.NumColumns &&
543+ FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType ());
544+ assert (VType->getNumElements () == SI.NumRows * SI.NumColumns &&
547545 " The vector size must match the number of matrix elements" );
548546
549547 // Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
563561
564562 // Otherwise split MatrixVal.
565563 SmallVector<Value *, 16 > SplitVecs;
566- for (unsigned MaskStart = 0 ;
567- MaskStart < cast<FixedVectorType>(VType)->getNumElements ();
564+ for (unsigned MaskStart = 0 ; MaskStart < VType->getNumElements ();
568565 MaskStart += SI.getStride ()) {
569566 Value *V = Builder.CreateShuffleVector (
570567 MatrixVal, createSequentialMask (MaskStart, SI.getStride (), 0 ),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
11571154 // / vectors.
11581155 MatrixTy loadMatrix (Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
11591156 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1160- auto *VType = cast<VectorType >(Ty);
1157+ auto *VType = cast<FixedVectorType >(Ty);
11611158 Type *EltTy = VType->getElementType ();
11621159 Type *VecTy = FixedVectorType::get (EltTy, Shape.getStride ());
11631160 Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
12391236 MatrixTy storeMatrix (Type *Ty, MatrixTy StoreVal, Value *Ptr,
12401237 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
12411238 IRBuilder<> &Builder) {
1242- auto VType = cast<VectorType >(Ty);
1239+ auto * VType = cast<FixedVectorType >(Ty);
12431240 Value *EltPtr = Ptr;
12441241 for (auto Vec : enumerate(StoreVal.vectors ())) {
12451242 Value *GEP = computeVectorAddr (
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
13771374 Value *LHS = MatMul->getArgOperand (0 );
13781375 Value *RHS = MatMul->getArgOperand (1 );
13791376
1380- Type *ElementType = cast<VectorType >(LHS->getType ())->getElementType ();
1377+ Type *ElementType = cast<FixedVectorType >(LHS->getType ())->getElementType ();
13811378 bool IsIntVec = ElementType->isIntegerTy ();
13821379
13831380 // Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
14751472 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
14761473 InstructionCost ReductionCost =
14771474 TTI.getArithmeticReductionCost (
1478- AddOpCode, cast<VectorType >(LHS->getType ()),
1475+ AddOpCode, cast<FixedVectorType >(LHS->getType ()),
14791476 IsIntVec ? std::nullopt : std::optional (FMF)) +
14801477 TTI.getArithmeticInstrCost (MulOpCode, LHS->getType ());
14811478 InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
15351532 Result = Builder.CreateAddReduce (Mul);
15361533 else {
15371534 Result = Builder.CreateFAddReduce (
1538- ConstantFP::get (cast<VectorType>(LHS-> getType ())-> getElementType (),
1539- 0.0 ),
1535+ ConstantFP::get (
1536+ cast<FixedVectorType>(LHS-> getType ())-> getElementType (), 0.0 ),
15401537 Mul);
15411538 cast<Instruction>(Result)->setFastMathFlags (FMF);
15421539 }
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
17351732 const unsigned R = LShape.NumRows ;
17361733 const unsigned C = RShape.NumColumns ;
17371734 const unsigned M = LShape.NumColumns ;
1738- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1735+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
17391736
17401737 const unsigned VF = std::max<unsigned >(
17411738 TTI.getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
17711768
17721769 void createTiledLoops (CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
17731770 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1774- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1771+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
17751772
17761773 // Create the main tiling loop nest.
17771774 TileInfo TI (LShape.NumRows , RShape.NumColumns , LShape.NumColumns , TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
18421839 const unsigned R = LShape.NumRows ;
18431840 const unsigned C = RShape.NumColumns ;
18441841 const unsigned M = LShape.NumColumns ;
1845- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1842+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
18461843
18471844 Value *APtr = getNonAliasingPointer (LoadOp0, Store, MatMul);
18481845 Value *BPtr = getNonAliasingPointer (LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
19141911 ? match (B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (T)))
19151912 : match (A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (T)))) {
19161913 IRBuilder<> Builder (MatMul);
1917- auto *EltType = cast<VectorType>(MatMul->getType ())->getElementType ();
1914+ auto *EltType =
1915+ cast<FixedVectorType>(MatMul->getType ())->getElementType ();
19181916 ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
19191917 ShapeInfo RShape (MatMul->getArgOperand (3 ), MatMul->getArgOperand (4 ));
19201918 const unsigned R = LShape.NumRows ;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
20452043 // / Lowers llvm.matrix.multiply.
20462044 void LowerMultiply (CallInst *MatMul) {
20472045 IRBuilder<> Builder (MatMul);
2048- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
2046+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
20492047 ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
20502048 ShapeInfo RShape (MatMul->getArgOperand (3 ), MatMul->getArgOperand (4 ));
20512049
@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
20732071 MatrixTy Result;
20742072 IRBuilder<> Builder (Inst);
20752073 Value *InputVal = Inst->getArgOperand (0 );
2076- VectorType *VectorTy = cast<VectorType >(InputVal->getType ());
2074+ FixedVectorType *VectorTy = cast<FixedVectorType >(InputVal->getType ());
20772075 ShapeInfo ArgShape (Inst->getArgOperand (1 ), Inst->getArgOperand (2 ));
20782076 MatrixTy InputMatrix = getMatrix (InputVal, ArgShape, Builder);
20792077
0 commit comments