Skip to content

Commit 6c880eb

Browse files
yuvaltassacopybara-github
authored andcommitted
Clean up mju_sqrMatTDSparse
PiperOrigin-RevId: 711486739 Change-Id: I5f9145dcd8522a9cb372cbdf60611d7d05cb524e
1 parent 4510c6d commit 6c880eb

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

src/engine/engine_util_sparse.c

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,8 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
742742
int* markers = mjSTACKALLOC(d, nc, int);
743743

744744
for (int i=0; i < nc; i++) {
745-
int* cols = res_colind+res_rowadr[i];
745+
int rowadr_i = res_rowadr[i];
746+
int* cols = res_colind + rowadr_i;
746747

747748
res_rownnz[i] = 0;
748749
buffer[i] = 0;
@@ -755,18 +756,26 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
755756
}
756757

757758
// iterate through each row of M'
758-
int end = rowadrT[i] + rownnzT[i];
759-
for (int r = rowadrT[i]; r < end; r++) {
759+
int adrT = rowadrT[i];
760+
int end_r = adrT + rownnzT[i];
761+
for (int r = adrT; r < end_r; r++) {
760762
int t = colindT[r];
761-
mjtNum v = diag ? matT[r] * diag[t] : matT[r];
762-
for (int c=rowadr[t]; c < rowadr[t]+rownnz[t]; c++) {
763+
int adr = rowadr[t];
764+
int end_c = adr + rownnz[t];
765+
for (int c=adr; c < end_c; c++) {
763766
int cc = colind[c];
767+
764768
// ignore upper triangle
765769
if (cc > i) {
766770
break;
767771
}
768772

769-
buffer[cc] += v*mat[c];
773+
// add value to buffer
774+
if (diag) {
775+
buffer[cc] += matT[r] * diag[t] * mat[c];
776+
} else {
777+
buffer[cc] += matT[r] * mat[c];
778+
}
770779

771780
// only need to insert nnz if not marked
772781
if (!markers[cc]) {
@@ -810,31 +819,36 @@ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
810819
}
811820
}
812821

813-
end = res_rownnz[i];
822+
end_r = res_rownnz[i];
814823

815824
// rowsuperT: reuse sparsity, copy into res
816825
if (rowsuperT && rowsuperT[i]) {
817-
for (int r=0; r < end; r++) {
818-
res[res_rowadr[i] + r] = buffer[cols[r]];
819-
buffer[cols[r]] = 0;
826+
for (int r=0; r < end_r; r++) {
827+
int c = cols[r];
828+
res[rowadr_i + r] = buffer[c];
829+
buffer[c] = 0;
820830
}
821-
} else {
822-
// clear out buffers since sparsity cannot be reused
823-
for (int r=0; r < end; r++) {
824-
int cc = cols[r];
825-
res[res_rowadr[i] + r] = buffer[cc];
826-
res_colind[res_rowadr[i] + r] = cc;
827-
buffer[cc] = 0;
828-
markers[cc] = 0;
831+
}
832+
833+
// clear out buffers, sparsity cannot be reused
834+
else {
835+
for (int r=0; r < end_r; r++) {
836+
int c = cols[r];
837+
int adr = rowadr_i + r;
838+
res[adr] = buffer[c];
839+
res_colind[adr] = c;
840+
buffer[c] = 0;
841+
markers[c] = 0;
829842
}
830843
}
831844
}
832845

833846

834847
// fill upper triangle
835848
for (int i=0; i < nc; i++) {
836-
int end = res_rowadr[i] + res_rownnz[i] - 1;
837-
for (int j=res_rowadr[i]; j < end; j++) {
849+
int start = res_rowadr[i];
850+
int end = start + res_rownnz[i] - 1;
851+
for (int j=start; j < end; j++) {
838852
int adr = res_rowadr[res_colind[j]] + res_rownnz[res_colind[j]]++;
839853
res[adr] = res[j];
840854
res_colind[adr] = i;

0 commit comments

Comments
 (0)