Skip to content

Commit c3c97b9

Browse files
committed
re-enable folding
1 parent ede50c6 commit c3c97b9

File tree

9 files changed

+88
-77
lines changed

9 files changed

+88
-77
lines changed

include/bin_key.h

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,28 @@ struct bin_key
77
template <typename Derived1>
88
static std::set<block_key> run(
99
const block_key &key,
10-
const Eigen::MatrixBase<Derived1> &na,
11-
const bool);
10+
const Eigen::MatrixBase<Derived1> &na);
1211

1312
template <typename Derived1, typename Derived2>
1413
static std::set<block_key> run(
1514
const Eigen::MatrixBase<Derived1> &key,
16-
const Eigen::MatrixBase<Derived2> &na,
17-
const bool);
15+
const Eigen::MatrixBase<Derived2> &na);
1816
};
1917

2018
template <size_t P>
2119
template <typename Derived1>
2220
std::set<block_key> bin_key<P>::run(
2321
const block_key &key,
24-
const Eigen::MatrixBase<Derived1> &na,
25-
const bool enabled)
22+
const Eigen::MatrixBase<Derived1> &na)
2623
{
27-
return bin_key<P>::run(key.vals, na, enabled);
24+
return bin_key<P>::run(key.vals, na);
2825
}
2926

3027
template <>
3128
template <typename Derived1, typename Derived2>
3229
std::set<block_key> bin_key<1>::run(
3330
const Eigen::MatrixBase<Derived1> &key,
34-
const Eigen::MatrixBase<Derived2> &na,
35-
const bool enabled)
31+
const Eigen::MatrixBase<Derived2> &na)
3632
{
3733
Vector<int> tmp = key;
3834
std::set<block_key> init, ret;
@@ -45,32 +41,17 @@ std::set<block_key> bin_key<1>::run(
4541
}
4642
else
4743
init.emplace(tmp);
48-
if (not enabled)
49-
return init;
50-
for (const block_key &k : init)
51-
{
52-
const int nseg = k(0) + k(1);
53-
const int nb = k(2);
54-
for (int aa = std::max(0, nseg - nb); aa <= std::min(na(0), nseg); ++aa)
55-
{
56-
const int bb = nseg - aa;
57-
tmp(0) = aa;
58-
tmp(1) = bb;
59-
ret.emplace(tmp);
60-
}
61-
}
62-
return ret;
44+
return init;
6345
}
6446

6547
template <size_t P>
6648
template <typename Derived1, typename Derived2>
6749
std::set<block_key> bin_key<P>::run(
6850
const Eigen::MatrixBase<Derived1> &key,
69-
const Eigen::MatrixBase<Derived2> &na,
70-
const bool enabled)
51+
const Eigen::MatrixBase<Derived2> &na)
7152
{
72-
std::set<block_key> bk1 = bin_key<1>::run(key.head(3), na.head(1), enabled);
73-
std::set<block_key> bk2 = bin_key<P - 1>::run(key.tail(3 * (P - 1)), na.tail(P - 1), enabled);
53+
std::set<block_key> bk1 = bin_key<1>::run(key.head(3), na.head(1));
54+
std::set<block_key> bk2 = bin_key<P - 1>::run(key.tail(3 * (P - 1)), na.tail(P - 1));
7455
std::set<block_key> ret;
7556
Vector<int> v(3 * P);
7657
for (const block_key& b1 : bk1)

include/block_key.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct block_key
1010
Vector<int> vals;
1111

1212
int operator()(int k) const { return vals(k); }
13+
int& operator()(int k) { return vals.coeffRef(k); }
1314

1415
int size() const { return vals.size(); }
1516

include/inference_manager.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class InferenceManager
3030

3131
void setParams(const ParameterVector &params);
3232

33-
bool saveGamma, folded;
33+
bool saveGamma;
3434
std::vector<double> hidden_states;
3535
std::map<block_key, Vector<adouble> > emission_probs;
3636
std::vector<Matrix<double>*> getXisums();
@@ -87,11 +87,11 @@ class NPopInferenceManager : public InferenceManager
8787
const std::vector<int*> observations,
8888
const std::vector<double> hidden_states,
8989
ConditionedSFS<adouble> *csfs,
90-
const bool binning) :
90+
const bool fold) :
9191
InferenceManager(P,
9292
(na.tail(na.size() - 1).array() + 1).prod() * (n.array() + 1).prod(),
9393
obs_lengths, observations, hidden_states, csfs),
94-
n(n), na(na), tensordims(make_tensordims()), bins(construct_bins(binning))
94+
n(n), na(na), tensordims(make_tensordims()), bins(construct_bins(fold))
9595
{}
9696

9797
virtual ~NPopInferenceManager() = default;
@@ -100,6 +100,8 @@ class NPopInferenceManager : public InferenceManager
100100
protected:
101101
// Virtual overrides
102102
void recompute_emission_probs();
103+
bool is_monomorphic(const block_key&);
104+
block_key folded_key(const block_key&);
103105
FixedVector<int, 2 * P> make_tensordims();
104106
block_key bk_to_map_key(const block_key &bk);
105107

smcpp/_smcpp.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ cdef extern from "inference_manager.h":
5050
vector[adouble] Q() except +
5151
bool debug
5252
bool saveGamma
53-
bool folded
5453
vector[double] hidden_states
5554
vector[pMatrixD] getGammas()
5655
vector[pMatrixD] getXisums()

smcpp/_smcpp.pyx

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,6 @@ cdef class _PyInferenceManager:
147147
def __dealloc__(self):
148148
del self._im
149149

150-
property folded:
151-
def __get__(self):
152-
return self._im.folded
153-
154-
def __set__(self, bint f):
155-
self._im.folded = f
156-
157150
property observations:
158151
def __get__(self):
159152
return self._observations
@@ -299,11 +292,11 @@ cdef class _PyInferenceManager:
299292

300293
cdef class PyOnePopInferenceManager(_PyInferenceManager):
301294

302-
def __cinit__(self, int n, observations, hidden_states, im_id):
295+
def __cinit__(self, int n, observations, hidden_states, im_id, bool fold):
303296
# This is needed because cinit cannot be inherited
304297
self.__my_cinit__(observations, hidden_states, im_id)
305298
with nogil:
306-
self._im = new OnePopInferenceManager(n, self._Ls, self._obs_ptrs, self._hs, False)
299+
self._im = new OnePopInferenceManager(n, self._Ls, self._obs_ptrs, self._hs, fold)
307300

308301
@property
309302
def pid(self):
@@ -322,7 +315,7 @@ cdef class PyTwoPopInferenceManager(_PyInferenceManager):
322315
cdef TwoPopInferenceManager* _im2
323316
cdef int _a1
324317

325-
def __cinit__(self, int n1, int n2, int a1, int a2, observations, hidden_states, im_id):
318+
def __cinit__(self, int n1, int n2, int a1, int a2, observations, hidden_states, im_id, bool fold):
326319
# This is needed because cinit cannot be inherited
327320
assert a1 + a2 == 2
328321
assert a1 in [1, 2]
@@ -331,7 +324,7 @@ cdef class PyTwoPopInferenceManager(_PyInferenceManager):
331324
self.__my_cinit__(observations, hidden_states, im_id)
332325
assert a1 in [1, 2], "a2=2 is not supported"
333326
with nogil:
334-
self._im2 = new TwoPopInferenceManager(n1, n2, a1, a2, self._Ls, self._obs_ptrs, self._hs, False)
327+
self._im2 = new TwoPopInferenceManager(n1, n2, a1, a2, self._Ls, self._obs_ptrs, self._hs, fold)
335328
self._im = self._im2
336329

337330
@targets("model update")

smcpp/analysis.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _init_hidden_states(self, prior_model, M):
154154
)
155155
logger.debug("%d hidden states:\n%s" % (len(self._hidden_states), str(self._hidden_states)))
156156

157-
def _init_inference_manager(self, folded):
157+
def _init_inference_manager(self, fold):
158158
## Create inference object which will be used for all further calculations.
159159
logger.debug("Creating inference manager...")
160160
d = {}
@@ -166,10 +166,12 @@ def _init_inference_manager(self, folded):
166166
k = (pid, n, a)
167167
data = [contig.data for contig in d[k]]
168168
if len(pid) == 1:
169-
im = _smcpp.PyOnePopInferenceManager(n[0], data, self._hidden_states, k)
169+
im = _smcpp.PyOnePopInferenceManager(n[0], data,
170+
self._hidden_states, k, fold)
170171
else:
171172
assert len(pid) == 2
172-
im = _smcpp.PyTwoPopInferenceManager(n[0], n[1], a[0], a[1], data, self._hidden_states, k)
173+
im = _smcpp.PyTwoPopInferenceManager(n[0], n[1], a[0], a[1],
174+
data, self._hidden_states, k, fold)
173175
im.model = self._model
174176
im.theta = self._theta
175177
im.rho = self._rho
@@ -276,7 +278,7 @@ def __init__(self, files, args):
276278

277279
if not args.no_initialize:
278280
self._hidden_states = np.array([0., np.inf])
279-
self._init_inference_manager(False)
281+
self._init_inference_manager(args.fold)
280282
self._init_optimizer(args, files, args.outdir, args.block_size,
281283
args.algorithm, args.tolerance, learn_rho=False)
282284
self._optimizer.run(1)
@@ -286,7 +288,7 @@ def __init__(self, files, args):
286288

287289
# Continue initializing
288290
self._init_hidden_states(args.prior_model, args.M)
289-
self._init_inference_manager(False)
291+
self._init_inference_manager(args.fold)
290292
self._init_optimizer(args, files, args.outdir, args.block_size,
291293
args.algorithm, args.tolerance, learn_rho=True)
292294

smcpp/commands/command.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def add_common_estimation_args(parser):
2121
data.add_argument('--no-filter', help="do not drop contigs with extreme heterozygosity. "
2222
"(not recommended unless data set is small)",
2323
action="store_true", default=False)
24-
# data.add_argument("--folded", action="store_true", default=False,
25-
# help="use folded SFS for emission probabilites. "
26-
# "useful if polarization is not known.")
24+
data.add_argument("--fold", action="store_true", default=False,
25+
help="use folded SFS for emission probabilites. "
26+
"(if polarization is not known.)")
2727

2828
optimizer = parser.add_argument_group("Optimization parameters")
2929
optimizer.add_argument(

src/inference_manager.cpp

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ InferenceManager::InferenceManager(
2323
const std::vector<int*> observations,
2424
const std::vector<double> hidden_states,
2525
ConditionedSFS<adouble> *csfs) :
26-
saveGamma(false), folded(false),
26+
saveGamma(false),
2727
hidden_states(hidden_states),
2828
npop(npop),
2929
sfs_dim(sfs_dim),
@@ -271,7 +271,42 @@ block_key NPopInferenceManager<P>::bk_to_map_key(const block_key &bk)
271271
}
272272

273273
template <size_t P>
274-
std::map<block_key, std::map<block_key, double> > NPopInferenceManager<P>::construct_bins(const bool binning)
274+
bool NPopInferenceManager<P>::is_monomorphic(const block_key &bk)
275+
{
276+
for (unsigned int p = 0; p < P; ++p)
277+
{
278+
const int ind = 3 * p;
279+
if (bk(ind) != na(p) or bk(ind + 1) != bk(ind + 2))
280+
return false;
281+
}
282+
return true;
283+
}
284+
285+
template <size_t P>
286+
block_key NPopInferenceManager<P>::folded_key(const block_key &bk)
287+
{
288+
block_key ret = bk;
289+
for (unsigned int p = 0; p < P; ++p)
290+
{
291+
const int ind = 3 * p;
292+
ret(ind) = na(p) - bk(ind);
293+
ret(ind + 1) = bk(ind + 2) - bk(ind + 1);
294+
ret(ind + 2) = bk(ind + 2);
295+
}
296+
if (is_monomorphic(ret))
297+
{
298+
for (unsigned int p = 0; p < P; ++p)
299+
{
300+
const int ind = 3 * p;
301+
ret(ind) = 0;
302+
ret(ind + 1) = ret(ind + 2);
303+
}
304+
}
305+
return ret;
306+
}
307+
308+
template <size_t P>
309+
std::map<block_key, std::map<block_key, double> > NPopInferenceManager<P>::construct_bins(const bool fold)
275310
{
276311
std::map<block_key, std::map<block_key, double> > ret;
277312
for (auto ob : obs)
@@ -283,7 +318,14 @@ std::map<block_key, std::map<block_key, double> > NPopInferenceManager<P>::const
283318
if (ret.count(bk) == 0)
284319
{
285320
std::map<block_key, double> m;
286-
for (const block_key &kbin : bin_key<P>::run(bk, na, binning))
321+
std::set<block_key> new_keys;
322+
for (const block_key &k : bin_key<P>::run(bk, na))
323+
{
324+
new_keys.emplace(k);
325+
if (fold)
326+
new_keys.emplace(folded_key(k));
327+
}
328+
for (const block_key &kbin : new_keys)
287329
for (const auto &p : marginalize_key<P>::run(kbin.vals, n, na))
288330
m[bk_to_map_key(p.first)] += p.second;
289331
ret[bk] = m;
@@ -341,23 +383,6 @@ void NPopInferenceManager<P>::recompute_emission_probs()
341383
#pragma omp parallel for
342384
for (auto it = bpm_keys.begin(); it < bpm_keys.end(); ++it)
343385
{
344-
// std::set<block_key> keys;
345-
// keys.insert(key);
346-
/*
347-
if (this->folded)
348-
{
349-
Vector<int> new_key(it->size());
350-
for (size_t p = 0; p < P; ++p)
351-
{
352-
int a = key(3 * p);
353-
int b = key(3 * p + 1);
354-
int nb = key(3 * p + 2);
355-
new_key(1 + 2 * p) = nb - b;
356-
new_key(2 + 2 * p) = nb;
357-
}
358-
keys.emplace(new_key);
359-
}
360-
*/
361386
const block_key k = *it;
362387
std::array<std::set<FixedVector<int, 3> >, P> s;
363388
Vector<adouble> tmp(M);
@@ -379,8 +404,12 @@ void NPopInferenceManager<P>::recompute_emission_probs()
379404
tmp = e2.col(a.sum() % 2);
380405
}
381406
else
407+
{
382408
for (const auto &p : bins.at(k))
409+
{
383410
tmp += p.second * tensorRef(p.first);
411+
}
412+
}
384413
if (tmp.maxCoeff() > 1.0 or tmp.minCoeff() <= 0.0)
385414
{
386415
std::cout << k << std::endl;
@@ -421,13 +450,13 @@ OnePopInferenceManager::OnePopInferenceManager(
421450
const std::vector<int> obs_lengths,
422451
const std::vector<int*> observations,
423452
const std::vector<double> hidden_states,
424-
const bool binning) :
453+
const bool fold) :
425454
NPopInferenceManager(
426455
FixedVector<int, 1>::Constant(n),
427456
FixedVector<int, 1>::Constant(2),
428457
obs_lengths, observations, hidden_states,
429458
new OnePopConditionedSFS<adouble>(n),
430-
binning) {}
459+
fold) {}
431460

432461
JointCSFS<adouble>* create_jcsfs(int n1, int n2, int a1, int a2, const std::vector<double> &hidden_states)
433462
{
@@ -442,13 +471,13 @@ TwoPopInferenceManager::TwoPopInferenceManager(
442471
const std::vector<int> obs_lengths,
443472
const std::vector<int*> observations,
444473
const std::vector<double> hidden_states,
445-
const bool binning) :
474+
const bool fold) :
446475
NPopInferenceManager(
447476
(FixedVector<int, 2>() << n1, n2).finished(),
448477
(FixedVector<int, 2>() << a1, a2).finished(),
449478
obs_lengths, observations, hidden_states,
450479
create_jcsfs(n1, n2, a1, a2, hidden_states),
451-
binning), a1(a1), a2(a2)
480+
fold), a1(a1), a2(a2)
452481
{
453482
if (a1 + a2 != 2)
454483
throw std::runtime_error("configuration not supported");

test/integration/test.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ set -e
44
$SMC vcf2smc example/example.vcf.gz /tmp/example.1.smc.gz 1 msp1:msp_0,msp_1
55
$SMC vcf2smc example/example.vcf.gz /tmp/example.2.smc.gz 1 msp2:msp_2
66
$SMC vcf2smc example/example.vcf.gz /tmp/example.12.smc.gz 1 msp1:msp_0,msp_1 msp2:msp_2
7-
$SMC estimate -o /tmp/out/1 --theta .00025 --em-iterations 1 /tmp/example.1.smc.gz
7+
$SMC estimate -o /tmp/out/1 --theta .00025 --fold --em-iterations 1 /tmp/example.1.smc.gz
88
$SMC estimate -o /tmp/out/2 --theta .00025 --em-iterations 1 /tmp/example.2.smc.gz
9-
$SMC estimate -o /tmp/out/12 --theta .00025 --em-iterations 1 /tmp/example.12.smc.gz
9+
$SMC estimate -o /tmp/out/12 --fold --theta .00025 --em-iterations 1 /tmp/example.12.smc.gz
1010
$SMC split -o /tmp/out/split --em-iterations 1 \
1111
/tmp/out/1/model.final.json \
1212
/tmp/out/2/model.final.json \
1313
/tmp/example.*.smc.gz
14+
$SMC split --fold -o /tmp/out/split --em-iterations 1 \
15+
/tmp/out/1/model.final.json \
16+
/tmp/out/2/model.final.json \
17+
/tmp/example.*.smc.gz
1418
$SMC plot -c -g 29 --logy /tmp/1.png /tmp/out/1/model.final.json
1519
$SMC plot /tmp/2.pdf /tmp/out/2/model.final.json
1620
$SMC plot -c --logy /tmp/12.png /tmp/out/12/model.final.json

0 commit comments

Comments
 (0)