@@ -395,41 +395,98 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
395395 (vector <8 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
396396 i32 , i32 , i32 ) -> vector <16 xf32 >
397397
398- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
399- %r34 = rocdl.smfmac.f32.16x16x32.f16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
398+ llvm.return %r0 : vector <32 x f32 >
399+ }
400+
401+ llvm.func @rocdl.smfmac (%arg0 : i32 ,
402+ %arg1 : vector <4 x f16 >,
403+ %arg2 : vector <8 x f16 >,
404+ %arg3 : vector <4 x f32 >,
405+ %arg4 : vector <16 x f32 >,
406+ %arg5 : vector <4 x i16 >,
407+ %arg6 : vector <8 x i16 >,
408+ %arg7 : vector <2 xi32 >,
409+ %arg8 : vector <4 xi32 >,
410+ %arg9 : vector <16 xi32 >) -> vector <4 x f32 > {
411+ %csti32 = llvm.mlir.constant (42 : i32 ) : i32
412+
413+ // CHECK-LABEL: rocdl.smfmac
414+
415+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
416+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1 , %arg2 , %arg3 , %csti32 , %csti32 , %csti32 :
400417 (vector <4 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
401- i32 , i32 , i32 ) -> vector <16 x f32 >
418+ i32 , i32 , i32 ) -> vector <4 x f32 >
402419
403- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %a , <8 x half> %b , <16 x float> %c , i32 %idx , i32 0 , i32 0 )
404- %r35 = rocdl.smfmac.f32.32x32x16.f16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
420+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}} , <8 x half> %{{.*}} , <16 x float> %{{.*}} , i32 42 , i32 42 , i32 42 )
421+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1 , %arg2 , %arg4 , %csti32 , %csti32 , %csti32 :
405422 (vector <4 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
406423 i32 , i32 , i32 ) -> vector <16 xf32 >
407424
408- // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %a , <8 x i16> %b , <4 x float> %c , i32 %idx , i32 0 , i32 0 )
409- %r36 = rocdl.smfmac.f32.16x16x32.bf16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
410- (vector <4 xi16 >, vector <8 xi16 >, vector <4 x i16 >,
425+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}} , <8 x i16> %{{.*}} , <4 x float> %{{.*}} , i32 42 , i32 42 , i32 42 )
426+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5 , %arg6 , %arg3 , %csti32 , %csti32 , %csti32 :
427+ (vector <4 xi16 >, vector <8 xi16 >, vector <4 x f32 >,
411428 i32 , i32 , i32 ) -> vector <4 xf32 >
412429
413- // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %a , <8 x i16> %b , <16 x float> %c , i32 %idx , i32 0 , i32 0 )
414- %r37 = rocdl.smfmac.f32.16x16x32 .bf16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
430+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}} , <8 x i16> %{{.*}} , <16 x float> %{{.*}} , i32 42 , i32 42 , i32 42 )
431+ %r3 = rocdl.smfmac.f32.32x32x16 .bf16 %arg5 , %arg6 , %arg4 , %csti32 , %csti32 , %csti32 :
415432 (vector <4 xi16 >, vector <8 xi16 >, vector <16 xf32 >,
416433 i32 , i32 , i32 ) -> vector <16 xf32 >
417434
435+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
436+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7 , %arg8 , %arg8 , %csti32 , %csti32 , %csti32 :
437+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
438+ i32 , i32 , i32 ) -> vector <4 xi32 >
418439
419- //def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
420- //def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
421- //def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
422- //def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
423- //def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
424- //def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
425- //def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
426- //def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
427- //def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
428- //def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
440+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
441+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7 , %arg8 , %arg9 , %csti32 , %csti32 , %csti32 :
442+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
443+ i32 , i32 , i32 ) -> vector <16 xi32 >
429444
430- llvm.return %r0 : vector <32 x f32 >
445+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
446+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
447+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
448+ i32 , i32 , i32 ) -> vector <4 xf32 >
449+
450+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
451+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
452+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
453+ i32 , i32 , i32 ) -> vector <4 xf32 >
454+
455+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
456+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
457+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
458+ i32 , i32 , i32 ) -> vector <4 xf32 >
459+
460+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
461+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
462+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
463+ i32 , i32 , i32 ) -> vector <4 xf32 >
464+
465+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
466+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
467+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
468+ i32 , i32 , i32 ) -> vector <16 xf32 >
469+
470+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
471+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
472+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
473+ i32 , i32 , i32 ) -> vector <16 xf32 >
474+
475+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
476+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
477+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
478+ i32 , i32 , i32 ) -> vector <16 xf32 >
479+
480+
481+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
482+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
483+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
484+ i32 , i32 , i32 ) -> vector <16 xf32 >
485+
486+ llvm.return %r0 : vector <4 x f32 >
431487}
432488
489+
433490llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4 (%arg0 : i32 ,
434491 %arg1 : vector <16 x f32 >, %arg2 : vector <8 xi32 >,
435492 %arg3 : vector <6 xi32 >, %arg4 : vector <4 xi32 >) -> vector <16 x f32 > {
0 commit comments