@@ -31,6 +31,8 @@ THE POSSIBILITY OF SUCH DAMAGE.
3131#include < string.h>
3232#include < stdexcept> // std::lenght_error
3333#include < vector>
34+ #include < algorithm>
35+ #include < numeric>
3436#include " coreneuron/nrnconf.h"
3537#include " coreneuron/nrniv/nrniv_decl.h"
3638#include " coreneuron/nrniv/output_spikes.h"
@@ -65,6 +67,74 @@ void spikevec_unlock() {
6567}
6668
6769#if NRNMPI
70+
71+ void local_spikevec_sort (std::vector<double >& isvect,
72+ std::vector<int >& isvecg,
73+ std::vector<double >& osvect,
74+ std::vector<int >& osvecg) {
75+ osvect.resize (isvect.size ());
76+ osvecg.resize (isvecg.size ());
77+ // first build a permutation vector
78+ std::vector<std::size_t > perm (isvect.size ());
79+ std::iota (perm.begin (), perm.end (), 0 );
80+ // sort by gid (second predicate first)
81+ std::stable_sort (perm.begin (), perm.end (),
82+ [&](std::size_t i, std::size_t j) { return isvecg[i] < isvecg[j]; });
83+ // then sort by time
84+ std::stable_sort (perm.begin (), perm.end (),
85+ [&](std::size_t i, std::size_t j) { return isvect[i] < isvect[j]; });
86+ // now apply permutation to time and gid output vectors
87+ std::transform (perm.begin (), perm.end (), osvect.begin (),
88+ [&](std::size_t i) { return isvect[i]; });
89+ std::transform (perm.begin (), perm.end (), osvecg.begin (),
90+ [&](std::size_t i) { return isvecg[i]; });
91+ }
92+
93+ void sort_spikes (std::vector<double >& spikevec_time, std::vector<int >& spikevec_gid) {
94+ double lmin_time = *(std::min_element (spikevec_time.begin (), spikevec_time.end ()));
95+ double lmax_time = *(std::max_element (spikevec_time.begin (), spikevec_time.end ()));
96+ double min_time = nrnmpi_dbl_allmin (lmin_time);
97+ double max_time = nrnmpi_dbl_allmax (lmax_time);
98+
99+ // allocate send and receive counts and displacements for MPI_Alltoallv
100+ std::vector<int > snd_cnts (nrnmpi_numprocs);
101+ std::vector<int > rcv_cnts (nrnmpi_numprocs);
102+ std::vector<int > snd_dsps (nrnmpi_numprocs);
103+ std::vector<int > rcv_dsps (nrnmpi_numprocs);
104+
105+ double bin_t = (max_time - min_time) / nrnmpi_numprocs;
106+ // first find number of spikes in each time window
107+ for (const auto & st : spikevec_time) {
108+ int idx = (int )(st - min_time) / bin_t ;
109+ snd_cnts[idx]++;
110+ }
111+ for (int i = 1 ; i < nrnmpi_numprocs; i++) {
112+ snd_dsps[i] = snd_dsps[i - 1 ] + snd_cnts[i - 1 ];
113+ }
114+
115+ // now let each rank know how many spikes they will receive
116+ // and get in turn all the buffer sizes to receive
117+ nrnmpi_int_alltoall (&snd_cnts[0 ], &rcv_cnts[0 ], 1 );
118+ for (int i = 1 ; i < nrnmpi_numprocs; i++) {
119+ rcv_dsps[i] = rcv_dsps[i - 1 ] + rcv_cnts[i - 1 ];
120+ }
121+ std::size_t new_sz = 0 ;
122+ for (const auto & r : rcv_cnts) {
123+ new_sz += r;
124+ }
125+ // prepare new sorted vectors
126+ std::vector<double > svt_buf (new_sz, 0.0 );
127+ std::vector<int > svg_buf (new_sz, 0 );
128+
129+ // now exchange data
130+ nrnmpi_dbl_alltoallv (spikevec_time.data (), &snd_cnts[0 ], &snd_dsps[0 ], svt_buf.data (),
131+ &rcv_cnts[0 ], &rcv_dsps[0 ]);
132+ nrnmpi_int_alltoallv (spikevec_gid.data (), &snd_cnts[0 ], &snd_dsps[0 ], svg_buf.data (),
133+ &rcv_cnts[0 ], &rcv_dsps[0 ]);
134+
135+ local_spikevec_sort (svt_buf, svg_buf, spikevec_time, spikevec_gid);
136+ }
137+
68138/* * Write generated spikes to out.dat using mpi parallel i/o.
69139 * \todo : MPI related code should be factored into nrnmpi.c
70140 * Check spike record length which is set to 64 chars
@@ -78,6 +148,7 @@ void output_spikes_parallel(const char* outpath) {
78148 if (nrnmpi_myid == 0 ) {
79149 remove (fname.c_str ());
80150 }
151+ sort_spikes (spikevec_time, spikevec_gid);
81152 nrnmpi_barrier ();
82153
83154 // each spike record in the file is time + gid (64 chars sufficient)
@@ -136,6 +207,11 @@ void output_spikes_serial(const char* outpath) {
136207 ss << outpath << " /out.dat" ;
137208 std::string fname = ss.str ();
138209
210+ // reserve some space for sorted spikevec buffers
211+ std::vector<double > sorted_spikevec_time (spikevec_time.size ());
212+ std::vector<int > sorted_spikevec_gid (spikevec_gid.size ());
213+ local_spikevec_sort (spikevec_time, spikevec_gid, sorted_spikevec_time, sorted_spikevec_gid);
214+
139215 // remove if file already exist
140216 remove (fname.c_str ());
141217
@@ -145,9 +221,9 @@ void output_spikes_serial(const char* outpath) {
145221 return ;
146222 }
147223
148- for (unsigned i = 0 ; i < spikevec_gid .size (); ++i)
149- if (spikevec_gid [i] > -1 )
150- fprintf (f, " %.8g\t %d\n " , spikevec_time [i], spikevec_gid [i]);
224+ for (std:: size_t i = 0 ; i < sorted_spikevec_gid .size (); ++i)
225+ if (sorted_spikevec_gid [i] > -1 )
226+ fprintf (f, " %.8g\t %d\n " , sorted_spikevec_time [i], sorted_spikevec_gid [i]);
151227
152228 fclose (f);
153229}
0 commit comments