@@ -1230,7 +1230,7 @@ namespace grb {
12301230 typename RIT, typename CIT, typename NIT
12311231 >
12321232 RC scale_unmasked_generic (
1233- const Matrix< IOType, reference, RIT, CIT, NIT > &A,
1233+ Matrix< IOType, reference, RIT, CIT, NIT > &A,
12341234 const InputType &x,
12351235 const Operator &op = Operator ()
12361236 ) {
@@ -1284,7 +1284,7 @@ namespace grb {
12841284 typename RIT_M, typename CIT_M, typename NIT_M
12851285 >
12861286 RC scale_masked_generic (
1287- const Matrix< IOType, reference, RIT_A, CIT_A, NIT_A > &A,
1287+ Matrix< IOType, reference, RIT_A, CIT_A, NIT_A > &A,
12881288 const Matrix< MaskType, reference, RIT_M, CIT_M, NIT_M > &mask,
12891289 const InputType &x,
12901290 const Operator &op = Operator ()
@@ -1329,28 +1329,21 @@ namespace grb {
13291329 for ( auto k = A_crs_raw.col_start [ i ]; k < A_crs_raw.col_start [ i + 1 ]; ++k ) {
13301330 auto k_col = A_crs_raw.row_index [ k ];
13311331
1332- // Increment the mask pointer until we find the right column, or an higher one
1333- while ( mask_raw.row_index [ mask_k ] < k_col && mask_k < mask_raw.col_start [ i + 1 ] ) {
1334- _DEBUG_THREADESAFE_PRINT ( " Skipping masked coordinate: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
1332+ // Increment the mask pointer until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1333+ while ( mask_k < mask_raw.col_start [ i + 1 ] && mask_raw.row_index [ mask_k ] > k_col ) {
1334+ _DEBUG_THREADESAFE_PRINT ( " NEquals masked coordinate: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
13351335 mask_k++;
13361336 }
1337- // if there is no value for this coordinate, skip it
1338- if ( mask_raw.row_index [ mask_k ] != k_col ) {
1339- _DEBUG_THREADESAFE_PRINT ( " Skipped masked coordinate: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
1340- continue ;
1341- }
1342-
1343- // Get mask value
1344- if ( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1345- _DEBUG_THREADESAFE_PRINT ( " Skipped masked value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
1337+
1338+ if ( mask_raw.row_index [ mask_k ] < k_col || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1339+ mask_k++;
1340+ _DEBUG_THREADESAFE_PRINT ( " Skip masked value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
13461341 continue ;
13471342 }
13481343
1349- // Increment the mask pointer in order to skip the next while loop (best case)
1350- mask_k++;
1351-
1344+ _DEBUG_THREADESAFE_PRINT ( " Found masked value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
13521345 // Get A value
1353- const IOType a_val_before = A_crs_raw.values [ k ];
1346+ const auto a_val_before = A_crs_raw.values [ k ];
13541347 _DEBUG_THREADESAFE_PRINT ( " A( " + std::to_string ( i ) + " ;" + std::to_string ( k_col ) + " ) = " + std::to_string ( a_val_before ) + " \n " );
13551348 // Compute the fold for this coordinate
13561349 local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values [ k ], a_val_before, x, op );
0 commit comments