Skip to content

Commit 58c9521

Browse files
yuvaltassacopybara-github
authored andcommitted
Modify mjSORT to remove some unnecessary copies.
PiperOrigin-RevId: 781573097 Change-Id: I60148bc88fb64ea6dfee9390efb3fd161ad3f1d8
1 parent 365cac4 commit 58c9521

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

src/engine/engine_sort.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,18 @@
3232
}
3333

3434
// sub-macro that merges two sub-sorted arrays [start, ..., mid), [mid, ..., end) together
35-
#define _mjMERGE(type, arr, buf, start, mid, end, cmp, context) \
35+
#define _mjMERGE(type, src, dest, start, mid, end, cmp, context) \
3636
{ \
37-
int len1 = mid - start, len2 = end - mid; \
38-
type* left = buf, *right = buf + len1; \
39-
for (int i = 0; i < len1; i++) left[i] = arr[start + i]; \
40-
for (int i = 0; i < len2; i++) right[i] = arr[mid + i]; \
41-
int i = 0, j = 0, k = start; \
42-
while (i < len1 && j < len2) { \
43-
if (cmp(left + i, right + j, context) <= 0) { \
44-
arr[k++] = left[i++]; \
37+
int i = start, j = mid, k = start; \
38+
while (i < mid && j < end) { \
39+
if (cmp(src + i, src + j, context) <= 0) { \
40+
dest[k++] = src[i++]; \
4541
} else { \
46-
arr[k++] = right[j++]; \
42+
dest[k++] = src[j++]; \
4743
} \
4844
} \
49-
while (i < len1) arr[k++] = left[i++]; \
50-
while (j < len2) arr[k++] = right[j++]; \
45+
if (i < mid) memcpy(dest + k, src + i, (mid - i) * sizeof(type)); \
46+
else if (j < end) memcpy(dest + k, src + j, (end - j) * sizeof(type)); \
5147
}
5248

5349
// defines an inline stable sorting function via tiled merge sorting (timsort)
@@ -60,15 +56,20 @@
6056
int end = (start + _mjRUNSIZE < n) ? start + _mjRUNSIZE : n; \
6157
_mjINSERTION_SORT(type, arr, start, end, cmp, context); \
6258
} \
59+
type* src = arr, *dest = buf, *tmp; \
6360
for (int len = _mjRUNSIZE; len < n; len *= 2) { \
6461
for (int start = 0; start < n; start += 2*len) { \
6562
int mid = start + len; \
6663
int end = (start + 2*len < n) ? start + 2*len : n; \
6764
if (mid < end) { \
68-
_mjMERGE(type, arr, buf, start, mid, end, cmp, context); \
65+
_mjMERGE(type, src, dest, start, mid, end, cmp, context); \
66+
} else { \
67+
memcpy(dest + start, src + start, (end - start) * sizeof(type)); \
6968
} \
7069
} \
70+
tmp = src; src = dest; dest = tmp; \
7171
} \
72+
if (src != arr) memcpy(arr, src, n * sizeof(type)); \
7273
}
7374

7475
#endif // MUJOCO_SRC_ENGINE_ENGINE_SORT_H_

0 commit comments

Comments
 (0)