Skip to content

Commit ae9fb64

Browse files
committed
Add safe-guard for masked version
1 parent 89c40e2 commit ae9fb64

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,44 +1372,42 @@ namespace grb {
13721372
auto B_k = B_raw.col_start[ i ];
13731373
auto mask_k = mask_raw.col_start[ i ];
13741374
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
1375-
auto k_col = A_crs_raw.row_index[ k ];
1375+
auto j = A_crs_raw.row_index[ k ];
13761376

13771377
// Increment the pointer of mask until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1378-
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
1378+
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > j ) {
13791379
_DEBUG_THREADESAFE_PRINT( "NEquals MASK coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13801380
mask_k++;
13811381
}
1382-
1383-
if( mask_raw.row_index[ B_k ] != k_col ) {
1384-
mask_k++;
1385-
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
1386-
continue;
1382+
if( mask_k >= mask_raw.col_start[ i + 1 ] ) {
1383+
_DEBUG_THREADESAFE_PRINT( "Not value left in mask for this column\n" );
1384+
break;
13871385
}
1388-
1389-
if( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1386+
1387+
if( mask_raw.row_index[ B_k ] != j || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
13901388
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
13911389
continue;
13921390
}
13931391

13941392
// Increment the pointer of B until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1395-
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > k_col ) {
1393+
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > j ) {
13961394
_DEBUG_THREADESAFE_PRINT( "NEquals B coordinate: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
13971395
B_k++;
13981396
}
13991397

14001398
// Get B value (or identity if not found)
14011399
auto B_val = B_identity;
1402-
if( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] == k_col ) {
1400+
if( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] == j ) {
14031401
_DEBUG_THREADESAFE_PRINT( "Found B value at: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
14041402
B_val = B_raw.values[ B_k ];
14051403
B_k++;
14061404
} else {
1407-
_DEBUG_THREADESAFE_PRINT( "Not found B, using identity: ( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( B_val ) + "\n" );
1405+
_DEBUG_THREADESAFE_PRINT( "Not found B, using identity: ( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( B_val ) + "\n" );
14081406
}
14091407

14101408
// Get A value
14111409
const auto a_val_before = A_crs_raw.values[ k ];
1412-
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
1410+
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( a_val_before ) + "\n" );
14131411
// Compute the fold for this coordinate
14141412
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, B_val, op );
14151413
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, B_val, op );

0 commit comments

Comments
 (0)