Skip to content

Commit 8946be7

Browse files
committed
Fix unit tests.
1 parent 3a8dfca commit 8946be7

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

bin/mfa/na_int8_matmul_bench.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct VariantConfig {
4747
struct Stats {
4848
double average_seconds = 0;
4949
double median_seconds = 0;
50+
double best3_average_seconds = 0;
5051
double min_seconds = 0;
5152
double max_seconds = 0;
5253
};
@@ -1070,6 +1071,9 @@ bool benchmark(const BenchmarkConfig& config, RunOnce&& run_once, Stats* stats)
10701071
stats->average_seconds = std::accumulate(samples.begin(), samples.end(), 0.0) / samples.size();
10711072
std::sort(samples.begin(), samples.end());
10721073
stats->median_seconds = samples[samples.size() / 2];
1074+
const size_t best_count = std::min<size_t>(3, samples.size());
1075+
stats->best3_average_seconds =
1076+
std::accumulate(samples.begin(), samples.begin() + best_count, 0.0) / best_count;
10731077
stats->min_seconds = samples.front();
10741078
stats->max_seconds = samples.back();
10751079
return true;
@@ -1194,6 +1198,7 @@ void print_stats(const char* label, const BenchmarkCase& bench, const Stats& sta
11941198
std::cout << label
11951199
<< " avg_ms=" << std::fixed << std::setprecision(3) << stats.average_seconds * 1e3
11961200
<< " median_ms=" << stats.median_seconds * 1e3
1201+
<< " best3_avg_ms=" << stats.best3_average_seconds * 1e3
11971202
<< " min_ms=" << stats.min_seconds * 1e3
11981203
<< " max_ms=" << stats.max_seconds * 1e3
11991204
<< " avg_gflops=" << flops / stats.average_seconds / 1e9
@@ -1648,14 +1653,18 @@ int main(int argc, char** argv)
16481653
std::cout << "speedup";
16491654
if (benchmark_raw)
16501655
std::cout << " raw_kernel_avg=" << baseline_stats.average_seconds / raw_stats.average_seconds
1651-
<< " raw_kernel_median=" << baseline_stats.median_seconds / raw_stats.median_seconds;
1656+
<< " raw_kernel_median=" << baseline_stats.median_seconds / raw_stats.median_seconds
1657+
<< " raw_kernel_best3=" << baseline_stats.best3_average_seconds / raw_stats.best3_average_seconds;
16521658
std::cout << " kernel_avg=" << baseline_stats.average_seconds / dynamic_stats.average_seconds
1653-
<< " kernel_median=" << baseline_stats.median_seconds / dynamic_stats.median_seconds;
1659+
<< " kernel_median=" << baseline_stats.median_seconds / dynamic_stats.median_seconds
1660+
<< " kernel_best3=" << baseline_stats.best3_average_seconds / dynamic_stats.best3_average_seconds;
16541661
if (benchmark_splitk)
16551662
std::cout << " splitk_kernel_avg=" << baseline_stats.average_seconds / splitk_stats.average_seconds
1656-
<< " splitk_kernel_median=" << baseline_stats.median_seconds / splitk_stats.median_seconds;
1663+
<< " splitk_kernel_median=" << baseline_stats.median_seconds / splitk_stats.median_seconds
1664+
<< " splitk_kernel_best3=" << baseline_stats.best3_average_seconds / splitk_stats.best3_average_seconds;
16571665
std::cout << " end_to_end_avg=" << baseline_stats.average_seconds / combined_stats.average_seconds
16581666
<< " end_to_end_median=" << baseline_stats.median_seconds / combined_stats.median_seconds
1667+
<< " end_to_end_best3=" << baseline_stats.best3_average_seconds / combined_stats.best3_average_seconds
16591668
<< '\n';
16601669
std::cout.flush();
16611670
std::_Exit(0);

test/int/nnc/mpsblas.tests.c

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,8 +3516,6 @@ TEST_CASE("scaled dot product attention gradient with quantized NA mps")
35163516
{
35173517
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_MPS) &&
35183518
ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_MPS));
3519-
ccv_nnc_mps_set_binary_artifacts(0, 0, 0);
3520-
ccv_nnc_mps_clear_graph_executable_cache();
35213519
const int B = 2;
35223520
const int R = 128;
35233521
const int C = 128;
@@ -3732,8 +3730,6 @@ TEST_CASE("scaled dot product attention gradient with quantized NA mps for recta
37323730
{
37333731
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_MPS) &&
37343732
ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_MPS));
3735-
ccv_nnc_mps_set_binary_artifacts(0, 0, 0);
3736-
ccv_nnc_mps_clear_graph_executable_cache();
37373733
typedef struct {
37383734
int R;
37393735
int C;

0 commit comments

Comments
 (0)