Skip to content

Commit 674d227

Browse files
yuvaltassacopybara-github
authored andcommitted
Standardize names of sparse fill-in pre-counting functions
PiperOrigin-RevId: 712835830 Change-Id: I8e90dfa52af56ede917e1fd90b8d551898c244e5
1 parent 0be7e6a commit 674d227

File tree

5 files changed

+75
-71
lines changed

5 files changed

+75
-71
lines changed

src/engine/engine_core_constraint.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,8 +2175,8 @@ void mj_projectConstraint(const mjModel* m, mjData* d) {
21752175
mju_superSparse(nefc, rowsuper, rownnz, rowadr, colind);
21762176

21772177
// AR = JM2 * JM2'
2178-
mju_sqrMatTDSparseInit(d->efc_AR_rownnz, d->efc_AR_rowadr, nefc, rownnzT,
2179-
rowadrT, colindT, rownnz, rowadr, colind, rowsuper, d, /*flg_upper=*/1);
2178+
mju_sqrMatTDSparseCount(d->efc_AR_rownnz, d->efc_AR_rowadr, nefc, rownnzT,
2179+
rowadrT, colindT, rownnz, rowadr, colind, rowsuper, d, /*flg_upper=*/1);
21802180

21812181
mju_sqrMatTDSparse(d->efc_AR, JM2T, JM2, NULL, nv, nefc,
21822182
d->efc_AR_rownnz, d->efc_AR_rowadr, d->efc_AR_colind,

src/engine/engine_solver.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,10 +1400,10 @@ static void MakeHessian(const mjModel* m, mjData* d, mjCGContext* ctx) {
14001400
}
14011401

14021402
// initialize Hessian rowadr, rownnz
1403-
mju_sqrMatTDSparseInit(ctx->H_rownnz, ctx->H_rowadr, nv,
1404-
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind,
1405-
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind, d->efc_JT_rowsuper,
1406-
d, /*flg_upper=*/0);
1403+
mju_sqrMatTDSparseCount(ctx->H_rownnz, ctx->H_rowadr, nv,
1404+
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind,
1405+
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind,
1406+
d->efc_JT_rowsuper, d, /*flg_upper=*/0);
14071407

14081408
// add nC to Hessian total nonzeros (unavoidable overcounting since H_colind is still unknown)
14091409
ctx->nH = m->nC + ctx->H_rowadr[nv - 1] + ctx->H_rownnz[nv - 1];
@@ -1440,7 +1440,7 @@ static void MakeHessian(const mjModel* m, mjData* d, mjCGContext* ctx) {
14401440
ctx->H_rownnz, ctx->H_rowadr, ctx->H_colind);
14411441

14421442
// count total and row non-zeros of reverse-Cholesky factor L
1443-
ctx->nL = mju_cholFactorNNZ(ctx->L_rownnz, HT_rownnz, HT_rowadr, HT_colind, nv, d);
1443+
ctx->nL = mju_cholFactorCount(ctx->L_rownnz, HT_rownnz, HT_rowadr, HT_colind, nv, d);
14441444
mj_freeStack(d);
14451445

14461446
// compute L row adresses: rowadr = cumsum(rownnz)

src/engine/engine_util_sparse.c

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -625,10 +625,10 @@ void mju_superSparse(int nr, int* rowsuper,
625625

626626

627627
// precount res_rownnz and precompute res_rowadr for mju_sqrMatTDSparse
628-
void mju_sqrMatTDSparseInit(int* res_rownnz, int* res_rowadr, int nr,
629-
const int* rownnz, const int* rowadr, const int* colind,
630-
const int* rownnzT, const int* rowadrT, const int* colindT,
631-
const int* rowsuperT, mjData* d, int flg_upper) {
628+
void mju_sqrMatTDSparseCount(int* res_rownnz, int* res_rowadr, int nr,
629+
const int* rownnz, const int* rowadr, const int* colind,
630+
const int* rownnzT, const int* rowadrT, const int* colindT,
631+
const int* rowsuperT, mjData* d, int flg_upper) {
632632
mj_markStack(d);
633633
int* chain = mjSTACKALLOC(d, 2*nr, int);
634634
int nchain = 0;
@@ -865,11 +865,11 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
865865
mj_freeStack(d);
866866
}
867867

868-
// compute row non-zeros of reverse-Cholesky factor L, return total non-zeros
869-
// based on ldl_symbolic from 'Algorithm 8xx: a concise sparse Cholesky factorization package'
870-
// note: reads pattern from upper triangle
871-
int mju_cholFactorNNZ(int* L_rownnz, const int* rownnz, const int* rowadr, const int* colind,
872-
int n, mjData* d) {
868+
// precount row non-zeros of reverse-Cholesky factor L, return total non-zeros
869+
// based on ldl_symbolic from 'Algorithm 8xx: a concise sparse Cholesky factorization package'
870+
// reads pattern from upper triangle
871+
int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr, const int* colind,
872+
int n, mjData* d) {
873873
mj_markStack(d);
874874
int* parent = mjSTACKALLOC(d, n, int);
875875
int* flag = mjSTACKALLOC(d, n, int);
@@ -880,24 +880,28 @@ int mju_cholFactorNNZ(int* L_rownnz, const int* rownnz, const int* rowadr, const
880880
flag[r] = r;
881881
L_rownnz[r] = 1; // start with 1 for diagonal
882882

883-
// loop over non-zero columns
883+
// loop over non-zero columns of upper triangle
884884
int start = rowadr[r];
885885
int end = start + rownnz[r];
886886
for (int c = start; c < end; c++) {
887887
int i = colind[c];
888-
if (i > r) {
889-
// traverse from i to ancestor, stop when row is flagged
890-
while (flag[i] != r) {
891-
// if not yet set, set parent to current row
892-
if (parent[i] == -1) {
893-
parent[i] = r;
894-
}
895888

896-
// increment non-zeros, flag row i, advance to parent
897-
L_rownnz[i]++;
898-
flag[i] = r;
899-
i = parent[i];
889+
// skip lower triangle
890+
if (i <= r) {
891+
continue;
892+
}
893+
894+
// traverse from i to ancestor, stop when row is flagged
895+
while (flag[i] != r) {
896+
// if not yet set, set parent to current row
897+
if (parent[i] == -1) {
898+
parent[i] = r;
900899
}
900+
901+
// increment non-zeros, flag row i, advance to parent
902+
L_rownnz[i]++;
903+
flag[i] = r;
904+
i = parent[i];
901905
}
902906
}
903907
}

src/engine/engine_util_sparse.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,17 @@ MJAPI void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT
9999
mjData* d, int flg_upper);
100100

101101
// precount res_rownnz and precompute res_rowadr for mju_sqrMatTDSparse
102-
MJAPI void mju_sqrMatTDSparseInit(int* res_rownnz, int* res_rowadr, int nr,
103-
const int* rownnz, const int* rowadr, const int* colind,
104-
const int* rownnzT, const int* rowadrT, const int* colindT,
105-
const int* rowsuperT, mjData* d, int flg_upper);
102+
MJAPI void mju_sqrMatTDSparseCount(int* res_rownnz, int* res_rowadr, int nr,
103+
const int* rownnz, const int* rowadr, const int* colind,
104+
const int* rownnzT, const int* rowadrT, const int* colindT,
105+
const int* rowsuperT, mjData* d, int flg_upper);
106106

107107
// precompute res_rowadr for mju_sqrMatTDSparse using uncompressed memory
108108
MJAPI void mju_sqrMatTDUncompressedInit(int* res_rowadr, int nc);
109109

110-
// compute row non-zeros of reverse-Cholesky factor L, return total
111-
MJAPI int mju_cholFactorNNZ(int* L_rownnz, const int* rownnz, const int* rowadr, const int* colind,
112-
int n, mjData* d);
110+
// precount row non-zeros of reverse-Cholesky factor L, return total
111+
MJAPI int mju_cholFactorCount(int* L_rownnz, const int* rownnz, const int* rowadr,
112+
const int* colind, int n, mjData* d);
113113

114114
// ------------------------------ inlined functions ------------------------------------------------
115115

test/engine/engine_util_sparse_test.cc

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse1) {
326326
int rowadrH[] = {0, 0, 0};
327327

328328
// test precount
329-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
330-
rownnzT, rowadrT, colindT, nullptr, data, 1);
329+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
330+
rownnzT, rowadrT, colindT, nullptr, data, 1);
331331

332332
EXPECT_THAT(rownnzH, ElementsAre(3, 3, 3));
333333
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
@@ -371,8 +371,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse2) {
371371
int rowadrH[] = {0, 0, 0};
372372

373373
// test precount
374-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
375-
rownnzT, rowadrT, colindT, nullptr, data, 1);
374+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
375+
rownnzT, rowadrT, colindT, nullptr, data, 1);
376376

377377
EXPECT_THAT(rownnzH, ElementsAre(3, 3, 3));
378378
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
@@ -419,8 +419,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse3) {
419419
mjtNum diag[] = {2, 3, 4};
420420

421421
// test precount
422-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
423-
rownnzT, rowadrT, colindT, nullptr, data, 1);
422+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
423+
rownnzT, rowadrT, colindT, nullptr, data, 1);
424424

425425
EXPECT_THAT(rownnzH, ElementsAre(2, 2, 0));
426426
EXPECT_THAT(rowadrH, ElementsAre(0, 2, 4));
@@ -467,8 +467,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse4) {
467467

468468

469469
// test precount
470-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
471-
rownnzT, rowadrT, colindT, nullptr, data, 1);
470+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
471+
rownnzT, rowadrT, colindT, nullptr, data, 1);
472472

473473
EXPECT_THAT(rownnzH, ElementsAre(2, 0, 2));
474474
EXPECT_THAT(rowadrH, ElementsAre(0, 2, 2));
@@ -513,8 +513,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse5) {
513513

514514

515515
// test precount
516-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
517-
rownnzT, rowadrT, colindT, nullptr, data, 1);
516+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
517+
rownnzT, rowadrT, colindT, nullptr, data, 1);
518518

519519
EXPECT_THAT(rownnzH, ElementsAre(3, 2, 2));
520520
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 5));
@@ -558,8 +558,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse6) {
558558
int rowadrH[] = {0, 0, 0};
559559

560560
// test precount
561-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
562-
rownnzT, rowadrT, colindT, nullptr, data, 1);
561+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
562+
rownnzT, rowadrT, colindT, nullptr, data, 1);
563563

564564
EXPECT_THAT(rownnzH, ElementsAre(2, 1, 2));
565565
EXPECT_THAT(rowadrH, ElementsAre(0, 2, 3));
@@ -605,8 +605,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse7) {
605605
mjtNum diag[] = {2, 3, 4};
606606

607607
// test precount
608-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 2, rownnz, rowadr, colind,
609-
rownnzT, rowadrT, colindT, nullptr, data, 1);
608+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 2, rownnz, rowadr, colind,
609+
rownnzT, rowadrT, colindT, nullptr, data, 1);
610610

611611
EXPECT_THAT(rownnzH, ElementsAre(2, 2));
612612
EXPECT_THAT(rowadrH, ElementsAre(0, 2));
@@ -651,8 +651,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse8) {
651651
mjtNum diag[] = {2, 3};
652652

653653
// test precount
654-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
655-
rownnzT, rowadrT, colindT, nullptr, data, 1);
654+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
655+
rownnzT, rowadrT, colindT, nullptr, data, 1);
656656

657657
EXPECT_THAT(rownnzH, ElementsAre(3, 2, 2));
658658
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 5));
@@ -698,8 +698,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse9) {
698698
mjtNum diag[] = {2, 3, 4};
699699

700700
// test precount
701-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
702-
rownnzT, rowadrT, colindT, nullptr, data, 1);
701+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
702+
rownnzT, rowadrT, colindT, nullptr, data, 1);
703703

704704
EXPECT_THAT(rownnzH, ElementsAre(3, 3, 3));
705705
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
@@ -746,8 +746,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse10) {
746746
mjtNum diag[] = {1, 1, 1};
747747

748748
// test precount
749-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
750-
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
749+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
750+
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
751751

752752
EXPECT_THAT(rownnzH, ElementsAre(3, 3, 3));
753753
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
@@ -794,8 +794,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse11) {
794794
mjtNum diag[] = {1, 1, 1};
795795

796796
// test precount
797-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
798-
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
797+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
798+
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
799799

800800
EXPECT_THAT(rownnzH, ElementsAre(3, 3, 3));
801801
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
@@ -842,8 +842,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse12) {
842842
mjtNum diag[] = {1, 1, 1};
843843

844844
// test precount
845-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 4, rownnz, rowadr, colind,
846-
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
845+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 4, rownnz, rowadr, colind,
846+
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
847847

848848
EXPECT_THAT(rownnzH, ElementsAre(4, 4, 4, 4));
849849
EXPECT_THAT(rowadrH, ElementsAre(0, 4, 8, 12));
@@ -894,8 +894,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse13) {
894894
mjtNum diag[] = {1, 1, 1};
895895

896896
// test precount
897-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 5, rownnz, rowadr, colind,
898-
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
897+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 5, rownnz, rowadr, colind,
898+
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
899899

900900
EXPECT_THAT(rownnzH, ElementsAre(2, 2, 0, 0, 0));
901901
EXPECT_THAT(rowadrH, ElementsAre(0, 2, 4, 4, 4));
@@ -944,8 +944,8 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse14) {
944944
int rowadrH[] = {0, 0, 0, 0, 0, 0, 0};
945945

946946
// test precount
947-
mju_sqrMatTDSparseInit(rownnzH, rowadrH, 7, rownnz, rowadr, colind,
948-
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
947+
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 7, rownnz, rowadr, colind,
948+
rownnzT, rowadrT, colindT, rowsuperT, data, 1);
949949

950950
EXPECT_THAT(rownnzH, ElementsAre(7, 7, 7, 7, 7, 7, 7));
951951
EXPECT_THAT(rowadrH, ElementsAre(0, 7, 14, 21, 28, 35, 42));
@@ -985,8 +985,8 @@ TEST_F(EngineUtilSparseTest, MjuCholFactorNNZ) {
985985
int colindA[4];
986986
int rownnzA_factor[2];
987987
mju_dense2sparse(sparseA, matA, nA, nA, rownnzA, rowadrA, colindA, 4);
988-
int nnzA = mju_cholFactorNNZ(rownnzA_factor,
989-
rownnzA, rowadrA, colindA, nA, d);
988+
int nnzA = mju_cholFactorCount(rownnzA_factor,
989+
rownnzA, rowadrA, colindA, nA, d);
990990

991991
EXPECT_EQ(nnzA, 2);
992992
EXPECT_THAT(AsVector(rownnzA_factor, 2), ElementsAre(1, 1));
@@ -1001,8 +1001,8 @@ TEST_F(EngineUtilSparseTest, MjuCholFactorNNZ) {
10011001
int colindB[9];
10021002
int rownnzB_factor[3];
10031003
mju_dense2sparse(sparseB, matB, nB, nB, rownnzB, rowadrB, colindB, 9);
1004-
int nnzB = mju_cholFactorNNZ(rownnzB_factor,
1005-
rownnzB, rowadrB, colindB, nB, d);
1004+
int nnzB = mju_cholFactorCount(rownnzB_factor,
1005+
rownnzB, rowadrB, colindB, nB, d);
10061006

10071007
EXPECT_EQ(nnzB, 5);
10081008
EXPECT_THAT(AsVector(rownnzB_factor, 3), ElementsAre(1, 2, 2));
@@ -1017,8 +1017,8 @@ TEST_F(EngineUtilSparseTest, MjuCholFactorNNZ) {
10171017
int colindC[9];
10181018
int rownnzC_factor[3];
10191019
mju_dense2sparse(sparseC, matC, nC, nC, rownnzC, rowadrC, colindC, 9);
1020-
int nnzC = mju_cholFactorNNZ(rownnzC_factor,
1021-
rownnzC, rowadrC, colindC, nC, d);
1020+
int nnzC = mju_cholFactorCount(rownnzC_factor,
1021+
rownnzC, rowadrC, colindC, nC, d);
10221022

10231023
EXPECT_EQ(nnzC, 4);
10241024
EXPECT_THAT(AsVector(rownnzC_factor, 3), ElementsAre(1, 2, 1));
@@ -1034,8 +1034,8 @@ TEST_F(EngineUtilSparseTest, MjuCholFactorNNZ) {
10341034
int colindD[16];
10351035
int rownnzD_factor[4];
10361036
mju_dense2sparse(sparseD, matD, nD, nD, rownnzD, rowadrD, colindD, 16);
1037-
int nnzD = mju_cholFactorNNZ(rownnzD_factor,
1038-
rownnzD, rowadrD, colindD, nD, d);
1037+
int nnzD = mju_cholFactorCount(rownnzD_factor,
1038+
rownnzD, rowadrD, colindD, nD, d);
10391039

10401040
EXPECT_EQ(nnzD, 8);
10411041
EXPECT_THAT(AsVector(rownnzD_factor, 4), ElementsAre(1, 2, 2, 3));

0 commit comments

Comments
 (0)