@@ -21,8 +21,6 @@ namespace mlir {
2121namespace triton {
2222namespace gpu {
2323
24- namespace {
25-
2624// Get the highest version supported for the hardware and the dot.
2725static int getMMAVersionSafe (int computeCapability, DotOp op) {
2826 // List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4745 return 0 ;
4846}
4947
50- SmallVector<unsigned > warpsPerTileV2 (DotOp dotOp, const ArrayRef< int64_t > shape,
51- int numWarps) {
48+ SmallVector<unsigned >
49+ warpsPerTileV2 (Operation *dotOp, const ArrayRef< int64_t > shape, int numWarps) {
5250 auto rank = shape.size ();
5351 // Early exit for batched matmul
5452 if (rank == 3 )
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112110}
113111
114112SmallVector<unsigned , 2 >
115- warpsPerTileV3 (DotOp dotOp, const ArrayRef<int64_t > shape, int numWarps,
113+ warpsPerTileV3 (Operation * dotOp, const ArrayRef<int64_t > shape, int numWarps,
116114 const SmallVector<unsigned , 3 > &instrShape) {
117115 SetVector<Operation *> slices;
118- mlir::getForwardSlice (dotOp. getResult (), &slices);
116+ mlir::getForwardSlice (dotOp-> getResult (0 ), &slices);
119117 // Contains a chained dot. We prefer to assign warps to one axis
120118 // to facilitate use cases like flash attention, allowing reductions within
121119 // the same warp.
@@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
170168 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
171169 newLayout, SharedMemorySpace);
172170 rewriter.setInsertionPointAfterValue (arg);
171+
172+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
173+ // to SharedEncoding.
174+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
175+ argType.getEncoding ())) {
176+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
177+ // then pass it to the LocalAllocOp.
178+ auto newArgType = RankedTensorType::get (
179+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
180+ auto dotOperandToBlockedCvt =
181+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
182+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
183+ dotOperandToBlockedCvt);
184+ }
185+
173186 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
174187}
175188
176189SmallVector<unsigned , 3 >
177- getWarpsPerTile (DotOp dotOp, const ArrayRef<int64_t > shape, int version,
190+ getWarpsPerTile (Operation* dotOp, const ArrayRef<int64_t > shape, int version,
178191 int numWarps, const SmallVector<unsigned , 3 > &instrShape) {
179192 switch (version) {
180193 case 2 :
@@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
188201}
189202
190203static bool bwdFilter (Operation *op) {
204+ // Dot operand layout assignment to Predicates are not currently supported
205+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
206+ // condition limits visibility of the original bit-width so that predicate
207+ // are not considered, hence, kwidth can never be = 32.
208+ if (isa<arith::UIToFPOp>(op)) {
209+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
210+ if (srcType.isInteger (1 ))
211+ return false ;
212+ }
213+
191214 return op->getNumOperands () == 1 &&
192215 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
193216 isPureUnaryInlineAsm (op) ||
@@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) {
207230// result, kwidth can be the bitwidth of the lower precision primitive.
208231// Conversely, in the downcasting scenario, no reordering is performed,
209232// making it directory use the lower precision primitive.
210- static int computeOrigBitWidth (Value x) {
233+ int computeOrigBitWidth (Value x) {
211234 int finalBitWidth = getElementTypeOrSelf (x).getIntOrFloatBitWidth ();
212235 int origBitWidth = finalBitWidth;
213236 SetVector<Operation *> slice;
@@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) {
227250 }
228251 return origBitWidth;
229252}
253+ // Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
254+ // extension.
255+ namespace {
230256
231257class BlockedToMMA : public mlir ::OpRewritePattern<DotOp> {
232258 int computeCapability;
@@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632658 NvidiaMmaEncodingAttr mmaLayout =
633659 dyn_cast<NvidiaMmaEncodingAttr>(D.getType ().getEncoding ());
634660 if (mmaLayout) {
635- bool isNativeFP8 = AElType.isFloat8E5M2 () || AElType.isFloat8E4M3FN ();
661+ bool isNativeFP8 =
662+ llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
636663 // promote operands for sm < 89 since fp8 mma is not natively supported
637664 // promote operands for sm >= 90 when mma is not v3
638665 if (!isNativeFP8 ||
@@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass
10181045 }
10191046};
10201047
1048+ Value getSharedMemMMAOperand (Value v, mlir::PatternRewriter &rewriter,
1049+ int opIdx, bool allowTranspose) {
1050+ return getSharedMemoryMMAOperand (v, rewriter, opIdx, allowTranspose);
1051+ }
1052+
10211053} // namespace gpu
10221054} // namespace triton
10231055} // namespace mlir
0 commit comments