Skip to content

Commit fa8f2ff

Browse files
committed
Fix masked iteration pattern
1 parent 11adcdf commit fa8f2ff

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ namespace grb {
12301230
typename RIT, typename CIT, typename NIT
12311231
>
12321232
RC scale_unmasked_generic(
1233-
const Matrix< IOType, reference, RIT, CIT, NIT > &A,
1233+
Matrix< IOType, reference, RIT, CIT, NIT > &A,
12341234
const InputType &x,
12351235
const Operator &op = Operator()
12361236
) {
@@ -1284,7 +1284,7 @@ namespace grb {
12841284
typename RIT_M, typename CIT_M, typename NIT_M
12851285
>
12861286
RC scale_masked_generic(
1287-
const Matrix< IOType, reference, RIT_A, CIT_A, NIT_A > &A,
1287+
Matrix< IOType, reference, RIT_A, CIT_A, NIT_A > &A,
12881288
const Matrix< MaskType, reference, RIT_M, CIT_M, NIT_M > &mask,
12891289
const InputType &x,
12901290
const Operator &op = Operator()
@@ -1329,28 +1329,21 @@ namespace grb {
13291329
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
13301330
auto k_col = A_crs_raw.row_index[ k ];
13311331

1332-
// Increment the mask pointer until we find the right column, or an higher one
1333-
while( mask_raw.row_index[ mask_k ] < k_col && mask_k < mask_raw.col_start[ i + 1 ] ) {
1334-
_DEBUG_THREADESAFE_PRINT( "Skipping masked coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
1332+
// Increment the mask pointer until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1333+
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
1334+
_DEBUG_THREADESAFE_PRINT( "NEquals masked coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13351335
mask_k++;
13361336
}
1337-
// if there is no value for this coordinate, skip it
1338-
if( mask_raw.row_index[ mask_k ] != k_col ) {
1339-
_DEBUG_THREADESAFE_PRINT( "Skipped masked coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
1340-
continue;
1341-
}
1342-
1343-
// Get mask value
1344-
if( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1345-
_DEBUG_THREADESAFE_PRINT( "Skipped masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
1337+
1338+
if( mask_raw.row_index[ mask_k ] < k_col || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1339+
mask_k++;
1340+
_DEBUG_THREADESAFE_PRINT( "Skip masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13461341
continue;
13471342
}
13481343

1349-
// Increment the mask pointer in order to skip the next while loop (best case)
1350-
mask_k++;
1351-
1344+
_DEBUG_THREADESAFE_PRINT( "Found masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13521345
// Get A value
1353-
const IOType a_val_before = A_crs_raw.values[ k ];
1346+
const auto a_val_before = A_crs_raw.values[ k ];
13541347
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
13551348
// Compute the fold for this coordinate
13561349
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, x, op );

0 commit comments

Comments
 (0)