Skip to content

Commit 4b2d023

Browse files
committed
fix tests
1 parent bc26b45 commit 4b2d023

File tree

1 file changed

+78
-21
lines changed

1 file changed

+78
-21
lines changed

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -395,41 +395,98 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
395395
(vector<8xf16>, vector<8xf16>, vector<16xf32>,
396396
i32, i32, i32) -> vector<16xf32>
397397

398-
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
399-
%r34 = rocdl.smfmac.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
398+
llvm.return %r0 : vector<32 x f32>
399+
}
400+
401+
llvm.func @rocdl.smfmac(%arg0 : i32,
402+
%arg1 : vector<4 x f16>,
403+
%arg2 : vector<8 x f16>,
404+
%arg3 : vector<4 x f32>,
405+
%arg4 : vector<16 x f32>,
406+
%arg5 : vector<4 x i16>,
407+
%arg6 : vector<8 x i16>,
408+
%arg7 : vector<2xi32>,
409+
%arg8 : vector<4xi32>,
410+
%arg9 : vector<16xi32>) -> vector<4 x f32> {
411+
%csti32 = llvm.mlir.constant(42 : i32) : i32
412+
413+
// CHECK-LABEL: rocdl.smfmac
414+
415+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
416+
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
400417
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
401-
i32, i32, i32) -> vector<16xf32>
418+
i32, i32, i32) -> vector<4xf32>
402419

403-
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %a, <8 x half> %b, <16 x float> %c, i32 %idx, i32 0, i32 0)
404-
%r35 = rocdl.smfmac.f32.32x32x16.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
420+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
421+
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
405422
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
406423
i32, i32, i32) -> vector<16xf32>
407424

408-
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %a, <8 x i16> %b, <4 x float> %c, i32 %idx, i32 0, i32 0)
409-
%r36 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
410-
(vector<4xi16>, vector<8xi16>, vector<4xi16>,
425+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
426+
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
427+
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
411428
i32, i32, i32) -> vector<4xf32>
412429

413-
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %a, <8 x i16> %b, <16 x float> %c, i32 %idx, i32 0, i32 0)
414-
%r37 = rocdl.smfmac.f32.16x16x32.bf16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
430+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
431+
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
415432
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
416433
i32, i32, i32) -> vector<16xf32>
417434

435+
// CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
436+
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
437+
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
438+
i32, i32, i32) -> vector<4xi32>
418439

419-
//def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
420-
//def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
421-
//def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
422-
//def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
423-
//def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
424-
//def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
425-
//def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
426-
//def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
427-
//def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
428-
//def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
440+
// CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
441+
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
442+
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
443+
i32, i32, i32) -> vector<16xi32>
429444

430-
llvm.return %r0 : vector<32 x f32>
445+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
446+
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
447+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
448+
i32, i32, i32) -> vector<4xf32>
449+
450+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
451+
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
452+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
453+
i32, i32, i32) -> vector<4xf32>
454+
455+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
456+
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
457+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
458+
i32, i32, i32) -> vector<4xf32>
459+
460+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
461+
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
462+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
463+
i32, i32, i32) -> vector<4xf32>
464+
465+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
466+
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
467+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
468+
i32, i32, i32) -> vector<16xf32>
469+
470+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
471+
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
472+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
473+
i32, i32, i32) -> vector<16xf32>
474+
475+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
476+
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
477+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
478+
i32, i32, i32) -> vector<16xf32>
479+
480+
481+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
482+
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
483+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
484+
i32, i32, i32) -> vector<16xf32>
485+
486+
llvm.return %r0 : vector<4 x f32>
431487
}
432488

489+
433490
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
434491
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
435492
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {

0 commit comments

Comments
 (0)