Skip to content

Commit d6d21b1

Browse files
committed
Add force_memmap support for large mutation files to increase speed
1 parent af1f9e6 commit d6d21b1

File tree

3 files changed

+178
-49
lines changed

3 files changed

+178
-49
lines changed

src/PrecisionProDB_Sqlite.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
file_sqlite = '/XCLabServer002_fastIO/examples/GENCODE/GENCODE.tsv.sqlite'
4040

4141

42-
def runSinglePerChromSqlite(file_sqlite, file_mutations, tempfolder, threads, chromosome, datatype, individual):
42+
def runSinglePerChromSqlite(file_sqlite, file_mutations, tempfolder, threads, chromosome, datatype, individual, force_memmap=False):
4343
'''
4444
run PerChrom_sqlite for a single chromosome
4545
'''
@@ -51,7 +51,8 @@ def runSinglePerChromSqlite(file_sqlite, file_mutations, tempfolder, threads, ch
5151
outprefix = outprefix,
5252
datatype = datatype,
5353
chromosome = chromosome,
54-
individual = individual)
54+
individual = individual,
55+
force_memmap = force_memmap)
5556
print('run perchrom_sqlite for chromosome', chromosome)
5657
# perchrom_sqlite.run_perChrom()
5758
# print('finished running perchrom_sqlite for chromosome:', chromosome)
@@ -65,7 +66,7 @@ def runSinglePerChromSqlite(file_sqlite, file_mutations, tempfolder, threads, ch
6566
return None
6667

6768

68-
def runPerChomSqlite(file_sqlite, file_mutations, threads, outprefix, protein_keyword, datatype, keep_all, individual, chromosomes_genome, chromosomes_genome_description, file_gtf):
69+
def runPerChomSqlite(file_sqlite, file_mutations, threads, outprefix, protein_keyword, datatype, keep_all, individual, chromosomes_genome, chromosomes_genome_description, file_gtf, force_memmap=False):
6970
'''
7071
'''
7172
# split file_mutations by chromosome
@@ -83,27 +84,43 @@ def runPerChomSqlite(file_sqlite, file_mutations, threads, outprefix, protein_ke
8384

8485
if individual == 'ALL_VARIANTS':
8586
individual = ''
87+
# only force memmap when sample information is retained
88+
use_force_memmap = force_memmap and individual not in ['', None]
8689
tempfolder = pergeno.tempfolder
8790
pergeno.chromosomes_genome = chromosomes_genome
88-
chromosomes_mutation = pergeno.splitMutationByChromosomeLarge(chromosomes_genome_description=chromosomes_genome_description, chromosomes_genome=chromosomes_genome)
91+
individual_columns = None
92+
if isinstance(individual, str) and individual not in ['', 'None', 'ALL_SAMPLES']:
93+
individual_columns = [i.strip() for i in individual.split(',') if i.strip()]
94+
elif isinstance(individual, (list, tuple)):
95+
individual_columns = [i for i in individual if i]
96+
chromosomes_mutation = pergeno.splitMutationByChromosomeLarge(
97+
chromosomes_genome_description=chromosomes_genome_description,
98+
chromosomes_genome=chromosomes_genome,
99+
individual_columns=individual_columns,
100+
enable_memmap=use_force_memmap
101+
)
89102

90103
# run tsv2memmap here with multiple threads to save time
91104
files_mutation_to_convert = [f'{tempfolder}/{i}.mutation.tsv' for i in chromosomes_mutation]
92105
files_mutation_to_convert = [i for i in files_mutation_to_convert if os.path.getsize(i) > 100000000]
93-
if len(files_mutation_to_convert) > 0 and threads > 1 and individual !='':
94-
if individual == 'ALL_SAMPLES':
95-
columns_in_file_mutation = openFile(files_mutation_to_convert[0], 'r').readline().strip().split('\t')
96-
individual_for_memmap = [i for i in columns_in_file_mutation if i not in ['chr', 'pos', '', 'ref', 'alt', 'pos_end']]
106+
if len(files_mutation_to_convert) > 0 and threads > 1 and individual !='' and not use_force_memmap:
107+
files_mutation_to_convert = [i for i in files_mutation_to_convert if not os.path.exists(i + '.memmap')]
108+
if len(files_mutation_to_convert) == 0:
109+
print('memmap files already present for large chromosomes, skip regeneration.')
97110
else:
98-
individual_for_memmap = individual
99-
from vcf2mutation import tsv2memmap
100-
pool = Pool(threads)
101-
pool.starmap(tsv2memmap, [(i, individual_for_memmap, i +'.memmap') for i in files_mutation_to_convert], chunksize=1)
102-
pool.close()
103-
pool.join()
111+
if individual == 'ALL_SAMPLES':
112+
columns_in_file_mutation = openFile(files_mutation_to_convert[0], 'r').readline().strip().split('\t')
113+
individual_for_memmap = [i for i in columns_in_file_mutation if i not in ['chr', 'pos', '', 'ref', 'alt', 'pos_end']]
114+
else:
115+
individual_for_memmap = individual
116+
from vcf2mutation import tsv2memmap
117+
pool = Pool(threads)
118+
pool.starmap(tsv2memmap, [(i, individual_for_memmap, i +'.memmap') for i in files_mutation_to_convert], chunksize=1)
119+
pool.close()
120+
pool.join()
104121

105122
# run runSinglePerChromSqlite
106-
chromosomes_mutated = [runSinglePerChromSqlite(file_sqlite, f'{tempfolder}/{chromosome}.mutation.tsv', tempfolder, threads, chromosome, datatype, individual) for chromosome in chromosomes_mutation]
123+
chromosomes_mutated = [runSinglePerChromSqlite(file_sqlite, f'{tempfolder}/{chromosome}.mutation.tsv', tempfolder, threads, chromosome, datatype, individual, force_memmap=use_force_memmap) for chromosome in chromosomes_mutation]
107124
# successful chromosomes
108125
chromosomes_mutated = [e for e in chromosomes_mutated if e is not None]
109126
# collect mutation annotations
@@ -115,8 +132,22 @@ def runPerChomSqlite(file_sqlite, file_mutations, threads, outprefix, protein_ke
115132

116133
# collect individual information for later use when adding unchanged proteins per individual
117134
all_individuals = []
135+
base_columns = {'chr', 'pos', 'ref', 'alt', 'pos_end'}
118136
if isinstance(individual, str):
119-
if individual not in ['', 'None']:
137+
if individual == 'ALL_SAMPLES':
138+
# Derive the actual sample columns from one of the split mutation files so
139+
# unchanged proteins can still be emitted when some individuals remain ref-only.
140+
sample_header_file = None
141+
for chromosome in chromosomes_mutation:
142+
candidate = f'{tempfolder}/{chromosome}.mutation.tsv'
143+
if os.path.exists(candidate):
144+
sample_header_file = candidate
145+
break
146+
if sample_header_file:
147+
with openFile(sample_header_file, 'r') as fo:
148+
header_columns = fo.readline().strip().split('\t')
149+
all_individuals = [col for col in header_columns if col not in base_columns]
150+
elif individual not in ['', 'None']:
120151
all_individuals = [i.strip() for i in individual.split(',') if i.strip()]
121152
elif isinstance(individual, (list, tuple)):
122153
all_individuals = [i for i in individual if i]
@@ -190,6 +221,7 @@ def runPerChomSqlite_vcf(file_mutations, file_sqlite, threads, outprefix, dataty
190221
# get two mutation files from vcf file
191222
print('start extracting mutation file from the vcf input')
192223
outprefix_vcf = outprefix + '.vcf2mutation'
224+
force_memmap = (individual == 'ALL_SAMPLES')
193225
individual = convertVCF2MutationComplex(file_vcf = file_mutations, outprefix = outprefix_vcf, individual_input=individual, filter_PASS = filter_PASS, chromosome_only = chromosome_only, info_field = info_field, info_field_thres=info_field_thres, threads = threads)
194226
individual = ','.join(individual)
195227
print('finished extracting mutations from the vcf file')
@@ -208,7 +240,8 @@ def runPerChomSqlite_vcf(file_mutations, file_sqlite, threads, outprefix, dataty
208240
individual,
209241
chromosomes_genome,
210242
chromosomes_genome_description,
211-
file_gtf
243+
file_gtf,
244+
force_memmap=force_memmap
212245
)
213246

214247

@@ -368,7 +401,8 @@ def main_PrecsionProDB_Sqlite(file_genome, file_gtf, file_mutations, file_protei
368401
individual=individual,
369402
chromosomes_genome=chromosomes_genome,
370403
chromosomes_genome_description=chromosomes_genome_description,
371-
file_gtf=file_gtf
404+
file_gtf=file_gtf,
405+
force_memmap=(individual == 'ALL_SAMPLES')
372406
)
373407

374408

src/PrecisionProDB_core.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from perChrom import PerChrom
88
import shutil
99
import re
10+
from array import array
1011

1112
def get_k_new(k, chromosomes_genome, chromosomes_genome_description):
1213
'''k is chromosome name. return chromosome name based on chromosomes_genome, chromosomes_genome_description
@@ -391,9 +392,12 @@ def splitMutationByChromosome(self, chromosomes_genome_description=None, chromos
391392
print('finish splitting the mutation file')
392393
return chromosomes_mutation
393394

394-
def splitMutationByChromosomeLarge(self, chromosomes_genome_description=None, chromosomes_genome=None):
395+
def splitMutationByChromosomeLarge(self, chromosomes_genome_description=None, chromosomes_genome=None, individual_columns=None, enable_memmap=False):
395396
'''split mutation file based on chromosomes
396397
file_mutations is generated from vcf file, no need to read with pandas or further processing
398+
individual_columns is an optional list of sample columns that should be written
399+
to memmap files while the TSV is being split. enable_memmap toggles this streamed
400+
writer to avoid re-reading the large TSV with pandas later on.
397401
'''
398402
tempfolder = self.tempfolder
399403
file_splitMutationByChromosomeLarge_done = os.path.join(tempfolder,'splitMutationByChromosomeLarge.done')
@@ -409,10 +413,25 @@ def splitMutationByChromosomeLarge(self, chromosomes_genome_description=None, ch
409413
chromosomes_genome_description = self.chromosomes_genome_description
410414

411415
dc_output = {}
416+
memmap_writers = {}
417+
sample_column_indices = []
412418
fo = openFile(file_mutations)
413419
for line in fo:
414420
break
415421
header = line
422+
header_columns = header.strip().split('\t')
423+
if enable_memmap:
424+
if individual_columns:
425+
name_to_index = {name: idx for idx, name in enumerate(header_columns)}
426+
sample_column_indices = [name_to_index[name] for name in individual_columns if name in name_to_index]
427+
if len(sample_column_indices) != len(individual_columns):
428+
missing = set(individual_columns) - set(name_to_index)
429+
if missing:
430+
print('warning: columns not found for memmap writing:', ','.join(missing))
431+
else:
432+
base_columns = {'chr', 'pos', 'ref', 'alt', 'pos_end'}
433+
# fallback: treat every column beyond the required fields as sample genotype
434+
sample_column_indices = [idx for idx, name in enumerate(header_columns) if name not in base_columns]
416435
for line in fo:
417436
k = line.split('\t', maxsplit=1)[0]
418437
k_new = get_k_new(k, chromosomes_genome, chromosomes_genome_description)
@@ -421,15 +440,65 @@ def splitMutationByChromosomeLarge(self, chromosomes_genome_description=None, ch
421440
dc_output[k_new] = open(tf,'w')
422441
dc_output[k_new].write(header)
423442
dc_output[k_new].write(line)
443+
if enable_memmap and sample_column_indices:
444+
if k_new not in memmap_writers:
445+
memmap_filename = tf + '.memmap'
446+
memmap_writers[k_new] = _ChromosomeMemmapWriter(memmap_filename, len(sample_column_indices))
447+
line_values = line.strip().split('\t')
448+
sample_values = [line_values[idx] if idx < len(line_values) else '0' for idx in sample_column_indices]
449+
memmap_writers[k_new].add_row(sample_values)
424450

451+
fo.close()
425452
for k_new in dc_output:
426453
dc_output[k_new].close()
454+
for writer in memmap_writers.values():
455+
writer.finalize()
427456
chromosomes_mutation = list(dc_output.keys())
428457

429458
print('finish splitting the mutation file')
430459
open(file_splitMutationByChromosomeLarge_done,'w').write('\n'.join(chromosomes_mutation))
431460
return chromosomes_mutation
432461

462+
class _ChromosomeMemmapWriter:
463+
"""Stream sample matrices into byte-aligned files for later memmap usage."""
464+
465+
def __init__(self, filename, n_cols):
466+
"""
467+
Args:
468+
filename (str): Destination path for the raw binary matrix.
469+
n_cols (int): Number of sample columns stored per row.
470+
"""
471+
self.filename = filename
472+
self.n_cols = n_cols
473+
self.handle = open(filename, 'wb')
474+
self.rows = 0
475+
476+
def add_row(self, values):
477+
"""
478+
Append a single row of sample indicators to the binary file.
479+
480+
Args:
481+
values (Iterable[str]): Raw string values from the TSV columns.
482+
"""
483+
row_array = array(
484+
'b',
485+
(1 if value not in ('', '0', '0.0', '.', 'False') else 0 for value in values)
486+
)
487+
if len(row_array) != self.n_cols:
488+
raise ValueError(f'mismatched memmap width: expected {self.n_cols}, got {len(row_array)}')
489+
row_array.tofile(self.handle)
490+
self.rows += 1
491+
492+
def finalize(self):
493+
"""Close file handle and create the companion .done flag."""
494+
self.handle.close()
495+
open(self.filename + '.done', 'w').close()
496+
497+
def __del__(self):
498+
"""Ensure file handle closes if finalize is not called explicitly."""
499+
if not self.handle.closed:
500+
self.handle.close()
501+
433502
def splitGtfByChromosomes(self,dc_protein2chr):
434503
'''split gtf file based on chromosome. only keep proteins in file_protein
435504
'''
@@ -740,4 +809,4 @@ def main():
740809
pergeno.runPerChom()
741810

742811
if __name__ == '__main__':
743-
main()
812+
main()

src/perChromSqlite.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
except:
1818
print('Cannot import tqdm. Will not use tqdm.')
1919
tqdm = False
20+
21+
_MEMMAP_CACHE = {}
2022
# # Global variables
2123
# con = None
2224
# df_mutations = None
@@ -88,42 +90,54 @@ def get_protein_id_from_df_mutations(df_mutations, file_sqlite, cpu_counts=10):
8890

8991
return combined_results
9092

93+
def _load_memmap(file_memmap, shape):
94+
"""Cache numpy.memmap instances per-process to avoid reopening the file repeatedly."""
95+
global _MEMMAP_CACHE
96+
cache_key = (file_memmap, shape)
97+
if cache_key not in _MEMMAP_CACHE:
98+
_MEMMAP_CACHE[cache_key] = np.memmap(file_memmap, dtype='int8', mode='r', shape=shape)
99+
return _MEMMAP_CACHE[cache_key]
100+
101+
91102
def convert_df_transcript2_to_df_transcript3_helper(protein_id, df_transcript2, df_mutations, individual, kargs):
92103
'''
93-
protein_id is a protein_id in df_transcript2
94-
for each protein_id, get the mutations in each individual.
95-
return a list of tuple, each tuple is (tuple of variant_index, individuals with this variant_index joined by ',')
96-
97-
If shape and file_memmap are provided in kargs, it means we're dealing with an extra large mutation file
98-
where individual data is stored in a memory-mapped file instead of in the DataFrame.
104+
Convert mutation membership for a single protein into grouped variant patterns.
105+
106+
The function inspects all variants linked to the provided protein and determines,
107+
for every individual, which subset of those variants is carried. Individuals sharing
108+
the same variant combination are collapsed into a single entry so downstream
109+
translation only needs to be executed once per unique pattern. When a memory-mapped
110+
allele matrix is available, the lookups are vectorized to avoid per-sample Python
111+
loops.
99112
'''
100113
mutations = df_transcript2.loc[protein_id]['mutations']
101114

102115
# Check if we're using memory-mapped file for individual data
103116
if 'shape' in kargs and 'file_memmap' in kargs:
104-
# Using memory-mapped file for individual data
105117
shape = kargs['shape']
106118
file_memmap = kargs['file_memmap']
107-
108-
# Open the memory-mapped file in read mode
109-
mmap = np.memmap(file_memmap, dtype='int8', mode='r', shape=shape)
110-
111-
# Get the indices of mutations for this protein
112-
mutation_indices = mutations
113-
119+
mmap = _load_memmap(file_memmap, shape)
120+
mutation_indices = np.array(mutations, dtype=int)
121+
if mutation_indices.size == 0:
122+
return []
123+
allele_block = mmap[mutation_indices, :]
124+
allele_block = allele_block if allele_block.ndim == 2 else allele_block.reshape(1, -1)
125+
# Locate every (variant, sample) pair where the allele is present.
126+
presence_coords = np.argwhere(allele_block == 1)
127+
if presence_coords.size == 0:
128+
return []
129+
sample_to_variants = {}
130+
for variant_pos, sample_idx in presence_coords:
131+
variant_idx = mutation_indices[variant_pos]
132+
sample_to_variants.setdefault(sample_idx, []).append(variant_idx)
114133
tdc = {}
115-
for i, sample in enumerate(individual):
116-
# For each sample, check which mutations are valid (value is 1)
117-
valid_mutations = []
118-
for idx in mutation_indices:
119-
if idx < mmap.shape[0] and i < mmap.shape[1] and mmap[idx, i] == 1:
120-
valid_mutations.append(idx)
121-
122-
variant_index = tuple(valid_mutations)
123-
if len(variant_index) > 0:
124-
if variant_index not in tdc:
125-
tdc[variant_index] = []
126-
tdc[variant_index].append(sample)
134+
for sample_idx, variant_list in sample_to_variants.items():
135+
if not variant_list:
136+
continue
137+
variant_index = tuple(sorted(variant_list))
138+
if len(variant_index) == 0:
139+
continue
140+
tdc.setdefault(variant_index, []).append(individual[sample_idx])
127141
else:
128142
# Using regular DataFrame for individual data
129143
tdf_m = df_mutations.loc[mutations]
@@ -236,7 +250,8 @@ def __init__(
236250
outprefix,
237251
datatype,
238252
chromosome,
239-
individual = None
253+
individual = None,
254+
force_memmap = False
240255
):
241256
self.file_sqlite = file_sqlite # genome file location
242257
self.file_mutations = file_mutations # mutation file location
@@ -245,9 +260,12 @@ def __init__(
245260
self.datatype = datatype # input datatype, could be GENCODE_GTF, GENCODE_GFF3, RefSeq or gtf
246261
self.chromosome = chromosome # chromosome name
247262
self.extra_large_file_mutation = False # whether the mutation file is larger than 1G
263+
self.force_memmap = force_memmap
248264

249265
# if file_mutation is larger than 1G, only read ['chr', 'pos', 'ref', 'alt']
266+
file_memmap_path = None
250267
if isinstance(self.file_mutations, str):
268+
file_memmap_path = self.file_mutations + '.memmap'
251269
if os.path.exists(self.file_mutations):
252270
if os.path.getsize(self.file_mutations) > 100000000:
253271
self.extra_large_file_mutation = True
@@ -289,11 +307,19 @@ def __init__(
289307

290308
self.individual = individual
291309

310+
# Determine whether we already built memmap data for this file or if it should be forced.
311+
memmap_exists = bool(file_memmap_path and os.path.exists(file_memmap_path))
312+
if not self.individual:
313+
self.extra_large_file_mutation = False
314+
else:
315+
self.extra_large_file_mutation = bool(self.individual) and (self.extra_large_file_mutation or memmap_exists or self.force_memmap)
316+
292317
if self.extra_large_file_mutation:
293318
self.df_mutations = perChrom.parse_mutation(file_mutations, columns_to_include=['chr', 'pos', 'ref', 'alt'])
294319
from vcf2mutation import tsv2memmap
295320
shape = (self.df_mutations.shape[0], len(self.individual))
296-
tsv2memmap(file_mutations, individuals = self.individual, memmap_file=file_mutations + '.memmap')
321+
if not memmap_exists:
322+
tsv2memmap(file_mutations, individuals = self.individual, memmap_file=file_mutations + '.memmap')
297323
self.shape = shape
298324
self.file_memmap = file_mutations + '.memmap'
299325

@@ -445,4 +471,4 @@ def main():
445471
perchrom_sqlite.run_perChrom()
446472

447473
if __name__ == '__main__':
448-
main()
474+
main()

0 commit comments

Comments
 (0)