Skip to content

Commit c1e65fb

Browse files
committed
Code formatting changes
1 parent d23f755 commit c1e65fb

File tree

2 files changed

+28
-37
lines changed

2 files changed

+28
-37
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -685,23 +685,24 @@ static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
685685
/// Maps f8 scale element types to WMMA scale format codes.
686686
static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
687687
return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
688-
.Case<Float8E8M0FNUType>([](auto) { return 0; })
689-
.Case<Float8E4M3FNType>([](auto) { return 2; })
690-
.Default([](Type) { return std::nullopt; });
688+
.Case([](Float8E8M0FNUType) { return 0; })
689+
.Case([](Float8E4M3FNType) { return 2; })
690+
.Default(std::nullopt);
691691
}
692692

693693
/// Determines the ROCDL intrinsic name for scaled WMMA based on dimensions
694694
/// and scale vector length.
695695
static std::optional<StringRef>
696696
getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16) {
697-
if (m == 16 && n == 16 && k == 128) {
697+
if (m == 16 && n == 16 && k == 128)
698698
return isScale16
699699
? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
700700
: ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
701-
} else if (m == 32 && n == 16 && k == 128) {
701+
702+
if (m == 32 && n == 16 && k == 128)
702703
return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
703704
: ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
704-
}
705+
705706
return std::nullopt;
706707
}
707708

@@ -1457,31 +1458,24 @@ struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
14571458
// The f4 variant does not have fmtA and fmtB attributes.
14581459
bool is32x16 = (m == 32 && n == 16 && k == 128);
14591460
if (!is32x16) {
1460-
attrs.emplace_back(
1461-
rewriter.getNamedAttr("fmtA", rewriter.getI32IntegerAttr(*aFmtCode)));
1462-
attrs.emplace_back(
1463-
rewriter.getNamedAttr("fmtB", rewriter.getI32IntegerAttr(*bFmtCode)));
1461+
attrs.emplace_back("fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1462+
attrs.emplace_back("fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
14641463
}
14651464

14661465
// modC uses default value of 0.
1467-
attrs.emplace_back(
1468-
rewriter.getNamedAttr("modC", rewriter.getI16IntegerAttr(0)));
1466+
attrs.emplace_back("modC", rewriter.getI16IntegerAttr(0));
14691467

14701468
// Scale attributes.
1471-
attrs.emplace_back(rewriter.getNamedAttr(
1472-
"scaleAType", rewriter.getI32IntegerAttr(op.getScaleAIdx())));
1473-
attrs.emplace_back(rewriter.getNamedAttr(
1474-
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt)));
1475-
attrs.emplace_back(rewriter.getNamedAttr(
1476-
"scaleBType", rewriter.getI32IntegerAttr(op.getScaleBIdx())));
1477-
attrs.emplace_back(rewriter.getNamedAttr(
1478-
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt)));
1469+
attrs.emplace_back("scaleAType",
1470+
rewriter.getI32IntegerAttr(op.getScaleAIdx()));
1471+
attrs.emplace_back("fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1472+
attrs.emplace_back("scaleBType",
1473+
rewriter.getI32IntegerAttr(op.getScaleBIdx()));
1474+
attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
14791475

14801476
// Reuse flags use default value of false.
1481-
attrs.emplace_back(
1482-
rewriter.getNamedAttr("reuseA", rewriter.getBoolAttr(false)));
1483-
attrs.emplace_back(
1484-
rewriter.getNamedAttr("reuseB", rewriter.getBoolAttr(false)));
1477+
attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
1478+
attrs.emplace_back("reuseB", rewriter.getBoolAttr(false));
14851479

14861480
// Convert typed float vectors to packed format.
14871481
Value sourceA =

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -448,17 +448,13 @@ LogicalResult WMMAOp::verify() {
448448

449449
LogicalResult ScaledWMMAOp::verify() {
450450
// Helper functions for type classification.
451-
auto isF8 = [](Type t) {
452-
return isa<Float8E4M3FNType, Float8E5M2Type, Float8E8M0FNUType,
453-
Float8E4M3FNUZType, Float8E5M2FNUZType>(t);
454-
};
455-
auto isF6 = [](Type t) { return isa<Float6E2M3FNType, Float6E3M2FNType>(t); };
456-
auto isF4 = [](Type t) { return isa<Float4E2M1FNType>(t); };
451+
auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
452+
auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
453+
auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
457454
auto isSmallFloat = [&](Type t) { return isF4(t) || isF6(t) || isF8(t); };
458-
auto isE8M0 = [](Type t) { return isa<Float8E8M0FNUType>(t); };
459-
auto isE4M3 = [](Type t) {
460-
return isa<Float8E4M3FNType, Float8E4M3FNUZType>(t);
461-
};
455+
auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
456+
auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
457+
auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
462458

463459
auto sourceAType = cast<VectorType>(getSourceA().getType());
464460
auto sourceBType = cast<VectorType>(getSourceB().getType());
@@ -517,9 +513,10 @@ LogicalResult ScaledWMMAOp::verify() {
517513
Type scaleAElemType = scaleAType.getElementType();
518514
Type scaleBElemType = scaleBType.getElementType();
519515

520-
// Validate scale element types are valid f8 types.
521-
if (!isF8(scaleAElemType) || !isF8(scaleBElemType))
522-
return emitOpError("scale operands must have f8 element types");
516+
// Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN).
517+
if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType))
518+
return emitOpError(
519+
"scale operands must have f8 element types (E8M0FNU or E4M3FN)");
523520

524521
// Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid.
525522
if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))

0 commit comments

Comments
 (0)