@@ -21,15 +21,13 @@ namespace mlir {
2121namespace triton {
2222namespace gpu {
2323
24- namespace {
25-
2624// Get the highest version supported for the hardware and the dot.
2725static int getMMAVersionSafe (int computeCapability, DotOp op) {
2826 // List supported mma version in order of preference.
2927 SmallVector<int > versionsSupported;
3028 if (computeCapability < 75 ) {
3129 versionsSupported = {1 };
32- } else if (computeCapability < 90 ) {
30+ } else if (computeCapability < 90 || computeCapability >= 100 ) {
3331 versionsSupported = {2 };
3432 } else if (computeCapability < 100 ) {
3533 versionsSupported = {3 , 2 };
@@ -45,8 +43,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4543 return 0 ;
4644}
4745
48- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
49- int numWarps) {
46+ SmallVector<unsigned >
47+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
5048 auto rank = shape.size ();
5149 // Early exit for batched matmul
5250 if (rank == 3 )
@@ -110,10 +108,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
110108}
111109
112110SmallVector<unsigned , 2 >
113- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
111+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
114112 const SmallVector<unsigned , 3 > &instrShape) {
115113 SetVector<Operation *> slices;
116- mlir::getForwardSlice (dotOp. getResult (), &slices);
114+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
117115 // Contains a chained dot. We prefer to assign warps to one axis
118116 // to facilitate use cases like flash attention, allowing reductions within
119117 // the same warp.
@@ -168,11 +166,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
168166 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
169167 newLayout, SharedMemorySpace);
170168 rewriter.setInsertionPointAfterValue (arg);
169+
170+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
171+ // to SharedEncoding.
172+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
173+ argType.getEncoding ())) {
174+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
175+ // then pass it to the LocalAllocOp.
176+ auto newArgType = RankedTensorType::get (
177+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
178+ auto dotOperandToBlockedCvt =
179+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
180+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
181+ dotOperandToBlockedCvt);
182+ }
183+
171184 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
172185}
173186
174187SmallVector<unsigned , 3 >
175- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
188+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
176189 int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
177190 switch (version) {
178191 case 2 :
@@ -185,18 +198,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
185198 }
186199}
187200
201+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
202+ // extension.
203+ namespace {
204+
188205class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
189206 int computeCapability;
190207 mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
191208
192209 static bool bwdFilter (Operation *op) {
210+ // Dot operand layout assignment to Predicates are not currently supported
211+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
212+ // condition limits visibility of the original bit-width so that predicate
213+ // are not considered, hence, kwidth can never be = 32.
214+ if (isa<arith::UIToFPOp>(op)) {
215+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
216+ if (srcType.isInteger (1 ))
217+ return false ;
218+ }
193219 return op->getNumOperands () == 1 &&
194220 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
195221 isPureUnaryInlineAsm (op) ||
196222 op->getDialect ()->getTypeID () ==
197223 mlir::TypeID::get<arith::ArithDialect>());
198224 }
199225
226+ public:
200227 // Finds the first different bitwidth in the chain of shape-preserving
201228 // unary ops that x depends on.
202229 // There are two primary scenarios:
@@ -720,6 +747,15 @@ class TritonGPUAccelerateMatmulPass
720747 }
721748};
722749
750+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
751+ int computeOrigBitWidth (Value x) {
752+ return BlockedToMMA::computeOrigBitWidth (x);
753+ }
754+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
755+ int opIdx, bool allowTranspose) {
756+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
757+ }
758+
723759} // namespace gpu
724760} // namespace triton
725761} // namespace mlir
0 commit comments