@@ -873,11 +873,10 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
873873
874874llvm.func @rocdl.wmma (%arg0 : vector <8 xf32 >, %arg1 : vector <16 x f16 >, %arg2 : vector <16 x i16 >, %arg3 : vector <8 x i32 >,
875875 %arg4 : vector <2 xi32 >, %arg5 : vector <4 xi32 >, %arg6 : vector <4 xf32 >, %arg7 : vector <8 xf16 >, %arg8 : vector <8 xi16 >,
876- %arg9 : vector <32 xf16 >, %arg10 : vector <16 xf32 >, %arg11 : vector <4 xf32 >, %arg12 : vector <32 xf32 >,
877- %arg13 : vector <16 x i32 >, %arg14 : vector <64 x f32 >, %arg15 : vector <64 x i32 >, %arg16 : i32 ) -> vector <8 xf32 > {
876+ %arg9 : vector <32 xf16 >, %arg10 : vector <16 xf32 >, %arg11 : vector <4 xf32 >, %arg12 : vector <32 xf32 >, %arg13 : vector < 64 x f32 >,
877+ %arg14 : vector <64 x i32 >, %arg15 : vector <64 x f16 >, %arg16 : vector <16 x bf16 >, %arg17 : vector < 32 x bf16 > ) -> vector <8 xf32 > {
878878 %zero = llvm.mlir.constant (false ) : i1
879879 %zero_i16 = llvm.mlir.constant (0 : i16 ) : i16
880- %zero_i32 = llvm.mlir.constant (0 : i32 ) : i32
881880 // ---- Wave32 -----
882881
883882 // f16 -> f32
@@ -909,36 +908,81 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
909908 %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero , %arg4 , %zero , %arg4 , %arg3 , %zero : (i1 , vector <2 xi32 >, i1 , vector <2 xi32 >, vector <8 xi32 >, i1 ) -> vector <8 xi32 >
910909
911910 // f32 -> f32
912- // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false , <16 x float> %10 , i1 false , <16 x float> %10 , i16 0, <4 x float> %11 , i1 false , i1 false )
911+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}} , <16 x float> %{{.*}} , i1 {{.*}} , <16 x float> %{{.*}} , i16 0, <4 x float> %{{.*}} , i1 {{.*}} , i1 {{.*}} )
913912 %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero , %arg10 , %zero , %arg10 , %zero_i16 , %arg11 , %zero , %zero : (i1 , vector <16 xf32 >, i1 , vector <16 xf32 >, i16 , vector <4 xf32 >, i1 , i1 ) -> vector <4 xf32 >
914913
915- // bf16 -> f32
916- // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
917- %r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xi16 >, i1 , vector <16 xi16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
918-
919914 // f16 -> f32
920- // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
921- %r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
915+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
916+ %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
917+
918+ // bf16 -> f32
919+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
920+ %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xbf16 >, i1 , vector <16 xbf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
922921
923922 // f16 -> f16
924- // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false , <16 x half> %1 , i1 false , <16 x half> %1 , i16 0, <32 x half> %9 , i1 false , i1 false )
923+ // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}} , <16 x half> %{{.*}} , i1 {{.*}} , <16 x half> %{{.*}} , i16 0, <32 x half> %{{.*}} , i1 {{.*}} , i1 {{.*}} )
925924 %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg9 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf16 >, i1 , i1 ) -> vector <32 xf16 >
926925
927926 // bf16 -> bf16
928- // CHECK: call <16 x i32 > @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16 (i1 false , <16 x i16 > %2 , i1 false , <16 x i16 > %2 , i16 0, <16 x i32 > %13 , i1 false , i1 false )
929- %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg13 , %zero , %zero : (i1 , vector <16 x i16 >, i1 , vector <16 x i16 >, i16 , vector <16 x i32 >, i1 , i1 ) -> vector <16 x i32 >
927+ // CHECK: call <32 x bfloat > @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16 (i1 {{.*}} , <16 x bfloat > %{{.*}} , i1 {{.*}} , <16 x bfloat > %{{.*}} , i16 0, <32 x bfloat > %{{.*}} , i1 {{.*}} , i1 {{.*}} )
928+ %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg17 , %zero , %zero : (i1 , vector <16 x bf16 >, i1 , vector <16 x bf16 >, i16 , vector <32 x bf16 >, i1 , i1 ) -> vector <32 x bf16 >
930929
931930 // bf16 -> bf16 / f32
932- // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
933- %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xi16 >, i1 , vector <16 xi16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <16 xi32 >
931+ // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
932+ %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xbf16 >, i1 , vector <16 xbf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xbf16 >
933+
934+ // f8/bf8 -> f16/f32
935+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
936+ %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
937+
938+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
939+ %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
940+
941+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
942+ %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
943+
944+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
945+ %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
946+
947+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
948+ %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
949+
950+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
951+ %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
952+
953+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
954+ %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
955+
956+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
957+ %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
958+
959+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
960+ %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
961+
962+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
963+ %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
964+
965+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
966+ %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
967+
968+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
969+ %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
970+
971+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
972+ %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
973+
974+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
975+ %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
976+
977+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
978+ %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
934979
935- // f8 -> f32
936- // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
937- %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg14 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
980+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
981+ %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
938982
939983 // iu8 -> i32
940- // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false , <4 x i32> %5 , i1 false , <4 x i32> %5 , <64 x i32> %15 , i1 false , i1 false )
941- %r8 .gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero , %arg5 , %zero , %arg5 , %arg15 , %zero , %zero : (i1 , vector <4 xi32 >, i1 , vector <4 xi32 >, vector <64 xi32 >, i1 , i1 ) -> vector <64 xi32 >
984+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}} , <4 x i32> %{{.*}} , i1 {{.*}} , <4 x i32> %{{.*}} , <64 x i32> %{{.*}} , i1 {{.*}} , i1 {{.*}} )
985+ %r23 .gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero , %arg5 , %zero , %arg5 , %arg14 , %zero , %zero : (i1 , vector <4 xi32 >, i1 , vector <4 xi32 >, vector <64 xi32 >, i1 , i1 ) -> vector <64 xi32 >
942986
943987 // ---- Wave64 -----
944988
0 commit comments