@@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
153153 auto newType = MemDescType::get (argType.getShape (), argType.getElementType (),
154154 newLayout, SharedMemorySpace);
155155 rewriter.setInsertionPointAfterValue (arg);
156+
157+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
158+ // to SharedEncoding.
159+ if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
160+ argType.getEncoding ())) {
161+ // Create a layout conversion from DotOperandEncoding to BlockedEncoding
162+ // then pass it to the LocalAllocOp.
163+ auto newArgType = RankedTensorType::get (
164+ argType.getShape (), argType.getElementType (), dotOpEnc.getParent ());
165+ auto dotOperandToBlockedCvt =
166+ rewriter.create <ConvertLayoutOp>(arg.getLoc (), newArgType, arg);
167+ return rewriter.create <LocalAllocOp>(arg.getLoc (), newType,
168+ dotOperandToBlockedCvt);
169+ }
170+
156171 return rewriter.create <LocalAllocOp>(arg.getLoc (), newType, arg);
157172}
158173
@@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
162177 mutable llvm::DenseMap<Operation *, unsigned > dotOpInstNs;
163178
164179 static bool bwdFilter (Operation *op) {
180+ // Dot operand layout assignment to Predicates are not currently supported
181+ // during lowering from TritonGPU to LLVM in Triton for MMA cases. This
182+ // condition limits visibility of the original bit-width so that predicate
183+ // are not considered, hence, kwidth can never be = 32.
184+ if (isa<arith::UIToFPOp>(op)) {
185+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
186+ if (srcType.isInteger (1 ))
187+ return false ;
188+ }
165189 return op->getNumOperands () == 1 &&
166190 (isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
167191 isPureUnaryInlineAsm (op) ||
0 commit comments