@@ -685,23 +685,24 @@ static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
685685// / Maps f8 scale element types to WMMA scale format codes.
686686static std::optional<uint32_t > getWmmaScaleFormat (Type elemType) {
687687 return TypeSwitch<Type, std::optional<uint32_t >>(elemType)
688- .Case <Float8E8M0FNUType> ([](auto ) { return 0 ; })
689- .Case <Float8E4M3FNType> ([](auto ) { return 2 ; })
690- .Default ([](Type) { return std::nullopt ; } );
688+ .Case ([](Float8E8M0FNUType ) { return 0 ; })
689+ .Case ([](Float8E4M3FNType ) { return 2 ; })
690+ .Default (std::nullopt );
691691}
692692
693693// / Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
694694// / and scale vector length.
695695static std::optional<StringRef>
696696getScaledWmmaIntrinsicName (int64_t m, int64_t n, int64_t k, bool isScale16) {
697- if (m == 16 && n == 16 && k == 128 ) {
697+ if (m == 16 && n == 16 && k == 128 )
698698 return isScale16
699699 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName ()
700700 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName ();
701- } else if (m == 32 && n == 16 && k == 128 ) {
701+
702+ if (m == 32 && n == 16 && k == 128 )
702703 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName ()
703704 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName ();
704- }
705+
705706 return std::nullopt ;
706707}
707708
@@ -1457,31 +1458,24 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
14571458 // The f4 variant does not have fmtA and fmtB attributes.
14581459 bool is32x16 = (m == 32 && n == 16 && k == 128 );
14591460 if (!is32x16) {
1460- attrs.emplace_back (
1461- rewriter.getNamedAttr (" fmtA" , rewriter.getI32IntegerAttr (*aFmtCode)));
1462- attrs.emplace_back (
1463- rewriter.getNamedAttr (" fmtB" , rewriter.getI32IntegerAttr (*bFmtCode)));
1461+ attrs.emplace_back (" fmtA" , rewriter.getI32IntegerAttr (*aFmtCode));
1462+ attrs.emplace_back (" fmtB" , rewriter.getI32IntegerAttr (*bFmtCode));
14641463 }
14651464
14661465 // modC uses default value of 0.
1467- attrs.emplace_back (
1468- rewriter.getNamedAttr (" modC" , rewriter.getI16IntegerAttr (0 )));
1466+ attrs.emplace_back (" modC" , rewriter.getI16IntegerAttr (0 ));
14691467
14701468 // Scale attributes.
1471- attrs.emplace_back (rewriter.getNamedAttr (
1472- " scaleAType" , rewriter.getI32IntegerAttr (op.getScaleAIdx ())));
1473- attrs.emplace_back (rewriter.getNamedAttr (
1474- " fmtScaleA" , rewriter.getI32IntegerAttr (*scaleAFmt)));
1475- attrs.emplace_back (rewriter.getNamedAttr (
1476- " scaleBType" , rewriter.getI32IntegerAttr (op.getScaleBIdx ())));
1477- attrs.emplace_back (rewriter.getNamedAttr (
1478- " fmtScaleB" , rewriter.getI32IntegerAttr (*scaleBFmt)));
1469+ attrs.emplace_back (" scaleAType" ,
1470+ rewriter.getI32IntegerAttr (op.getScaleAIdx ()));
1471+ attrs.emplace_back (" fmtScaleA" , rewriter.getI32IntegerAttr (*scaleAFmt));
1472+ attrs.emplace_back (" scaleBType" ,
1473+ rewriter.getI32IntegerAttr (op.getScaleBIdx ()));
1474+ attrs.emplace_back (" fmtScaleB" , rewriter.getI32IntegerAttr (*scaleBFmt));
14791475
14801476 // Reuse flags use default value of false.
1481- attrs.emplace_back (
1482- rewriter.getNamedAttr (" reuseA" , rewriter.getBoolAttr (false )));
1483- attrs.emplace_back (
1484- rewriter.getNamedAttr (" reuseB" , rewriter.getBoolAttr (false )));
1477+ attrs.emplace_back (" reuseA" , rewriter.getBoolAttr (false ));
1478+ attrs.emplace_back (" reuseB" , rewriter.getBoolAttr (false ));
14851479
14861480 // Convert typed float vectors to packed format.
14871481 Value sourceA =
0 commit comments