Skip to content

Commit d96fa66

Browse files
committed
[SYSTEMDS-3172] Performance improvement CSC sparse block
This patch makes some simple performance improvement in order to reduce the runtime of the sparse component tests (300+s -> 30s). In detail the runtime of specific tests improved as follows: * SparseBlockMerge: 149s -> 14.7s * SparseBlockIndexRange: 110s -> 13.4s * SparseBlockGetFirstIndex: 29s -> 1.3s
1 parent 81efec8 commit d96fa66

File tree

1 file changed

+33
-101
lines changed

1 file changed

+33
-101
lines changed

src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java

Lines changed: 33 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.util.ArrayList;
2929
import java.util.Arrays;
3030
import java.util.BitSet;
31-
import java.util.Comparator;
3231
import java.util.Iterator;
3332
import java.util.List;
3433

@@ -119,10 +118,8 @@ else if(sblock instanceof SparseBlockMCSC) {
119118
for(SparseRow column : columns) {
120119
int rowIdx[] = column.indexes();
121120
double vals[] = column.values();
122-
for(int i = 0; i < column.size(); i++) {
123-
_indexes[valPos + i] = rowIdx[i];
124-
_values[valPos + i] = vals[i];
125-
}
121+
System.arraycopy(rowIdx, 0, _indexes, valPos, column.size());
122+
System.arraycopy(vals, 0, _values, valPos, column.size());
126123
_ptr[ptrPos] = _ptr[ptrPos - 1] + column.size();
127124
ptrPos++;
128125
valPos += column.size();
@@ -483,8 +480,9 @@ public int size(int r) {
483480
throw new RuntimeException("Row index has to be zero or larger.");
484481

485482
int nnz = 0;
486-
for(int i = 0; i < _size; i++) {
487-
if(_indexes[i] == r)
483+
for(int c=0; c<_ptr.length-1; c++) {
484+
int ix = Arrays.binarySearch(_indexes, _ptr[c], _ptr[c+1], r);
485+
if(ix >= 0)
488486
nnz++;
489487
}
490488
return nnz;
@@ -531,12 +529,13 @@ public long size(int rl, int ru, int cl, int cu) {
531529

532530
@Override
533531
public boolean isEmpty(int r) {
534-
boolean empty = true;
535-
for(int i = 0; i < _size; i++) {
536-
if(_indexes[i] == r)
532+
int clen = numCols();
533+
for(int c=0; c<clen; c++) {
534+
int ix = Arrays.binarySearch(_indexes, _ptr[c], _ptr[c+1], r);
535+
if(ix >= 0)
537536
return false;
538537
}
539-
return empty;
538+
return true;
540539
}
541540

542541
public boolean isEmptyCol(int c) {
@@ -609,26 +608,15 @@ public long getExactSizeInMemory() {
609608

610609
@Override
611610
public int[] indexes(int r) {
612-
//Count elements per row
613-
//int[] rowCounts = numElemPerRow();
614-
615-
// Compute csr pointers
616-
int[] csrPtr = rowPointerAll();
617-
618-
// Populate CSR indices array
619-
int[] csrIndices = new int[_size];
620-
// Temporary array to keep track of the current position in each row
621-
int[] currentPos = Arrays.copyOf(csrPtr, _rlen);
622-
623-
for(int col = 0; col < numCols(); col++) {
624-
for(int i = _ptr[col]; i < _ptr[col + 1]; i++) {
625-
int row = _indexes[i];
626-
int pos = currentPos[row]++;
627-
csrIndices[pos] = col;
628-
}
611+
int clen = numCols();
612+
int[] cix = new int[clen];
613+
int pos = 0;
614+
for(int c = 0; c < clen; c++) {
615+
int ix = Arrays.binarySearch(_indexes, _ptr[c], _ptr[c+1], r);
616+
if(ix >= 0)
617+
cix[pos++] = c;
629618
}
630-
631-
return csrIndices;
619+
return cix;
632620
}
633621

634622
public int[] indexesCol(int c) {
@@ -637,20 +625,15 @@ public int[] indexesCol(int c) {
637625

638626
@Override
639627
public double[] values(int r) {
640-
// Only use first _size elements for sorting
641-
Integer[] idx = new Integer[_size];
642-
for(int i = 0; i < _size; i++)
643-
idx[i] = i;
644-
645-
// Sort indices based on corresponding index values
646-
Arrays.sort(idx, Comparator.comparingInt(i -> _indexes[i]));
647-
648-
// Create values array sorted in row order
649-
double[] csrValues = new double[_size];
650-
for(int i = 0; i < _size; i++) {
651-
csrValues[i] = _values[idx[i]];
628+
int clen = numCols();
629+
double[] vals = new double[clen];
630+
int pos = 0;
631+
for(int c = 0; c < clen; c++) {
632+
int ix = Arrays.binarySearch(_indexes, _ptr[c], _ptr[c+1], r);
633+
if(ix >= 0)
634+
vals[pos++] = _values[ix];
652635
}
653-
return csrValues;
636+
return vals;
654637
}
655638

656639
public double[] valuesCol(int c) {
@@ -659,12 +642,7 @@ public double[] valuesCol(int c) {
659642

660643
@Override
661644
public int pos(int r) {
662-
int nnz = 0;
663-
for(int i = 0; i < _size; i++) {
664-
if(_indexes[i] < r)
665-
nnz++;
666-
}
667-
return nnz;
645+
return 0;
668646
}
669647

670648
public int posCol(int c) {
@@ -787,7 +765,6 @@ public void append(int r, int c, double v) {
787765
shiftRightAndInsert(pos + len, r, v);
788766
}
789767
incrPtr(c + 1);
790-
791768
}
792769

793770
@Override
@@ -1005,27 +982,12 @@ public double get(int r, int c) {
1005982

1006983
@Override
1007984
public SparseRow get(int r) {
1008-
int rowSize = size(r);
1009-
if(rowSize == 0)
1010-
return new SparseRowScalar();
1011-
1012-
//Create sparse row
1013-
SparseRowVector row = new SparseRowVector(rowSize);
1014-
1015-
for(int i = 0; i < _size; i++) {
1016-
if(_indexes[i] == r) {
1017-
//Search for index i in pointer array
1018-
for(int j = 0; j < _ptr.length; j++) {
1019-
// two possible cases
1020-
if(_ptr[j] < i && _ptr[j + 1] > i) {
1021-
row.set(j, _values[i]);
1022-
}
1023-
else if(_ptr[j] == i && _ptr[j + 1] > i) {
1024-
row.set(j, _values[i]);
1025-
break;
1026-
}
1027-
}
1028-
}
985+
int clen = numCols();
986+
SparseRowVector row = new SparseRowVector(clen);
987+
for(int c = 0; c < clen; c++) {
988+
int ix = Arrays.binarySearch(_indexes, _ptr[c], _ptr[c+1], r);
989+
if(ix >= 0)
990+
row.append(c, _values[ix]);
1029991
}
1030992
return row;
1031993
}
@@ -1329,16 +1291,6 @@ private void incrPtr(int cl, int cnt) {
13291291
_ptr[i] += cnt;
13301292
}
13311293

1332-
private void incrRowPtr(int rl, int[] csrPtr) {
1333-
incrRowPtr(rl, csrPtr, 1);
1334-
}
1335-
1336-
private void incrRowPtr(int rl, int[] csrPtr, int cnt) {
1337-
for(int i = rl; i < csrPtr.length; i++) {
1338-
csrPtr[i] += cnt;
1339-
}
1340-
}
1341-
13421294
private void decrPtr(int cl) {
13431295
decrPtr(cl, 1);
13441296
}
@@ -1348,26 +1300,6 @@ private void decrPtr(int cl, int cnt) {
13481300
_ptr[i] -= cnt;
13491301
}
13501302

1351-
@SuppressWarnings("unused")
1352-
private int[] numElemPerRow() {
1353-
int rlen = numRows();
1354-
int[] rowCount = new int[rlen];
1355-
for(int i = 0; i < _size; i++) {
1356-
rowCount[_indexes[i]] += 1;
1357-
}
1358-
return rowCount;
1359-
}
1360-
1361-
private int[] rowPointerAll() {
1362-
int rlen = numRows();
1363-
int[] csrPtr = new int[rlen + 1];
1364-
csrPtr[0] = 0;
1365-
for(int i = 0; i < _size; i++)
1366-
incrRowPtr(_indexes[i] + 1, csrPtr);
1367-
1368-
return csrPtr;
1369-
}
1370-
13711303
private int internPosFIndexLTECol(int r, int c) {
13721304
int pos = posCol(c);
13731305
int len = sizeCol(c);

0 commit comments

Comments
 (0)