@@ -398,6 +398,95 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
398398 llvm.return %r0 : vector <32 x f32 >
399399}
400400
401+ llvm.func @rocdl.smfmac (%arg0 : i32 ,
402+ %arg1 : vector <4 x f16 >,
403+ %arg2 : vector <8 x f16 >,
404+ %arg3 : vector <4 x f32 >,
405+ %arg4 : vector <16 x f32 >,
406+ %arg5 : vector <4 x i16 >,
407+ %arg6 : vector <8 x i16 >,
408+ %arg7 : vector <2 xi32 >,
409+ %arg8 : vector <4 xi32 >,
410+ %arg9 : vector <16 xi32 >) -> vector <4 x f32 > {
411+ %csti32 = llvm.mlir.constant (42 : i32 ) : i32
412+
413+ // CHECK-LABEL: rocdl.smfmac
414+
415+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
416+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1 , %arg2 , %arg3 , %csti32 , %csti32 , %csti32 :
417+ (vector <4 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
418+ i32 , i32 , i32 ) -> vector <4 xf32 >
419+
420+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
421+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1 , %arg2 , %arg4 , %csti32 , %csti32 , %csti32 :
422+ (vector <4 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
423+ i32 , i32 , i32 ) -> vector <16 xf32 >
424+
425+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
426+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5 , %arg6 , %arg3 , %csti32 , %csti32 , %csti32 :
427+ (vector <4 xi16 >, vector <8 xi16 >, vector <4 xf32 >,
428+ i32 , i32 , i32 ) -> vector <4 xf32 >
429+
430+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
431+ %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5 , %arg6 , %arg4 , %csti32 , %csti32 , %csti32 :
432+ (vector <4 xi16 >, vector <8 xi16 >, vector <16 xf32 >,
433+ i32 , i32 , i32 ) -> vector <16 xf32 >
434+
435+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
436+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7 , %arg8 , %arg8 , %csti32 , %csti32 , %csti32 :
437+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
438+ i32 , i32 , i32 ) -> vector <4 xi32 >
439+
440+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
441+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7 , %arg8 , %arg9 , %csti32 , %csti32 , %csti32 :
442+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
443+ i32 , i32 , i32 ) -> vector <16 xi32 >
444+
445+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
446+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
447+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
448+ i32 , i32 , i32 ) -> vector <4 xf32 >
449+
450+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
451+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
452+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
453+ i32 , i32 , i32 ) -> vector <4 xf32 >
454+
455+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
456+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
457+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
458+ i32 , i32 , i32 ) -> vector <4 xf32 >
459+
460+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
461+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
462+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
463+ i32 , i32 , i32 ) -> vector <4 xf32 >
464+
465+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
466+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
467+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
468+ i32 , i32 , i32 ) -> vector <16 xf32 >
469+
470+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
471+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
472+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
473+ i32 , i32 , i32 ) -> vector <16 xf32 >
474+
475+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
476+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
477+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
478+ i32 , i32 , i32 ) -> vector <16 xf32 >
479+
480+
481+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
482+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
483+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
484+ i32 , i32 , i32 ) -> vector <16 xf32 >
485+
486+ llvm.return %r0 : vector <4 x f32 >
487+ }
488+
489+
401490llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4 (%arg0 : i32 ,
402491 %arg1 : vector <16 x f32 >, %arg2 : vector <8 xi32 >,
403492 %arg3 : vector <6 xi32 >, %arg4 : vector <4 xi32 >) -> vector <16 x f32 > {
0 commit comments