@@ -20,8 +20,6 @@ namespace mlir {
2020namespace triton {
2121namespace gpu {
2222
23- namespace {
24-
2523// Get the highest version supported for the hardware and the dot.
2624static int getMMAVersionSafe (int computeCapability, DotOp op) {
2725 // List supported mma version in order of preference.
@@ -44,8 +42,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4442 return 0 ;
4543}
4644
47- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
48- int numWarps) {
45+ SmallVector<unsigned >
46+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
4947 auto rank = shape.size ();
5048 // Early exit for batched matmul
5149 if (rank == 3 )
@@ -109,10 +107,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
109107}
110108
111109SmallVector<unsigned , 2 >
112- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
110+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
113111 const SmallVector<unsigned , 3 > &instrShape) {
114112 SetVector<Operation *> slices;
115- mlir::getForwardSlice (dotOp. getResult (), &slices);
113+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
116114 // Contains a chained dot. We prefer to assign warps to one axis
117115 // to facilitate use cases like flash attention, allowing reductions within
118116 // the same warp.
@@ -167,11 +165,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
167165 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
168166 newLayout, SharedMemorySpace);
169167 rewriter.setInsertionPointAfterValue (arg);
168+
169+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
170+ // to SharedEncoding.
171+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
172+ argType.getEncoding ())) {
173+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
174+ // then pass it to the LocalAllocOp.
175+ auto newArgType = RankedTensorType::get (
176+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
177+ auto dotOperandToBlockedCvt =
178+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
179+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
180+ dotOperandToBlockedCvt);
181+ }
182+
170183 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
171184}
172185
173186SmallVector<unsigned , 3 >
174- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
187+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
175188 int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
176189 switch (version) {
177190 case 2 :
@@ -184,18 +197,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
184197 }
185198}
186199
200+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
201+ // extension.
202+ namespace {
203+
187204class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
188205 int computeCapability;
189206 mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
190207
191208 static bool bwdFilter (Operation *op) {
209+ // Dot operand layout assignment to Predicates are not currently supported
210+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
211+ // condition limits visibility of the original bit-width so that predicate
212+ // are not considered, hence, kwidth can never be = 32.
213+ if (isa<arith::UIToFPOp>(op)) {
214+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
215+ if (srcType.isInteger (1 ))
216+ return false ;
217+ }
192218 return op->getNumOperands () == 1 &&
193219 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
194220 isPureUnaryInlineAsm (op) ||
195221 op->getDialect ()->getTypeID () ==
196222 mlir::TypeID::get<arith::ArithDialect>());
197223 }
198224
225+ public:
199226 // Finds the first different bitwidth in the chain of shape-preserving
200227 // unary ops that x depends on.
201228 // There are two primary scenarios:
@@ -813,6 +840,15 @@ class TritonGPUAccelerateMatmulPass
813840 }
814841};
815842
843+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
844+ int computeOrigBitWidth (Value x) {
845+ return BlockedToMMA::computeOrigBitWidth (x);
846+ }
847+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
848+ int opIdx, bool allowTranspose) {
849+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
850+ }
851+
816852} // namespace gpu
817853} // namespace triton
818854} // namespace mlir
0 commit comments