Skip to content

Commit da3b66a

Browse files
y3tsengy3tseng
authored andcommitted
improve speed for progressive alignment
1 parent bd741b9 commit da3b66a

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

src/alignment-helper.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,11 @@ void msa::alignment_helper::addGappyColumnsBack(alnPath &aln_before, alnPath &al
376376
void msa::alignment_helper::updateAlignment(NodePair &nodes, SequenceDB *database, Option* option, alnPath &aln)
377377
{
378378
int totalLen = aln.size(), startLen = 0, endLen = 0;
379-
bool updateSeq = (database->currentTask != 2);
380379
tbb::this_task_arena::isolate([&] {
381380
tbb::parallel_for(tbb::blocked_range<int>(0, nodes.first->seqsIncluded.size()), [&](tbb::blocked_range<int> range) {
382381
for (int idx = range.begin(); idx < range.end(); ++idx) {
383382
int sIdx = nodes.first->seqsIncluded[idx];
384-
if (updateSeq) {
383+
if (database->currentTask != 2 && sIdx >= 0) {
385384
database->id_map[sIdx]->memCheck(totalLen);
386385
int storeFrom = database->id_map[sIdx]->storage;
387386
int storeTo = 1 - storeFrom;
@@ -420,7 +419,7 @@ void msa::alignment_helper::updateAlignment(NodePair &nodes, SequenceDB *databas
420419
tbb::parallel_for(tbb::blocked_range<int>(0, nodes.second->seqsIncluded.size()), [&](tbb::blocked_range<int> range) {
421420
for (int idx = range.begin(); idx < range.end(); ++idx) {
422421
int sIdx = nodes.second->seqsIncluded[idx];
423-
if (updateSeq) {
422+
if (database->currentTask != 2 && sIdx >= 0) {
424423
database->id_map[sIdx]->memCheck(totalLen);
425424
int orgIdx = 0;
426425
int storeFrom = database->id_map[sIdx]->storage;
@@ -455,11 +454,31 @@ void msa::alignment_helper::updateAlignment(NodePair &nodes, SequenceDB *databas
455454
}
456455
});
457456
});
458-
for (auto idx: nodes.second->seqsIncluded) nodes.first->seqsIncluded.push_back(idx);
459457
nodes.first->alnNum += nodes.second->alnNum;
460-
nodes.second->seqsIncluded.clear();
461458
nodes.first->alnLen = totalLen;
462459
nodes.first->alnWeight += nodes.second->alnWeight;
460+
for (auto idx: nodes.second->seqsIncluded) nodes.first->seqsIncluded.push_back(idx);
461+
nodes.second->seqsIncluded.clear();
462+
if (nodes.first->seqsIncluded.size() > _UPDATE_SEQ_TH && !nodes.first->msaFreq.empty() && database->currentTask != 2) {
463+
int seqCount = 0, firstSeqID = 0;
464+
for (auto idx: nodes.first->seqsIncluded) {
465+
if (idx > 1) {
466+
if (firstSeqID == 0) firstSeqID = -idx;
467+
seqCount++;
468+
}
469+
}
470+
if (seqCount >= _UPDATE_SEQ_TH) {
471+
database->subtreeAln[firstSeqID] = alnPath(totalLen, 0);
472+
std::vector<int> new_seqsIncluded;
473+
new_seqsIncluded.push_back(firstSeqID);
474+
for (auto idx: nodes.first->seqsIncluded) {
475+
if (idx >= 0) database->id_map[idx]->subtreeIdx = firstSeqID;
476+
else new_seqsIncluded.push_back(idx);
477+
}
478+
nodes.first->seqsIncluded = new_seqsIncluded;
479+
}
480+
}
481+
463482
return;
464483
}
465484

src/msa.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ namespace msa
175175
namespace alignment_helper
176176
{
177177
constexpr int _CAL_PROFILE_TH = 1000;
178+
constexpr int _UPDATE_SEQ_TH = 1000;
178179

179180
void calculateProfile(float *profile, NodePair &nodes, SequenceDB *database, Option *option, int32_t memLen);
180181
void removeGappyColumns(float *hostFreq, NodePair &nodes, Option *option, std::pair<IntPairVec, IntPairVec> &gappyColumns, int32_t memLen, IntPair &lens, int currentTask);
@@ -199,6 +200,7 @@ namespace msa
199200
void updateNode(Tree *tree, NodePairVec &nodes, SequenceDB *database);
200201
void progressiveAlignment(Tree *T, SequenceDB *database, Option *option, std::vector<NodePairVec> &alnPairPerLevel, Params &param, alnFunction alignmentKernel);
201202
void msaOnSubtree(Tree *T, SequenceDB *database, Option *option, Params &param, alnFunction alignmentKernel, int subtree=-1);
203+
void updateAlignment(Node* node, SequenceDB *database);
202204

203205
namespace cpu
204206
{

src/phylogeny.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ namespace phylogeny {
4343
int alnNum = {0};
4444
float alnWeight = {0};
4545
int getAlnNum(int currentTask) {
46-
return (currentTask == 2) ? alnNum : seqsIncluded.size();
46+
// return (currentTask == 2) ? alnNum : seqsIncluded.size();
47+
return alnNum;
4748
};
4849
int getAlnLen(int currentTask) {
4950
return alnLen;

src/progressive.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <chrono>
66
#include <functional>
7+
#include <tbb/parallel_for.h>
78

89

910
void msa::progressive::getProgressivePairs(std::vector<std::pair<NodePair,int>>& alnOrder, std::stack<Node*> postStack, int grpID, int mode) {
@@ -190,6 +191,43 @@ void msa::progressive::progressiveAlignment(Tree *T, SequenceDB *database, Optio
190191
}
191192
}
192193

194+
void msa::progressive::updateAlignment(Node* node, SequenceDB *database) {
195+
std::vector<int> new_seqsIncludes;
196+
tbb::this_task_arena::isolate([&] {
197+
tbb::parallel_for(tbb::blocked_range<int>(0, database->sequences.size()), [&](tbb::blocked_range<int> range) {
198+
for (int idx = range.begin(); idx < range.end(); ++idx) {
199+
auto seq = database->sequences[idx];
200+
auto sIdx = seq->id;
201+
if (seq->subtreeIdx < -1) {
202+
auto aln = database->subtreeAln[seq->subtreeIdx];
203+
database->id_map[sIdx]->memCheck(aln.size());
204+
int orgIdx = 0;
205+
int storeFrom = database->id_map[sIdx]->storage;
206+
int storeTo = 1 - storeFrom;
207+
for (int k = 0; k < aln.size(); ++k) {
208+
if (aln[k] == 0) {
209+
database->id_map[sIdx]->alnStorage[storeTo][k] = database->id_map[sIdx]->alnStorage[storeFrom][orgIdx];
210+
orgIdx++;
211+
}
212+
else {
213+
database->id_map[sIdx]->alnStorage[storeTo][k] = '-';
214+
}
215+
}
216+
database->id_map[sIdx]->len = aln.size();
217+
database->id_map[sIdx]->changeStorage();
218+
}
219+
}
220+
});
221+
});
222+
for (auto sIdx: node->seqsIncluded) {
223+
if (sIdx >= 0) new_seqsIncludes.push_back(sIdx);
224+
}
225+
for (auto seq: database->sequences) {
226+
if (seq->subtreeIdx < 0) new_seqsIncludes.push_back(seq->id);
227+
}
228+
node->seqsIncluded = new_seqsIncludes;
229+
return;
230+
}
193231

194232
void msa::progressive::msaOnSubtree(Tree *T, SequenceDB *database, Option *option, Params &param, alnFunction alignmentKernel, int subtree) {
195233
auto progressiveStart = std::chrono::high_resolution_clock::now();
@@ -228,8 +266,11 @@ void msa::progressive::msaOnSubtree(Tree *T, SequenceDB *database, Option *optio
228266
if (database->currentTask != 2) std::cerr << "Alignment (length: " << T->root->alnLen << ") completed in " << progressiveTime.count() / 1000000000 << " s\n";
229267
else std::cerr<< "Alignment on " << T->allNodes.size() << " subalignments (length: " << T->root->getAlnLen(database->currentTask) << ") in " << progressiveTime.count() / 1000000 << " ms\n";
230268
}
231-
if (database->fallback_nodes.empty())
269+
if (database->fallback_nodes.empty()) {
270+
if (option->alnMode == DEFAULT_ALN || option->alnMode == PLACE_W_TREE) updateAlignment(T->root, database);
232271
return;
272+
}
273+
233274

234275
// Adding bad sequences back
235276
auto badStart = std::chrono::high_resolution_clock::now();
@@ -248,6 +289,7 @@ void msa::progressive::msaOnSubtree(Tree *T, SequenceDB *database, Option *optio
248289
std::cerr << "Realign profiles that have been deferred. Total profiles/sequences: " << database->fallback_nodes.size() << " / " << badSeqBefore << '\n';
249290
database->fallback_nodes.clear();
250291
progressiveAlignment(T, database, option, alnPairsPerLevel,param, cpu::alignmentKernel_CPU);
292+
if (option->alnMode == DEFAULT_ALN || option->alnMode == PLACE_W_TREE) updateAlignment(T->root, database);
251293
// Reset currentTask
252294
database->currentTask = 0;
253295
auto badEnd = std::chrono::high_resolution_clock::now();

0 commit comments

Comments
 (0)