Skip to content

Commit e116a04

Browse files
authored
Fix MIXR Fp4 splitK Random float tests (#2102)
* Fix rand 1 -rand_type float tests
1 parent 5a601b9 commit e116a04

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

mlir/test/fusion/pr-e2e/mixr-gemm-splitk/f4/mixr-quant-dot-multi-output-add.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
// RUN: rocmlir-gen -fut quant_dot_multi_output_add --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand none -fut quant_dot_multi_output_add_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2-
2+
// RUN: rocmlir-gen -fut quant_dot_multi_output_add --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut quant_dot_multi_output_add_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CLONE
33
// We need a check for each output as this test case has two outputs in it.
44
// CHECK: [1 1 1]
55
// CHECK: [1 1 1]
6+
// CLONE: [1 1 1]
7+
// CLONE: [1 1 1]
68
module {
7-
func.func @quant_dot_multi_output_add(%arg0: !migraphx.shaped<1x64x128xf4E2M1FN, 8192x128x1>, %arg1: !migraphx.shaped<1x128x64xf4E2M1FN, 8192x64x1>, %arg2: !migraphx.shaped<1x2x1x128xf8E8M0FNU, 256x128x128x1>, %arg3: !migraphx.shaped<1x128x2x1xf8E8M0FNU, 256x2x1x1>) -> (!migraphx.shaped<1x64x64xf32, 4096x64x1>, !migraphx.shaped<1x64x64xf32, 4096x64x1>) attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
8-
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 2, 32, 128]} : <1x2x1x128xf8E8M0FNU, 256x128x128x1> -> <1x2x32x128xf8E8M0FNU, 256x128x0x1>
9-
%1 = migraphx.reshape %0 {dims = [1, 64, 128]} : <1x2x32x128xf8E8M0FNU, 256x128x0x1> -> <1x64x128xf8E8M0FNU, 8192x128x1>
10-
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 128, 2, 32]} : <1x128x2x1xf8E8M0FNU, 256x2x1x1> -> <1x128x2x32xf8E8M0FNU, 256x2x1x0>
11-
%3 = migraphx.reshape %2 {dims = [1, 128, 64]} : <1x128x2x32xf8E8M0FNU, 256x2x1x0> -> <1x128x64xf8E8M0FNU, 8192x64x1>
9+
func.func @quant_dot_multi_output_add(%arg0: !migraphx.shaped<1x64x128xf4E2M1FN, 8192x128x1>, %arg1: !migraphx.shaped<1x128x64xf4E2M1FN, 8192x64x1>, %arg2: !migraphx.shaped<1x64x4x1xf8E8M0FNU, 256x4x1x1>, %arg3: !migraphx.shaped<1x4x1x64xf8E8M0FNU, 256x64x64x1>) -> (!migraphx.shaped<1x64x64xf32, 4096x64x1>, !migraphx.shaped<1x64x64xf32, 4096x64x1>) attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
10+
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 64, 4, 32]} : <1x64x4x1xf8E8M0FNU, 256x4x1x1> -> <1x64x4x32xf8E8M0FNU, 256x4x0x1>
11+
%1 = migraphx.reshape %0 {dims = [1, 64, 128]} : <1x64x4x32xf8E8M0FNU, 256x4x0x1> -> <1x64x128xf8E8M0FNU, 256x4x1>
12+
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 4, 32, 64]} : <1x4x1x64xf8E8M0FNU, 256x64x64x1> -> <1x4x32x64xf8E8M0FNU, 256x0x64x1>
13+
%3 = migraphx.reshape %2 {dims = [1, 128, 64]} : <1x4x32x64xf8E8M0FNU, 256x0x64x1> -> <1x128x64xf8E8M0FNU, 256x64x1>
1214
%4 = migraphx.literal(dense<1.0> : tensor<1xf32>) : <1xf32, 0>
1315
%5 = migraphx.literal(dense<2.0> : tensor<1xf32>) : <1xf32, 0>
14-
%6 = migraphx.quant_dot %arg0 scaled by %1, %arg1 scaled by %3 {perf_config="v3:64,64,16,32,32,32,3,1,2,1,1"} : <1x64x128xf4E2M1FN, 8192x128x1> scaled by !migraphx.shaped<1x64x128xf8E8M0FNU, 8192x128x1>, <1x128x64xf4E2M1FN, 8192x64x1> scaled by !migraphx.shaped<1x128x64xf8E8M0FNU, 8192x64x1> -> <1x64x64xf32, 4096x64x1>
16+
%6 = migraphx.quant_dot %arg0 scaled by %1, %arg1 scaled by %3 {perf_config="v3:64,64,16,32,32,32,3,1,2,1,1"} : <1x64x128xf4E2M1FN, 8192x128x1> scaled by !migraphx.shaped<1x64x128xf8E8M0FNU, 256x4x1>, <1x128x64xf4E2M1FN, 8192x64x1> scaled by !migraphx.shaped<1x128x64xf8E8M0FNU, 256x64x1> -> <1x64x64xf32, 4096x64x1>
1517
%7 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [1, 64, 64]} : <1xf32, 0> -> <1x64x64xf32, 0x0x0>
1618
%8 = migraphx.multibroadcast %5 {out_dyn_dims = [], out_lens = [1, 64, 64]} : <1xf32, 0> -> <1x64x64xf32, 0x0x0>
1719
%9 = migraphx.add %6, %7 : <1x64x64xf32, 4096x64x1>, <1x64x64xf32, 0x0x0> -> <1x64x64xf32, 4096x64x1>

mlir/test/fusion/pr-e2e/mixr-gemm-splitk/f4/mixr-quant-dot-multi-reduce-splitk.e2e.mlir

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
// RUN: rocmlir-gen -fut quant_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand none -RMS_threshold=3e-3 -absDiff_threshold 7e-1 -relDiff_threshold 3e-3 -fut quant_dot_multi_reduce_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2-
2+
// RUN: rocmlir-gen -fut quant_dot_multi_reduce --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_min 1 -rand_max 2 -rand_type float -RMS_threshold=3e-3 -absDiff_threshold 7e-1 -relDiff_threshold 3e-3 -fut quant_dot_multi_reduce_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CLONE
33
// We need a check for each output as this test case has two outputs in it.
44
// CHECK: [1 1 1]
55
// CHECK: [1 1 1]
6+
// CLONE: [1 1 1]
7+
// CLONE: [1 1 1]
68
module {
7-
func.func @quant_dot_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf32, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x320x320xf4E2M1FN, 102400x320x1>, %arg2: !migraphx.shaped<2x320x4096xf4E2M1FN, 1310720x4096x1>, %arg3: !migraphx.shaped<2x10x1x320xf8E8M0FNU, 3200x320x320x1>, %arg4: !migraphx.shaped<2x320x128x1xf8E8M0FNU, 40960x128x1x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>) attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
8-
%0 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2, 10, 32, 320]} : <2x10x1x320xf8E8M0FNU, 3200x320x320x1> -> <2x10x32x320xf8E8M0FNU, 3200x320x0x1>
9-
%1 = migraphx.reshape %0 {dims = [2, 320, 320]} : <2x10x32x320xf8E8M0FNU, 3200x320x0x1> -> <2x320x320xf8E8M0FNU, 102400x320x1>
10-
%2 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [2, 320, 128, 32]} : <2x320x128x1xf8E8M0FNU, 40960x128x1x1> -> <2x320x128x32xf8E8M0FNU, 40960x128x1x0>
11-
%3 = migraphx.reshape %2 {dims = [2, 320, 4096]} : <2x320x128x32xf8E8M0FNU, 40960x128x1x0> -> <2x320x4096xf8E8M0FNU, 1310720x4096x1>
9+
func.func @quant_dot_multi_reduce(%arg0: !migraphx.shaped<2x32x10x64x64xf32, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x320x320xf4E2M1FN, 102400x320x1>, %arg2: !migraphx.shaped<2x320x4096xf4E2M1FN, 1310720x4096x1>, %arg3: !migraphx.shaped<2x320x10x1xf8E8M0FNU, 3200x10x1x1>, %arg4: !migraphx.shaped<2x10x1x4096xf8E8M0FNU, 40960x4096x4096x1>) -> (!migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>) attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
10+
%0 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [2, 320, 10, 32]} : <2x320x10x1xf8E8M0FNU, 3200x10x1x1> -> <2x320x10x32xf8E8M0FNU, 3200x10x0x1>
11+
%1 = migraphx.reshape %0 {dims = [2, 320, 320]} : <2x320x10x32xf8E8M0FNU, 3200x10x0x1> -> <2x320x320xf8E8M0FNU, 3200x10x1>
12+
%2 = migraphx.multibroadcast %arg4 {out_dyn_dims = [], out_lens = [2, 10, 32, 4096]} : <2x10x1x4096xf8E8M0FNU, 40960x4096x4096x1> -> <2x10x32x4096xf8E8M0FNU, 40960x0x4096x1>
13+
%3 = migraphx.reshape %2 {dims = [2, 320, 4096]} : <2x10x32x4096xf8E8M0FNU, 40960x0x4096x1> -> <2x320x4096xf8E8M0FNU, 40960x4096x1>
1214
%4 = migraphx.literal(dense<2.44140629E-5> : tensor<1xf32>) : <1xf32, 0>
13-
%5 = migraphx.quant_dot %arg1 scaled by %1, %arg2 scaled by %3 {perf_config="v3:64,64,16,32,32,32,4,1,2,1,1"} : <2x320x320xf4E2M1FN, 102400x320x1> scaled by !migraphx.shaped<2x320x320xf8E8M0FNU, 102400x320x1>, <2x320x4096xf4E2M1FN, 1310720x4096x1> scaled by !migraphx.shaped<2x320x4096xf8E8M0FNU, 1310720x4096x1> -> <2x320x4096xf32, 1310720x4096x1>
15+
%5 = migraphx.quant_dot %arg1 scaled by %1, %arg2 scaled by %3 {perf_config="v3:64,64,16,32,32,32,4,1,2,1,1"} : <2x320x320xf4E2M1FN, 102400x320x1> scaled by !migraphx.shaped<2x320x320xf8E8M0FNU, 3200x10x1>, <2x320x4096xf4E2M1FN, 1310720x4096x1> scaled by !migraphx.shaped<2x320x4096xf8E8M0FNU, 40960x4096x1> -> <2x320x4096xf32, 1310720x4096x1>
1416
%6 = migraphx.reshape %5 {dims = [2, 32, 10, 64, 64]} : <2x320x4096xf32, 1310720x4096x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1>
1517
%7 = migraphx.add %6, %arg0 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x10x1x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1>
1618
%8 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2, 32, 10, 64, 64]} : <1xf32, 0> -> <2x32x10x64x64xf32, 0x0x0x0x0>
@@ -21,4 +23,3 @@ module {
2123
return %12, %7 : !migraphx.shaped<2x32x1x1x1xf32, 32x1x1x1x1>, !migraphx.shaped<2x32x10x64x64xf32, 1310720x40960x4096x64x1>
2224
}
2325
}
24-

mlir/test/fusion/pr-e2e/mixr-gemm-splitk/f4/mixr-quant-dot-splitk-add.e2e.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
// RUN: rocmlir-gen -fut quant_dot_splitk_add --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand none -fut quant_dot_splitk_add_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2+
3+
// RUN: rocmlir-gen -fut quant_dot_splitk_add --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -print-results -rand 1 -rand_type float -fut quant_dot_splitk_add_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CLONE
24
module {
35
// CHECK: [1 1 1]
4-
func.func @quant_dot_splitk_add(%arg0: !migraphx.shaped<1x128x256xf4E2M1FN, 32768x256x1>, %arg1: !migraphx.shaped<1x256x128xf4E2M1FN, 32768x128x1>, %arg2: !migraphx.shaped<1x4x1x256xf8E8M0FNU, 1024x256x256x1>, %arg3: !migraphx.shaped<1x256x4x1xf8E8M0FNU, 1024x4x1x1>, %arg4: !migraphx.shaped<1x128x128xf32, 16384x128x1>) -> !migraphx.shaped<1x128x128xf32, 16384x128x1> attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
5-
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 4, 32, 256]} : <1x4x1x256xf8E8M0FNU, 1024x256x256x1> -> <1x4x32x256xf8E8M0FNU, 1024x256x0x1>
6-
%1 = migraphx.reshape %0 {dims = [1, 128, 256]} : <1x4x32x256xf8E8M0FNU, 1024x256x0x1> -> <1x128x256xf8E8M0FNU, 32768x256x1>
7-
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 256, 4, 32]} : <1x256x4x1xf8E8M0FNU, 1024x4x1x1> -> <1x256x4x32xf8E8M0FNU, 1024x4x1x0>
8-
%3 = migraphx.reshape %2 {dims = [1, 256, 128]} : <1x256x4x32xf8E8M0FNU, 1024x4x1x0> -> <1x256x128xf8E8M0FNU, 32768x128x1>
9-
%4 = migraphx.quant_dot %arg0 scaled by %1, %arg1 scaled by %3 {perf_config="v3:64,64,16,32,32,32,4,1,2,1,1"} : <1x128x256xf4E2M1FN, 32768x256x1> scaled by !migraphx.shaped<1x128x256xf8E8M0FNU, 32768x256x1>, <1x256x128xf4E2M1FN, 32768x128x1> scaled by !migraphx.shaped<1x256x128xf8E8M0FNU, 32768x128x1> -> <1x128x128xf32, 16384x128x1>
6+
// CLONE: [1 1 1]
7+
func.func @quant_dot_splitk_add(%arg0: !migraphx.shaped<1x128x256xf4E2M1FN, 32768x256x1>, %arg1: !migraphx.shaped<1x256x128xf4E2M1FN, 32768x128x1>, %arg2: !migraphx.shaped<1x128x8x1xf8E8M0FNU, 1024x8x1x1>, %arg3: !migraphx.shaped<1x8x1x128xf8E8M0FNU, 1024x128x128x1>, %arg4: !migraphx.shaped<1x128x128xf32, 16384x128x1>) -> !migraphx.shaped<1x128x128xf32, 16384x128x1> attributes{arch = "", enable_splitk_for_tuning, kernel = "mixr"} {
8+
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 128, 8, 32]} : <1x128x8x1xf8E8M0FNU, 1024x8x1x1> -> <1x128x8x32xf8E8M0FNU, 1024x8x0x1>
9+
%1 = migraphx.reshape %0 {dims = [1, 128, 256]} : <1x128x8x32xf8E8M0FNU, 1024x8x0x1> -> <1x128x256xf8E8M0FNU, 1024x8x1>
10+
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 8, 32, 128]} : <1x8x1x128xf8E8M0FNU, 1024x128x128x1> -> <1x8x32x128xf8E8M0FNU, 1024x0x128x1>
11+
%3 = migraphx.reshape %2 {dims = [1, 256, 128]} : <1x8x32x128xf8E8M0FNU, 1024x0x128x1> -> <1x256x128xf8E8M0FNU, 1024x128x1>
12+
%4 = migraphx.quant_dot %arg0 scaled by %1, %arg1 scaled by %3 {perf_config="v3:64,64,16,32,32,32,4,1,2,1,1"} : <1x128x256xf4E2M1FN, 32768x256x1> scaled by !migraphx.shaped<1x128x256xf8E8M0FNU, 1024x8x1>, <1x256x128xf4E2M1FN, 32768x128x1> scaled by !migraphx.shaped<1x256x128xf8E8M0FNU, 1024x128x1> -> <1x128x128xf32, 16384x128x1>
1013
%5 = migraphx.add %4, %arg4 {} : <1x128x128xf32, 16384x128x1>, <1x128x128xf32, 16384x128x1> -> <1x128x128xf32, 16384x128x1>
1114
return %5 : !migraphx.shaped<1x128x128xf32, 16384x128x1>
1215
}

mlir/test/fusion/pr-e2e/mixr-gemm-splitk/f4/mixr-quant-dot-splitk.e2e.mlir

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// RUN: rocmlir-gen -fut quant_dot_splitk --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand none -fut quant_dot_splitk_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2+
3+
// RUN: rocmlir-gen -fut quant_dot_splitk --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -print-results -rand 1 -rand_type float -fut quant_dot_splitk_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=CLONE
24
module {
35
// CHECK: [1 1 1]
4-
func.func @quant_dot_splitk(%arg0: !migraphx.shaped<1x64x128xf4E2M1FN, 8192x128x1>, %arg1: !migraphx.shaped<1x128x64xf4E2M1FN, 8192x64x1>, %arg2: !migraphx.shaped<1x2x1x128xf8E8M0FNU, 256x128x128x1>, %arg3: !migraphx.shaped<1x128x2x1xf8E8M0FNU, 256x2x1x1>) -> !migraphx.shaped<1x64x64xf32, 4096x64x1> attributes{arch = "gfx950", enable_splitk_for_tuning, kernel = "mixr"} {
5-
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 2, 32, 128]} : <1x2x1x128xf8E8M0FNU, 256x128x128x1> -> <1x2x32x128xf8E8M0FNU, 256x128x0x1>
6-
%1 = migraphx.reshape %0 {dims = [1, 64, 128]} : <1x2x32x128xf8E8M0FNU, 256x128x0x1> -> <1x64x128xf8E8M0FNU, 8192x128x1>
7-
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 128, 2, 32]} : <1x128x2x1xf8E8M0FNU, 256x2x1x1> -> <1x128x2x32xf8E8M0FNU, 256x2x1x0>
8-
%3 = migraphx.reshape %2 {dims = [1, 128, 64]} : <1x128x2x32xf8E8M0FNU, 256x2x1x0> -> <1x128x64xf8E8M0FNU, 8192x64x1>
6+
// CLONE: [1 1 1]
7+
func.func @quant_dot_splitk(%arg0: !migraphx.shaped<1x64x128xf4E2M1FN, 8192x128x1>, %arg1: !migraphx.shaped<1x128x64xf4E2M1FN, 8192x64x1>, %arg2: !migraphx.shaped<1x64x4x1xf8E8M0FNU, 256x4x1x1>, %arg3: !migraphx.shaped<1x4x1x64xf8E8M0FNU, 256x64x64x1>) -> !migraphx.shaped<1x64x64xf32, 4096x64x1> attributes{arch = "gfx950", enable_splitk_for_tuning, kernel = "mixr"} {
8+
%0 = migraphx.multibroadcast %arg2 {out_dyn_dims = [], out_lens = [1, 64, 4, 32]} : <1x64x4x1xf8E8M0FNU, 256x4x1x1> -> <1x64x4x32xf8E8M0FNU, 256x4x1x0>
9+
%1 = migraphx.reshape %0 {dims = [1, 64, 128]} : <1x64x4x32xf8E8M0FNU, 256x4x1x0> -> <1x64x128xf8E8M0FNU, 8192x128x1>
10+
%2 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 4, 32, 64]} : <1x4x1x64xf8E8M0FNU, 256x64x64x1> -> <1x4x32x64xf8E8M0FNU, 256x64x0x1>
11+
%3 = migraphx.reshape %2 {dims = [1, 128, 64]} : <1x4x32x64xf8E8M0FNU, 256x64x0x1> -> <1x128x64xf8E8M0FNU, 8192x64x1>
912
%4 = migraphx.quant_dot %arg0 scaled by %1, %arg1 scaled by %3 {perf_config="v3:64,64,16,32,32,32,2,1,2,1,1"} : <1x64x128xf4E2M1FN, 8192x128x1> scaled by !migraphx.shaped<1x64x128xf8E8M0FNU, 8192x128x1>, <1x128x64xf4E2M1FN, 8192x64x1> scaled by !migraphx.shaped<1x128x64xf8E8M0FNU, 8192x64x1> -> <1x64x64xf32, 4096x64x1>
1013
return %4 : !migraphx.shaped<1x64x64xf32, 4096x64x1>
1114
}

0 commit comments

Comments
 (0)