Skip to content

Commit ac11e5f

Browse files
yuvaltassacopybara-github
authored andcommitted
Add CSR implementation of mj_factorI
PiperOrigin-RevId: 712498431 Change-Id: I13b52e53482ed97da8788875d4d95e2beb5ca7c1
1 parent 7eb8231 commit ac11e5f

File tree

9 files changed

+514
-14
lines changed

9 files changed

+514
-14
lines changed

src/engine/engine_core_smooth.c

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,10 +1440,9 @@ void mj_factorI(const mjModel* m, mjData* d, const mjtNum* M, mjtNum* qLD, mjtNu
14401440
}
14411441
}
14421442

1443-
// compute 1/diag(D), 1/sqrt(diag(D))
1443+
// compute 1/diag(D)
14441444
for (int i=0; i < nv; i++) {
1445-
mjtNum qLDi = qLD[dof_Madr[i]];
1446-
qLDiagInv[i] = 1.0/qLDi;
1445+
qLDiagInv[i] = 1.0 / qLD[dof_Madr[i]];
14471446
}
14481447
}
14491448

@@ -1458,6 +1457,40 @@ void mj_factorM(const mjModel* m, mjData* d) {
14581457

14591458

14601459

1460+
// sparse L'*D*L factorizaton of inertia-like matrix M, assumed spd
1461+
// like mj_factorI, but using CSR representation
1462+
void mj_factorIs(mjtNum* mat, mjtNum* diaginv, int nv,
1463+
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind) {
1464+
// backward loop over rows
1465+
for (int k=nv-1; k >= 0; k--) {
1466+
// get row k's address, diagonal index, inverse diagonal value
1467+
int rowadr_k = rowadr[k];
1468+
int diag_k = rowadr_k + rownnz[k] - 1;
1469+
mjtNum invD = 1 / mat[diag_k];
1470+
if (diaginv) diaginv[k] = invD;
1471+
1472+
// skip if simple
1473+
if (diagnum[k]) {
1474+
continue;
1475+
}
1476+
1477+
// update triangle above row k, inclusive
1478+
for (int adr=diag_k - 1; adr >= rowadr_k; adr--) {
1479+
// tmp = L(k, i) / L(k, k)
1480+
mjtNum tmp = mat[adr] * invD;
1481+
1482+
// update row i < k: L(i, 0..i) -= L(i, 0..i) * L(k, i) / L(k, k)
1483+
int i = colind[adr];
1484+
mju_addToScl(mat + rowadr[i], mat + rowadr_k, -tmp, rownnz[i]);
1485+
1486+
// update ith element of row k: L(k, i) /= L(k, k)
1487+
mat[adr] = tmp;
1488+
}
1489+
}
1490+
}
1491+
1492+
1493+
14611494
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
14621495
// L is in lower triangle of qLD; D is on diagonal of qLD
14631496
// handle n vectors at once
@@ -1575,16 +1608,15 @@ void mj_solveLD(const mjModel* m, mjtNum* restrict x, int n,
15751608
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
15761609
// like mj_solveLD, but using the CSR representation of L
15771610
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
1578-
const int* rownnz, const int* rowadr, const int* diagind, const int* diagnum,
1579-
const int* colind) {
1611+
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind) {
15801612
// x <- L^-T x
15811613
for (int i=nv-1; i > 0; i--) {
15821614
// skip diagonal (simple) rows, exploit sparsity of input vector
15831615
if (diagnum[i] || x[i] == 0) {
15841616
continue;
15851617
}
15861618

1587-
int d = diagind[i];
1619+
int d = rownnz[i] - 1;
15881620
int adr_i = rowadr[i];
15891621
mjtNum x_i = x[i];
15901622
for (int j=0; j < d; j++) {
@@ -1607,7 +1639,7 @@ void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv
16071639
}
16081640

16091641
int adr = rowadr[i];
1610-
x[i] -= mju_dotSparse(qLDs+adr, x, diagind[i], colind+adr, /*flg_unc1=*/0);
1642+
x[i] -= mju_dotSparse(qLDs+adr, x, rownnz[i] - 1, colind+adr, /*flg_unc1=*/0);
16111643
}
16121644
}
16131645

src/engine/engine_core_smooth.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ MJAPI void mj_crb(const mjModel* m, mjData* d);
5151
// sparse L'*D*L factorizaton of inertia-like matrix M, assumed spd
5252
MJAPI void mj_factorI(const mjModel* m, mjData* d, const mjtNum* M, mjtNum* qLD, mjtNum* qLDiagInv);
5353

54+
// sparse L'*D*L factorizaton of inertia-like matrix
55+
// like mj_factorI, but using CSR representation
56+
MJAPI void mj_factorIs(mjtNum* mat, mjtNum* diaginv, int nv,
57+
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind);
58+
5459
// sparse L'*D*L factorizaton of the inertia matrix M, assumed spd
5560
MJAPI void mj_factorM(const mjModel* m, mjData* d);
5661

@@ -61,8 +66,7 @@ MJAPI void mj_solveLD(const mjModel* m, mjtNum* x, int n,
6166
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
6267
// like mj_solveLD, but using the CSR representation of L
6368
MJAPI void mj_solveLDs(mjtNum* x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
64-
const int* rownnz, const int* rowadr, const int* diagind, const int* diagnum,
65-
const int* colind);
69+
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind);
6670

6771
// sparse backsubstitution: x = inv(L'*D*L)*y, use factorization in d
6872
MJAPI void mj_solveM(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);

test/benchmark/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ mujoco_test(
4343
ADDITIONAL_LINK_LIBRARIES benchmark::benchmark absl::core_headers
4444
)
4545

46+
mujoco_test(
47+
factorI_benchmark_test
48+
MAIN_TARGET benchmark::benchmark_main
49+
ADDITIONAL_LINK_LIBRARIES benchmark::benchmark absl::core_headers
50+
)
51+
52+
mujoco_test(
53+
inertia_benchmark_test
54+
MAIN_TARGET benchmark::benchmark_main
55+
ADDITIONAL_LINK_LIBRARIES benchmark::benchmark absl::core_headers
56+
)
57+
4658
mujoco_test(
4759
solveLD_benchmark_test
4860
MAIN_TARGET benchmark::benchmark_main
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2025 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// A benchmark for comparing different implementations of mj_factorI.
16+
17+
#include <benchmark/benchmark.h>
18+
#include <absl/base/attributes.h>
19+
#include <mujoco/mjdata.h>
20+
#include <mujoco/mujoco.h>
21+
#include "src/engine/engine_core_smooth.h"
22+
#include "test/fixture.h"
23+
24+
namespace mujoco {
25+
namespace {
26+
27+
// number of steps to benchmark
28+
static const int kNumBenchmarkSteps = 50;
29+
30+
// ----------------------------- benchmark ------------------------------------
31+
32+
static void BM_factorI(benchmark::State& state, bool legacy, bool coil) {
33+
static mjModel* m;
34+
if (coil) {
35+
m = LoadModelFromPath("plugin/elasticity/coil.xml");
36+
} else {
37+
m = LoadModelFromPath("humanoid/humanoid100.xml");
38+
}
39+
40+
mjData* d = mj_makeData(m);
41+
mj_forward(m, d);
42+
43+
// allocate inputs and outputs
44+
mj_markStack(d);
45+
46+
// CSR matrices
47+
mjtNum* Ms = mj_stackAllocNum(d, m->nC);
48+
mjtNum* LDs = mj_stackAllocNum(d, m->nC);
49+
for (int i=0; i < m->nC; i++) {
50+
Ms[i] = d->qM[d->mapM2C[i]];
51+
}
52+
53+
// benchmark
54+
while (state.KeepRunningBatch(kNumBenchmarkSteps)) {
55+
for (int i=0; i < kNumBenchmarkSteps; i++) {
56+
if (legacy) {
57+
mj_factorI(m, d, d->qM, d->qLD, d->qLDiagInv);
58+
} else {
59+
mju_copy(LDs, Ms, m->nC);
60+
mj_factorIs(LDs, d->qLDiagInv, m->nv,
61+
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
62+
}
63+
}
64+
}
65+
66+
// finalize
67+
mj_freeStack(d);
68+
mj_deleteData(d);
69+
mj_deleteModel(m);
70+
state.SetItemsProcessed(state.iterations());
71+
}
72+
73+
void ABSL_ATTRIBUTE_NO_TAIL_CALL
74+
BM_factorI_COIL_LEGACY(benchmark::State& state) {
75+
MujocoErrorTestGuard guard;
76+
BM_factorI(state, /*legacy=*/true, /*coil=*/true);
77+
}
78+
BENCHMARK(BM_factorI_COIL_LEGACY);
79+
80+
void ABSL_ATTRIBUTE_NO_TAIL_CALL
81+
BM_factorI_COIL_CSR(benchmark::State& state) {
82+
MujocoErrorTestGuard guard;
83+
BM_factorI(state, /*legacy=*/false, /*coil=*/true);
84+
}
85+
BENCHMARK(BM_factorI_COIL_CSR);
86+
87+
void ABSL_ATTRIBUTE_NO_TAIL_CALL
88+
BM_factorI_H100_LEGACY(benchmark::State& state) {
89+
MujocoErrorTestGuard guard;
90+
BM_factorI(state, /*legacy=*/true, /*coil=*/false);
91+
}
92+
BENCHMARK(BM_factorI_H100_LEGACY);
93+
94+
void ABSL_ATTRIBUTE_NO_TAIL_CALL
95+
BM_factorI_H100_CSR(benchmark::State& state) {
96+
MujocoErrorTestGuard guard;
97+
BM_factorI(state, /*legacy=*/false, /*coil=*/false);
98+
}
99+
BENCHMARK(BM_factorI_H100_CSR);
100+
101+
} // namespace
102+
} // namespace mujoco
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright 2025 DeepMind Technologies Limited
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// A benchmark for comparing legacy and two CSR implementations of inertia
16+
// factor and then solve.
17+
18+
#include <benchmark/benchmark.h>
19+
#include <absl/base/attributes.h>
20+
#include <mujoco/mjdata.h>
21+
#include <mujoco/mujoco.h>
22+
#include "src/engine/engine_core_smooth.h"
23+
#include "test/fixture.h"
24+
25+
namespace mujoco {
26+
namespace {
27+
28+
// number of steps to benchmark
29+
static const int kNumBenchmarkSteps = 50;
30+
31+
// ----------------------------- benchmark ------------------------------------
32+
33+
enum class SolveType {
34+
kLegacy = 0,
35+
kCsr,
36+
};
37+
38+
static void BM_solve(benchmark::State& state, SolveType type) {
39+
static mjModel* m;
40+
m = LoadModelFromPath("../test/benchmark/testdata/inertia.xml");
41+
42+
mjData* d = mj_makeData(m);
43+
mj_forward(m, d);
44+
45+
// allocate input and output vectors
46+
mj_markStack(d);
47+
48+
// make CSR matrix
49+
mjtNum* Ms = mj_stackAllocNum(d, m->nC);
50+
mjtNum* LDs = mj_stackAllocNum(d, m->nC);
51+
for (int i=0; i < m->nC; i++) {
52+
Ms[i] = d->qM[d->mapM2C[i]];
53+
}
54+
55+
// arbitrary input vector
56+
mjtNum *res = mj_stackAllocNum(d, m->nv);
57+
mjtNum *vec = mj_stackAllocNum(d, m->nv);
58+
for (int i=0; i < m->nv; i++) {
59+
vec[i] = 0.2 + 0.3*i;
60+
}
61+
62+
// benchmark
63+
while (state.KeepRunningBatch(kNumBenchmarkSteps)) {
64+
for (int i=0; i < kNumBenchmarkSteps; i++) {
65+
switch (type) {
66+
case SolveType::kLegacy:
67+
mj_factorI(m, d, d->qM, d->qLD, d->qLDiagInv);
68+
mj_solveM(m, d, res, vec, 1);
69+
break;
70+
case SolveType::kCsr:
71+
mju_copy(LDs, Ms, m->nC);
72+
mj_factorIs(LDs, d->qLDiagInv, m->nv,
73+
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
74+
mju_copy(res, vec, m->nv);
75+
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv,
76+
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
77+
}
78+
}
79+
}
80+
81+
// finalize
82+
mj_freeStack(d);
83+
mj_deleteData(d);
84+
mj_deleteModel(m);
85+
state.SetItemsProcessed(state.iterations());
86+
}
87+
88+
void ABSL_ATTRIBUTE_NO_TAIL_CALL BM_solve_LEGACY(benchmark::State& state) {
89+
MujocoErrorTestGuard guard;
90+
BM_solve(state, SolveType::kLegacy);
91+
}
92+
BENCHMARK(BM_solve_LEGACY);
93+
94+
void ABSL_ATTRIBUTE_NO_TAIL_CALL BM_solve_CSR(benchmark::State& state) {
95+
MujocoErrorTestGuard guard;
96+
BM_solve(state, SolveType::kCsr);
97+
}
98+
BENCHMARK(BM_solve_CSR);
99+
100+
} // namespace
101+
} // namespace mujoco

test/benchmark/solveLD_benchmark_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ static void BM_solveLD(benchmark::State& state, bool featherstone, bool coil) {
6464
} else {
6565
mju_copy(res, vec, m->nv);
6666
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv,
67-
d->C_rownnz, d->C_rowadr, d->C_diag, m->dof_simplenum,
68-
d->C_colind);
67+
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
6968
}
7069
}
7170
}

0 commit comments

Comments
 (0)