Skip to content

Commit 970aa1a

Browse files
Perform packing for inputs with shapes non-divisible by 4
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 2404d99 commit 970aa1a

File tree

2 files changed

+118
-35
lines changed

2 files changed

+118
-35
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/PatternMatch.h"
2828
#include "mlir/IR/TypeUtilities.h"
2929
#include "llvm/ADT/DenseMap.h"
30+
#include "llvm/ADT/SmallVector.h"
3031
#include "llvm/ADT/TypeSwitch.h"
3132

3233
#include <cstdint>
@@ -688,36 +689,22 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
688689
return idx;
689690
};
690691

691-
// Obtain offsets for new shape from flat index.
692-
auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
693-
SmallVector<int64_t> res;
694-
ShapedType shapedty = static_cast<ShapedType>(ty);
695-
int64_t numElements = shapedty.getNumElements();
696-
for (unsigned size : shapedty.getShape()) {
697-
numElements /= size;
698-
res.push_back(idx / numElements);
699-
idx -= (idx / numElements) * size;
700-
}
701-
return res;
702-
};
703-
704692
// For every scale operand of this ScaledMFMAOp, if the scale follows the
705693
// following pattern:
706-
//
707-
// %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
708-
// vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
709-
// into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
694+
// (f8 here means f8E8M0FNU)
695+
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
696+
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
697+
// amdgpu.scaled_mfma(%scale[0] * ...
710698
//
711699
// rewrite to:
712700
//
713-
// %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
714-
// vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
715-
// vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
701+
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
702+
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
716703
// amdgpu.scaled_mfma(%scale[0-3] * ...
717704
//
718705
// This creates duplicate shape_casts for every use but these will be
719706
// removed in CSE.
720-
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
707+
for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
721708
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
722709
if (!insertOp) {
723710
return rewriter.notifyMatchFailure(op,
@@ -738,7 +725,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
738725
Value scaleSrc = extractOp.getOperand(0);
739726
auto stype = dyn_cast<VectorType>(scaleSrc.getType());
740727
if (!stype) {
741-
return rewriter.notifyMatchFailure(op, "not a shaped type");
728+
return rewriter.notifyMatchFailure(op, "not a vector type");
742729
}
743730
// We do not handle dynamic dims yet, assume that the input is padded to
744731
// a static shape now.
@@ -748,25 +735,32 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
748735
}
749736

750737
int64_t numElements = stype.getNumElements();
751-
if (numElements <= 4 || !(numElements % 4)) {
738+
if (numElements <= 4) {
752739
return rewriter.notifyMatchFailure(
753-
op, "no packing if # of scales less than or indivisible by four");
740+
op, "no packing if # of scales less than four");
741+
}
742+
int64_t idx = getIdxFromExtract(extractOp);
743+
int64_t offset = idx - (idx % 4);
744+
int64_t size = std::min(4l, numElements - offset);
745+
int64_t opsel = idx - offset;
746+
if (size != 4l) {
747+
opsel += 4l - size;
748+
offset = numElements - 4l;
749+
size = 4l;
754750
}
755751

756-
Type newSrcType = VectorType::get(
757-
SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
752+
Type newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
753+
stype.getElementType());
758754
Value newScaleSrc =
759755
rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
760-
int64_t idx = getIdxFromExtract(extractOp);
761-
SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
762756
auto scaleTy = VectorType::get({4}, stype.getElementType());
763757
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
764-
loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
765-
SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
758+
loc, newScaleSrc, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{size},
759+
ArrayRef<int64_t>{1});
766760
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
767761
rewriter.modifyOpInPlace(
768762
op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
769-
setOpsel(opIdx, offsets[1]);
763+
setOpsel(opIdx, opsel);
770764
}
771765
return success();
772766
}

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,11 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
163163
// -----
164164

165165
// CHECK-LABEL: func @scaled_mfma
166-
// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
167-
// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
166+
// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
167+
// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
168168
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
169-
// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
170-
// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
169+
// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
170+
// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
171171
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
172172
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
173173
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
@@ -184,3 +184,92 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
184184
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
185185
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
186186
}
187+
188+
// -----
189+
190+
// CHECK-LABEL: func @scaled_mfma_less_than_4
191+
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
192+
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
193+
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
194+
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
195+
// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
196+
func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
197+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
198+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
199+
%scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
200+
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
201+
%scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
202+
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
203+
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
204+
return %res_0 : vector<4xf32>
205+
}
206+
207+
208+
// -----
209+
210+
// CHECK-LABEL: func @scaled_mfma_ugly_shapes
211+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
215+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
216+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
217+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
218+
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
219+
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
220+
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
221+
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
222+
%scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
223+
%scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
224+
%scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
225+
%scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
226+
%scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
227+
%scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
228+
%scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
229+
%scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
230+
231+
// idx = 138 + 8 = 146 => opsel = 2
232+
%scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
233+
// idx = 147 => opsel = 3
234+
%scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
235+
// idx = 148 => opsel = 0
236+
%scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
237+
// idx = 149 => opsel = 1
238+
%scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
239+
// idx = 160 => opsel = 3 (last idx of last 4 bytes)
240+
%scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
241+
// idx = 159 => opsel = 3
242+
%scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
243+
// idx = 158 => opsel = 2
244+
%scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
245+
// idx = 157 => opsel = 1
246+
%scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
247+
248+
%sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
249+
%sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
250+
%sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
251+
%sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
252+
%sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
253+
%sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
254+
%sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
255+
%sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
256+
257+
%sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
258+
%sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
259+
%sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
260+
%sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
261+
%sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
262+
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
263+
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
264+
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
265+
266+
%res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
267+
%res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
268+
%res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
269+
%res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
270+
%res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
271+
%res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
272+
%res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
273+
%res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
274+
return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
275+
}

0 commit comments

Comments
 (0)