Skip to content

Commit 98cbe57

Browse files
authored
Merge pull request #24 from iqtree/piyumal
Fix issue for MCMCTree dating when there is a leaf at first child of the super tree and Formatted code for Hessain calculation.
2 parents 1b00526 + 44a70a0 commit 98cbe57

File tree

3 files changed

+85
-106
lines changed

3 files changed

+85
-106
lines changed

main/phyloanalysis.cpp

Lines changed: 78 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,38 +2862,31 @@ void printMCMCFileFormat(PhyloTree *tree, MatrixXd &hessian, stringstream &tree_
28622862
counter++;
28632863
}
28642864
// here we check whether the first child of the partition tree is a leaf in the case of missing taxa.
2865-
if (leftSingle && counter==1 && (nei->link_neighbors[part_id]->node->neighbors.size()==1 || nei1->link_neighbors[part_id]->node->neighbors.size()==1))
2866-
{
2865+
if (leftSingle && counter == 1 && (nei->link_neighbors[part_id]->node->neighbors.size() == 1 || nei1->
2866+
link_neighbors[part_id]->node->neighbors.size() == 1)){
28672867
leftSingleRoot = true;
28682868
counter++;
28692869
}
28702870
}
28712871
// Note: if we get an unrooted tree with left child is a leaf, we need to push it to back as the last branch
28722872
// This is a special case in MCMCTree
2873-
if (leftSingleRoot)
2874-
{
2873+
if (leftSingleRoot){
28752874
auto partBrM = partBrMmap.front();
28762875
vector<pair<int, int>> partBrMapleftSingle;
28772876
vector<pair<int, int>> partBrMmap2;
28782877
int leftSinglePartId = partBrM.second;
28792878

2880-
for (auto it : partBrMmap)
2881-
{
2882-
if (it.second != leftSinglePartId)
2883-
{
2879+
// Note: multiples of same id could be mapped to the original tree
2880+
for (auto it : partBrMmap){
2881+
if (it.second != leftSinglePartId){
28842882
partBrMmap2.emplace_back(it.first, it.second);
2885-
}
2886-
else
2887-
{
2883+
}else{
28882884
partBrMapleftSingle.emplace_back(it.first, it.second);
28892885
}
28902886
}
2891-
2892-
for (auto it : partBrMapleftSingle)
2893-
{
2887+
for (auto it : partBrMapleftSingle){
28942888
partBrMmap2.emplace_back(it.first, it.second);
28952889
}
2896-
28972890
partBrMmap = partBrMmap2;
28982891
}
28992892
int numStates = tree->getModel()->num_states;
@@ -2910,21 +2903,17 @@ void printMCMCFileFormat(PhyloTree *tree, MatrixXd &hessian, stringstream &tree_
29102903
auto *hessian_diagonal_part = aligned_alloc<double>(branchNum);
29112904
memset(hessian_diagonal_part, 0, branchNum * sizeof(double));
29122905

2913-
// double *G_matrix_sub_tree = tree->G_matrix;
2906+
double *G_matrix_sub_tree = tree->G_matrix;
29142907
if (superTree->params->partition_type != BRLEN_OPTIMIZE) {
29152908
for (auto mapping: partBrMmap) {
29162909
int stree_branch_id = mapping.first;
29172910
int part_branch_id = mapping.second;
2918-
for (int i = 0; i < nPtn; i++) {
2919-
G_matrix_part[stree_branch_id * nPtn + i] = tree->G_matrix[part_branch_id * nPtn + i];
2920-
}
2921-
//memcpy(G_matrix_part+(sizeof(double)*stree_branch_id*nPtn), G_matrix_sub_tree+sizeof(double)*(part_branch_id*nPtn), sizeof(double)*nPtn);
2911+
memmove(G_matrix_part+stree_branch_id*nPtn, G_matrix_sub_tree+part_branch_id*nPtn, sizeof(double)*nPtn);
29222912
gradient_vector_part[stree_branch_id] = tree->gradient_vector[part_branch_id];
29232913
hessian_diagonal_part[stree_branch_id] = tree->hessian_diagonal[part_branch_id];
29242914
}
29252915
saveTreeMCMCTree(branchLengths, branch_lengths_vector, tree, tree_stream);
29262916
} else {
2927-
29282917
DoubleVector branchLengths2;
29292918
int branchCounter = 0;
29302919
std::unordered_map<int, int> brVisitedMap;
@@ -2935,29 +2924,65 @@ void printMCMCFileFormat(PhyloTree *tree, MatrixXd &hessian, stringstream &tree_
29352924
continue;
29362925
}
29372926
brVisitedMap[part_branch_id] = 1;
2938-
for (int i = 0; i < nPtn; i++) {
2939-
G_matrix_part[branchCounter * nPtn + i] = tree->G_matrix[part_branch_id * nPtn + i];
2940-
}
2941-
//memcpy(G_matrix_part+(sizeof(double)*stree_branch_id*nPtn), G_matrix_sub_tree+sizeof(double)*(part_branch_id*nPtn), sizeof(double)*nPtn);
2927+
memmove(G_matrix_part+branchCounter*nPtn, G_matrix_sub_tree+part_branch_id*nPtn, sizeof(double)*nPtn);
29422928
gradient_vector_part[branchCounter] = tree->gradient_vector[part_branch_id];
29432929
hessian_diagonal_part[branchCounter] = tree->hessian_diagonal[part_branch_id];
29442930
branchLengths2.push_back(branchLengths[part_branch_id]);
29452931
branchCounter++;
29462932
}
29472933
saveTreeMCMCTree(branchLengths2, branch_lengths_vector, tree, tree_stream);
29482934
}
2949-
29502935
processDervMCMCTree(gradient_vector_part, branchNum, nPtn, G_matrix_part, ptn_freq_diagonal,
29512936
gradient_vector_eigen, hessian, hessian_diagonal_part);
29522937
aligned_free(G_matrix_part);
29532938
aligned_free(gradient_vector_part);
29542939
aligned_free(hessian_diagonal_part);
29552940

29562941
} else {
2957-
processDervMCMCTree(tree->gradient_vector, nBranches, nPtn, tree->G_matrix, ptn_freq_diagonal,
2958-
gradient_vector_eigen, hessian, tree->hessian_diagonal);
2959-
saveTreeMCMCTree(branchLengths, branch_lengths_vector, tree, tree_stream);
29602942

2943+
int numStates = tree->getModel()->num_states;
2944+
size_t mem_size = get_safe_upper_limit(tree->getAlnNSite()) + max(get_safe_upper_limit(numStates),
2945+
get_safe_upper_limit(tree->getModelFactory()->unobserved_ptns.size()));
2946+
BranchVector singleAlnBranches;
2947+
tree->getBranches(singleAlnBranches);
2948+
vector<int> branch_ids;
2949+
for (Branch branch: singleAlnBranches){
2950+
auto nei = branch.first->findNeighbor(branch.second);
2951+
branch_ids.push_back(nei->id);
2952+
}
2953+
if (leftSingle){
2954+
int single_root = branch_ids.front();
2955+
branch_ids.erase(branch_ids.begin());
2956+
branch_ids.push_back(single_root);
2957+
}
2958+
size_t branchNum = tree->branchNum;
2959+
size_t g_matrix_size = branchNum * mem_size;
2960+
auto *G_matrix_new = aligned_alloc<double>(g_matrix_size);
2961+
memset(G_matrix_new, 0, g_matrix_size * sizeof(double));
2962+
auto *gradient_vector_new = aligned_alloc<double>(branchNum);
2963+
memset(gradient_vector_new, 0, branchNum * sizeof(double));
2964+
auto *hessian_diagonal_new = aligned_alloc<double>(branchNum);
2965+
memset(hessian_diagonal_new, 0, branchNum * sizeof(double));
2966+
2967+
double *G_matrix_tree = tree->G_matrix;
2968+
int br_counter = 0;
2969+
DoubleVector branchLengthsNew;
2970+
for (int br_id: branch_ids) {
2971+
memmove(G_matrix_new+br_counter*nPtn, G_matrix_tree+br_id*nPtn, sizeof(double)*nPtn);
2972+
gradient_vector_new[br_counter] = tree->gradient_vector[br_id];
2973+
hessian_diagonal_new[br_counter] = tree->hessian_diagonal[br_id];
2974+
branchLengthsNew.push_back(branchLengths[br_id]);
2975+
br_counter++;
2976+
}
2977+
2978+
2979+
processDervMCMCTree(gradient_vector_new, nBranches, nPtn, G_matrix_new, ptn_freq_diagonal,
2980+
gradient_vector_eigen, hessian, hessian_diagonal_new);
2981+
saveTreeMCMCTree(branchLengthsNew, branch_lengths_vector, tree, tree_stream);
2982+
2983+
aligned_free(G_matrix_new);
2984+
aligned_free(gradient_vector_new);
2985+
aligned_free(hessian_diagonal_new);
29612986
}
29622987
}
29632988

@@ -2968,7 +2993,7 @@ void generateDummyAlignment(PhyloTree* tree, ofstream &dummyAlignment){
29682993
// auto alignment = tree->aln->getPattern(0);
29692994
tree->getOrderedTaxa(nodeVector);
29702995
dummyAlignment << " " << nodeVector.size() << " " << 1 << endl;
2971-
for (Node *node: nodeVector) {
2996+
for (const Node *node: nodeVector) {
29722997
dummyAlignment << node->name << " " << "A" << endl;
29732998
}
29742999
dummyAlignment << endl;
@@ -2982,7 +3007,6 @@ void printMCMCTreeCtlFile(IQTree *iqtree, ofstream &ctl, ofstream &dummyAlignmen
29823007
} else {
29833008
ndata = 1;
29843009
}
2985-
29863010
StrVector mcmc_iter_vec;
29873011
convert_string_vec(Params::getInstance().mcmc_iter.c_str(), mcmc_iter_vec, ',');
29883012
ctl << "seed = -1" << endl
@@ -3094,38 +3118,10 @@ void printHessian(IQTree *iqtree, int partition_type) {
30943118
RowVectorXd partition_branch_lengths_vector;
30953119
RowVectorXd partition_gradient_vector_eigen;
30963120

3097-
if (it->traversal_starting_node)
3098-
{
3121+
if (it->traversal_starting_node){
30993122
auto part_traversal_starting_nei = (Neighbor*)(((Node*)it->traversal_starting_node)->neighbors[0]->node)
31003123
->findNeighbor((Node*)it->traversal_starting_node);
31013124
it->root = part_traversal_starting_nei->node;
3102-
3103-
// This is special case in MCMCtree: if the fist son of root has only one taxa, then select the next neighbour.
3104-
// if (part_traversal_starting_nei->node->neighbors.size() == 1 && it->leftSingleRoot)
3105-
// {
3106-
// // leftSingle = true;
3107-
// part_traversal_starting_nei = (PhyloNeighbor*)((part_traversal_starting_nei->node)->
3108-
// findNeighbor((part_traversal_starting_nei->node->neighbors[0]->node)));
3109-
// it->root = part_traversal_starting_nei->node;
3110-
//
3111-
// NeighborVec neighbors = part_traversal_starting_nei->node->neighbors;
3112-
// auto nei1 = part_traversal_starting_nei->node->neighbors[0];
3113-
// auto nei2 = part_traversal_starting_nei->node->neighbors[1];
3114-
// auto nei3 = part_traversal_starting_nei->node->neighbors[2];
3115-
// if (part_traversal_starting_nei->node->neighbors[0]->node == it->traversal_starting_node)
3116-
// {
3117-
// part_traversal_starting_nei->node->neighbors[0] = nei2;
3118-
// part_traversal_starting_nei->node->neighbors[1] = nei3;
3119-
// part_traversal_starting_nei->node->neighbors[2] = nei1;
3120-
// }
3121-
// else if (part_traversal_starting_nei->node->neighbors[1]->node == it->traversal_starting_node)
3122-
// {
3123-
// part_traversal_starting_nei->node->neighbors[0] = nei1;
3124-
// part_traversal_starting_nei->node->neighbors[1] = nei3;
3125-
// part_traversal_starting_nei->node->neighbors[2] = nei2;
3126-
// }
3127-
//
3128-
// }
31293125
}
31303126

31313127
printMCMCFileFormat(it, partition_hessian, partition_tree_stream, partition_branch_lengths_vector,
@@ -3179,36 +3175,20 @@ void printHessian(IQTree *iqtree, int partition_type) {
31793175
size_t nBranches = iqtree->branchNum;
31803176

31813177
// iqtree->sortTaxa();
3182-
3183-
if (iqtree->traversal_starting_node)
3184-
{
3178+
if (iqtree->traversal_starting_node){
31853179
auto traversal_starting_nei = (Neighbor*)(((Node*)iqtree->traversal_starting_node)->neighbors[0]->node)
31863180
->findNeighbor((Node*)iqtree->traversal_starting_node);
31873181
iqtree->root = traversal_starting_nei->node;
3188-
iqtree->sortTaxa();
3182+
// iqtree->sortTaxa();
31893183
iqtree->initializeTree();
31903184

31913185
// This is special case in MCMCtree: if the fist son of root has only one taxa, then select the next neighbour.
3192-
if (traversal_starting_nei->node->neighbors.size() == 1)
3193-
{
3194-
traversal_starting_nei = (PhyloNeighbor*)((traversal_starting_nei->node)->findNeighbor(
3195-
(traversal_starting_nei->node->neighbors[0]->node)));
3196-
iqtree->root = traversal_starting_nei->node;
3197-
// iqtree->sortTaxa();
3198-
// iqtree->initializeTree();
3199-
NeighborVec neighbors = traversal_starting_nei->node->neighbors;
3200-
auto nei1 = traversal_starting_nei->node->neighbors[0];
3201-
auto nei2 = traversal_starting_nei->node->neighbors[1];
3202-
auto nei3 = traversal_starting_nei->node->neighbors[2];
3203-
traversal_starting_nei->node->neighbors[0] = nei2;
3204-
traversal_starting_nei->node->neighbors[1] = nei3;
3205-
traversal_starting_nei->node->neighbors[2] = nei1;
3206-
// iqtree->sortTaxa();
3207-
iqtree->initializeTree();
3186+
if (traversal_starting_nei->node->neighbors.size() == 1){
3187+
iqtree->leftSingleRoot = true;
32083188
}
32093189
}
32103190

3211-
printMCMCFileFormat(iqtree, hessian, tree_stream, branch_lengths_vector, gradient_vector_eigen);
3191+
printMCMCFileFormat(iqtree, hessian, tree_stream, branch_lengths_vector, gradient_vector_eigen, NULL, 0, iqtree->leftSingleRoot);
32123192

32133193
outfile << endl << iqtree->aln->getNSeq() << endl << endl;
32143194
outfile << tree_stream.str() << endl << endl;
@@ -3225,7 +3205,7 @@ void printHessian(IQTree *iqtree, int partition_type) {
32253205
alnFile.close();
32263206

32273207
cout << endl << "Gradients and Hessians written to: " << outFileName << endl;
3228-
cout << "Ctl file for MCMCTree written to: " << ctlFileName << endl;
3208+
cout << "Ctl file for MCMCTree written to: " << ctlFileName << endl << endl;
32293209
// cout << "Add time records calibrations to: " << iqtree->params->user_file << endl;
32303210
}
32313211

@@ -3260,12 +3240,10 @@ SuperNeighbor* findRootedNeighbour(SuperNeighbor* super_root, int part_id) {
32603240
if (!super_root) {
32613241
return nullptr;
32623242
}
3263-
32643243
// Check if the root itself satisfies the condition
32653244
if (super_root->link_neighbors[part_id]) {
32663245
return super_root;
32673246
}
3268-
32693247
// Use a queue for level-order traversal
32703248
std::queue<SuperNeighbor*> q;
32713249
q.push(super_root);
@@ -3274,19 +3252,16 @@ SuperNeighbor* findRootedNeighbour(SuperNeighbor* super_root, int part_id) {
32743252
while (!q.empty()) {
32753253
auto current = q.front();
32763254
q.pop();
3277-
32783255
// Iterate over the current node's neighbors
32793256
for (auto nei : current->node->neighbors) {
32803257
auto super_nei = dynamic_cast<SuperNeighbor*>(nei);
32813258
if (!super_nei) {
32823259
continue; // Skip invalid neighbors
32833260
}
3284-
32853261
// Check if this neighbor satisfies the condition
32863262
if (super_nei->link_neighbors[part_id]) {
32873263
return super_nei;
32883264
}
3289-
32903265
// Add the neighbor to the queue for further exploration
32913266
q.push(super_nei);
32923267
}
@@ -3355,25 +3330,23 @@ void startTreeReconstruction(Params &params, IQTree* &iqtree, ModelCheckpoint &m
33553330
// if users want to perform tree dating (with mcmc)
33563331
// and if ModelFinder was run, the traversal starting node was incidently deleted (after copyTree and restoreCheckpoint)
33573332
// we have to delete tree nodes to force IQ-TREE to re-read the tree from the treefile
3358-
if (params.dating_method == "mcmctree" && params.dating_mf)
3359-
{
3333+
3334+
if (params.dating_method == "mcmctree" && params.dating_mf){
33603335
// if it's a supertree, delete all tree members
3361-
if (iqtree->isSuperTree())
3362-
{
3336+
if (iqtree->isSuperTree()){
33633337
PhyloSuperTree* stree = (PhyloSuperTree*) iqtree;
3338+
33643339
// delete member trees one by one
3365-
for (PhyloSuperTree::iterator it = stree->begin(); it != stree->end(); it++)
3366-
{
3367-
if ((*it)->root)
3368-
{
3340+
for (PhyloSuperTree::iterator it = stree->begin(); it != stree->end(); it++){
3341+
if ((*it)->root){
3342+
33693343
(*it)->freeNode();
33703344
(*it)->root = NULL;
33713345
}
33723346
}
33733347
}
33743348
// delete the tree itself
3375-
if (iqtree->root)
3376-
{
3349+
if (iqtree->root){
33773350
iqtree->freeNode();
33783351
iqtree->root = NULL;
33793352
}
@@ -3960,15 +3933,14 @@ void runTreeReconstruction(Params &params, IQTree* &iqtree) {
39603933
//check for dating with MCMCTree
39613934
if (params.dating_method == "mcmctree")
39623935
{
3963-
39643936
cout << endl << "--- Generating the gradients and the Hessian for MCMCTree ---" << endl;
39653937
if (iqtree->isSuperTree())
39663938
{
39673939
auto* stree = (PhyloSuperTree*)iqtree;
39683940
int part_id = 0;
39693941
for (auto& it : *stree)
39703942
{
3971-
auto* partition_tree = (PhyloTree*)it;
3943+
auto* partition_tree = it;
39723944
bool leftSingleRoot = false;
39733945

39743946
// If we memorized the traversal starting node -> find the corresponding traversal starting node for the current partition
@@ -3982,13 +3954,17 @@ void runTreeReconstruction(Params &params, IQTree* &iqtree) {
39823954
// identify whether there is missing taxa at the root. This is needed if we have a clade of 2 taxa at
39833955
// the root in the rooted tree. When unrooted, they become two children for root which is are leaves.
39843956
// if leftSingleRoot == true, that means the leaves at the root are not from a 2 leaves clade.
3985-
for(auto linkedNei: nei->node->neighbors){
3957+
for (auto linkedNei : nei->node->neighbors)
3958+
{
39863959
auto superLinkedNei = (SuperNeighbor*)linkedNei;
3987-
if(!superLinkedNei->link_neighbors[part_id]){
3960+
if (!superLinkedNei->link_neighbors[part_id])
3961+
{
39883962
leftSingleRoot = true;
3963+
}else
3964+
{
3965+
it->root_available = false;
39893966
}
39903967
}
3991-
39923968
// this is the other case where we don't miss any nodes at the root but a leaf at the root
39933969
if (!leftSingleRoot && linked_super_nei->node->neighbors.size() == 1)
39943970
{

main/timetree.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,8 @@ void computeHessian(PhyloTree *tree) {
367367
if (tree->traversal_starting_node && tree->root != tree->traversal_starting_node){
368368
tree->root = (Node *) tree->traversal_starting_node;
369369
}
370-
371370
// sort the internal nodes according to their smallest taxon id
372-
// tree->sortTaxa();
371+
// tree->sortTaxa();
373372
tree->clearBranchDirection();
374373
tree->initializeTree();
375374
tree->computeBranchDirection();
@@ -388,7 +387,6 @@ void computeHessian(PhyloTree *tree) {
388387
cout << "lh: " << lh << " df: " << df << " ddf: " << ddf << endl;
389388
}
390389
}
391-
392390
}
393391

394392
void runMCMCTree(PhyloTree *tree) {

tree/phylotree.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ class PhyloTree : public MTree, public Optimization, public CheckpointFactory {
841841
int num_packets;
842842

843843
/** flag to identify partition-trees with missing root for MCMCTree branch traversal order*/
844-
bool leftSingleRoot;
844+
bool leftSingleRoot = false;
845845

846846
/****************************************************************************
847847
helper functions for computing tree traversal
@@ -1586,6 +1586,11 @@ class PhyloTree : public MTree, public Optimization, public CheckpointFactory {
15861586
* */
15871587
double *hessian_diagonal;
15881588

1589+
/**
1590+
Flag to check root node availability in case of missing data in partitions
1591+
* */
1592+
bool root_available = true;
1593+
15891594

15901595
/****************************************************************************
15911596
Nearest Neighbor Interchange by maximum likelihood

0 commit comments

Comments
 (0)