1- from argparse import ArgumentParser
2- import numpy as np
1+ import argparse
32import time
3+ import sys
4+ import gzip
5+ import numpy as np
6+ from numba import jit , prange , set_num_threads
7+
8+ # --- Helper: Efficient FASTA Reader (Supports .gz) ---
9+ def read_fasta_to_numpy (filename ):
10+ """
11+ Reads a FASTA (or .gz FASTA) file and converts it to a 2D NumPy character array.
12+ """
13+ headers = []
14+ sequences = []
15+ current_seq = []
16+
17+ # Determine if we need to open with gzip or standard open
18+ if filename .endswith ('.gz' ):
19+ # 'rt' mode opens it as text, handling the decompression automatically
20+ f = gzip .open (filename , 'rt' )
21+ else :
22+ f = open (filename , 'r' )
23+
24+ try :
25+ for line in f :
26+ line = line .strip ()
27+ if not line :
28+ continue
29+ if line .startswith ('>' ):
30+ if current_seq :
31+ sequences .append ("" .join (current_seq ))
32+ current_seq = []
33+ headers .append (line )
34+ else :
35+ current_seq .append (line )
36+ # Add the last sequence
37+ if current_seq :
38+ sequences .append ("" .join (current_seq ))
39+ finally :
40+ f .close ()
41+
42+ if not sequences :
43+ raise ValueError ("No sequences found in input file." )
44+
45+ # Convert to NumPy array of characters (S1 = 1-byte string)
46+ try :
47+ # Use 'S1' (bytes) for performance equivalent to C++ char
48+ seq_matrix = np .array ([list (s ) for s in sequences ], dtype = 'S1' )
49+ except ValueError :
50+ raise ValueError ("Sequences must all be the same length for this operation." )
51+
52+ return headers , seq_matrix
53+
54+ # --- Core Logic: Parallel Filter (Numba) ---
55+ # nopython=True: Compile to machine code (no Python interpreter slow-down)
56+ # parallel=True: Enable automatic parallelization (OpenMP/TBB backend)
57+ @jit (nopython = True , parallel = True )
58+ def get_column_mask (seq_matrix , threshold ):
59+ n_seqs , n_cols = seq_matrix .shape
60+ max_allowed_gaps = int (np .floor (threshold * n_seqs ))
61+
62+ keep_mask = np .zeros (n_cols , dtype = np .bool_ )
63+
64+ for col_idx in prange (n_cols ):
65+ gap_count = 0
66+ for seq_idx in range (n_seqs ):
67+ # b'-' is the byte representation of a dash
68+ if seq_matrix [seq_idx , col_idx ] == b'-' :
69+ gap_count += 1
70+
71+ if gap_count <= max_allowed_gaps :
72+ keep_mask [col_idx ] = True
73+
74+ return keep_mask
75+
76+ # --- Main Execution ---
77+ def main ():
78+ parser = argparse .ArgumentParser (description = 'Reduce alignment length to speedup tree inference process' )
79+ parser .add_argument ('inaln' , help = 'Input alignment (FASTA or .gz)' )
80+ parser .add_argument ('outaln' , help = 'Output alignment (Uncompressed FASTA)' )
81+ parser .add_argument ('threshold' , type = float , help = 'Minimum gap proportion for a column be removed' )
82+ parser .add_argument ('--threads' , type = int , default = 1 , help = 'Number of threads to use' )
83+ args = parser .parse_args ()
84+
85+ # 1. Configure Threads
86+ if args .threads > 1 :
87+ set_num_threads (args .threads )
88+ print (f"Using { args .threads } threads." )
89+
90+ try :
91+ print (f"Reading alignment from { args .inaln } ..." )
92+ headers , seq_matrix = read_fasta_to_numpy (args .inaln )
93+
94+ num_seqs , num_cols = seq_matrix .shape
95+ print (f"Original dimensions: { num_seqs } sequences, { num_cols } columns" )
96+
97+ # Start Timing
98+ start_time = time .perf_counter ()
99+
100+ # 2. Parallel Analysis
101+ keep_mask = get_column_mask (seq_matrix , args .threshold )
102+
103+ # 3. Filtering (Slicing)
104+ filtered_matrix = seq_matrix [:, keep_mask ]
105+
106+ end_time = time .perf_counter ()
107+ elapsed_ms = (end_time - start_time ) * 1000
108+
109+ new_cols = filtered_matrix .shape [1 ]
110+ print (f"Original length: { num_cols } , length after removing gappy columns: { new_cols } " )
111+ print (f"Remove gappy columns in { elapsed_ms :.2f} ms" )
112+
113+ # 4. Write Output (Uncompressed)
114+ print (f"Writing output to { args .outaln } ..." )
115+ with open (args .outaln , 'w' ) as f :
116+ for i , header in enumerate (headers ):
117+ seq_str = filtered_matrix [i ].tobytes ().decode ('utf-8' )
118+ f .write (f"{ header } \n { seq_str } \n " )
119+
120+ print ("Done." )
121+
122+ except Exception as e :
123+ sys .stderr .write (f"Error: { e } \n " )
124+ sys .exit (1 )
4125
5- parser = ArgumentParser (description = 'Reduce alignment length to speedup tree inference process' )
6- parser .add_argument ('inaln' , help = 'Input alignment' )
7- parser .add_argument ('outaln' , help = 'Output alignment' )
8- parser .add_argument ('threshold' , type = float , help = 'Minimum gap porpotion for a column be removed' )
9- args = parser .parse_args ()
10-
11- st = time .time ()
12-
13- threshold = args .threshold
14- name = []
15- aln = []
16-
17- with open (args .inaln , "r" ) as alnFile :
18- inContent = alnFile .read ().splitlines ()
19- for c in inContent :
20- if c [0 ] == '>' :
21- name .append (c )
22- else :
23- aln .append (c )
24-
25- allAln = np .array ([list (a ) for a in aln ])
26- lb = len (allAln [0 ])
27- allAln = np .transpose (allAln )
28- stayedRows = []
29- rowID = 0
30- for row in allAln :
31- num_gap = (row == '-' ).sum ()
32- if num_gap / len (name ) <= threshold :
33- stayedRows .append (rowID )
34- rowID += 1
35- newAln = []
36- for r in stayedRows :
37- newAln .append (allAln [r ])
38- newAln = np .array (newAln )
39- newAln = np .transpose (newAln )
40- la = len (newAln [0 ])
41- outFile = []
42- with open (args .outaln , "w" ) as outFile :
43- n = 0
44- for a in newAln :
45- outFile .write (name [n ]+ '\n ' )
46- n += 1
47- outFile .write ("" .join (a )+ '\n ' )
48- en = time .time ()
49-
50- print ("Masked gappy site. Length before/after: " + str (lb )+ "/" + str (la )+ ". Total time: " , en - st , "seconds." )
126+ if __name__ == "__main__" :
127+ main ()
0 commit comments