Skip to content

Commit e42370c

Browse files
yuvaltassacopybara-github
authored andcommitted
Avoid allocation and copying in implicit solver's addJTBJSparse
PiperOrigin-RevId: 710679558 Change-Id: Ie75fdcd1127ff619c2668f568ab6a8a11e079dda
1 parent 8cb253b commit e42370c

File tree

3 files changed

+14
-24
lines changed

3 files changed

+14
-24
lines changed

src/engine/engine_derivative.c

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,6 @@ static void addJTBJSparse(
730730
const mjModel* m, mjData* d, const mjtNum* J,
731731
const mjtNum* B, int n, int offset,
732732
const int* J_rownnz, const int* J_rowadr, const int* J_colind) {
733-
int nv = m->nv;
734-
735-
// allocate row
736-
mj_markStack(d);
737-
mjtNum* row = mjSTACKALLOC(d, nv, mjtNum);
738733

739734
// compute qDeriv(k,p) += sum_{i,j} ( J(i,k)*B(i,j)*J(j,p) )
740735
for (int i = 0; i < n; i++) {
@@ -749,19 +744,14 @@ static void addJTBJSparse(
749744
int ik = J_rowadr[offset_i] + k;
750745
int colik = J_colind[ik];
751746

752-
// row = J(i,k)*B(i,j)*J(j,:)
753-
mju_scl(row, J + J_rowadr[offset_j], J[ik]*B[i*n+j], J_rownnz[offset_j]);
754-
755-
// qDeriv(k,:) += row
756-
mju_addToSparseInc(d->qDeriv + d->D_rowadr[colik], row,
757-
d->D_rownnz[colik], d->D_colind + d->D_rowadr[colik],
758-
J_rownnz[offset_j], J_colind + J_rowadr[offset_j]);
747+
// qDeriv(k,:) += J(j,:) * J(i,k)*B(i,j)
748+
mju_addToSclSparseInc(d->qDeriv + d->D_rowadr[colik], J + J_rowadr[offset_j],
749+
d->D_rownnz[colik], d->D_colind + d->D_rowadr[colik],
750+
J_rownnz[offset_j], J_colind + J_rowadr[offset_j],
751+
J[ik]*B[i*n+j]);
759752
}
760753
}
761754
}
762-
763-
// free space
764-
mj_freeStack(d);
765755
}
766756

767757

src/engine/engine_util_sparse.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,10 @@ void mju_combineSparseInc(mjtNum* dst, const mjtNum* src, int n, mjtNum a, mjtNu
355355

356356

357357

358-
// dst += src, only at common non-zero indices
359-
void mju_addToSparseInc(mjtNum* dst, const mjtNum* src,
360-
int nnzdst, const int* inddst,
361-
int nnzsrc, const int* indsrc) {
358+
// dst += scl*src, only at common non-zero indices
359+
void mju_addToSclSparseInc(mjtNum* dst, const mjtNum* src,
360+
int nnzdst, const int* inddst,
361+
int nnzsrc, const int* indsrc, mjtNum scl) {
362362
if (!nnzdst || !nnzsrc) {
363363
return;
364364
}
@@ -368,7 +368,7 @@ void mju_addToSparseInc(mjtNum* dst, const mjtNum* src,
368368
// common non-zero index
369369
if (inds == indd) {
370370
// add
371-
dst[adrd] += src[adrs];
371+
dst[adrd] += scl * src[adrs];
372372

373373
// advance src
374374
if (++adrs < nnzsrc) {

src/engine/engine_util_sparse.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ int mju_combineSparse(mjtNum* dst, const mjtNum* src, mjtNum a, mjtNum b,
6363
void mju_combineSparseInc(mjtNum* dst, const mjtNum* src, int n, mjtNum a, mjtNum b,
6464
int dst_nnz, int src_nnz, int* dst_ind, const int* src_ind);
6565

66-
// dst += src, only at common non-zero indices
67-
void mju_addToSparseInc(mjtNum* dst, const mjtNum* src,
68-
int nnzdst, const int* inddst,
69-
int nnzsrc, const int* indsrc);
66+
// dst += scl * src, only at common non-zero indices
67+
void mju_addToSclSparseInc(mjtNum* dst, const mjtNum* src,
68+
int nnzdst, const int* inddst,
69+
int nnzsrc, const int* indsrc, mjtNum scl);
7070

7171
// add to sparse matrix: dst = dst + scl*src, return nnz of result
7272
int mju_addToSparseMat(mjtNum* dst, const mjtNum* src, int n, int nrow, mjtNum scl,

0 commit comments

Comments
 (0)