1717#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1818#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
1919#include " mlir/IR/TypeUtilities.h"
20- #include " mlir/IR/Types.h"
2120
2221using namespace mlir ;
2322
@@ -58,8 +57,7 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
5857 if (type.getElementType ().isF32 ())
5958 return type.getOperand () == " COp" ? NVVM::MMATypes::f32
6059 : NVVM::MMATypes::tf32;
61- if (type.getElementType ().isF64 ())
62- return NVVM::MMATypes::f64 ;
60+
6361 if (type.getElementType ().isSignedInteger (8 ))
6462 return NVVM::MMATypes::s8;
6563 if (type.getElementType ().isUnsignedInteger (8 ))
@@ -214,13 +212,8 @@ struct WmmaMmaOpToNVVMLowering
214212 // then passed on to the intrinsic call. Emit llvm ops to extract individual
215213 // values form lowered memrefs.
216214 SmallVector<Value> unpackedOps;
215+
217216 auto unpackOp = [&](Value operand) {
218- // f64 a and b fragments are not structs but scalars.
219- if (!isa<LLVM::LLVMStructType>(operand.getType ())) {
220- unpackedOps.push_back (operand);
221- return ;
222- }
223- // every other type is lowered to an LLVM struct, extract the values.
224217 auto structType = cast<LLVM::LLVMStructType>(operand.getType ());
225218 for (size_t i = 0 , e = structType.getBody ().size (); i < e; ++i) {
226219 Value toUse = LLVM::ExtractValueOp::create (rewriter, loc, operand, i);
@@ -283,16 +276,10 @@ struct WmmaConstantOpToNVVMLowering
283276 return failure ();
284277 Location loc = subgroupMmaConstantOp.getLoc ();
285278 Value cst = adaptor.getOperands ()[0 ];
286- Type type = convertMMAToLLVMType (
279+ LLVM::LLVMStructType type = convertMMAToLLVMType (
287280 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType ()));
288- // If the element is not a struct, it means it's a scalar f64.
289- auto structType = dyn_cast<LLVM::LLVMStructType>(type);
290- if (!structType) {
291- rewriter.replaceOp (subgroupMmaConstantOp, cst);
292- return success ();
293- }
294281 // If the element type is a vector create a vector from the operand.
295- if (auto vecType = dyn_cast<VectorType>(structType .getBody ()[0 ])) {
282+ if (auto vecType = dyn_cast<VectorType>(type .getBody ()[0 ])) {
296283 Value vecCst = LLVM::PoisonOp::create (rewriter, loc, vecType);
297284 for (int64_t vecEl = 0 ; vecEl < vecType.getNumElements (); vecEl++) {
298285 Value idx = LLVM::ConstantOp::create (rewriter, loc,
@@ -302,8 +289,8 @@ struct WmmaConstantOpToNVVMLowering
302289 }
303290 cst = vecCst;
304291 }
305- Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, structType );
306- for (size_t i : llvm::seq (size_t (0 ), structType .getBody ().size ())) {
292+ Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, type );
293+ for (size_t i : llvm::seq (size_t (0 ), type .getBody ().size ())) {
307294 matrixStruct =
308295 LLVM::InsertValueOp::create (rewriter, loc, matrixStruct, cst, i);
309296 }
@@ -367,24 +354,10 @@ struct WmmaElementwiseOpToNVVMLowering
367354 return failure ();
368355 Location loc = subgroupMmaElementwiseOp.getLoc ();
369356 size_t numOperands = adaptor.getOperands ().size ();
370- Type destType = convertMMAToLLVMType (
357+ LLVM::LLVMStructType destType = convertMMAToLLVMType (
371358 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType ()));
372-
373- // If the element is not a struct, it means it's a scalar f64.
374- LLVM::LLVMStructType structDestTy =
375- dyn_cast<LLVM::LLVMStructType>(destType);
376- if (!structDestTy) {
377- SmallVector<Value> operands;
378- for (auto operand : adaptor.getOperands ()) {
379- operands.push_back (operand);
380- }
381- Value element = createScalarOp (
382- rewriter, loc, subgroupMmaElementwiseOp.getOpType (), operands);
383- rewriter.replaceOp (subgroupMmaElementwiseOp, element);
384- return success ();
385- }
386- Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, structDestTy);
387- for (size_t i = 0 , e = structDestTy.getBody ().size (); i < e; ++i) {
359+ Value matrixStruct = LLVM::PoisonOp::create (rewriter, loc, destType);
360+ for (size_t i = 0 , e = destType.getBody ().size (); i < e; ++i) {
388361 SmallVector<Value> extractedOperands;
389362 for (size_t opIdx = 0 ; opIdx < numOperands; opIdx++) {
390363 extractedOperands.push_back (LLVM::ExtractValueOp::create (
@@ -404,18 +377,13 @@ struct WmmaElementwiseOpToNVVMLowering
404377} // namespace
405378
406379// / Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
407- Type mlir::convertMMAToLLVMType (gpu::MMAMatrixType type) {
380+ LLVM::LLVMStructType mlir::convertMMAToLLVMType (gpu::MMAMatrixType type) {
408381 NVVM::MMAFrag frag = convertOperand (type.getOperand ());
409382 NVVM::MMATypes eltType = getElementType (type);
410383 auto nRow = type.getShape ()[0 ];
411384 auto nCol = type.getShape ()[1 ];
412385 std::pair<Type, unsigned > typeInfo =
413386 NVVM::inferMMAType (eltType, frag, nRow, nCol, type.getContext ());
414- // Special handling for f64 a and b fragments
415- Type f64Ty = Float64Type::get (type.getContext ());
416- if (typeInfo.first == f64Ty && typeInfo.second == 1 ) {
417- return f64Ty;
418- }
419387 return LLVM::LLVMStructType::getLiteral (
420388 type.getContext (), SmallVector<Type, 8 >(typeInfo.second , typeInfo.first ));
421389}
0 commit comments