Skip to content

Commit ab6cfd9

Browse files
wmma scales intrinsics
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 048d50d commit ab6cfd9

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,10 @@ def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8
516516
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
517517
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
518518
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
519+
def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4">;
520+
def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4">;
521+
def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.32x16x128.f4">;
522+
def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.32x16x128.f4">;
519523

520524
//===---------------------------------------------------------------------===//
521525
// LDS transpose intrinsics (available in GFX950)

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,10 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
884884
// CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
885885
%r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
886886

887+
%r9.gfx1250 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %zero_i32, %arg5, %zero_i32, %arg5, %zero_i16, %arg11, %zero_i32, %zero_i32, %arg16, %zero_i32, %zero_i32, %arg16, %zero, %zero : (i32, vector<4xi32>, i32, vector<4xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
888+
// %r7.gfx1250 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4
889+
// %r7.gfx1250 = rocdl.wmma.scale.f32.32x16x128.f4
890+
// %r7.gfx1250 = rocdl.wmma.scale16.f32.32x16x128.f4
887891
// ---- Wave64 -----
888892

889893
// f16 -> f32

0 commit comments

Comments
 (0)