Skip to content

Commit db2ae26

Browse files
committed
Cleanup and better temporary data passing to sweep function
1 parent a436e66 commit db2ae26

File tree

2 files changed

+107
-86
lines changed

2 files changed

+107
-86
lines changed

include/graphblas/algorithms/simulated_annealing_re.hpp

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,16 @@ namespace grb {
112112
* @tparam StateType The state variable type.
113113
* @tparam EnergyType The energy type.
114114
* @tparam TempType The inverse temperature type.
115+
* @tparam SweepDataType Type of data to be passed on to the sweep function (e.g. a tuple of references to temporary vectors).
115116
* @tparam Ring The semiring under which to make the sweeps.
116117
*
117118
*/
118119
template<
119120
typename QType, // type of coupling matrix values
120-
typename StateType, // type of state, ideally 0/1
121+
typename StateType, // type of state, possibly 0/1
121122
typename EnergyType,
122123
typename TempType,
124+
typename SweepDataType, // type of data to be passed through to the sweep function
123125
typename RSI, typename CSI, typename NZI, Backend backend,
124126
class Ring = Semiring<
125127
grb::operators::add< QType >, grb::operators::mul< QType >,
@@ -133,10 +135,7 @@ namespace grb {
133135
const grb::Vector< QType, backend >&,
134136
grb::Vector< StateType, backend >&,
135137
const TempType&,
136-
grb::Vector< QType >&,
137-
grb::Vector< QType >&,
138-
grb::Vector< StateType >&,
139-
const std::vector< grb::Vector< bool > >&,
138+
SweepDataType&,
140139
const Ring&
141140
)
142141
> &sweep,
@@ -147,36 +146,35 @@ namespace grb {
147146
grb::Vector< TempType > &betas,
148147
std::vector< grb::Vector< StateType, backend > > &temp_states,
149148
grb::Vector< EnergyType > &temp_energies,
150-
grb::Vector< QType > &temp_sweep1,
151-
grb::Vector< QType > &temp_sweep2,
152-
grb::Vector< StateType > &temp_sweep3,
153-
const std::vector< grb::Vector< bool > >& masks,
149+
SweepDataType& temp_sweep,
154150
const size_t &n_sweeps = 1,
155151
const bool &use_pt = false,
156152
const Ring &ring = Ring()
157153
){
158154

159-
size_t n_replicas = states.size();
155+
const size_t n_replicas = states.size();
156+
const size_t n = grb::size(states[0]);
160157

161158
assert( n_replicas > 0 );
162159
assert( n_replicas == grb::size( betas ) );
163-
assert( grb::ncols( couplings ) == grb::nrows( couplings ) );
164-
assert( grb::size( states[0] ) == grb::nrows( couplings ) );
165-
assert( grb::size( states[0] ) == grb::size( local_fields ) );
160+
assert( n == grb::ncols( couplings ) );
161+
assert( n == grb::nrows( couplings ) );
162+
assert( n == grb::size( local_fields ) );
166163

167-
for(size_t i = 1; i < n_replicas ; ++i ){
168-
assert( grb::size( states[0] ) == grb::size( states[ i ] ) );
164+
for(size_t i = 0; i < n_replicas ; ++i ){
165+
assert( n == grb::size( states[ i ] ) );
169166
}
170167

171-
const size_t n = grb::size(states[0]);
172168

173169
#ifndef NDEBUG
174-
std::cerr << "DEBUG: Called simulated_annealing_RE with parameters: "
175-
<< "\n\t n = " << n
176-
<< "\n\t n_replicas = " << n_replicas
177-
<< "\n\t n_sweeps = " << n_sweeps
178-
<< "\n\t use_pt = " << use_pt
179-
<< std::endl;
170+
if( grb::spmd<>::pid() == 0 ) {
171+
std::cerr << "DEBUG: Called simulated_annealing_RE with parameters: "
172+
<< "\n\t n = " << n
173+
<< "\n\t n_replicas = " << n_replicas
174+
<< "\n\t n_sweeps = " << n_sweeps
175+
<< "\n\t use_pt = " << use_pt
176+
<< std::endl;
177+
}
180178
#endif
181179

182180
grb::RC rc = grb::SUCCESS;
@@ -185,12 +183,9 @@ namespace grb {
185183
temp_states = states;
186184

187185
for( size_t i_sweep = 0 ; rc == grb::SUCCESS && i_sweep < n_sweeps ; ++i_sweep ){
188-
// randomize order of replicas
189-
// std::random_shuffle( states.begin(), states.end() );
190-
191-
for( size_t j = 0 ; rc == grb::SUCCESS && j < n_replicas ; ++j ){
186+
for( size_t j = 0 ; j < n_replicas ; ++j ){
192187

193-
energies[j] += sweep( couplings, local_fields, states[j], betas[j], temp_sweep1, temp_sweep2 , temp_sweep3, masks , ring );
188+
energies[j] += sweep( couplings, local_fields, states[j], betas[j], temp_sweep, ring );
194189

195190
// update_best state and energy
196191
if( energies[j] < temp_energies[j] ){
@@ -204,7 +199,9 @@ namespace grb {
204199
rc = pt( states, energies, betas );
205200
}
206201
#ifndef NDEBUG
207-
std::cerr << "Energy at iteration " << i_sweep << " = " << energies[ 0 ] << std::endl;
202+
if( grb::spmd<>::pid() == 0 ) {
203+
std::cerr << "Energy at iteration " << i_sweep << " = " << energies[ 0 ] << std::endl;
204+
}
208205
#endif
209206
} // n_sweeps
210207

0 commit comments

Comments
 (0)