Skip to content

Commit c11b4e0

Browse files
committed
Add e2e test for failing MIGraphX CI case
1 parent fd77066 commit c11b4e0

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// This is a design that was in the MIGraphX CI that was previously failing
2+
// here: https://ontrack-internal.amd.com/browse/SWDEV-558297
3+
4+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
5+
6+
module {
7+
// CHECK: [1 1 1]
8+
func.func @mlir_attention(
9+
%arg0: !migraphx.shaped<1x1x12xf16, 12x12x1>,
10+
%arg1: !migraphx.shaped<1x2x4x2xf16, 16x8x2x1>,
11+
%arg2: !migraphx.shaped<1x2x4x2xf16, 16x8x2x1>,
12+
%arg3: !migraphx.shaped<1x1xsi32, 1x1>
13+
) -> !migraphx.shaped<1x1x4xf16, 4x4x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
14+
%0 = migraphx.literal(dense<[0, 1, 2, 3]> : tensor<4xsi32>) : <4xsi32, 1>
15+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
16+
%2 = migraphx.literal(dense<1.000000e+00> : tensor<1xf16>) : <1xf16, 1>
17+
%3 = migraphx.reshape %arg0 {dims = [1, 1, 6, 2]} : <1x1x12xf16, 12x12x1> -> <1x1x6x2xf16, 12x12x2x1>
18+
%4 = migraphx.transpose %3 {permutation = [0, 2, 1, 3]} : <1x1x6x2xf16, 12x12x2x1> -> <1x6x1x2xf16, 12x2x12x1>
19+
%5 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 2]} : <1x1xsi32, 1x1> -> <1x2xsi32, 1x0>
20+
%6 = migraphx.slice %4 {axes = [1], ends = [2], starts = [0]} : <1x6x1x2xf16, 12x2x12x1> -> <1x2x1x2xf16, 12x2x12x1>
21+
%7 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <1x2x4x2xf16, 16x8x2x1> -> <1x2x2x4xf16, 16x8x1x2>
22+
%8 = migraphx.dot %6, %7 : <1x2x1x2xf16, 12x2x12x1>, <1x2x2x4xf16, 16x8x1x2> -> <1x2x1x4xf16, 8x4x4x1>
23+
%9 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <4xsi32, 1> -> <1x2x1x4xsi32, 0x0x0x1>
24+
%10 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1xf16, 1> -> <1x2x1x4xf16, 0x0x0x0>
25+
%11 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1xf16, 1> -> <1x2x1x4xf16, 0x0x0x0>
26+
%12 = migraphx.mul %8, %11 : <1x2x1x4xf16, 8x4x4x1>, <1x2x1x4xf16, 0x0x0x0> -> <1x2x1x4xf16, 8x4x4x1>
27+
%13 = migraphx.reshape %5 {dims = [1, 2, 1, 1]} : <1x2xsi32, 1x0> -> <1x2x1x1xsi32, 2x1x1x1>
28+
%14 = migraphx.multibroadcast %13 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1x2x1x1xsi32, 2x1x1x1> -> <1x2x1x4xsi32, 2x1x1x0>
29+
%15 = migraphx.greater %9, %14 : <1x2x1x4xsi32, 0x0x0x1>, <1x2x1x4xsi32, 2x1x1x0> -> <1x2x1x4xsi32, 8x4x4x1>
30+
%16 = migraphx.convert %15 {target_type = 0 : i64} : <1x2x1x4xsi32, 8x4x4x1> to <1x2x1x4xsi8, 8x4x4x1>
31+
%17 = migraphx.where %16, %10, %12 : <1x2x1x4xsi8, 8x4x4x1>, <1x2x1x4xf16, 0x0x0x0>, <1x2x1x4xf16, 8x4x4x1> -> <1x2x1x4xf16, 8x4x4x1>
32+
%18 = migraphx.softmax %17 {axis = 3 : i64} : <1x2x1x4xf16, 8x4x4x1> -> <1x2x1x4xf16, 8x4x4x1>
33+
%19 = migraphx.dot %18, %arg2 : <1x2x1x4xf16, 8x4x4x1>, <1x2x4x2xf16, 16x8x2x1> -> <1x2x1x2xf16, 4x2x2x1>
34+
%20 = migraphx.transpose %19 {permutation = [0, 2, 1, 3]} : <1x2x1x2xf16, 4x2x2x1> -> <1x1x2x2xf16, 4x2x2x1>
35+
%21 = migraphx.reshape %20 {dims = [1, 1, 4]} : <1x1x2x2xf16, 4x2x2x1> -> <1x1x4xf16, 4x4x1>
36+
return %21 : !migraphx.shaped<1x1x4xf16, 4x4x1>
37+
}
38+
}
39+

0 commit comments

Comments
 (0)