@@ -816,9 +816,12 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
816816}
817817
818818llvm.func @rocdl.wmma (%arg0 : vector <8 xf32 >, %arg1 : vector <16 x f16 >, %arg2 : vector <16 x i16 >, %arg3 : vector <8 x i32 >,
819- %arg4 : vector <2 xi32 >, %arg5 : vector <4 xi32 >, %arg6 : vector <4 xf32 >, %arg7 : vector <8 xf16 >, %arg8 : vector <8 xi16 >) -> vector <8 xf32 > {
819+ %arg4 : vector <2 xi32 >, %arg5 : vector <4 xi32 >, %arg6 : vector <4 xf32 >, %arg7 : vector <8 xf16 >, %arg8 : vector <8 xi16 >,
820+ %arg9 : vector <32 xf16 >, %arg10 : vector <16 xf32 >, %arg11 : vector <4 xf32 >, %arg12 : vector <32 xf32 >,
821+ %arg13 : vector <16 xi32 >, %arg14 : vector <64 xf32 >, %arg15 : vector <64 xi32 >, %arg16 : i32 ) -> vector <8 xf32 > {
820822 %zero = llvm.mlir.constant (false ) : i1
821-
823+ %zero_i16 = llvm.mlir.constant (0 : i16 ) : i16
824+ %zero_i32 = llvm.mlir.constant (0 : i32 ) : i32
822825 // ---- Wave32 -----
823826
824827 // f16 -> f32
@@ -849,6 +852,38 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
849852 // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
850853 %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero , %arg4 , %zero , %arg4 , %arg3 , %zero : (i1 , vector <2 xi32 >, i1 , vector <2 xi32 >, vector <8 xi32 >, i1 ) -> vector <8 xi32 >
851854
855+ // f32 -> f32
856+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %10, i1 false, <16 x float> %10, i16 0, <4 x float> %11, i1 false, i1 false)
857+ %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero , %arg10 , %zero , %arg10 , %zero_i16 , %arg11 , %zero , %zero : (i1 , vector <16 xf32 >, i1 , vector <16 xf32 >, i16 , vector <4 xf32 >, i1 , i1 ) -> vector <4 xf32 >
858+
859+ // bf16 -> f32
860+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
861+ %r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xi16 >, i1 , vector <16 xi16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
862+
863+ // f16 -> f32
864+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
865+ %r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
866+
867+ // f16 -> f16
868+ // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x half> %9, i1 false, i1 false)
869+ %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg9 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf16 >, i1 , i1 ) -> vector <32 xf16 >
870+
871+ // bf16 -> bf16
872+ // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <16 x i32> %13, i1 false, i1 false)
873+ %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg13 , %zero , %zero : (i1 , vector <16 xi16 >, i1 , vector <16 xi16 >, i16 , vector <16 xi32 >, i1 , i1 ) -> vector <16 xi32 >
874+
875+ // bf16 -> bf16 / f32
876+ // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
877+ %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero , %arg2 , %zero , %arg2 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xi16 >, i1 , vector <16 xi16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <16 xi32 >
878+
879+ // f8 -> f32
880+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
881+ %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg14 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
882+
883+ // iu8 -> i32
884+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
885+ %r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero , %arg5 , %zero , %arg5 , %arg15 , %zero , %zero : (i1 , vector <4 xi32 >, i1 , vector <4 xi32 >, vector <64 xi32 >, i1 , i1 ) -> vector <64 xi32 >
886+
852887 // ---- Wave64 -----
853888
854889 // f16 -> f32
0 commit comments