2222import java .util .List ;
2323
2424import org .apache .sysds .runtime .DMLRuntimeException ;
25- import org .apache .sysds .runtime .compress .utils .Util ;
2625import org .apache .sysds .runtime .controlprogram .caching .MatrixObject ;
2726import org .apache .sysds .runtime .data .DenseBlock ;
2827import org .apache .sysds .runtime .data .SparseBlock ;
@@ -58,17 +57,21 @@ public ResultMergeMatrix(MatrixObject out, MatrixObject[] in, String outputFilen
5857 super (out , in , outputFilename , accum );
5958 }
6059
61- protected void mergeWithoutComp (MatrixBlock out , MatrixBlock in , boolean appendOnly ) {
62- mergeWithoutComp (out , in , appendOnly , false );
60+ protected void mergeWithoutComp (MatrixBlock out , MatrixBlock in , DenseBlock compare , boolean appendOnly ) {
61+ mergeWithoutComp (out , in , compare , appendOnly , false );
6362 }
6463
65- protected void mergeWithoutComp (MatrixBlock out , MatrixBlock in , boolean appendOnly , boolean par ) {
64+ protected void mergeWithoutComp (MatrixBlock out , MatrixBlock in , DenseBlock compare , boolean appendOnly , boolean par ) {
6665 // pass through to matrix block operations
67- if (_isAccum )
66+ if (_isAccum ) {
6867 out .binaryOperationsInPlace (PLUS , in );
68+ //compare block used for compensation here
69+ if ( compare != null )
70+ out .binaryOperationsInPlace (MINUS ,
71+ new MatrixBlock (out .getNumRows (),out .getNumColumns (), compare ));
72+ }
6973 else {
7074 MatrixBlock out2 = out .merge (in , appendOnly , par );
71-
7275 if (out2 != out )
7376 throw new DMLRuntimeException ("Failed merge need to allow returned MatrixBlock to be used" );
7477 }
@@ -90,18 +93,13 @@ protected void mergeWithComp(MatrixBlock out, MatrixBlock in, DenseBlock compare
9093 // NaNs, since NaN != NaN, otherwise we would potentially overwrite results
9194 // * For the case of accumulation, we add out += (new-old) to ensure correct results
9295 // because all inputs have the old values replicated
93- final int rows = in .getNumRows ();
94- final int cols = in .getNumColumns ();
95- if (in .isEmptyBlock (false )) {
96- if (_isAccum )
97- return ; // nothing to do
96+ int rows = in .getNumRows ();
97+ int cols = in .getNumColumns ();
98+ if (in .isEmptyBlock (false ))
9899 mergeWithCompEmpty (out , rows , cols , compare );
99- }
100- else if (in .isInSparseFormat () && _isAccum )
101- mergeSparseAccumulative (out , in , rows , cols , compare );
102100 else if (in .isInSparseFormat ())
103101 mergeSparse (out , in , rows , cols , compare );
104- else // SPARSE/ DENSE
102+ else // DENSE
105103 mergeGeneric (out , in , rows , cols , compare );
106104 }
107105
@@ -111,90 +109,62 @@ private void mergeWithCompEmpty(MatrixBlock out, int m, int n, DenseBlock compar
111109 }
112110
113111 private void mergeWithCompEmptyRow (MatrixBlock out , int m , int n , DenseBlock compare , int i ) {
114-
115112 for (int j = 0 ; j < n ; j ++) {
116113 final double valOld = compare .get (i , j );
117- if (!Util . eq (0.0 , valOld )) // NaN awareness
114+ if (!equals (0.0 , valOld )) // NaN awareness
118115 out .set (i , j , 0 );
119116 }
120117 }
121118
122- private void mergeSparseAccumulative (MatrixBlock out , MatrixBlock in , int m , int n , DenseBlock compare ) {
123- final SparseBlock a = in .getSparseBlock ();
124- for (int i = 0 ; i < m ; i ++) {
125- if (a .isEmpty (i ))
126- continue ;
127- final int apos = a .pos (i );
128- final int alen = a .size (i ) + apos ;
129- final int [] aix = a .indexes (i );
130- final double [] aval = a .values (i );
131- mergeSparseRowAccumulative (out , apos , alen , aix , aval , compare , n , i );
132- }
133- }
134-
135- private void mergeSparseRowAccumulative (MatrixBlock out , int apos , int alen , int [] aix , double [] aval ,
136- DenseBlock compare , int n , int i ) {
137- for (; apos < alen ; apos ++) { // inside
138- final double valOld = compare .get (i , aix [apos ]);
139- final double valNew = aval [apos ];
140- if (!Util .eq (valNew , valOld )) { // NaN awareness
141- double value = out .get (i , aix [apos ]) + (valNew - valOld );
142- out .set (i , aix [apos ], value );
143- }
144- }
145- }
146-
147119 private void mergeSparse (MatrixBlock out , MatrixBlock in , int m , int n , DenseBlock compare ) {
148120 final SparseBlock a = in .getSparseBlock ();
149121 for (int i = 0 ; i < m ; i ++) {
150122 if (a .isEmpty (i ))
151123 mergeWithCompEmptyRow (out , m , n , compare , i );
152124 else {
153- final int apos = a .pos (i );
154- final int alen = a .size (i ) + apos ;
155- final int [] aix = a .indexes (i );
156- final double [] aval = a .values (i );
157- mergeSparseRow (out , apos , alen , aix , aval , compare , n , i );
158- }
159- }
160- }
161-
162- private void mergeSparseRow (MatrixBlock out , int apos , int alen , int [] aix , double [] aval , DenseBlock compare , int n ,
163- int i ) {
164- int j = 0 ;
165- for (; j < n && apos < alen ; j ++) { // inside
166- final boolean aposValid = aix [apos ] == j ;
167- final double valOld = compare .get (i , j );
168- final double valNew = aix [apos ] == j ? aval [apos ] : 0.0 ;
169- if (!Util .eq (valNew , valOld )) { // NaN awareness
170- double value = !_isAccum ? valNew : (out .get (i , j ) + (valNew - valOld ));
171- out .set (i , j , value );
172- }
173- if (aposValid )
174- apos ++;
175- }
176- for (; j < n ; j ++) {
177- final double valOld = compare .get (i , j );
178- if (valOld != 0 ) {
179- double value = (out .get (i , j ) - valOld );
180- out .set (i , j , value );
125+ int apos = a .pos (i );
126+ int alen = a .size (i ) + apos ;
127+ int [] aix = a .indexes (i );
128+ double [] avals = a .values (i );
129+ int j = 0 ;
130+ for (; j < n && apos < alen ; j ++) { // inside
131+ boolean aposValid = (aix [apos ] == j );
132+ double valOld = compare .get (i , j );
133+ double valNew = aposValid ? avals [apos ] : 0.0 ;
134+ if (!equals (valNew , valOld )) { // NaN awareness
135+ double value = !_isAccum ? valNew : (out .get (i , j ) + (valNew - valOld ));
136+ out .set (i , j , value );
137+ }
138+ if (aposValid )
139+ apos ++;
140+ }
141+ for (; j < n ; j ++) {
142+ double valOld = compare .get (i , j );
143+ if (valOld != 0 ) {
144+ double value = (out .get (i , j ) - valOld );
145+ out .set (i , j , value );
146+ }
147+ }
181148 }
182149 }
183-
184150 }
185151
186152 private void mergeGeneric (MatrixBlock out , MatrixBlock in , int m , int n , DenseBlock compare ) {
187153 for (int i = 0 ; i < m ; i ++) {
188154 for (int j = 0 ; j < n ; j ++) {
189155 final double valOld = compare .get (i , j );
190156 final double valNew = in .get (i , j ); // input value
191- if (!Util .eq (valNew , valOld )) { // NaN awareness
192- double value = !_isAccum ? valNew : (out .get (i , j ) + (valNew - valOld ));
193- out .set (i , j , value );
157+ if (!equals (valNew , valOld )) { // NaN awareness
158+ out .set (i , j , valNew );
194159 }
195160 }
196161 }
197162 }
163+
164+ private boolean equals (double valNew , double valOld ) {
165+ return (valNew == valOld && !Double .isNaN (valNew )) //for changed values
166+ || (Double .isNaN (valNew ) && Double .isNaN (valOld )); //NaN awareness
167+ }
198168
199169 protected long computeNonZeros (MatrixObject out , List <MatrixObject > in ) {
200170 // sum of nnz of input (worker result) - output var existing nnz
0 commit comments