Skip to content

Commit 58490d3

Browse files
committed
[SYSTEMDS-3799] Fix parfor result merge (all combinations)
This patch fixes recently discovered (see code coverage) issues of parfor result merge for combinations of different result merge implementations, dense/sparse inputs, with compare dense/sparse blocks, and most importantly += accumulation into the output.
1 parent 913edde commit 58490d3

File tree

6 files changed

+122
-123
lines changed

6 files changed

+122
-123
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public abstract class ResultMerge<T extends CacheableData<?>> implements Seriali
3636
protected static final Log LOG = LogFactory.getLog(ResultMerge.class.getName());
3737
protected static final String NAME_SUFFIX = "_rm";
3838
protected static final BinaryOperator PLUS = InstructionUtils.parseBinaryOperator("+");
39+
protected static final BinaryOperator MINUS = InstructionUtils.parseBinaryOperator("-");
3940

4041
//inputs to result merge
4142
protected T _output = null;

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,10 @@ private void createBinaryBlockResultFile( String fnameStaging, String fnameStagi
258258
DenseBlock compare = DataConverter.convertToDenseBlock(mb, true);
259259
for( String lname : dir.list() ) {
260260
MatrixBlock tmp = LocalFileUtils.readMatrixBlockFromLocal( dir+"/"+lname );
261-
mergeWithComp(mb, tmp, compare);
261+
if( _isAccum )
262+
mergeWithoutComp(mb, tmp, compare, appendOnly);
263+
else
264+
mergeWithComp(mb, tmp, compare);
262265
}
263266

264267
//sort sparse due to append-only
@@ -279,7 +282,7 @@ private void createBinaryBlockResultFile( String fnameStaging, String fnameStagi
279282
}
280283
else {
281284
MatrixBlock tmp = LocalFileUtils.readMatrixBlockFromLocal( dir+"/"+lname );
282-
mergeWithoutComp(mb, tmp, appendOnly);
285+
mergeWithoutComp(mb, tmp, null, appendOnly);
283286
}
284287
}
285288

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public MatrixObject executeSerialMerge()
7373

7474
//create compare matrix if required (existing data in result)
7575
_compare = getCompareMatrix(outMB);
76-
if( _compare != null )
76+
if( _compare != null || _isAccum )
7777
outMBNew.copy(outMB);
7878

7979
//serial merge all inputs
@@ -90,7 +90,7 @@ public MatrixObject executeSerialMerge()
9090
MatrixBlock inMB = in.acquireRead();
9191

9292
//core merge
93-
merge( outMBNew, inMB, appendOnly );
93+
merge( outMBNew, inMB, _compare, appendOnly );
9494

9595
//unpin and clear in-memory input_i
9696
in.release();
@@ -169,7 +169,7 @@ public MatrixObject executeParallelMerge( int par )
169169

170170
//create compare matrix if required (existing data in result)
171171
_compare = getCompareMatrix(outMB);
172-
if( _compare != null )
172+
if( _compare != null || _isAccum )
173173
outMBNew.copy(outMB);
174174

175175
//parallel merge of all inputs
@@ -215,7 +215,7 @@ public MatrixObject executeParallelMerge( int par )
215215
return moNew;
216216
}
217217

218-
private static DenseBlock getCompareMatrix( MatrixBlock output ) {
218+
private DenseBlock getCompareMatrix( MatrixBlock output ) {
219219
//create compare matrix only if required
220220
if( !output.isEmptyBlock(false) )
221221
return DataConverter.convertToDenseBlock(output, false);
@@ -253,11 +253,12 @@ private MatrixObject createNewMatrixObject( MatrixBlock data ) {
253253
*
254254
* @param out output matrix block
255255
* @param in input matrix block
256+
* @param compare initialized output
256257
* @param appendOnly ?
257258
*/
258-
private void merge( MatrixBlock out, MatrixBlock in, boolean appendOnly ) {
259-
if( _compare == null )
260-
mergeWithoutComp(out, in, appendOnly, true);
259+
private void merge( MatrixBlock out, MatrixBlock in, DenseBlock compare, boolean appendOnly ) {
260+
if( _compare == null || _isAccum )
261+
mergeWithoutComp(out, in, _compare, appendOnly, true);
261262
else
262263
mergeWithComp(out, in, _compare);
263264
}
@@ -304,7 +305,7 @@ public void run()
304305
LOG.trace("ResultMerge (local, in-memory): Merge input "+_inMO.hashCode()+" (fname="+_inMO.getFileName()+")");
305306

306307
MatrixBlock inMB = _inMO.acquireRead(); //incl. implicit read from HDFS
307-
merge( _outMB, inMB, false );
308+
merge( _outMB, inMB, _compare, false );
308309
_inMO.release();
309310
_inMO.clearData();
310311
}

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java

Lines changed: 44 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.List;
2323

2424
import org.apache.sysds.runtime.DMLRuntimeException;
25-
import org.apache.sysds.runtime.compress.utils.Util;
2625
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2726
import org.apache.sysds.runtime.data.DenseBlock;
2827
import org.apache.sysds.runtime.data.SparseBlock;
@@ -58,17 +57,21 @@ public ResultMergeMatrix(MatrixObject out, MatrixObject[] in, String outputFilen
5857
super(out, in, outputFilename, accum);
5958
}
6059

61-
protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in, boolean appendOnly) {
62-
mergeWithoutComp(out, in, appendOnly, false);
60+
protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in, DenseBlock compare, boolean appendOnly) {
61+
mergeWithoutComp(out, in, compare, appendOnly, false);
6362
}
6463

65-
protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in, boolean appendOnly, boolean par) {
64+
protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in, DenseBlock compare, boolean appendOnly, boolean par) {
6665
// pass through to matrix block operations
67-
if(_isAccum)
66+
if(_isAccum) {
6867
out.binaryOperationsInPlace(PLUS, in);
68+
//compare block used for compensation here
69+
if( compare != null )
70+
out.binaryOperationsInPlace(MINUS,
71+
new MatrixBlock(out.getNumRows(),out.getNumColumns(), compare));
72+
}
6973
else {
7074
MatrixBlock out2 = out.merge(in, appendOnly, par);
71-
7275
if(out2 != out)
7376
throw new DMLRuntimeException("Failed merge need to allow returned MatrixBlock to be used");
7477
}
@@ -90,18 +93,13 @@ protected void mergeWithComp(MatrixBlock out, MatrixBlock in, DenseBlock compare
9093
// NaNs, since NaN != NaN, otherwise we would potentially overwrite results
9194
// * For the case of accumulation, we add out += (new-old) to ensure correct results
9295
// because all inputs have the old values replicated
93-
final int rows = in.getNumRows();
94-
final int cols = in.getNumColumns();
95-
if(in.isEmptyBlock(false)) {
96-
if(_isAccum)
97-
return; // nothing to do
96+
int rows = in.getNumRows();
97+
int cols = in.getNumColumns();
98+
if(in.isEmptyBlock(false))
9899
mergeWithCompEmpty(out, rows, cols, compare);
99-
}
100-
else if(in.isInSparseFormat() && _isAccum)
101-
mergeSparseAccumulative(out, in, rows, cols, compare);
102100
else if(in.isInSparseFormat())
103101
mergeSparse(out, in, rows, cols, compare);
104-
else // SPARSE/DENSE
102+
else // DENSE
105103
mergeGeneric(out, in, rows, cols, compare);
106104
}
107105

@@ -111,90 +109,62 @@ private void mergeWithCompEmpty(MatrixBlock out, int m, int n, DenseBlock compar
111109
}
112110

113111
private void mergeWithCompEmptyRow(MatrixBlock out, int m, int n, DenseBlock compare, int i) {
114-
115112
for(int j = 0; j < n; j++) {
116113
final double valOld = compare.get(i, j);
117-
if(!Util.eq(0.0, valOld)) // NaN awareness
114+
if(!equals(0.0, valOld)) // NaN awareness
118115
out.set(i, j, 0);
119116
}
120117
}
121118

122-
private void mergeSparseAccumulative(MatrixBlock out, MatrixBlock in, int m, int n, DenseBlock compare) {
123-
final SparseBlock a = in.getSparseBlock();
124-
for(int i = 0; i < m; i++) {
125-
if(a.isEmpty(i))
126-
continue;
127-
final int apos = a.pos(i);
128-
final int alen = a.size(i) + apos;
129-
final int[] aix = a.indexes(i);
130-
final double[] aval = a.values(i);
131-
mergeSparseRowAccumulative(out, apos, alen, aix, aval, compare, n, i);
132-
}
133-
}
134-
135-
private void mergeSparseRowAccumulative(MatrixBlock out, int apos, int alen, int[] aix, double[] aval,
136-
DenseBlock compare, int n, int i) {
137-
for(; apos < alen; apos++) { // inside
138-
final double valOld = compare.get(i, aix[apos]);
139-
final double valNew = aval[apos];
140-
if(!Util.eq(valNew, valOld)) { // NaN awareness
141-
double value = out.get(i, aix[apos]) + (valNew - valOld);
142-
out.set(i, aix[apos], value);
143-
}
144-
}
145-
}
146-
147119
private void mergeSparse(MatrixBlock out, MatrixBlock in, int m, int n, DenseBlock compare) {
148120
final SparseBlock a = in.getSparseBlock();
149121
for(int i = 0; i < m; i++) {
150122
if(a.isEmpty(i))
151123
mergeWithCompEmptyRow(out, m, n, compare, i);
152124
else {
153-
final int apos = a.pos(i);
154-
final int alen = a.size(i) + apos;
155-
final int[] aix = a.indexes(i);
156-
final double[] aval = a.values(i);
157-
mergeSparseRow(out, apos, alen, aix, aval, compare, n, i);
158-
}
159-
}
160-
}
161-
162-
private void mergeSparseRow(MatrixBlock out, int apos, int alen, int[] aix, double[] aval, DenseBlock compare, int n,
163-
int i) {
164-
int j = 0;
165-
for(; j < n && apos < alen; j++) { // inside
166-
final boolean aposValid = aix[apos] == j;
167-
final double valOld = compare.get(i, j);
168-
final double valNew = aix[apos] == j ? aval[apos] : 0.0;
169-
if(!Util.eq(valNew, valOld)) { // NaN awareness
170-
double value = !_isAccum ? valNew : (out.get(i, j) + (valNew - valOld));
171-
out.set(i, j, value);
172-
}
173-
if(aposValid)
174-
apos++;
175-
}
176-
for(; j < n; j++) {
177-
final double valOld = compare.get(i, j);
178-
if(valOld != 0) {
179-
double value = (out.get(i, j) - valOld);
180-
out.set(i, j, value);
125+
int apos = a.pos(i);
126+
int alen = a.size(i) + apos;
127+
int[] aix = a.indexes(i);
128+
double[] avals = a.values(i);
129+
int j = 0;
130+
for(; j < n && apos < alen; j++) { // inside
131+
boolean aposValid = (aix[apos] == j);
132+
double valOld = compare.get(i, j);
133+
double valNew = aposValid ? avals[apos] : 0.0;
134+
if(!equals(valNew, valOld)) { // NaN awareness
135+
double value = !_isAccum ? valNew : (out.get(i, j) + (valNew - valOld));
136+
out.set(i, j, value);
137+
}
138+
if(aposValid)
139+
apos++;
140+
}
141+
for(; j < n; j++) {
142+
double valOld = compare.get(i, j);
143+
if(valOld != 0) {
144+
double value = (out.get(i, j) - valOld);
145+
out.set(i, j, value);
146+
}
147+
}
181148
}
182149
}
183-
184150
}
185151

186152
private void mergeGeneric(MatrixBlock out, MatrixBlock in, int m, int n, DenseBlock compare) {
187153
for(int i = 0; i < m; i++) {
188154
for(int j = 0; j < n; j++) {
189155
final double valOld = compare.get(i, j);
190156
final double valNew = in.get(i, j); // input value
191-
if(!Util.eq(valNew, valOld)) { // NaN awareness
192-
double value = !_isAccum ? valNew : (out.get(i, j) + (valNew - valOld));
193-
out.set(i, j, value);
157+
if(!equals(valNew, valOld)) { // NaN awareness
158+
out.set(i, j, valNew);
194159
}
195160
}
196161
}
197162
}
163+
164+
private boolean equals(double valNew, double valOld) {
165+
return (valNew == valOld && !Double.isNaN(valNew)) //for changed values
166+
|| (Double.isNaN(valNew) && Double.isNaN(valOld)); //NaN awareness
167+
}
198168

199169
protected long computeNonZeros(MatrixObject out, List<MatrixObject> in) {
200170
// sum of nnz of input (worker result) - output var existing nnz

src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Iter
5252

5353
//merge all blocks into compare block
5454
MatrixBlock out = new MatrixBlock(cin);
55-
while( din.hasNext() )
56-
mergeWithComp(out, din.next(), compare);
55+
while( din.hasNext() ) {
56+
if( _isAccum )
57+
mergeWithoutComp(out, din.next(), compare, false);
58+
else
59+
mergeWithComp(out, din.next(), compare);
60+
}
5761

5862
//create output tuple
5963
return new Tuple2<>(new MatrixIndexes(ixin), out);

0 commit comments

Comments
 (0)