Skip to content

Commit 25126e8

Browse files
yuvaltassacopybara-github
authored andcommitted
Move mju_cholFactorCount to engine_util_solve.
PiperOrigin-RevId: 746011856 Change-Id: If9812251420053644f4eca81adb5116d370ee524
1 parent 3c53097 commit 25126e8

File tree

7 files changed

+150
-152
lines changed

7 files changed

+150
-152
lines changed

src/engine/engine_util_solve.c

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,60 @@ int mju_cholFactorSparse(mjtNum* mat, int n, mjtNum mindiag,
191191

192192

193193

194+
// precount row non-zeros of reverse-Cholesky factor L, return total non-zeros
195+
// based on ldl_symbolic from 'Algorithm 8xx: a concise sparse Cholesky factorization package'
196+
// reads pattern from upper triangle
197+
int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr, const int* colind,
198+
int n, mjData* d) {
199+
mj_markStack(d);
200+
int* parent = mjSTACKALLOC(d, n, int);
201+
int* flag = mjSTACKALLOC(d, n, int);
202+
203+
// loop over rows in reverse order
204+
for (int r = n - 1; r >= 0; r--) {
205+
parent[r] = -1;
206+
flag[r] = r;
207+
L_rownnz[r] = 1; // start with 1 for diagonal
208+
209+
// loop over non-zero columns of upper triangle
210+
int start = rowadr[r];
211+
int end = start + rownnz[r];
212+
for (int c = start; c < end; c++) {
213+
int i = colind[c];
214+
215+
// skip lower triangle
216+
if (i <= r) {
217+
continue;
218+
}
219+
220+
// traverse from i to ancestor, stop when row is flagged
221+
while (flag[i] != r) {
222+
// if not yet set, set parent to current row
223+
if (parent[i] == -1) {
224+
parent[i] = r;
225+
}
226+
227+
// increment non-zeros, flag row i, advance to parent
228+
L_rownnz[i]++;
229+
flag[i] = r;
230+
i = parent[i];
231+
}
232+
}
233+
}
234+
235+
mj_freeStack(d);
236+
237+
// sum up all row non-zeros
238+
int nnz = 0;
239+
for (int r = 0; r < n; r++) {
240+
nnz += L_rownnz[r];
241+
}
242+
243+
return nnz;
244+
}
245+
246+
247+
194248
// sparse reverse-order Cholesky solve
195249
void mju_cholSolveSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int n,
196250
const int* rownnz, const int* rowadr, const int* colind) {

src/engine/engine_util_solve.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ int mju_cholFactorSparse(mjtNum* mat, int n, mjtNum mindiag,
3838
int* rownnz, const int* rowadr, int* colind,
3939
mjData* d);
4040

41+
// precount row non-zeros of reverse-Cholesky factor L, return total
42+
MJAPI int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr,
43+
const int* colind, int n, mjData* d);
44+
4145
// sparse reverse-order Cholesky solve
4246
void mju_cholSolveSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int n,
4347
const int* rownnz, const int* rowadr, const int* colind);

src/engine/engine_util_sparse.c

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -795,55 +795,3 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
795795

796796
mj_freeStack(d);
797797
}
798-
799-
// precount row non-zeros of reverse-Cholesky factor L, return total non-zeros
800-
// based on ldl_symbolic from 'Algorithm 8xx: a concise sparse Cholesky factorization package'
801-
// reads pattern from upper triangle
802-
int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr, const int* colind,
803-
int n, mjData* d) {
804-
mj_markStack(d);
805-
int* parent = mjSTACKALLOC(d, n, int);
806-
int* flag = mjSTACKALLOC(d, n, int);
807-
808-
// loop over rows in reverse order
809-
for (int r = n - 1; r >= 0; r--) {
810-
parent[r] = -1;
811-
flag[r] = r;
812-
L_rownnz[r] = 1; // start with 1 for diagonal
813-
814-
// loop over non-zero columns of upper triangle
815-
int start = rowadr[r];
816-
int end = start + rownnz[r];
817-
for (int c = start; c < end; c++) {
818-
int i = colind[c];
819-
820-
// skip lower triangle
821-
if (i <= r) {
822-
continue;
823-
}
824-
825-
// traverse from i to ancestor, stop when row is flagged
826-
while (flag[i] != r) {
827-
// if not yet set, set parent to current row
828-
if (parent[i] == -1) {
829-
parent[i] = r;
830-
}
831-
832-
// increment non-zeros, flag row i, advance to parent
833-
L_rownnz[i]++;
834-
flag[i] = r;
835-
i = parent[i];
836-
}
837-
}
838-
}
839-
840-
mj_freeStack(d);
841-
842-
// sum up all row non-zeros
843-
int nnz = 0;
844-
for (int r = 0; r < n; r++) {
845-
nnz += L_rownnz[r];
846-
}
847-
848-
return nnz;
849-
}

src/engine/engine_util_sparse.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,6 @@ MJAPI void mju_sqrMatTDSparseCount(int* res_rownnz, int* res_rowadr, int nr,
103103
// precompute res_rowadr for mju_sqrMatTDSparse using uncompressed memory
104104
MJAPI void mju_sqrMatTDUncompressedInit(int* res_rowadr, int nc);
105105

106-
// precount row non-zeros of reverse-Cholesky factor L, return total
107-
MJAPI int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr,
108-
const int* colind, int n, mjData* d);
109106

110107
// ------------------------------ inlined functions ------------------------------------------------
111108

test/engine/engine_util_solve_test.cc

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <random>
2222
#include <iomanip>
2323
#include <string>
24-
#include <vector>
2524

2625
#include <gmock/gmock.h>
2726
#include <gtest/gtest.h>
@@ -36,6 +35,7 @@ namespace {
3635
using ::testing::DoubleEq;
3736
using ::testing::Pointwise;
3837
using ::testing::DoubleNear;
38+
using ::testing::ElementsAre;
3939
using ::std::string;
4040
using ::std::setw;
4141
using QCQP2Test = MujocoTest;
@@ -649,5 +649,79 @@ TEST_F(BandMatrixTest, Solve) {
649649
}
650650
}
651651

652+
using EngineUtilSolveTest = MujocoTest;
653+
654+
TEST_F(EngineUtilSolveTest, MjuCholFactorNNZ) {
655+
mjModel* model = LoadModelFromString("<mujoco/>");
656+
mjData* d = mj_makeData(model);
657+
658+
int nA = 2;
659+
mjtNum matA[4] = {1, 0,
660+
0, 1};
661+
mjtNum sparseA[4];
662+
int rownnzA[2];
663+
int rowadrA[2];
664+
int colindA[4];
665+
int rownnzA_factor[2];
666+
mju_dense2sparse(sparseA, matA, nA, nA, rownnzA, rowadrA, colindA, 4);
667+
int nnzA = mju_cholFactorCount(rownnzA_factor,
668+
rownnzA, rowadrA, colindA, nA, d);
669+
670+
EXPECT_EQ(nnzA, 2);
671+
EXPECT_THAT(AsVector(rownnzA_factor, 2), ElementsAre(1, 1));
672+
673+
int nB = 3;
674+
mjtNum matB[9] = {10, 1, 0,
675+
0, 10, 1,
676+
0, 0, 10};
677+
mjtNum sparseB[9];
678+
int rownnzB[3];
679+
int rowadrB[3];
680+
int colindB[9];
681+
int rownnzB_factor[3];
682+
mju_dense2sparse(sparseB, matB, nB, nB, rownnzB, rowadrB, colindB, 9);
683+
int nnzB = mju_cholFactorCount(rownnzB_factor,
684+
rownnzB, rowadrB, colindB, nB, d);
685+
686+
EXPECT_EQ(nnzB, 5);
687+
EXPECT_THAT(AsVector(rownnzB_factor, 3), ElementsAre(1, 2, 2));
688+
689+
int nC = 3;
690+
mjtNum matC[9] = {10, 1, 0,
691+
0, 10, 0,
692+
0, 0, 10};
693+
mjtNum sparseC[9];
694+
int rownnzC[3];
695+
int rowadrC[3];
696+
int colindC[9];
697+
int rownnzC_factor[3];
698+
mju_dense2sparse(sparseC, matC, nC, nC, rownnzC, rowadrC, colindC, 9);
699+
int nnzC = mju_cholFactorCount(rownnzC_factor,
700+
rownnzC, rowadrC, colindC, nC, d);
701+
702+
EXPECT_EQ(nnzC, 4);
703+
EXPECT_THAT(AsVector(rownnzC_factor, 3), ElementsAre(1, 2, 1));
704+
705+
int nD = 4;
706+
mjtNum matD[16] = {10, 1, 2, 3,
707+
0, 10, 0, 0,
708+
0, 0, 10, 1,
709+
0, 0, 0, 10};
710+
mjtNum sparseD[16];
711+
int rownnzD[4];
712+
int rowadrD[4];
713+
int colindD[16];
714+
int rownnzD_factor[4];
715+
mju_dense2sparse(sparseD, matD, nD, nD, rownnzD, rowadrD, colindD, 16);
716+
int nnzD = mju_cholFactorCount(rownnzD_factor,
717+
rownnzD, rowadrD, colindD, nD, d);
718+
719+
EXPECT_EQ(nnzD, 8);
720+
EXPECT_THAT(AsVector(rownnzD_factor, 4), ElementsAre(1, 2, 2, 3));
721+
722+
mj_deleteData(d);
723+
mj_deleteModel(model);
724+
}
725+
652726
} // namespace
653727
} // namespace mujoco

0 commit comments

Comments
 (0)