Skip to content

Commit 8cc17ac

Browse files
committed
Implement Monoid variant of BLAS3::eWiseApply
1 parent 3dcfb02 commit 8cc17ac

File tree

1 file changed

+295
-23
lines changed

1 file changed

+295
-23
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 295 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -928,22 +928,20 @@ namespace grb {
928928
* \a allow_void is true; otherwise, will be ignored.
929929
* \endinternal
930930
*/
931-
932931
template<
933932
bool allow_void,
934933
Descriptor descr,
935-
class MulMonoid, class Operator,
934+
class Operator,
936935
typename OutputType, typename InputType1, typename InputType2,
937936
typename RIT1, typename CIT1, typename NIT1,
938937
typename RIT2, typename CIT2, typename NIT2,
939938
typename RIT3, typename CIT3, typename NIT3
940939
>
941-
RC eWiseApply_matrix_generic(
940+
RC eWiseApply_matrix_generic_intersection(
942941
Matrix< OutputType, reference, RIT1, CIT1, NIT1 > &C,
943942
const Matrix< InputType1, reference, RIT2, CIT2, NIT2 > &A,
944943
const Matrix< InputType2, reference, RIT3, CIT3, NIT3 > &B,
945944
const Operator &oper,
946-
const MulMonoid &mulMonoid,
947945
const Phase &phase,
948946
const typename std::enable_if<
949947
!grb::is_object< OutputType >::value &&
@@ -958,15 +956,14 @@ namespace grb {
958956
std::is_same< InputType1, void >::value ||
959957
std::is_same< InputType2, void >::value
960958
) ),
961-
"grb::internal::eWiseApply_matrix_generic: the non-monoid version of "
959+
"grb::internal::eWiseApply_matrix_generic_intersection: the non-monoid version of "
962960
"elementwise mxm can only be used if neither of the input matrices "
963961
"is a pattern matrix (of type void)" );
964962
assert( phase != TRY );
965963

966964
#ifdef _DEBUG
967-
std::cout << "In grb::internal::eWiseApply_matrix_generic\n";
965+
std::cout << "In grb::internal::eWiseApply_matrix_generic_intersection\n";
968966
#endif
969-
970967
// get whether the matrices should be transposed prior to execution
971968
constexpr bool trans_left = descr & descriptors::transpose_left;
972969
constexpr bool trans_right = descr & descriptors::transpose_right;
@@ -1146,11 +1143,9 @@ namespace grb {
11461143
for( size_t k = A_raw.col_start[ i ]; k < A_raw.col_start[ i + 1 ]; ++k ) {
11471144
const size_t k_col = A_raw.row_index[ k ];
11481145
coors1.assign( k_col );
1149-
valbuf[ k_col ] = A_raw.getValue( k,
1150-
mulMonoid.template getIdentity< typename Operator::D1 >() );
1146+
valbuf[ k_col ] = A_raw.values[ k ];
11511147
#ifdef _DEBUG
1152-
std::cout << "A( " << i << ", " << k_col << " ) = " << A_raw.getValue( k,
1153-
mulMonoid.template getIdentity< typename Operator::D1 >() ) << ", ";
1148+
std::cout << "A( " << i << ", " << k_col << " ) = " << A_raw.values[ k ] << ", ";
11541149
#endif
11551150
}
11561151
#ifdef _DEBUG
@@ -1160,11 +1155,9 @@ namespace grb {
11601155
const size_t l_col = B_raw.row_index[ l ];
11611156
if( coors1.assigned( l_col ) ) {
11621157
coors2.assign( l_col );
1163-
(void)grb::apply( valbuf[ l_col ], valbuf[ l_col ], B_raw.getValue( l,
1164-
mulMonoid.template getIdentity< typename Operator::D2 >() ), oper );
1158+
(void)grb::apply( valbuf[ l_col ], valbuf[ l_col ], B_raw.values[ l ], oper );
11651159
#ifdef _DEBUG
1166-
std::cout << "B( " << i << ", " << l_col << " ) = " << B_raw.getValue( l,
1167-
mulMonoid.template getIdentity< typename Operator::D2 >() )
1160+
std::cout << "B( " << i << ", " << l_col << " ) = " << B_raw.values[ l ]
11681161
<< " to yield C( " << i << ", " << l_col << " ), ";
11691162
#endif
11701163
}
@@ -1190,6 +1183,289 @@ namespace grb {
11901183
#endif
11911184
}
11921185

1186+
#ifndef NDEBUG
1187+
for( size_t j = 0; j < n; ++j ) {
1188+
assert( CCS_raw.col_start[ j + 1 ] - CCS_raw.col_start[ j ] == C_col_index[ j ] );
1189+
}
1190+
#endif
1191+
1192+
// set final number of nonzeroes in output matrix
1193+
internal::setCurrentNonzeroes( C, nzc );
1194+
}
1195+
1196+
// done
1197+
return SUCCESS;
1198+
}
1199+
1200+
/**
1201+
* \internal general elementwise matrix application that all eWiseApply
1202+
* variants refer to.
1203+
* @param[in] oper The operator corresponding to \a mulMonoid if
1204+
* \a allow_void is true; otherwise, an arbitrary operator
1205+
* under which to perform the eWiseApply.
1206+
* @param[in] mulMonoid The monoid under which to perform the eWiseApply if
1207+
* \a allow_void is true; otherwise, will be ignored.
1208+
* \endinternal
1209+
*/
1210+
template<
1211+
bool allow_void,
1212+
Descriptor descr,
1213+
class Monoid,
1214+
typename OutputType, typename InputType1, typename InputType2,
1215+
typename RIT1, typename CIT1, typename NIT1,
1216+
typename RIT2, typename CIT2, typename NIT2,
1217+
typename RIT3, typename CIT3, typename NIT3
1218+
>
1219+
RC eWiseApply_matrix_generic_union(
1220+
Matrix< OutputType, reference, RIT1, CIT1, NIT1 > &C,
1221+
const Matrix< InputType1, reference, RIT2, CIT2, NIT2 > &A,
1222+
const Matrix< InputType2, reference, RIT3, CIT3, NIT3 > &B,
1223+
const Monoid &monoid,
1224+
const Phase &phase,
1225+
const typename std::enable_if<
1226+
!grb::is_object< OutputType >::value &&
1227+
!grb::is_object< InputType1 >::value &&
1228+
!grb::is_object< InputType2 >::value &&
1229+
grb::is_monoid< Monoid >::value,
1230+
void >::type * const = nullptr
1231+
) {
1232+
assert( !(descr & descriptors::force_row_major ) );
1233+
static_assert( allow_void ||
1234+
( !(
1235+
std::is_same< InputType1, void >::value ||
1236+
std::is_same< InputType2, void >::value
1237+
) ),
1238+
"grb::internal::eWiseApply_matrix_generic_union: the non-monoid version of "
1239+
"elementwise mxm can only be used if neither of the input matrices "
1240+
"is a pattern matrix (of type void)" );
1241+
assert( phase != TRY );
1242+
1243+
#ifdef _DEBUG
1244+
std::cout << "In grb::internal::eWiseApply_matrix_generic_union\n";
1245+
#endif
1246+
// get whether the matrices should be transposed prior to execution
1247+
constexpr bool trans_left = descr & descriptors::transpose_left;
1248+
constexpr bool trans_right = descr & descriptors::transpose_right;
1249+
1250+
// run-time checks
1251+
const size_t m = grb::nrows( C );
1252+
const size_t n = grb::ncols( C );
1253+
const size_t m_A = !trans_left ? grb::nrows( A ) : grb::ncols( A );
1254+
const size_t n_A = !trans_left ? grb::ncols( A ) : grb::nrows( A );
1255+
const size_t m_B = !trans_right ? grb::nrows( B ) : grb::ncols( B );
1256+
const size_t n_B = !trans_right ? grb::ncols( B ) : grb::nrows( B );
1257+
1258+
// Identities
1259+
const auto identity_A = monoid.template getIdentity< OutputType >();
1260+
const auto identity_B = monoid.template getIdentity< OutputType >();
1261+
1262+
if( m != m_A || m != m_B || n != n_A || n != n_B ) {
1263+
return MISMATCH;
1264+
}
1265+
1266+
const auto oper = monoid.getOperator();
1267+
const auto &A_raw = !trans_left ?
1268+
internal::getCRS( A ) :
1269+
internal::getCCS( A );
1270+
const auto &B_raw = !trans_right ?
1271+
internal::getCRS( B ) :
1272+
internal::getCCS( B );
1273+
auto &C_raw = internal::getCRS( C );
1274+
auto &CCS_raw = internal::getCCS( C );
1275+
1276+
#ifdef _DEBUG
1277+
std::cout << "\t\t A offset array = { ";
1278+
for( size_t i = 0; i <= m_A; ++i ) {
1279+
std::cout << A_raw.col_start[ i ] << " ";
1280+
}
1281+
std::cout << "}\n";
1282+
for( size_t i = 0; i < m_A; ++i ) {
1283+
for( size_t k = A_raw.col_start[ i ]; k < A_raw.col_start[ i + 1 ]; ++k ) {
1284+
std::cout << "\t\t ( " << i << ", " << A_raw.row_index[ k ] << " ) = "
1285+
<< A_raw.getPrintValue( k ) << "\n";
1286+
}
1287+
}
1288+
std::cout << "\t\t B offset array = { ";
1289+
for( size_t j = 0; j <= m_B; ++j ) {
1290+
std::cout << B_raw.col_start[ j ] << " ";
1291+
}
1292+
std::cout << "}\n";
1293+
for( size_t j = 0; j < m_B; ++j ) {
1294+
for( size_t k = B_raw.col_start[ j ]; k < B_raw.col_start[ j + 1 ]; ++k ) {
1295+
std::cout << "\t\t ( " << B_raw.row_index[ k ] << ", " << j << " ) = "
1296+
<< B_raw.getPrintValue( k ) << "\n";
1297+
}
1298+
}
1299+
#endif
1300+
1301+
// retrieve buffers
1302+
char * arr1, * arr2, * arr3, * buf1, * buf2, * buf3;
1303+
arr1 = arr2 = buf1 = buf2 = nullptr;
1304+
InputType1 * vbuf1 = nullptr;
1305+
InputType2 * vbuf2 = nullptr;
1306+
OutputType * valbuf = nullptr;
1307+
internal::getMatrixBuffers( arr1, buf1, vbuf1, 1, A );
1308+
internal::getMatrixBuffers( arr2, buf2, vbuf2, 1, B );
1309+
internal::getMatrixBuffers( arr3, buf3, valbuf, 1, C );
1310+
// end buffer retrieval
1311+
1312+
// initialisations
1313+
internal::Coordinates< reference > coors1, coors2;
1314+
coors1.set( arr1, false, buf1, n );
1315+
coors2.set( arr2, false, buf2, n );
1316+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1317+
#pragma omp parallel
1318+
{
1319+
size_t start, end;
1320+
config::OMP::localRange( start, end, 0, n + 1 );
1321+
#else
1322+
const size_t start = 0;
1323+
const size_t end = n + 1;
1324+
#endif
1325+
for( size_t j = start; j < end; ++j ) {
1326+
CCS_raw.col_start[ j ] = 0;
1327+
}
1328+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1329+
}
1330+
#endif
1331+
// end initialisations
1332+
1333+
// nonzero count
1334+
size_t nzc = 0;
1335+
1336+
// symbolic phase
1337+
if( phase == RESIZE ) {
1338+
for( size_t i = 0; i < m; ++i ) {
1339+
coors1.clear();
1340+
for( size_t k = A_raw.col_start[ i ]; k < A_raw.col_start[ i + 1 ]; ++k ) {
1341+
const size_t k_col = A_raw.row_index[ k ];
1342+
coors1.assign( k_col );
1343+
(void)++nzc;
1344+
}
1345+
for( size_t l = B_raw.col_start[ i ]; l < B_raw.col_start[ i + 1 ]; ++l ) {
1346+
const size_t l_col = B_raw.row_index[ l ];
1347+
if( not coors1.assigned( l_col ) ) {
1348+
(void)++nzc;
1349+
}
1350+
}
1351+
}
1352+
1353+
const RC ret = grb::resize( C, nzc );
1354+
if( ret != SUCCESS ) {
1355+
return ret;
1356+
}
1357+
}
1358+
1359+
// computational phase
1360+
if( phase == EXECUTE ) {
1361+
// retrieve additional buffer
1362+
config::NonzeroIndexType * const C_col_index = internal::template
1363+
getReferenceBuffer< typename config::NonzeroIndexType >( n + 1 );
1364+
1365+
// perform column-wise nonzero count
1366+
for( size_t i = 0; i < m; ++i ) {
1367+
coors1.clear();
1368+
for( size_t k = A_raw.col_start[ i ]; k < A_raw.col_start[ i + 1 ]; ++k ) {
1369+
const size_t k_col = A_raw.row_index[ k ];
1370+
coors1.assign( k_col );
1371+
(void) ++nzc;
1372+
(void) ++CCS_raw.col_start[ k_col + 1 ];
1373+
}
1374+
for( size_t l = B_raw.col_start[ i ]; l < B_raw.col_start[ i + 1 ]; ++l ) {
1375+
const size_t l_col = B_raw.row_index[ l ];
1376+
if( not coors1.assigned( l_col ) ) {
1377+
(void) ++nzc;
1378+
(void) ++CCS_raw.col_start[ l_col + 1 ];
1379+
}
1380+
}
1381+
}
1382+
1383+
// check capacity
1384+
if( nzc > capacity( C ) ) {
1385+
#ifdef _DEBUG
1386+
std::cout << "\t detected insufficient capacity "
1387+
<< "for requested operation\n";
1388+
#endif
1389+
const RC clear_rc = clear( C );
1390+
if( clear_rc != SUCCESS ) {
1391+
return PANIC;
1392+
} else {
1393+
return FAILED;
1394+
}
1395+
}
1396+
1397+
// prefix sum for CCS_raw.col_start
1398+
assert( CCS_raw.col_start[ 0 ] == 0 );
1399+
for( size_t j = 1; j < n; ++j ) {
1400+
CCS_raw.col_start[ j + 1 ] += CCS_raw.col_start[ j ];
1401+
}
1402+
assert( CCS_raw.col_start[ n ] == nzc );
1403+
1404+
// set C_col_index to all zero
1405+
#ifdef _H_GRB_REFERENCE_OMP_BLAS3
1406+
#pragma omp parallel for simd
1407+
#endif
1408+
for( size_t j = 0; j < n; j++ ) {
1409+
C_col_index[ j ] = 0;
1410+
}
1411+
1412+
// do computations
1413+
size_t nzc = 0;
1414+
C_raw.col_start[ 0 ] = 0;
1415+
for( size_t i = 0; i < m; ++i ) {
1416+
coors1.clear();
1417+
coors2.clear();
1418+
#ifdef _DEBUG
1419+
std::cout << "\t The elements ";
1420+
#endif
1421+
for( size_t k = A_raw.col_start[ i ]; k < A_raw.col_start[ i + 1 ]; ++k ) {
1422+
const size_t k_col = A_raw.row_index[ k ];
1423+
coors1.assign( k_col );
1424+
valbuf[ k_col ] = A_raw.getValue( k, identity_A );
1425+
#ifdef _DEBUG
1426+
std::cout << "A( " << i << ", " << k_col << " ) = " << A_raw.values[ k ] << ", ";
1427+
#endif
1428+
}
1429+
#ifdef _DEBUG
1430+
std::cout << "are multiplied pairwise with ";
1431+
#endif
1432+
for( size_t l = B_raw.col_start[ i ]; l < B_raw.col_start[ i + 1 ]; ++l ) {
1433+
const size_t l_col = B_raw.getValue( l, identity_B );
1434+
if( coors1.assigned( l_col ) ) { // Intersection case
1435+
(void)grb::apply( valbuf[ l_col ], valbuf[ l_col ], B_raw.getValue( l, identity_B ), oper );
1436+
#ifdef _DEBUG
1437+
std::cout << "B( " << i << ", " << l_col << " ) = " << B_raw.getValue( l, identity_B )
1438+
<< " to yield C( " << i << ", " << l_col << " ), ";
1439+
#endif
1440+
} else { // Union case
1441+
coors1.assign( l_col );
1442+
valbuf[ l_col ] = B_raw.getValue( l, identity_B );
1443+
#ifdef _DEBUG
1444+
std::cout << "B( " << i << ", " << l_col << " ) = " << B_raw.getValue( l, identity_B ) << ", ";
1445+
#endif
1446+
}
1447+
}
1448+
#ifdef _DEBUG
1449+
std::cout << "\n";
1450+
#endif
1451+
for( size_t k = 0; k < coors1.nonzeroes(); ++k ) {
1452+
const size_t j = coors1.index( k );
1453+
// update CRS
1454+
C_raw.row_index[ nzc ] = j;
1455+
C_raw.setValue( nzc, valbuf[ j ] );
1456+
// update CCS
1457+
const size_t CCS_index = C_col_index[ j ]++ + CCS_raw.col_start[ j ];
1458+
CCS_raw.row_index[ CCS_index ] = i;
1459+
CCS_raw.setValue( CCS_index, valbuf[ j ] );
1460+
// update count
1461+
(void)++nzc;
1462+
}
1463+
C_raw.col_start[ i + 1 ] = nzc;
1464+
#ifdef _DEBUG
1465+
std::cout << "\n";
1466+
#endif
1467+
}
1468+
11931469
#ifndef NDEBUG
11941470
for( size_t j = 0; j < n; ++j ) {
11951471
assert( CCS_raw.col_start[ j + 1 ] - CCS_raw.col_start[ j ] == C_col_index[ j ] );
@@ -1257,8 +1533,8 @@ namespace grb {
12571533
std::cout << "In grb::eWiseApply_matrix_generic (reference, monoid)\n";
12581534
#endif
12591535

1260-
return internal::eWiseApply_matrix_generic< true, descr >(
1261-
C, A, B, mulmono.getOperator(), mulmono, phase
1536+
return internal::eWiseApply_matrix_generic_union< true, descr >(
1537+
C, A, B, mulmono, phase
12621538
);
12631539
}
12641540

@@ -1317,12 +1593,8 @@ namespace grb {
13171593
"input matrices is a pattern matrix (of type void)"
13181594
);
13191595

1320-
typename grb::Monoid<
1321-
grb::operators::mul< double >,
1322-
grb::identities::one
1323-
> dummyMonoid;
1324-
return internal::eWiseApply_matrix_generic< false, descr >(
1325-
C, A, B, mulOp, dummyMonoid, phase
1596+
return internal::eWiseApply_matrix_generic_intersection< false, descr >(
1597+
C, A, B, mulOp, phase
13261598
);
13271599
}
13281600

0 commit comments

Comments
 (0)