Skip to content

Commit 923f75f

Browse files
yuvaltassacopybara-github
authored andcommitted
Use only lower triangle in CSR back-substitution.
PiperOrigin-RevId: 711686157 Change-Id: I4fc98cdfb927e5608ce3a99ea34890ce556fcfd7
1 parent 6c880eb commit 923f75f

File tree

2 files changed

+10
-67
lines changed

2 files changed

+10
-67
lines changed

src/engine/engine_core_smooth.c

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,16 +1578,19 @@ void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv
15781578
const int* rownnz, const int* rowadr, const int* diagind, const int* diagnum,
15791579
const int* colind) {
15801580
// x <- L^-T x
1581-
for (int i=nv-2; i >= 0; i--) {
1582-
// skip diagonal (simple) rows
1583-
if (diagnum[i]) {
1581+
for (int i=nv-1; i > 0; i--) {
1582+
// skip diagonal (simple) rows, exploit sparsity of input vector
1583+
if (diagnum[i] || x[i] == 0) {
15841584
continue;
15851585
}
15861586

1587-
int d1 = diagind[i] + 1;
1588-
int nnz = rownnz[i] - d1;
1589-
int adr = rowadr[i] + d1;
1590-
x[i] -= mju_dotSparse(qLDs+adr, x, nnz, colind+adr, /*flg_unc1=*/0);
1587+
int d = diagind[i];
1588+
int adr_i = rowadr[i];
1589+
mjtNum x_i = x[i];
1590+
for (int j=0; j < d; j++) {
1591+
int adr = adr_i + j;
1592+
x[colind[adr]] -= qLDs[adr] * x_i;
1593+
}
15911594
}
15921595

15931596
// x(i) /= D(i,i)

test/engine/engine_core_smooth_test.cc

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -456,53 +456,6 @@ TEST_F(CoreSmoothTest, FactorI) {
456456
mj_deleteModel(model);
457457
}
458458

459-
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
460-
// like mj_solveLD, but using the CSR representation of L
461-
// variant that only uses the lower triangle of qLDs
462-
static void mj_solveLDsLower(mjtNum* x, const mjtNum* qLDs,
463-
const mjtNum* qLDiagInv, int nv, const int* rownnz,
464-
const int* rowadr, const int* diagind,
465-
const int* diagnum, const int* colind,
466-
int* scratch) {
467-
int* marker = scratch;
468-
for (int i=1; i < nv; i++) {
469-
marker[i] = rowadr[i] + diagind[i] - 1;
470-
}
471-
472-
// x <- L^-T x
473-
for (int i=nv-2; i >= 0; i--) {
474-
// skip diagonal (simple) rows
475-
if (diagnum[i]) {
476-
continue;
477-
}
478-
479-
for (int j=i+1; j < nv; j++) {
480-
if (colind[marker[j]] == i) {
481-
x[i] -= qLDs[marker[j]--] * x[j];
482-
}
483-
}
484-
}
485-
486-
// x(i) /= D(i,i)
487-
for (int i=0; i < nv; i++) {
488-
x[i] *= qLDiagInv[i];
489-
}
490-
491-
// x <- L^-1 x
492-
for (int i=1; i < nv; i++) {
493-
// skip diagonal (simple) rows
494-
if (diagnum[i]) {
495-
i += diagnum[i] - 1; // when iterating forward we can skip ahead
496-
continue;
497-
}
498-
499-
int d = diagind[i];
500-
int adr = rowadr[i];
501-
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
502-
}
503-
}
504-
505-
506459
TEST_F(CoreSmoothTest, SolveLDs) {
507460
const std::string xml_path = GetTestDataFilePath(kInertiaPath);
508461
char error[1024];
@@ -537,7 +490,6 @@ TEST_F(CoreSmoothTest, SolveLDs) {
537490
for (int i=0; i < nv; i++) vec[i] = vec2[i] = 20 + 30*i;
538491
for (int i=0; i < nv; i+=2) vec[i] = vec2[i] = 0;
539492

540-
// use upper triangle
541493
mj_solveLD(m, vec.data(), 1, d->qLD, d->qLDiagInv);
542494
mj_solveLDs(vec2.data(), LDs.data(), d->qLDiagInv, nv,
543495
d->C_rownnz, d->C_rowadr, d->C_diag, m->dof_simplenum,
@@ -548,18 +500,6 @@ TEST_F(CoreSmoothTest, SolveLDs) {
548500
EXPECT_FLOAT_EQ(vec[i], vec2[i]);
549501
}
550502

551-
// don't use use upper triangle
552-
mj_solveLD(m, vec.data(), 1, d->qLD, d->qLDiagInv);
553-
vector<int> scratch(nv);
554-
mj_solveLDsLower(vec2.data(), LDs.data(), d->qLDiagInv, nv, d->C_rownnz,
555-
d->C_rowadr, d->C_diag, m->dof_simplenum, d->C_colind,
556-
scratch.data());
557-
558-
// expect vectors to match up to floating point precision
559-
for (int i=0; i < nv; i++) {
560-
EXPECT_FLOAT_EQ(vec[i], vec2[i]);
561-
}
562-
563503
mj_deleteData(d);
564504
mj_deleteModel(m);
565505
}

0 commit comments

Comments
 (0)