@@ -1255,7 +1255,7 @@ namespace grb {
12551255
12561256 // Check mask dimensions
12571257 if ( m != m_B || n != n_B ) {
1258- _DEBUG_THREADESAFE_PRINT ( " Mask dimensions do not match input matrix dimensions \n " );
1258+ _DEBUG_THREADESAFE_PRINT ( " Dimensions of matrices do not match! \n " );
12591259 return MISMATCH;
12601260 }
12611261
@@ -1326,16 +1326,104 @@ namespace grb {
13261326 const Monoid &monoid = Monoid ()
13271327 ) {
13281328 _DEBUG_THREADESAFE_PRINT ( " In grb::internal::fold_matrix_matrix_masked_generic( reference )\n " );
1329- RC rc = UNSUPPORTED;
1330- (void ) A;
1331- (void ) mask;
1332- (void ) B;
1333- (void ) monoid;
1329+ RC rc = SUCCESS;
13341330
13351331 if ( grb::nnz (mask) == 0 || grb::nnz (A) == 0 ) {
13361332 return rc;
13371333 }
13381334
1335+ const auto &A_crs_raw = internal::getCRS ( A );
1336+ const auto &A_ccs_raw = internal::getCCS ( A );
1337+ const auto &mask_raw = descr & grb::descriptors::transpose_left ?
1338+ internal::getCCS ( mask ) : internal::getCRS ( mask );
1339+ const auto &B_raw = descr & grb::descriptors::transpose_right ?
1340+ internal::getCCS ( B ) : internal::getCRS ( B );
1341+ const size_t m = nrows ( A );
1342+ const size_t n = ncols ( A );
1343+ const size_t m_mask = descr & grb::descriptors::transpose_left || descr & grb::descriptors::transpose_left ?
1344+ ncols ( mask ) : nrows ( mask );
1345+ const size_t n_mask = descr & grb::descriptors::transpose_left || descr & grb::descriptors::transpose_left ?
1346+ nrows ( mask ) : ncols ( mask );
1347+ const size_t m_B = descr & grb::descriptors::transpose_right || descr & grb::descriptors::transpose_matrix ?
1348+ ncols ( B ) : nrows ( B );
1349+ const size_t n_B = descr & grb::descriptors::transpose_right || descr & grb::descriptors::transpose_matrix ?
1350+ nrows ( B ) : ncols ( B );
1351+
1352+ // Check mask dimensions
1353+ if ( m != m_B || n != n_B || m != m_mask || n != n_mask ) {
1354+ _DEBUG_THREADESAFE_PRINT ( " Dimensions of matrices do not match!\n " );
1355+ return MISMATCH;
1356+ }
1357+
1358+ RC local_rc = rc;
1359+ const auto & op = monoid.getOperator ();
1360+ const InputType B_identity = monoid.template getIdentity < InputType >();
1361+
1362+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1363+ #pragma omp parallel default(none) shared(A_crs_raw, A_ccs_raw, mask_raw, B_raw, rc, std::cout) firstprivate(local_rc, m, op, B_identity)
1364+ #endif
1365+ {
1366+ size_t start_row = 0 ;
1367+ size_t end_row = m;
1368+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1369+ config::OMP::localRange ( start_row, end_row, 0 , m );
1370+ #endif
1371+ for ( auto i = start_row; i < end_row; ++i ) {
1372+ auto B_k = B_raw.col_start [ i ];
1373+ auto mask_k = mask_raw.col_start [ i ];
1374+ for ( auto k = A_crs_raw.col_start [ i ]; k < A_crs_raw.col_start [ i + 1 ]; ++k ) {
1375+ auto k_col = A_crs_raw.row_index [ k ];
1376+
1377+ // Increment the pointer of mask until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1378+ while ( mask_k < mask_raw.col_start [ i + 1 ] && mask_raw.row_index [ mask_k ] > k_col ) {
1379+ _DEBUG_THREADESAFE_PRINT ( " NEquals MASK coordinate: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ mask_k ] ) + " )\n " );
1380+ mask_k++;
1381+ }
1382+
1383+ if ( mask_raw.row_index [ B_k ] != k_col ) {
1384+ mask_k++;
1385+ _DEBUG_THREADESAFE_PRINT ( " Skip MASK value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ B_k ] ) + " )\n " );
1386+ continue ;
1387+ }
1388+
1389+ if ( not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
1390+ _DEBUG_THREADESAFE_PRINT ( " Skip MASK value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( mask_raw.row_index [ B_k ] ) + " )\n " );
1391+ continue ;
1392+ }
1393+
1394+ // Increment the pointer of B until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
1395+ while ( B_k < B_raw.col_start [ i + 1 ] && B_raw.row_index [ B_k ] > k_col ) {
1396+ _DEBUG_THREADESAFE_PRINT ( " NEquals B coordinate: ( " + std::to_string ( i ) + " ;" + std::to_string ( B_raw.row_index [ B_k ] ) + " )\n " );
1397+ B_k++;
1398+ }
1399+
1400+ // Get B value (or identity if not found)
1401+ auto B_val = B_identity;
1402+ if ( B_k < B_raw.col_start [ i + 1 ] && B_raw.row_index [ B_k ] == k_col ) {
1403+ _DEBUG_THREADESAFE_PRINT ( " Found B value at: ( " + std::to_string ( i ) + " ;" + std::to_string ( B_raw.row_index [ B_k ] ) + " )\n " );
1404+ B_val = B_raw.values [ B_k ];
1405+ B_k++;
1406+ } else {
1407+ _DEBUG_THREADESAFE_PRINT ( " Not found B, using identity: ( " + std::to_string ( i ) + " ;" + std::to_string ( k_col ) + " ) = " + std::to_string ( B_val ) + " \n " );
1408+ }
1409+
1410+ // Get A value
1411+ const auto a_val_before = A_crs_raw.values [ k ];
1412+ _DEBUG_THREADESAFE_PRINT ( " A( " + std::to_string ( i ) + " ;" + std::to_string ( k_col ) + " ) = " + std::to_string ( a_val_before ) + " \n " );
1413+ // Compute the fold for this coordinate
1414+ local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values [ k ], a_val_before, B_val, op );
1415+ local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values [ k ], a_val_before, B_val, op );
1416+ _DEBUG_THREADESAFE_PRINT ( " Computing: op(" + std::to_string ( a_val_before ) + " , " + std::to_string ( a_val_before ) + " ) = " + std::to_string ( A_ccs_raw.values [ k ] ) + " \n " );
1417+ }
1418+ }
1419+
1420+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1421+ #pragma omp critical
1422+ #endif
1423+ { // Reduction with the global return code
1424+ rc = rc ? rc : local_rc;
1425+ }
1426+ }
13391427 return rc;
13401428 }
13411429
0 commit comments