2323
2424#include " llvm/ADT/STLExtras.h"
2525#include " llvm/ADT/TypeSwitch.h"
26+ #include " llvm/Support/Casting.h"
2627#include < optional>
2728
2829namespace mlir {
@@ -826,19 +827,20 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
826827}
827828
828829static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
829- mfmaOpToScaledIntrinsic (MFMAOp mfma, Chipset chipset) {
830- return mfmaOpToScaledIntrinsic (
831- mfma.getSourceA ().getType (), mfma.getSourceB ().getType (),
832- mfma.getDestC ().getType (), mfma.getM (), mfma.getN (), mfma.getK (),
833- mfma.getBlocks (), chipset);
834- }
835-
836- static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
837- mfmaOpToScaledIntrinsic (ScaledMFMAOp smfma, Chipset chipset) {
838- return mfmaOpToScaledIntrinsic (smfma.getSourceA ().getType (),
839- smfma.getSourceB ().getType (),
840- smfma.getDestC ().getType (), smfma.getM (),
841- smfma.getN (), smfma.getK (), 1u , chipset);
830+ mfmaOpToScaledIntrinsic (Operation *op, Chipset chipset) {
831+ if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
832+ return mfmaOpToScaledIntrinsic (
833+ mfma.getSourceA ().getType (), mfma.getSourceB ().getType (),
834+ mfma.getDestC ().getType (), mfma.getM (), mfma.getN (), mfma.getK (),
835+ mfma.getBlocks (), chipset);
836+ }
837+ if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
838+ return mfmaOpToScaledIntrinsic (smfma.getSourceA ().getType (),
839+ smfma.getSourceB ().getType (),
840+ smfma.getDestC ().getType (), smfma.getM (),
841+ smfma.getN (), smfma.getK (), 1u , chipset);
842+ }
843+ return std::nullopt ;
842844}
843845
844846// / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -964,7 +966,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
964966
965967struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern <ScaledMFMAOp> {
966968 ScaledMFMAOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
967- : ConvertOpToLLVMPattern<ScaledMFMAOp> (converter), chipset(chipset) {}
969+ : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
968970
969971 Chipset chipset;
970972
@@ -986,7 +988,7 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
986988 return op.emitOpError (
987989 " no intrinsic matching Scaled MFMA size on given chipset" );
988990
989- StringRef intrinsicName = std::get< 0 >( *maybeScaledIntrinsic) ;
991+ auto [ intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
990992 OperationState loweredOp (loc, intrinsicName);
991993 loweredOp.addTypes (intrinsicOutType);
992994 loweredOp.addOperands (
@@ -997,7 +999,6 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
997999 Value scaleB = createI32Constant (rewriter, loc, adaptor.getScaleB ());
9981000 Value opselA = createI32Constant (rewriter, loc, adaptor.getOpselA ());
9991001 Value opselB = createI32Constant (rewriter, loc, adaptor.getOpselB ());
1000- auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
10011002 loweredOp.addOperands ({createI32Constant (rewriter, loc, aTypeCode),
10021003 createI32Constant (rewriter, loc, bTypeCode),
10031004 /* scale A byte=*/ opselA, /* scale A=*/ scaleA,
0 commit comments