@@ -219,7 +219,9 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
219219 %arg4 : vector <16 x f32 >, %arg5 : vector <4 xf32 >,
220220 %arg6 : vector <4 xf16 >, %arg7 : vector <32 x i32 >,
221221 %arg8 : vector <16 x i32 >, %arg9 : vector <4 xi32 >,
222- %arg10 : vector <2 xi16 >, %arg11 : i64 ) -> vector <32 x f32 > {
222+ %arg10 : vector <2 xi16 >, %arg11 : i64 ,
223+ %arg12 : vector <8 xbf16 >, %arg13 : vector <4 xi32 >,
224+ %arg14 : vector <8 xf16 >) -> vector <32 x f32 > {
223225 %csti32 = llvm.mlir.constant (42 : i32 ) : i32
224226
225227 // CHECK-LABEL: rocdl.xdlops
@@ -362,6 +364,37 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
362364 %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11 , %arg11 , %arg4 , %csti32 , %csti32 , %csti32 :
363365 (i64 , i64 , vector <16 xf32 >,
364366 i32 , i32 , i32 ) -> vector <16 xf32 >
367+
368+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf16(<8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
369+ %r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12 , %arg12 , %arg5 , %csti32 , %csti32 , %csti32 :
370+ (vector <8 xbf16 >, vector <8 xbf16 >, vector <4 xf32 >,
371+ i32 , i32 , i32 ) -> vector <4 xf32 >
372+
373+ // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x64.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
374+ %r29 = rocdl.mfma.i32.16x16x64.i8 %arg9 , %arg9 , %arg9 , %csti32 , %csti32 , %csti32 :
375+ (vector <4 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
376+ i32 , i32 , i32 ) -> vector <4 xi32 >
377+
378+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
379+ %r30 = rocdl.mfma.f32.16x16x32.f16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
380+ (vector <8 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
381+ i32 , i32 , i32 ) -> vector <4 xi32 >
382+
383+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
384+ %r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12 , %arg12 , %arg4 , %csti32 , %csti32 , %csti32 :
385+ (vector <8 xbf16 >, vector <8 xbf16 >, vector <16 xf32 >,
386+ i32 , i32 , i32 ) -> vector <16 xf32 >
387+
388+ // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x32.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
389+ %r32 = rocdl.mfma.i32.32x32x32.i8 %arg9 , %arg9 , %arg8 , %csti32 , %csti32 , %csti32 :
390+ (vector <4 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
391+ i32 , i32 , i32 ) -> vector <16 xi32 >
392+
393+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
394+ %r33 = rocdl.mfma.f32.32x32x16.f16 %arg14 , %arg14 , %arg4 , %csti32 , %csti32 , %csti32 :
395+ (vector <8 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
396+ i32 , i32 , i32 ) -> vector <16 xf32 >
397+
365398 llvm.return %r0 : vector <32 x f32 >
366399}
367400
0 commit comments