@@ -19,8 +19,6 @@ namespace mlir {
1919namespace triton {
2020namespace gpu {
2121
22- namespace {
23-
2422// Get the highest version supported for the hardware and the dot.
2523static int getMMAVersionSafe (int computeCapability, DotOp op) {
2624 // List supported mma version in order of preference.
@@ -43,8 +41,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4341 return 0 ;
4442}
4543
46- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
47- int numWarps) {
44+ SmallVector<unsigned >
45+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
4846 auto rank = shape.size ();
4947 // Early exit for batched matmul
5048 if (rank == 3 )
@@ -57,9 +55,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
5755 auto slices = multiRootGetSlice (dotOp, {filter}, {filter});
5856 bool hasChainedDot = false ;
5957 for (Operation *op : slices) {
60- if (isa<DotOp>(op) && (op != dotOp)) {
61- auto chainedDot = cast<DotOp>(op);
62- auto resTy = chainedDot.getResult ().getType ();
58+ if (dotOp->getName () == op->getName () && op != dotOp) {
59+ auto resTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
6360 if (resTy.getRank () != rank) {
6461 continue ;
6562 }
@@ -108,12 +105,17 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
108105}
109106
110107SmallVector<unsigned , 2 >
111- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
108+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
112109 const SmallVector<unsigned , 3 > &instrShape) {
113110 SetVector<Operation *> slices;
114- mlir::getForwardSlice (dotOp.getResult (), &slices);
115- if (llvm::find_if (slices, [](Operation *op) { return isa<DotOp>(op); }) !=
116- slices.end ())
111+ mlir::getForwardSlice (dotOp->getResult (0 ), &slices);
112+ if (llvm::find_if (slices, [&](Operation *op) {
113+ return dotOp->getName () == op->getName () ||
114+ // Contains a chained dot. We prefer to assign warps to one axis
115+ // to facilitate use cases like flash attention, allowing reductions
116+ // within the same warp.
117+ op->hasTrait <OpTrait::DotLike>();
118+ }) != slices.end ())
117119 return {(unsigned )numWarps, 1 };
118120
119121 // For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
@@ -162,11 +164,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
162164 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
163165 newLayout, SharedMemorySpace);
164166 rewriter.setInsertionPointAfterValue (arg);
167+
168+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
169+ // to SharedEncoding.
170+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
171+ argType.getEncoding ())) {
172+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
173+ // then pass it to the LocalAllocOp.
174+ auto newArgType = RankedTensorType::get (
175+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
176+ auto dotOperandToBlockedCvt =
177+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
178+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
179+ dotOperandToBlockedCvt);
180+ }
181+
165182 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
166183}
167184
168185SmallVector<unsigned , 3 >
169- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
186+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
170187 int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
171188 switch (version) {
172189 case 2 :
@@ -179,18 +196,32 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
179196 }
180197}
181198
199+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
200+ // extension.
201+ namespace {
202+
182203class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
183204 int computeCapability;
184205 mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
185206
186207 static bool bwdFilter (Operation *op) {
208+ // Dot operand layout assignment to Predicates are not currently supported
209+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
210+ // condition limits visibility of the original bit-width so that predicate
211+ // are not considered, hence, kwidth can never be = 32.
212+ if (isa<arith::UIToFPOp>(op)) {
213+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
214+ if (srcType.isInteger (1 ))
215+ return false ;
216+ }
187217 return op->getNumOperands () == 1 &&
188218 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
189219 isPureUnaryInlineAsm (op) ||
190220 op->getDialect ()->getTypeID () ==
191221 mlir::TypeID::get<arith::ArithDialect>());
192222 }
193223
224+ public:
194225 // Finds the first different bitwidth in the chain of shape-preserving
195226 // unary ops that x depends on.
196227 // There are two primary scenarios:
@@ -595,6 +626,15 @@ class TritonGPUAccelerateMatmulPass
595626 }
596627};
597628
629+ // Expose helper functions from BlockedToMMA to be reused for sparse matmul.
630+ int computeOrigBitWidth (Value x) {
631+ return BlockedToMMA::computeOrigBitWidth (x);
632+ }
633+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
634+ int opIdx, bool allowTranspose) {
635+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
636+ }
637+
598638} // namespace gpu
599639} // namespace triton
600640} // namespace mlir
0 commit comments