Skip to content

Commit 55327c3

Browse files
committed
more mlir tests
1 parent 4b2d023 commit 55327c3

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
258258
llvm.return
259259
}
260260

261+
262+
llvm.func @rocdl.smfmac(%arg0 : i32,
263+
%arg1 : vector<4 x f16>,
264+
%arg2 : vector<8 x f16>,
265+
%arg3 : vector<4 x f32>,
266+
%arg4 : vector<16 x f32>,
267+
%arg5 : vector<4 x i16>,
268+
%arg6 : vector<8 x i16>,
269+
%arg7 : vector<2xi32>,
270+
%arg8 : vector<4xi32>,
271+
%arg9 : vector<16xi32>) -> vector<4 x f32> {
272+
%csti32 = llvm.mlir.constant(42 : i32) : i32
273+
274+
// CHECK-LABEL: rocdl.smfmac
275+
// CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
276+
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
277+
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
278+
i32, i32, i32) -> vector<4xf32>
279+
280+
// CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
281+
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
282+
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
283+
i32, i32, i32) -> vector<16xf32>
284+
285+
// CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
286+
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
287+
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
288+
i32, i32, i32) -> vector<4xf32>
289+
290+
// CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
291+
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
292+
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
293+
i32, i32, i32) -> vector<16xf32>
294+
295+
// CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
296+
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
297+
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
298+
i32, i32, i32) -> vector<4xi32>
299+
300+
// CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
301+
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
302+
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
303+
i32, i32, i32) -> vector<16xi32>
304+
305+
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
306+
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
307+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
308+
i32, i32, i32) -> vector<4xf32>
309+
310+
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
311+
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
312+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
313+
i32, i32, i32) -> vector<4xf32>
314+
315+
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
316+
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
317+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
318+
i32, i32, i32) -> vector<4xf32>
319+
320+
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
321+
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
322+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
323+
i32, i32, i32) -> vector<4xf32>
324+
325+
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
326+
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
327+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
328+
i32, i32, i32) -> vector<16xf32>
329+
330+
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
331+
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
332+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
333+
i32, i32, i32) -> vector<16xf32>
334+
335+
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
336+
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
337+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
338+
i32, i32, i32) -> vector<16xf32>
339+
340+
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
341+
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
342+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
343+
i32, i32, i32) -> vector<16xf32>
344+
345+
llvm.return %r0 : vector<4 x f32>
346+
}
347+
261348
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
262349
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
263350
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {

0 commit comments

Comments
 (0)