@@ -1069,7 +1069,7 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10691069 boolean containsNull = false ;
10701070 for (int i = 0 ; i < seqHeight ; i += blkz ) {
10711071 // blocked execution for earlier JIT compilation
1072- int t = fusedSeqRexpandSparseBlock (csr , A , w , i , Math .min (i + blkz , seqHeight ), ( int ) outCols );
1072+ int t = fusedSeqRexpandSparseBlock (csr , A , w , i , Math .min (i + blkz , seqHeight ), updateClen , outCols );
10731073 if (t < 0 ) {
10741074 t = Math .abs (t );
10751075 containsNull = true ;
@@ -1088,7 +1088,7 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d
10881088 }
10891089
10901090 private static int fusedSeqRexpandSparseBlock (final SparseBlockCSR csr , final MatrixBlock A , final double w , int rl ,
1091- int ru , int maxOutCol ) {
1091+ int ru , boolean updateClen , int maxOutCol ) {
10921092
10931093 // prepare allocation of CSR sparse block
10941094 final int [] rowPointers = csr .rowPointers ();
@@ -1099,11 +1099,9 @@ private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final Ma
10991099 int maxCol = 0 ;
11001100
11011101 for (int i = rl ; i < ru ; i ++) {
1102- int c = rexpandSingleRow (i , A .get (i , 0 ), w , indexes , values , maxOutCol );
1103- if (c < 0 )
1104- containsNull = true ;
1105- else
1106- maxCol = Math .max (c , maxCol );
1102+ int c = rexpandSingleRow (i , A .get (i , 0 ), w , indexes , values , updateClen , maxOutCol );
1103+ containsNull |= c < 0 ;
1104+ maxCol = Math .max (c , maxCol );
11071105 rowPointers [i ] = i ;
11081106 }
11091107
@@ -1117,25 +1115,22 @@ private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updat
11171115 ret .clen = maxCol ;
11181116 }
11191117
1120- public static int rexpandSingleRow (int row , double v2 , double w , int [] retIx , double [] retVals , int maxOutCol ) {
1121- // If any of the values are NaN (i.e., missing) then
1122- // we skip this tuple, proceed to the next tuple
1123- if (Double .isNaN (v2 ))
1124- return -1 ;
1118+ public static int rexpandSingleRow (int row , double v2 , double w , int [] retIx , double [] retVals ,
1119+ boolean updateClen , int maxOutCol ) {
11251120
1126- // safe casts to long for consistent behavior with indexing
1127- int col = UtilFunctions .toInt (v2 );
1128- if (col <= 0 )
1121+ final int colUnsafe = UtilFunctions .toInt (v2 ); // colUnsafe = 0 for Nan
1122+ int isNan = (Double .isNaN (v2 ) ? 1 : 0 ); // avoid branching by boolean to int conversion
1123+ int col = colUnsafe - isNan ; // col = -1 for Nan
1124+
1125+ // use branch prediction for rare case
1126+ if (!Double .isNaN (v2 ) && colUnsafe <= 0 )
11291127 throw new DMLRuntimeException ("Erroneous input while computing the contingency table (value <= zero): " + v2 );
1130- // maxOutCol = - 1 if not specified --> TRUE
1131- if (col <= maxOutCol ){
1132- // set weight as value (expand is guaranteed to address different cells)
1133- retIx [row ] = col - 1 ;
1134- retVals [row ] = w ;
1135- }
11361128
1137- // maintain max seen col
1138- return col ;
1129+ // avoid branching again by boolean to int conversion
1130+ int valid = !Double .isNaN (v2 ) && (updateClen || col <= maxOutCol ) ? 1 : 0 ;
1131+ retIx [row ] = (col - 1 )*valid ; // use valid as switch
1132+ retVals [row ] = w *valid ;
1133+ return valid *col + valid - 1 ; // -1 if invalid else col
11391134 }
11401135
11411136 /**
0 commit comments