Skip to content

Commit 3081ecf

Browse files
committed
[SYSTEMDS-3793] Fix transpose performance on ultra-sparse matrices
The multi-threaded implementation of ultra-sparse matrices has a couple of shortcomings (e.g., count column nnz, block allocation, too late fallback to single-threaded). On a large 85M x 85M graph with 90M non-zeros the transpose did not finish in hours. In this patch we now introduces a more sophisticated sparse row iterator (row and column lower/bounds) in order to facilitate a simple and fast transpose ultra sparse operation. However, this implementation was still much slower than falling back to single-threaded operations and thus use single-threaded transpose for all ultra-sparse matrices instead of if nnz < max(rows,cols). Now this operations completes in <9s.
1 parent b8b0c23 commit 3081ecf

File tree

3 files changed

+71
-18
lines changed

3 files changed

+71
-18
lines changed

src/main/java/org/apache/sysds/runtime/data/SparseBlock.java

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,22 @@ public Iterator<IJV> getIterator(int rl, int ru) {
561561
//default generic iterator, override if necessary
562562
return new SparseBlockIterator(rl, Math.min(ru,numRows()));
563563
}
564+
565+
/**
566+
* Get a non-zero iterator over the subblock [rl/cl, ru/cu). Note that
567+
* the returned IJV object is reused across next calls and should
568+
* be directly consumed or deep copied.
569+
*
570+
* @param rl inclusive lower row index starting at 0
571+
* @param ru exclusive upper row index starting at 0
572+
* @param cl inclusive lower column index starting at 0
573+
* @param cu exclusive upper column index starting at 0
574+
* @return IJV iterator
575+
*/
576+
public Iterator<IJV> getIterator(int rl, int ru, int cl, int cu) {
577+
//default generic iterator, override if necessary
578+
return new SparseBlockIterator(rl, Math.min(ru,numRows()), cl, cu);
579+
}
564580

565581
/**
566582
* Get an iterator over the indices of non-empty rows within the entire sparse block.
@@ -694,19 +710,29 @@ private class SparseBlockIterator implements Iterator<IJV>
694710
private int _curColIx = -1; //current col index pos
695711
private int[] _curIndexes = null; //current col indexes
696712
private double[] _curValues = null; //current col values
697-
private boolean _noNext = false; //end indicator
713+
private boolean _noNext = false; //end indicator
698714
private IJV retijv = new IJV(); //reuse output tuple
715+
private int _cl = 0;
716+
private int _cu = Integer.MAX_VALUE;
699717

700718
protected SparseBlockIterator(int ru) {
701719
_rlen = ru;
702720
_curRow = 0;
703-
findNextNonZeroRow();
721+
findNextNonZeroRow(0);
704722
}
705723

706724
protected SparseBlockIterator(int rl, int ru) {
707725
_rlen = ru;
708726
_curRow = rl;
709-
findNextNonZeroRow();
727+
findNextNonZeroRow(0);
728+
}
729+
730+
protected SparseBlockIterator(int rl, int ru, int cl, int cu) {
731+
_rlen = ru;
732+
_curRow = rl;
733+
_cl = cl;
734+
_cu = cu;
735+
findNextNonZeroRow(cl);
710736
}
711737

712738
@Override
@@ -717,14 +743,12 @@ public boolean hasNext() {
717743
@Override
718744
public IJV next( ) {
719745
retijv.set(_curRow, _curIndexes[_curColIx], _curValues[_curColIx]);
720-
721-
//NOTE: no preincrement on curcolix to avoid OpenJDK8 escape analysis bug, encountered
722-
//with tests SparsityRecompileTest/SparsityFunctionRecompileTest on parfor local result merge
723-
if( _curColIx < pos(_curRow)+size(_curRow)-1 )
746+
if( _curColIx < pos(_curRow)+size(_curRow)-1 && _curIndexes[_curColIx+1] < _cu ) {
724747
_curColIx++;
748+
}
725749
else {
726750
_curRow++;
727-
findNextNonZeroRow();
751+
findNextNonZeroRow(_cl);
728752
}
729753

730754
return retijv;
@@ -733,19 +757,21 @@ public IJV next( ) {
733757
@Override
734758
public void remove() {
735759
throw new RuntimeException("SparseBlockIterator is unsupported!");
736-
}
760+
}
737761

738762
/**
739763
* Moves cursor to next non-zero value or indicates that no more
740764
* values are available.
741765
*/
742-
private void findNextNonZeroRow() {
743-
while( _curRow<_rlen && isEmpty(_curRow))
766+
private void findNextNonZeroRow(int cl) {
767+
while( _curRow<_rlen && (isEmpty(_curRow)
768+
|| (cl>0 && posFIndexGTE(_curRow, cl) < 0)) )
744769
_curRow++;
745770
if(_curRow >= _rlen)
746771
_noNext = true;
747772
else {
748-
_curColIx = pos(_curRow);
773+
_curColIx = (cl==0) ?
774+
pos(_curRow) : posFIndexGTE(_curRow, cl);
749775
_curIndexes = indexes(_curRow);
750776
_curValues = values(_curRow);
751777
}

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ public static MatrixBlock transpose(MatrixBlock in, MatrixBlock out, int k, bool
234234
|| (SHALLOW_COPY_REORG && !in.sparse && !out.sparse && (in.rlen == 1 || in.clen == 1)) //
235235
|| (in.sparse && !out.sparse && in.rlen == 1) //
236236
|| (!in.sparse && out.sparse && in.rlen == 1) //
237-
|| (in.sparse && out.sparse && in.nonZeros < Math.max(in.rlen, in.clen)) // ultra-sparse
238-
) {
237+
|| (in.sparse && out.sparse && in.isUltraSparse(false)))
238+
{
239239
return transpose(in, out);
240240
}
241241
// set meta data and allocate output arrays (if required)
@@ -250,7 +250,9 @@ public static MatrixBlock transpose(MatrixBlock in, MatrixBlock out, int k, bool
250250
// Timing time = new Timing(true);
251251

252252
// CSR is only allowed in the transposed output if the number of non zeros is counted in the columns
253-
allowCSR = allowCSR && (in.clen <= 4096 || out.nonZeros < 10000000);
253+
// and the temporary count arrays are not larger than the entire input
254+
allowCSR = allowCSR && (in.clen <= 4096 || out.nonZeros < 10000000)
255+
&& (k*4*in.clen < in.getInMemorySize());
254256

255257
int[] cnt = null;
256258
final ExecutorService pool = CommonThreadPool.get(k);
@@ -276,12 +278,13 @@ else if(out.sparse)
276278
out.allocateSparseRowsBlock(false);
277279
else
278280
out.allocateDenseBlock(false);
279-
280281

281282
// compute actual transpose and check for errors
282283
ArrayList<TransposeTask> tasks = new ArrayList<>();
283-
boolean allowReturnBlock = out.sparse && in.sparse && in.rlen >= in.clen && cnt == null;
284-
boolean row = (in.sparse || in.rlen >= in.clen) && (!out.sparse || allowReturnBlock);
284+
boolean allowReturnBlock = out.sparse && in.sparse
285+
&& in.rlen >= in.clen && cnt == null && !in.isUltraSparse(false);
286+
boolean row = (in.sparse || in.rlen >= in.clen)
287+
&& (!out.sparse || allowReturnBlock) && !in.isUltraSparse(false);
285288
int len = row ? in.rlen : in.clen;
286289
int blklen = (int) (Math.ceil((double) len / k));
287290
blklen += (!out.sparse && (blklen % 8) != 0) ? 8 - blklen % 8 : 0;
@@ -1192,6 +1195,15 @@ private static void transposeUltraSparse(MatrixBlock in, MatrixBlock out) {
11921195
out.setNonZeros(in.getNonZeros());
11931196
}
11941197

1198+
private static void transposeUltraSparse(MatrixBlock in, MatrixBlock out, int rl, int ru, int cl, int cu) {
1199+
Iterator<IJV> iter = in.getSparseBlockIterator(rl, ru, cl, cu);
1200+
SparseBlock b = out.getSparseBlock();
1201+
while( iter.hasNext() ) {
1202+
IJV cell = iter.next();
1203+
b.append(cell.getJ(), cell.getI(), cell.getV());
1204+
}
1205+
}
1206+
11951207
private static void transposeSparseToSparse(MatrixBlock in, MatrixBlock out, int rl, int ru, int cl, int cu,
11961208
int[] cnt) {
11971209
// NOTE: called only in sequential or column-wise parallel execution
@@ -3861,6 +3873,8 @@ public MatrixBlock call() {
38613873
transposeDenseToDense( _in, _out, rl, ru, cl, cu );
38623874
else if( _in.sparse && _out.sparse && _out.sparseBlock instanceof SparseBlockCSR)
38633875
transposeSparseToSparseCSR(_in, _out, rl, ru, cl, cu, _cnt);
3876+
else if( _in.sparse && _out.sparse && _in.isUltraSparse(false) )
3877+
transposeUltraSparse(_in, _out, rl, ru, cl, cu);
38643878
else if( _in.sparse && _out.sparse ){
38653879
if(allowReturnBlock)
38663880
return transposeSparseToSparseBlock(_in, rl, ru);

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,19 @@ public Iterator<IJV> getSparseBlockIterator(int rl, int ru) {
686686
//get iterator over sparse block
687687
return sparseBlock.getIterator(rl, ru);
688688
}
689+
690+
public Iterator<IJV> getSparseBlockIterator(int rl, int ru, int cl, int cu) {
691+
//check for valid format, should have been checked from outside
692+
if( !sparse )
693+
throw new RuntimeException("getSparseBlockInterator should not be called for dense format");
694+
695+
//check for existing sparse block: return empty list
696+
if( sparseBlock==null )
697+
return Collections.emptyListIterator();
698+
699+
//get iterator over sparse block
700+
return sparseBlock.getIterator(rl, ru, cl, cu);
701+
}
689702

690703
@Override
691704
public double get(int r, int c) {

0 commit comments

Comments
 (0)