@@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
258258 llvm.return
259259}
260260
261+
262+ llvm.func @rocdl.smfmac (%arg0 : i32 ,
263+ %arg1 : vector <4 x f16 >,
264+ %arg2 : vector <8 x f16 >,
265+ %arg3 : vector <4 x f32 >,
266+ %arg4 : vector <16 x f32 >,
267+ %arg5 : vector <4 x i16 >,
268+ %arg6 : vector <8 x i16 >,
269+ %arg7 : vector <2 xi32 >,
270+ %arg8 : vector <4 xi32 >,
271+ %arg9 : vector <16 xi32 >) -> vector <4 x f32 > {
272+ %csti32 = llvm.mlir.constant (42 : i32 ) : i32
273+
274+ // CHECK-LABEL: rocdl.smfmac
275+ // CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
276+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1 , %arg2 , %arg3 , %csti32 , %csti32 , %csti32 :
277+ (vector <4 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
278+ i32 , i32 , i32 ) -> vector <4 xf32 >
279+
280+ // CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
281+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1 , %arg2 , %arg4 , %csti32 , %csti32 , %csti32 :
282+ (vector <4 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
283+ i32 , i32 , i32 ) -> vector <16 xf32 >
284+
285+ // CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
286+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5 , %arg6 , %arg3 , %csti32 , %csti32 , %csti32 :
287+ (vector <4 xi16 >, vector <8 xi16 >, vector <4 xf32 >,
288+ i32 , i32 , i32 ) -> vector <4 xf32 >
289+
290+ // CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
291+ %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5 , %arg6 , %arg4 , %csti32 , %csti32 , %csti32 :
292+ (vector <4 xi16 >, vector <8 xi16 >, vector <16 xf32 >,
293+ i32 , i32 , i32 ) -> vector <16 xf32 >
294+
295+ // CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
296+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7 , %arg8 , %arg8 , %csti32 , %csti32 , %csti32 :
297+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
298+ i32 , i32 , i32 ) -> vector <4 xi32 >
299+
300+ // CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
301+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7 , %arg8 , %arg9 , %csti32 , %csti32 , %csti32 :
302+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
303+ i32 , i32 , i32 ) -> vector <16 xi32 >
304+
305+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
306+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
307+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
308+ i32 , i32 , i32 ) -> vector <4 xf32 >
309+
310+ // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
311+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
312+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
313+ i32 , i32 , i32 ) -> vector <4 xf32 >
314+
315+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
316+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
317+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
318+ i32 , i32 , i32 ) -> vector <4 xf32 >
319+
320+ // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
321+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
322+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
323+ i32 , i32 , i32 ) -> vector <4 xf32 >
324+
325+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
326+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
327+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
328+ i32 , i32 , i32 ) -> vector <16 xf32 >
329+
330+ // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
331+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
332+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
333+ i32 , i32 , i32 ) -> vector <16 xf32 >
334+
335+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
336+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
337+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
338+ i32 , i32 , i32 ) -> vector <16 xf32 >
339+
340+ // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
341+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
342+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
343+ i32 , i32 , i32 ) -> vector <16 xf32 >
344+
345+ llvm.return %r0 : vector <4 x f32 >
346+ }
347+
261348llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4 (%arg0 : i32 ,
262349 %arg1 : vector <16 x f32 >, %arg2 : vector <8 xi32 >,
263350 %arg3 : vector <6 xi32 >, %arg4 : vector <4 xi32 >) {
0 commit comments