@@ -1013,6 +1013,144 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
10131013 llvm.return %r0 : vector <8 xf32 >
10141014}
10151015
1016+ llvm.func @rocdl.wmma.scale.f32.16x16x128.f8f6f4 (%arg0 : i32 ,
1017+ %arg1 : vector <4 x f32 >, %arg2 : vector <8 xi32 >,
1018+ %arg3 : vector <6 xi32 >, %arg4 : vector <4 xi32 >) -> vector <4 x f32 > {
1019+ %cst0 = llvm.mlir.constant (0 : i32 ) : i32
1020+ %cst1 = llvm.mlir.constant (1 : i32 ) : i32
1021+ %cst2 = llvm.mlir.constant (2 : i32 ) : i32
1022+ %cst3 = llvm.mlir.constant (3 : i32 ) : i32
1023+ %cst4 = llvm.mlir.constant (4 : i32 ) : i32
1024+
1025+ // CHECK-LABEL: rocdl.mfma.scale.f32.16x16x128.f8f6f4
1026+ // fp8 * fp8
1027+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1028+ %r00 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg2 , %arg1 , %cst0 , %cst0 , %cst0 , %arg0 , %cst0 , %arg0 :
1029+ (vector <8 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1030+
1031+ // fp8 * bf8
1032+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1033+ %r01 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg2 , %arg1 , %cst0 , %cst1 , %cst0 , %arg0 , %cst0 , %arg0 :
1034+ (vector <8 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1035+
1036+ // fp8 * fp6
1037+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1038+ %r02 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg3 , %arg1 , %cst0 , %cst2 , %cst0 , %arg0 , %cst0 , %arg0 :
1039+ (vector <8 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1040+
1041+ // fp8 * bf6
1042+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1043+ %r03 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg3 , %arg1 , %cst0 , %cst3 , %cst0 , %arg0 , %cst0 , %arg0 :
1044+ (vector <8 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1045+
1046+ // fp8 * fp4
1047+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 0, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1048+ %r04 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg4 , %arg1 , %cst0 , %cst4 , %cst0 , %arg0 , %cst0 , %arg0 :
1049+ (vector <8 xi32 >, vector <4 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1050+
1051+ // bf8 * fp8
1052+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1053+ %r10 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg2 , %arg1 , %cst1 , %cst0 , %cst0 , %arg0 , %cst0 , %arg0 :
1054+ (vector <8 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1055+
1056+ // bf8 * bf8
1057+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1058+ %r11 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg2 , %arg1 , %cst1 , %cst1 , %cst0 , %arg0 , %cst0 , %arg0 :
1059+ (vector <8 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1060+
1061+ // bf8 * fp6
1062+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1063+ %r12 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg3 , %arg1 , %cst1 , %cst2 , %cst0 , %arg0 , %cst0 , %arg0 :
1064+ (vector <8 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1065+
1066+ // bf8 * bf6
1067+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v6i32(<8 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1068+ %r13 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg3 , %arg1 , %cst1 , %cst3 , %cst0 , %arg0 , %cst0 , %arg0 :
1069+ (vector <8 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1070+
1071+ // bf8 * fp4
1072+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i32(<8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 1, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1073+ %r14 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg2 , %arg4 , %arg1 , %cst1 , %cst4 , %cst0 , %arg0 , %cst0 , %arg0 :
1074+ (vector <8 xi32 >, vector <4 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1075+
1076+ // fp6 * fp8
1077+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1078+ %r20 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg2 , %arg1 , %cst2 , %cst0 , %cst0 , %arg0 , %cst0 , %arg0 :
1079+ (vector <6 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1080+
1081+ // fp6 * bf8
1082+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1083+ %r21 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg2 , %arg1 , %cst2 , %cst1 , %cst0 , %arg0 , %cst0 , %arg0 :
1084+ (vector <6 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1085+
1086+ // fp6 * fp6
1087+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1088+ %r22 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg3 , %arg1 , %cst2 , %cst2 , %cst0 , %arg0 , %cst0 , %arg0 :
1089+ (vector <6 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1090+
1091+ // fp6 * bf6
1092+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1093+ %r23 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg3 , %arg1 , %cst2 , %cst3 , %cst0 , %arg0 , %cst0 , %arg0 :
1094+ (vector <6 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1095+
1096+ // fp6 * fp4
1097+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 2, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1098+ %r24 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg4 , %arg1 , %cst2 , %cst4 , %cst0 , %arg0 , %cst0 , %arg0 :
1099+ (vector <6 xi32 >, vector <4 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1100+
1101+ // bf6 * fp8
1102+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1103+ %r30 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg2 , %arg1 , %cst3 , %cst0 , %cst0 , %arg0 , %cst0 , %arg0 :
1104+ (vector <6 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1105+
1106+ // bf6 * bf8
1107+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v8i32(<6 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1108+ %r31 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg2 , %arg1 , %cst3 , %cst1 , %cst0 , %arg0 , %cst0 , %arg0 :
1109+ (vector <6 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1110+
1111+ // bf6 * fp6
1112+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1113+ %r32 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg3 , %arg1 , %cst3 , %cst2 , %cst0 , %arg0 , %cst0 , %arg0 :
1114+ (vector <6 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1115+
1116+ // bf6 * bf6
1117+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1118+ %r33 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg3 , %arg1 , %cst3 , %cst3 , %cst0 , %arg0 , %cst0 , %arg0 :
1119+ (vector <6 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1120+
1121+ // bf6 * fp4
1122+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v4i32(<6 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 3, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1123+ %r34 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg3 , %arg4 , %arg1 , %cst3 , %cst4 , %cst0 , %arg0 , %cst0 , %arg0 :
1124+ (vector <6 xi32 >, vector <4 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1125+
1126+ // fp4 * fp8
1127+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1128+ %r40 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4 , %arg2 , %arg1 , %cst4 , %cst0 , %cst0 , %arg0 , %cst0 , %arg0 :
1129+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1130+
1131+ // fp4 * bf8
1132+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v8i32(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 1, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1133+ %r41 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4 , %arg2 , %arg1 , %cst4 , %cst1 , %cst0 , %arg0 , %cst0 , %arg0 :
1134+ (vector <4 xi32 >, vector <8 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1135+
1136+ // fp4 * fp6
1137+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 2, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1138+ %r42 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4 , %arg3 , %arg1 , %cst4 , %cst2 , %cst0 , %arg0 , %cst0 , %arg0 :
1139+ (vector <4 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1140+
1141+ // fp4 * bf6
1142+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v6i32(<4 x i32> %{{.*}}, <6 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 3, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}})
1143+ %r43 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4 , %arg3 , %arg1 , %cst4 , %cst3 , %cst0 , %arg0 , %cst0 , %arg0 :
1144+ (vector <4 xi32 >, vector <6 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1145+
1146+ // fp4 * fp4
1147+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 4, i32 4, i32 0, i32 %{{.*}}, i32 0, i32 %{{.*}}
1148+ %r44 = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %arg4 , %arg4 , %arg1 , %cst4 , %cst4 , %cst0 , %arg0 , %cst0 , %arg0 :
1149+ (vector <4 xi32 >, vector <4 xi32 >, vector <4 xf32 >, i32 , i32 , i32 , i32 , i32 , i32 ) -> vector <4 xf32 >
1150+
1151+ llvm.return %r00 : vector <4 x f32 >
1152+ }
1153+
10161154llvm.func @rocdl.ds.read.tr (%ptr : !llvm.ptr <3 >) -> vector <4 xf16 > {
10171155 // CHECK-LABEL: rocdl.ds.read.tr
10181156 // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
0 commit comments