From fe2bf1cff2611cfb200605c303f34a3b6ef10720 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Sun, 1 Jun 2025 10:51:20 -0700 Subject: [PATCH] [Matrix] Use FixedVectorType everywhere in the LowerMatrixIntrinsics pass. NFC These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes. --- .../Scalar/LowerMatrixIntrinsics.cpp | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..787e107464c0a 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -383,25 +383,25 @@ class LowerMatrixIntrinsics { return Vectors.size(); else { assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); - return cast(Vectors[0]->getType())->getNumElements(); + return getVectorTy()->getNumElements(); } } unsigned getNumRows() const { if (isColumnMajor()) { assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); - return cast(Vectors[0]->getType())->getNumElements(); + return getVectorTy()->getNumElements(); } else return Vectors.size(); } void addVector(Value *V) { Vectors.push_back(V); } - VectorType *getColumnTy() { + FixedVectorType *getColumnTy() { assert(isColumnMajor() && "only supported for column-major matrixes"); return getVectorTy(); } - VectorType *getVectorTy() const { - return cast(Vectors[0]->getType()); + FixedVectorType *getVectorTy() const { + return cast(Vectors[0]->getType()); } iterator_range::iterator> columns() { @@ -514,7 +514,7 @@ class LowerMatrixIntrinsics { : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {} unsigned getNumOps(Type *VT) { - assert(isa(VT) && "Expected vector type"); + assert(isa(VT) && "Expected vector type"); return getNumOps(VT->getScalarType(), cast(VT)->getNumElements()); } @@ -540,10 +540,8 @@ class LowerMatrixIntrinsics { /// into vectors. MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, IRBuilder<> &Builder) { - VectorType *VType = dyn_cast(MatrixVal->getType()); - assert(VType && "MatrixVal must be a vector type"); - assert(cast(VType)->getNumElements() == - SI.NumRows * SI.NumColumns && + FixedVectorType *VType = cast(MatrixVal->getType()); + assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && "The vector size must match the number of matrix elements"); // Check if we lowered MatrixVal using shape information. In that case, @@ -563,8 +561,7 @@ class LowerMatrixIntrinsics { // Otherwise split MatrixVal. SmallVector SplitVecs; - for (unsigned MaskStart = 0; - MaskStart < cast(VType)->getNumElements(); + for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); MaskStart += SI.getStride()) { Value *V = Builder.CreateShuffleVector( MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0), @@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics { /// vectors. MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { - auto *VType = cast(Ty); + auto *VType = cast(Ty); Type *EltTy = VType->getElementType(); Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); Value *EltPtr = Ptr; @@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics { MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, MaybeAlign MAlign, Value *Stride, bool IsVolatile, IRBuilder<> &Builder) { - auto VType = cast(Ty); + auto *VType = cast(Ty); Value *EltPtr = Ptr; for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr( @@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics { Value *LHS = MatMul->getArgOperand(0); Value *RHS = MatMul->getArgOperand(1); - Type *ElementType = cast(LHS->getType())->getElementType(); + Type *ElementType = cast(LHS->getType())->getElementType(); bool IsIntVec = ElementType->isIntegerTy(); // Floating point reductions require reassocation. @@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics { int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; InstructionCost ReductionCost = TTI.getArithmeticReductionCost( - AddOpCode, cast(LHS->getType()), + AddOpCode, cast(LHS->getType()), IsIntVec ? std::nullopt : std::optional(FMF)) + TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); InstructionCost SequentialAddCost = @@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics { Result = Builder.CreateAddReduce(Mul); else { Result = Builder.CreateFAddReduce( - ConstantFP::get(cast(LHS->getType())->getElementType(), - 0.0), + ConstantFP::get( + cast(LHS->getType())->getElementType(), 0.0), Mul); cast(Result)->setFastMathFlags(FMF); } @@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics { const unsigned R = LShape.NumRows; const unsigned C = RShape.NumColumns; const unsigned M = LShape.NumColumns; - auto *EltType = cast(MatMul->getType())->getElementType(); + auto *EltType = cast(MatMul->getType())->getElementType(); const unsigned VF = std::max( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) @@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics { void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, Value *RPtr, ShapeInfo RShape, StoreInst *Store) { - auto *EltType = cast(MatMul->getType())->getElementType(); + auto *EltType = cast(MatMul->getType())->getElementType(); // Create the main tiling loop nest. TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize); @@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics { const unsigned R = LShape.NumRows; const unsigned C = RShape.NumColumns; const unsigned M = LShape.NumColumns; - auto *EltType = cast(MatMul->getType())->getElementType(); + auto *EltType = cast(MatMul->getType())->getElementType(); Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); @@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics { ? match(B, m_Intrinsic(m_Value(T))) : match(A, m_Intrinsic(m_Value(T)))) { IRBuilder<> Builder(MatMul); - auto *EltType = cast(MatMul->getType())->getElementType(); + auto *EltType = + cast(MatMul->getType())->getElementType(); ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); const unsigned R = LShape.NumRows; @@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics { /// Lowers llvm.matrix.multiply. void LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); - auto *EltType = cast(MatMul->getType())->getElementType(); + auto *EltType = cast(MatMul->getType())->getElementType(); ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); @@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics { MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); - VectorType *VectorTy = cast(InputVal->getType()); + FixedVectorType *VectorTy = cast(InputVal->getType()); ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);