11// RUN: rocmlir-gen -fut quant_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand none -RMS_threshold=3e-3 -absDiff_threshold 7e-1 -relDiff_threshold 3e-3 -fut quant_dot_multi_reduce_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2-
2+ // RUN: rocmlir-gen -fut quant_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_min 1 -rand_max 2 -rand_type float -RMS_threshold=3e-3 -absDiff_threshold 7e-1 -relDiff_threshold 3e-3 -fut quant_dot_multi_reduce_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CLONE
33// We need a check for each output as this test case has two outputs in it.
44// CHECK: [1 1 1]
55// CHECK: [1 1 1]
6+ // CLONE: [1 1 1]
7+ // CLONE: [1 1 1]
68module {
7- func.func @quant_dot_multi_reduce (%arg0: !migraphx.shaped <2 x32 x10 x64 x64 xf32 , 0x10x1x0x0 >, %arg1: !migraphx.shaped <2 x320 x320 xf4 E2 M1 FN, 102400 x320 x1 >, %arg2: !migraphx.shaped <2 x320 x4096 xf4 E2 M1 FN, 1310720 x4096 x1 >, %arg3: !migraphx.shaped <2 x 10 x 1 x 320 xf 8 E 8 M 0 FNU, 3200 x 320 x 320 x 1 >, %arg4: !migraphx.shaped <2 x 320 x 128 x 1 xf 8 E 8 M 0 FNU, 40960 x 128 x 1 x 1 >) -> (!migraphx.shaped <2 x32 x1 x1 x1 xf32 , 32 x1 x1 x1 x1 >, !migraphx.shaped <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >) attributes {arch = " " , enable_splitk_for_tuning , kernel = " mixr" } {
8- %0 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2 , 10 , 32 , 320 ]} : <2 x 10 x 1 x 320 xf 8 E 8 M 0 FNU, 3200 x 320 x 320 x 1 > -> <2 x 10 x 32 x 320 xf 8 E 8 M 0 FNU, 3200 x 320 x 0 x 1 >
9- %1 = migraphx.reshape %0 {dims = [2 , 320 , 320 ]} : <2 x 10 x 32 x 320 xf 8 E 8 M 0 FNU, 3200 x 320 x 0 x 1 > -> <2 x320 x320 xf8 E8 M0 FNU, 102400 x 320 x 1 >
10- %2 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [2 , 320 , 128 , 32 ]} : <2 x 320 x 128 x 1 xf 8 E 8 M 0 FNU, 40960 x 128 x 1 x 1 > -> <2 x 320 x 128 x 32 xf 8 E 8 M 0 FNU, 40960 x 128 x 1 x 0 >
11- %3 = migraphx.reshape %2 {dims = [2 , 320 , 4096 ]} : <2 x 320 x 128 x 32 xf 8 E 8 M 0 FNU, 40960 x 128 x 1 x 0 > -> <2 x320 x4096 xf8 E8 M0 FNU, 1310720 x 4096 x 1 >
9+ func.func @quant_dot_multi_reduce (%arg0: !migraphx.shaped <2 x32 x10 x64 x64 xf32 , 0x10x1x0x0 >, %arg1: !migraphx.shaped <2 x320 x320 xf4 E2 M1 FN, 102400 x320 x1 >, %arg2: !migraphx.shaped <2 x320 x4096 xf4 E2 M1 FN, 1310720 x4096 x1 >, %arg3: !migraphx.shaped <2 x 320 x 10 x 1 xf 8 E 8 M 0 FNU, 3200 x 10 x 1 x 1 >, %arg4: !migraphx.shaped <2 x 10 x 1 x 4096 xf 8 E 8 M 0 FNU, 40960 x 4096 x 4096 x 1 >) -> (!migraphx.shaped <2 x32 x1 x1 x1 xf32 , 32 x1 x1 x1 x1 >, !migraphx.shaped <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >) attributes {arch = " " , enable_splitk_for_tuning , kernel = " mixr" } {
10+ %0 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2 , 320 , 10 , 32 ]} : <2 x 320 x 10 x 1 xf 8 E 8 M 0 FNU, 3200 x 10 x 1 x 1 > -> <2 x 320 x 10 x 32 xf 8 E 8 M 0 FNU, 3200 x 10 x 0 x 1 >
11+ %1 = migraphx.reshape %0 {dims = [2 , 320 , 320 ]} : <2 x 320 x 10 x 32 xf 8 E 8 M 0 FNU, 3200 x 10 x 0 x 1 > -> <2 x320 x320 xf8 E8 M0 FNU, 3200 x 10 x 1 >
12+ %2 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [2 , 10 , 32 , 4096 ]} : <2 x 10 x 1 x 4096 xf 8 E 8 M 0 FNU, 40960 x 4096 x 4096 x 1 > -> <2 x 10 x 32 x 4096 xf 8 E 8 M 0 FNU, 40960 x 0 x 4096 x 1 >
13+ %3 = migraphx.reshape %2 {dims = [2 , 320 , 4096 ]} : <2 x 10 x 32 x 4096 xf 8 E 8 M 0 FNU, 40960 x 0 x 4096 x 1 > -> <2 x320 x4096 xf8 E8 M0 FNU, 40960 x 4096 x 1 >
1214 %4 = migraphx.literal (dense <2.44140629E-5 > : tensor <1 xf32 >) : <1 xf32 , 0 >
13- %5 = migraphx.quant_dot %arg1 scaled by %1 , %arg2 scaled by %3 {perf_config =" v3:64,64,16,32,32,32,4,1,2,1,1" } : <2 x320 x320 xf4 E2 M1 FN, 102400 x320 x1 > scaled by !migraphx.shaped <2 x320 x320 xf8 E8 M0 FNU, 102400 x 320 x 1 >, <2 x320 x4096 xf4 E2 M1 FN, 1310720 x4096 x1 > scaled by !migraphx.shaped <2 x320 x4096 xf8 E8 M0 FNU, 1310720 x 4096 x 1 > -> <2 x320 x4096 xf32 , 1310720 x4096 x1 >
15+ %5 = migraphx.quant_dot %arg1 scaled by %1 , %arg2 scaled by %3 {perf_config =" v3:64,64,16,32,32,32,4,1,2,1,1" } : <2 x320 x320 xf4 E2 M1 FN, 102400 x320 x1 > scaled by !migraphx.shaped <2 x320 x320 xf8 E8 M0 FNU, 3200 x 10 x 1 >, <2 x320 x4096 xf4 E2 M1 FN, 1310720 x4096 x1 > scaled by !migraphx.shaped <2 x320 x4096 xf8 E8 M0 FNU, 40960 x 4096 x 1 > -> <2 x320 x4096 xf32 , 1310720 x4096 x1 >
1416 %6 = migraphx.reshape %5 {dims = [2 , 32 , 10 , 64 , 64 ]} : <2 x320 x4096 xf32 , 1310720 x4096 x1 > -> <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >
1517 %7 = migraphx.add %6 , %arg0 : <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >, <2 x32 x10 x64 x64 xf32 , 0x10x1x0x0 > -> <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >
1618 %8 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2 , 32 , 10 , 64 , 64 ]} : <1 xf32 , 0 > -> <2 x32 x10 x64 x64 xf32 , 0x0x0x0x0 >
@@ -21,4 +23,3 @@ module {
2123 return %12 , %7 : !migraphx.shaped <2 x32 x1 x1 x1 xf32 , 32 x1 x1 x1 x1 >, !migraphx.shaped <2 x32 x10 x64 x64 xf32 , 1310720 x40960 x4096 x64 x1 >
2224 }
2325}
24-
0 commit comments