Skip to content

Commit 00010f5

Browse files
yuvaltassacopybara-github
authored andcommitted
Compute diagonal indices in mj_sqrMatTDSparse
PiperOrigin-RevId: 713714087 Change-Id: Icc8eae74e6e47ba1d12a6e9aa0d774ab006cfe38
1 parent daba00c commit 00010f5

File tree

6 files changed

+53
-34
lines changed

6 files changed

+53
-34
lines changed

src/engine/engine_core_constraint.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,23 +2174,20 @@ void mj_projectConstraint(const mjModel* m, mjData* d) {
21742174
// construct supernodes
21752175
mju_superSparse(nefc, rowsuper, rownnz, rowadr, colind);
21762176

2177-
// AR = JM2 * JM2'
2177+
// pre-count efc_AR_rownnz, efc_AR_rowadr
21782178
mju_sqrMatTDSparseCount(d->efc_AR_rownnz, d->efc_AR_rowadr, nefc, rownnzT,
21792179
rowadrT, colindT, rownnz, rowadr, colind, rowsuper, d, /*flg_upper=*/1);
21802180

2181+
// AR = JM2 * JM2'
2182+
int* diagind = mjSTACKALLOC(d, nefc, int);
21812183
mju_sqrMatTDSparse(d->efc_AR, JM2T, JM2, NULL, nv, nefc,
21822184
d->efc_AR_rownnz, d->efc_AR_rowadr, d->efc_AR_colind,
21832185
rownnzT, rowadrT, colindT, NULL,
2184-
rownnz, rowadr, colind, rowsuper, d, /*flg_upper=*/1);
2186+
rownnz, rowadr, colind, rowsuper, d, diagind);
21852187

21862188
// add R to diagonal of AR
21872189
for (int i=0; i < nefc; i++) {
2188-
for (int j=0; j < d->efc_AR_rownnz[i]; j++) {
2189-
if (i == d->efc_AR_colind[d->efc_AR_rowadr[i]+j]) {
2190-
d->efc_AR[d->efc_AR_rowadr[i]+j] += d->efc_R[i];
2191-
break;
2192-
}
2193-
}
2190+
d->efc_AR[diagind[i]] += d->efc_R[i];
21942191
}
21952192
}
21962193

src/engine/engine_solver.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,7 @@ static void MakeHessian(const mjModel* m, mjData* d, mjCGContext* ctx) {
14241424
ctx->H_rownnz, ctx->H_rowadr, ctx->H_colind,
14251425
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind, NULL,
14261426
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind, d->efc_JT_rowsuper,
1427-
d, /*flg_upper=*/0);
1427+
d, /*diagind=*/NULL);
14281428

14291429
// add mass matrix: H = J'*D*J + C
14301430
mj_addMSparse(m, d, ctx->H, ctx->H_rownnz, ctx->H_rowadr, ctx->H_colind,
@@ -1518,7 +1518,7 @@ static void FactorizeHessian(const mjModel* m, mjData* d, mjCGContext* ctx,
15181518
ctx->H_rownnz, ctx->H_rowadr, ctx->H_colind,
15191519
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind, NULL,
15201520
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind, d->efc_JT_rowsuper,
1521-
d, /*flg_upper=*/0);
1521+
d, /*diagind=*/NULL);
15221522

15231523
// add mass matrix: H = J'*D*J + C
15241524
mj_addMSparse(m, d, ctx->H, ctx->H_rownnz, ctx->H_rowadr, ctx->H_colind,

src/engine/engine_util_sparse.c

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
724724
const int* colind, const int* rowsuper,
725725
const int* rownnzT, const int* rowadrT,
726726
const int* colindT, const int* rowsuperT,
727-
mjData* d, int flg_upper) {
727+
mjData* d, int* diagind) {
728728
// allocate space for accumulation buffer and matT
729729
mj_markStack(d);
730730

@@ -838,8 +838,14 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
838838
}
839839

840840

841-
// fill upper triangle
842-
if (flg_upper) {
841+
// diagonal indices requested: fill upper triangle
842+
if (diagind) {
843+
// save diagonal indices
844+
for (int i=0; i < nc; i++) {
845+
diagind[i] = res_rowadr[i] + res_rownnz[i] - 1;
846+
}
847+
848+
// fill upper triangle
843849
for (int i=0; i < nc; i++) {
844850
int start = res_rowadr[i];
845851
int end = start + res_rownnz[i] - 1;

src/engine/engine_util_sparse.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,15 @@ MJAPI void mju_transposeSparse(mjtNum* res, const mjtNum* mat, int nr, int nc,
8989
MJAPI void mju_superSparse(int nr, int* rowsuper,
9090
const int* rownnz, const int* rowadr, const int* colind);
9191

92-
// compute sparse M'*diag*M (diag=NULL: compute M'*M), res has uncompressed layout
93-
// res_rowadr is required to be precomputed
92+
// compute sparse M'*diag*M (diag=NULL: compute M'*M), res_rowadr must be precomputed
9493
MJAPI void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
9594
const mjtNum* diag, int nr, int nc,
9695
int* res_rownnz, const int* res_rowadr, int* res_colind,
9796
const int* rownnz, const int* rowadr,
9897
const int* colind, const int* rowsuper,
9998
const int* rownnzT, const int* rowadrT,
10099
const int* colindT, const int* rowsuperT,
101-
mjData* d, int flg_upper);
100+
mjData* d, int* diagind);
102101

103102
// precount res_rownnz and precompute res_rowadr for mju_sqrMatTDSparse
104103
MJAPI void mju_sqrMatTDSparseCount(int* res_rownnz, int* res_rowadr, int nr,

test/benchmark/engine_util_sparse_benchmark_test.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void ABSL_ATTRIBUTE_NOINLINE mju_sqrMatTDSparse_baseline(
4343
int nr, int nc, int* res_rownnz, int* res_rowadr, int* res_colind,
4444
const int* rownnz, const int* rowadr, const int* colind,
4545
const int* rowsuper, const int* rownnzT, const int* rowadrT,
46-
const int* colindT, const int* rowsuperT, mjData* d, int unused) {
46+
const int* colindT, const int* rowsuperT, mjData* d, int* unused) {
4747
mj_markStack(d);
4848
int* chain = mj_stackAllocInt(d, 2 * nc);
4949
mjtNum* buffer = mj_stackAllocNum(d, nc);
@@ -435,6 +435,7 @@ static void BM_combineSparse(benchmark::State& state, CombineFuncPtr func) {
435435
int* rownnz = mj_stackAllocInt(d, m->nv);
436436
int* rowadr = mj_stackAllocInt(d, m->nv);
437437
int* colind = mj_stackAllocInt(d, m->nv*m->nv);
438+
int* diagind = mj_stackAllocInt(d, m->nv);
438439

439440
// compute D corresponding to quad states
440441
mjtNum* D = mj_stackAllocNum(d, d->nefc);
@@ -454,7 +455,7 @@ static void BM_combineSparse(benchmark::State& state, CombineFuncPtr func) {
454455
d->efc_J_colind, d->efc_J_rowsuper,
455456
d->efc_JT_rownnz, d->efc_JT_rowadr,
456457
d->efc_JT_colind, d->efc_JT_rowsuper, d,
457-
/*flg_upper=*/1);
458+
diagind);
458459

459460
// compute H = M + J'*D*J
460461
mj_addM(m, d, H, rownnz, rowadr, colind);
@@ -559,6 +560,7 @@ static void BM_sqrMatTDSparse(benchmark::State& state, SqrMatTDFuncPtr func) {
559560
int* rownnz = mj_stackAllocInt(d, m->nv);
560561
int* rowadr = mj_stackAllocInt(d, m->nv);
561562
int* colind = mj_stackAllocInt(d, m->nv * m->nv);
563+
int* diagind = mj_stackAllocInt(d, m->nv);
562564

563565
// compute D corresponding to quad states
564566
mjtNum* D = mj_stackAllocNum(d, d->nefc);
@@ -579,7 +581,7 @@ static void BM_sqrMatTDSparse(benchmark::State& state, SqrMatTDFuncPtr func) {
579581
func(H, d->efc_J, d->efc_JT, D, d->nefc, m->nv, rownnz, rowadr, colind,
580582
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind, NULL,
581583
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind,
582-
d->efc_JT_rowsuper, d, /*flg_upper=*/1);
584+
d->efc_JT_rowsuper, d, diagind);
583585
}
584586
} else {
585587
for (auto s : state) {
@@ -592,7 +594,7 @@ static void BM_sqrMatTDSparse(benchmark::State& state, SqrMatTDFuncPtr func) {
592594
H, d->efc_J, d->efc_JT, D, d->nefc, m->nv, rownnz, rowadr, colind,
593595
d->efc_J_rownnz, d->efc_J_rowadr, d->efc_J_colind, d->efc_J_rowsuper,
594596
d->efc_JT_rownnz, d->efc_JT_rowadr, d->efc_JT_colind,
595-
d->efc_JT_rowsuper, d, /*unused=*/0);
597+
d->efc_JT_rowsuper, d, /*unused=*/nullptr);
596598
}
597599
}
598600

test/engine/engine_util_sparse_test.cc

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse1) {
324324
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
325325
int rownnzH[] = {0, 0, 0};
326326
int rowadrH[] = {0, 0, 0};
327+
int diagindH[] = {0, 0, 0};
327328

328329
// test precount
329330
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
@@ -336,7 +337,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse1) {
336337
mju_sqrMatTDUncompressedInit(rowadrH, 3);
337338
mju_sqrMatTDSparse(matH, mat, matT, nullptr, 3, 3, rownnzH, rowadrH, colindH,
338339
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
339-
nullptr, data, 1);
340+
nullptr, data, diagindH);
340341

341342
EXPECT_THAT(matH, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0));
342343
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2));
@@ -369,6 +370,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse2) {
369370
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
370371
int rownnzH[] = {0, 0, 0};
371372
int rowadrH[] = {0, 0, 0};
373+
int diagindH[] = {0, 0, 0};
372374

373375
// test precount
374376
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
@@ -382,7 +384,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse2) {
382384
mju_sqrMatTDUncompressedInit(rowadrH, 3);
383385
mju_sqrMatTDSparse(matH, mat, matT, nullptr, 3, 3, rownnzH, rowadrH, colindH,
384386
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
385-
nullptr, data, 1);
387+
nullptr, data, diagindH);
386388

387389
EXPECT_THAT(matH, ElementsAre(12, 0, 12, 0, 6, 3, 12, 3, 14));
388390
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2));
@@ -415,6 +417,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse3) {
415417
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
416418
int rownnzH[] = {0, 0, 0};
417419
int rowadrH[] = {0, 0, 0};
420+
int diagindH[] = {0, 0, 0};
418421

419422
mjtNum diag[] = {2, 3, 4};
420423

@@ -429,7 +432,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse3) {
429432
mju_sqrMatTDUncompressedInit(rowadrH, 3);
430433
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 3, rownnzH, rowadrH, colindH,
431434
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
432-
nullptr, data, 1);
435+
nullptr, data, diagindH);
433436

434437
EXPECT_THAT(matH, ElementsAre(66, 4, 0, 4, 35, 0, 0, 0, 0));
435438
EXPECT_THAT(colindH, ElementsAre(0, 1, 0, 0, 1, 0, 0, 0, 0));
@@ -462,6 +465,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse4) {
462465
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
463466
int rownnzH[] = {0, 0, 0};
464467
int rowadrH[] = {0, 0, 0};
468+
int diagindH[] = {0, 0, 0};
465469

466470
mjtNum diag[] = {2, 3, 4};
467471

@@ -477,7 +481,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse4) {
477481
mju_sqrMatTDUncompressedInit(rowadrH, 3);
478482
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 3, rownnzH, rowadrH, colindH,
479483
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
480-
nullptr, data, 1);
484+
nullptr, data, diagindH);
481485

482486
EXPECT_THAT(matH, ElementsAre(66, 4, 0, 0, 0, 0, 4, 35, 0));
483487
EXPECT_THAT(colindH, ElementsAre(0, 2, 0, 0, 0, 0, 0, 2, 0));
@@ -510,6 +514,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse5) {
510514
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
511515
int rownnzH[] = {0, 0, 0};
512516
int rowadrH[] = {0, 0, 0};
517+
int diagindH[] = {0, 0, 0};
513518

514519

515520
// test precount
@@ -523,7 +528,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse5) {
523528
mju_sqrMatTDUncompressedInit(rowadrH, 3);
524529
mju_sqrMatTDSparse(matH, mat, matT, nullptr, 3, 3, rownnzH, rowadrH, colindH,
525530
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
526-
nullptr, data, 1);
531+
nullptr, data, diagindH);
527532

528533
EXPECT_THAT(matH, ElementsAre(5, 6, 4, 6, 9, 0, 4, 16, 0));
529534
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 0, 0, 2, 0));
@@ -556,6 +561,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse6) {
556561
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
557562
int rownnzH[] = {0, 0, 0};
558563
int rowadrH[] = {0, 0, 0};
564+
int diagindH[] = {0, 0, 0};
559565

560566
// test precount
561567
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 3, rownnz, rowadr, colind,
@@ -568,12 +574,13 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse6) {
568574
mju_sqrMatTDUncompressedInit(rowadrH, 3);
569575
mju_sqrMatTDSparse(matH, mat, matT, nullptr, 3, 3, rownnzH, rowadrH, colindH,
570576
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
571-
nullptr, data, 1);
577+
nullptr, data, diagindH);
572578

573579
EXPECT_THAT(matH, ElementsAre(1, 2, 0, 4, 0, 0, 2, 13, 0));
574580
EXPECT_THAT(colindH, ElementsAre(0, 2, 0, 1, 0, 0, 0, 2, 0));
575581
EXPECT_THAT(rownnzH, ElementsAre(2, 1, 2));
576582
EXPECT_THAT(rowadrH, ElementsAre(0, 3, 6));
583+
EXPECT_THAT(diagindH, ElementsAre(0, 3, 7));
577584

578585
mj_deleteData(data);
579586
mj_deleteModel(model);
@@ -601,6 +608,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse7) {
601608
int colindH[] = {0, 0, 0, 0};
602609
int rownnzH[] = {0, 0};
603610
int rowadrH[] = {0, 0};
611+
int diagindH[] = {0, 0};
604612

605613
mjtNum diag[] = {2, 3, 4};
606614

@@ -615,7 +623,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse7) {
615623
mju_sqrMatTDUncompressedInit(rowadrH, 2);
616624
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 2, rownnzH, rowadrH, colindH,
617625
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
618-
nullptr, data, 1);
626+
nullptr, data, diagindH);
619627

620628
EXPECT_THAT(matH, ElementsAre(66, 4, 4, 35));
621629
EXPECT_THAT(colindH, ElementsAre(0, 1, 0, 1));
@@ -647,6 +655,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse8) {
647655
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
648656
int rownnzH[] = {0, 0, 0};
649657
int rowadrH[] = {0, 0, 0};
658+
int diagindH[] = {0, 0, 0};
650659

651660
mjtNum diag[] = {2, 3};
652661

@@ -661,7 +670,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse8) {
661670
mju_sqrMatTDUncompressedInit(rowadrH, 3);
662671
mju_sqrMatTDSparse(matH, mat, matT, diag, 2, 3, rownnzH, rowadrH, colindH,
663672
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
664-
nullptr, data, 1);
673+
nullptr, data, diagindH);
665674

666675
EXPECT_THAT(matH, ElementsAre(14, 18, 8, 18, 27, 0, 8, 32, 0));
667676
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 0, 0, 2, 0));
@@ -694,6 +703,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse9) {
694703
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
695704
int rownnzH[] = {0, 0, 0};
696705
int rowadrH[] = {0, 0, 0};
706+
int diagindH[] = {0, 0, 0};
697707

698708
mjtNum diag[] = {2, 3, 4};
699709

@@ -708,7 +718,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse9) {
708718
mju_sqrMatTDUncompressedInit(rowadrH, 3);
709719
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 3, rownnzH, rowadrH, colindH,
710720
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
711-
nullptr, data, 1);
721+
nullptr, data, diagindH);
712722

713723
EXPECT_THAT(matH, ElementsAre(69, 77, 80, 77, 99, 108, 80, 108, 120));
714724
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2));
@@ -742,6 +752,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse10) {
742752
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
743753
int rownnzH[] = {0, 0, 0};
744754
int rowadrH[] = {0, 0, 0};
755+
int diagindH[] = {0, 0, 0};
745756

746757
mjtNum diag[] = {1, 1, 1};
747758

@@ -756,7 +767,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse10) {
756767
mju_sqrMatTDUncompressedInit(rowadrH, 3);
757768
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 3, rownnzH, rowadrH, colindH,
758769
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
759-
rowsuperT, data, 1);
770+
rowsuperT, data, diagindH);
760771

761772
EXPECT_THAT(matH, ElementsAre(14, 14, 14, 14, 14, 14, 14, 14, 14));
762773
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2));
@@ -790,6 +801,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse11) {
790801
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
791802
int rownnzH[] = {0, 0, 0};
792803
int rowadrH[] = {0, 0, 0};
804+
int diagindH[] = {0, 0, 0};
793805

794806
mjtNum diag[] = {1, 1, 1};
795807

@@ -804,7 +816,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse11) {
804816
mju_sqrMatTDUncompressedInit(rowadrH, 3);
805817
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 3, rownnzH, rowadrH, colindH,
806818
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
807-
rowsuperT, data, 1);
819+
rowsuperT, data, diagindH);
808820

809821
EXPECT_THAT(matH, ElementsAre(1, 1, 1, 1, 10, 10, 1, 10, 10));
810822
EXPECT_THAT(colindH, ElementsAre(0, 1, 2, 0, 1, 2, 0, 1, 2));
@@ -838,6 +850,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse12) {
838850
int colindH[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
839851
int rownnzH[] = {0, 0, 0, 0};
840852
int rowadrH[] = {0, 0, 0, 0};
853+
int diagindH[] = {0, 0, 0, 0};
841854

842855
mjtNum diag[] = {1, 1, 1};
843856

@@ -852,7 +865,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse12) {
852865
mju_sqrMatTDUncompressedInit(rowadrH, 4);
853866
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 4, rownnzH, rowadrH, colindH,
854867
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
855-
rowsuperT, data, 1);
868+
rowsuperT, data, diagindH);
856869

857870
EXPECT_THAT(matH,
858871
ElementsAre(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 10, 1, 1, 10, 10));
@@ -890,6 +903,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse13) {
890903
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
891904
int rownnzH[] = {0, 0, 0, 0, 0};
892905
int rowadrH[] = {0, 0, 0, 0, 0};
906+
int diagindH[] = {0, 0, 0, 0, 0};
893907

894908
mjtNum diag[] = {1, 1, 1};
895909

@@ -904,7 +918,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse13) {
904918
mju_sqrMatTDUncompressedInit(rowadrH, 5);
905919
mju_sqrMatTDSparse(matH, mat, matT, diag, 3, 5, rownnzH, rowadrH, colindH,
906920
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
907-
rowsuperT, data, 1);
921+
rowsuperT, data, diagindH);
908922

909923
EXPECT_THAT(matH, ElementsAre(3, 3, 0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,
910924
0, 0, 0, 0, 0, 0, 0, 0, 0));
@@ -942,6 +956,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse14) {
942956
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
943957
int rownnzH[] = {0, 0, 0, 0, 0, 0, 0};
944958
int rowadrH[] = {0, 0, 0, 0, 0, 0, 0};
959+
int diagindH[] = {0, 0, 0, 0, 0, 0, 0};
945960

946961
// test precount
947962
mju_sqrMatTDSparseCount(rownnzH, rowadrH, 7, rownnz, rowadr, colind,
@@ -954,7 +969,7 @@ TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse14) {
954969
mju_sqrMatTDUncompressedInit(rowadrH, 7);
955970
mju_sqrMatTDSparse(matH, mat, matT, nullptr, 1, 7, rownnzH, rowadrH, colindH,
956971
rownnz, rowadr, colind, nullptr, rownnzT, rowadrT, colindT,
957-
rowsuperT, data, 1);
972+
rowsuperT, data, diagindH);
958973

959974
EXPECT_THAT(
960975
matH, ElementsAre(1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 2,

0 commit comments

Comments
 (0)