|
| 1 | +// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s |
| 2 | +// CHECK: [1 1 1] |
| 3 | + |
| 4 | +module { |
| 5 | + func.func @mlir_attention(%arg0: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg1: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg2: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> { |
| 6 | + %0 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1> |
| 7 | + %1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1> |
| 8 | + %2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1> |
| 9 | + %3 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1> |
| 10 | + %4 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0> |
| 11 | + %5 = migraphx.broadcast %arg0 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0> |
| 12 | + %6 = migraphx.greater %3, %5 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1> |
| 13 | + %7 = migraphx.convert %6 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1> |
| 14 | + %8 = migraphx.multibroadcast %7 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1> |
| 15 | + %9 = migraphx.slice %arg1 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1> |
| 16 | + %10 = migraphx.transpose %arg2 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128> |
| 17 | + %11 = migraphx.dot %9, %10 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1> |
| 18 | + %12 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0> |
| 19 | + %13 = migraphx.mul %11, %4 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1> |
| 20 | + %14 = migraphx.where %8, %12, %13 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1> |
| 21 | + %15 = migraphx.reshape %14 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1> |
| 22 | + %16 = migraphx.reduce_max %15 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1> |
| 23 | + %17 = migraphx.reshape %16 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1> |
| 24 | + %18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0> |
| 25 | + %19 = migraphx.sub %14, %18 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1> |
| 26 | + %20 = migraphx.exp %19 : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1> |
| 27 | + %21 = migraphx.reshape %20 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1> |
| 28 | + %22 = migraphx.reduce_sum %21 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1> |
| 29 | + %23 = migraphx.reshape %22 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1> |
| 30 | + %24 = migraphx.multibroadcast %23 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0> |
| 31 | + %25 = migraphx.div %20, %24 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1> |
| 32 | + %26 = migraphx.dot %25, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1> |
| 33 | + %27 = migraphx.transpose %26 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1> |
| 34 | + %28 = migraphx.reshape %27 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1> |
| 35 | + return %28 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1> |
| 36 | + } |
| 37 | +} |
| 38 | + |
0 commit comments