Skip to content

Commit 14c59b9

Browse files
committed
fixed null handling in fused seq ctable
1 parent 407090d commit 14c59b9

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,24 +191,24 @@ private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, i
191191

192192
int maxCol = 0;
193193
int zeros = 0;
194-
int notHandledNulls = 0;
195194
final double[] aVals = A.getDenseBlockValues();
196195

197196
for(int i = start; i < end; i++) {
198197
final double v2 = aVals[i];
199-
int colUnsafe = UtilFunctions.toInt(v2);
200-
if(colUnsafe <= 0)
198+
final int colUnsafe = UtilFunctions.toInt(v2);
199+
if(!Double.isNaN(v2) && colUnsafe < 0)
201200
throw new DMLRuntimeException(
202201
"Erroneous input while computing the contingency table (value <= zero): " + v2);
203-
boolean invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol);
204-
final int colSafe = invalid ? maxOutCol : colUnsafe - 1;
205-
zeros += invalid ? 1 : 0;
206-
notHandledNulls += Double.isNaN(v2) ? maxOutCol : 0;
202+
// Boolean to int conversion to avoid branch
203+
final int invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol) ? 1 : 0;
204+
// if invalid -> maxOutCol else -> colUnsafe - 1
205+
final int colSafe = maxOutCol*invalid + (colUnsafe - 1)*(1 - invalid);
206+
zeros += invalid;
207207
maxCol = Math.max(colUnsafe, maxCol);
208208
map[i] = colSafe;
209209
}
210210

211-
if (notHandledNulls < 0){
211+
if (maxOutCol == -1 && zeros > 0){
212212
maxCol *= -1;
213213
}
214214

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10691069
boolean containsNull = false;
10701070
for(int i = 0; i < seqHeight; i += blkz) {
10711071
// blocked execution for earlier JIT compilation
1072-
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, seqHeight), (int) outCols);
1072+
int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, seqHeight), updateClen,outCols);
10731073
if(t < 0) {
10741074
t = Math.abs(t);
10751075
containsNull = true;
@@ -1088,7 +1088,7 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10881088
}
10891089

10901090
private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl,
1091-
int ru, int maxOutCol) {
1091+
int ru, boolean updateClen,int maxOutCol) {
10921092

10931093
// prepare allocation of CSR sparse block
10941094
final int[] rowPointers = csr.rowPointers();
@@ -1099,11 +1099,9 @@ private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final Ma
10991099
int maxCol = 0;
11001100

11011101
for(int i = rl; i < ru; i++) {
1102-
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values, maxOutCol);
1103-
if(c < 0)
1104-
containsNull = true;
1105-
else
1106-
maxCol = Math.max(c, maxCol);
1102+
int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values, updateClen, maxOutCol);
1103+
containsNull |= c < 0;
1104+
maxCol = Math.max(c, maxCol);
11071105
rowPointers[i] = i;
11081106
}
11091107

@@ -1117,25 +1115,22 @@ private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updat
11171115
ret.clen = maxCol;
11181116
}
11191117

1120-
public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals, int maxOutCol) {
1121-
// If any of the values are NaN (i.e., missing) then
1122-
// we skip this tuple, proceed to the next tuple
1123-
if(Double.isNaN(v2))
1124-
return -1;
1118+
public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals,
1119+
boolean updateClen, int maxOutCol) {
11251120

1126-
// safe casts to long for consistent behavior with indexing
1127-
int col = UtilFunctions.toInt(v2);
1128-
if(col <= 0)
1121+
final int colUnsafe = UtilFunctions.toInt(v2); // colUnsafe = 0 for Nan
1122+
int isNan = (Double.isNaN(v2) ? 1 : 0); // avoid branching by boolean to int conversion
1123+
int col = colUnsafe - isNan; // col = -1 for Nan
1124+
1125+
// use branch prediction for rare case
1126+
if(!Double.isNaN(v2) && colUnsafe <= 0)
11291127
throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2);
1130-
// maxOutCol = - 1 if not specified --> TRUE
1131-
if(col <= maxOutCol){
1132-
// set weight as value (expand is guaranteed to address different cells)
1133-
retIx[row] = col - 1;
1134-
retVals[row] = w;
1135-
}
11361128

1137-
// maintain max seen col
1138-
return col;
1129+
// avoid branching again by boolean to int conversion
1130+
int valid = !Double.isNaN(v2) && (updateClen || col <= maxOutCol) ? 1 : 0;
1131+
retIx[row] = (col - 1)*valid; // use valid as switch
1132+
retVals[row] = w*valid;
1133+
return valid*col + valid - 1; // -1 if invalid else col
11391134
}
11401135

11411136
/**

0 commit comments

Comments
 (0)