Skip to content

Commit 2b6d917

Browse files
Add packing of scales for ScaledMFMAOp
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 0d989b2 commit 2b6d917

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
10481048
attr-dict
10491049
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
10501050
}];
1051+
let hasCanonicalizer = 1;
10511052
}
10521053
#endif // AMDGPU

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Utils/IndexingUtils.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/IR/BuiltinTypes.h"
2223
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
2829
#include "llvm/ADT/DenseMap.h"
2930
#include "llvm/ADT/TypeSwitch.h"
3031

32+
#include <cstdint>
3133
#include <limits>
3234
#include <optional>
3335

@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
631633
return success();
632634
}
633635

636+
//===----------------------------------------------------------------------===//
637+
// ScaledMFMAOp
638+
//===----------------------------------------------------------------------===//
639+
640+
namespace {
641+
/// Check if the scales input is used in other scaled mfma's while they exist.
642+
/// If theyre unused then pack the scales.
643+
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
644+
using OpRewritePattern::OpRewritePattern;
645+
646+
LogicalResult matchAndRewrite(ScaledMFMAOp op,
647+
PatternRewriter &rewriter) const override {
648+
Location loc = op.getLoc();
649+
// If this use of a scale has a non zero opsel, packing has already been
650+
// done.
651+
auto checkIfUnpackable = [&](OpOperand &op) {
652+
if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
653+
switch (op.getOperandNumber()) {
654+
case 3:
655+
return smfma.getScalesIdxA() != 0;
656+
break;
657+
case 4:
658+
return smfma.getScalesIdxB() != 0;
659+
break;
660+
default:
661+
return true;
662+
break;
663+
}
664+
}
665+
};
666+
667+
auto setOpsel = [&](unsigned idx, int64_t val) {
668+
switch (idx) {
669+
case 3:
670+
return op.setScalesIdxA(val);
671+
break;
672+
case 4:
673+
return op.setScalesIdxB(val);
674+
break;
675+
default:
676+
break;
677+
}
678+
};
679+
680+
// Obtain flat index from offsets and shape.
681+
auto getIdxFromExtract = [](vector::ExtractOp op) {
682+
ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
683+
int cumul = 1;
684+
int idx = 0;
685+
for (auto [offset, size] :
686+
reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
687+
idx += offset * cumul;
688+
cumul *= size;
689+
}
690+
return idx;
691+
};
692+
693+
// Obtain offsets for new shape from flat index.
694+
auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
695+
SmallVector<int64_t> res;
696+
ShapedType shapedty = static_cast<ShapedType>(ty);
697+
int64_t numElements = shapedty.getNumElements();
698+
for (auto size : shapedty.getShape()) {
699+
numElements /= size;
700+
res.push_back(idx / numElements);
701+
idx -= (idx / numElements) * size;
702+
}
703+
return res;
704+
};
705+
706+
// For every scale operand of this ScaledMFMAOp, if the scale follows the
707+
// following pattern:
708+
//
709+
// %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
710+
// %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
711+
// amdgpu.scaled_mfma(%scale[0] * ...
712+
//
713+
// rewrite to:
714+
//
715+
// %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
716+
// %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
717+
// amdgpu.scaled_mfma(%scale[0-3] * ...
718+
//
719+
// This creates duplicate shape_casts for every use but these will be removed in CSE.
720+
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
721+
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
722+
if (!insertOp) {
723+
return failure();
724+
}
725+
if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
726+
return failure();
727+
}
728+
729+
auto extractOp =
730+
insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
731+
if (!extractOp) {
732+
return failure();
733+
}
734+
735+
Value scaleSrc = extractOp.getOperand(0);
736+
auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
737+
if (!stype) {
738+
return failure();
739+
}
740+
// We do not handle dynamic dims yet, assume that the input is padded to
741+
// a static shape now.
742+
if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
743+
[&](int64_t i) { return stype.isDynamicDim(i); })) {
744+
return failure();
745+
}
746+
747+
int64_t numElements = stype.getNumElements();
748+
if (numElements <= 4) {
749+
return failure();
750+
}
751+
752+
Type newSrcType = VectorType::get(
753+
SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
754+
Value newScaleSrc =
755+
rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
756+
int64_t idx = getIdxFromExtract(extractOp);
757+
SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
758+
auto scaleTy = VectorType::get({4}, stype.getElementType());
759+
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
760+
loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
761+
SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
762+
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
763+
op.setOperand(opIdx, scale);
764+
setOpsel(opIdx, offsets[1]);
765+
}
766+
return success();
767+
}
768+
};
769+
} // namespace
770+
771+
void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
772+
MLIRContext *context) {
773+
results.add<PackScales>(context);
774+
}
775+
634776
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
635777

636778
#define GET_ATTRDEF_CLASSES

mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
1414
MLIRROCDLDialect
1515
# Needed for GPU address space enum definition
1616
MLIRGPUDialect
17+
MLIRVectorDialect
1718
MLIRIR
1819
MLIRSideEffectInterfaces
1920
MLIRMemRefUtils

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
159159
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
160160
func.return
161161
}
162+
163+
// -----
164+
165+
// CHECK-LABEL: func @scaled_mfma
166+
// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
167+
// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
168+
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
169+
// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
170+
// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
171+
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
172+
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
173+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
174+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
175+
%scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
176+
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
177+
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
178+
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
179+
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
180+
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
181+
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
182+
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
183+
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
184+
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
185+
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
186+
}

0 commit comments

Comments
 (0)