- 
                Notifications
    You must be signed in to change notification settings 
- Fork 14.9k
[mlir][rocdl] Add gfx1250+ cvt scale intrinsics #159649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][rocdl] Add gfx1250+ cvt scale intrinsics #159649
Conversation
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Erick Ochoa Lopez (amd-eochoalo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/159649.diff 3 Files Affected: 
 diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 9fa3ec1fc4b21..1252d8589cc63 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -835,10 +835,17 @@ class ROCDL_ConcreteVector<Type elem, int length> :
 
 def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
 def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
 def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
 def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
 def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
 def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
 def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
 def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
 def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
@@ -975,6 +982,61 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
   string nameForOp = typeName;
 }
 
+//===---------------------------------------------------------------------===//
+// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
+//===---------------------------------------------------------------------===//
+
+foreach smallT = [
+  ScaleArgInfo<I32, "Fp4">,
+  ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
+  ScaleArgInfo<ROCDL_V2I32Type, "Bf8">
+] in {
+  foreach largeT = [
+    ScaleArgInfo<ROCDL_V8F16Type, "F16">,
+    ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
+    ScaleArgInfo<ROCDL_V8F32Type, "F32">,
+  ] in {
+    def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
+          ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
+          [Pure], 1, [2], ["scaleSel"]>,
+        Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {
+
+      let summary = "Scales 8 " # smallT.name # " and converts them to 8 " # largeT.name # ".";
+      let results = (outs largeT.type:$res);
+      let assemblyFormat = [{
+        attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
+      }];
+    }
+  } // foreach largeT
+} // foreach smallTOp
+
+//===---------------------------------------------------------------------===//
+// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
+//===---------------------------------------------------------------------===//
+
+foreach smallT = [
+  ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
+  ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
+] in {
+  foreach largeT = [
+    ScaleArgInfo<ROCDL_V16F16Type, "F16">,
+    ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">,
+    ScaleArgInfo<ROCDL_V16F32Type, "F32">,
+  ] in {
+    def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op :
+          ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name,
+          [Pure], 1, [2], ["scaleSel"]>,
+        Arguments<(ins smallT.type:$src, I32:$scale, I32Attr:$scaleSel)> {
+
+      let summary = "Scales 16 " # smallT.name # " and converts them to 16 " # largeT.name # ".";
+      let results = (outs largeT.type:$res);
+      let assemblyFormat = [{
+        attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
+      }];
+    }
+  } // foreach largeT
+} // foreach smallTOp
+
 //===---------------------------------------------------------------------===//
 // Scaled 32x6-bit float float conversion intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 782ef4e154440..959bb35302b20 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1025,6 +1025,57 @@ llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
 
 // -----
 
+// CHECK-LABEL: rocdl.cvt.scale.pk8
+llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
+
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp4
+  %0 =      rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp4
+  %1 =      rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp4
+  %2 =      rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>
+
+  // CHECK: rocdl.cvt.scale.pk8.f16.fp8
+  %3 =      rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.fp8
+  %4 =      rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
+  // CHECK: rocdl.cvt.scale.pk8.f32.fp8
+  %5 =      rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>
+
+  // CHECK: rocdl.cvt.scale.pk8.f16.bf8
+  %6 =      rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
+  // CHECK: rocdl.cvt.scale.pk8.bf16.bf8
+  %7 =      rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
+  // CHECK: rocdl.cvt.scale.pk8.f32.bf8
+  %8 =      rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>
+
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: rocdl.cvt.scale.pk16
+llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
+
+  // CHECK: rocdl.cvt.scale.pk16.f16.fp6
+  %0 =      rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.fp6
+  %1 =      rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
+  // CHECK: rocdl.cvt.scale.pk16.f32.fp6
+  %2 =      rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>
+
+  // CHECK: rocdl.cvt.scale.pk16.f16.bf6
+  %3 =      rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
+  // CHECK: rocdl.cvt.scale.pk16.bf16.bf6
+  %4 =      rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
+  // CHECK: rocdl.cvt.scale.pk16.f32.bf6
+  %5 =      rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>
+
+  llvm.return
+}
+
+// -----
+
 // expected-error@below {{attribute attached to unexpected op}}
 func.func private @expected_llvm_func() attributes { rocdl.kernel }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index a464358250c38..bf18db99f6cf2 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1298,6 +1298,54 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
   llvm.return %ret : i32
 }
 
+// CHECK-LABEL: rocdl.cvt.scale.pk8
+// CHECK-SAME:(i32 %[[I32:.+]], <2 x i32> %[[V2I32:.+]], i32 [[SCALE:.+]])
+llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
+
+  // CHECK:   call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp4(i32 [[I32]], i32 [[SCALE]], i32 0)
+  %0 =                               rocdl.cvt.scale.pk8.f16.fp4 %i32, %scale[0] : vector<8xf16>
+  // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp4(i32 [[I32]], i32 [[SCALE]], i32 0)
+  %1 =                               rocdl.cvt.scale.pk8.bf16.fp4 %i32, %scale[0] : vector<8xbf16>
+  // CHECK:  call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f16.fp4(i32 [[I32]], i32 [[SCALE]], i32 0)
+  %2 =                               rocdl.cvt.scale.pk8.f32.fp4 %i32, %scale[0] : vector<8xf32>
+
+  // CHECK:   call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.fp8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %3 =                               rocdl.cvt.scale.pk8.f16.fp8 %v2xi32, %scale[0] : vector<8xf16>
+  // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.fp8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %4 =                               rocdl.cvt.scale.pk8.bf16.fp8 %v2xi32, %scale[0] : vector<8xbf16>
+  // CHECK:   call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f16.fp8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %5 =                                rocdl.cvt.scale.pk8.f32.fp8 %v2xi32, %scale[0] : vector<8xf32>
+
+  // CHECK:   call <8 x half> @llvm.amdgcn.cvt.scale.pk8.f16.bf8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %6 =                               rocdl.cvt.scale.pk8.f16.bf8 %v2xi32, %scale[0] : vector<8xf16>
+  // CHECK: call <8 x bfloat> @llvm.amdgcn.cvt.scale.pk8.bf16.bf8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %7 =                               rocdl.cvt.scale.pk8.bf16.bf8 %v2xi32, %scale[0] : vector<8xbf16>
+  // CHECK:   call <8 x float> @llvm.amdgcn.cvt.scale.pk8.f16.bf8(<2 x i32> [[V2I32]], i32 [[SCALE]], i32 0)
+  %8 =                                rocdl.cvt.scale.pk8.f32.bf8 %v2xi32, %scale[0] : vector<8xf32>
+
+  llvm.return
+}
+
+// CHECK-LABEL: @rocdl.cvt.scale.pk16
+// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
+
+  // CHECK:   call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %0 =                                rocdl.cvt.scale.pk16.f16.fp6 %v3xi32, %scale[0] : vector<16xf16>
+  // CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %1 =                                rocdl.cvt.scale.pk16.bf16.fp6 %v3xi32, %scale[0] : vector<16xbf16>
+  // CHECK:  call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.fp6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %2 =                                rocdl.cvt.scale.pk16.f32.fp6 %v3xi32, %scale[0] : vector<16xf32>
+  // CHECK:   call <16 x half> @llvm.amdgcn.cvt.scale.pk16.f16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %3 =                                rocdl.cvt.scale.pk16.f16.bf6 %v3xi32, %scale[0] : vector<16xf16>
+  // CHECK: call <16 x bfloat> @llvm.amdgcn.cvt.scale.pk16.bf16.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %4 =                                rocdl.cvt.scale.pk16.bf16.bf6 %v3xi32, %scale[0] : vector<16xbf16>
+  // CHECK:  call <16 x float> @llvm.amdgcn.cvt.scale.pk16.f32.bf6(<3 x i32> %[[SRC0]], i32 %[[SCALE]], i32 0)
+  %5 =                                rocdl.cvt.scale.pk16.f32.bf6 %v3xi32, %scale[0] : vector<16xf32>
+
+  llvm.return
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"
 | 
b1c72d4    to
    a99f2fb      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add to the description that these are available on gfx1250+, similar to https://mlir.llvm.org/docs/Dialects/ROCDLDialect/#rocdlswaitdscnt-rocdlwaitdscntop ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to document the supported chips, as was already mentioned, but overall this all makes sense
| @kuhar added availability to the description. Thanks! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Maybe also add it to the PR subject line? [mlir][rocdl] Add gfx1250+ cvt scale intrinsics
| It would be nice to be able to find all the gfx1250 bits by searching git log | 
No description provided.