Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -835,10 +835,17 @@ class ROCDL_ConcreteVector<Type elem, int length> :

def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
Expand Down Expand Up @@ -975,6 +982,61 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
string nameForOp = typeName;
}

//===---------------------------------------------------------------------===//
// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//

foreach smallT = [
ScaleArgInfo<I32, "Fp4">,
ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
ScaleArgInfo<ROCDL_V2I32Type, "Bf8">
] in {
foreach largeT = [
ScaleArgInfo<ROCDL_V8F16Type, "F16">,
ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V8F32Type, "F32">,
] in {
def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {

let summary = "Scales 8 " # smallT.name # " and converts them to 8 " # largeT.name # ".";
let results = (outs largeT.type:$res);
let assemblyFormat = [{
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
}];
}
} // foreach largeT
} // foreach smallTOp

//===---------------------------------------------------------------------===//
// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//

foreach smallT = [
ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
] in {
foreach largeT = [
ScaleArgInfo<ROCDL_V16F16Type, "F16">,
ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V16F32Type, "F32">,
] in {
def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {

let summary = "Scales 16 " # smallT.name # " and converts them to 16 " # largeT.name # ".";
let results = (outs largeT.type:$res);
let assemblyFormat = [{
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
}];
}
} // foreach largeT
} // foreach smallTOp

//===---------------------------------------------------------------------===//
// Scaled 32x6-bit float float conversion intrinsics
//===---------------------------------------------------------------------===//
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,57 @@ llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {

// -----

// CHECK-LABEL: rocdl.cvt.scale.pk8
llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {

// CHECK: rocdl.cvt.scale.pk8.f16.fp4
%0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4
%1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
// CHECK: rocdl.cvt.scale.pk8.f32.fp4
%2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>

// CHECK: rocdl.cvt.scale.pk8.f16.fp8
%3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8
%4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
// CHECK: rocdl.cvt.scale.pk8.f32.fp8
%5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>

// CHECK: rocdl.cvt.scale.pk8.f16.bf8
%6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
// CHECK: rocdl.cvt.scale.pk8.bf16.bf8
%7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
// CHECK: rocdl.cvt.scale.pk8.f32.bf8
%8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>

llvm.return
}

// -----

// CHECK-LABEL: rocdl.cvt.scale.pk16
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {

// CHECK: rocdl.cvt.scale.pk16.f16.fp6
%0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
// CHECK: rocdl.cvt.scale.pk16.bf16.fp6
%1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
// CHECK: rocdl.cvt.scale.pk16.f32.fp6
%2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>

// CHECK: rocdl.cvt.scale.pk16.f16.bf6
%3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
// CHECK: rocdl.cvt.scale.pk16.bf16.bf6
%4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
// CHECK: rocdl.cvt.scale.pk16.f32.bf6
%5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>

llvm.return
}

// -----

// expected-error@below {{attribute attached to unexpected op}}
func.func private @expected_llvm_func() attributes { rocdl.kernel }

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,54 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
llvm.return %ret : i32
}

// CHECK-LABEL: rocdl.cvt.scale.pk8
// CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {

// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
%0 = rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
%1 = rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp4(i32 %[[I32]], i32 %[[SCALE]], i32 0)
%2 = rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>

// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%3 = rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%4 = rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.fp8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%5 = rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>

// CHECK: call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%6 = rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
// CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%7 = rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
// CHECK: call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f32.bf8(<2 x i32> %[[V2I32]], i32 %[[SCALE]], i32 0)
%8 = rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>

llvm.return
}

// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {

// CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%0 = rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
// CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%1 = rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
// CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%2 = rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>
// CHECK: call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%3 = rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
// CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%4 = rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
// CHECK: call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
%5 = rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>

llvm.return
}

// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
Expand Down