Skip to content

Commit 11adcdf

Browse files
committed
Extend test + small improvement for masked variant
1 parent c7fc283 commit 11adcdf

File tree

2 files changed

+63
-26
lines changed

2 files changed

+63
-26
lines changed

include/graphblas/reference/blas3.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,10 @@ namespace grb {
12921292
_DEBUG_THREADESAFE_PRINT( "In grb::internal::scale_masked_generic( reference )\n" );
12931293
RC rc = SUCCESS;
12941294

1295+
if(grb::nnz(mask) == 0) {
1296+
return rc;
1297+
}
1298+
12951299
const auto &A_crs_raw = internal::getCRS( A );
12961300
const auto &A_ccs_raw = internal::getCCS( A );
12971301
const auto &mask_raw = descr & grb::descriptors::transpose_right ?

tests/unit/fold_matrix_scalar_to_matrix.cpp

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ struct input {
105105
const grb::Matrix< M > & mask;
106106
const S scalar;
107107
const grb::Matrix< T > & expected;
108+
const bool skip_masked, skip_unmasked;
108109
const OpFoldl & opFoldl;
109110
const OpFoldr & opFoldr = OpFoldr();
110111

@@ -114,10 +115,13 @@ struct input {
114115
const grb::Matrix< M > & mask,
115116
const S scalar,
116117
const grb::Matrix< T > & expected,
118+
bool skip_masked = false,
119+
bool skip_unmasked = false,
117120
const OpFoldl & opFoldl = OpFoldl(),
118121
const OpFoldr & opFoldr = OpFoldr() ) :
119122
test_label( test_label ),
120-
test_description( test_description ), initial( initial ), mask( mask ), scalar( scalar ), expected( expected ), opFoldl( opFoldl ), opFoldr( opFoldr ) {}
123+
test_description( test_description ), initial( initial ), mask( mask ), scalar( scalar ), expected( expected ), skip_masked( skip_masked ), skip_unmasked( skip_unmasked ), opFoldl( opFoldl ),
124+
opFoldr( opFoldr ) {}
121125
};
122126

123127
template< typename T, typename M, typename S, class OpFoldl, class OpFoldr >
@@ -127,7 +131,7 @@ void grb_program( const input< T, M, S, OpFoldl, OpFoldr > & in, grb::RC & rc )
127131
printSparseMatrix( in.initial, "initial" );
128132
printSparseMatrix( in.expected, "expected" );
129133

130-
if( not SKIP_FOLDL && not SKIP_UNMASKED && rc == RC::SUCCESS ) { // Unmasked foldl
134+
if( not in.skip_unmasked && not SKIP_FOLDL && not SKIP_UNMASKED && rc == RC::SUCCESS ) { // Unmasked foldl
131135
grb::Matrix< T > result = in.initial;
132136
foldl( result, in.scalar, in.opFoldl );
133137
std::cout << "foldl (unmasked) \"" << in.test_label << "\": ";
@@ -139,7 +143,7 @@ void grb_program( const input< T, M, S, OpFoldl, OpFoldr > & in, grb::RC & rc )
139143
printSparseMatrix( result, "foldl (unmasked) result" );
140144
}
141145

142-
if( not SKIP_FOLDL && not SKIP_MASKED && rc == RC::SUCCESS ) { // Masked foldl
146+
if( not in.skip_masked && not SKIP_FOLDL && not SKIP_MASKED && rc == RC::SUCCESS ) { // Masked foldl
143147
grb::Matrix< T > result = in.initial;
144148
foldl( result, in.mask, in.scalar, in.opFoldl );
145149
std::cout << "foldl (masked) \"" << in.test_label << "\": ";
@@ -151,7 +155,7 @@ void grb_program( const input< T, M, S, OpFoldl, OpFoldr > & in, grb::RC & rc )
151155
printSparseMatrix( result, "foldl (masked) result" );
152156
}
153157

154-
if( not SKIP_FOLDR && not SKIP_UNMASKED && rc == RC::SUCCESS ) { // Unmasked foldr
158+
if( not in.skip_unmasked && not SKIP_FOLDR && not SKIP_UNMASKED && rc == RC::SUCCESS ) { // Unmasked foldr
155159
grb::Matrix< T > result = in.initial;
156160
foldr( result, in.scalar, in.opFoldr );
157161
std::cout << "foldr (unmasked) \"" << in.test_label << "\": ";
@@ -163,7 +167,7 @@ void grb_program( const input< T, M, S, OpFoldl, OpFoldr > & in, grb::RC & rc )
163167
printSparseMatrix( result, "foldr (unmasked) result" );
164168
}
165169

166-
if( not SKIP_FOLDR && not SKIP_MASKED && rc == RC::SUCCESS ) { // Masked foldr
170+
if( not in.skip_masked && not SKIP_FOLDR && not SKIP_MASKED && rc == RC::SUCCESS ) { // Masked foldr
167171
grb::Matrix< T > result = in.initial;
168172
foldr( result, in.mask, in.scalar, in.opFoldr );
169173
std::cout << "foldr (masked) \"" << in.test_label << "\": ";
@@ -201,35 +205,64 @@ int main( int argc, char ** argv ) {
201205

202206
if( ! rc ) { // Identity square * 2
203207
const int k = 2;
204-
const std::string label( "Test 01" );
205-
const std::string description( "Initial: Identity int [" + std::to_string( n ) + ";" + std::to_string( n ) +
206-
"]\n"
207-
"Mask: Identity void matrix.\n"
208-
"k = 2\n"
209-
"Operator: mul\n"
210-
"Expected: Identity int [" +
211-
std::to_string( n ) + ";" + std::to_string( n ) + "] * 2" );
212208
// Initial matrix
213209
Matrix< int > initial( n, n );
214210
std::vector< size_t > initial_rows( n ), initial_cols( n );
215211
std::vector< int > initial_values( n, 1 );
216212
std::iota( initial_rows.begin(), initial_rows.end(), 0 );
217213
std::iota( initial_cols.begin(), initial_cols.end(), 0 );
218214
buildMatrixUnique( initial, initial_rows.data(), initial_cols.data(), initial_values.data(), initial_values.size(), SEQUENTIAL );
219-
// Mask
220-
Matrix< void > mask( n, n );
221-
buildMatrixUnique( mask, initial_rows.data(), initial_cols.data(), initial_rows.size(), SEQUENTIAL );
222-
// Expected matrix
223-
Matrix< int > expected( n, n );
224-
std::vector< int > expected_values( n, 2 );
225-
buildMatrixUnique( expected, initial_rows.data(), initial_cols.data(), expected_values.data(), expected_values.size(), SEQUENTIAL );
226-
std::cout << "-- Running " << label << " --" << std::endl;
227-
input< int, void, int, grb::operators::mul< int >, grb::operators::mul< int > > in { label.c_str(), description.c_str(), initial, mask, k, expected };
228-
if( launcher.exec( &grb_program, in, rc, true ) != SUCCESS ) {
229-
std::cerr << "Launching " << label << " failed" << std::endl;
230-
return 255;
215+
216+
{
217+
const std::string label( "Test 01" );
218+
const std::string description( "Initial: Identity int [" + std::to_string( n ) + ";" + std::to_string( n ) +
219+
"]\n"
220+
"Mask: Identity void matrix (matching the input).\n"
221+
"k = 2\n"
222+
"Operator: mul\n"
223+
"Expected: Identity int [" +
224+
std::to_string( n ) + ";" + std::to_string( n ) + "] * 2" );
225+
// Mask (matching the input matrix)
226+
Matrix< void > mask( n, n );
227+
buildMatrixUnique( mask, initial_rows.data(), initial_cols.data(), initial_rows.size(), SEQUENTIAL );
228+
// Expected matrix
229+
Matrix< int > expected( n, n );
230+
std::vector< int > expected_values( n, 2 );
231+
buildMatrixUnique( expected, initial_rows.data(), initial_cols.data(), expected_values.data(), expected_values.size(), SEQUENTIAL );
232+
233+
std::cout << "-- Running " << label << " --" << std::endl;
234+
input< int, void, int, grb::operators::mul< int >, grb::operators::mul< int > > in { label.c_str(), description.c_str(), initial, mask, k, expected };
235+
if( launcher.exec( &grb_program, in, rc, true ) != SUCCESS ) {
236+
std::cerr << "Launching " << label << " failed" << std::endl;
237+
return 255;
238+
}
239+
std::cout << std::endl << std::flush;
240+
}
241+
242+
{
243+
const std::string label( "Test 02" );
244+
const std::string description( "Initial: Identity int [" + std::to_string( n ) + ";" + std::to_string( n ) +
245+
"]\n"
246+
"Mask: Identity void matrix (empty).\n"
247+
"k = 2\n"
248+
"Operator: mul\n"
249+
"Expected: Identity int [" +
250+
std::to_string( n ) + ";" + std::to_string( n ) + "]" );
251+
// Mask (matching the input matrix)
252+
Matrix< void > mask( n, n );
253+
buildMatrixUnique( mask, initial_rows.data(), initial_cols.data(), 0, SEQUENTIAL );
254+
// Expected matrix
255+
Matrix< int > expected( n, n );
256+
buildMatrixUnique( expected, initial_rows.data(), initial_cols.data(), initial_values.data(), initial_values.size(), SEQUENTIAL );
257+
258+
std::cout << "-- Running " << label << " --" << std::endl;
259+
input< int, void, int, grb::operators::mul< int >, grb::operators::mul< int > > in { label.c_str(), description.c_str(), initial, mask, k, expected, false, true };
260+
if( launcher.exec( &grb_program, in, rc, true ) != SUCCESS ) {
261+
std::cerr << "Launching " << label << " failed" << std::endl;
262+
return 255;
263+
}
264+
std::cout << std::endl << std::flush;
231265
}
232-
std::cout << std::endl << std::flush;
233266
}
234267

235268
if( rc != SUCCESS ) {

0 commit comments

Comments
 (0)