1717#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1818#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
1919#include " mlir/IR/TypeUtilities.h"
20+ #include " mlir/IR/Types.h"
2021
2122using namespace mlir ;
2223
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
5758 if (type.getElementType ().isF32 ())
5859 return type.getOperand () == " COp" ? NVVM::MMATypes::f32
5960 : NVVM::MMATypes::tf32;
60-
61+ if (type.getElementType ().isF64 ())
62+ return NVVM::MMATypes::f64 ;
6163 if (type.getElementType ().isSignedInteger (8 ))
6264 return NVVM::MMATypes::s8;
6365 if (type.getElementType ().isUnsignedInteger (8 ))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
212214 // then passed on to the intrinsic call. Emit llvm ops to extract individual
213215 // values form lowered memrefs.
214216 SmallVector<Value> unpackedOps;
215-
216217 auto unpackOp = [&](Value operand) {
218+ // f64 a and b fragments are not structs but scalars.
219+ if (!isa<LLVM::LLVMStructType>(operand.getType ())) {
220+ unpackedOps.push_back (operand);
221+ return ;
222+ }
223+ // every other type is lowered to an LLVM struct, extract the values.
217224 auto structType = cast<LLVM::LLVMStructType>(operand.getType ());
218225 for (size_t i = 0 , e = structType.getBody ().size (); i < e; ++i) {
219226 Value toUse = LLVM::ExtractValueOp::create (rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
276283 return failure ();
277284 Location loc = subgroupMmaConstantOp.getLoc ();
278285 Value cst = adaptor.getOperands ()[0 ];
279- LLVM::LLVMStructType type = convertMMAToLLVMType (
286+ Type type = convertMMAToLLVMType (
280287 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType ()));
288+ // If the element is not a struct, it means it's a scalar f64.
289+ auto structType = dyn_cast<LLVM::LLVMStructType>(type);
290+ if (!structType) {
291+ rewriter.replaceOp (subgroupMmaConstantOp, cst);
292+ return success ();
293+ }
281294 // If the element type is a vector create a vector from the operand.
282- if (auto vecType = dyn_cast<VectorType>(type .getBody ()[0 ])) {
295+ if (auto vecType = dyn_cast<VectorType>(structType .getBody ()[0 ])) {
283296 Value vecCst = LLVM::PoisonOp::create (rewriter, loc, vecType);
284297 for (int64_t vecEl = 0 ; vecEl < vecType.getNumElements (); vecEl++) {
285298 Value idx = LLVM::ConstantOp::create (rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
289302 }
290303 cst = vecCst;
291304 }
292- Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, type );
293- for (size_t i : llvm::seq (size_t (0 ), type .getBody ().size ())) {
305+ Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, structType );
306+ for (size_t i : llvm::seq (size_t (0 ), structType .getBody ().size ())) {
294307 matrixStruct =
295308 LLVM::InsertValueOp::create (rewriter, loc, matrixStruct, cst, i);
296309 }
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
354367 return failure ();
355368 Location loc = subgroupMmaElementwiseOp.getLoc ();
356369 size_t numOperands = adaptor.getOperands ().size ();
357- LLVM::LLVMStructType destType = convertMMAToLLVMType (
370+ Type destType = convertMMAToLLVMType (
358371 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType ()));
359- Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, destType);
360- for (size_t i = 0 , e = destType.getBody ().size (); i < e; ++i) {
372+
373+ // If the element is not a struct, it means it's a scalar f64.
374+ LLVM::LLVMStructType structDestTy =
375+ dyn_cast<LLVM::LLVMStructType>(destType);
376+ if (!structDestTy) {
377+ SmallVector<Value> operands;
378+ for (auto operand : adaptor.getOperands ()) {
379+ operands.push_back (operand);
380+ }
381+ Value element = createScalarOp (
382+ rewriter, loc, subgroupMmaElementwiseOp.getOpType (), operands);
383+ rewriter.replaceOp (subgroupMmaElementwiseOp, element);
384+ return success ();
385+ }
386+ Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, structDestTy);
387+ for (size_t i = 0 , e = structDestTy.getBody ().size (); i < e; ++i) {
361388 SmallVector<Value> extractedOperands;
362389 for (size_t opIdx = 0 ; opIdx < numOperands; opIdx++) {
363390 extractedOperands.push_back (LLVM::ExtractValueOp::create (
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
377404} // namespace
378405
379406// / Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380- LLVM::LLVMStructType mlir::convertMMAToLLVMType (gpu::MMAMatrixType type) {
407+ Type mlir::convertMMAToLLVMType (gpu::MMAMatrixType type) {
381408 NVVM::MMAFrag frag = convertOperand (type.getOperand ());
382409 NVVM::MMATypes eltType = getElementType (type);
383410 auto nRow = type.getShape ()[0 ];
384411 auto nCol = type.getShape ()[1 ];
385412 std::pair<Type, unsigned > typeInfo =
386413 NVVM::inferMMAType (eltType, frag, nRow, nCol, type.getContext ());
414+ // Special handling for f64 a and b fragments
415+ Type f64Ty = Float64Type::get (type.getContext ());
416+ if (typeInfo.first == f64Ty && typeInfo.second == 1 ) {
417+ return f64Ty;
418+ }
387419 return LLVM::LLVMStructType::getLiteral (
388420 type.getContext (), SmallVector<Type, 8 >(typeInfo.second , typeInfo.first ));
389421}
0 commit comments