Skip to content

Commit 9628061

Browse files
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (#155951)
The ScaledMFMAOp accepts scales as a vector of 4 bytes (`vector<4xf8E8M0FNU>`) that can be stored in a single register with a particular scale accessed using the `OpSel` attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers. This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors. Example: ``` %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ... ``` to ``` %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0-3] * ... ``` --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 8c41859 commit 9628061

File tree

4 files changed

+224
-0
lines changed

4 files changed

+224
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
10481048
attr-dict
10491049
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
10501050
}];
1051+
let hasCanonicalizer = 1;
10511052
}
10521053
#endif // AMDGPU

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Utils/IndexingUtils.h"
20+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/IR/BuiltinTypes.h"
2223
#include "mlir/IR/Diagnostics.h"
@@ -26,8 +27,11 @@
2627
#include "mlir/IR/PatternMatch.h"
2728
#include "mlir/IR/TypeUtilities.h"
2829
#include "llvm/ADT/DenseMap.h"
30+
#include "llvm/ADT/SmallVector.h"
2931
#include "llvm/ADT/TypeSwitch.h"
3032

33+
#include <algorithm>
34+
#include <cstdint>
3135
#include <limits>
3236
#include <optional>
3337

@@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() {
631635
return success();
632636
}
633637

638+
//===----------------------------------------------------------------------===//
639+
// ScaledMFMAOp
640+
//===----------------------------------------------------------------------===//
641+
642+
namespace {
643+
/// Check if the scales input is used in other scaled mfma's while they exist.
644+
/// If theyre unused then pack the scales.
645+
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
646+
using OpRewritePattern::OpRewritePattern;
647+
648+
LogicalResult matchAndRewrite(ScaledMFMAOp op,
649+
PatternRewriter &rewriter) const override {
650+
Location loc = op.getLoc();
651+
auto setOpsel = [&op](unsigned idx, int64_t val) {
652+
switch (idx) {
653+
case 3:
654+
op.setScalesIdxA(val);
655+
break;
656+
case 4:
657+
op.setScalesIdxB(val);
658+
break;
659+
default:
660+
break;
661+
}
662+
};
663+
664+
// For every scale operand of this ScaledMFMAOp, if the scale is produced by
665+
// the extraction of a single scale from some vector, then attempt to
666+
// extract 4 values from that vector instead.
667+
//
668+
// Example: (f8 here means f8E8M0FNU)
669+
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
670+
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
671+
// amdgpu.scaled_mfma(%scale[0] * ...
672+
//
673+
// rewrite to:
674+
//
675+
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
676+
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
677+
// amdgpu.scaled_mfma(%scale[0-3] * ...
678+
//
679+
// This creates duplicate shape_casts for every use but these will be
680+
// removed in CSE.
681+
for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
682+
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
683+
if (!insertOp) {
684+
return rewriter.notifyMatchFailure(op,
685+
"defining op not a vector.insert");
686+
}
687+
// If the extracted value is not a single scalar, then it has been packed.
688+
if (isa<VectorType>(insertOp.getValueToStore().getType())) {
689+
return rewriter.notifyMatchFailure(
690+
op, "scaled mfma operand already packed");
691+
}
692+
693+
auto extractOp =
694+
insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
695+
if (!extractOp) {
696+
return rewriter.notifyMatchFailure(op,
697+
"defining op not a vector.extract");
698+
}
699+
700+
Value scaleSrc = extractOp.getOperand(0);
701+
auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
702+
if (!scaleSrcType) {
703+
return rewriter.notifyMatchFailure(op, "not a vector type");
704+
}
705+
706+
// We do not handle dynamic dims yet, assume that the input is padded to
707+
// a static shape now.
708+
if (!scaleSrcType.hasStaticShape()) {
709+
return rewriter.notifyMatchFailure(op,
710+
"dynamic dims not yet supported");
711+
}
712+
713+
int64_t numElements = scaleSrcType.getNumElements();
714+
if (numElements <= 4) {
715+
return rewriter.notifyMatchFailure(
716+
op, "no packing if # of scales less than four");
717+
}
718+
719+
// Find a linearized idx using the size and offsets of the extract op.
720+
auto extractedPos = llvm::to_vector_of<int64_t>(
721+
llvm::reverse(extractOp.getStaticPosition()));
722+
ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
723+
int64_t scaleSrcRank = scaleSrcType.getRank();
724+
SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
725+
for (int64_t i = 1; i < scaleSrcRank; ++i) {
726+
extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
727+
}
728+
int64_t idx = linearize(extractedPos, extractSizes);
729+
730+
// All n scales (where n is the total number of scales) must now be
731+
// extracted in chunks of 4 elements. This is done by dividing the
732+
// original vector of scales into groups of 4 elements
733+
// at offsets 0, 4, ..., m (where m = n/4). All extractions of a
734+
// scale at a particular index are now replaced with an extraction
735+
// of the entire group of 4 elements to which that index belongs.
736+
//
737+
// If the number of scales happens to be indivisible by 4, extract
738+
// the remaining n - m scales in a chunk of 4 elements starting at
739+
// offset n - 4.
740+
int64_t offset = idx - (idx % 4);
741+
int64_t opsel = idx - offset;
742+
int64_t size = 4l;
743+
// Accomdate remaining elements in the case of non-4-divisible vectors.
744+
if (numElements - offset < size) {
745+
opsel = size - (numElements - idx);
746+
offset = numElements - 4l;
747+
}
748+
Type scaleSrcElemType = scaleSrcType.getElementType();
749+
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
750+
scaleSrcElemType);
751+
Value newScaleSrc =
752+
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
753+
auto extract = vector::ExtractStridedSliceOp::create(
754+
rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
755+
ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
756+
rewriter.modifyOpInPlace(op, [&] {
757+
op->setOperand(opIdx, extract);
758+
setOpsel(opIdx, opsel);
759+
});
760+
}
761+
return success();
762+
}
763+
};
764+
} // namespace
765+
766+
void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
767+
MLIRContext *context) {
768+
results.add<PackScales>(context);
769+
}
770+
634771
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
635772

636773
#define GET_ATTRDEF_CLASSES

mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
1414
MLIRROCDLDialect
1515
# Needed for GPU address space enum definition
1616
MLIRGPUDialect
17+
MLIRVectorDialect
1718
MLIRIR
1819
MLIRSideEffectInterfaces
1920
MLIRMemRefUtils

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,88 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
159159
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
160160
func.return
161161
}
162+
163+
// -----
164+
165+
// CHECK-LABEL: func @scaled_mfma
166+
// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
167+
// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
168+
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
169+
// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
170+
// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
171+
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
172+
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
173+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
174+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
175+
%scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
176+
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
177+
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
178+
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
179+
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
180+
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
181+
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
182+
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
183+
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
184+
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
185+
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
186+
}
187+
188+
// -----
189+
190+
// CHECK-LABEL: func @scaled_mfma_less_than_4
191+
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
192+
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
193+
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
194+
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
195+
// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
196+
func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
197+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
198+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
199+
%scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
200+
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
201+
%scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
202+
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
203+
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
204+
return %res_0 : vector<4xf32>
205+
}
206+
207+
// -----
208+
209+
// CHECK-LABEL: func @scaled_mfma_ugly_shapes
210+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
211+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214+
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
215+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
216+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
217+
%scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
218+
%scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
219+
%scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
220+
%scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
221+
222+
// idx = 160 => opsel = 3 (last idx of last 4 bytes)
223+
%scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
224+
// idx = 159 => opsel = 3
225+
%scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
226+
// idx = 158 => opsel = 2
227+
%scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
228+
// idx = 157 => opsel = 1
229+
%scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
230+
231+
%sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
232+
%sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
233+
%sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
234+
%sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
235+
236+
%sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
237+
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
238+
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
239+
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
240+
241+
%res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
242+
%res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
243+
%res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
244+
%res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
245+
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
246+
}

0 commit comments

Comments
 (0)