Skip to content

Commit aec1f8a

Browse files
committed
Replaced eWiseLambda with fold in Simulated Annealing
1 parent 28e9648 commit aec1f8a

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

include/graphblas/algorithms/simulated_annealing_re.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,14 @@ namespace grb {
640640
rc = rc ? rc : grb::set< descr >( h, static_cast< QType >( 0.0 ) );
641641
}
642642
rc = rc ? rc : grb::mxv< dense_descr >( h, couplings, state , ring );
643-
std::uniform_real_distribution< QType > rand_gen ( 0.0, 1.0 );
643+
644+
std::exponential_distribution< EnergyType > rand_gen ( beta );
644645
for( size_t i = 0 ; i < n; ++i ){
645-
grb::setElement( rand, rand_gen( rng ), i );
646+
const auto rnd = -rand_gen( rng );
647+
grb::setElement( rand, rnd, i );
646648
}
647649

650+
const grb::operators::leq< EnergyType > leq_operator;
648651
#ifndef NDEBUG
649652
const grb::Vector< StateType > old_state = state;
650653
#endif
@@ -656,15 +659,10 @@ namespace grb {
656659
rc = rc ? rc : grb::foldl< descr >( dn, static_cast< EnergyType >( -1 ), ring.getAdditiveMonoid() );
657660
rc = rc ? rc : grb::foldl< descr >( dn, h, ring.getMultiplicativeMonoid() );
658661

659-
// ( dn >= 0 ) | ( rand < beta * dn )
660-
rc = rc ? rc : grb::set< descr >( accept, mask );
661-
rc = rc ? rc : grb::wait(); // needed to avoid ERROR: Segmentation Fault with nonblocking backend
662-
rc = rc ? rc : grb::eWiseLambda< descr >(
663-
[ &mask, &accept, &dn, &rand, beta ]( const size_t i ){
664-
if( mask[i] ){
665-
accept[i] = ( dn[i] >= 0 ) || ( internal::log( rand[i] ) < beta * dn[i] );
666-
}
667-
}, mask, rand, dn, accept );
662+
// Choose which changes to accept
663+
// ( dn >= 0 ) | ( rand/beta < dn )
664+
rc = rc ? rc : grb::foldl< descr >( dn, rand, leq_operator );
665+
rc = rc ? rc : grb::set< descr >( accept, dn, mask );
668666

669667
// new_state = np.where(accept, 1 - old, old)
670668
rc = rc ? rc : grb::foldl< descr >( state, accept, static_cast< StateType >( -1 ), ring.getMultiplicativeMonoid() );

tests/smoke/simulated_annealing_re_from_mpi.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ EnergyType sequential_sweep_immediate(
337337
const auto &couplings = std::get<0>(data);
338338
const auto &local_fields = std::get<1>(data);
339339
auto &h = std::get<2>(data);
340-
auto &log_rand = std::get<3>(data);
340+
auto &rand = std::get<3>(data);
341341
auto &delta = std::get<4>(data);
342342
const auto &masks = std::get<5>(data);
343343
auto &dn = std::get<6>(data);
@@ -346,20 +346,21 @@ EnergyType sequential_sweep_immediate(
346346

347347
rc = rc ? rc : grb::wait();
348348
rc = rc ? rc : grb::resize( h, n );
349-
rc = rc ? rc : grb::resize( log_rand, n );
349+
rc = rc ? rc : grb::resize( rand, n );
350350
rc = rc ? rc : grb::resize( delta, n );
351351
rc = rc ? rc : grb::resize( dn, n );
352352
rc = rc ? rc : grb::resize( accept, n );
353353

354354
rc = rc ? rc : grb::set< descr >( h, local_fields );
355355
rc = rc ? rc : grb::mxv< descr >( h, couplings, state , ring );
356356

357-
std::uniform_real_distribution< JType > rand ( 0.0, 1.0 );
358-
for( size_t j = 0 ; j < n ; ++j ){
359-
const auto rnd = rand( rng );
360-
rc = rc ? rc : grb::setElement(log_rand, std::log( rnd ), j );
357+
std::exponential_distribution< EnergyType > rand_gen ( beta );
358+
for( size_t i = 0 ; i < n; ++i ){
359+
const auto rnd = -rand_gen( rng );
360+
grb::setElement( rand, rnd, i );
361361
}
362362

363+
const grb::operators::leq< EnergyType > leq_operator;
363364
#ifndef NDEBUG
364365
const grb::Vector< IOType, backend > old_state = state;
365366
#endif
@@ -375,16 +376,10 @@ EnergyType sequential_sweep_immediate(
375376
rc = rc ? rc : grb::foldl< descr >( dn, static_cast< EnergyType >( -1 ), ring.getAdditiveMonoid() );
376377
rc = rc ? rc : grb::foldl< descr >( dn, h, ring.getMultiplicativeMonoid() );
377378

378-
// ( dn >= 0 ) | ( log_rand < beta * dn )
379-
rc = rc ? rc : grb::set< descr >( accept, mask );
380-
rc = rc ? rc : grb::wait(); // needed to avoid ERROR: Segmentation Fault with nonblocking backend
381-
rc = rc ? rc : grb::eWiseLambda< descr >(
382-
[ &mask, &accept, &dn, &log_rand, beta ]( const size_t i ){
383-
(void) i;
384-
if( mask[i] ){
385-
accept[i] = ( dn[i] >= 0 ) || ( log_rand[i] < beta * dn[i] );
386-
}
387-
}, mask, log_rand, dn, accept );
379+
// Choose which changes to accept
380+
// ( dn >= 0 ) | ( rand/beta < dn )
381+
rc = rc ? rc : grb::foldl< descr >( dn, rand, leq_operator );
382+
rc = rc ? rc : grb::set< descr >( accept, dn, mask );
388383

389384
// new_state = np.where(accept, 1 - old, old)
390385
rc = rc ? rc : grb::foldl< descr >( state, accept, static_cast< IOType >( -1 ), ring.getMultiplicativeMonoid() );

0 commit comments

Comments
 (0)