1+ // RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
3+ module {
4+ // CHECK: [1 1 1]
5+ func.func @mlir_attention (
6+ %arg0: !migraphx.shaped <1 x1 x12 xf16 , 12 x12 x1 >,
7+ %arg1: !migraphx.shaped <1 x2 x4 x2 xf16 , 16 x8 x2 x1 >,
8+ %arg2: !migraphx.shaped <1 x2 x4 x2 xf16 , 16 x8 x2 x1 >,
9+ %arg3: !migraphx.shaped <1 x1 xsi32 , 1 x1 >
10+ ) -> !migraphx.shaped <1 x1 x4 xf16 , 4 x4 x1 > attributes {arch = " " , kernel = " mixr" , num_cu = 0 : i64 } {
11+ %0 = migraphx.literal (dense <[0 , 1 , 2 , 3 ]> : tensor <4 xsi32 >) : <4 xsi32 , 1 >
12+ %1 = migraphx.literal (dense <0xFC00 > : tensor <1 xf16 >) : <1 xf16 , 1 >
13+ %2 = migraphx.literal (dense <1.000000e+00 > : tensor <1 xf16 >) : <1 xf16 , 1 >
14+ %3 = migraphx.reshape %arg0 {dims = [1 , 1 , 6 , 2 ]} : <1 x1 x12 xf16 , 12 x12 x1 > -> <1 x1 x6 x2 xf16 , 12 x12 x2 x1 >
15+ %4 = migraphx.transpose %3 {permutation = [0 , 2 , 1 , 3 ]} : <1 x1 x6 x2 xf16 , 12 x12 x2 x1 > -> <1 x6 x1 x2 xf16 , 12 x2 x12 x1 >
16+ %5 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1 , 2 ]} : <1 x1 xsi32 , 1 x1 > -> <1 x2 xsi32 , 1 x0 >
17+ %6 = migraphx.slice %4 {axes = [1 ], ends = [2 ], starts = [0 ]} : <1 x6 x1 x2 xf16 , 12 x2 x12 x1 > -> <1 x2 x1 x2 xf16 , 12 x2 x12 x1 >
18+ %7 = migraphx.transpose %arg1 {permutation = [0 , 1 , 3 , 2 ]} : <1 x2 x4 x2 xf16 , 16 x8 x2 x1 > -> <1 x2 x2 x4 xf16 , 16 x8 x1 x2 >
19+ %8 = migraphx.dot %6 , %7 : <1 x2 x1 x2 xf16 , 12 x2 x12 x1 >, <1 x2 x2 x4 xf16 , 16 x8 x1 x2 > -> <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >
20+ %9 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1 , 2 , 1 , 4 ]} : <4 xsi32 , 1 > -> <1 x2 x1 x4 xsi32 , 0x0x0x1 >
21+ %10 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1 , 2 , 1 , 4 ]} : <1 xf16 , 1 > -> <1 x2 x1 x4 xf16 , 0x0x0x0 >
22+ %11 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1 , 2 , 1 , 4 ]} : <1 xf16 , 1 > -> <1 x2 x1 x4 xf16 , 0x0x0x0 >
23+ %12 = migraphx.mul %8 , %11 : <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >, <1 x2 x1 x4 xf16 , 0x0x0x0 > -> <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >
24+ %13 = migraphx.reshape %5 {dims = [1 , 2 , 1 , 1 ]} : <1 x2 xsi32 , 1 x0 > -> <1 x2 x1 x1 xsi32 , 2 x1 x1 x1 >
25+ %14 = migraphx.multibroadcast %13 {out_dyn_dims = [], out_lens = [1 , 2 , 1 , 4 ]} : <1 x2 x1 x1 xsi32 , 2 x1 x1 x1 > -> <1 x2 x1 x4 xsi32 , 2 x1 x1 x0 >
26+ %15 = migraphx.greater %9 , %14 : <1 x2 x1 x4 xsi32 , 0x0x0x1 >, <1 x2 x1 x4 xsi32 , 2 x1 x1 x0 > -> <1 x2 x1 x4 xsi32 , 8 x4 x4 x1 >
27+ %16 = migraphx.convert %15 {target_type = 0 : i64 } : <1 x2 x1 x4 xsi32 , 8 x4 x4 x1 > to <1 x2 x1 x4 xsi8 , 8 x4 x4 x1 >
28+ %17 = migraphx.where %16 , %10 , %12 : <1 x2 x1 x4 xsi8 , 8 x4 x4 x1 >, <1 x2 x1 x4 xf16 , 0x0x0x0 >, <1 x2 x1 x4 xf16 , 8 x4 x4 x1 > -> <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >
29+ %18 = migraphx.softmax %17 {axis = 3 : i64 } : <1 x2 x1 x4 xf16 , 8 x4 x4 x1 > -> <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >
30+ %19 = migraphx.dot %18 , %arg2 : <1 x2 x1 x4 xf16 , 8 x4 x4 x1 >, <1 x2 x4 x2 xf16 , 16 x8 x2 x1 > -> <1 x2 x1 x2 xf16 , 4 x2 x2 x1 >
31+ %20 = migraphx.transpose %19 {permutation = [0 , 2 , 1 , 3 ]} : <1 x2 x1 x2 xf16 , 4 x2 x2 x1 > -> <1 x1 x2 x2 xf16 , 4 x2 x2 x1 >
32+ %21 = migraphx.reshape %20 {dims = [1 , 1 , 4 ]} : <1 x1 x2 x2 xf16 , 4 x2 x2 x1 > -> <1 x1 x4 xf16 , 4 x4 x1 >
33+ return %21 : !migraphx.shaped <1 x1 x4 xf16 , 4 x4 x1 >
34+ }
35+ }
0 commit comments