Skip to content

Commit 36f72b7

Browse files
y3tsengy3tseng
authored andcommitted
improve speed for masking gappy columns
1 parent f466e50 commit 36f72b7

File tree

6 files changed

+132
-56
lines changed

6 files changed

+132
-56
lines changed

install/installIterative.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ conda config --add channels bioconda
77
conda config --add channels conda-forge
88
conda config --set channel_priority strict
99

10-
conda install snakemake -y
11-
conda install ete3 -y
12-
conda install numpy -y
10+
conda install -y snakemake ete3 numpy numba
11+
1312

1413
# Get system architecture
1514
ARCH=$(uname -m)

workflow/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ gapextend: -5
3434
matrix: ""
3535

3636
# Iterative Mode
37-
mask_gappy: 0.995 # Minimum proportion of gappy sites that would be mask before proceed to tree inference step
37+
mask_gappy: 0.95 # Minimum proportion of gappy sites that would be mask before proceed to tree inference step
3838

3939

4040
# IQ-TREE Model Selection
4141
# Automatically determine the best model without specifying one, though this may take some time.
4242
# If you would like to speed up the process, please specify a model (e.g., "-m GTR").
4343
# See "https://iqtree.github.io/doc/Substitution-Models" for more details on available models.
44-
iqtree_model: ""
44+
iqtree_model: ""

workflow/rules/fasttree.smk

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ rule fasttree:
1111
threads: config["num_threads"]
1212
shell:
1313
'''
14-
python3 scripts/reduceLen.py {input.msa} {params.tempFile} {params.threshold}
14+
python3 scripts/reduceLen.py --threads {threads} {input.msa} {params.tempFile} {params.threshold}
1515
export OMP_NUM_THREADS={threads}
1616
{params.fasttree_exe} {params.model} -fastest {params.tempFile} > {params.tempTree}
1717
python3 scripts/resolveTree.py {params.tempTree} {output.tree}

workflow/rules/iqtree.smk

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ rule iqtree:
1111
threads: config["num_threads"]
1212
shell:
1313
'''
14-
python3 scripts/reduceLen.py {input.msa} {params.tempFile} {params.threshold}
14+
python3 scripts/reduceLen.py --threads {threads} {input.msa} {params.tempFile} {params.threshold}
1515
{params.iqtree_exe} -s {params.tempFile} {params.model} --threads-max {threads}
1616
mv {params.temp}/msa.mask.fa.treefile {output}
1717
rm {params.temp}/msa.mask.fa.*

workflow/rules/raxml.smk

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ rule raxml:
1010
threads: config["num_threads"]
1111
shell:
1212
'''
13-
python3 scripts/reduceLen.py {input.msa} {params.tempFile} {params.threshold}
13+
python3 scripts/reduceLen.py --threads {threads} {input.msa} {params.tempFile} {params.threshold}
1414
{params.raxml_exe} -s {params.tempFile} -m {params.model} -n raxml.tree -T {threads} -p 235813
1515
mv RAxML_bestTree.raxml.tree {output}
1616
rm *.raxml.tree

workflow/scripts/reduceLen.py

Lines changed: 125 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,127 @@
1-
from argparse import ArgumentParser
2-
import numpy as np
1+
import argparse
32
import time
3+
import sys
4+
import gzip
5+
import numpy as np
6+
from numba import jit, prange, set_num_threads
7+
8+
# --- Helper: Efficient FASTA Reader (Supports .gz) ---
9+
def read_fasta_to_numpy(filename):
10+
"""
11+
Reads a FASTA (or .gz FASTA) file and converts it to a 2D NumPy character array.
12+
"""
13+
headers = []
14+
sequences = []
15+
current_seq = []
16+
17+
# Determine if we need to open with gzip or standard open
18+
if filename.endswith('.gz'):
19+
# 'rt' mode opens it as text, handling the decompression automatically
20+
f = gzip.open(filename, 'rt')
21+
else:
22+
f = open(filename, 'r')
23+
24+
try:
25+
for line in f:
26+
line = line.strip()
27+
if not line:
28+
continue
29+
if line.startswith('>'):
30+
if current_seq:
31+
sequences.append("".join(current_seq))
32+
current_seq = []
33+
headers.append(line)
34+
else:
35+
current_seq.append(line)
36+
# Add the last sequence
37+
if current_seq:
38+
sequences.append("".join(current_seq))
39+
finally:
40+
f.close()
41+
42+
if not sequences:
43+
raise ValueError("No sequences found in input file.")
44+
45+
# Convert to NumPy array of characters (S1 = 1-byte string)
46+
try:
47+
# Use 'S1' (bytes) for performance equivalent to C++ char
48+
seq_matrix = np.array([list(s) for s in sequences], dtype='S1')
49+
except ValueError:
50+
raise ValueError("Sequences must all be the same length for this operation.")
51+
52+
return headers, seq_matrix
53+
54+
# --- Core Logic: Parallel Filter (Numba) ---
55+
# nopython=True: Compile to machine code (no Python interpreter slow-down)
56+
# parallel=True: Enable automatic parallelization (OpenMP/TBB backend)
57+
@jit(nopython=True, parallel=True)
58+
def get_column_mask(seq_matrix, threshold):
59+
n_seqs, n_cols = seq_matrix.shape
60+
max_allowed_gaps = int(np.floor(threshold * n_seqs))
61+
62+
keep_mask = np.zeros(n_cols, dtype=np.bool_)
63+
64+
for col_idx in prange(n_cols):
65+
gap_count = 0
66+
for seq_idx in range(n_seqs):
67+
# b'-' is the byte representation of a dash
68+
if seq_matrix[seq_idx, col_idx] == b'-':
69+
gap_count += 1
70+
71+
if gap_count <= max_allowed_gaps:
72+
keep_mask[col_idx] = True
73+
74+
return keep_mask
75+
76+
# --- Main Execution ---
77+
def main():
78+
parser = argparse.ArgumentParser(description='Reduce alignment length to speedup tree inference process')
79+
parser.add_argument('inaln', help='Input alignment (FASTA or .gz)')
80+
parser.add_argument('outaln', help='Output alignment (Uncompressed FASTA)')
81+
parser.add_argument('threshold', type=float, help='Minimum gap proportion for a column be removed')
82+
parser.add_argument('--threads', type=int, default=1, help='Number of threads to use')
83+
args = parser.parse_args()
84+
85+
# 1. Configure Threads
86+
if args.threads > 1:
87+
set_num_threads(args.threads)
88+
print(f"Using {args.threads} threads.")
89+
90+
try:
91+
print(f"Reading alignment from {args.inaln}...")
92+
headers, seq_matrix = read_fasta_to_numpy(args.inaln)
93+
94+
num_seqs, num_cols = seq_matrix.shape
95+
print(f"Original dimensions: {num_seqs} sequences, {num_cols} columns")
96+
97+
# Start Timing
98+
start_time = time.perf_counter()
99+
100+
# 2. Parallel Analysis
101+
keep_mask = get_column_mask(seq_matrix, args.threshold)
102+
103+
# 3. Filtering (Slicing)
104+
filtered_matrix = seq_matrix[:, keep_mask]
105+
106+
end_time = time.perf_counter()
107+
elapsed_ms = (end_time - start_time) * 1000
108+
109+
new_cols = filtered_matrix.shape[1]
110+
print(f"Original length: {num_cols}, length after removing gappy columns: {new_cols}")
111+
print(f"Remove gappy columns in {elapsed_ms:.2f} ms")
112+
113+
# 4. Write Output (Uncompressed)
114+
print(f"Writing output to {args.outaln}...")
115+
with open(args.outaln, 'w') as f:
116+
for i, header in enumerate(headers):
117+
seq_str = filtered_matrix[i].tobytes().decode('utf-8')
118+
f.write(f"{header}\n{seq_str}\n")
119+
120+
print("Done.")
121+
122+
except Exception as e:
123+
sys.stderr.write(f"Error: {e}\n")
124+
sys.exit(1)
4125

5-
parser = ArgumentParser(description='Reduce alignment length to speedup tree inference process')
6-
parser.add_argument('inaln', help='Input alignment')
7-
parser.add_argument('outaln', help='Output alignment')
8-
parser.add_argument('threshold', type=float, help='Minimum gap porpotion for a column be removed')
9-
args = parser.parse_args()
10-
11-
st = time.time()
12-
13-
threshold = args.threshold
14-
name = []
15-
aln = []
16-
17-
with open(args.inaln, "r") as alnFile:
18-
inContent = alnFile.read().splitlines()
19-
for c in inContent:
20-
if c[0] == '>':
21-
name.append(c)
22-
else:
23-
aln.append(c)
24-
25-
allAln = np.array([list(a) for a in aln])
26-
lb = len(allAln[0])
27-
allAln = np.transpose(allAln)
28-
stayedRows = []
29-
rowID = 0
30-
for row in allAln:
31-
num_gap = (row == '-').sum()
32-
if num_gap/len(name) <= threshold:
33-
stayedRows.append(rowID)
34-
rowID += 1
35-
newAln = []
36-
for r in stayedRows:
37-
newAln.append(allAln[r])
38-
newAln = np.array(newAln)
39-
newAln = np.transpose(newAln)
40-
la = len(newAln[0])
41-
outFile = []
42-
with open(args.outaln, "w") as outFile:
43-
n = 0
44-
for a in newAln:
45-
outFile.write(name[n]+'\n')
46-
n += 1
47-
outFile.write("".join(a)+'\n')
48-
en = time.time()
49-
50-
print("Masked gappy site. Length before/after: "+str(lb)+"/"+str(la)+". Total time: ", en-st, "seconds.")
126+
if __name__ == "__main__":
127+
main()

0 commit comments

Comments
 (0)