Skip to content

Commit 3f62407

Browse files
authored
[ROCDL] Added rocdl.cvt.scale.sr.pk8 ops (#162244)
This patch introduces some missing FP conversion instructions in the ROCDL dialect for the GFX1250 arch. Specifically: Downscaling 8x packed F16, Bf16, Fp32 values to Fp8, Bf8, Fp4 with stochastic rounding Tests: Added lit-tests to check MLIR -> LLVM lowering
1 parent b256d0a commit 3f62407

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,24 @@ foreach smallT = [
10291029
attr-dict $src `,` $scale `:` type($res)
10301030
}];
10311031
}
1032+
1033+
1034+
def ROCDL_CvtScaleF32SrPk8 # smallT.nameForOp # largeT.nameForOp # Op :
1035+
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk8." # smallT.name # "." # largeT.name,
1036+
[Pure], 1>,
1037+
Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
1038+
let results = (outs smallT.type:$res);
1039+
let summary = "Scale and convert packed "
1040+
# largeT.name # " to packed " # smallT.name # " with stochastic rounding";
1041+
let description = [{
1042+
Convert 8 packed }] # largeT.name # [{ values to packed }]
1043+
# smallT.name # [{, multiplying by the exponent part of `scale`
1044+
before doing so and apply stochastic rounding. This op is for gfx1250+ arch.
1045+
}];
1046+
let assemblyFormat = [{
1047+
attr-dict $src `,` $seed `,` $scale `:` type($res)
1048+
}];
1049+
}
10321050
} // foreach largeT
10331051
} // foreach smallTOp
10341052

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
11001100

11011101
// -----
11021102

1103+
// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
1104+
llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
1105+
%v8xf16: vector<8xf16>,
1106+
%v8xbf16: vector<8xbf16>,
1107+
%seed: i32,
1108+
%scale: f32) {
1109+
1110+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32
1111+
%0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
1112+
// CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32
1113+
%1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
1114+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32
1115+
%2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
1116+
1117+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16
1118+
%3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
1119+
// CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16
1120+
%4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
1121+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16
1122+
%5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
1123+
1124+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16
1125+
%6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
1126+
// CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16
1127+
%7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
1128+
// CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16
1129+
%8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
1130+
1131+
llvm.return
1132+
}
1133+
1134+
// -----
1135+
11031136
// CHECK-LABEL: rocdl.cvt.scale.pk16
11041137
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
11051138

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>
13681368
llvm.return
13691369
}
13701370

1371+
// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
1372+
// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
1373+
llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
1374+
%v8xf16: vector<8xf16>,
1375+
%v8xbf16: vector<8xbf16>,
1376+
%seed: i32,
1377+
%scale: f32) {
1378+
1379+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
1380+
%0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
1381+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
1382+
%1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
1383+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
1384+
%2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
1385+
1386+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
1387+
%3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
1388+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
1389+
%4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
1390+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
1391+
%5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
1392+
1393+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
1394+
%6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
1395+
// CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
1396+
%7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
1397+
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
1398+
%8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
1399+
1400+
llvm.return
1401+
}
1402+
1403+
13711404
// CHECK-LABEL: @rocdl.cvt.scale.pk16
13721405
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
13731406
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {

0 commit comments

Comments
 (0)