@@ -928,22 +928,20 @@ namespace grb {
928928 * \a allow_void is true; otherwise, will be ignored.
929929 * \endinternal
930930 */
931-
932931 template <
933932 bool allow_void,
934933 Descriptor descr,
935- class MulMonoid , class Operator ,
934+ class Operator ,
936935 typename OutputType, typename InputType1, typename InputType2,
937936 typename RIT1, typename CIT1, typename NIT1,
938937 typename RIT2, typename CIT2, typename NIT2,
939938 typename RIT3, typename CIT3, typename NIT3
940939 >
941- RC eWiseApply_matrix_generic (
940+ RC eWiseApply_matrix_generic_intersection (
942941 Matrix< OutputType, reference, RIT1, CIT1, NIT1 > &C,
943942 const Matrix< InputType1, reference, RIT2, CIT2, NIT2 > &A,
944943 const Matrix< InputType2, reference, RIT3, CIT3, NIT3 > &B,
945944 const Operator &oper,
946- const MulMonoid &mulMonoid,
947945 const Phase &phase,
948946 const typename std::enable_if<
949947 !grb::is_object< OutputType >::value &&
@@ -958,15 +956,14 @@ namespace grb {
958956 std::is_same< InputType1, void >::value ||
959957 std::is_same< InputType2, void >::value
960958 ) ),
961- " grb::internal::eWiseApply_matrix_generic : the non-monoid version of "
959+ " grb::internal::eWiseApply_matrix_generic_intersection : the non-monoid version of "
962960 " elementwise mxm can only be used if neither of the input matrices "
963961 " is a pattern matrix (of type void)" );
964962 assert ( phase != TRY );
965963
966964#ifdef _DEBUG
967- std::cout << " In grb::internal::eWiseApply_matrix_generic \n " ;
965+ std::cout << " In grb::internal::eWiseApply_matrix_generic_intersection \n " ;
968966#endif
969-
970967 // get whether the matrices should be transposed prior to execution
971968 constexpr bool trans_left = descr & descriptors::transpose_left;
972969 constexpr bool trans_right = descr & descriptors::transpose_right;
@@ -1146,11 +1143,9 @@ namespace grb {
11461143 for ( size_t k = A_raw.col_start [ i ]; k < A_raw.col_start [ i + 1 ]; ++k ) {
11471144 const size_t k_col = A_raw.row_index [ k ];
11481145 coors1.assign ( k_col );
1149- valbuf[ k_col ] = A_raw.getValue ( k,
1150- mulMonoid.template getIdentity < typename Operator::D1 >() );
1146+ valbuf[ k_col ] = A_raw.values [ k ];
11511147#ifdef _DEBUG
1152- std::cout << " A( " << i << " , " << k_col << " ) = " << A_raw.getValue ( k,
1153- mulMonoid.template getIdentity < typename Operator::D1 >() ) << " , " ;
1148+ std::cout << " A( " << i << " , " << k_col << " ) = " << A_raw.values [ k ] << " , " ;
11541149#endif
11551150 }
11561151#ifdef _DEBUG
@@ -1160,11 +1155,9 @@ namespace grb {
11601155 const size_t l_col = B_raw.row_index [ l ];
11611156 if ( coors1.assigned ( l_col ) ) {
11621157 coors2.assign ( l_col );
1163- (void )grb::apply ( valbuf[ l_col ], valbuf[ l_col ], B_raw.getValue ( l,
1164- mulMonoid.template getIdentity < typename Operator::D2 >() ), oper );
1158+ (void )grb::apply ( valbuf[ l_col ], valbuf[ l_col ], B_raw.values [ l ], oper );
11651159#ifdef _DEBUG
1166- std::cout << " B( " << i << " , " << l_col << " ) = " << B_raw.getValue ( l,
1167- mulMonoid.template getIdentity < typename Operator::D2 >() )
1160+ std::cout << " B( " << i << " , " << l_col << " ) = " << B_raw.values [ l ]
11681161 << " to yield C( " << i << " , " << l_col << " ), " ;
11691162#endif
11701163 }
@@ -1190,6 +1183,289 @@ namespace grb {
11901183#endif
11911184 }
11921185
1186+ #ifndef NDEBUG
1187+ for ( size_t j = 0 ; j < n; ++j ) {
1188+ assert ( CCS_raw.col_start [ j + 1 ] - CCS_raw.col_start [ j ] == C_col_index[ j ] );
1189+ }
1190+ #endif
1191+
1192+ // set final number of nonzeroes in output matrix
1193+ internal::setCurrentNonzeroes ( C, nzc );
1194+ }
1195+
1196+ // done
1197+ return SUCCESS;
1198+ }
1199+
1200+ /* *
1201+ * \internal general elementwise matrix application that all eWiseApply
1202+ * variants refer to.
1203+ * @param[in] oper The operator corresponding to \a mulMonoid if
1204+ * \a allow_void is true; otherwise, an arbitrary operator
1205+ * under which to perform the eWiseApply.
1206+ * @param[in] mulMonoid The monoid under which to perform the eWiseApply if
1207+ * \a allow_void is true; otherwise, will be ignored.
1208+ * \endinternal
1209+ */
1210+ template <
1211+ bool allow_void,
1212+ Descriptor descr,
1213+ class Monoid ,
1214+ typename OutputType, typename InputType1, typename InputType2,
1215+ typename RIT1, typename CIT1, typename NIT1,
1216+ typename RIT2, typename CIT2, typename NIT2,
1217+ typename RIT3, typename CIT3, typename NIT3
1218+ >
1219+ RC eWiseApply_matrix_generic_union (
1220+ Matrix< OutputType, reference, RIT1, CIT1, NIT1 > &C,
1221+ const Matrix< InputType1, reference, RIT2, CIT2, NIT2 > &A,
1222+ const Matrix< InputType2, reference, RIT3, CIT3, NIT3 > &B,
1223+ const Monoid &monoid,
1224+ const Phase &phase,
1225+ const typename std::enable_if<
1226+ !grb::is_object< OutputType >::value &&
1227+ !grb::is_object< InputType1 >::value &&
1228+ !grb::is_object< InputType2 >::value &&
1229+ grb::is_monoid< Monoid >::value,
1230+ void >::type * const = nullptr
1231+ ) {
1232+ assert ( !(descr & descriptors::force_row_major ) );
1233+ static_assert ( allow_void ||
1234+ ( !(
1235+ std::is_same< InputType1, void >::value ||
1236+ std::is_same< InputType2, void >::value
1237+ ) ),
1238+ " grb::internal::eWiseApply_matrix_generic_union: the non-monoid version of "
1239+ " elementwise mxm can only be used if neither of the input matrices "
1240+ " is a pattern matrix (of type void)" );
1241+ assert ( phase != TRY );
1242+
1243+ #ifdef _DEBUG
1244+ std::cout << " In grb::internal::eWiseApply_matrix_generic_union\n " ;
1245+ #endif
1246+ // get whether the matrices should be transposed prior to execution
1247+ constexpr bool trans_left = descr & descriptors::transpose_left;
1248+ constexpr bool trans_right = descr & descriptors::transpose_right;
1249+
1250+ // run-time checks
1251+ const size_t m = grb::nrows ( C );
1252+ const size_t n = grb::ncols ( C );
1253+ const size_t m_A = !trans_left ? grb::nrows ( A ) : grb::ncols ( A );
1254+ const size_t n_A = !trans_left ? grb::ncols ( A ) : grb::nrows ( A );
1255+ const size_t m_B = !trans_right ? grb::nrows ( B ) : grb::ncols ( B );
1256+ const size_t n_B = !trans_right ? grb::ncols ( B ) : grb::nrows ( B );
1257+
1258+ // Identities
1259+ const auto identity_A = monoid.template getIdentity < OutputType >();
1260+ const auto identity_B = monoid.template getIdentity < OutputType >();
1261+
1262+ if ( m != m_A || m != m_B || n != n_A || n != n_B ) {
1263+ return MISMATCH;
1264+ }
1265+
1266+ const auto oper = monoid.getOperator ();
1267+ const auto &A_raw = !trans_left ?
1268+ internal::getCRS ( A ) :
1269+ internal::getCCS ( A );
1270+ const auto &B_raw = !trans_right ?
1271+ internal::getCRS ( B ) :
1272+ internal::getCCS ( B );
1273+ auto &C_raw = internal::getCRS ( C );
1274+ auto &CCS_raw = internal::getCCS ( C );
1275+
1276+ #ifdef _DEBUG
1277+ std::cout << " \t\t A offset array = { " ;
1278+ for ( size_t i = 0 ; i <= m_A; ++i ) {
1279+ std::cout << A_raw.col_start [ i ] << " " ;
1280+ }
1281+ std::cout << " }\n " ;
1282+ for ( size_t i = 0 ; i < m_A; ++i ) {
1283+ for ( size_t k = A_raw.col_start [ i ]; k < A_raw.col_start [ i + 1 ]; ++k ) {
1284+ std::cout << " \t\t ( " << i << " , " << A_raw.row_index [ k ] << " ) = "
1285+ << A_raw.getPrintValue ( k ) << " \n " ;
1286+ }
1287+ }
1288+ std::cout << " \t\t B offset array = { " ;
1289+ for ( size_t j = 0 ; j <= m_B; ++j ) {
1290+ std::cout << B_raw.col_start [ j ] << " " ;
1291+ }
1292+ std::cout << " }\n " ;
1293+ for ( size_t j = 0 ; j < m_B; ++j ) {
1294+ for ( size_t k = B_raw.col_start [ j ]; k < B_raw.col_start [ j + 1 ]; ++k ) {
1295+ std::cout << " \t\t ( " << B_raw.row_index [ k ] << " , " << j << " ) = "
1296+ << B_raw.getPrintValue ( k ) << " \n " ;
1297+ }
1298+ }
1299+ #endif
1300+
1301+ // retrieve buffers
1302+ char * arr1, * arr2, * arr3, * buf1, * buf2, * buf3;
1303+ arr1 = arr2 = buf1 = buf2 = nullptr ;
1304+ InputType1 * vbuf1 = nullptr ;
1305+ InputType2 * vbuf2 = nullptr ;
1306+ OutputType * valbuf = nullptr ;
1307+ internal::getMatrixBuffers ( arr1, buf1, vbuf1, 1 , A );
1308+ internal::getMatrixBuffers ( arr2, buf2, vbuf2, 1 , B );
1309+ internal::getMatrixBuffers ( arr3, buf3, valbuf, 1 , C );
1310+ // end buffer retrieval
1311+
1312+ // initialisations
1313+ internal::Coordinates< reference > coors1, coors2;
1314+ coors1.set ( arr1, false , buf1, n );
1315+ coors2.set ( arr2, false , buf2, n );
1316+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1317+ #pragma omp parallel
1318+ {
1319+ size_t start, end;
1320+ config::OMP::localRange ( start, end, 0 , n + 1 );
1321+ #else
1322+ const size_t start = 0 ;
1323+ const size_t end = n + 1 ;
1324+ #endif
1325+ for ( size_t j = start; j < end; ++j ) {
1326+ CCS_raw.col_start [ j ] = 0 ;
1327+ }
1328+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1329+ }
1330+ #endif
1331+ // end initialisations
1332+
1333+ // nonzero count
1334+ size_t nzc = 0 ;
1335+
1336+ // symbolic phase
1337+ if ( phase == RESIZE ) {
1338+ for ( size_t i = 0 ; i < m; ++i ) {
1339+ coors1.clear ();
1340+ for ( size_t k = A_raw.col_start [ i ]; k < A_raw.col_start [ i + 1 ]; ++k ) {
1341+ const size_t k_col = A_raw.row_index [ k ];
1342+ coors1.assign ( k_col );
1343+ (void )++nzc;
1344+ }
1345+ for ( size_t l = B_raw.col_start [ i ]; l < B_raw.col_start [ i + 1 ]; ++l ) {
1346+ const size_t l_col = B_raw.row_index [ l ];
1347+ if ( not coors1.assigned ( l_col ) ) {
1348+ (void )++nzc;
1349+ }
1350+ }
1351+ }
1352+
1353+ const RC ret = grb::resize ( C, nzc );
1354+ if ( ret != SUCCESS ) {
1355+ return ret;
1356+ }
1357+ }
1358+
1359+ // computational phase
1360+ if ( phase == EXECUTE ) {
1361+ // retrieve additional buffer
1362+ config::NonzeroIndexType * const C_col_index = internal::template
1363+ getReferenceBuffer< typename config::NonzeroIndexType >( n + 1 );
1364+
1365+ // perform column-wise nonzero count
1366+ for ( size_t i = 0 ; i < m; ++i ) {
1367+ coors1.clear ();
1368+ for ( size_t k = A_raw.col_start [ i ]; k < A_raw.col_start [ i + 1 ]; ++k ) {
1369+ const size_t k_col = A_raw.row_index [ k ];
1370+ coors1.assign ( k_col );
1371+ (void ) ++nzc;
1372+ (void ) ++CCS_raw.col_start [ k_col + 1 ];
1373+ }
1374+ for ( size_t l = B_raw.col_start [ i ]; l < B_raw.col_start [ i + 1 ]; ++l ) {
1375+ const size_t l_col = B_raw.row_index [ l ];
1376+ if ( not coors1.assigned ( l_col ) ) {
1377+ (void ) ++nzc;
1378+ (void ) ++CCS_raw.col_start [ l_col + 1 ];
1379+ }
1380+ }
1381+ }
1382+
1383+ // check capacity
1384+ if ( nzc > capacity ( C ) ) {
1385+ #ifdef _DEBUG
1386+ std::cout << " \t detected insufficient capacity "
1387+ << " for requested operation\n " ;
1388+ #endif
1389+ const RC clear_rc = clear ( C );
1390+ if ( clear_rc != SUCCESS ) {
1391+ return PANIC;
1392+ } else {
1393+ return FAILED;
1394+ }
1395+ }
1396+
1397+ // prefix sum for CCS_raw.col_start
1398+ assert ( CCS_raw.col_start [ 0 ] == 0 );
1399+ for ( size_t j = 1 ; j < n; ++j ) {
1400+ CCS_raw.col_start [ j + 1 ] += CCS_raw.col_start [ j ];
1401+ }
1402+ assert ( CCS_raw.col_start [ n ] == nzc );
1403+
1404+ // set C_col_index to all zero
1405+ #ifdef _H_GRB_REFERENCE_OMP_BLAS3
1406+ #pragma omp parallel for simd
1407+ #endif
1408+ for ( size_t j = 0 ; j < n; j++ ) {
1409+ C_col_index[ j ] = 0 ;
1410+ }
1411+
1412+ // do computations
1413+ size_t nzc = 0 ;
1414+ C_raw.col_start [ 0 ] = 0 ;
1415+ for ( size_t i = 0 ; i < m; ++i ) {
1416+ coors1.clear ();
1417+ coors2.clear ();
1418+ #ifdef _DEBUG
1419+ std::cout << " \t The elements " ;
1420+ #endif
1421+ for ( size_t k = A_raw.col_start [ i ]; k < A_raw.col_start [ i + 1 ]; ++k ) {
1422+ const size_t k_col = A_raw.row_index [ k ];
1423+ coors1.assign ( k_col );
1424+ valbuf[ k_col ] = A_raw.getValue ( k, identity_A );
1425+ #ifdef _DEBUG
1426+ std::cout << " A( " << i << " , " << k_col << " ) = " << A_raw.values [ k ] << " , " ;
1427+ #endif
1428+ }
1429+ #ifdef _DEBUG
1430+ std::cout << " are multiplied pairwise with " ;
1431+ #endif
1432+ for ( size_t l = B_raw.col_start [ i ]; l < B_raw.col_start [ i + 1 ]; ++l ) {
1433+ const size_t l_col = B_raw.getValue ( l, identity_B );
1434+ if ( coors1.assigned ( l_col ) ) { // Intersection case
1435+ (void )grb::apply ( valbuf[ l_col ], valbuf[ l_col ], B_raw.getValue ( l, identity_B ), oper );
1436+ #ifdef _DEBUG
1437+ std::cout << " B( " << i << " , " << l_col << " ) = " << B_raw.getValue ( l, identity_B )
1438+ << " to yield C( " << i << " , " << l_col << " ), " ;
1439+ #endif
1440+ } else { // Union case
1441+ coors1.assign ( l_col );
1442+ valbuf[ l_col ] = B_raw.getValue ( l, identity_B );
1443+ #ifdef _DEBUG
1444+ std::cout << " B( " << i << " , " << l_col << " ) = " << B_raw.getValue ( l, identity_B ) << " , " ;
1445+ #endif
1446+ }
1447+ }
1448+ #ifdef _DEBUG
1449+ std::cout << " \n " ;
1450+ #endif
1451+ for ( size_t k = 0 ; k < coors1.nonzeroes (); ++k ) {
1452+ const size_t j = coors1.index ( k );
1453+ // update CRS
1454+ C_raw.row_index [ nzc ] = j;
1455+ C_raw.setValue ( nzc, valbuf[ j ] );
1456+ // update CCS
1457+ const size_t CCS_index = C_col_index[ j ]++ + CCS_raw.col_start [ j ];
1458+ CCS_raw.row_index [ CCS_index ] = i;
1459+ CCS_raw.setValue ( CCS_index, valbuf[ j ] );
1460+ // update count
1461+ (void )++nzc;
1462+ }
1463+ C_raw.col_start [ i + 1 ] = nzc;
1464+ #ifdef _DEBUG
1465+ std::cout << " \n " ;
1466+ #endif
1467+ }
1468+
11931469#ifndef NDEBUG
11941470 for ( size_t j = 0 ; j < n; ++j ) {
11951471 assert ( CCS_raw.col_start [ j + 1 ] - CCS_raw.col_start [ j ] == C_col_index[ j ] );
@@ -1257,8 +1533,8 @@ namespace grb {
12571533 std::cout << " In grb::eWiseApply_matrix_generic (reference, monoid)\n " ;
12581534#endif
12591535
1260- return internal::eWiseApply_matrix_generic < true , descr >(
1261- C, A, B, mulmono. getOperator (), mulmono , phase
1536+ return internal::eWiseApply_matrix_generic_union < true , descr >(
1537+ C, A, B, mulmono, phase
12621538 );
12631539 }
12641540
@@ -1317,12 +1593,8 @@ namespace grb {
13171593 " input matrices is a pattern matrix (of type void)"
13181594 );
13191595
1320- typename grb::Monoid<
1321- grb::operators::mul< double >,
1322- grb::identities::one
1323- > dummyMonoid;
1324- return internal::eWiseApply_matrix_generic< false , descr >(
1325- C, A, B, mulOp, dummyMonoid, phase
1596+ return internal::eWiseApply_matrix_generic_intersection< false , descr >(
1597+ C, A, B, mulOp, phase
13261598 );
13271599 }
13281600
0 commit comments