Skip to content

Commit 365cac4

Browse files
yuvaltassacopybara-github
authored andcommitted
Add benchmarks for mjSORT macro.
PiperOrigin-RevId: 781555531 Change-Id: Ia33b5512fbf65b08931037859d8c6062b018d760
1 parent 0b11e3e commit 365cac4

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright 2025 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <algorithm>
16+
#include <random>
17+
#include <vector>
18+
19+
#include "benchmark/benchmark.h"
20+
#include <absl/base/attributes.h>
21+
#include <mujoco/mujoco.h>
22+
#include "src/engine/engine_sort.h"
23+
24+
namespace mujoco {
25+
namespace {
26+
27+
// A struct with data to be sorted.
28+
struct Sortable {
29+
mjtNum value;
30+
int id1;
31+
int id2;
32+
};
33+
34+
// Comparison function for Sortable struct.
35+
int CompareSortable(const Sortable* a, const Sortable* b, void* context) {
36+
if (a->value < b->value) {
37+
return -1;
38+
} else if (a->value == b->value) {
39+
return 0;
40+
} else {
41+
return 1;
42+
}
43+
}
44+
45+
// Instantiate the sort function.
46+
mjSORT(SortSortable, Sortable, CompareSortable)
47+
48+
// Generate data for sorting benchmarks.
49+
std::vector<Sortable> GenerateData(int n, double unsorted_fraction) {
50+
std::vector<Sortable> data(n);
51+
for (int i = 0; i < n; ++i) {
52+
data[i] = {static_cast<mjtNum>(i), i, -i};
53+
}
54+
55+
if (unsorted_fraction > 0.0) {
56+
std::mt19937 g(12345);
57+
if (unsorted_fraction >= 1.0) {
58+
std::shuffle(data.begin(), data.end(), g);
59+
} else {
60+
int num_to_shuffle = n * unsorted_fraction;
61+
for (int i = 0; i < num_to_shuffle; ++i) {
62+
int j = std::uniform_int_distribution<int>(i, n - 1)(g);
63+
std::swap(data[i], data[j]);
64+
}
65+
}
66+
}
67+
return data;
68+
}
69+
70+
// Run a sorting benchmark with the given size and unsorted fraction.
71+
ABSL_ATTRIBUTE_NO_TAIL_CALL static void SortBenchmark(
72+
benchmark::State& state, int n, double unsorted_fraction) {
73+
auto data = GenerateData(n, unsorted_fraction);
74+
std::vector<Sortable> buf(n);
75+
std::vector<Sortable> copy = data;
76+
77+
for (auto s : state) {
78+
state.PauseTiming();
79+
copy = data;
80+
state.ResumeTiming();
81+
SortSortable(copy.data(), buf.data(), n, nullptr);
82+
}
83+
84+
state.counters["items/s"] =
85+
benchmark::Counter(state.iterations() * n, benchmark::Counter::kIsRate);
86+
}
87+
88+
// Define benchmarks.
89+
void BM_Sort_1k_AlmostSorted(benchmark::State& state) {
90+
SortBenchmark(state, 1000, 0.05);
91+
}
92+
BENCHMARK(BM_Sort_1k_AlmostSorted);
93+
94+
void BM_Sort_1k_Random(benchmark::State& state) {
95+
SortBenchmark(state, 1000, 1.0);
96+
}
97+
BENCHMARK(BM_Sort_1k_Random);
98+
99+
void BM_Sort_100k_AlmostSorted(benchmark::State& state) {
100+
SortBenchmark(state, 100000, 0.05);
101+
}
102+
BENCHMARK(BM_Sort_100k_AlmostSorted);
103+
104+
void BM_Sort_100k_Random(benchmark::State& state) {
105+
SortBenchmark(state, 100000, 1.0);
106+
}
107+
BENCHMARK(BM_Sort_100k_Random);
108+
109+
// Run a sorting benchmark with std::stable_sort
110+
ABSL_ATTRIBUTE_NO_TAIL_CALL static void StdSortBenchmark(
111+
benchmark::State& state, int n, double unsorted_fraction) {
112+
auto data = GenerateData(n, unsorted_fraction);
113+
std::vector<Sortable> copy = data;
114+
115+
for (auto s : state) {
116+
state.PauseTiming();
117+
copy = data;
118+
state.ResumeTiming();
119+
std::stable_sort(
120+
copy.begin(), copy.end(),
121+
[](const Sortable& a, const Sortable& b) { return a.value < b.value; });
122+
}
123+
124+
state.counters["items/s"] =
125+
benchmark::Counter(state.iterations() * n, benchmark::Counter::kIsRate);
126+
}
127+
128+
void BM_StdSort_1k_AlmostSorted(benchmark::State& state) {
129+
StdSortBenchmark(state, 1000, 0.05);
130+
}
131+
BENCHMARK(BM_StdSort_1k_AlmostSorted);
132+
133+
void BM_StdSort_1k_Random(benchmark::State& state) {
134+
StdSortBenchmark(state, 1000, 1.0);
135+
}
136+
BENCHMARK(BM_StdSort_1k_Random);
137+
138+
void BM_StdSort_100k_AlmostSorted(benchmark::State& state) {
139+
StdSortBenchmark(state, 100000, 0.05);
140+
}
141+
BENCHMARK(BM_StdSort_100k_AlmostSorted);
142+
143+
void BM_StdSort_100k_Random(benchmark::State& state) {
144+
StdSortBenchmark(state, 100000, 1.0);
145+
}
146+
BENCHMARK(BM_StdSort_100k_Random);
147+
148+
} // namespace
149+
} // namespace mujoco
150+
151+
int main(int argc, char** argv) {
152+
benchmark::Initialize(&argc, argv);
153+
benchmark::RunSpecifiedBenchmarks();
154+
return 0;
155+
}

0 commit comments

Comments
 (0)