Skip to content

Commit cd0f191

Browse files
authored
[mlir][rocdl] Add gfx1250+ cvt scale intrinsics (#159649)
1 parent e2467cb commit cd0f191

File tree

3 files changed

+168
-0
lines changed

3 files changed

+168
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,10 +835,17 @@ class ROCDL_ConcreteVector<Type elem, int length> :
835835

836836
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
837837
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
838+
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
838839
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
839840
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
841+
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
840842
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
841843
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
844+
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
845+
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
846+
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
847+
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
848+
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
842849
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
843850
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
844851
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
@@ -975,6 +982,68 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
975982
string nameForOp = typeName;
976983
}
977984

985+
//===---------------------------------------------------------------------===//
986+
// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
987+
//===---------------------------------------------------------------------===//
988+
989+
foreach smallT = [
990+
ScaleArgInfo<I32, "Fp4">,
991+
ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
992+
ScaleArgInfo<ROCDL_V2I32Type, "Bf8">
993+
] in {
994+
foreach largeT = [
995+
ScaleArgInfo<ROCDL_V8F16Type, "F16">,
996+
ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
997+
ScaleArgInfo<ROCDL_V8F32Type, "F32">,
998+
] in {
999+
def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
1000+
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
1001+
[Pure], 1, [2], ["scaleSel"]>,
1002+
Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {
1003+
1004+
let summary = "Scales 8 " # smallT.name # " and converts them to 8 " # largeT.name # ".";
1005+
let description = [{
1006+
Available on gfx1250+.
1007+
}];
1008+
let results = (outs largeT.type:$res);
1009+
let assemblyFormat = [{
1010+
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
1011+
}];
1012+
}
1013+
} // foreach largeT
1014+
} // foreach smallTOp
1015+
1016+
//===---------------------------------------------------------------------===//
1017+
// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
1018+
//===---------------------------------------------------------------------===//
1019+
1020+
foreach smallT = [
1021+
ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
1022+
ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
1023+
] in {
1024+
foreach largeT = [
1025+
ScaleArgInfo<ROCDL_V16F16Type, "F16">,
1026+
ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">,
1027+
ScaleArgInfo<ROCDL_V16F32Type, "F32">,
1028+
] in {
1029+
def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op :
1030+
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name,
1031+
[Pure], 1, [2], ["scaleSel"]>,
1032+
Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {
1033+
1034+
let summary = "Scales 16 " # smallT.name # " and converts them to 16 " # largeT.name # ".";
1035+
let description = [{
1036+
Available on gfx1250+.
1037+
}];
1038+
let results = (outs largeT.type:$res);
1039+
let assemblyFormat = [{
1040+
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
1041+
}];
1042+
1043+
}
1044+
} // foreach largeT
1045+
} // foreach smallTOp
1046+
9781047
//===---------------------------------------------------------------------===//
9791048
// Scaled 32x6-bit float float conversion intrinsics
9801049
//===---------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,57 @@ llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
10251025

10261026
// -----
10271027

1028+
// CHECK-LABEL: rocdl.cvt.scale.pk8
1029+
llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
1030+
1031+
// CHECK: rocdl.cvt.scale.pk8.f16.fp4
1032+
%0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
1033+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4
1034+
%1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
1035+
// CHECK: rocdl.cvt.scale.pk8.f32.fp4
1036+
%2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>
1037+
1038+
// CHECK: rocdl.cvt.scale.pk8.f16.fp8
1039+
%3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
1040+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8
1041+
%4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
1042+
// CHECK: rocdl.cvt.scale.pk8.f32.fp8
1043+
%5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>
1044+
1045+
// CHECK: rocdl.cvt.scale.pk8.f16.bf8
1046+
%6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
1047+
// CHECK: rocdl.cvt.scale.pk8.bf16.bf8
1048+
%7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
1049+
// CHECK: rocdl.cvt.scale.pk8.f32.bf8
1050+
%8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>
1051+
1052+
llvm.return
1053+
}
1054+
1055+
// -----
1056+
1057+
// CHECK-LABEL: rocdl.cvt.scale.pk16
1058+
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
1059+
1060+
// CHECK: rocdl.cvt.scale.pk16.f16.fp6
1061+
%0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
1062+
// CHECK: rocdl.cvt.scale.pk16.bf16.fp6
1063+
%1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
1064+
// CHECK: rocdl.cvt.scale.pk16.f32.fp6
1065+
%2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>
1066+
1067+
// CHECK: rocdl.cvt.scale.pk16.f16.bf6
1068+
%3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
1069+
// CHECK: rocdl.cvt.scale.pk16.bf16.bf6
1070+
%4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
1071+
// CHECK: rocdl.cvt.scale.pk16.f32.bf6
1072+
%5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>
1073+
1074+
llvm.return
1075+
}
1076+
1077+
// -----
1078+
10281079
// expected-error@below {{attribute attached to unexpected op}}
10291080
func.func private @expected_llvm_func() attributes { rocdl.kernel }
10301081

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,54 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
12981298
llvm.return %ret : i32
12991299
}
13001300

1301+
// CHECK-LABEL: rocdl.cvt.scale.pk8
1302+
// CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 %[[SCALE:.+]])
1303+
llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
1304+
1305+
// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
1306+
%0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
1307+
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
1308+
%1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
1309+
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
1310+
%2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>
1311+
1312+
// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1313+
%3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
1314+
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1315+
%4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
1316+
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1317+
%5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>
1318+
1319+
// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1320+
%6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
1321+
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1322+
%7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
1323+
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
1324+
%8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>
1325+
1326+
llvm.return
1327+
}
1328+
1329+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
1330+
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
1331+
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
1332+
1333+
// CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1334+
%0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
1335+
// CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1336+
%1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
1337+
// CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1338+
%2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>
1339+
// CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1340+
%3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
1341+
// CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1342+
%4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
1343+
// CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
1344+
%5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>
1345+
1346+
llvm.return
1347+
}
1348+
13011349
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
13021350
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
13031351
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

0 commit comments

Comments
 (0)