Skip to content

Commit 8a9face

Browse files
PR review round 4
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 02f5d98 commit 8a9face

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,8 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
539539
/// Note that the type of `input` has already been LLVM type converted:
540540
/// therefore 8-bit and smaller floats are represented as their corresponding
541541
/// `iN` integers.
542-
static Value castScaledMFMAVectorOperand(ConversionPatternRewriter &rewriter,
543-
Location loc, Value input) {
542+
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
543+
Location loc, Value input) {
544544
Type inputType = input.getType();
545545
Type outputType = rewriter.getI32Type();
546546
if (auto intType = dyn_cast<IntegerType>(inputType))
@@ -1018,10 +1018,10 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
10181018
createI32Constant(rewriter, loc, bTypeCode),
10191019
/*scales idx A=*/scalesIdxA,
10201020
/*scales A*/
1021-
castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
1021+
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
10221022
/*scales idx B=*/scalesIdxB,
10231023
/*scales B*/
1024-
castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesB())});
1024+
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
10251025
Value lowered = rewriter.create(loweredOp)->getResult(0);
10261026
rewriter.replaceOp(op, lowered);
10271027
return success();

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5252
func.return
5353
}
5454

55+
// CHECK-LABEL: func @scaled_mfma_to_rocdl(
56+
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xi8>, %[[ARG8:.*]]: i8
5557
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
5658
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
5759
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
@@ -60,42 +62,42 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
6062

6163
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6264
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
63-
// CHECK: %[[c2:.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
64-
// CHECK: %[[c3:.+]] = llvm.zext{{.*}} : i8 to i32
65+
// CHECK: %[[b0:.+]] = llvm.bitcast %[[ARG7]] : vector<4xi8> to i32
66+
// CHECK: %[[z0:.+]] = llvm.zext %[[ARG8]] : i8 to i32
6567

66-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
68+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
6769
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<16xf32>
68-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
70+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
6971
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<4xf32>
7072

7173
// CHECK: llvm.bitcast
7274

73-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
75+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
7476
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<16xf32>
75-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
77+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
7678
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<4xf32>
7779

7880
// CHECK: llvm.bitcast
7981

80-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
82+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
8183
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
82-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
84+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
8385
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<4xf32>
8486

8587
// CHECK: llvm.bitcast
8688
// CHECK: llvm.mlir.constant(3 : i32) : i32
8789

88-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
90+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
8991
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<16xf32>
90-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
92+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
9193
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<4xf32>
9294

9395
// CHECK: llvm.bitcast
9496
// CHECK: llvm.mlir.constant(4 : i32) : i32
9597

96-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
98+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
9799
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<16xf32>
98-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
100+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
99101
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<4xf32>
100102

101103
func.return

0 commit comments

Comments
 (0)