Skip to content

Commit 58c991b

Browse files
authored
Update TosaToRock for new attention patterns (#1973)
* Add additional LIT tests * isMaskGeneric code * Partially working version * Fix errors and add LIT test * Make tests agnostic * Add newline * More changes for causal idenfication with nth no select * Add newlines * Remove TODO * Attend to Copilot comment * Attend to additional review comments * Add broadcast in instances where it doesn't exist * Remove additional whitespace * Add triangular check * Rebase fixes and update element types in MIGraphXToTosa * More rebase changes for switch to mul for broadcast * Clang-format * Update LIT test after rebase * Fix when we broadcast * Convert over some functions to using generic function * Clang-format * Fix comment * More review comments * Fix causal validation check * Use easier method for getting numHeads and batch * Update LIT test invocation * Attend to more review comments * More clang-format
1 parent 16fab03 commit 58c991b

File tree

8 files changed

+674
-138
lines changed

8 files changed

+674
-138
lines changed

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ BroadcastConverter::matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
584584
// because tosa does not have an explicit broadcast op
585585
auto oneTensor = rock::tosa::getOneTensor(rewriter, loc, outType);
586586
auto mulWithOne = rock::tosa::getMulOp(rewriter, loc, sameRankReshapedOp,
587-
oneTensor, elemType);
587+
oneTensor, newOutElementTy);
588588
rewriter.replaceOp(op, mulWithOne);
589589
return success();
590590
}

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 363 additions & 137 deletions
Large diffs are not rendered by default.

mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-causal.mlir

Lines changed: 125 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
// CHECK: [1 1 1]
3+
4+
module {
5+
func.func @mlir_attention(%arg0: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg1: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg2: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> {
6+
%0 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1>
7+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
8+
%2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1>
9+
%3 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1>
10+
%4 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
11+
%5 = migraphx.broadcast %arg0 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0>
12+
%6 = migraphx.greater %3, %5 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1>
13+
%7 = migraphx.convert %6 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1>
14+
%8 = migraphx.multibroadcast %7 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1>
15+
%9 = migraphx.slice %arg1 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1>
16+
%10 = migraphx.transpose %arg2 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128>
17+
%11 = migraphx.dot %9, %10 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1>
18+
%12 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
19+
%13 = migraphx.mul %11, %4 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1>
20+
%14 = migraphx.where %8, %12, %13 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
21+
%15 = migraphx.reshape %14 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
22+
%16 = migraphx.reduce_max %15 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
23+
%17 = migraphx.reshape %16 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
24+
%18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
25+
%19 = migraphx.sub %14, %18 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
26+
%20 = migraphx.exp %19 : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
27+
%21 = migraphx.reshape %20 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
28+
%22 = migraphx.reduce_sum %21 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
29+
%23 = migraphx.reshape %22 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
30+
%24 = migraphx.multibroadcast %23 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
31+
%25 = migraphx.div %20, %24 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
32+
%26 = migraphx.dot %25, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1>
33+
%27 = migraphx.transpose %26 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1>
34+
%28 = migraphx.reshape %27 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1>
35+
return %28 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1>
36+
}
37+
}
38+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
// CHECK: [1 1 1]
3+
4+
module {
5+
func.func @mlir_attention(%arg0: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg1: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg2: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> {
6+
%0 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1>
7+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
8+
%2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1>
9+
%3 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1>
10+
%4 = migraphx.slice %arg0 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1>
11+
%5 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128>
12+
%6 = migraphx.dot %4, %5 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1>
13+
%7 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
14+
%8 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
15+
%9 = migraphx.mul %6, %8 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1>
16+
%10 = migraphx.broadcast %arg2 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0>
17+
%11 = migraphx.greater %3, %10 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1>
18+
%12 = migraphx.convert %11 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1>
19+
%13 = migraphx.multibroadcast %12 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1>
20+
%14 = migraphx.where %13, %7, %9 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
21+
%15 = migraphx.reshape %14 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
22+
%16 = migraphx.reduce_max %15 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
23+
%17 = migraphx.reshape %16 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
24+
%18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
25+
%19 = migraphx.sub %14, %18 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
26+
%20 = migraphx.exp %19 : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
27+
%21 = migraphx.reshape %20 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
28+
%22 = migraphx.reduce_sum %21 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
29+
%23 = migraphx.reshape %22 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
30+
%24 = migraphx.multibroadcast %23 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
31+
%25 = migraphx.div %20, %24 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
32+
%26 = migraphx.dot %25, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1>
33+
%27 = migraphx.transpose %26 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1>
34+
%28 = migraphx.reshape %27 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1>
35+
return %28 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1>
36+
}
37+
}
38+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
// CHECK: [1 1 1]
3+
module {
4+
func.func @mlir_attention(%arg0: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg1: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg2: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> {
5+
%0 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
6+
%1 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1>
7+
%2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1>
8+
%3 = migraphx.slice %arg0 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1>
9+
%4 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128>
10+
%5 = migraphx.dot %3, %4 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1>
11+
%6 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
12+
%7 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
13+
%8 = migraphx.mul %5, %7 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1>
14+
%9 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1>
15+
%10 = migraphx.broadcast %arg2 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0>
16+
%11 = migraphx.greater %9, %10 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1>
17+
%12 = migraphx.convert %11 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1>
18+
%13 = migraphx.multibroadcast %12 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1>
19+
%14 = migraphx.where %13, %6, %8 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
20+
%15 = migraphx.softmax %14 {axis = 3 : i64} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
21+
%16 = migraphx.dot %15, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1>
22+
%17 = migraphx.transpose %16 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1>
23+
%18 = migraphx.reshape %17 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1>
24+
return %18 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1>
25+
}
26+
}
27+

0 commit comments

Comments
 (0)