Skip to content

Commit 4510c6d

Browse files
yuvaltassacopybara-github
authored andcommitted
Further speedup to CSR back-substitution using dof_simplenum.
PiperOrigin-RevId: 711440268 Change-Id: I81cd9a6a8b8ec78d08cfbc34f60a9216ea753833
1 parent 69c9ac0 commit 4510c6d

File tree

5 files changed

+85
-15
lines changed

5 files changed

+85
-15
lines changed

src/engine/engine_core_smooth.c

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,15 +1575,19 @@ void mj_solveLD(const mjModel* m, mjtNum* restrict x, int n,
15751575
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
15761576
// like mj_solveLD, but using the CSR representation of L
15771577
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
1578-
const int* rownnz, const int* rowadr, const int* diag, const int* colind) {
1578+
const int* rownnz, const int* rowadr, const int* diagind, const int* diagnum,
1579+
const int* colind) {
15791580
// x <- L^-T x
15801581
for (int i=nv-2; i >= 0; i--) {
1581-
int d1 = diag[i] + 1;
1582-
int nnz = rownnz[i] - d1;
1583-
if (nnz > 0) {
1584-
int adr = rowadr[i] + d1;
1585-
x[i] -= mju_dotSparse(qLDs+adr, x, nnz, colind+adr, /*flg_unc1=*/0);
1582+
// skip diagonal (simple) rows
1583+
if (diagnum[i]) {
1584+
continue;
15861585
}
1586+
1587+
int d1 = diagind[i] + 1;
1588+
int nnz = rownnz[i] - d1;
1589+
int adr = rowadr[i] + d1;
1590+
x[i] -= mju_dotSparse(qLDs+adr, x, nnz, colind+adr, /*flg_unc1=*/0);
15871591
}
15881592

15891593
// x(i) /= D(i,i)
@@ -1593,11 +1597,14 @@ void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv
15931597

15941598
// x <- L^-1 x
15951599
for (int i=1; i < nv; i++) {
1596-
int d = diag[i];
1597-
if (d > 0) {
1598-
int adr = rowadr[i];
1599-
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
1600+
// skip diagonal (simple) rows
1601+
if (diagnum[i]) {
1602+
i += diagnum[i] - 1; // when iterating forward we can skip ahead
1603+
continue;
16001604
}
1605+
1606+
int adr = rowadr[i];
1607+
x[i] -= mju_dotSparse(qLDs+adr, x, diagind[i], colind+adr, /*flg_unc1=*/0);
16011608
}
16021609
}
16031610

src/engine/engine_core_smooth.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ MJAPI void mj_solveLD(const mjModel* m, mjtNum* x, int n,
6161
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
6262
// like mj_solveLD, but using the CSR representation of L
6363
MJAPI void mj_solveLDs(mjtNum* x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
64-
const int* rownnz, const int* rowadr, const int* diag, const int* colind);
64+
const int* rownnz, const int* rowadr, const int* diagind, const int* diagnum,
65+
const int* colind);
6566

6667
// sparse backsubstitution: x = inv(L'*D*L)*y, use factorization in d
6768
MJAPI void mj_solveM(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);

test/benchmark/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ mujoco_test(
4444
)
4545

4646
mujoco_test(
47-
engine_core_smooth_benchmark_test
47+
solveLD_benchmark_test
4848
MAIN_TARGET benchmark::benchmark_main
4949
ADDITIONAL_LINK_LIBRARIES benchmark::benchmark absl::core_headers
5050
)

test/benchmark/engine_core_smooth_benchmark_test.cc renamed to test/benchmark/solveLD_benchmark_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ static void BM_solveLD(benchmark::State& state, bool featherstone, bool coil) {
6464
} else {
6565
mju_copy(res, vec, m->nv);
6666
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv,
67-
d->C_rownnz, d->C_rowadr, d->C_diag, d->C_colind);
67+
d->C_rownnz, d->C_rowadr, d->C_diag, m->dof_simplenum,
68+
d->C_colind);
6869
}
6970
}
7071
}

test/engine/engine_core_smooth_test.cc

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,54 @@ TEST_F(CoreSmoothTest, FactorI) {
456456
mj_deleteModel(model);
457457
}
458458

459-
TEST_F(CoreSmoothTest, SolveLD2) {
459+
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
460+
// like mj_solveLD, but using the CSR representation of L
461+
// variant that only uses the lower triangle of qLDs
462+
static void mj_solveLDsLower(mjtNum* x, const mjtNum* qLDs,
463+
const mjtNum* qLDiagInv, int nv, const int* rownnz,
464+
const int* rowadr, const int* diagind,
465+
const int* diagnum, const int* colind,
466+
int* scratch) {
467+
int* marker = scratch;
468+
for (int i=1; i < nv; i++) {
469+
marker[i] = rowadr[i] + diagind[i] - 1;
470+
}
471+
472+
// x <- L^-T x
473+
for (int i=nv-2; i >= 0; i--) {
474+
// skip diagonal (simple) rows
475+
if (diagnum[i]) {
476+
continue;
477+
}
478+
479+
for (int j=i+1; j < nv; j++) {
480+
if (colind[marker[j]] == i) {
481+
x[i] -= qLDs[marker[j]--] * x[j];
482+
}
483+
}
484+
}
485+
486+
// x(i) /= D(i,i)
487+
for (int i=0; i < nv; i++) {
488+
x[i] *= qLDiagInv[i];
489+
}
490+
491+
// x <- L^-1 x
492+
for (int i=1; i < nv; i++) {
493+
// skip diagonal (simple) rows
494+
if (diagnum[i]) {
495+
i += diagnum[i] - 1; // when iterating forward we can skip ahead
496+
continue;
497+
}
498+
499+
int d = diagind[i];
500+
int adr = rowadr[i];
501+
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
502+
}
503+
}
504+
505+
506+
TEST_F(CoreSmoothTest, SolveLDs) {
460507
const std::string xml_path = GetTestDataFilePath(kInertiaPath);
461508
char error[1024];
462509
mjModel* m = mj_loadXML(xml_path.c_str(), nullptr, error, sizeof(error));
@@ -490,9 +537,23 @@ TEST_F(CoreSmoothTest, SolveLD2) {
490537
for (int i=0; i < nv; i++) vec[i] = vec2[i] = 20 + 30*i;
491538
for (int i=0; i < nv; i+=2) vec[i] = vec2[i] = 0;
492539

540+
// use upper triangle
493541
mj_solveLD(m, vec.data(), 1, d->qLD, d->qLDiagInv);
494542
mj_solveLDs(vec2.data(), LDs.data(), d->qLDiagInv, nv,
495-
d->C_rownnz, d->C_rowadr, d->C_diag, d->C_colind);
543+
d->C_rownnz, d->C_rowadr, d->C_diag, m->dof_simplenum,
544+
d->C_colind);
545+
546+
// expect vectors to match up to floating point precision
547+
for (int i=0; i < nv; i++) {
548+
EXPECT_FLOAT_EQ(vec[i], vec2[i]);
549+
}
550+
551+
// don't use use upper triangle
552+
mj_solveLD(m, vec.data(), 1, d->qLD, d->qLDiagInv);
553+
vector<int> scratch(nv);
554+
mj_solveLDsLower(vec2.data(), LDs.data(), d->qLDiagInv, nv, d->C_rownnz,
555+
d->C_rowadr, d->C_diag, m->dof_simplenum, d->C_colind,
556+
scratch.data());
496557

497558
// expect vectors to match up to floating point precision
498559
for (int i=0; i < nv; i++) {

0 commit comments

Comments
 (0)