@@ -23,7 +23,7 @@ InferenceManager::InferenceManager(
2323 const std::vector<int *> observations,
2424 const std::vector<double > hidden_states,
2525 ConditionedSFS<adouble> *csfs) :
26- saveGamma(false ), folded( false ),
26+ saveGamma(false ),
2727 hidden_states(hidden_states),
2828 npop(npop),
2929 sfs_dim(sfs_dim),
@@ -271,7 +271,42 @@ block_key NPopInferenceManager<P>::bk_to_map_key(const block_key &bk)
271271}
272272
273273template <size_t P>
274- std::map<block_key, std::map<block_key, double > > NPopInferenceManager<P>::construct_bins(const bool binning)
274+ bool NPopInferenceManager<P>::is_monomorphic(const block_key &bk)
275+ {
276+ for (unsigned int p = 0 ; p < P; ++p)
277+ {
278+ const int ind = 3 * p;
279+ if (bk (ind) != na (p) or bk (ind + 1 ) != bk (ind + 2 ))
280+ return false ;
281+ }
282+ return true ;
283+ }
284+
285+ template <size_t P>
286+ block_key NPopInferenceManager<P>::folded_key(const block_key &bk)
287+ {
288+ block_key ret = bk;
289+ for (unsigned int p = 0 ; p < P; ++p)
290+ {
291+ const int ind = 3 * p;
292+ ret (ind) = na (p) - bk (ind);
293+ ret (ind + 1 ) = bk (ind + 2 ) - bk (ind + 1 );
294+ ret (ind + 2 ) = bk (ind + 2 );
295+ }
296+ if (is_monomorphic (ret))
297+ {
298+ for (unsigned int p = 0 ; p < P; ++p)
299+ {
300+ const int ind = 3 * p;
301+ ret (ind) = 0 ;
302+ ret (ind + 1 ) = ret (ind + 2 );
303+ }
304+ }
305+ return ret;
306+ }
307+
308+ template <size_t P>
309+ std::map<block_key, std::map<block_key, double > > NPopInferenceManager<P>::construct_bins(const bool fold)
275310{
276311 std::map<block_key, std::map<block_key, double > > ret;
277312 for (auto ob : obs)
@@ -283,7 +318,14 @@ std::map<block_key, std::map<block_key, double> > NPopInferenceManager<P>::const
283318 if (ret.count (bk) == 0 )
284319 {
285320 std::map<block_key, double > m;
286- for (const block_key &kbin : bin_key<P>::run (bk, na, binning))
321+ std::set<block_key> new_keys;
322+ for (const block_key &k : bin_key<P>::run (bk, na))
323+ {
324+ new_keys.emplace (k);
325+ if (fold)
326+ new_keys.emplace (folded_key (k));
327+ }
328+ for (const block_key &kbin : new_keys)
287329 for (const auto &p : marginalize_key<P>::run (kbin.vals , n, na))
288330 m[bk_to_map_key (p.first )] += p.second ;
289331 ret[bk] = m;
@@ -341,23 +383,6 @@ void NPopInferenceManager<P>::recompute_emission_probs()
341383#pragma omp parallel for
342384 for (auto it = bpm_keys.begin (); it < bpm_keys.end (); ++it)
343385 {
344- // std::set<block_key> keys;
345- // keys.insert(key);
346- /*
347- if (this->folded)
348- {
349- Vector<int> new_key(it->size());
350- for (size_t p = 0; p < P; ++p)
351- {
352- int a = key(3 * p);
353- int b = key(3 * p + 1);
354- int nb = key(3 * p + 2);
355- new_key(1 + 2 * p) = nb - b;
356- new_key(2 + 2 * p) = nb;
357- }
358- keys.emplace(new_key);
359- }
360- */
361386 const block_key k = *it;
362387 std::array<std::set<FixedVector<int , 3 > >, P> s;
363388 Vector<adouble> tmp (M);
@@ -379,8 +404,12 @@ void NPopInferenceManager<P>::recompute_emission_probs()
379404 tmp = e2 .col (a.sum () % 2 );
380405 }
381406 else
407+ {
382408 for (const auto &p : bins.at (k))
409+ {
383410 tmp += p.second * tensorRef (p.first );
411+ }
412+ }
384413 if (tmp.maxCoeff () > 1.0 or tmp.minCoeff () <= 0.0 )
385414 {
386415 std::cout << k << std::endl;
@@ -421,13 +450,13 @@ OnePopInferenceManager::OnePopInferenceManager(
421450 const std::vector<int > obs_lengths,
422451 const std::vector<int *> observations,
423452 const std::vector<double > hidden_states,
424- const bool binning ) :
453+ const bool fold ) :
425454 NPopInferenceManager(
426455 FixedVector<int , 1 >::Constant(n),
427456 FixedVector<int, 1>::Constant(2 ),
428457 obs_lengths, observations, hidden_states,
429458 new OnePopConditionedSFS<adouble>(n),
430- binning ) {}
459+ fold ) {}
431460
432461JointCSFS<adouble>* create_jcsfs (int n1, int n2, int a1, int a2, const std::vector<double > &hidden_states)
433462{
@@ -442,13 +471,13 @@ TwoPopInferenceManager::TwoPopInferenceManager(
442471 const std::vector<int > obs_lengths,
443472 const std::vector<int *> observations,
444473 const std::vector<double > hidden_states,
445- const bool binning ) :
474+ const bool fold ) :
446475 NPopInferenceManager(
447476 (FixedVector<int , 2 >() << n1, n2).finished(),
448477 (FixedVector<int , 2 >() << a1, a2).finished(),
449478 obs_lengths, observations, hidden_states,
450479 create_jcsfs(n1, n2, a1, a2, hidden_states),
451- binning ), a1(a1), a2(a2)
480+ fold ), a1(a1), a2(a2)
452481{
453482 if (a1 + a2 != 2 )
454483 throw std::runtime_error (" configuration not supported" );
0 commit comments