Skip to content

Commit 7e34ba3

Browse files
committed
Add safe-guard for masked version
1 parent 89c40e2 commit 7e34ba3

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,16 +1274,19 @@ namespace grb {
12741274
for( auto i = start_row; i < end_row; ++i ) {
12751275
auto B_k = B_raw.col_start[ i ];
12761276
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
1277-
auto k_col = A_crs_raw.row_index[ k ];
1277+
const auto j = A_crs_raw.row_index[ k ];
12781278

12791279
// Increment the mask pointer until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1280-
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > k_col ) {
1280+
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > j ) {
12811281
_DEBUG_THREADESAFE_PRINT( "NEquals B coordinate: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
12821282
B_k++;
12831283
}
1284+
if( B_k >= B_raw.col_start[ i + 1 ] ) {
1285+
_DEBUG_THREADESAFE_PRINT( "Not value left in B for this column\n" );
1286+
break;
1287+
}
12841288

1285-
if( B_raw.row_index[ B_k ] < k_col ) {
1286-
B_k++;
1289+
if( B_raw.row_index[ B_k ] != j ) {
12871290
_DEBUG_THREADESAFE_PRINT( "Skip B value at: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
12881291
continue;
12891292
}
@@ -1293,7 +1296,7 @@ namespace grb {
12931296
_DEBUG_THREADESAFE_PRINT( "B( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " ) = " + std::to_string( B_val ) + "\n" );
12941297
// Get A value
12951298
const auto a_val_before = A_crs_raw.values[ k ];
1296-
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
1299+
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( a_val_before ) + "\n" );
12971300
// Compute the fold for this coordinate
12981301
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, B_val, op );
12991302
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, B_val, op );
@@ -1372,44 +1375,42 @@ namespace grb {
13721375
auto B_k = B_raw.col_start[ i ];
13731376
auto mask_k = mask_raw.col_start[ i ];
13741377
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
1375-
auto k_col = A_crs_raw.row_index[ k ];
1378+
auto j = A_crs_raw.row_index[ k ];
13761379

13771380
// Increment the pointer of mask until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1378-
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
1381+
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > j ) {
13791382
_DEBUG_THREADESAFE_PRINT( "NEquals MASK coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
13801383
mask_k++;
13811384
}
1382-
1383-
if( mask_raw.row_index[ B_k ] != k_col ) {
1384-
mask_k++;
1385-
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
1386-
continue;
1385+
if( mask_k >= mask_raw.col_start[ i + 1 ] ) {
1386+
_DEBUG_THREADESAFE_PRINT( "Not value left in mask for this column\n" );
1387+
break;
13871388
}
1388-
1389-
if( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1389+
1390+
if( mask_raw.row_index[ B_k ] != j || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
13901391
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
13911392
continue;
13921393
}
13931394

13941395
// Increment the pointer of B until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1395-
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > k_col ) {
1396+
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > j ) {
13961397
_DEBUG_THREADESAFE_PRINT( "NEquals B coordinate: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
13971398
B_k++;
13981399
}
13991400

14001401
// Get B value (or identity if not found)
14011402
auto B_val = B_identity;
1402-
if( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] == k_col ) {
1403+
if( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] == j ) {
14031404
_DEBUG_THREADESAFE_PRINT( "Found B value at: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
14041405
B_val = B_raw.values[ B_k ];
14051406
B_k++;
14061407
} else {
1407-
_DEBUG_THREADESAFE_PRINT( "Not found B, using identity: ( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( B_val ) + "\n" );
1408+
_DEBUG_THREADESAFE_PRINT( "Not found B, using identity: ( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( B_val ) + "\n" );
14081409
}
14091410

14101411
// Get A value
14111412
const auto a_val_before = A_crs_raw.values[ k ];
1412-
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
1413+
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( a_val_before ) + "\n" );
14131414
// Compute the fold for this coordinate
14141415
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, B_val, op );
14151416
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, B_val, op );

0 commit comments

Comments
 (0)