Skip to content

Commit e39a5df

Browse files
yuvaltassacopybara-github
authored andcommitted
mju_compressSparse(): add option to compress away small elements.
PiperOrigin-RevId: 726010576 Change-Id: Ib6458037a5cf9a9d82abb364b7dda7aab2f36807
1 parent faed1a1 commit e39a5df

File tree

4 files changed

+62
-14
lines changed

4 files changed

+62
-14
lines changed

src/engine/engine_util_sparse.c

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -438,22 +438,34 @@ int mju_addChains(int* res, int n, int NV1, int NV2,
438438

439439

440440

441-
// compress layout of sparse matrix
442-
void mju_compressSparse(mjtNum* mat, int nr, int nc, int* rownnz, int* rowadr, int* colind) {
443-
rowadr[0] = 0;
444-
int adr = rownnz[0];
445-
for (int r=1; r < nr; r++) {
441+
// compress sparse matrix, remove elements with abs(value) <= minval, return total non-zeros
442+
int mju_compressSparse(mjtNum* mat, int nr, int nc, int* rownnz, int* rowadr, int* colind,
443+
mjtNum minval) {
444+
int remove_small = minval >= 0;
445+
int adr = 0;
446+
for (int r=0; r < nr; r++) {
446447
// save old rowadr, record new
447-
int rowadr1 = rowadr[r];
448+
int rowadr_old = rowadr[r];
448449
rowadr[r] = adr;
449450

450-
// shift mat and mat_colind
451-
for (int adr1=rowadr1; adr1 < rowadr1+rownnz[r]; adr1++) {
452-
mat[adr] = mat[adr1];
453-
colind[adr] = colind[adr1];
451+
// shift mat and colind
452+
int nnz = 0;
453+
int end = rowadr_old + rownnz[r];
454+
for (int adr_old=rowadr_old; adr_old < end; adr_old++) {
455+
if (remove_small && mju_abs(mat[adr_old]) <= minval) {
456+
continue;
457+
}
458+
if (adr != adr_old) {
459+
mat[adr] = mat[adr_old];
460+
colind[adr] = colind[adr_old];
461+
}
454462
adr++;
463+
if (remove_small) nnz++;
455464
}
465+
if (remove_small) rownnz[r] = nnz;
456466
}
467+
468+
return rowadr[nr-1] + rownnz[nr-1];
457469
}
458470

459471

src/engine/engine_util_sparse.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ MJAPI void mju_mulMatVecSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec
5050
MJAPI void mju_mulMatTVecSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int nr, int nc,
5151
const int* rownnz, const int* rowadr, const int* colind);
5252

53-
// compress layout of sparse matrix
54-
MJAPI void mju_compressSparse(mjtNum* mat, int nr, int nc,
55-
int* rownnz, int* rowadr, int* colind);
53+
// compress sparse matrix, remove elements with abs(value) <= minval, return total non-zeros
54+
MJAPI int mju_compressSparse(mjtNum* mat, int nr, int nc,
55+
int* rownnz, int* rowadr, int* colind, mjtNum minval);
5656

5757
// count the number of non-zeros in the sum of two sparse vectors
5858
MJAPI int mju_combineSparseCount(int a_nnz, int b_nnz, const int* a_ind, const int* b_ind);

test/benchmark/engine_util_sparse_benchmark_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ void ABSL_ATTRIBUTE_NOINLINE transposeSparse_baseline(
167167
}
168168
}
169169

170-
mju_compressSparse(res, nc, nr, res_rownnz, res_rowadr, res_colind);
170+
mju_compressSparse(res, nc, nr, res_rownnz, res_rowadr, res_colind,
171+
/*minval=*/-1);
171172
}
172173

173174
int compare_baseline(const int* vec1,

test/engine/engine_util_sparse_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,41 @@ TEST_F(EngineUtilSparseTest, MjuTransposeNullMatrix) {
300300
EXPECT_THAT(rowadrT, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0, 0));
301301
}
302302

303+
TEST_F(EngineUtilSparseTest, MjuCompressSparse) {
304+
// sparse matrix (uncompressed with spurious values between the rows):
305+
// [[1, 0, 2]
306+
// [0, O, 3] (second zero represented)
307+
mjtNum mat[] = {1, 2, 9, 0, 3}; // spurious 9 value
308+
int colind[] = {0, 2, -1, 1, 2}; // spurious -1 index
309+
int rownnz[] = {2, 2};
310+
int rowadr[] = {0, 3};
311+
312+
mjtNum dense_expected[] = {1, 0, 2, 0, 0, 3};
313+
mjtNum dense[6];
314+
mju_sparse2dense(dense, mat, 2, 3, rownnz, rowadr, colind);
315+
EXPECT_EQ(AsVector(dense, 6), AsVector(dense_expected, 6));
316+
317+
// check that spurious values are removed
318+
int nnz = mju_compressSparse(mat, 2, 3, rownnz, rowadr, colind,
319+
/*minval=*/-1);
320+
EXPECT_EQ(nnz, 4);
321+
mju_sparse2dense(dense, mat, 2, 3, rownnz, rowadr, colind);
322+
EXPECT_EQ(AsVector(dense, 6), AsVector(dense_expected, 6));
323+
324+
// check that represented zero gets compressed aways with minval=0
325+
nnz = mju_compressSparse(mat, 2, 3, rownnz, rowadr, colind, /*minval=*/0);
326+
EXPECT_EQ(nnz, 3);
327+
mju_sparse2dense(dense, mat, 2, 3, rownnz, rowadr, colind);
328+
EXPECT_EQ(AsVector(dense, 6), AsVector(dense_expected, 6));
329+
330+
// check that 1 gets compressed aways with minval=1
331+
nnz = mju_compressSparse(mat, 2, 3, rownnz, rowadr, colind, /*minval=*/1);
332+
EXPECT_EQ(nnz, 2);
333+
mju_sparse2dense(dense, mat, 2, 3, rownnz, rowadr, colind);
334+
mjtNum dense_expected_minval1[] = {0, 0, 2, 0, 0, 3};
335+
EXPECT_EQ(AsVector(dense, 6), AsVector(dense_expected_minval1, 6));
336+
}
337+
303338
static constexpr char modelStr[] = R"(<mujoco/>)";
304339

305340
TEST_F(EngineUtilSparseTest, MjuSqrMatTDSparse1) {

0 commit comments

Comments
 (0)