-
Couldn't load subscription status.
- Fork 15k
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp #155951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp #155951
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Muzammil (Muzammiluddin-Syed-ECE) ChangesThe ScaledMFMAOp accepts scales as a vector of 4 bytes ( This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors. Example: to Full diff: https://github.com/llvm/llvm-project/pull/155951.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
+ let hasCanonicalizer = 1;
}
#endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstdint>
#include <limits>
#include <optional>
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // If this use of a scale has a non zero opsel, packing has already been
+ // done.
+ auto checkIfUnpackable = [&](OpOperand &op) {
+ if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+ switch (op.getOperandNumber()) {
+ case 3:
+ return smfma.getScalesIdxA() != 0;
+ break;
+ case 4:
+ return smfma.getScalesIdxB() != 0;
+ break;
+ default:
+ return true;
+ break;
+ }
+ }
+ };
+
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ return op.setScalesIdxA(val);
+ break;
+ case 4:
+ return op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // Obtain flat index from offsets and shape.
+ auto getIdxFromExtract = [](vector::ExtractOp op) {
+ ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+ int cumul = 1;
+ int idx = 0;
+ for (auto [offset, size] :
+ reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+ idx += offset * cumul;
+ cumul *= size;
+ }
+ return idx;
+ };
+
+ // Obtain offsets for new shape from flat index.
+ auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+ SmallVector<int64_t> res;
+ ShapedType shapedty = static_cast<ShapedType>(ty);
+ int64_t numElements = shapedty.getNumElements();
+ for (auto size : shapedty.getShape()) {
+ numElements /= size;
+ res.push_back(idx / numElements);
+ idx -= (idx / numElements) * size;
+ }
+ return res;
+ };
+
+ // For every scale operand of this ScaledMFMAOp, if the scale follows the
+ // following pattern:
+ //
+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+ // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0] * ...
+ //
+ // rewrite to:
+ //
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+ // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0-3] * ...
+ //
+ // This creates duplicate shape_casts for every use but these will be removed in CSE.
+ for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+ auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+ if (!insertOp) {
+ return failure();
+ }
+ if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+ return failure();
+ }
+
+ auto extractOp =
+ insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+ if (!extractOp) {
+ return failure();
+ }
+
+ Value scaleSrc = extractOp.getOperand(0);
+ auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+ if (!stype) {
+ return failure();
+ }
+ // We do not handle dynamic dims yet, assume that the input is padded to
+ // a static shape now.
+ if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+ [&](int64_t i) { return stype.isDynamicDim(i); })) {
+ return failure();
+ }
+
+ int64_t numElements = stype.getNumElements();
+ if (numElements <= 4) {
+ return failure();
+ }
+
+ Type newSrcType = VectorType::get(
+ SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+ Value newScaleSrc =
+ rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+ int64_t idx = getIdxFromExtract(extractOp);
+ SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+ auto scaleTy = VectorType::get({4}, stype.getElementType());
+ Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+ SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+ Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+ op.setOperand(opIdx, scale);
+ setOpsel(opIdx, offsets[1]);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<PackScales>(context);
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
+ MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+ %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Muzammiluddin Syed <[email protected]>
2eed475 to
2b6d917
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50% minor notes, some substantive problems
Signed-off-by: Muzammiluddin Syed <[email protected]>
|
... Oh, and while I'm here, the |
6c026b2 to
3873eda
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
168a45e to
2404d99
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
Signed-off-by: Muzammiluddin Syed <[email protected]>
24f5c4c to
ab6b1ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM here
Signed-off-by: Muzammiluddin Syed <[email protected]>
dd6dc05 to
9d8ffbd
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
Improving choice of class used, from SmallVector to ArrayRef (https://llvm.org/docs/ProgrammersManual.html#llvm-adt-arrayref-h). Also infer template types when possible. Leftover from #155951. --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
…C (#163770) Improving choice of class used, from SmallVector to ArrayRef (https://llvm.org/docs/ProgrammersManual.html#llvm-adt-arrayref-h). Also infer template types when possible. Leftover from llvm/llvm-project#155951. --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
The ScaledMFMAOp accepts scales as a vector of 4 bytes (
vector<4xf8E8M0FNU>) that can be stored in a single register with a particular scale accessed using theOpSelattribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.
Example:
to