Skip to content

Commit 0a3da01

Browse files
committed
Masked variant implemented in reference+omp
1 parent 3bd364c commit 0a3da01

File tree

3 files changed

+129
-7
lines changed

3 files changed

+129
-7
lines changed

include/graphblas/blas0.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <type_traits> //enable_if
3333

3434
#include "graphblas/descriptors.hpp"
35+
#include "graphblas/identities.hpp"
3536
#include "graphblas/rc.hpp"
3637
#include "graphblas/type_traits.hpp"
3738

@@ -604,6 +605,38 @@ namespace grb {
604605

605606
};
606607

608+
template< typename MaskType >
609+
struct MaskHasValue {
610+
611+
public:
612+
template < Descriptor descr = descriptors::no_operation, typename MaskStruct >
613+
MaskHasValue( const MaskStruct& mask_raw, const size_t k ) {
614+
bool hasValue = (bool) mask_raw.getValue( k, grb::identities::logical_false<bool>() );
615+
if (descr & grb::descriptors::invert_mask) {
616+
hasValue = !hasValue;
617+
}
618+
value = hasValue;
619+
}
620+
621+
bool value;
622+
};
623+
624+
template<>
625+
struct MaskHasValue< void > {
626+
627+
public:
628+
template < Descriptor descr = descriptors::no_operation, typename MaskStruct >
629+
MaskHasValue( const MaskStruct& mask_raw, const size_t k ) :
630+
value(not (descr & grb::descriptors::invert_mask)){
631+
(void) mask_raw;
632+
(void) k;
633+
}
634+
635+
const bool value;
636+
637+
};
638+
639+
607640
} // namespace internal
608641

609642
} // namespace grb

include/graphblas/reference/blas3.hpp

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ namespace grb {
12551255

12561256
// Check mask dimensions
12571257
if( m != m_B || n != n_B ) {
1258-
_DEBUG_THREADESAFE_PRINT( "Mask dimensions do not match input matrix dimensions\n" );
1258+
_DEBUG_THREADESAFE_PRINT( "Dimensions of matrices do not match!\n" );
12591259
return MISMATCH;
12601260
}
12611261

@@ -1326,16 +1326,104 @@ namespace grb {
13261326
const Monoid &monoid = Monoid()
13271327
) {
13281328
_DEBUG_THREADESAFE_PRINT( "In grb::internal::fold_matrix_matrix_masked_generic( reference )\n" );
1329-
RC rc = UNSUPPORTED;
1330-
(void) A;
1331-
(void) mask;
1332-
(void) B;
1333-
(void) monoid;
1329+
RC rc = SUCCESS;
13341330

13351331
if( grb::nnz(mask) == 0 || grb::nnz(A) == 0 ) {
13361332
return rc;
13371333
}
13381334

1335+
const auto &A_crs_raw = internal::getCRS( A );
1336+
const auto &A_ccs_raw = internal::getCCS( A );
1337+
const auto &mask_raw = descr & grb::descriptors::transpose_left ?
1338+
internal::getCCS( mask ) : internal::getCRS( mask );
1339+
const auto &B_raw = descr & grb::descriptors::transpose_right ?
1340+
internal::getCCS( B ) : internal::getCRS( B );
1341+
const size_t m = nrows( A );
1342+
const size_t n = ncols( A );
1343+
const size_t m_mask = descr & grb::descriptors::transpose_left || descr & grb::descriptors::transpose_left ?
1344+
ncols( mask ) : nrows( mask );
1345+
const size_t n_mask = descr & grb::descriptors::transpose_left || descr & grb::descriptors::transpose_left ?
1346+
nrows( mask ) : ncols( mask );
1347+
const size_t m_B = descr & grb::descriptors::transpose_right || descr & grb::descriptors::transpose_matrix ?
1348+
ncols( B ) : nrows( B );
1349+
const size_t n_B = descr & grb::descriptors::transpose_right || descr & grb::descriptors::transpose_matrix ?
1350+
nrows( B ) : ncols( B );
1351+
1352+
// Check mask dimensions
1353+
if( m != m_B || n != n_B || m != m_mask || n != n_mask ) {
1354+
_DEBUG_THREADESAFE_PRINT( "Dimensions of matrices do not match!\n" );
1355+
return MISMATCH;
1356+
}
1357+
1358+
RC local_rc = rc;
1359+
const auto& op = monoid.getOperator();
1360+
const InputType B_identity = monoid.template getIdentity< InputType >();
1361+
1362+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1363+
#pragma omp parallel default(none) shared(A_crs_raw, A_ccs_raw, mask_raw, B_raw, rc, std::cout) firstprivate(local_rc, m, op, B_identity)
1364+
#endif
1365+
{
1366+
size_t start_row = 0;
1367+
size_t end_row = m;
1368+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1369+
config::OMP::localRange( start_row, end_row, 0, m );
1370+
#endif
1371+
for( auto i = start_row; i < end_row; ++i ) {
1372+
auto B_k = B_raw.col_start[ i ];
1373+
auto mask_k = mask_raw.col_start[ i ];
1374+
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
1375+
auto k_col = A_crs_raw.row_index[ k ];
1376+
1377+
// Increment the pointer of mask until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1378+
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
1379+
_DEBUG_THREADESAFE_PRINT( "NEquals MASK coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
1380+
mask_k++;
1381+
}
1382+
1383+
if( mask_raw.row_index[ B_k ] != k_col ) {
1384+
mask_k++;
1385+
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
1386+
continue;
1387+
}
1388+
1389+
if( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1390+
_DEBUG_THREADESAFE_PRINT( "Skip MASK value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ B_k ] ) + " )\n" );
1391+
continue;
1392+
}
1393+
1394+
// Increment the pointer of B until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1395+
while( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] > k_col ) {
1396+
_DEBUG_THREADESAFE_PRINT( "NEquals B coordinate: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
1397+
B_k++;
1398+
}
1399+
1400+
// Get B value (or identity if not found)
1401+
auto B_val = B_identity;
1402+
if( B_k < B_raw.col_start[ i + 1 ] && B_raw.row_index[ B_k ] == k_col ) {
1403+
_DEBUG_THREADESAFE_PRINT( "Found B value at: ( " + std::to_string( i ) + ";" + std::to_string( B_raw.row_index[ B_k ] ) + " )\n" );
1404+
B_val = B_raw.values[ B_k ];
1405+
B_k++;
1406+
} else {
1407+
_DEBUG_THREADESAFE_PRINT( "Not found B, using identity: ( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( B_val ) + "\n" );
1408+
}
1409+
1410+
// Get A value
1411+
const auto a_val_before = A_crs_raw.values[ k ];
1412+
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
1413+
// Compute the fold for this coordinate
1414+
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, B_val, op );
1415+
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, B_val, op );
1416+
_DEBUG_THREADESAFE_PRINT( "Computing: op(" + std::to_string( a_val_before ) + ", " + std::to_string( a_val_before ) + ") = " + std::to_string( A_ccs_raw.values[ k ] ) + "\n" );
1417+
}
1418+
}
1419+
1420+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1421+
#pragma omp critical
1422+
#endif
1423+
{ // Reduction with the global return code
1424+
rc = rc ? rc : local_rc;
1425+
}
1426+
}
13391427
return rc;
13401428
}
13411429

tests/unit/fold_matrix_to_matrix.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using namespace grb;
4343
constexpr bool SKIP_FOLDL = false;
4444
constexpr bool SKIP_FOLDR = false;
4545
constexpr bool SKIP_UNMASKED = false;
46-
constexpr bool SKIP_MASKED = true; // Not implemented yet
46+
constexpr bool SKIP_MASKED = false; // Not implemented yet
4747

4848
#define _DEBUG
4949

@@ -129,6 +129,7 @@ void grb_program( const input< T, M, S, MonoidFoldl, MonoidFoldr > & in, grb::RC
129129
rc = RC::SUCCESS;
130130

131131
printSparseMatrix( in.initial, "initial" );
132+
printSparseMatrix( in.B, "B" );
132133
printSparseMatrix( in.expected, "expected" );
133134

134135
if( not in.skip_unmasked && not SKIP_FOLDL && not SKIP_UNMASKED && rc == RC::SUCCESS ) { // Unmasked foldl

0 commit comments

Comments
 (0)