@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
6262
6363// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
6464llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128 (
65- %arg0 : vector <4 xf32 >, %arg1 : vector <8 xbf16 >, %arg2 : vector <8 xbf16 >
65+ %src : vector <4 xf32 >, %a : vector <8 xbf16 >, %b : vector <8 xbf16 >
6666 ) -> vector <4 xf32 >
6767{
6868 // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
69- %0 = " x86vector.avx512.intr.dpbf16ps.128" (%arg0 , %arg1 , %arg2 )
69+ %0 = " x86vector.avx512.intr.dpbf16ps.128" (%src , %a , %b )
7070 : (vector <4 xf32 >, vector <8 xbf16 >, vector <8 xbf16 >) -> vector <4 xf32 >
7171 llvm.return %0 : vector <4 xf32 >
7272}
7373
7474// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
7575llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256 (
76- %arg0 : vector <8 xf32 >, %arg1 : vector <16 xbf16 >, %arg2 : vector <16 xbf16 >
76+ %src : vector <8 xf32 >, %a : vector <16 xbf16 >, %b : vector <16 xbf16 >
7777 ) -> vector <8 xf32 >
7878{
7979 // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
80- %0 = " x86vector.avx512.intr.dpbf16ps.256" (%arg0 , %arg1 , %arg2 )
80+ %0 = " x86vector.avx512.intr.dpbf16ps.256" (%src , %a , %b )
8181 : (vector <8 xf32 >, vector <16 xbf16 >, vector <16 xbf16 >) -> vector <8 xf32 >
8282 llvm.return %0 : vector <8 xf32 >
8383}
8484
8585// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
8686llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512 (
87- %arg0 : vector <16 xf32 >, %arg1 : vector <32 xbf16 >, %arg2 : vector <32 xbf16 >
87+ %src : vector <16 xf32 >, %a : vector <32 xbf16 >, %b : vector <32 xbf16 >
8888 ) -> vector <16 xf32 >
8989{
9090 // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
91- %0 = " x86vector.avx512.intr.dpbf16ps.512" (%arg0 , %arg1 , %arg2 )
91+ %0 = " x86vector.avx512.intr.dpbf16ps.512" (%src , %a , %b )
9292 : (vector <16 xf32 >, vector <32 xbf16 >, vector <32 xbf16 >) -> vector <16 xf32 >
9393 llvm.return %0 : vector <16 xf32 >
9494}
9595
96+ // CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
97+ llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256 (
98+ %a: vector <8 xf32 >) -> vector <8 xbf16 >
99+ {
100+ // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
101+ %0 = " x86vector.avx512.intr.cvtneps2bf16.256" (%a )
102+ : (vector <8 xf32 >) -> vector <8 xbf16 >
103+ llvm.return %0 : vector <8 xbf16 >
104+ }
105+
106+ // CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
107+ llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512 (
108+ %a: vector <16 xf32 >) -> vector <16 xbf16 >
109+ {
110+ // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
111+ %0 = " x86vector.avx512.intr.cvtneps2bf16.512" (%a )
112+ : (vector <16 xf32 >) -> vector <16 xbf16 >
113+ llvm.return %0 : vector <16 xbf16 >
114+ }
115+
96116// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
97117llvm.func @LLVM_x86_avx_rsqrt_ps_256 (%a: vector <8 xf32 >) -> vector <8 xf32 >
98118{
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
103123
104124// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
105125llvm.func @LLVM_x86_avx_dp_ps_256 (
106- %arg0 : vector <8 xf32 >, %arg1 : vector <8 xf32 >
126+ %a : vector <8 xf32 >, %b : vector <8 xf32 >
107127 ) -> vector <8 xf32 >
108128{
109129 // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
110- %0 = llvm.mlir.constant (-1 : i8 ) : i8
111- %1 = " x86vector.avx.intr.dp.ps.256" (%arg0 , %arg1 , %0 ) : (vector <8 xf32 >, vector <8 xf32 >, i8 ) -> vector <8 xf32 >
130+ %c = llvm.mlir.constant (-1 : i8 ) : i8
131+ %1 = " x86vector.avx.intr.dp.ps.256" (%a , %b , %c ) : (vector <8 xf32 >, vector <8 xf32 >, i8 ) -> vector <8 xf32 >
112132 llvm.return %1 : vector <8 xf32 >
113133}
0 commit comments