From 2b6d91708127fb3da9f648c778037981657c3ad1 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 28 Aug 2025 17:01:09 -0700 Subject: [PATCH 1/7] Add packing of scales for ScaledMFMAOp Signed-off-by: Muzammiluddin Syed --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 1 + mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 142 ++++++++++++++++++ mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt | 1 + mlir/test/Dialect/AMDGPU/canonicalize.mlir | 25 +++ 4 files changed, 169 insertions(+) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 2ccf350a359a8..a24a918357f2d 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp : attr-dict `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC) }]; + let hasCanonicalizer = 1; } #endif // AMDGPU diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 11a40d663a201..4107ec53a0988 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -28,6 +29,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include @@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ScaledMFMAOp +//===----------------------------------------------------------------------===// + +namespace { +/// Check if the scales input is used in other scaled mfma's while they exist. +/// If theyre unused then pack the scales. +struct PackScales final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScaledMFMAOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // If this use of a scale has a non zero opsel, packing has already been + // done. + auto checkIfUnpackable = [&](OpOperand &op) { + if (auto smfma = dyn_cast(op.getOwner())) { + switch (op.getOperandNumber()) { + case 3: + return smfma.getScalesIdxA() != 0; + break; + case 4: + return smfma.getScalesIdxB() != 0; + break; + default: + return true; + break; + } + } + }; + + auto setOpsel = [&](unsigned idx, int64_t val) { + switch (idx) { + case 3: + return op.setScalesIdxA(val); + break; + case 4: + return op.setScalesIdxB(val); + break; + default: + break; + } + }; + + // Obtain flat index from offsets and shape. + auto getIdxFromExtract = [](vector::ExtractOp op) { + ShapedType ty = dyn_cast(op.getOperand(0).getType()); + int cumul = 1; + int idx = 0; + for (auto [offset, size] : + reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) { + idx += offset * cumul; + cumul *= size; + } + return idx; + }; + + // Obtain offsets for new shape from flat index. + auto getOffsetsFromIdx = [](int64_t idx, Type ty) { + SmallVector res; + ShapedType shapedty = static_cast(ty); + int64_t numElements = shapedty.getNumElements(); + for (auto size : shapedty.getShape()) { + numElements /= size; + res.push_back(idx / numElements); + idx -= (idx / numElements) * size; + } + return res; + }; + + // For every scale operand of this ScaledMFMAOp, if the scale follows the + // following pattern: + // + // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector + // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU> + // amdgpu.scaled_mfma(%scale[0] * ... + // + // rewrite to: + // + // %reshaped = vector.shape_cast %ScaleSrc : vector to vector + // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector + // amdgpu.scaled_mfma(%scale[0-3] * ... + // + // This creates duplicate shape_casts for every use but these will be removed in CSE. + for (auto opIdx : SmallVector({3, 4})) { + auto insertOp = op.getOperand(opIdx).getDefiningOp(); + if (!insertOp) { + return failure(); + } + if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) { + return failure(); + } + + auto extractOp = + insertOp.getOperand(0).getDefiningOp(); + if (!extractOp) { + return failure(); + } + + Value scaleSrc = extractOp.getOperand(0); + auto stype = dyn_cast(scaleSrc.getType()); + if (!stype) { + return failure(); + } + // We do not handle dynamic dims yet, assume that the input is padded to + // a static shape now. + if (llvm::any_of(llvm::seq(0, stype.getRank()), + [&](int64_t i) { return stype.isDynamicDim(i); })) { + return failure(); + } + + int64_t numElements = stype.getNumElements(); + if (numElements <= 4) { + return failure(); + } + + Type newSrcType = VectorType::get( + SmallVector({numElements / 4, 4}), stype.getElementType()); + Value newScaleSrc = + rewriter.create(loc, newSrcType, scaleSrc); + int64_t idx = getIdxFromExtract(extractOp); + SmallVector offsets(getOffsetsFromIdx(idx, newSrcType)); + auto scaleTy = VectorType::get({4}, stype.getElementType()); + Value extract = rewriter.create( + loc, newScaleSrc, SmallVector{offsets[0], 0}, + SmallVector{1, 4}, SmallVector{1, 1}); + Value scale = rewriter.create(loc, scaleTy, extract); + op.setOperand(opIdx, scale); + setOpsel(opIdx, offsets[1]); + } + return success(); + } +}; +} // namespace + +void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt index 2a019954c8356..5d14a05945e95 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect MLIRROCDLDialect # Needed for GPU address space enum definition MLIRGPUDialect + MLIRVectorDialect MLIRIR MLIRSideEffectInterfaces MLIRMemRefUtils diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 5501ad42dbd90..75cbf29c95f29 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: : f32, memref<128x72xf32, 1>, memref func.return } + +// ----- + +// CHECK-LABEL: func @scaled_mfma +// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}} +// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}} +func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU> + %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + return %res_0, %res_1 : vector<4xf32>, vector<4xf32> +} From 3873edac6f205ac98808103bfdb1251eedfadf99 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 17 Sep 2025 14:36:47 -0500 Subject: [PATCH 2/7] PR review round 0 Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 26 ++++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 4107ec53a0988..2e3f95651902e 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -653,26 +653,24 @@ struct PackScales final : OpRewritePattern { switch (op.getOperandNumber()) { case 3: return smfma.getScalesIdxA() != 0; - break; case 4: return smfma.getScalesIdxB() != 0; - break; default: - return true; break; } } + return true; }; auto setOpsel = [&](unsigned idx, int64_t val) { switch (idx) { case 3: - return op.setScalesIdxA(val); + op.setScalesIdxA(val); break; case 4: - return op.setScalesIdxB(val); + op.setScalesIdxB(val); break; - default: + default: break; } }; @@ -695,7 +693,7 @@ struct PackScales final : OpRewritePattern { SmallVector res; ShapedType shapedty = static_cast(ty); int64_t numElements = shapedty.getNumElements(); - for (auto size : shapedty.getShape()) { + for (unsigned size : shapedty.getShape()) { numElements /= size; res.push_back(idx / numElements); idx -= (idx / numElements) * size; @@ -706,17 +704,19 @@ struct PackScales final : OpRewritePattern { // For every scale operand of this ScaledMFMAOp, if the scale follows the // following pattern: // - // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector - // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU> - // amdgpu.scaled_mfma(%scale[0] * ... + // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from + // vector %scale = vector.insert %unit, ... : f8E8M0FNU + // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ... // // rewrite to: // - // %reshaped = vector.shape_cast %ScaleSrc : vector to vector - // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector + // %reshaped = vector.shape_cast %ScaleSrc : vector to + // vector %scale = vector.extract %reshaped[?] : + // vector<4xf8E8M0FNU> from vector // amdgpu.scaled_mfma(%scale[0-3] * ... // - // This creates duplicate shape_casts for every use but these will be removed in CSE. + // This creates duplicate shape_casts for every use but these will be + // removed in CSE. for (auto opIdx : SmallVector({3, 4})) { auto insertOp = op.getOperand(opIdx).getDefiningOp(); if (!insertOp) { From 2404d99880984e087cbe9d62faf1d3a4205effd8 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 17 Sep 2025 15:31:56 -0500 Subject: [PATCH 3/7] PR review round 1 Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 33 +++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2e3f95651902e..1cc800ec92090 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -670,7 +670,7 @@ struct PackScales final : OpRewritePattern { case 4: op.setScalesIdxB(val); break; - default: + default: break; } }; @@ -678,8 +678,8 @@ struct PackScales final : OpRewritePattern { // Obtain flat index from offsets and shape. auto getIdxFromExtract = [](vector::ExtractOp op) { ShapedType ty = dyn_cast(op.getOperand(0).getType()); - int cumul = 1; - int idx = 0; + int64_t cumul = 1; + int64_t idx = 0; for (auto [offset, size] : reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) { idx += offset * cumul; @@ -720,33 +720,37 @@ struct PackScales final : OpRewritePattern { for (auto opIdx : SmallVector({3, 4})) { auto insertOp = op.getOperand(opIdx).getDefiningOp(); if (!insertOp) { - return failure(); + return rewriter.notifyMatchFailure(op, + "defining op not a vector.insert"); } if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) { - return failure(); + return rewriter.notifyMatchFailure(op, + "some scaled mfma's already packed"); } auto extractOp = insertOp.getOperand(0).getDefiningOp(); if (!extractOp) { - return failure(); + return rewriter.notifyMatchFailure(op, + "defining op not a vector.extract"); } Value scaleSrc = extractOp.getOperand(0); - auto stype = dyn_cast(scaleSrc.getType()); + auto stype = dyn_cast(scaleSrc.getType()); if (!stype) { - return failure(); + return rewriter.notifyMatchFailure(op, "not a shaped type"); } // We do not handle dynamic dims yet, assume that the input is padded to // a static shape now. - if (llvm::any_of(llvm::seq(0, stype.getRank()), - [&](int64_t i) { return stype.isDynamicDim(i); })) { - return failure(); + if (!stype.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic dims not yet supported"); } int64_t numElements = stype.getNumElements(); - if (numElements <= 4) { - return failure(); + if (numElements <= 4 || !(numElements % 4)) { + return rewriter.notifyMatchFailure( + op, "no packing if # of scales less than or indivisible by four"); } Type newSrcType = VectorType::get( @@ -760,7 +764,8 @@ struct PackScales final : OpRewritePattern { loc, newScaleSrc, SmallVector{offsets[0], 0}, SmallVector{1, 4}, SmallVector{1, 1}); Value scale = rewriter.create(loc, scaleTy, extract); - op.setOperand(opIdx, scale); + rewriter.modifyOpInPlace( + op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); }); setOpsel(opIdx, offsets[1]); } return success(); From 970aa1a9e55d00e5d9df108f4a0015a87234bd88 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 17 Sep 2025 17:47:41 -0500 Subject: [PATCH 4/7] Perform packing for inputs with shapes non-divisible by 4 Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 56 +++++------ mlir/test/Dialect/AMDGPU/canonicalize.mlir | 97 +++++++++++++++++++- 2 files changed, 118 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 1cc800ec92090..e04a1d75724fb 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -688,36 +689,22 @@ struct PackScales final : OpRewritePattern { return idx; }; - // Obtain offsets for new shape from flat index. - auto getOffsetsFromIdx = [](int64_t idx, Type ty) { - SmallVector res; - ShapedType shapedty = static_cast(ty); - int64_t numElements = shapedty.getNumElements(); - for (unsigned size : shapedty.getShape()) { - numElements /= size; - res.push_back(idx / numElements); - idx -= (idx / numElements) * size; - } - return res; - }; - // For every scale operand of this ScaledMFMAOp, if the scale follows the // following pattern: - // - // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from - // vector %scale = vector.insert %unit, ... : f8E8M0FNU - // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ... + // (f8 here means f8E8M0FNU) + // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...> + // %scale = vector.insert %unit, ... : f8 into vector<4xf8> + // amdgpu.scaled_mfma(%scale[0] * ... // // rewrite to: // - // %reshaped = vector.shape_cast %ScaleSrc : vector to - // vector %scale = vector.extract %reshaped[?] : - // vector<4xf8E8M0FNU> from vector + // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector + // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector // amdgpu.scaled_mfma(%scale[0-3] * ... // // This creates duplicate shape_casts for every use but these will be // removed in CSE. - for (auto opIdx : SmallVector({3, 4})) { + for (auto opIdx : std::array({3, 4})) { auto insertOp = op.getOperand(opIdx).getDefiningOp(); if (!insertOp) { return rewriter.notifyMatchFailure(op, @@ -738,7 +725,7 @@ struct PackScales final : OpRewritePattern { Value scaleSrc = extractOp.getOperand(0); auto stype = dyn_cast(scaleSrc.getType()); if (!stype) { - return rewriter.notifyMatchFailure(op, "not a shaped type"); + return rewriter.notifyMatchFailure(op, "not a vector type"); } // We do not handle dynamic dims yet, assume that the input is padded to // a static shape now. @@ -748,25 +735,32 @@ struct PackScales final : OpRewritePattern { } int64_t numElements = stype.getNumElements(); - if (numElements <= 4 || !(numElements % 4)) { + if (numElements <= 4) { return rewriter.notifyMatchFailure( - op, "no packing if # of scales less than or indivisible by four"); + op, "no packing if # of scales less than four"); + } + int64_t idx = getIdxFromExtract(extractOp); + int64_t offset = idx - (idx % 4); + int64_t size = std::min(4l, numElements - offset); + int64_t opsel = idx - offset; + if (size != 4l) { + opsel += 4l - size; + offset = numElements - 4l; + size = 4l; } - Type newSrcType = VectorType::get( - SmallVector({numElements / 4, 4}), stype.getElementType()); + Type newSrcType = VectorType::get(SmallVector({numElements}), + stype.getElementType()); Value newScaleSrc = rewriter.create(loc, newSrcType, scaleSrc); - int64_t idx = getIdxFromExtract(extractOp); - SmallVector offsets(getOffsetsFromIdx(idx, newSrcType)); auto scaleTy = VectorType::get({4}, stype.getElementType()); Value extract = rewriter.create( - loc, newScaleSrc, SmallVector{offsets[0], 0}, - SmallVector{1, 4}, SmallVector{1, 1}); + loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{1}); Value scale = rewriter.create(loc, scaleTy, extract); rewriter.modifyOpInPlace( op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); }); - setOpsel(opIdx, offsets[1]); + setOpsel(opIdx, opsel); } return success(); } diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 75cbf29c95f29..8179d8e0ce513 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -163,11 +163,11 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: // ----- // CHECK-LABEL: func @scaled_mfma -// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> -// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU> +// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU> // CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}} -// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> -// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU> +// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU> // CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}} func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> @@ -184,3 +184,92 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> return %res_0, %res_1 : vector<4xf32>, vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @scaled_mfma_less_than_4 +// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU> +// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU> +// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU> +// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU> +// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}} +func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> { + %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU> + %scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU> + %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU> + %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + return %res_0 : vector<4xf32> +} + + +// ----- + +// CHECK-LABEL: func @scaled_mfma_ugly_shapes +// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> +func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU> + %scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + %scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU> + + // idx = 138 + 8 = 146 => opsel = 2 + %scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 147 => opsel = 3 + %scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 148 => opsel = 0 + %scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 149 => opsel = 1 + %scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 160 => opsel = 3 (last idx of last 4 bytes) + %scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 159 => opsel = 3 + %scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 158 => opsel = 2 + %scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + // idx = 157 => opsel = 1 + %scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU> + + %sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + + %sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + + %res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> +} From ab6b1aea263f4d3f1753b79f15c5fac35cfdb78c Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 18 Sep 2025 00:36:59 -0500 Subject: [PATCH 5/7] PR Review round 2 Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 104 +++++++++---------- mlir/test/Dialect/AMDGPU/canonicalize.mlir | 33 +----- 2 files changed, 52 insertions(+), 85 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index e04a1d75724fb..49e0cec1373ec 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include #include @@ -647,22 +648,6 @@ struct PackScales final : OpRewritePattern { LogicalResult matchAndRewrite(ScaledMFMAOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - // If this use of a scale has a non zero opsel, packing has already been - // done. - auto checkIfUnpackable = [&](OpOperand &op) { - if (auto smfma = dyn_cast(op.getOwner())) { - switch (op.getOperandNumber()) { - case 3: - return smfma.getScalesIdxA() != 0; - case 4: - return smfma.getScalesIdxB() != 0; - default: - break; - } - } - return true; - }; - auto setOpsel = [&](unsigned idx, int64_t val) { switch (idx) { case 3: @@ -676,22 +661,11 @@ struct PackScales final : OpRewritePattern { } }; - // Obtain flat index from offsets and shape. - auto getIdxFromExtract = [](vector::ExtractOp op) { - ShapedType ty = dyn_cast(op.getOperand(0).getType()); - int64_t cumul = 1; - int64_t idx = 0; - for (auto [offset, size] : - reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) { - idx += offset * cumul; - cumul *= size; - } - return idx; - }; - - // For every scale operand of this ScaledMFMAOp, if the scale follows the - // following pattern: - // (f8 here means f8E8M0FNU) + // For every scale operand of this ScaledMFMAOp, if the scale is produced by + // the extraction of a single scale from some vector, then attempt to + // extract 4 values from that vector instead. + // + // Example: (f8 here means f8E8M0FNU) // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...> // %scale = vector.insert %unit, ... : f8 into vector<4xf8> // amdgpu.scaled_mfma(%scale[0] * ... @@ -710,57 +684,79 @@ struct PackScales final : OpRewritePattern { return rewriter.notifyMatchFailure(op, "defining op not a vector.insert"); } - if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) { - return rewriter.notifyMatchFailure(op, - "some scaled mfma's already packed"); + // if the extracted value is not a single scalar, then it has been packed. + if (dyn_cast(insertOp.getValueToStore().getType())) { + return rewriter.notifyMatchFailure( + op, "scaled mfma operand already packed"); } auto extractOp = - insertOp.getOperand(0).getDefiningOp(); + insertOp.getValueToStore().getDefiningOp(); if (!extractOp) { return rewriter.notifyMatchFailure(op, "defining op not a vector.extract"); } Value scaleSrc = extractOp.getOperand(0); - auto stype = dyn_cast(scaleSrc.getType()); - if (!stype) { + auto scaleSrcType = dyn_cast(scaleSrc.getType()); + if (!scaleSrcType) { return rewriter.notifyMatchFailure(op, "not a vector type"); } + // We do not handle dynamic dims yet, assume that the input is padded to // a static shape now. - if (!stype.hasStaticShape()) { + if (!scaleSrcType.hasStaticShape()) { return rewriter.notifyMatchFailure(op, "dynamic dims not yet supported"); } - int64_t numElements = stype.getNumElements(); + int64_t numElements = scaleSrcType.getNumElements(); if (numElements <= 4) { return rewriter.notifyMatchFailure( op, "no packing if # of scales less than four"); } - int64_t idx = getIdxFromExtract(extractOp); + + // Find a linearized idx using the size and offsets of the extract op + ArrayRef scaleSrcShape = scaleSrcType.getShape(); + int64_t scaleSrcRank = scaleSrcType.getRank(); + SmallVector extractedPos(extractOp.getStaticPosition()); + SmallVector extractSizes(scaleSrcRank, 1); + std::reverse(extractedPos.begin(), extractedPos.end()); + for (int64_t i = 1; i < scaleSrcRank; i++) { + extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i]; + } + int64_t idx = linearize(extractedPos, extractSizes); + + // All n scales (where n is the total number of scales) must now be + // extracted in chunks of 4 elements. This is done by dividing the + // original vector of scales into groups of 4 elements + // at offsets 0, 4, ..., m (where m = n/4). All extractions of a + // scale at a particular index are now replaced with an extraction + // of the entire group of 4 elements to which that index belongs. + // + // If the number of scales happens to be indivisible by 4, extract + // the remaining n - m scales in a chunk of 4 elements starting at + // offset n - 4. int64_t offset = idx - (idx % 4); - int64_t size = std::min(4l, numElements - offset); int64_t opsel = idx - offset; - if (size != 4l) { - opsel += 4l - size; + int64_t size = 4l; + // Accomdate remaining elements in the case of non-4-divisible vectors. + if (numElements - offset < size) { + opsel = size - (numElements - idx); offset = numElements - 4l; - size = 4l; } - - Type newSrcType = VectorType::get(SmallVector({numElements}), - stype.getElementType()); + Type scaleSrcElemType = scaleSrcType.getElementType(); + auto newSrcType = VectorType::get(SmallVector({numElements}), + scaleSrcElemType); Value newScaleSrc = rewriter.create(loc, newSrcType, scaleSrc); - auto scaleTy = VectorType::get({4}, stype.getElementType()); - Value extract = rewriter.create( + auto extract = rewriter.create( loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, ArrayRef{1}); - Value scale = rewriter.create(loc, scaleTy, extract); - rewriter.modifyOpInPlace( - op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); }); - setOpsel(opIdx, opsel); + rewriter.modifyOpInPlace(op, [&] { + op->setOperand(opIdx, extract); + setOpsel(opIdx, opsel); + }); } return success(); } diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 8179d8e0ce513..52d3275dab43b 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -204,38 +204,21 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4 return %res_0 : vector<4xf32> } - // ----- // CHECK-LABEL: func @scaled_mfma_ugly_shapes -// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> -// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> -// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> -// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> -func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) { +func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) { %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU> - %scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> - %scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> - %scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> - %scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> %scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU> %scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU> %scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU> %scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU> - // idx = 138 + 8 = 146 => opsel = 2 - %scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU> - // idx = 147 => opsel = 3 - %scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU> - // idx = 148 => opsel = 0 - %scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU> - // idx = 149 => opsel = 1 - %scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU> // idx = 160 => opsel = 3 (last idx of last 4 bytes) %scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU> // idx = 159 => opsel = 3 @@ -245,31 +228,19 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4 // idx = 157 => opsel = 1 %scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU> - %sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> - %res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> - %res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> - %res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> - %res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> - return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> + return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> } From 9d8ffbdd5387df486bb44c902845f4a5da10e3e5 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 18 Sep 2025 13:35:45 -0500 Subject: [PATCH 6/7] PR Review round 3 Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 49e0cec1373ec..8afcdbdfee973 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -648,7 +648,7 @@ struct PackScales final : OpRewritePattern { LogicalResult matchAndRewrite(ScaledMFMAOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto setOpsel = [&](unsigned idx, int64_t val) { + auto setOpsel = [&op](unsigned idx, int64_t val) { switch (idx) { case 3: op.setScalesIdxA(val); @@ -684,8 +684,8 @@ struct PackScales final : OpRewritePattern { return rewriter.notifyMatchFailure(op, "defining op not a vector.insert"); } - // if the extracted value is not a single scalar, then it has been packed. - if (dyn_cast(insertOp.getValueToStore().getType())) { + // If the extracted value is not a single scalar, then it has been packed. + if (isa(insertOp.getValueToStore().getType())) { return rewriter.notifyMatchFailure( op, "scaled mfma operand already packed"); } @@ -717,12 +717,12 @@ struct PackScales final : OpRewritePattern { } // Find a linearized idx using the size and offsets of the extract op + SmallVector extractedPos(llvm::to_vector_of( + llvm::reverse(extractOp.getStaticPosition()))); ArrayRef scaleSrcShape = scaleSrcType.getShape(); int64_t scaleSrcRank = scaleSrcType.getRank(); - SmallVector extractedPos(extractOp.getStaticPosition()); SmallVector extractSizes(scaleSrcRank, 1); - std::reverse(extractedPos.begin(), extractedPos.end()); - for (int64_t i = 1; i < scaleSrcRank; i++) { + for (int64_t i = 1; i < scaleSrcRank; ++i) { extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i]; } int64_t idx = linearize(extractedPos, extractSizes); @@ -749,10 +749,10 @@ struct PackScales final : OpRewritePattern { auto newSrcType = VectorType::get(SmallVector({numElements}), scaleSrcElemType); Value newScaleSrc = - rewriter.create(loc, newSrcType, scaleSrc); - auto extract = rewriter.create( - loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, - ArrayRef{1}); + vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); + auto extract = vector::ExtractStridedSliceOp::create( + rewriter, loc, newScaleSrc, ArrayRef{offset}, + ArrayRef{size}, ArrayRef{1}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); From f4ca5651bab7d82731b0b6e708353bb501bd5928 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 18 Sep 2025 14:15:20 -0500 Subject: [PATCH 7/7] minor fixups Signed-off-by: Muzammiluddin Syed --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 8afcdbdfee973..eb122717e3a66 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -716,9 +716,9 @@ struct PackScales final : OpRewritePattern { op, "no packing if # of scales less than four"); } - // Find a linearized idx using the size and offsets of the extract op - SmallVector extractedPos(llvm::to_vector_of( - llvm::reverse(extractOp.getStaticPosition()))); + // Find a linearized idx using the size and offsets of the extract op. + auto extractedPos = llvm::to_vector_of( + llvm::reverse(extractOp.getStaticPosition())); ArrayRef scaleSrcShape = scaleSrcType.getShape(); int64_t scaleSrcRank = scaleSrcType.getRank(); SmallVector extractSizes(scaleSrcRank, 1);