Skip to content

Commit 8a5f092

Browse files
yuvaltassacopybara-github
authored andcommitted
Allow CSR back-substitution to handle multiple vectors.
PiperOrigin-RevId: 713229673 Change-Id: I7a5b43fe966cf9e482bd41e2eea6c30dd3ffa1d4
1 parent 4d82ab5 commit 8a5f092

File tree

5 files changed

+116
-29
lines changed

5 files changed

+116
-29
lines changed

src/engine/engine_core_smooth.c

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,39 +1607,88 @@ void mj_solveLD(const mjModel* m, mjtNum* restrict x, int n,
16071607

16081608
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
16091609
// like mj_solveLD, but using the CSR representation of L
1610-
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
1610+
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv, int n,
16111611
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind) {
1612-
// x <- L^-T x
1613-
for (int i=nv-1; i > 0; i--) {
1614-
// skip diagonal (simple) rows, exploit sparsity of input vector
1615-
if (diagnum[i] || x[i] == 0) {
1616-
continue;
1612+
// single vector
1613+
if (n == 1) {
1614+
// x <- L^-T x
1615+
for (int i=nv-1; i > 0; i--) {
1616+
// skip diagonal rows, zero elements in input vector
1617+
mjtNum x_i = x[i];
1618+
if (x_i == 0 || diagnum[i]) {
1619+
continue;
1620+
}
1621+
1622+
int start = rowadr[i];
1623+
int end = start + rownnz[i] - 1;
1624+
for (int adr=start; adr < end; adr++) {
1625+
x[colind[adr]] -= qLDs[adr] * x_i;
1626+
}
16171627
}
16181628

1619-
int d = rownnz[i] - 1;
1620-
int adr_i = rowadr[i];
1621-
mjtNum x_i = x[i];
1622-
for (int j=0; j < d; j++) {
1623-
int adr = adr_i + j;
1624-
x[colind[adr]] -= qLDs[adr] * x_i;
1629+
// x <- D^-1 x
1630+
for (int i=0; i < nv; i++) {
1631+
x[i] *= qLDiagInv[i];
16251632
}
1626-
}
16271633

1628-
// x(i) /= D(i,i)
1629-
for (int i=0; i < nv; i++) {
1630-
x[i] *= qLDiagInv[i];
1634+
// x <- L^-1 x
1635+
for (int i=1; i < nv; i++) {
1636+
// skip diagonal rows
1637+
if (diagnum[i]) {
1638+
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
1639+
continue;
1640+
}
1641+
1642+
int adr = rowadr[i];
1643+
x[i] -= mju_dotSparse(qLDs+adr, x, rownnz[i] - 1, colind+adr, /*flg_unc1=*/0);
1644+
}
16311645
}
16321646

1633-
// x <- L^-1 x
1634-
for (int i=1; i < nv; i++) {
1635-
// skip diagonal (simple) rows
1636-
if (diagnum[i]) {
1637-
i += diagnum[i] - 1; // when iterating forward we can skip ahead
1638-
continue;
1647+
// multiple vectors
1648+
else {
1649+
// x <- L^-T x
1650+
for (int i=nv-1; i > 0; i--) {
1651+
// skip diagonal rows
1652+
if (diagnum[i]) {
1653+
continue;
1654+
}
1655+
1656+
int start = rowadr[i];
1657+
int end = start + rownnz[i] - 1;
1658+
for (int adr=start; adr < end; adr++) {
1659+
int j = colind[adr];
1660+
mjtNum val = qLDs[adr];
1661+
for (int offset=0; offset < n*nv; offset+=nv) {
1662+
mjtNum x_i;
1663+
if ((x_i = x[i+offset])) {
1664+
x[j+offset] -= val * x_i;
1665+
}
1666+
}
1667+
}
1668+
}
1669+
1670+
// x <- D^-1 x
1671+
for (int i=0; i < nv; i++) {
1672+
mjtNum invD_i = qLDiagInv[i];
1673+
for (int offset=0; offset < n*nv; offset+=nv) {
1674+
x[i+offset] *= invD_i;
1675+
}
16391676
}
16401677

1641-
int adr = rowadr[i];
1642-
x[i] -= mju_dotSparse(qLDs+adr, x, rownnz[i] - 1, colind+adr, /*flg_unc1=*/0);
1678+
// x <- L^-1 x
1679+
for (int i=1; i < nv; i++) {
1680+
// skip diagonal rows
1681+
if (diagnum[i]) {
1682+
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
1683+
continue;
1684+
}
1685+
1686+
int adr = rowadr[i];
1687+
int d = rownnz[i] - 1;
1688+
for (int offset=0; offset < n*nv; offset+=nv) {
1689+
x[i+offset] -= mju_dotSparse(qLDs+adr, x+offset, d, colind+adr, /*flg_unc1=*/0);
1690+
}
1691+
}
16431692
}
16441693
}
16451694

src/engine/engine_core_smooth.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ MJAPI void mj_solveLD(const mjModel* m, mjtNum* x, int n,
6464
const mjtNum* qLD, const mjtNum* qLDiagInv);
6565

6666
// in-place sparse backsubstitution: x = inv(L'*D*L)*x
67-
// like mj_solveLD, but using the CSR representation of L
68-
MJAPI void mj_solveLDs(mjtNum* x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv,
67+
// handle n vectors at once
68+
MJAPI void mj_solveLDs(mjtNum* x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv, int n,
6969
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind);
7070

7171
// sparse backsubstitution: x = inv(L'*D*L)*y, use factorization in d

test/benchmark/inertia_benchmark_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static void BM_solve(benchmark::State& state, SolveType type) {
7272
mj_factorIs(LDs, d->qLDiagInv, m->nv,
7373
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
7474
mju_copy(res, vec, m->nv);
75-
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv,
75+
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv, 1,
7676
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
7777
}
7878
}

test/benchmark/solveLD_benchmark_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ static void BM_solveLD(benchmark::State& state, bool featherstone, bool coil) {
6363
mj_solveM(m, d, res, vec, 1);
6464
} else {
6565
mju_copy(res, vec, m->nv);
66-
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv,
66+
mj_solveLDs(res, LDs, d->qLDiagInv, m->nv, 1,
6767
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
6868
}
6969
}

test/engine/engine_core_smooth_test.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ TEST_F(CoreSmoothTest, SolveLDs) {
495495
for (int i=0; i < nv; i+=2) vec[i] = vec2[i] = 0;
496496

497497
mj_solveLD(m, vec.data(), 1, d->qLD, d->qLDiagInv);
498-
mj_solveLDs(vec2.data(), LDs.data(), d->qLDiagInv, nv,
498+
mj_solveLDs(vec2.data(), LDs.data(), d->qLDiagInv, nv, 1,
499499
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
500500

501501
// expect vectors to match up to floating point precision
@@ -507,6 +507,44 @@ TEST_F(CoreSmoothTest, SolveLDs) {
507507
mj_deleteModel(m);
508508
}
509509

510+
TEST_F(CoreSmoothTest, SolveLDmultipleVectors) {
511+
const std::string xml_path = GetTestDataFilePath(kInertiaPath);
512+
char error[1024];
513+
mjModel* m = mj_loadXML(xml_path.c_str(), nullptr, error, sizeof(error));
514+
ASSERT_THAT(m, NotNull()) << "Failed to load model: " << error;
515+
516+
mjData* d = mj_makeData(m);
517+
mj_forward(m, d);
518+
519+
int nv = m->nv;
520+
int nC = m->nC;
521+
522+
// copy LD into LDs: CSR format
523+
vector<mjtNum> LDs(nC);
524+
for (int i=0; i < nC; i++) {
525+
LDs[i] = d->qLD[d->mapM2C[i]];
526+
}
527+
528+
// compare n LD and LDs vector solve
529+
int n = 3;
530+
vector<mjtNum> vec(nv*n);
531+
vector<mjtNum> vec2(nv*n);
532+
for (int i=0; i < nv*n; i++) vec[i] = vec2[i] = 2 + 3*i;
533+
for (int i=0; i < nv*n; i+=3) vec[i] = vec2[i] = 0;
534+
535+
mj_solveLD(m, vec.data(), n, d->qLD, d->qLDiagInv);
536+
mj_solveLDs(vec2.data(), LDs.data(), d->qLDiagInv, nv, n,
537+
d->C_rownnz, d->C_rowadr, m->dof_simplenum, d->C_colind);
538+
539+
// expect vectors to match up to floating point precision
540+
for (int i=0; i < nv*n; i++) {
541+
EXPECT_FLOAT_EQ(vec[i], vec2[i]);
542+
}
543+
544+
mj_deleteData(d);
545+
mj_deleteModel(m);
546+
}
547+
510548
TEST_F(CoreSmoothTest, FactorIs) {
511549
const std::string xml_path = GetTestDataFilePath(kInertiaPath);
512550
char error[1024];

0 commit comments

Comments
 (0)