Skip to content

Commit 79f1484

Browse files
Removing unnecessary type changes
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 2ce03cb commit 79f1484

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4085,11 +4085,11 @@ class AMDGPUWmmaScaleF4IntrinsicModsC<LLVMType scale_ty> :
40854085

40864086
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
40874087
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
4088-
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyfloat_ty>;
4088+
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
40894089
def int_amdgcn_wmma_f32_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
40904090
def int_amdgcn_wmma_f16_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
4091-
def int_amdgcn_wmma_bf16_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyint_ty>;
4092-
def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
4091+
def int_amdgcn_wmma_bf16_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
4092+
def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
40934093
def int_amdgcn_wmma_f32_16x16x64_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
40944094
def int_amdgcn_wmma_f32_16x16x64_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
40954095
def int_amdgcn_wmma_f32_16x16x64_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,10 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
873873

874874
llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
875875
%arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
876-
%arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>,
877-
%arg13 : vector<16xi32>, %arg14 : vector<64xf32>, %arg15 : vector<64xi32>, %arg16 : i32) -> vector<8xf32> {
876+
%arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>,
877+
%arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> {
878878
%zero = llvm.mlir.constant(false) : i1
879879
%zero_i16 = llvm.mlir.constant(0 : i16) : i16
880-
%zero_i32 = llvm.mlir.constant(0 : i32) : i32
881880
// ---- Wave32 -----
882881

883882
// f16 -> f32
@@ -909,36 +908,81 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
909908
%r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
910909

911910
// f32 -> f32
912-
// CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %10, i1 false, <16 x float> %10, i16 0, <4 x float> %11, i1 false, i1 false)
911+
// CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
913912
%r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
914913

915-
// bf16 -> f32
916-
// CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
917-
%r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
918-
919914
// f16 -> f32
920-
// CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
921-
%r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
915+
// CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
916+
%r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
917+
918+
// bf16 -> f32
919+
// CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
920+
%r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
922921

923922
// f16 -> f16
924-
// CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x half> %9, i1 false, i1 false)
923+
// CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
925924
%r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
926925

927926
// bf16 -> bf16
928-
// CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <16 x i32> %13, i1 false, i1 false)
929-
%r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg13, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<16xi32>, i1, i1) -> vector<16xi32>
927+
// CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}})
928+
%r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16>
930929

931930
// bf16 -> bf16 / f32
932-
// CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
933-
%r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<16xi32>
931+
// CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
932+
%r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16>
933+
934+
// f8/bf8 -> f16/f32
935+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
936+
%r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
937+
938+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
939+
%r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
940+
941+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
942+
%r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
943+
944+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
945+
%r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
946+
947+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
948+
%r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
949+
950+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
951+
%r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
952+
953+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
954+
%r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
955+
956+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
957+
%r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
958+
959+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
960+
%r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
961+
962+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
963+
%r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
964+
965+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
966+
%r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
967+
968+
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
969+
%r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
970+
971+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
972+
%r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
973+
974+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
975+
%r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
976+
977+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
978+
%r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
934979

935-
// f8 -> f32
936-
// CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
937-
%r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg14, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
980+
// CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
981+
%r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16>
938982

939983
// iu8 -> i32
940-
// CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
941-
%r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
984+
// CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}})
985+
%r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
942986

943987
// ---- Wave64 -----
944988

0 commit comments

Comments
 (0)