diff --git a/.github/workflows/Short_runs.yml b/.github/workflows/Short_runs.yml new file mode 100644 index 00000000..b9e23158 --- /dev/null +++ b/.github/workflows/Short_runs.yml @@ -0,0 +1,31 @@ +name: Short runs + +on: + workflow_dispatch: + push: + branches: [ "master", "sc_*" ] + pull_request: + branches: [ "master" ] + +permissions: + contents: read + +jobs: + integration-tests: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + sudo apt-get install -y minimap2 samtools + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Run integration tests + run: | + pytest tests/console_test.py -v diff --git a/.github/workflows/Stereo_toy.yml b/.github/workflows/Stereo_toy.yml new file mode 100644 index 00000000..60f731bf --- /dev/null +++ b/.github/workflows/Stereo_toy.yml @@ -0,0 +1,44 @@ +name: Stereo toy data test + +on: + workflow_dispatch: + schedule: + - cron: '0 2 * * 0,4' + +env: + RUN_NAME: Stereo_toy + LAUNCHER: ${{github.workspace}}/tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + launch-runner: + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'Stereo toy data test' + if: always() + shell: bash + env: + STEP_NAME: STEREO.TOY + run: | + export PATH=$PATH:${{env.BIN_PATH}} + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.cfg -o ${{env.OUTPUT_BASE}} diff --git a/.gitignore b/.gitignore index f0d3b892..1be3a988 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,6 @@ venv.bak/ .mypy_cache/ .idea + +# Claude Code documentation (private) +.claude/ diff --git a/VERSION b/VERSION index e06d07af..30291cba 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.10.0 \ No newline at end of file +3.10.0 diff --git a/detect_barcodes.py b/detect_barcodes.py new file mode 100755 index 00000000..d3797027 --- /dev/null +++ b/detect_barcodes.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +# +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ +import concurrent.futures +import os +import random +import sys +import argparse +import gzip +from traceback import print_exc +import shutil +from concurrent.futures import ProcessPoolExecutor +from collections import defaultdict +import numpy +import pysam +from Bio import SeqIO, Seq, SeqRecord +import logging + +from src.modes import IsoQuantMode +from src.barcode_calling.common import bit_to_str, reverese_complement +from src.barcode_calling.barcode_callers import ( + TenXBarcodeDetector, + DoubleBarcodeDetector, + SharedMemoryStereoBarcodeDetector, + SharedMemoryStereoSplttingBarcodeDetector, + ReadStats, + VisiumHDBarcodeDetector +) + +logger = logging.getLogger('IsoQuant') + + +READ_CHUNK_SIZE = 100000 + +BARCODE_CALLING_MODES = {IsoQuantMode.tenX_v3: TenXBarcodeDetector, + IsoQuantMode.curio: DoubleBarcodeDetector, + IsoQuantMode.stereoseq_nosplit: SharedMemoryStereoBarcodeDetector, + IsoQuantMode.stereoseq: SharedMemoryStereoSplttingBarcodeDetector, + IsoQuantMode.visium_5prime: TenXBarcodeDetector, + IsoQuantMode.visium_hd: VisiumHDBarcodeDetector + } + +BARCODE_FILES_REQUIRED = {IsoQuantMode.tenX_v3: [1], + IsoQuantMode.curio: [1, 2], + IsoQuantMode.stereoseq_nosplit: [1], + IsoQuantMode.stereoseq: [1], + IsoQuantMode.visium_5prime: [1], + IsoQuantMode.visium_hd: [2] + } + + +def stats_file_name(file_name): + return file_name + ".stats" + + +def get_umi_length(isoquant_mode: IsoQuantMode): + if isoquant_mode not in BARCODE_CALLING_MODES: + return 0 + try: + return BARCODE_CALLING_MODES[isoquant_mode].UMI_LEN + except AttributeError: + return 0 + + +class SimpleReadStorage: + def __init__(self): + self.read_ids = [] + self.sequences = [] + + def add(self, read_id, seq): + self.read_ids.append(read_id) + self.sequences.append(seq) + + def clear(self): + self.read_ids.clear() + self.sequences.clear() + + def __len__(self): + return len(self.read_ids) + + def __iter__(self): + for i in range(len(self.read_ids)): + yield self.read_ids[i], self.sequences[i] + + def __getstate__(self): + return self.read_ids, self.sequences + + def __setstate__(self, state): + self.read_ids = state[0] + self.sequences = state[1] + + +class BarcodeCaller: + def __init__(self, output_file_name, barcode_detector, header=False, output_sequences=None): + self.barcode_detector = barcode_detector + self.output_file_name = output_file_name + self.output_file = open(self.output_file_name, "w") + self.output_sequences = output_sequences + self.output_sequences_file = None + self.process_function = self._process_read_normal + if self.output_sequences: + self.output_sequences_file = open(self.output_sequences, "w") + self.process_function = self._process_read_split + if header: + self.output_file.write(barcode_detector.result_type().header() + "\n") + self.read_stat = ReadStats() + + def get_stats(self): + return self.read_stat + + def dump_stats(self, file_name=None): + if not file_name: + file_name = stats_file_name(self.output_file_name) + stat_out = open(file_name, "w") + stat_out.write(str(self.read_stat)) + stat_out.close() + + def close(self): + self.output_file.close() + if self.output_sequences_file: + self.output_sequences_file.close() + + def __del__(self): + if not self.output_file.closed: + self.output_file.close() + if self.output_sequences_file and not self.output_sequences_file.closed: + self.output_sequences_file.close() + + def process(self, input_file): + logger.info("Processing " + input_file) + fname, outer_ext = os.path.splitext(os.path.basename(input_file)) + low_ext = outer_ext.lower() + + handle = input_file + if low_ext in ['.gz', '.gzip']: + handle = gzip.open(input_file, "rt") + input_file = fname + fname, outer_ext = os.path.splitext(os.path.basename(input_file)) + low_ext = outer_ext.lower() + + if low_ext in ['.fq', '.fastq']: + self._process_fastx(SeqIO.parse(handle, "fastq")) + elif low_ext in ['.fa', '.fasta']: + self._process_fastx(SeqIO.parse(handle, "fasta")) + elif low_ext in ['.bam', '.sam']: + self._process_bam(pysam.AlignmentFile(input_file, "rb", check_sq=False)) + else: + logger.error("Unknown file format " + input_file) + + logger.info("Finished " + input_file) + + def _process_fastx(self, read_handler): + counter = 0 + for r in read_handler: + if counter % 100 == 0: + sys.stdout.write("Processed %d reads\r" % counter) + counter += 1 + read_id = r.id + seq = str(r.seq) + self.process_function(read_id, seq) + + def _process_bam(self, read_handler): + counter = 0 + for r in read_handler: + if counter % 100 == 0: + sys.stdout.write("Processed %d reads\r" % counter) + counter += 1 + read_id = r.query_name + seq = r.query_sequence + self.process_function(read_id, seq) + + # split read and find multiple barcodes + def _process_read_split(self, read_id, read_sequence): + logger.debug("==== %s ====" % read_id) + barcode_result = self.barcode_detector.find_barcode_umi(read_id, read_sequence) + + seq_records = [] + require_tso = len(barcode_result.detected_patterns) > 1 + strands = set() + for r in barcode_result.detected_patterns: + self.read_stat.add_read(r) + if not r.is_valid(): + self.output_file.write("%s\n" % str(r)) + continue + + read_segment_start = max(0, r.primer - 25, r.polyT - 75) + read_segment_end = len(read_sequence) if r.tso5 == -1 else min(len(read_sequence), r.tso5 + 25) + r.read_id = read_id + ("_%d_%d_%s" % (read_segment_start, read_segment_end, r.strand)) + if r.strand == "+": + new_read_seq = read_sequence[read_segment_start:read_segment_end] + else: + new_read_seq = reverese_complement(read_sequence)[read_segment_start:read_segment_end] + strands.add(r.strand) + self.output_file.write("%s\n" % str(r)) + if self.output_sequences and (not require_tso or r.tso5 != -1): + seq_records.append(SeqRecord.SeqRecord(seq=Seq.Seq(new_read_seq), id=r.read_id, description="")) + + self.read_stat.add_custom_stats("Splits", len(barcode_result.detected_patterns)) + # self.read_stat.add_custom_stats("Splits %d %s" % (len(barcode_result.detected_patterns), "".join(list(sorted(strands)))), 1) + if self.output_sequences_file: + SeqIO.write(seq_records, self.output_sequences_file, "fasta") + + def _process_read_normal(self, read_id, read_sequence): + logger.debug("==== %s ====" % read_id) + if read_sequence is None: return + barcode_result = self.barcode_detector.find_barcode_umi(read_id, read_sequence) + + self.output_file.write("%s\n" % str(barcode_result)) + self.read_stat.add_read(barcode_result) + + def process_chunk(self, read_chunk): + counter = 0 + for read_id, seq in read_chunk: + self.process_function(read_id, seq) + counter += 1 + return counter + + +def fastx_file_chunk_reader(handler): + current_chunk = SimpleReadStorage() + for r in handler: + current_chunk.add(r.id, str(r.seq)) + if len(current_chunk) >= READ_CHUNK_SIZE: + yield current_chunk + current_chunk = SimpleReadStorage() + yield current_chunk + + +def bam_file_chunk_reader(handler): + current_chunk = SimpleReadStorage() + for r in handler: + if r.is_secondary or r.is_supplementary: + continue + current_chunk.add(r.query_name, r.query_sequence) + if len(current_chunk) >= READ_CHUNK_SIZE: + yield current_chunk + current_chunk = SimpleReadStorage() + yield current_chunk + + +def process_chunk(barcode_detector, read_chunk, output_file, num, out_fasta=None, min_score=None): + output_file += "_" + str(num) + if out_fasta: + out_fasta += "_" + str(num) + counter = 0 + if min_score: + barcode_detector.min_score = min_score + barcode_caller = BarcodeCaller(output_file, barcode_detector, output_sequences=out_fasta) + counter += barcode_caller.process_chunk(read_chunk) + read_chunk.clear() + barcode_caller.dump_stats() + barcode_caller.close() + + return output_file, out_fasta, counter + + +def prepare_barcodes(args): + logger.info("Using barcodes from %s" % ", ".join(args.barcodes)) + barcode_files = len(args.barcodes) + if barcode_files not in BARCODE_FILES_REQUIRED[args.mode]: + logger.critical("Barcode calling mode %s requires %s files, %d provided" % + (args.mode.name, " or ".join([str(x) for x in BARCODE_FILES_REQUIRED[args.mode]]), barcode_files)) + exit(-3) + barcodes = [] + for bc in args.barcodes: + barcodes.append(load_barcodes(bc, needs_iterator=args.mode.needs_barcode_iterator())) + + if len(barcodes) == 1: + barcodes = barcodes[0] + if not args.mode.needs_barcode_iterator(): + logger.info("Loaded %d barcodes" % len(barcodes)) + else: + if not args.mode.needs_barcode_iterator(): + for i, bc in enumerate(barcodes): + logger.info("Loaded %d barcodes from %s" % (len(bc), args.barcodes[i])) + barcodes = tuple(barcodes) + return barcodes + + +def process_single_thread(args): + logger.info("Preparing barcodes indices") + barcodes = prepare_barcodes(args) + barcode_detector = BARCODE_CALLING_MODES[args.mode](barcodes) + if args.min_score: + barcode_detector.min_score = args.min_score + barcode_caller = BarcodeCaller(args.output_tsv, barcode_detector, header=True, output_sequences=args.out_fasta) + barcode_caller.process(args.input) + barcode_caller.dump_stats() + for stat_line in barcode_caller.get_stats(): + logger.info(" " + stat_line) + barcode_caller.close() + logger.info("Finished barcode calling") + + +def process_in_parallel(args): + input_file = args.input + logger.info("Processing " + input_file) + fname, outer_ext = os.path.splitext(os.path.basename(input_file)) + low_ext = outer_ext.lower() + + handle = input_file + if low_ext in ['.gz', '.gzip']: + handle = gzip.open(input_file, "rt") + input_file = fname + fname, outer_ext = os.path.splitext(os.path.basename(input_file)) + low_ext = outer_ext.lower() + + if low_ext in ['.fq', '.fastq']: + read_chunk_gen = fastx_file_chunk_reader(SeqIO.parse(handle, "fastq")) + elif low_ext in ['.fa', '.fasta']: + read_chunk_gen = fastx_file_chunk_reader(SeqIO.parse(handle, "fasta")) + elif low_ext in ['.bam', '.sam']: + read_chunk_gen = bam_file_chunk_reader(pysam.AlignmentFile(input_file, "rb", check_sq=False)) + else: + logger.error("Unknown file format " + input_file) + exit(-1) + + tmp_dir = "barcode_calling_%x" % random.randint(0, 1 << 32) + while os.path.exists(tmp_dir): + tmp_dir = "barcode_calling_%x" % random.randint(0, 1 << 32) + if args.tmp_dir: + tmp_dir = os.path.join(args.tmp_dir, tmp_dir) + os.makedirs(tmp_dir) + + barcodes = prepare_barcodes(args) + barcode_detector = BARCODE_CALLING_MODES[args.mode](barcodes) + logger.info("Barcode caller created") + + min_score = None + if args.min_score: + min_score = args.min_score + + tmp_barcode_file = os.path.join(tmp_dir, "bc") + tmp_fasta_file = os.path.join(tmp_dir, "subreads") if args.out_fasta else None + chunk_counter = 0 + future_results = [] + output_files = [] + + with ProcessPoolExecutor(max_workers=args.threads) as proc: + for chunk in read_chunk_gen: + future_results.append(proc.submit(process_chunk, + barcode_detector, + chunk, + tmp_barcode_file, + chunk_counter, + tmp_fasta_file, + min_score)) + chunk_counter += 1 + if chunk_counter >= args.threads: + break + + reads_left = True + read_counter = 0 + while future_results: + completed_features, _ = concurrent.futures.wait(future_results, + return_when=concurrent.futures.FIRST_COMPLETED) + for c in completed_features: + if c.exception() is not None: + raise c.exception() + res = c.result() + out_file, out_fasta, read_count = res + read_counter += read_count + sys.stdout.write("Processed %d reads\r" % read_counter) + output_files.append((out_file, out_fasta)) + future_results.remove(c) + if reads_left: + try: + chunk = next(read_chunk_gen) + future_results.append(proc.submit(process_chunk, + barcode_detector, + chunk, + tmp_barcode_file, + chunk_counter, + tmp_fasta_file, + min_score)) + chunk_counter += 1 + except StopIteration: + reads_left = False + + with open(args.output_tsv, "w") as final_output_tsv: + final_output_fasta = open(args.out_fasta, "w") if args.out_fasta else None + header = BARCODE_CALLING_MODES[args.mode].result_type().header() + final_output_tsv.write(header + "\n") + stat_dict = defaultdict(int) + for tmp_file, tmp_fasta in output_files: + shutil.copyfileobj(open(tmp_file, "r"), final_output_tsv) + if tmp_fasta and final_output_fasta: + shutil.copyfileobj(open(tmp_fasta, "r"), final_output_fasta) + for l in open(stats_file_name(tmp_file), "r"): + v = l.strip().split("\t") + if len(v) != 2: + continue + stat_dict[v[0]] += int(v[1]) + + if final_output_fasta is not None: + final_output_fasta.close() + + with open(stats_file_name(args.output_tsv), "w") as out_stats: + for k, v in stat_dict.items(): + logger.info(" %s: %d" % (k, v)) + out_stats.write("%s\t%d\n" % (k, v)) + shutil.rmtree(tmp_dir) + logger.info("Finished barcode calling") + + +def load_barcodes(inf, needs_iterator=False): + if inf.endswith("h5") or inf.endswith("hdf5"): + return load_h5_barcodes_bit(inf) + + if inf.endswith("gz") or inf.endswith("gzip"): + handle = gzip.open(inf, "rt") + else: + handle = open(inf, "r") + + barcode_iterator = iter(l.strip().split()[0] for l in handle) + if needs_iterator: + return barcode_iterator + + return [b for b in barcode_iterator] + + +def load_h5_barcodes_bit(h5_file_path, dataset_name='bpMatrix_1'): + raise NotImplementedError() + import h5py + barcode_list = [] + with h5py.File(h5_file_path, 'r') as h5_file: + dataset = numpy.array(h5_file[dataset_name]) + for row in dataset: + for col in row: + barcode_list.append(bit_to_str(int(col[0]))) + return barcode_list + + +def set_logger(logger_instance, args): + logger_instance.setLevel(logging.INFO) + if args.debug: + logger_instance.setLevel(logging.DEBUG) + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + if args.debug: + ch.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + ch.setFormatter(formatter) + logger_instance.addHandler(ch) + + +def parse_args(sys_argv): + def add_hidden_option(*args, **kwargs): # show command only with --full-help + kwargs['help'] = argparse.SUPPRESS + parser.add_argument(*args, **kwargs) + + parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--output", "-o", type=str, help="output prefix name", required=True) + parser.add_argument("--barcodes", "-b", nargs='+', type=str, help="barcode whitelist(s)", required=False) + # parser.add_argument("--umi", "-u", type=str, help="potential UMIs, detected de novo if not set") + parser.add_argument("--mode", type=str, help="mode to be used", choices=[x.name for x in BARCODE_CALLING_MODES.keys()], + default=IsoQuantMode.stereoseq.name) + parser.add_argument("--input", "-i", type=str, help="input reads in [gzipped] FASTA, FASTQ, BAM, SAM", + required=True) + parser.add_argument("--threads", "-t", type=int, help="threads to use (16)", default=16) + parser.add_argument("--tmp_dir", type=str, help="folder for temporary files") + parser.add_argument("--min_score", type=int, help="minimal barcode score " + "(scoring system is +1, -1, -1, -1)") + add_hidden_option('--debug', action='store_true', default=False, help='Debug log output.') + + args = parser.parse_args(sys_argv) + args.mode = IsoQuantMode[args.mode] + args.out_fasta = None + args.output_tsv = None + return args + + +def check_args(args): + if args.out_fasta is None and args.mode.produces_new_fasta(): + args.out_fasta = args.output + ".split_reads.fasta" + if args.output_tsv is None: + args.output_tsv = args.output + ".barcoded_reads.tsv" + + +def main(sys_argv): + args = parse_args(sys_argv) + set_logger(logger, args) + check_args(args) + out_dir = os.path.dirname(args.output) + if out_dir and not os.path.exists(out_dir): + os.makedirs(out_dir, exist_ok=True) + + if args.threads == 1 or args.mode.enforces_single_thread(): + process_single_thread(args) + else: + process_in_parallel(args) + + +if __name__ == "__main__": + # stuff only to run when not called via 'import' here + try: + main(sys.argv[1:]) + except SystemExit: + raise + except: + print_exc() + sys.exit(-1) diff --git a/docs/cmd.md b/docs/cmd.md index 12e17350..10520543 100644 --- a/docs/cmd.md +++ b/docs/cmd.md @@ -96,23 +96,68 @@ Use this option at your own risk. Input file names are used as labels if not set. `--read_group` - Sets a way to group feature counts (e.g. by cell type). Available options are: + Sets one or more ways to group feature counts (e.g. by cell type, file, or barcode). + Multiple grouping strategies can be combined (space-separated). + Available grouping options: - * `file_name`: groups reads by their original file names (or file name labels) within an experiment. + * `file_name` - groups reads by their original file names (or file name labels) within an experiment. This option makes sense when multiple files are provided. -This option is designed for obtaining expression tables with a separate column for each file. -If multiple BAM/FASTQ files are provided and `--read_group` option is not set, IsoQuant will set `--read_group:file_name` +If multiple BAM/FASTQ files are provided and `--read_group` option is not set, IsoQuant will set `--read_group file_name` by default. - * `tag`: groups reads by BAM file read tag: set `tag:TAG`, where `TAG` is the desired tag name -(e.g. `tag:RG` with use `RG` values as groups, `RG` will be used if unset); - * `read_id`: groups reads by read name suffix: set `read_id:DELIM` where `DELIM` is the + + * `tag:TAG` - groups reads by BAM file read tag, where `TAG` is the tag name +(e.g. `tag:CB` uses `CB` tag values as groups, commonly used for cell barcodes in single-cell data). + + * `read_id:DELIM` - groups reads by read name suffix, where `DELIM` is the symbol/string by which the read id will be split -(e.g. if `DELIM` is `_`, for read `m54158_180727_042959_59310706_ccs_NEU` the group will set as `NEU`); - * `file`: uses additional file with group information for every read: `file:FILE:READ_COL:GROUP_COL:DELIM`, -where `FILE` is the file name, `READ_COL` is column with read ids (0 if not set), -`GROUP_COL` is column with group ids (1 if not set), -`DELIM` is separator symbol (tab if not set). File can be gzipped. +(e.g. if `DELIM` is `_`, for read `m54158_180727_042959_59310706_ccs_NEU` the group will be `NEU`). + + * `file:FILE:READ_COL:GROUP_COL(S):DELIM` - uses additional TSV file with group information for every read, +where `FILE` is the file path, `READ_COL` is column with read ids (default: 0), +`GROUP_COL(S)` is column(s) with group ids (default: 1; use comma-separated columns for multi-column grouping, e.g., `1,2,3`), +`DELIM` is separator symbol (default: tab). File can be gzipped. + + * `barcode_spot` or `barcode_spot:FILE` - maps barcodes to spots/cell types. +Uses barcode-to-spot mapping from `--barcode2spot` file by default, or from explicit `FILE` if specified. +Useful for grouping single-cell/spatial data by cell type or spatial region instead of individual barcodes. + +**Example**: `--read_group tag:CB file_name barcode_spot` creates multi-level grouping by cell barcode tag, file name, and cell type. + + +## Single-cell and spatial transcriptomics options + +`--mode` or `-m` +IsoQuant mode for processing single-cell or spatial transcriptomics data. Available modes: + +* `bulk` - standard bulk RNA-seq mode (default) +* `tenX_v3` - 10x Genomics single-cell 3' gene expression +* `curio` - Curio Bioscience single-cell +* `visium_hd` - 10x Genomics Visium HD spatial transcriptomics +* `visium_5prime` - 10x Genomics Visium 5' spatial transcriptomics +* `stereoseq` - Stereo-seq spatial transcriptomics (BGI) +* `stereoseq_nosplit` - Stereo-seq without barcode splitting + +Single-cell and spatial modes enable automatic barcode calling and UMI-based deduplication. + +`--barcode_whitelist` +Path to file(s) with barcode whitelist for barcode calling. +Required for single-cell/spatial modes unless `--barcoded_reads` is provided. +File should contain one barcode per line. + +`--barcoded_reads` +Path to TSV file(s) with barcoded reads. +Format: `read_idbarcodeumi` (one read per line). +If not provided, barcodes will be called automatically from raw reads using `--barcode_whitelist`. + +`--barcode_column` +Column index for barcodes in the `--barcoded_reads` file (default: 1). +Read ID column is 0, barcode column is 1, UMI column is 2. +`--barcode2spot` +Path to TSV file(s) mapping barcodes to cell types or spatial spots. +Format: `barcodecell_type` (one barcode per line). +Used with `--read_group barcode_spot` to group counts by cell type/spot instead of individual barcodes. +Useful for reducing output dimensions from thousands of barcodes to tens of cell types. ### Output options diff --git a/isoquant.py b/isoquant.py index 79a85ada..5603ffd4 100755 --- a/isoquant.py +++ b/isoquant.py @@ -19,11 +19,14 @@ from collections import namedtuple from io import StringIO from traceback import print_exc +from concurrent.futures import ProcessPoolExecutor +import concurrent.futures import pysam import gffutils import pyfaidx +from src.modes import IsoQuantMode, ISOQUANT_MODES from src.gtf2db import convert_gtf_to_db from src.read_mapper import ( DATA_TYPE_ALIASES, @@ -34,14 +37,16 @@ NANOPORE_DATA, DataSetReadMapper ) -from src.dataset_processor import DatasetProcessor, PolyAUsageStrategies from src.alignment_processor import PolyATrimmed +from src.dataset_processor import DatasetProcessor, PolyAUsageStrategies from src.graph_based_model_construction import StrandnessReportingLevel from src.long_read_assigner import AmbiguityResolvingMethod from src.long_read_counter import COUNTING_STRATEGIES, CountingStrategy, GroupedOutputFormat, NormalizationMethod from src.input_data_storage import InputDataStorage, InputDataType from src.multimap_resolver import MultimapResolvingStrategy from src.stats import combine_counts +from detect_barcodes import process_single_thread, process_in_parallel, get_umi_length + logger = logging.getLogger('IsoQuant') @@ -63,6 +68,7 @@ def parse_args(cmd_args=None, namespace=None): output_setup_args_group = parser.add_argument_group('Output configuration') align_args_group = parser.add_argument_group('Aligner settings') filer_args_group = parser.add_argument_group('Read filtering options') + sc_args_group = parser.add_argument_group('Single-cell/spatial-related options:') other_options = parser.add_argument_group("Additional options:") show_full_help = '--full_help' in cmd_args @@ -127,11 +133,16 @@ def add_hidden_option(*args, **kwargs): # show command only with --full-help input_args_group.add_argument('--illumina_bam', nargs='+', type=str, help='sorted and indexed file(s) with Illumina reads from the same sample') - input_args_group.add_argument("--read_group", help="a way to group feature counts (no grouping by default): " - "by BAM file tag (tag:TAG); " - "using additional file (file:FILE:READ_COL:GROUP_COL:DELIM); " - "using read id (read_id:DELIM); " - "by original file name (file_name)", type=str) + input_args_group.add_argument("--read_group", nargs='+', type=str, + help="one or more ways to group feature counts (no grouping by default); " + "multiple grouping strategies can be specified (space-separated); " + "supported formats: " + "tag:TAG (BAM tag), " + "file:FILE:READ_COL:GROUP_COL(S):DELIM (TSV file, use comma-separated columns for multi-column grouping, e.g., file:table.tsv:0:1,2,3), " + "read_id:DELIM (read ID suffix), " + "file_name (original filename), " + "barcode_spot[:FILE] (map barcodes to spots/cell types using --barcode2spot or explicit file); " + "example: --read_group tag:CB file_name barcode_spot") add_additional_option_to_group(input_args_group, "--read_assignments", nargs='+', type=str, help="reuse read assignments (binary format)", default=None) @@ -147,6 +158,22 @@ def add_hidden_option(*args, **kwargs): # show command only with --full-help input_args_group.add_argument('--fl_data', action='store_true', default=False, help="reads represent FL transcripts; both ends of the read are considered to be reliable") + # SC ARGUMENTS + sc_args_group.add_argument("--mode", "-m", type=str, choices=ISOQUANT_MODES, + help="IsoQuant modes: " + ", ".join(ISOQUANT_MODES) + + "; default:%s" % IsoQuantMode.bulk.name, default=IsoQuantMode.bulk.name) + sc_args_group.add_argument('--barcode_whitelist', type=str, nargs='+', + help='file with barcode whitelist for barcode calling') + sc_args_group.add_argument("--barcoded_reads", type=str, nargs='+', + help='TSV file with barcoded reads; barcodes will be called automatically if not provided') + # TODO: add UMI column, support various formats + sc_args_group.add_argument("--barcode_column", type=str, + help='column with barcodes in barcoded_reads file, default=1; read id column is 0', + default=1) + # TODO: add multiple columns + sc_args_group.add_argument("--barcode2spot", type=str, nargs='+', + help='TSV file barcode to cell type / spot id information') + # ALGORITHM add_additional_option_to_group(algo_args_group, "--report_novel_unspliced", "-u", type=bool_str, help="report novel monoexonic transcripts (true/false), " @@ -199,6 +226,8 @@ def add_hidden_option(*args, **kwargs): # show command only with --full-help help='Do not use previously generated index, feature db or alignments.') add_additional_option_to_group(pipeline_args_group, "--no_model_construction", action="store_true", default=False, help="run only read assignment and quantification") + add_additional_option_to_group(pipeline_args_group, "--no_large_files", action="store_true", + default=False, help="do not output files containing all reads (bed and tsv)") add_additional_option_to_group(pipeline_args_group, "--run_aligner_only", action="store_true", default=False, help="align reads to reference without running further analysis") add_additional_option_to_group(pipeline_args_group, "--no_gtf_check", help="do not perform GTF checks", @@ -397,10 +426,23 @@ def save_params(args): args.__dict__[file_opt] = os.path.abspath(args.__dict__[file_opt]) if "read_group" in args.__dict__ and args.__dict__["read_group"]: - vals = args.read_group.split(":") - if len(vals) > 1 and vals[0] == 'file': - vals[1] = os.path.abspath(vals[1]) - args.read_group = ":".join(vals) + # Handle both list (nargs='+') and string (backward compatibility) + if isinstance(args.read_group, list): + updated_specs = [] + for spec in args.read_group: + vals = spec.split(":") + if len(vals) > 1 and vals[0] == 'file': + vals[1] = os.path.abspath(vals[1]) + updated_specs.append(":".join(vals)) + else: + updated_specs.append(spec) + args.read_group = updated_specs + else: + # Backward compatibility with string format + vals = args.read_group.split(":") + if len(vals) > 1 and vals[0] == 'file': + vals[1] = os.path.abspath(vals[1]) + args.read_group = ":".join(vals) pickler = pickle.Pickler(open(args.param_file, "wb"), -1) pickler.dump(args) @@ -423,7 +465,7 @@ def check_input_params(args): if not args.fastq and not args.bam and not args.unmapped_bam and not args.read_assignments and not args.yaml: logger.error("No input data was provided") return False - + if args.yaml and args.illumina_bam: logger.error("When providing a yaml file it should include all input files, including the illumina bam file.") return False @@ -460,7 +502,18 @@ def check_input_params(args): if args.process_only_chr and args.discard_chr: args.discard_chr = [] logger.warning("--discard_chr has not effect when --process_only_chr is set and will be ignored") - + + if not isinstance(args.mode, IsoQuantMode): + args.mode = IsoQuantMode[args.mode] + + args.umi_length = 0 + if args.mode.needs_barcode_calling(): + if not args.barcode_whitelist and not args.barcoded_reads: + logger.critical("You have chosen single-cell/spatial mode %s, please specify barcode whitelist or file with " + "barcoded reads" % args.mode.name) + exit(-3) + args.umi_length = get_umi_length(args.mode) + check_input_files(args) return True @@ -580,9 +633,16 @@ def set_data_dependent_options(args): args.resolve_ambiguous = 'monoexon_and_fsm' if args.fl_data else 'default' args.requires_polya_for_construction = False - if args.read_group is None and args.input_data.has_replicas(): - args.read_group = "file_name" - args.use_technical_replicas = args.read_group == "file_name" + + # Automatically add file_name grouping when multiple files are present + if args.input_data.has_replicas(): + if args.read_group is None: + # No read grouping specified, use file_name + args.read_group = ["file_name"] + else: + # Read grouping specified, ensure file_name is included + if "file_name" not in args.read_group: + args.read_group.append("file_name") def set_matching_options(args): @@ -802,12 +862,77 @@ def prepare_reference_genome(args): args.reference = gunzipped_reference +class BarcodeCallingArgs: + def __init__(self, input, barcode_whitelist, mode, output, out_fasta, tmp_dir, threads): + self.input = input + self.barcodes = barcode_whitelist + self.mode = mode + self.output_tsv = output + self.out_fasta = out_fasta + self.tmp_dir = tmp_dir + self.threads = threads + self.min_score = None + + +def call_barcodes(args): + if not args.barcoded_reads: + for sample in args.input_data.samples: + new_reads = [] + for i, files in enumerate(sample.file_list): + output_barcodes = sample.barcodes_tsv + "_%d.tsv" % i + barcodes_done = sample.barcodes_done + "_%d.tsv" % i + + output_fasta = None + if args.mode.produces_new_fasta(): + output_fasta = sample.split_reads_fasta + "_%d.fa" % i + new_reads.append([output_fasta]) + bc_threads = 1 if args.mode.enforces_single_thread() else args.threads + if os.path.exists(barcodes_done): + if args.resume: + logger.info("Barcodes were called during the previous run, skipping") + sample.barcoded_reads.append(output_barcodes) + continue + os.remove(barcodes_done) + + bc_args = BarcodeCallingArgs(files[0], args.barcode_whitelist, args.mode, + output_barcodes, output_fasta, sample.aux_dir, bc_threads) + # Launching barcode calling in a separate process has the following reason: + # Read chunks are not cleared by the GC in the end of barcode calling, leaving the main + # IsoQuant process to consume ~2,5 GB even when barcode calling is done. + # Once 16 child processes are created later, IsoQuant instantly takes threads x 2,5 GB for nothing. + with ProcessPoolExecutor(max_workers=1) as proc: + logger.info("Detecting barcodes") + if bc_threads == 1: + future_res = proc.submit(process_single_thread, bc_args) + else: + future_res = proc.submit(process_in_parallel, bc_args) + + concurrent.futures.wait([future_res], return_when=concurrent.futures.ALL_COMPLETED) + if future_res.exception() is not None: + raise future_res.exception() + + sample.barcoded_reads.append(output_barcodes) + open(barcodes_done, "w").close() + logger.info("Processed %s, barcodes are stored in %s" % (files[0], output_barcodes)) + + if args.mode.produces_new_fasta(): + logger.info("Reads were split during barcode calling") + logger.info("The following files will be used instead of original reads %s " % ", ".join(map(lambda x: x[0], new_reads))) + sample.file_list = new_reads + else: + # TODO barcoded files via YAML + args.input_data.samples[0].barcoded_reads = args.barcoded_reads + + def run_pipeline(args): logger.info(" === IsoQuant pipeline started === ") logger.info("Python version: %s" % sys.version) logger.info("gffutils version: %s" % gffutils.__version__) logger.info("pysam version: %s" % pysam.__version__) logger.info("pyfaidx version: %s" % pyfaidx.__version__) + if args.mode.needs_barcode_calling(): + # call barcodes + call_barcodes(args) # gunzip refernece genome if needed prepare_reference_genome(args) @@ -826,14 +951,15 @@ def run_pipeline(args): if args.run_aligner_only: logger.info("Isoform assignment step is skipped because --run-aligner-only option was used") - else: - # run isoform assignment - dataset_processor = DatasetProcessor(args) - dataset_processor.process_all_samples(args.input_data) + return + + # run isoform assignment + dataset_processor = DatasetProcessor(args) + dataset_processor.process_all_samples(args.input_data) - # aggregate counts for all samples - if len(args.input_data.samples) > 1 and args.genedb: - combine_counts(args.input_data, args.output) + # aggregate counts for all samples + if len(args.input_data.samples) > 1 and args.genedb: + combine_counts(args.input_data, args.output) logger.info(" === IsoQuant pipeline finished === ") diff --git a/requirements.txt b/requirements.txt index 1c7c69a7..4cb544be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,12 @@ pandas>=1.0.1 pysam>=0.15 packaging pyfaidx>=0.7 +ssw-py>=1.0.0 pyyaml>=5.4 matplotlib>=3.1.3 numpy>=1.21.0 scipy>=1.4.1 seaborn>=0.10.0 +editdistance>=0.8.1 +biopython>=1.76 diff --git a/requirements_tests.txt b/requirements_tests.txt index f24fc7b2..34fd571b 100644 --- a/requirements_tests.txt +++ b/requirements_tests.txt @@ -11,3 +11,9 @@ simplejson>=3.17.0 six>=1.14.0 pyfaidx>=0.7 pyyaml>=5.4 +ssw-py>=1.0.0 +matplotlib>=3.1.3 +scipy>=1.4.1 +seaborn>=0.10.0 +editdistance>=0.8.1 +biopython>=1.76 diff --git a/src/alignment_processor.py b/src/alignment_processor.py index 0f0caee7..ba27877c 100644 --- a/src/alignment_processor.py +++ b/src/alignment_processor.py @@ -244,6 +244,7 @@ class AlignmentCollector: def __init__(self, chr_id, bam_pairs, params, illumina_bam, genedb=None, chr_record=None, read_groupper=DefaultReadGrouper(), + barcode_dict=None, small_chr_max_coverage=1000000, usual_gene_max_coverage=-1): self.chr_id = chr_id @@ -258,6 +259,7 @@ def __init__(self, chr_id, bam_pairs, params, illumina_bam, multiple_iterators=not self.params.high_memory) self.strand_detector = StrandDetector(self.chr_record) self.read_groupper = read_groupper + self.barcode_dict = barcode_dict # read_id -> (barcode, umi) self.small_chr_max_coverage = small_chr_max_coverage self.usual_gene_max_coverage = usual_gene_max_coverage self.polya_finder = PolyAFinder(self.params.polya_window, self.params.polya_fraction) @@ -276,9 +278,14 @@ def process(self): self.alignment_stat_counter.add(AlignmentType.primary) if alignment_storage.alignment_is_not_adjacent(alignment): - for res in self.forward_alignments(alignment_storage): - yield res - alignment_storage.reset() + preceding_genes = self.get_genes_in_region(alignment_storage.region) + next_genes = self.get_genes_in_region((alignment.reference_start, alignment.reference_end - 1)) + + if len(preceding_genes.intersection(next_genes)) == 0: + for res in self.forward_alignments(alignment_storage): + yield res + alignment_storage.reset() + alignment_storage.add_alignment(bam_index, alignment) if alignment_storage.region: @@ -383,7 +390,13 @@ def process_intergenic(self, alignment_storage, region, skip_read_fraction=1): read_assignment.exons = alignment_info.read_exons read_assignment.corrected_exons = corrector.correct_read(alignment_info) read_assignment.corrected_introns = junctions_from_blocks(read_assignment.corrected_exons) - read_assignment.read_group = self.read_groupper.get_group_id(alignment, self.bam_merger.bam_pairs[bam_index][1]) + + group_ids = self.read_groupper.get_group_id(alignment, self.bam_merger.bam_pairs[bam_index][1]) + # Ensure read_group is always a list + read_assignment.read_group = group_ids if isinstance(group_ids, list) else [group_ids] + # Populate barcode and UMI if available + if read_id in self.barcode_dict: + read_assignment.barcode, read_assignment.umi = self.barcode_dict[read_id] read_assignment.mapped_strand = "-" if alignment.is_reverse else "+" read_assignment.strand = self.get_assignment_strand(read_assignment) read_assignment.chr_id = self.chr_id @@ -448,7 +461,12 @@ def process_genic(self, alignment_storage, gene_info, region, skip_read_fraction read_assignment) read_assignment.corrected_introns = junctions_from_blocks(read_assignment.corrected_exons) - read_assignment.read_group = self.read_groupper.get_group_id(alignment, self.bam_merger.bam_pairs[bam_index][1]) + group_ids = self.read_groupper.get_group_id(alignment, self.bam_merger.bam_pairs[bam_index][1]) + # Ensure read_group is always a list + read_assignment.read_group = group_ids if isinstance(group_ids, list) else [group_ids] + # Populate barcode and UMI if available + if read_id in self.barcode_dict: + read_assignment.barcode, read_assignment.umi = self.barcode_dict[read_id] read_assignment.mapped_strand = "-" if alignment.is_reverse else "+" read_assignment.strand = self.get_assignment_strand(read_assignment) AlignmentCollector.check_antisense(read_assignment) @@ -515,6 +533,14 @@ def count_indel_stats(self, alignment): return indel_count, junctions_with_indels + def get_genes_in_region(self, current_region): + if not self.genedb: + return set() + return set(g.id for g in self.genedb.region(seqid=self.chr_id, + start=current_region[0], + end=current_region[1], + featuretype="gene")) + def get_gene_info_for_region(self, current_region): if not self.genedb: return GeneInfo.from_region(self.chr_id, current_region[0], current_region[1], diff --git a/src/assignment_io.py b/src/assignment_io.py index aaa250de..1398f985 100644 --- a/src/assignment_io.py +++ b/src/assignment_io.py @@ -27,7 +27,7 @@ BasicReadAssignment, MatchClassification, ReadAssignmentType) -from .gene_info import GeneInfo +from .gene_info import GeneInfo, GeneList logger = logging.getLogger('IsoQuant') @@ -71,7 +71,8 @@ def __init__(self, output_file_name, params, assignment_checker=PrintAllFunctor( AbstractAssignmentPrinter.__init__(self, output_file_name, params, assignment_checker) self.gzipped = gzipped if gzipped: - self.output_file = gzip.open(output_file_name + ".gz", "wt") + self.output_file_name += ".gz" + self.output_file = gzip.open(self.output_file_name, "wt") else: self.output_file = open(self.output_file_name, "w") @@ -161,6 +162,7 @@ def add_read_info(self, read_assignment): def flush(self): pass + class BaseTmpFileAssignmentLoader: def __init__(self, input_file_name): self.loader = open(input_file_name, "rb") @@ -208,6 +210,25 @@ def get_object(self): return None +class GeneListTmpFileAssignmentLoader(BaseTmpFileAssignmentLoader): + def __init__(self, input_file_name): + BaseTmpFileAssignmentLoader.__init__(self, input_file_name) + self.current_gene_info = None + + def get_object(self): + if self.is_gene_info(): + self.current_gene_info = GeneList.deserialize(self.loader) + self._read_id() + return self.current_gene_info + elif self.is_read_assignment(): + assert self.current_gene_info is not None + assignment = ReadAssignment.deserialize(self.loader, self.current_gene_info) + self._read_id() + return assignment + else: + return None + + class QuickTmpFileAssignmentLoader(BaseTmpFileAssignmentLoader): def __init__(self, input_file_name): BaseTmpFileAssignmentLoader.__init__(self, input_file_name) @@ -248,7 +269,7 @@ def unmatched_line(read_assignment, additional_info): line += "\t" + " ".join(additional_info) else: line += "\t*" - line += "\t%s\n" % read_assignment.read_group + line += "\t%s\n" % ",".join(read_assignment.read_group) return line def add_read_info(self, read_assignment): @@ -310,7 +331,7 @@ def add_read_info(self, read_assignment): line += "\t" + " ".join(additional_info) else: line += "\t*" - line += "\t%s\n" % read_assignment.read_group + line += "\t%s\n" % ",".join(read_assignment.read_group) self.output_file.write(line) diff --git a/src/assignment_loader.py b/src/assignment_loader.py new file mode 100644 index 00000000..77194964 --- /dev/null +++ b/src/assignment_loader.py @@ -0,0 +1,239 @@ +############################################################################ +# Copyright (c) 2022-2025 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import logging +from collections import defaultdict + +import gffutils +from pyfaidx import Fasta, UnsupportedCompressionFormat + +from .serialization import * +from .file_naming import * +from .isoform_assignment import BasicReadAssignment, ReadAssignmentType +from .assignment_io import ( + NormalTmpFileAssignmentLoader, + QuickTmpFileAssignmentLoader, GeneListTmpFileAssignmentLoader +) +from .gene_info import GeneList, GeneInfo + +logger = logging.getLogger('IsoQuant') + + +class BasicAssignmentLoader: + def __init__(self, save_file_name): + logger.info("Loading read assignments from " + save_file_name) + assert os.path.exists(save_file_name) + self.save_file_name = save_file_name + + def has_next(self): + raise NotImplementedError() + + def get_next(self): + raise NotImplementedError() + + +class FullAssignmentLoader(BasicAssignmentLoader): + def __init__(self, save_file_name, multimapped_chr_dict, filtered_read_set=None): + BasicAssignmentLoader.__init__(self, save_file_name) + self.multimapped_chr_dict = multimapped_chr_dict + self.filtered_read_set = filtered_read_set + + def load_read_assignment(self, read_assignment): + if self.filtered_read_set is not None and read_assignment.read_id not in self.filtered_read_set: + return None + if self.multimapped_chr_dict is not None and read_assignment.read_id in self.multimapped_chr_dict: + resolved_assignment = None + for a in self.multimapped_chr_dict[read_assignment.read_id]: + if a.assignment_id == read_assignment.assignment_id and a.chr_id == read_assignment.chr_id: + if resolved_assignment is not None: + logger.info("Duplicate read: %s %s %s" % (read_assignment.read_id, a.gene_id, a.chr_id)) + resolved_assignment = a + + if not resolved_assignment: + logger.warning("Incomplete information on read %s" % read_assignment.read_id) + return None + elif resolved_assignment.assignment_type == ReadAssignmentType.suspended: + return None + else: + read_assignment.assignment_type = resolved_assignment.assignment_type + read_assignment.gene_assignment_type = resolved_assignment.gene_assignment_type + read_assignment.multimapper = resolved_assignment.multimapper + return read_assignment + + +class ReadAssignmentLoader(FullAssignmentLoader): + def __init__(self, save_file_name, gffutils_db, chr_record, multimapped_chr_dict, filtered_read_set=None): + FullAssignmentLoader.__init__(self, save_file_name, multimapped_chr_dict, filtered_read_set) + self.genedb = gffutils_db + self.chr_record = chr_record + self.unpickler = NormalTmpFileAssignmentLoader(save_file_name, gffutils_db, chr_record) + + def has_next(self): + return self.unpickler.has_next() + + def get_next(self): + if not self.unpickler.has_next(): + return None, None + + assert self.unpickler.is_gene_info() + gene_info = self.unpickler.get_object() + assignment_storage = [] + while self.unpickler.is_read_assignment(): + read_assignment = self.load_read_assignment(self.unpickler.get_object()) + if read_assignment is not None: + assignment_storage.append(read_assignment) + + return gene_info, assignment_storage + + +class MergingSimpleReadAssignmentLoader(FullAssignmentLoader): + def __init__(self, save_file_name, multimapped_chr_dict, filtered_read_set=None): + FullAssignmentLoader.__init__(self, save_file_name, multimapped_chr_dict, filtered_read_set) + self.unpickler = GeneListTmpFileAssignmentLoader(save_file_name) + self.current_gene_list = None + + def has_next(self): + return self.unpickler.has_next() + + def get_next(self): + if not self.unpickler.has_next(): + return None, None + + assignment_storage = [] + while self.unpickler.has_next(): + if self.current_gene_list is None: + assert self.unpickler.is_gene_info() + self.current_gene_list = self.unpickler.get_object() + elif self.unpickler.is_gene_info(): + gene_list = self.unpickler.get_object() + if self.current_gene_list.overlaps(gene_list): + self.current_gene_list.merge(gene_list) + else: + self.current_gene_list = gene_list + return None, assignment_storage + + while self.unpickler.is_read_assignment(): + read_assignment = self.load_read_assignment(self.unpickler.get_object()) + if read_assignment is not None: + assignment_storage.append(read_assignment) + + return None, assignment_storage + + +class MergingReadAssignmentLoader(MergingSimpleReadAssignmentLoader): + def __init__(self, save_file_name, gffutils_db, chr_record, multimapped_chr_dict, filtered_read_set=None): + MergingSimpleReadAssignmentLoader.__init__(save_file_name, multimapped_chr_dict, filtered_read_set) + self.genedb = gffutils_db + self.chr_record = chr_record + + def _create_gene_info(self): + if not self.current_gene_list.gene_id_set: + gene_info = GeneInfo.from_region(self.current_gene_list.chr_id, self.current_gene_list.start, + self.current_gene_list.end, self.current_gene_list.delta) + else: + gene_info = GeneInfo([self.genedb[gene_id] for gene_id in self.current_gene_list.gene_id_set], self.genedb, + self.current_gene_list.delta) + if self.chr_record: + gene_info.set_reference_sequence(gene_info.all_read_region_start, + gene_info.all_read_region_end, + self.chr_record) + return gene_info + + def get_next(self): + if not self.unpickler.has_next(): + return None, None + + assignment_storage = [] + while self.unpickler.has_next(): + if self.current_gene_list is None: + assert self.unpickler.is_gene_info() + self.current_gene_list = self.unpickler.get_object() + elif self.unpickler.is_gene_info(): + gene_list = self.unpickler.get_object() + if self.current_gene_list.overlaps(gene_list): + self.current_gene_list.merge(gene_list) + else: + gene_info = self._create_gene_info() + for a in assignment_storage: + a.gene_info = gene_info + self.current_gene_list = gene_list + return gene_info, assignment_storage + + while self.unpickler.is_read_assignment(): + read_assignment = self.load_read_assignment(self.unpickler.get_object()) + if read_assignment is not None: + assignment_storage.append(read_assignment) + + gene_info = self._create_gene_info() + for a in assignment_storage: + a.gene_info = gene_info + return gene_info, assignment_storage + + +class BasicReadAssignmentLoader(BasicAssignmentLoader): + def __init__(self, save_file_name): + BasicAssignmentLoader.__init__(self, save_file_name) + self.unpickler = QuickTmpFileAssignmentLoader(save_file_name) + + def has_next(self): + return self.unpickler.has_next() + + def get_next(self): + if not self.unpickler.has_next(): + return + + assert self.unpickler.is_gene_info() + self.unpickler.get_object() + + while self.unpickler.is_read_assignment(): + yield self.unpickler.get_object() + + +def prepare_multimapped_reads(saves_prefix ,chr_id): + multimapped_reads = defaultdict(list) + multimap_loader = open(multimappers_file_name(saves_prefix ,chr_id), "rb") + list_size = read_int(multimap_loader) + while list_size != TERMINATION_INT: + for i in range(list_size): + a = BasicReadAssignment.deserialize(multimap_loader) + if a.chr_id == chr_id: + multimapped_reads[a.read_id].append(a) + list_size = read_int(multimap_loader) + return multimapped_reads + + +def prepare_read_filter(chr_id, saves_prefix, use_filtered_reads): + if not use_filtered_reads: + return None + filtered_reads = set() + for l in open(filtered_reads_file_name(saves_prefix, chr_id), "r"): + filtered_reads.add(l.rstrip()) + return filtered_reads + + +def load_genedb(genedb): + if genedb: + return gffutils.FeatureDB(genedb) + return None + + +def create_assignment_loader(chr_id, saves_prefix, genedb, reference_fasta, reference_fai, use_filtered_reads=False): + current_chr_record = Fasta(reference_fasta, indexname=reference_fai)[chr_id] + multimapped_reads = prepare_multimapped_reads(saves_prefix, chr_id) + filtered_reads = prepare_read_filter(chr_id, saves_prefix, use_filtered_reads) + gffutils_db = load_genedb(genedb) + chr_dump_file = saves_file_name(saves_prefix, chr_id) + + return ReadAssignmentLoader(chr_dump_file, gffutils_db, current_chr_record, multimapped_reads, filtered_reads) + + +def create_merging_assignment_loader(chr_id, saves_prefix, use_filtered_reads=False): + multimapped_reads = prepare_multimapped_reads(saves_prefix, chr_id) + filtered_reads = prepare_read_filter(chr_id, saves_prefix, use_filtered_reads) + chr_dump_file = saves_file_name(saves_prefix, chr_id) + + return MergingSimpleReadAssignmentLoader(chr_dump_file, multimapped_reads, filtered_reads) diff --git a/src/barcode_calling/__init__.py b/src/barcode_calling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/barcode_calling/barcode_callers.py b/src/barcode_calling/barcode_callers.py new file mode 100644 index 00000000..6a5adc66 --- /dev/null +++ b/src/barcode_calling/barcode_callers.py @@ -0,0 +1,1460 @@ +########################################################################### +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +""" +Barcode detection and matching for single-cell and spatial transcriptomics. + +Implements barcode calling for various platforms: +- Stereo-seq (standard and splitting modes) +- 10x Genomics (v3, Visium HD) +- Curio (double barcode) + +Uses k-mer indexing and Smith-Waterman alignment for approximate matching +of barcodes against whitelists. +""" + +import os +import logging +from collections import defaultdict +from typing import List, Tuple, Optional, Set, Dict, Iterable + +from .kmer_indexer import KmerIndexer, ArrayKmerIndexer, Array2BitKmerIndexer +from .common import find_polyt_start, reverese_complement, find_candidate_with_max_score_ssw, detect_exact_positions, \ + detect_first_exact_positions, str_to_2bit, bit_to_str, find_candidate_with_max_score_ssw_var_len +from .shared_mem_index import SharedMemoryArray2BitKmerIndexer + +logger = logging.getLogger('IsoQuant') + + +def increase_if_valid(val: Optional[int], delta: int) -> Optional[int]: + """ + Increment a coordinate value if it's valid. + + Args: + val: Position value (-1 or None indicates invalid) + delta: Amount to increment + + Returns: + Incremented value if valid, otherwise original value + """ + if val and val != -1: + return val + delta + return val + + +class BarcodeDetectionResult: + """ + Base class for barcode detection results. + + Stores detected barcode, UMI, and quality scores for a single read. + """ + + NOSEQ = "*" # Sentinel for missing/undetected sequence + + def __init__(self, read_id: str, barcode: str = NOSEQ, UMI: str = NOSEQ, + BC_score: int = -1, UMI_good: bool = False, strand: str = ".", + additional_info: Optional[Dict] = None): + """ + Initialize barcode detection result. + + Args: + read_id: Read identifier + barcode: Detected barcode sequence (NOSEQ if not found) + UMI: Detected UMI sequence (NOSEQ if not found) + BC_score: Barcode alignment score + UMI_good: Whether UMI passes quality filters + strand: Detected strand ('+', '-', or '.') + additional_info: Optional platform-specific metadata + """ + self.read_id: str = read_id + self.barcode: str = barcode + self.UMI: str = UMI + self.BC_score: int = BC_score + self.UMI_good: bool = UMI_good + self.strand: str = strand + + def is_valid(self) -> bool: + """Check if a valid barcode was detected.""" + return self.barcode != BarcodeDetectionResult.NOSEQ + + def update_coordinates(self, delta: int) -> None: + """ + Shift all genomic coordinates by delta. + + Used when processing read subsequences. + + Args: + delta: Amount to shift coordinates + """ + pass + + def more_informative_than(self, that: 'BarcodeDetectionResult') -> bool: + """ + Compare two results to determine which is more informative. + + Args: + that: Another detection result + + Returns: + True if this result is more informative + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() + + def get_additional_attributes(self) -> List[str]: + """ + Get list of detected additional features (primer, linker, etc.). + + Returns: + List of detected feature names + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() + + def set_strand(self, strand: str) -> None: + """Set the detected strand.""" + self.strand = strand + + def __str__(self) -> str: + """Format result as TSV line.""" + return "%s\t%s\t%s\t%d\t%s\t%s" % (self.read_id, self.barcode, self.UMI, + self.BC_score, self.UMI_good, self.strand) + + @staticmethod + def header() -> str: + """Get TSV header for result output.""" + return "#read_id\tbarcode\tUMI\tBC_score\tvalid_UMI\tstrand" + + +class DoubleBarcodeDetectionResult(BarcodeDetectionResult): + """ + Detection result for platforms with double barcodes (e.g., Curio, Stereo-seq). + + Extends base result with positions of additional features: + polyT tail, primer, and linker sequences. + """ + + def __init__(self, read_id: str, barcode: str = BarcodeDetectionResult.NOSEQ, + UMI: str = BarcodeDetectionResult.NOSEQ, + BC_score: int = -1, UMI_good: bool = False, strand: str = ".", + polyT: int = -1, primer: int = -1, linker_start: int = -1, linker_end: int = -1): + """ + Initialize double barcode detection result. + + Args: + read_id: Read identifier + barcode: Detected barcode (concatenated if split by linker) + UMI: Detected UMI sequence + BC_score: Barcode alignment score + UMI_good: Whether UMI passes quality filters + strand: Detected strand + polyT: Position of polyT tail start (-1 if not found) + primer: Position of primer end (-1 if not found) + linker_start: Position of linker start (-1 if not found) + linker_end: Position of linker end (-1 if not found) + """ + BarcodeDetectionResult.__init__(self, read_id, barcode, UMI, BC_score, UMI_good, strand) + self.primer: int = primer + self.linker_start: int = linker_start + self.linker_end: int = linker_end + self.polyT: int = polyT + + def is_valid(self): + return self.barcode != BarcodeDetectionResult.NOSEQ + + def update_coordinates(self, delta): + self.primer = increase_if_valid(self.primer, delta) + self.linker_start = increase_if_valid(self.linker_start, delta) + self.linker_end = increase_if_valid(self.linker_end, delta) + self.polyT = increase_if_valid(self.polyT, delta) + + def more_informative_than(self, that): + if self.BC_score != that.BC_score: + return self.BC_score > that.BC_score + if self.linker_start != that.linker_start: + return self.linker_start > that.linker_start + if self.primer != that.primer: + return self.primer > that.primer + return self.polyT > that.polyT + + def get_additional_attributes(self): + attr = [] + if self.polyT != -1: + attr.append("PolyT detected") + if self.primer != -1: + attr.append("Primer detected") + if self.linker_start != -1: + attr.append("Linker detected") + return attr + + def set_strand(self, strand): + self.strand = strand + + def __str__(self): + return (BarcodeDetectionResult.__str__(self) + + "\t%d\t%d\t%d\t%d" % (self.polyT, self.primer, self.linker_start, self.linker_end)) + + @staticmethod + def header(): + return BarcodeDetectionResult.header() + "\tpolyT_start\tprimer_end\tlinker_start\tlinker_end" + + +class StereoBarcodeDetectionResult(DoubleBarcodeDetectionResult): + def __init__(self, read_id, barcode=BarcodeDetectionResult.NOSEQ, UMI=BarcodeDetectionResult.NOSEQ, + BC_score=-1, UMI_good=False, strand=".", + polyT=-1, primer=-1, linker_start=-1, linker_end=-1, tso=-1): + DoubleBarcodeDetectionResult.__init__(self, read_id, barcode, UMI, BC_score, UMI_good, strand, + polyT, primer, linker_start, linker_end) + self.tso5 = tso + + def update_coordinates(self, delta): + self.tso5 = increase_if_valid(self.tso5, delta) + DoubleBarcodeDetectionResult.update_coordinates(self, delta) + + def __str__(self): + return (DoubleBarcodeDetectionResult.__str__(self) + + "\t%d" % self.tso5) + + def get_additional_attributes(self): + attr = [] + if self.polyT != -1: + attr.append("PolyT detected") + if self.primer != -1: + attr.append("Primer detected") + if self.linker_start != -1: + attr.append("Linker detected") + if self.tso5 != -1: + attr.append("TSO detected") + return attr + + @staticmethod + def header(): + return DoubleBarcodeDetectionResult.header() + "\tTSO5" + + +class TenXBarcodeDetectionResult(BarcodeDetectionResult): + def __init__(self, read_id, barcode=BarcodeDetectionResult.NOSEQ, UMI=BarcodeDetectionResult.NOSEQ, + BC_score=-1, UMI_good=False, strand=".", + polyT=-1, r1=-1): + BarcodeDetectionResult.__init__(self, read_id, barcode, UMI, BC_score, UMI_good, strand) + self.r1 = r1 + self.polyT = polyT + + def is_valid(self): + return self.barcode != BarcodeDetectionResult.NOSEQ + + def update_coordinates(self, delta): + self.r1 = increase_if_valid(self.r1, delta) + self.polyT = increase_if_valid(self.polyT, delta) + + def more_informative_than(self, that): + if self.polyT != that.polyT: + return self.polyT > that.polyT + if self.r1 != that.r1: + return self.r1 > that.r1 + return self.BC_score > that.BC_score + + def get_additional_attributes(self): + attr = [] + if self.polyT != -1: + attr.append("PolyT detected") + if self.r1 != -1: + attr.append("R1 detected") + return attr + + def set_strand(self, strand): + self.strand = strand + + def __str__(self): + return (BarcodeDetectionResult.__str__(self) + + "\t%d\t%d" % (self.polyT, self.r1)) + + @staticmethod + def header(): + return BarcodeDetectionResult.header() + "\tpolyT_start\tR1_end" + + +class SplittingBarcodeDetectionResult: + def __init__(self, read_id): + self.read_id = read_id + self.detected_patterns = [] + + def append(self, barcode_detection_result): + self.detected_patterns.append(barcode_detection_result) + + def empty(self): + return not self.detected_patterns + + def filter(self): + if not self.detected_patterns: return + barcoded_results = [] + for r in self.detected_patterns: + if r.barcode != BarcodeDetectionResult.NOSEQ: + barcoded_results.append(r) + + if not barcoded_results: + self.detected_patterns = [self.detected_patterns[0]] + else: + self.detected_patterns = barcoded_results + + @staticmethod + def header(): + return StereoBarcodeDetectionResult.header() + + +class ReadStats: + """ + Statistics tracker for barcode detection results. + + Accumulates counts of processed reads, detected barcodes, valid UMIs, + and platform-specific features (primers, linkers, polyT tails, etc.). + """ + + def __init__(self): + """Initialize empty statistics.""" + self.read_count: int = 0 + self.bc_count: int = 0 + self.umi_count: int = 0 + self.additional_attributes_counts: Dict[str, int] = defaultdict(int) + + def add_read(self, barcode_detection_result: BarcodeDetectionResult) -> None: + """ + Add a read result to statistics. + + Args: + barcode_detection_result: Detection result to accumulate + """ + self.read_count += 1 + # Count detected features (primer, linker, etc.) + for a in barcode_detection_result.get_additional_attributes(): + self.additional_attributes_counts[a] += 1 + # Count valid barcode + if barcode_detection_result.barcode != BarcodeDetectionResult.NOSEQ: + self.bc_count += 1 + # Count valid UMI + if barcode_detection_result.UMI_good: + self.umi_count += 1 + + def add_custom_stats(self, stat_name: str, val: int) -> None: + """ + Add custom statistic value. + + Args: + stat_name: Name of statistic + val: Count to add + """ + self.additional_attributes_counts[stat_name] += val + + def __str__(self) -> str: + """Format statistics as human-readable string.""" + human_readable_str = ("Total reads\t%d\nBarcode detected\t%d\nReliable UMI\t%d\n" % + (self.read_count, self.bc_count, self.umi_count)) + for a in self.additional_attributes_counts: + human_readable_str += "%s\t%d\n" % (a, self.additional_attributes_counts[a]) + return human_readable_str + + def __iter__(self) -> Iterable[str]: + """Iterate over statistics as formatted strings.""" + yield "Total reads: %d" % self.read_count + yield "Barcode detected: %d" % self.bc_count + yield "Reliable UMI: %d" % self.umi_count + for a in self.additional_attributes_counts: + yield "%s: %d" % (a, self.additional_attributes_counts[a]) + + +class StereoBarcodeDetector: + LINKER = "TTGTCTTCCTAAGAC" + TSO_PRIMER = "ACTGAGAGGCATGGCGACCTTATCAG" + PC1_PRIMER = "CTTCCGATCTATGGCGACCTTATCAG" + BC_LENGTH = 25 + UMI_LEN = 10 + NON_T_UMI_BASES = 0 + UMI_LEN_DELTA = 4 + TERMINAL_MATCH_DELTA = 3 + STRICT_TERMINAL_MATCH_DELTA = 1 + + def __init__(self, barcodes, min_score=21): + self.main_primer = StereoBarcodeDetector.PC1_PRIMER + self.pcr_primer_indexer = ArrayKmerIndexer([self.main_primer], kmer_size=6) + self.linker_indexer = ArrayKmerIndexer([StereoBarcodeDetector.LINKER], kmer_size=5) + self.strict_linker_indexer = ArrayKmerIndexer([StereoBarcodeDetector.LINKER], kmer_size=6) + + self.barcode_indexer = None + if barcodes: + bit_barcodes = map(str_to_2bit, barcodes) + self.barcode_indexer = Array2BitKmerIndexer(bit_barcodes, kmer_size=14, seq_len=self.BC_LENGTH) + logger.info("Indexed %d barcodes" % self.barcode_indexer.total_sequences) + + self.umi_set = None + self.min_score = min_score + + def find_barcode_umi_multiple(self, read_id, sequence): + read_result = [] + r = self._find_barcode_umi_fwd(read_id, sequence) + current_start = 0 + while r.polyT != -1: + r.set_strand("+") + read_result.append(r) + new_start = r.polyT + 50 + current_start += new_start + if len(sequence) - current_start < 50: + break + seq = sequence[current_start:] + new_id = read_id + "_%d" % current_start + r = self._find_barcode_umi_fwd(new_id, seq) + + rev_seq = reverese_complement(sequence) + read_id += "_R" + rr = self._find_barcode_umi_fwd(read_id, rev_seq) + current_start = 0 + while rr.polyT != -1: + rr.set_strand("-") + read_result.append(rr) + new_start = rr.polyT + 50 + current_start += new_start + if len(rev_seq) - current_start < 50: + break + seq = rev_seq[current_start:] + new_id = read_id + "_%d" % current_start + rr = self._find_barcode_umi_fwd(new_id, seq) + + if not read_result: + read_result.append(r) + return read_result + + def find_barcode_umi(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + polyt_start = find_polyt_start(sequence) + + linker_start, linker_end = None, None + if polyt_start != -1: + # use relaxed parameters is polyA is found + linker_occurrences = self.linker_indexer.get_occurrences(sequence[0:polyt_start + 1]) + linker_start, linker_end = detect_exact_positions(sequence, 0, polyt_start + 1, + self.linker_indexer.k, StereoBarcodeDetector.LINKER, + linker_occurrences, min_score=12, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + if linker_start is None: + # if polyT was not found, or linker was not found to the left of polyT, look for linker in the entire read + linker_occurrences = self.strict_linker_indexer.get_occurrences(sequence) + linker_start, linker_end = detect_exact_positions(sequence, 0, len(sequence), + self.linker_indexer.k, StereoBarcodeDetector.LINKER, + linker_occurrences, min_score=12, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if linker_start is None: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start) + logger.debug("LINKER: %d-%d" % (linker_start, linker_end)) + + if polyt_start == -1: + # if polyT was not detected earlier, use relaxed parameters once the linker is found + presumable_polyt_start = linker_end + self.UMI_LEN + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + + primer_occurrences = self.pcr_primer_indexer.get_occurrences(sequence[:linker_start]) + primer_start, primer_end = detect_exact_positions(sequence, 0, linker_start, + self.pcr_primer_indexer.k, self.main_primer, + primer_occurrences, min_score=12, + end_delta=self.TERMINAL_MATCH_DELTA) + if primer_start is not None: + logger.debug("PRIMER: %d-%d" % (primer_start, primer_end)) + else: + primer_start = -1 + primer_end = linker_start - self.BC_LENGTH - 1 + + if primer_end < 0: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start, primer=-1, + linker_start=linker_start, linker_end=linker_end) + + barcode_start = primer_end + 1 + barcode_end = linker_start - 1 + bc_len = barcode_end - barcode_start + if abs(bc_len - self.BC_LENGTH) > 10: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + + potential_barcode = sequence[barcode_start:barcode_end + 1] + logger.debug("Barcode: %s" % (potential_barcode)) + matching_barcodes = self.barcode_indexer.get_occurrences(potential_barcode, max_hits=10, min_kmers=2) + barcode, bc_score, bc_start, bc_end = \ + find_candidate_with_max_score_ssw(matching_barcodes, potential_barcode, + min_score=self.min_score, sufficient_score=self.BC_LENGTH - 1) + + if barcode is None: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + logger.debug("Found: %s %d-%d" % (barcode, bc_start, bc_end)) + + potential_umi_start = linker_end + 1 + potential_umi_end = polyt_start - 1 + umi = None + good_umi = False + if potential_umi_start + 2 * self.UMI_LEN > potential_umi_end > potential_umi_start: + umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % umi) + good_umi = abs(len(umi) - self.UMI_LEN) <= self.UMI_LEN_DELTA + + if not umi: + return DoubleBarcodeDetectionResult(read_id, barcode, BC_score=bc_score, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + return DoubleBarcodeDetectionResult(read_id, barcode, umi, bc_score, good_umi, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + + @staticmethod + def result_type(): + return DoubleBarcodeDetectionResult + + +class SharedMemoryStereoBarcodeDetector(StereoBarcodeDetector): + MIN_BARCODES_FOR_SHARED_MEM = 1000000 + + def __init__(self, barcodes, min_score=21): + super().__init__([], min_score=min_score) + bit_barcodes = list(map(str_to_2bit, barcodes)) + self.barcode_count = len(bit_barcodes) + self.barcodes = [] + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + # workaround for --test and small barcode sets (which cannot happen in practice) + self.barcode_indexer = KmerIndexer(list(map(lambda x: bit_to_str(x, self.BC_LENGTH), bit_barcodes)), kmer_size=14) + self.barcodes = self.barcode_indexer.seq_list + else: + self.barcode_indexer = SharedMemoryArray2BitKmerIndexer(bit_barcodes, kmer_size=14, + seq_len=super().BC_LENGTH) + + logger.info("Indexed %d barcodes" % self.barcode_count) + + def __getstate__(self): + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + return (self.min_score, + self.barcode_count, + self.barcodes) + else: + return (self.min_score, + self.barcode_count, + self.barcode_indexer.get_sharable_info()) + + def __setstate__(self, state): + self.min_score = state[0] + super().__init__([], min_score=self.min_score) + self.barcodes = [] + self.barcode_count = state[1] + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + self.barcodes = state[2] + self.barcode_indexer = KmerIndexer(self.barcodes, kmer_size=14) + else: + self.barcode_indexer = SharedMemoryArray2BitKmerIndexer.from_sharable_info(state[2]) + + +class StereoSplttingBarcodeDetector: + TSO5 = "CCCGCCTCTCAGTACGTCAGCAG" + LINKER = "TTGTCTTCCTAAGAC" + TSO_PRIMER = "ACTGAGAGGCATGGCGACCTTATCAG" + PC1_PRIMER = "CTTCCGATCTATGGCGACCTTATCAG" + BC_LENGTH = 25 + UMI_LEN = 10 + NON_T_UMI_BASES = 0 + UMI_LEN_DELTA = 3 + TERMINAL_MATCH_DELTA = 3 + STRICT_TERMINAL_MATCH_DELTA = 1 + + def __init__(self, barcodes, min_score=21): + self.main_primer = self.PC1_PRIMER + self.tso5_indexer = ArrayKmerIndexer([self.TSO5], kmer_size=8) + self.pcr_primer_indexer = ArrayKmerIndexer([self.main_primer], kmer_size=6) + self.linker_indexer = ArrayKmerIndexer([self.LINKER], kmer_size=5) + self.strict_linker_indexer = ArrayKmerIndexer([StereoBarcodeDetector.LINKER], kmer_size=7) + + self.barcode_indexer = None + if barcodes: + bit_barcodes = map(str_to_2bit, barcodes) + self.barcode_indexer = Array2BitKmerIndexer(bit_barcodes, kmer_size=14, seq_len=self.BC_LENGTH) + logger.info("Indexed %d barcodes" % self.barcode_indexer.total_sequences) + self.umi_set = None + self.min_score = min_score + + def find_barcode_umi(self, read_id, sequence): + read_result = SplittingBarcodeDetectionResult(read_id) + logger.debug("Looking in forward direction") + r = self._find_barcode_umi_fwd(read_id, sequence) + prev_start = 0 + while r.polyT != -1: + r.set_strand("+") + read_result.append(r) + if r.tso5 != -1: + current_start = r.tso5 + 15 + else: + current_start = r.polyT + 100 + # always make a step + current_start = max(prev_start + 150, current_start) + prev_start = current_start + if len(sequence) - current_start < 50: + break + + logger.debug("Looking further from %d" % current_start) + seq = sequence[current_start:] + r = self._find_barcode_umi_fwd(read_id, seq) + r.update_coordinates(current_start) + + logger.debug("Looking in reverse direction") + rev_seq = reverese_complement(sequence) + r = self._find_barcode_umi_fwd(read_id, rev_seq) + prev_start = 0 + while r.polyT != -1: + r.set_strand("-") + read_result.append(r) + if r.tso5 != -1: + current_start = r.tso5 + 15 + else: + current_start = r.polyT + 100 + # always make a step + current_start = max(prev_start + 150, current_start) + prev_start = current_start + if len(rev_seq) - current_start < 50: + break + + logger.debug("Looking further from %d" % current_start) + seq = rev_seq[current_start:] + r = self._find_barcode_umi_fwd(read_id, seq) + r.update_coordinates(current_start) + + if read_result.empty(): + # add empty result anyway + read_result.append(r) + + read_result.filter() + logger.debug("Total barcodes detected %d" % len(read_result.detected_patterns)) + return read_result + + def find_barcode_umi_single(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + polyt_start = find_polyt_start(sequence) + logger.debug("PolyT found right away %d" % polyt_start) + + linker_start, linker_end = None, None + tso5_start = None + if polyt_start != -1: + # use relaxed parameters is polyA is found + logger.debug("Looking for linker in %d" % len(sequence[0:polyt_start + 1])) + linker_occurrences = self.linker_indexer.get_occurrences(sequence[0:polyt_start + 1]) + linker_start, linker_end = detect_exact_positions(sequence, 0, polyt_start + 1, + self.linker_indexer.k, self.LINKER, + linker_occurrences, min_score=10, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + tso5_occurrences = self.tso5_indexer.get_occurrences(sequence[polyt_start + 1:]) + tso5_start, tso5_end = detect_first_exact_positions(sequence, polyt_start + 1, len(sequence), + self.tso5_indexer.k, self.TSO5, + tso5_occurrences, min_score=15, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + if linker_start is None: + # if polyT was not found, or linker was not found to the left of polyT, look for linker in the entire read + linker_occurrences = self.strict_linker_indexer.get_occurrences(sequence) + linker_start, linker_end = detect_first_exact_positions(sequence, 0, len(sequence), + self.linker_indexer.k, StereoBarcodeDetector.LINKER, + linker_occurrences, min_score=12, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if linker_start is None: + return StereoBarcodeDetectionResult(read_id, polyT=polyt_start) + logger.debug("LINKER: %d-%d" % (linker_start, linker_end)) + + if polyt_start == -1 or polyt_start < linker_start: + # if polyT was not detected earlier, use relaxed parameters once the linker is found + presumable_polyt_start = linker_end + self.UMI_LEN + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + logger.debug("PolyT found later %d" % polyt_start) + else: + logger.debug("PolyT was not found %d" % polyt_start) + + tso5_occurrences = self.tso5_indexer.get_occurrences(sequence[polyt_start + 1:]) + tso5_start, tso5_end = detect_first_exact_positions(sequence, polyt_start + 1, len(sequence), + self.tso5_indexer.k, self.TSO5, + tso5_occurrences, min_score=15, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + if tso5_start: + logger.debug("TSO found %d" % tso5_start) + # check that no another linker is found inbetween polyA and TSO 5' + linker_occurrences = self.strict_linker_indexer.get_occurrences(sequence[polyt_start + 1: tso5_start]) + new_linker_start, new_linker_end = detect_exact_positions(sequence, polyt_start + 1, tso5_start, + self.linker_indexer.k, self.LINKER, + linker_occurrences, min_score=12, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if new_linker_start is not None and new_linker_start != -1 and new_linker_start - polyt_start > 100: + # another linker found inbetween polyT and TSO + logger.debug("Another linker was found before TSO: %d" % new_linker_start) + tso5_start = new_linker_start - self.BC_LENGTH - len(self.main_primer) - len(self.TSO5) + logger.debug("TSO updated %d" % tso5_start) + else: + tso5_start = -1 + + primer_occurrences = self.pcr_primer_indexer.get_occurrences(sequence[:linker_start]) + primer_start, primer_end = detect_exact_positions(sequence, 0, linker_start, + self.pcr_primer_indexer.k, self.main_primer, + primer_occurrences, min_score=12, + end_delta=self.TERMINAL_MATCH_DELTA) + if primer_start is not None: + logger.debug("PRIMER: %d-%d" % (primer_start, primer_end)) + else: + primer_end = linker_start - self.BC_LENGTH - 1 + + if primer_end < 0: + return StereoBarcodeDetectionResult(read_id, polyT=polyt_start, primer=-1, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + + barcode_start = primer_end + 1 + barcode_end = linker_start - 1 + bc_len = barcode_end - barcode_start + if abs(bc_len - self.BC_LENGTH) > 10: + return StereoBarcodeDetectionResult(read_id, polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + + potential_barcode = sequence[barcode_start:barcode_end + 1] + logger.debug("Barcode: %s" % (potential_barcode)) + if not self.barcode_indexer: + return StereoBarcodeDetectionResult(read_id, potential_barcode, BC_score=0, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + + matching_barcodes = self.barcode_indexer.get_occurrences(potential_barcode, max_hits=10, min_kmers=2) + barcode, bc_score, bc_start, bc_end = \ + find_candidate_with_max_score_ssw(matching_barcodes, potential_barcode, + min_score=self.min_score, sufficient_score=self.BC_LENGTH - 1) + + if barcode is None: + return StereoBarcodeDetectionResult(read_id, polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + logger.debug("Found: %s %d-%d" % (barcode, bc_start, bc_end)) + + potential_umi_start = linker_end + 1 + potential_umi_end = polyt_start - 1 + umi = None + good_umi = False + if potential_umi_start + 2 * self.UMI_LEN > potential_umi_end > potential_umi_start: + umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % umi) + good_umi = abs(len(umi) - self.UMI_LEN) <= self.UMI_LEN_DELTA + + if not umi: + return StereoBarcodeDetectionResult(read_id, barcode, BC_score=bc_score, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + return StereoBarcodeDetectionResult(read_id, barcode, umi, bc_score, good_umi, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end, tso=tso5_start) + + @staticmethod + def result_type(): + return SplittingBarcodeDetectionResult + + +class SharedMemoryStereoSplttingBarcodeDetector(StereoSplttingBarcodeDetector): + MIN_BARCODES_FOR_SHARED_MEM = 1000000 + + def __init__(self, barcodes, min_score=21): + super().__init__([], min_score=min_score) + bit_barcodes = list(map(str_to_2bit, barcodes)) + self.barcode_count = len(bit_barcodes) + self.barcodes = [] + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + # workaround for --test and small barcode sets (which cannot happen in practice) + self.barcode_indexer = KmerIndexer(list(map(lambda x: bit_to_str(x, self.BC_LENGTH), bit_barcodes)), kmer_size=14) + self.barcodes = self.barcode_indexer.seq_list + else: + self.barcode_indexer = SharedMemoryArray2BitKmerIndexer(bit_barcodes, kmer_size=10, + seq_len=super().BC_LENGTH) + + logger.info("Indexed %d barcodes" % self.barcode_count) + + def __getstate__(self): + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + return (self.min_score, + self.barcode_count, + self.barcodes) + else: + return (self.min_score, + self.barcode_count, + self.barcode_indexer.get_sharable_info()) + + def __setstate__(self, state): + self.min_score = state[0] + super().__init__([], min_score=self.min_score) + self.barcodes = [] + self.barcode_count = state[1] + if self.barcode_count < self.MIN_BARCODES_FOR_SHARED_MEM: + self.barcodes = state[2] + self.barcode_indexer = KmerIndexer(self.barcodes, kmer_size=14) + else: + self.barcode_indexer = SharedMemoryArray2BitKmerIndexer.from_sharable_info(state[2]) + + +class SharedMemoryWrapper: + def __init__(self, barcode_detector_class, barcodes, min_score=21): + self.barcode_detector_class = barcode_detector_class + self.min_score = min_score + self.barcode_detector = self.barcode_detector_class([], self.min_score) + bit_barcodes = list(map(str_to_2bit, barcodes)) + self.barcode_detector.barcode_indexer = SharedMemoryArray2BitKmerIndexer(bit_barcodes, kmer_size=14, + seq_len=self.barcode_detector_class.BC_LENGTH) + + logger.info("Indexed %d barcodes" % self.barcode_detector.barcode_indexer.total_sequences) + + def __getstate__(self): + return (self.barcode_detector_class, + self.min_score, + self.barcode_detector.barcode_indexer.get_sharable_info()) + + def __setstate__(self, state): + self.barcode_detector_class = state[0] + self.min_score = state[1] + self.barcode_detector = self.barcode_detector_class([], self.min_score) + self.barcode_detector.barcode_indexer = SharedMemoryArray2BitKmerIndexer.from_sharable_info(state[2]) + + +class DoubleBarcodeDetector: + LINKER = "TCTTCAGCGTTCCCGAGA" + PCR_PRIMER = "TACACGACGCTCTTCCGATCT" + LEFT_BC_LENGTH = 8 + RIGHT_BC_LENGTH = 6 + BC_LENGTH = LEFT_BC_LENGTH + RIGHT_BC_LENGTH + UMI_LEN = 9 + NON_T_UMI_BASES = 2 + UMI_LEN_DELTA = 2 + TERMINAL_MATCH_DELTA = 2 + STRICT_TERMINAL_MATCH_DELTA = 1 + + def __init__(self, barcode_list, umi_list=None, min_score=13): + self.pcr_primer_indexer = ArrayKmerIndexer([DoubleBarcodeDetector.PCR_PRIMER], kmer_size=6) + self.linker_indexer = ArrayKmerIndexer([DoubleBarcodeDetector.LINKER], kmer_size=5) + joint_barcode_list = [] + if isinstance(barcode_list, tuple): + # barcode provided separately + for b1 in barcode_list[0]: + for b2 in barcode_list[1]: + joint_barcode_list.append(b1 + b2) + else: + joint_barcode_list = barcode_list + self.barcode_indexer = ArrayKmerIndexer(joint_barcode_list, kmer_size=6) + self.umi_set = None + if umi_list: + self.umi_set = set(umi_list) + logger.debug("Loaded %d UMIs" % len(umi_list)) + self.umi_indexer = KmerIndexer(umi_list, kmer_size=5) + self.min_score = min_score + + def find_barcode_umi(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + polyt_start = find_polyt_start(sequence) + + linker_start, linker_end = None, None + if polyt_start != -1: + # use relaxed parameters is polyA is found + linker_occurrences = self.linker_indexer.get_occurrences(sequence[0:polyt_start + 1]) + linker_start, linker_end = detect_exact_positions(sequence, 0, polyt_start + 1, + self.linker_indexer.k, self.LINKER, + linker_occurrences, min_score=11, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + if linker_start is None: + # if polyT was not found, or linker was not found to the left of polyT, look for linker in the entire read + linker_occurrences = self.linker_indexer.get_occurrences(sequence) + linker_start, linker_end = detect_exact_positions(sequence, 0, len(sequence), + self.linker_indexer.k, self.LINKER, + linker_occurrences, min_score=14, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if linker_start is None: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start) + logger.debug("LINKER: %d-%d" % (linker_start, linker_end)) + + if polyt_start == -1 or polyt_start - linker_end > self.RIGHT_BC_LENGTH + self.UMI_LEN + 10: + # if polyT was not detected earlier, use relaxed parameters once the linker is found + presumable_polyt_start = linker_end + self.RIGHT_BC_LENGTH + self.UMI_LEN + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + + primer_occurrences = self.pcr_primer_indexer.get_occurrences(sequence[:linker_start]) + primer_start, primer_end = detect_exact_positions(sequence, 0, linker_start, + self.pcr_primer_indexer.k, self.PCR_PRIMER, + primer_occurrences, min_score=5, + end_delta=self.TERMINAL_MATCH_DELTA) + if primer_start is not None: + logger.debug("PRIMER: %d-%d" % (primer_start, primer_end)) + else: + primer_start = -1 + primer_end = -1 + + barcode_start = primer_end + 1 if primer_start != -1 else linker_start - self.LEFT_BC_LENGTH + barcode_end = linker_end + self.RIGHT_BC_LENGTH + 1 + potential_barcode = sequence[barcode_start:linker_start] + sequence[linker_end + 1:barcode_end + 1] + logger.debug("Barcode: %s" % (potential_barcode)) + matching_barcodes = self.barcode_indexer.get_occurrences(potential_barcode) + barcode, bc_score, bc_start, bc_end = \ + find_candidate_with_max_score_ssw(matching_barcodes, potential_barcode, min_score=self.min_score) + + if barcode is None: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + logger.debug("Found: %s %d-%d" % (barcode, bc_start, bc_end)) + # position of barcode end in the reference: end of potential barcode minus bases to the alignment end + read_barcode_end = barcode_start + bc_end - 1 + (linker_end - linker_start + 1) + + potential_umi_start = read_barcode_end + 1 + potential_umi_end = polyt_start - 1 + if potential_umi_end - potential_umi_start <= 5: + potential_umi_end = potential_umi_start + self.UMI_LEN - 1 + potential_umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % potential_umi) + + umi = None + good_umi = False + if self.umi_set: + matching_umis = self.umi_indexer.get_occurrences(potential_umi) + umi, umi_score, umi_start, umi_end = \ + find_candidate_with_max_score_ssw(matching_umis, potential_umi, min_score=7) + logger.debug("Found UMI %s %d-%d" % (umi, umi_start, umi_end)) + + if not umi : + umi = potential_umi + if self.UMI_LEN - self.UMI_LEN_DELTA <= len(umi) <= self.UMI_LEN + self.UMI_LEN_DELTA and \ + all(x != "T" for x in umi[-self.NON_T_UMI_BASES:]): + good_umi = True + + if not umi: + return DoubleBarcodeDetectionResult(read_id, barcode, BC_score=bc_score, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + return DoubleBarcodeDetectionResult(read_id, barcode, umi, bc_score, good_umi, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + + @staticmethod + def result_type(): + return DoubleBarcodeDetectionResult + + +class BruteForceDoubleBarcodeDetector: + LINKER = "TCTTCAGCGTTCCCGAGA" + LEFT_BC_LENGTH = 2 + RIGHT_BC_LENGTH = 12 + BC_LENGTH = LEFT_BC_LENGTH + RIGHT_BC_LENGTH + + def __init__(self, joint_barcode_list): + self.barcode_set = set(joint_barcode_list) + + def find_barcode_umi(self, read_id, sequence): + linker_found, barcode = self._find_barcode_umi_fwd(read_id, sequence) + if linker_found: + return linker_found, barcode + + rev_seq = reverese_complement(sequence) + return self._find_barcode_umi_fwd(read_id, rev_seq) + + def _find_barcode_umi_fwd(self, read_id, sequence): + pos = sequence.find(self.LINKER) + if pos == -1: + return DoubleBarcodeDetectionResult(read_id) + + bc_start = max(0, pos - self.LEFT_BC_LENGTH) + barcode = sequence[bc_start:pos] + linker_end = pos + len(self.LINKER) + bc_end = min(len(sequence), linker_end + 1 + self.RIGHT_BC_LENGTH) + barcode += sequence[linker_end + 1:bc_end] + if len(barcode) != self.BC_LENGTH or barcode not in self.barcode_set: + return DoubleBarcodeDetectionResult(read_id, linker_start=pos) + return DoubleBarcodeDetectionResult(read_id, barcode, BC_score=len(barcode), linker_start=pos) + + @staticmethod + def result_type(): + return DoubleBarcodeDetectionResult + + +class IlluminaDoubleBarcodeDetector: + LINKER = "TCTTCAGCGTTCCCGAGA" + PCR_PRIMER = "TACACGACGCTCTTCCGATCT" + LEFT_BC_LENGTH = 8 + RIGHT_BC_LENGTH = 6 + BC_LENGTH = LEFT_BC_LENGTH + RIGHT_BC_LENGTH + MIN_BC_LEN = BC_LENGTH - 4 + UMI_LEN = 9 + NON_T_UMI_BASES = 2 + UMI_LEN_DELTA = 2 + SCORE_DIFF = 1 + + TERMINAL_MATCH_DELTA = 1 + STRICT_TERMINAL_MATCH_DELTA = 0 + + def __init__(self, joint_barcode_list, umi_list=None, min_score=14): + self.pcr_primer_indexer = KmerIndexer([DoubleBarcodeDetector.PCR_PRIMER], kmer_size=6) + self.linker_indexer = KmerIndexer([DoubleBarcodeDetector.LINKER], kmer_size=5) + self.barcode_indexer = KmerIndexer(joint_barcode_list, kmer_size=5) + self.umi_set = None + if umi_list: + self.umi_set = set(umi_list) + logger.debug("Loaded %d UMIs" % len(umi_list)) + self.umi_indexer = KmerIndexer(umi_list, kmer_size=5) + self.min_score = min_score + + def find_barcode_umi(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + # look for linker in the entire read + linker_occurrences = self.linker_indexer.get_occurrences(sequence) + linker_start, linker_end = detect_exact_positions(sequence, 0, len(sequence), + self.linker_indexer.k, self.LINKER, + linker_occurrences, min_score=14, + start_delta=self.TERMINAL_MATCH_DELTA, + end_delta=self.TERMINAL_MATCH_DELTA) + + if linker_start is None: + return DoubleBarcodeDetectionResult(read_id) + logger.debug("LINKER: %d-%d" % (linker_start, linker_end)) + primer_end = -1 # forget about primer + + # use relaxed parameters once the linker is found + presumable_polyt_start = linker_end + self.RIGHT_BC_LENGTH + self.UMI_LEN + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + + barcode_start = linker_start - self.LEFT_BC_LENGTH + if barcode_start < 0: + barcode_start = 0 + barcode_end = linker_end + self.RIGHT_BC_LENGTH + 1 + if barcode_end > len(sequence): + barcode_end = len(sequence) + + potential_barcode = sequence[barcode_start:linker_start] + sequence[linker_end + 1:barcode_end] + logger.debug("Barcode: %s" % (potential_barcode)) + if len(potential_barcode) < self.MIN_BC_LEN: + return DoubleBarcodeDetectionResult(read_id, linker_start=linker_start, linker_end=linker_end) + matching_barcodes = self.barcode_indexer.get_occurrences(potential_barcode) + barcode, bc_score, bc_start, bc_end = \ + find_candidate_with_max_score_ssw(matching_barcodes, potential_barcode, + min_score=len(potential_barcode) - 1, score_diff=self.SCORE_DIFF) + + if barcode is None: + return DoubleBarcodeDetectionResult(read_id, polyT=polyt_start, primer=-1, + linker_start=linker_start, linker_end=linker_end) + logger.debug("Found: %s %d-%d" % (barcode, bc_start, bc_end)) + # position of barcode end in the reference: end of potential barcode minus bases to the alignment end + read_barcode_end = barcode_start + bc_end - 1 + (linker_end - linker_start + 1) + + potential_umi_start = read_barcode_end + 1 + potential_umi_end = polyt_start - 1 + if polyt_start != -1 or potential_umi_end - potential_umi_start <= 5: + potential_umi_end = potential_umi_start + self.UMI_LEN - 1 + potential_umi = sequence[potential_umi_start:min(potential_umi_end + 1, len(sequence))] + logger.debug("Potential UMI: %s" % potential_umi) + + umi = None + good_umi = False + if self.umi_set: + matching_umis = self.umi_indexer.get_occurrences(potential_umi) + umi, umi_score, umi_start, umi_end = \ + find_candidate_with_max_score_ssw(matching_umis, potential_umi, min_score=7) + logger.debug("Found UMI %s %d-%d" % (umi, umi_start, umi_end)) + + if not umi : + umi = potential_umi + if self.UMI_LEN - self.UMI_LEN_DELTA <= len(umi) <= self.UMI_LEN + self.UMI_LEN_DELTA and \ + all(x != "T" for x in umi[-self.NON_T_UMI_BASES:]): + good_umi = True + + if not umi: + return DoubleBarcodeDetectionResult(read_id, barcode, BC_score=bc_score, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + return DoubleBarcodeDetectionResult(read_id, barcode, umi, bc_score, good_umi, + polyT=polyt_start, primer=primer_end, + linker_start=linker_start, linker_end=linker_end) + + @staticmethod + def result_type(): + return DoubleBarcodeDetectionResult + + +class TenXBarcodeDetector: + TSO = "CCCATGTACTCTGCGTTGATACCACTGCTT" + # R1 = "ACACTCTTTCCCTACACGACGCTCTTCCGATCT" # + R1 = "CTACACGACGCTCTTCCGATCT" # 10x 3' + BARCODE_LEN_10X = 16 + UMI_LEN = 12 + + UMI_LEN_DELTA = 2 + TERMINAL_MATCH_DELTA = 2 + STRICT_TERMINAL_MATCH_DELTA = 1 + + def __init__(self, barcode_list, umi_list=None): + self.r1_indexer = KmerIndexer([TenXBarcodeDetector.R1], kmer_size=7) + self.barcode_indexer = KmerIndexer(barcode_list, kmer_size=6) + self.umi_set = None + if umi_list: + self.umi_set = set(umi_list) + logger.debug("Loaded %d UMIs" % len(umi_list)) + self.umi_indexer = KmerIndexer(umi_list, kmer_size=5) + self.min_score = 14 + if len(self.barcode_indexer.seq_list) > 100000: + self.min_score = 16 + logger.debug("Min score set to %d" % self.min_score) + + def find_barcode_umi(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + polyt_start = find_polyt_start(sequence) + + r1_start, r1_end = None, None + if polyt_start != -1: + # use relaxed parameters is polyA is found + r1_occurrences = self.r1_indexer.get_occurrences(sequence[0:polyt_start + 1]) + r1_start, r1_end = detect_exact_positions(sequence, 0, polyt_start + 1, + self.r1_indexer.k, self.R1, + r1_occurrences, min_score=11, + end_delta=self.TERMINAL_MATCH_DELTA) + + if r1_start is None: + # if polyT was not found, or linker was not found to the left of polyT, look for linker in the entire read + r1_occurrences = self.r1_indexer.get_occurrences(sequence) + r1_start, r1_end = detect_exact_positions(sequence, 0, len(sequence), + self.r1_indexer.k, self.R1, + r1_occurrences, min_score=18, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if r1_start is None: + return TenXBarcodeDetectionResult(read_id, polyT=polyt_start) + logger.debug("LINKER: %d-%d" % (r1_start, r1_end)) + + if polyt_start == -1 or polyt_start - r1_end > self.BARCODE_LEN_10X + self.UMI_LEN + 10: + # if polyT was not detected earlier, use relaxed parameters once the linker is found + presumable_polyt_start = r1_end + self.BARCODE_LEN_10X + self.UMI_LEN + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + + barcode_start = r1_end + 1 + barcode_end = r1_end + self.BARCODE_LEN_10X + 1 + potential_barcode = sequence[barcode_start:barcode_end + 1] + logger.debug("Barcode: %s" % (potential_barcode)) + matching_barcodes = self.barcode_indexer.get_occurrences(potential_barcode) + barcode, bc_score, bc_start, bc_end = \ + find_candidate_with_max_score_ssw(matching_barcodes, potential_barcode, min_score=self.min_score) + + if barcode is None: + return TenXBarcodeDetectionResult(read_id, polyT=polyt_start, r1=r1_end) + logger.debug("Found: %s %d-%d" % (barcode, bc_start, bc_end)) + # position of barcode end in the reference: end of potential barcode minus bases to the alignment end + read_barcode_end = barcode_start + bc_end - 1 + potential_umi_start = read_barcode_end + 1 + potential_umi_end = polyt_start - 1 + if potential_umi_end - potential_umi_start <= 5: + potential_umi_end = potential_umi_start + self.UMI_LEN - 1 + potential_umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % potential_umi) + + umi = None + good_umi = False + if self.umi_set: + matching_umis = self.umi_indexer.get_occurrences(potential_umi) + umi, umi_score, umi_start, umi_end = \ + find_candidate_with_max_score_ssw(matching_umis, potential_umi, min_score=7) + logger.debug("Found UMI %s %d-%d" % (umi, umi_start, umi_end)) + + if not umi : + umi = potential_umi + if self.UMI_LEN - self.UMI_LEN_DELTA <= len(umi) <= self.UMI_LEN + self.UMI_LEN_DELTA: + good_umi = True + + if not umi: + return TenXBarcodeDetectionResult(read_id, barcode, BC_score=bc_score, polyT=polyt_start, r1=r1_end) + return TenXBarcodeDetectionResult(read_id, barcode, umi, bc_score, good_umi, polyT=polyt_start, r1=r1_end) + + def find_barcode_umi_no_polya(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + @staticmethod + def result_type(): + return TenXBarcodeDetectionResult + + +class VisiumHDBarcodeDetector: + R1 = "CTACACGACGCTCTTCCGATCT" # 10x 3' + BARCODE1_LEN_VIS = 16 + BARCODE2_LEN_VIS = 15 + SEPARATOR_BASES = 2 + TOTAL_BARCODE_LEN_VIS = BARCODE1_LEN_VIS + BARCODE2_LEN_VIS + UMI_LEN_VIS=9 + + UMI_LEN_DELTA = 2 + TERMINAL_MATCH_DELTA = 2 + STRICT_TERMINAL_MATCH_DELTA = 1 + + def __init__(self, barcode_pair_list): + assert len(barcode_pair_list) == 2 + self.r1_indexer = KmerIndexer([VisiumHDBarcodeDetector.R1], kmer_size=7) + self.part1_list = barcode_pair_list[0] + self.part2_list = barcode_pair_list[1] + self.part1_barcode_indexer = KmerIndexer( self.part1_list, kmer_size=7) + self.part2_barcode_indexer = KmerIndexer(self.part2_list, kmer_size=7) + self.umi_set = None + self.min_score = 13 + logger.debug("Min score set to %d" % self.min_score) + + def find_barcode_umi(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + def _find_barcode_umi_fwd(self, read_id, sequence): + logger.debug("===== " + read_id) + polyt_start = find_polyt_start(sequence) + + r1_start, r1_end = None, None + if polyt_start != -1: + # use relaxed parameters is polyA is found + r1_occurrences = self.r1_indexer.get_occurrences(sequence[0:polyt_start + 1]) + r1_start, r1_end = detect_exact_positions(sequence, 0, polyt_start + 1, + self.r1_indexer.k, self.R1, + r1_occurrences, min_score=10, + end_delta=self.TERMINAL_MATCH_DELTA) + + if r1_start is None: + # if polyT was not found, or linker was not found to the left of polyT, look for linker in the entire read + r1_occurrences = self.r1_indexer.get_occurrences(sequence) + r1_start, r1_end = detect_exact_positions(sequence, 0, len(sequence), + self.r1_indexer.k, self.R1, + r1_occurrences, min_score=17, + start_delta=self.STRICT_TERMINAL_MATCH_DELTA, + end_delta=self.STRICT_TERMINAL_MATCH_DELTA) + + if r1_start is not None: + # return TenXBarcodeDetectionResult(read_id, polyT=polyt_start) + logger.debug("PRIMER: %d-%d" % (r1_start, r1_end)) + + if r1_end is not None and (polyt_start == -1 or polyt_start - r1_end > self.TOTAL_BARCODE_LEN_VIS + self.UMI_LEN_VIS + 10): + # if polyT was not detected earlier, use relaxed parameters once the linker is found + presumable_polyt_start = r1_end + self.TOTAL_BARCODE_LEN_VIS + self.UMI_LEN_VIS + search_start = presumable_polyt_start - 4 + search_end = min(len(sequence), presumable_polyt_start + 10) + polyt_start = find_polyt_start(sequence[search_start:search_end], window_size=5, polya_fraction=1.0) + if polyt_start != -1: + polyt_start += search_start + + if polyt_start == -1: + if r1_start is None: + return TenXBarcodeDetectionResult(read_id, polyT=polyt_start) + # no polyT, start from the left + potential_umi_start = r1_end + 1 + potential_umi_end = potential_umi_start + self.UMI_LEN_VIS - 1 + potential_umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % potential_umi) + + barcode1_start = r1_end + self.UMI_LEN_VIS + 1 + barcode1_end = barcode1_start + self.BARCODE1_LEN_VIS - 1 + potential_barcode1 = sequence[barcode1_start:barcode1_end + 1] + matching_barcodes1 = self.part1_barcode_indexer.get_occurrences(potential_barcode1) + barcode1, bc1_score, bc1_start, bc1_end = \ + find_candidate_with_max_score_ssw_var_len(matching_barcodes1, potential_barcode1, min_score=self.min_score) + logger.debug("Barcode 1: %s, %s" % (potential_barcode1, barcode1)) + real_bc1_end = barcode1_start + bc1_end + + barcode2_start = real_bc1_end + 1 + barcode2_end = barcode2_start + self.BARCODE2_LEN_VIS - 1 + potential_barcode2 = sequence[barcode2_start:barcode2_end + 1] + matching_barcodes2 = self.part2_barcode_indexer.get_occurrences(potential_barcode2) + barcode2, bc2_score, bc2_start, bc2_end = \ + find_candidate_with_max_score_ssw_var_len(matching_barcodes2, potential_barcode2, min_score=self.min_score) + logger.debug("Barcode 2: %s, %s" % (potential_barcode2, barcode2)) + + if barcode1 is None or barcode2 is None: + return TenXBarcodeDetectionResult(read_id, polyT=polyt_start, r1=r1_end) + + return TenXBarcodeDetectionResult(read_id, barcode1 + "|" + barcode2, potential_umi, bc1_score+bc2_score, + UMI_good=True, polyT=polyt_start, r1=r1_end) + + barcode2_end = polyt_start - 1 - self.SEPARATOR_BASES + barcode2_start = barcode2_end - self.BARCODE2_LEN_VIS + 1 + potential_barcode2 = sequence[barcode2_start:barcode2_end + 1] + matching_barcodes2 = self.part2_barcode_indexer.get_occurrences(potential_barcode2) + barcode2, bc2_score, bc2_start, bc2_end = \ + find_candidate_with_max_score_ssw_var_len(matching_barcodes2, potential_barcode2, min_score=self.min_score) + logger.debug("Barcode 2: %s, %s" % (potential_barcode2, barcode2)) + + real_bc2_start = barcode2_start + bc2_start + barcode1_end = real_bc2_start - 1 + barcode1_start = barcode1_end - self.BARCODE1_LEN_VIS + 1 + potential_barcode1 = sequence[barcode1_start:barcode1_end + 1] + matching_barcodes1 = self.part1_barcode_indexer.get_occurrences(potential_barcode1) + barcode1, bc1_score, bc1_start, bc1_end = \ + find_candidate_with_max_score_ssw_var_len(matching_barcodes1, potential_barcode1, min_score=self.min_score) + logger.debug("Barcode 1: %s, %s" % (potential_barcode1, barcode1)) + real_bc1_start = barcode1_start + bc1_start + + potential_umi_end = real_bc1_start - 1 + if r1_end is not None: + potential_umi_start = r1_end + 1 + else: + potential_umi_start = max(0, potential_umi_end - self.UMI_LEN_VIS) + umi_good = abs(potential_umi_end - potential_umi_start + 1 - self.UMI_LEN_VIS) <= self.UMI_LEN_DELTA + potential_umi = sequence[potential_umi_start:potential_umi_end + 1] + logger.debug("Potential UMI: %s" % potential_umi) + + if barcode1 is None or barcode2 is None: + return TenXBarcodeDetectionResult(read_id, polyT=polyt_start, r1=r1_end if r1_end is not None else -1) + + return TenXBarcodeDetectionResult(read_id, barcode1 + "|" + barcode2, potential_umi, bc1_score + bc2_score, + UMI_good=umi_good, polyT=polyt_start, r1=r1_end if r1_end is not None else -1) + + + def find_barcode_umi_no_polya(self, read_id, sequence): + read_result = self._find_barcode_umi_fwd(read_id, sequence) + if read_result.polyT != -1: + read_result.set_strand("+") + if read_result.is_valid(): + return read_result + + rev_seq = reverese_complement(sequence) + read_rev_result = self._find_barcode_umi_fwd(read_id, rev_seq) + if read_rev_result.polyT != -1: + read_rev_result.set_strand("-") + if read_rev_result.is_valid(): + return read_rev_result + + return read_result if read_result.more_informative_than(read_rev_result) else read_rev_result + + @staticmethod + def result_type(): + return TenXBarcodeDetectionResult diff --git a/src/barcode_calling/common.py b/src/barcode_calling/common.py new file mode 100644 index 00000000..225928c8 --- /dev/null +++ b/src/barcode_calling/common.py @@ -0,0 +1,217 @@ +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ +import copy + +from ssw import AlignmentMgr + + +def find_polyt_start(seq, window_size = 16, polya_fraction = 0.75): + polyA_count = int(window_size * polya_fraction) + + if len(seq) < window_size: + return -1 + i = 0 + a_count = seq[0:window_size].count('T') + while i < len(seq) - window_size: + if a_count >= polyA_count: + break + first_base_a = seq[i] == 'T' + new_base_a = i + window_size < len(seq) and seq[i + window_size] == 'T' + if first_base_a and not new_base_a: + a_count -= 1 + elif not first_base_a and new_base_a: + a_count += 1 + i += 1 + + if i >= len(seq) - window_size: + return -1 + + return i + max(0, seq[i:].find('TTTT')) + + +base_comp = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'N': 'N', " ": " "} + + +def reverese_complement(my_seq): ## obtain reverse complement of a sequence + lms = list(map(lambda x: base_comp[x], my_seq))[::-1] + return ''.join(lms) + + +def align_pattern_ssw(sequence, start, end, pattern, min_score=0): + seq = sequence[start:end] + align_mgr = AlignmentMgr(match_score=1, mismatch_penalty=1) + align_mgr.set_read(pattern) + align_mgr.set_reference(seq) + alignment = align_mgr.align(gap_open=1, gap_extension=1) + if alignment.optimal_score < min_score: + return None, None, None, None, None + return start + alignment.reference_start, start + alignment.reference_end, \ + alignment.read_start, alignment.read_end, alignment.optimal_score + + +def find_candidate_with_max_score_ssw(barcode_matches: list, read_sequence, min_score=10, score_diff=0, sufficient_score=0): + best_match = [0, 0, 0] + best_barcode = None + second_best_score = 0 + + align_mgr = AlignmentMgr(match_score=1, mismatch_penalty=1) + align_mgr.set_reference(read_sequence) + for barcode_match in barcode_matches: + barcode = barcode_match[0] + align_mgr.set_read(barcode) + alignment = align_mgr.align(gap_open=1, gap_extension=1) + if alignment.optimal_score < min_score: + continue + + if alignment.optimal_score > best_match[0]: + best_barcode = barcode + second_best_score = best_match[0] + best_match[0] = alignment.optimal_score + best_match[1] = alignment.reference_start - alignment.read_start + best_match[2] = alignment.reference_end + (len(barcode) - alignment.read_end) + elif alignment.optimal_score == best_match[0] and alignment.reference_start < best_match[1]: + best_barcode = barcode + second_best_score = best_match[0] + best_match[1] = alignment.reference_start - alignment.read_start + best_match[2] = alignment.reference_end + (len(barcode) - alignment.read_end) + + if alignment.optimal_score > sufficient_score > 0: + # dirty hack to select first "sufficiently good" alignment + break + + if best_match[0] - second_best_score < score_diff: + return None, 0, 0, 0 + + return best_barcode, best_match[0], best_match[1], best_match[2] + + +def find_candidate_with_max_score_ssw_var_len(barcode_matches: list, read_sequence, min_score=14, score_diff=1): + best_match = [0, 0, 0, 0] + second_best_match = [0, 0, 0, 0] + best_barcode = None + + align_mgr = AlignmentMgr(match_score=1, mismatch_penalty=1) + align_mgr.set_reference(read_sequence) + for barcode_match in barcode_matches: + barcode = barcode_match[0] + align_mgr.set_read(barcode) + alignment = align_mgr.align(gap_open=1, gap_extension=1) + if alignment.optimal_score < min_score: + continue + + ed = len(barcode) - alignment.optimal_score + if alignment.optimal_score > best_match[0]: + second_best_match = copy.copy(best_match) + best_barcode = barcode + best_match[0] = alignment.optimal_score + best_match[1] = alignment.reference_start - alignment.read_start + best_match[2] = alignment.reference_end + (len(barcode) - alignment.read_end) + best_match[3] = ed + elif alignment.optimal_score == best_match[0] and ed <= best_match[3]: + second_best_match = copy.copy(best_match) + best_barcode = barcode + best_match[1] = alignment.reference_start - alignment.read_start + best_match[2] = alignment.reference_end + (len(barcode) - alignment.read_end) + best_match[3] = ed + + if best_barcode and best_match[0] < len(best_barcode) and best_match[0] - second_best_match[0] < score_diff: + return None, 0, 0, 0 + + return best_barcode, best_match[0], best_match[1], best_match[2] + + +def detect_exact_positions(sequence, start, end, kmer_size, pattern, pattern_occurrences: list, + min_score=0, start_delta=-1, end_delta=-1): + pattern_index = None + for i, p in enumerate(pattern_occurrences): + if p[0] == pattern: + pattern_index = i + break + if pattern_index is None: + return None, None + + start_pos, end_pos, pattern_start, pattern_end, score = None, None, None, None, 0 + last_potential_pos = -2*len(pattern) + for match_position in pattern_occurrences[pattern_index][2]: + if match_position - last_potential_pos < len(pattern): + continue + + potential_start = start + match_position - len(pattern) + kmer_size + potential_start = max(start, potential_start) + potential_end = start + match_position + len(pattern) + 1 + potential_end = min(end, potential_end) + alignment = \ + align_pattern_ssw(sequence, potential_start, potential_end, pattern, min_score) + if alignment[4] is not None and alignment[4] > score: + start_pos, end_pos, pattern_start, pattern_end, score = alignment + + if start_pos is None: + return None, None + + if start_delta >= 0 and pattern_start > start_delta: + return None, None + if end_delta >= 0 and len(pattern) - pattern_end - 1 > end_delta: + return None, None + leftover_bases = len(pattern) - pattern_end - 1 + skipped_bases = pattern_start + return start_pos - skipped_bases, end_pos + leftover_bases + + +def detect_first_exact_positions(sequence, start, end, kmer_size, pattern, pattern_occurrences: list, + min_score=0, start_delta=-1, end_delta=-1): + pattern_index = None + for i, p in enumerate(pattern_occurrences): + if p[0] == pattern: + pattern_index = i + break + if pattern_index is None: + return None, None + + start_pos, end_pos, pattern_start, pattern_end, score = None, None, None, None, 0 + last_potential_pos = -2*len(pattern) + for match_position in pattern_occurrences[pattern_index][2]: + if match_position - last_potential_pos < len(pattern): + continue + + potential_start = start + match_position - len(pattern) + kmer_size + potential_start = max(start, potential_start) + potential_end = start + match_position + len(pattern) + 1 + potential_end = min(end, potential_end) + alignment = \ + align_pattern_ssw(sequence, potential_start, potential_end, pattern, min_score) + if alignment[4] is not None: + start_pos, end_pos, pattern_start, pattern_end, score = alignment + break + + if start_pos is None: + return None, None + + if start_delta >= 0 and pattern_start > start_delta: + return None, None + if end_delta >= 0 and len(pattern) - pattern_end - 1 > end_delta: + return None, None + leftover_bases = len(pattern) - pattern_end - 1 + skipped_bases = pattern_start + return start_pos - skipped_bases, end_pos + leftover_bases + + +NUCL2BIN = {'A': 0, 'C': 1, 'G': 3, 'T': 2, 'a': 0, 'c': 1, 'g': 3, 't': 2} +BIN2NUCL = ["A", "C", "T", "G"] + + +def str_to_2bit(seq): + kmer_idx = 0 + seq_len = len(seq) + for i in range(seq_len): + kmer_idx |= ((ord(seq[i]) & 6) >> 1) << ((seq_len - i - 1) * 2) + return kmer_idx + + +def bit_to_str(seq, seq_len): + str_seq = "" + for i in range(seq_len): + str_seq += BIN2NUCL[(seq >> ((seq_len - i - 1) * 2)) & 3] + return str_seq diff --git a/src/barcode_calling/kmer_indexer.py b/src/barcode_calling/kmer_indexer.py new file mode 100644 index 00000000..0b1c8af8 --- /dev/null +++ b/src/barcode_calling/kmer_indexer.py @@ -0,0 +1,342 @@ +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +""" +K-mer indexers for fast approximate string matching. + +Used for barcode calling in single-cell data. Indexes known barcodes/UMIs +and allows fast lookup of similar sequences based on shared k-mers. +""" + +import math +from typing import List, Tuple, Iterable, Dict, DefaultDict +from collections import defaultdict +from .common import bit_to_str, str_to_2bit + + +class KmerIndexer: + """ + Basic k-mer indexer using dictionary-based storage. + + Indexes sequences by their k-mers for fast approximate matching. + Best for small to medium barcode sets. + """ + + def __init__(self, known_strings: Iterable[str], kmer_size: int = 6): + """ + Initialize k-mer index. + + Args: + known_strings: Collection of reference sequences (barcodes/UMIs) + kmer_size: Length of k-mers to use for indexing + """ + self.seq_list: List[str] = list(known_strings) + self.k: int = kmer_size + self.index: DefaultDict[str, List[int]] = defaultdict(list) + self._index() + + def _get_kmers(self, seq: str) -> Iterable[str]: + """ + Generate all k-mers from a sequence using sliding window. + + Args: + seq: Input sequence + + Yields: + K-mer strings + """ + if len(seq) < self.k: + return + kmer = seq[:self.k] + yield kmer + # Slide window by removing first char and adding next + for i in range(self.k, len(seq)): + kmer = kmer[1:] + seq[i] + yield kmer + + def _index(self) -> None: + """Build k-mer index from all sequences.""" + for i, barcode in enumerate(self.seq_list): + for kmer in self._get_kmers(barcode): + self.index[kmer].append(i) + + def append(self, barcode: str) -> None: + """ + Add a new barcode to the index. + + Args: + barcode: Sequence to add to index + """ + self.seq_list.append(barcode) + index = len(self.seq_list) - 1 + for kmer in self._get_kmers(barcode): + self.index[kmer].append(index) + + def empty(self) -> bool: + """Check if index is empty.""" + return len(self.seq_list) == 0 + + def get_occurrences(self, sequence: str, max_hits: int = 0, min_kmers: int = 1, + hits_delta: int = 1, ignore_equal: bool = False) -> List[Tuple[str, int, List[int]]]: + """ + Find indexed sequences with shared k-mers. + + Args: + sequence: Query sequence to search + max_hits: Maximum number of results (0 = unlimited) + min_kmers: Minimum shared k-mers required + hits_delta: Include results within this many k-mers of top hit + ignore_equal: Skip exact matches + + Returns: + List of (sequence, shared_kmer_count, kmer_positions) tuples, + sorted by shared k-mer count (descending) + """ + # Count shared k-mers for each indexed sequence + barcode_counts: DefaultDict[int, int] = defaultdict(int) + barcode_positions: DefaultDict[int, List[int]] = defaultdict(list) + + for pos, kmer in enumerate(self._get_kmers(sequence)): + for i in self.index[kmer]: + barcode_counts[i] += 1 + barcode_positions[i].append(pos) + + # Filter and build results + result = [] + for i in barcode_counts.keys(): + count = barcode_counts[i] + if count < min_kmers: + continue + if ignore_equal and self.seq_list[i] == sequence: + continue + result.append((self.seq_list[i], count, barcode_positions[i])) + + if not result: + return [] + + # Keep only top hits within hits_delta + top_hits = max(result, key=lambda x: x[1])[1] + result = filter(lambda x: x[1] >= top_hits - hits_delta, result) + result = list(sorted(result, reverse=True, key=lambda x: x[1])) + + if max_hits == 0: + return result + return result[:max_hits] + + +class ArrayKmerIndexer: + """ + Optimized k-mer indexer using array-based storage and binary encoding. + + Converts k-mers to integers (2 bits per nucleotide) for faster lookup. + Memory usage: O(4^k) array entries. Best for k <= 8. + """ + + # Nucleotide to 2-bit encoding (A=00, C=01, G=10, T=11) + NUCL2BIN: Dict[str, int] = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'a': 0, 'c': 1, 'g': 2, 't': 3} + + def __init__(self, known_strings: Iterable[str], kmer_size: int = 6): + """ + Initialize array-based k-mer index. + + Args: + known_strings: Collection of reference sequences + kmer_size: Length of k-mers (recommended <= 8) + """ + self.seq_list: List[str] = list(known_strings) + self.k: int = kmer_size + total_kmers = int(math.pow(4, kmer_size)) # 4^k possible k-mers + self.index: List[List[int]] = [[] for _ in range(total_kmers)] + self.mask: int = (1 << (2 * self.k)) - 1 # Bit mask for k-mer + self._index() + + def _get_kmer_indexes(self, seq: str) -> Iterable[int]: + """ + Generate binary-encoded k-mer indices using sliding window. + + Converts k-mers to integers (2 bits per nucleotide) for fast lookup. + + Args: + seq: Input sequence + + Yields: + Integer k-mer indices + """ + if len(seq) < self.k: + return + + # Initialize first k-mer + kmer_idx = 0 + for i in range(self.k): + # Encode nucleotide as 2 bits and shift into position + kmer_idx |= ArrayKmerIndexer.NUCL2BIN[seq[i]] << ((self.k - i - 1) * 2) + yield kmer_idx + + # Slide window: shift left 2 bits, add new nucleotide + for i in range(self.k, len(seq)): + kmer_idx <<= 2 # Shift left to drop leftmost nucleotide + kmer_idx &= self.mask # Keep only k nucleotides + kmer_idx |= ArrayKmerIndexer.NUCL2BIN[seq[i]] # Add rightmost nucleotide + yield kmer_idx + + def _index(self) -> None: + """Build k-mer index from all sequences.""" + for i, barcode in enumerate(self.seq_list): + for kmer_idx in self._get_kmer_indexes(barcode): + self.index[kmer_idx].append(i) + + def append(self, barcode: str) -> None: + """Add a new barcode to the index.""" + self.seq_list.append(barcode) + index = len(self.seq_list) - 1 + for kmer_idx in self._get_kmer_indexes(barcode): + self.index[kmer_idx].append(index) + + def empty(self) -> bool: + """Check if index is empty.""" + return len(self.seq_list) == 0 + + def get_occurrences(self, sequence: str, max_hits: int = 0, min_kmers: int = 1, + hits_delta: int = 1, ignore_equal: bool = False) -> List[Tuple[str, int, List[int]]]: + """Find indexed sequences with shared k-mers (same as KmerIndexer.get_occurrences).""" + barcode_counts: DefaultDict[int, int] = defaultdict(int) + barcode_positions: DefaultDict[int, List[int]] = defaultdict(list) + + for pos, kmer_idx in enumerate(self._get_kmer_indexes(sequence)): + for i in self.index[kmer_idx]: + barcode_counts[i] += 1 + barcode_positions[i].append(pos) + + result = [] + for i in barcode_counts.keys(): + count = barcode_counts[i] + if count < min_kmers: + continue + if ignore_equal and self.seq_list[i] == sequence: + continue + result.append((self.seq_list[i], count, barcode_positions[i])) + + if not result: + return [] + + top_hits = max(result, key=lambda x: x[1])[1] + result = filter(lambda x: x[1] >= top_hits - hits_delta, result) + result = list(sorted(result, reverse=True, key=lambda x: x[1])) + + if max_hits == 0: + return result + return result[:max_hits] + + +class Array2BitKmerIndexer: + """ + Memory-efficient k-mer indexer using 2-bit encoding for both k-mers and sequences. + + Stores sequences as integers (2 bits per nucleotide) to minimize memory. + Uses flat array with range indexing for better cache performance. + Best for large barcode sets (e.g., single-cell whitelists). + """ + + def __init__(self, known_bin_seq: Iterable[int], kmer_size: int = 12, seq_len: int = 25): + """ + Initialize 2-bit k-mer index. + + Args: + known_bin_seq: Pre-encoded sequences as integers + kmer_size: Length of k-mers + seq_len: Length of sequences (all must be same length) + """ + self.k: int = kmer_size + total_kmers = int(math.pow(4, kmer_size)) + tmp_index: List[List[int]] = [[] for _ in range(total_kmers)] + self.mask: int = (1 << (2 * self.k)) - 1 # K-mer bit mask + self.seq_len: int = seq_len + self.seq_mask: int = (1 << (2 * self.seq_len)) - 1 # Sequence bit mask + self.total_sequences: int = 0 + self._index(known_bin_seq, tmp_index) + + # Flatten index for better cache performance + self.index: List[int] = [] + self.index_ranges: List[int] = [0] + for l in tmp_index: + self.index += l + self.index_ranges.append(len(self.index)) + + + def _get_kmer_bin_indexes(self, bin_seq: int) -> Iterable[int]: + """ + Extract k-mer indices from a 2-bit encoded sequence. + + Args: + bin_seq: Sequence encoded as integer (2 bits per nucleotide) + + Yields: + Integer k-mer indices + """ + for i in range(self.seq_len - self.k + 1): + # Extract k-mer by shifting and masking + yield (bin_seq >> ((self.seq_len - self.k - i) * 2)) & self.mask + + def _index(self, known_bin_seq: Iterable[int], tmp_index: List[List[int]]) -> None: + """Build k-mer index from 2-bit encoded sequences.""" + for bin_seq in known_bin_seq: + self.total_sequences += 1 + for kmer_idx in self._get_kmer_bin_indexes(bin_seq): + tmp_index[kmer_idx].append(bin_seq) + + def empty(self) -> bool: + """Check if index is empty.""" + return self.total_sequences == 0 + + def get_occurrences(self, sequence: str, max_hits: int = 0, min_kmers: int = 1, + hits_delta: int = 1, ignore_equal: bool = False) -> List[Tuple[str, int, List[int]]]: + """ + Find indexed sequences with shared k-mers. + + Args: + sequence: Query sequence (string, will be converted to 2-bit) + max_hits: Maximum number of results (0 = unlimited) + min_kmers: Minimum shared k-mers required + hits_delta: Include results within this many k-mers of top hit + ignore_equal: Skip exact matches + + Returns: + List of (sequence_str, shared_kmer_count, kmer_positions) tuples + """ + barcode_counts: DefaultDict[int, int] = defaultdict(int) + barcode_positions: DefaultDict[int, List[int]] = defaultdict(list) + + seq = str_to_2bit(sequence) + for pos, kmer_idx in enumerate(self._get_kmer_bin_indexes(seq)): + # Use flat index with ranges for cache-friendly access + start_index = self.index_ranges[kmer_idx] + end_index = self.index_ranges[kmer_idx + 1] + for barcode_index in range(start_index, end_index): + barcode = self.index[barcode_index] + barcode_counts[barcode] += 1 + barcode_positions[barcode].append(pos) + + result = [] + for barcode in barcode_counts.keys(): + count = barcode_counts[barcode] + if count < min_kmers: + continue + if ignore_equal and barcode == seq: + continue + result.append((barcode, count, barcode_positions[barcode])) + + if not result: + return [] + + # Filter top hits + top_hits = max(result, key=lambda x: x[1])[1] + result = filter(lambda x: x[1] >= top_hits - hits_delta, result) + result = sorted(result, reverse=True, key=lambda x: x[1]) + + # Convert 2-bit sequences back to strings + if max_hits == 0: + return [(bit_to_str(x[0], self.seq_len), x[1], x[2]) for x in result] + return [(bit_to_str(x[0], self.seq_len), x[1], x[2]) for x in list(result)[:max_hits]] \ No newline at end of file diff --git a/src/barcode_calling/shared_mem_index.py b/src/barcode_calling/shared_mem_index.py new file mode 100644 index 00000000..d9e0890a --- /dev/null +++ b/src/barcode_calling/shared_mem_index.py @@ -0,0 +1,164 @@ +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ +import math +import numpy +from collections import defaultdict +from multiprocessing import shared_memory +from .common import bit_to_str, str_to_2bit + + +class SharedMemoryIndexInfo: + def __init__(self, barcode_count, kmer_size, seq_len, index_size, barcodes_sm_name, index_sm_name, index_range_sm_name): + self.barcode_count = barcode_count + self.kmer_size = kmer_size + self.seq_len = seq_len + self.index_size = index_size + self.barcodes_sm_name = barcodes_sm_name + self.index_sm_name = index_sm_name + self.index_range_sm_name = index_range_sm_name + + def __getstate__(self): + return (self.barcode_count, + self.kmer_size, + self.seq_len, + self.index_size, + self.barcodes_sm_name, + self.index_sm_name, + self.index_range_sm_name) + + def __setstate__(self, state): + self.barcode_count = state[0] + self.kmer_size = state[1] + self.seq_len = state[2] + self.index_size = state[3] + self.barcodes_sm_name = state[4] + self.index_sm_name = state[5] + self.index_range_sm_name = state[6] + + +class SharedMemoryArray2BitKmerIndexer: + # @params: + # known_bin_seq: collection of strings in binary or string format (barcodes or UMI) + # kmer_size: K to use for indexing + SEQ_DTYPE = numpy.uint64 + KMER_DTYPE = numpy.uint32 + INDEX_DTYPE = numpy.uint32 + + def __init__(self, known_bin_seq: list, kmer_size=12, seq_len=25): + self.main_instance = True + self.k = kmer_size + total_kmers = int(math.pow(4, self.k)) + tmp_index = [] + for i in range(total_kmers): + tmp_index.append([]) + self.mask = (1 << (2 * self.k)) - 1 + self.seq_len = seq_len + self.seq_mask = (1 << (2 * self.seq_len)) - 1 + self._index(known_bin_seq, tmp_index) + self.index_size = sum(len(x) for x in tmp_index) + self.total_sequences = len(known_bin_seq) + self.barcodes_shared_memory = shared_memory.SharedMemory(create=True, size=self.total_sequences*self.SEQ_DTYPE().nbytes) + self.known_bin_seq = numpy.ndarray(shape=(self.total_sequences, ), dtype=self.SEQ_DTYPE, buffer=self.barcodes_shared_memory.buf) + self.known_bin_seq[:] = known_bin_seq[:] + self.index_shared_memory = shared_memory.SharedMemory(create=True, size=self.index_size*self.KMER_DTYPE().nbytes) + self.index = numpy.ndarray(shape=(self.index_size, ), dtype=self.KMER_DTYPE, buffer=self.index_shared_memory.buf) + self.index_ranges_shared_memory = shared_memory.SharedMemory(create=True, size=(total_kmers+1)*self.INDEX_DTYPE().nbytes) + self.index_ranges = numpy.ndarray(shape=(total_kmers + 1, ), dtype=self.INDEX_DTYPE, buffer=self.index_ranges_shared_memory.buf) + self.index_ranges[0] = 0 + i = 0 + index_i = 1 + for l in tmp_index: + for e in l: + self.index[i] = e + i += 1 + self.index_ranges[index_i] = i + index_i += 1 + + def __del__(self): + self.barcodes_shared_memory.close() + self.index_shared_memory.close() + self.index_ranges_shared_memory.close() + if self.main_instance: + self.barcodes_shared_memory.unlink() + self.index_shared_memory.unlink() + self.index_ranges_shared_memory.unlink() + + @classmethod + def from_sharable_info(cls, shared_mem_index_info: SharedMemoryIndexInfo): + kmer_index = cls.__new__(cls) + kmer_index.main_instance = False + kmer_index.k = shared_mem_index_info.kmer_size + kmer_index.mask = (1 << (2 * kmer_index.k)) - 1 + kmer_index.seq_len = shared_mem_index_info.seq_len + kmer_index.seq_mask = (1 << (2 * kmer_index.seq_len)) - 1 + kmer_index.index_size = shared_mem_index_info.index_size + kmer_index.total_sequences = shared_mem_index_info.barcode_count + total_kmers = int(math.pow(4, kmer_index.k)) + kmer_index.barcodes_shared_memory = shared_memory.SharedMemory(create=False, name=shared_mem_index_info.barcodes_sm_name) + kmer_index.known_bin_seq = numpy.ndarray(shape=(kmer_index.total_sequences, ), dtype=SharedMemoryArray2BitKmerIndexer.SEQ_DTYPE, buffer=kmer_index.barcodes_shared_memory.buf) + kmer_index.index_shared_memory = shared_memory.SharedMemory(create=False, name=shared_mem_index_info.index_sm_name) + kmer_index.index = numpy.ndarray(shape=(kmer_index.index_size, ), dtype=kmer_index.KMER_DTYPE, buffer=kmer_index.index_shared_memory.buf) + kmer_index.index_ranges_shared_memory = shared_memory.SharedMemory(create=False, name=shared_mem_index_info.index_range_sm_name) + kmer_index.index_ranges = numpy.ndarray(shape=(total_kmers + 1, ), dtype=kmer_index.INDEX_DTYPE, buffer=kmer_index.index_ranges_shared_memory.buf) + return kmer_index + + def get_sharable_info(self): + return SharedMemoryIndexInfo(self.total_sequences, self.k, self.seq_len, self.index_size, + self.barcodes_shared_memory.name, self.index_shared_memory.name, + self.index_ranges_shared_memory.name) + + def _get_kmer_bin_indexes(self, bin_seq): + for i in range(self.seq_len - self.k + 1): + yield (bin_seq >> ((self.seq_len - self.k - i) * 2)) & self.mask + + def _index(self, known_bin_seq, tmp_index): + for i, bin_seq in enumerate(known_bin_seq): + for kmer_idx in self._get_kmer_bin_indexes(bin_seq): + tmp_index[kmer_idx].append(i) + + def empty(self): + return len(self.known_bin_seq) == 0 + + # @params: + # sequence: a string to be searched against known strings + # max_hits: return at most max_hits candidates + # min_kmers: minimal number of matching k-mers + # @return + # a list of (pattern: str, number of shared kmers: int, their positions: list) + # sorted descending by the number of shared k-m + def get_occurrences(self, sequence, max_hits=0, min_kmers=1, hits_delta=1, ignore_equal=False): + barcode_counts = defaultdict(int) + barcode_positions = defaultdict(list) + + seq = str_to_2bit(sequence) + for pos, kmer_idx in enumerate(self._get_kmer_bin_indexes(seq)): + start_index = self.index_ranges[kmer_idx] + end_index = self.index_ranges[kmer_idx + 1] + for barcode_index in range(start_index, end_index): + barcode = int(self.known_bin_seq[self.index[barcode_index]]) + barcode_counts[barcode] += 1 + barcode_positions[barcode].append(pos) + + result = [] + for barcode in barcode_counts.keys(): + count = barcode_counts[barcode] + if count < min_kmers: + continue + if ignore_equal and barcode == seq: + continue + result.append((barcode, count, barcode_positions[barcode])) + + if not result: + return [] + + top_hits = max(result, key=lambda x: x[1])[1] + result = filter(lambda x: x[1] >= top_hits - hits_delta, result) + result = sorted(result, reverse=True, key=lambda x: x[1]) + + if max_hits == 0: + return [(bit_to_str(x[0], self.seq_len), x[1], x[2]) for x in result] + return [(bit_to_str(x[0], self.seq_len), x[1], x[2]) for x in list(result)[:max_hits]] + \ No newline at end of file diff --git a/src/barcode_calling/umi_filtering.py b/src/barcode_calling/umi_filtering.py new file mode 100644 index 00000000..bd7d2757 --- /dev/null +++ b/src/barcode_calling/umi_filtering.py @@ -0,0 +1,608 @@ +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details +############################################################################ + +import pysam +import sys +import gzip +import os.path +from collections import defaultdict +from typing import Dict, List, Tuple, Set, Optional +import logging +import editdistance +import gffutils + +from src.assignment_loader import create_merging_assignment_loader +from src.isoform_assignment import MatchEventSubtype, ReadAssignment +from src.common import overlaps, junctions_from_blocks + +logger = logging.getLogger('IsoQuant') + + +def format_read_assignment_for_output(read_assignment: ReadAssignment) -> str: + """ + Format ReadAssignment for UMI filtering output file. + + Expects additional_attributes to contain: + - 'transcript_type': Type of transcript (e.g., 'protein_coding') + - 'polya_site': PolyA site position (-1 if none) + - 'cell_type': Cell type/spot from barcode mapping + + Args: + read_assignment: ReadAssignment object with additional_attributes set + + Returns: + Formatted string for output file + """ + exon_blocks = read_assignment.corrected_exons + chr_id = read_assignment.chr_id + strand = read_assignment.strand + + intron_blocks = junctions_from_blocks(exon_blocks) + exons_str = ";%;" + ";%;".join(["%s_%d_%d_%s" % (chr_id, e[0], e[1], strand) for e in exon_blocks]) + introns_str = ";%;" + ";%;".join(["%s_%d_%d_%s" % (chr_id, e[0], e[1], strand) for e in intron_blocks]) + + # Determine read type from assignment type + if read_assignment.assignment_type.is_unique(): + read_type = "known" + elif read_assignment.assignment_type.is_inconsistent(): + read_type = "novel" + elif read_assignment.assignment_type.is_ambiguous(): + read_type = "known_ambiguous" + else: + read_type = "none" + + # Get gene and transcript IDs + if read_assignment.isoform_matches: + gene_id = read_assignment.isoform_matches[0].assigned_gene + transcript_id = read_assignment.isoform_matches[0].assigned_transcript + matching_events = [e.event_type for e in read_assignment.isoform_matches[0].match_subclassifications] + else: + gene_id = "." + transcript_id = "." + matching_events = [] + + # Get additional attributes + transcript_type = read_assignment.additional_attributes.get('transcript_type', 'unknown') + polya_site = read_assignment.additional_attributes.get('polya_site', -1) + cell_type = read_assignment.additional_attributes.get('cell_type', 'None') + + # Format TSS and polyA + polyA = "NoPolyA" + TSS = "NoTSS" + + if isinstance(matching_events, str): + # Legacy string format + if "tss_match" in matching_events: + tss_pos = exon_blocks[-1][1] if strand == "-" else exon_blocks[0][0] + TSS = "%s_%d_%d_%s" % (chr_id, tss_pos, tss_pos, strand) + if "correct_polya" in matching_events and polya_site != -1: + polyA_pos = polya_site + polyA = "%s_%d_%d_%s" % (chr_id, polyA_pos, polyA_pos, strand) + else: + # Event subtype format + if strand == '+': + if any(x in [MatchEventSubtype.terminal_site_match_left, + MatchEventSubtype.terminal_site_match_left_precise] for x in matching_events): + tss_pos = exon_blocks[0][0] + TSS = "%s_%d_%d_%s" % (chr_id, tss_pos, tss_pos, strand) + if (polya_site != -1 and + any(x == MatchEventSubtype.correct_polya_site_right for x in matching_events)): + polyA_pos = polya_site + polyA = "%s_%d_%d_%s" % (chr_id, polyA_pos, polyA_pos, strand) + elif strand == '-': + if any(x in [MatchEventSubtype.terminal_site_match_right, + MatchEventSubtype.terminal_site_match_right_precise] for x in matching_events): + tss_pos = exon_blocks[-1][1] + TSS = "%s_%d_%d_%s" % (chr_id, tss_pos, tss_pos, strand) + if (polya_site != -1 and + any(x == MatchEventSubtype.correct_polya_site_left for x in matching_events)): + polyA_pos = polya_site + polyA = "%s_%d_%d_%s" % (chr_id, polyA_pos, polyA_pos, strand) + + return "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%d\t%s\t%s" % ( + read_assignment.read_id, gene_id, cell_type, read_assignment.barcode, read_assignment.umi, + introns_str, TSS, polyA, exons_str, read_type, len(intron_blocks), transcript_id, transcript_type) + + +class UMIFilter: + """ + UMI-based deduplication filter for barcoded reads. + + Processes reads with UMI (Unique Molecular Identifier) tags to remove PCR/RT duplicates. + Groups reads by gene, barcode, and UMI, selecting representative reads from each molecule. + """ + + def __init__(self, umi_length: int = 0, edit_distance: int = 3, disregard_length_diff: bool = True, + only_unique_assignments: bool = False, only_spliced_reads: bool = False): + """ + Initialize UMI filter. + + Args: + umi_length: Expected UMI length (0 = variable length) + edit_distance: Maximum edit distance for UMI clustering + disregard_length_diff: Whether to ignore length differences in edit distance calculation + only_unique_assignments: Only process uniquely assigned reads + only_spliced_reads: Only process spliced reads + """ + self.umi_length = umi_length + self.max_edit_distance = edit_distance + self.disregard_length_diff = disregard_length_diff + self.only_unique_assignments = only_unique_assignments + self.only_spliced_reads = only_spliced_reads + + self.selected_reads: Set[str] = set() + self.stats: Dict[str, int] = defaultdict(int) + self.unique_gene_barcode: Set[Tuple[str, str]] = set() + self.total_assignments = 0 + self.duplicated_molecule_counts: Dict[int, int] = defaultdict(int) + + self.umi_len_dif_func = (lambda x: 0) if self.umi_length == 0 else (lambda x: -abs(self.umi_length - x)) + + def _find_similar_umi(self, umi: str, trusted_umi_list: List[str]) -> Optional[str]: + """ + Find similar UMI in trusted list within edit distance threshold. + + Args: + umi: Query UMI sequence + trusted_umi_list: List of already accepted UMIs + + Returns: + Most similar UMI if found within threshold, None otherwise + """ + if self.max_edit_distance == -1: + return None if not trusted_umi_list else trusted_umi_list[0] + + similar_umi = None + best_dist = 100 + + for occ in trusted_umi_list: + if self.max_edit_distance == 0: + # Exact match mode + if self.disregard_length_diff: + similar, ed = occ == umi, 0 + elif len(occ) < len(umi): + similar, ed = occ in umi, abs(len(occ) - len(umi)) + else: + similar, ed = umi in occ, abs(len(occ) - len(umi)) + elif occ == umi: + similar, ed = True, 0 + else: + ed = editdistance.eval(occ, umi) + if not self.disregard_length_diff: + ed -= abs(len(occ) - len(umi)) + similar, ed = ed <= self.max_edit_distance, ed + + if similar and ed < best_dist: + similar_umi = occ + best_dist = ed + if best_dist == 0: + break + + return similar_umi + + def _construct_umi_dict(self, molecule_list: List[ReadAssignment]) -> Dict[str, List[ReadAssignment]]: + """ + Group molecules by UMI, clustering similar UMIs. + + Uses greedy clustering: processes UMIs by frequency, merging similar low-frequency + UMIs into high-frequency ones. + + Args: + molecule_list: List of read assignments with UMI information + + Returns: + Dict mapping representative UMI to list of reads + """ + # Count reads per UMI + umi_counter: Dict[str, Set[str]] = defaultdict(set) + for m in molecule_list: + umi_counter[m.umi].add(m.read_id) + + # Create sorting keys: (count, length_penalty, umi_sequence) + umi_sorting_keys = {} + for umi in umi_counter: + if umi is None or umi == "None": + umi_sorting_keys[umi] = (-1, 0, "") + else: + umi_sorting_keys[umi] = (len(umi_counter[umi]), self.umi_len_dif_func(len(umi)), umi) + + # Sort molecules by UMI frequency (process high-frequency UMIs first) + # This ensures reads with same UMI are processed consecutively + molecule_list = sorted(molecule_list, key=lambda ml: umi_sorting_keys[ml.umi], reverse=True) + + umi_dict: Dict[str, List[ReadAssignment]] = defaultdict(list) + trusted_umi_list: List[str] = [] + + for m in molecule_list: + if m.umi is None or m.umi == "None": + # Collect untrusted UMIs together + umi_dict["None"].append(m) + continue + + similar_umi = self._find_similar_umi(m.umi, trusted_umi_list) + if similar_umi is None: + # New distinct UMI + umi_dict[m.umi].append(m) + trusted_umi_list.append(m.umi) + else: + # Merge with existing UMI + umi_dict[similar_umi].append(m) + + return umi_dict + + def _process_duplicates(self, molecule_list: List[ReadAssignment]) -> List[ReadAssignment]: + """ + Process PCR/RT duplicates for a gene-barcode pair. + + Groups reads by UMI and selects representative read for each molecule. + Selection criteria (in order): + 1. Unique assignment > ambiguous + 2. More exons + 3. Longer transcript + + Args: + molecule_list: All reads for this gene-barcode pair + + Returns: + List of selected representative reads + """ + if not molecule_list: + return [] + + if len(molecule_list) == 1: + self.duplicated_molecule_counts[1] += 1 + logger.debug("Unique " + molecule_list[0].read_id) + return molecule_list + + resulting_reads = [] + umi_dict = self._construct_umi_dict(molecule_list) + + for umi in umi_dict.keys(): + duplicate_count = len(umi_dict[umi]) + self.duplicated_molecule_counts[duplicate_count] += 1 + + if duplicate_count == 1: + resulting_reads.append((umi_dict[umi][0], umi)) + continue + + # Select best read from duplicates + best_read = umi_dict[umi][0] + logger.debug("Selecting from:") + for m in umi_dict[umi]: + logger.debug("%s %s" % (m.read_id, m.umi)) + # Prefer unique assignments + if not best_read.assignment_type.is_unique() and m.assignment_type.is_unique(): + best_read = m + # Prefer more exons + elif len(m.corrected_exons) > len(best_read.corrected_exons): + best_read = m + # Prefer longer transcripts + elif (len(m.corrected_exons) == len(best_read.corrected_exons) and + m.corrected_exons[-1][1] - m.corrected_exons[0][0] > + best_read.corrected_exons[-1][1] - best_read.corrected_exons[0][0]): + best_read = m + + # Check for ambiguity in transcript annotation among duplicates with same read ID + polyas = set() + transcript_types = set() + isoform_ids = set() + for m in umi_dict[umi]: + if m.read_id != best_read.read_id: + continue + transcript_types.add(m.additional_attributes.get('transcript_type', 'unknown')) + polyas.add(m.additional_attributes.get('polya_site', -1)) + if m.isoform_matches: + isoform_ids.add(m.isoform_matches[0].assigned_transcript) + + # Resolve ambiguities by updating additional_attributes + if len(transcript_types) > 1: + best_read.set_additional_attribute('transcript_type', 'None') + if len(polyas) > 1: + best_read.set_additional_attribute('polya_site', -1) + if len(isoform_ids) > 1: + # Clear the transcript ID by removing all matches except the gene + if best_read.isoform_matches: + best_read.isoform_matches[0].assigned_transcript = "None" + + # Clear annotation for inconsistent assignments + if best_read.assignment_type.is_inconsistent(): + if best_read.isoform_matches: + best_read.isoform_matches[0].assigned_transcript = "None" + best_read.set_additional_attribute('polya_site', -1) + best_read.set_additional_attribute('transcript_type', 'None') + + logger.debug("Selected %s %s" % (best_read.read_id, best_read.umi)) + resulting_reads.append((best_read, umi)) + + # If single UMI, trust it regardless of whether it's marked as trusted + if len(resulting_reads) == 1: + return [resulting_reads[0][0]] + + # With multiple UMIs, ignore untrusted ones + return [x[0] for x in filter(lambda x: x[1] != "None", resulting_reads)] + + def _process_gene(self, gene_dict: Dict[str, List[ReadAssignment]]): + """ + Process all barcodes for a gene. + + Args: + gene_dict: Dict mapping barcode to list of reads + + Yields: + Representative read for each molecule + """ + for barcode in gene_dict: + for r in self._process_duplicates(gene_dict[barcode]): + yield r + + def _process_chunk(self, gene_barcode_dict: Dict[str, Dict[str, List[ReadAssignment]]], + allinfo_outf, read_ids_outf=None) -> Tuple[int, int]: + """ + Process a chunk of reads grouped by gene and barcode. + + Args: + gene_barcode_dict: Nested dict[gene_id][barcode] -> list of reads + allinfo_outf: Output file for detailed assignment info + read_ids_outf: Optional output file for selected read IDs + + Returns: + Tuple of (total_read_count, spliced_read_count) + """ + read_count = 0 + spliced_count = 0 + + for gene_id in gene_barcode_dict: + for read_assignment in self._process_gene(gene_barcode_dict[gene_id]): + # Skip non-unique reads if already selected + if (not read_assignment.assignment_type.is_unique() and + read_assignment.read_id in self.selected_reads): + continue + + read_count += 1 + if read_assignment.isoform_matches: + gene_id = read_assignment.isoform_matches[0].assigned_gene + self.unique_gene_barcode.add((gene_id, read_assignment.barcode)) + + if len(read_assignment.corrected_exons) > 1: + spliced_count += 1 + + if read_ids_outf: + read_ids_outf.write(read_assignment.read_id + "\n") + + self.selected_reads.add(read_assignment.read_id) + + if allinfo_outf: + allinfo_outf.write(format_read_assignment_for_output(read_assignment) + "\n") + + return read_count, spliced_count + + def add_stats_for_read(self, read_assignment: ReadAssignment): + """ + Update statistics for a processed read. + + Args: + read_assignment: ReadAssignment object + """ + gene_id = "." + if read_assignment.isoform_matches: + gene_id = read_assignment.isoform_matches[0].assigned_gene + + assigned = gene_id != "." + spliced = len(read_assignment.corrected_exons) > 1 + barcoded = read_assignment.barcode is not None + unique = assigned # In simplified logic, we only process unique/consistent assignments + + if assigned: + self.stats["Assigned to any gene"] += 1 + if spliced: + self.stats["Spliced"] += 1 + if unique: + self.stats["Uniquely assigned"] += 1 + if unique and spliced: + self.stats["Uniquely assigned and spliced"] += 1 + if barcoded: + if assigned: + self.stats["Assigned to any gene and barcoded"] += 1 + if spliced: + self.stats["Spliced and barcoded"] += 1 + if unique: + self.stats["Uniquely assigned and barcoded"] += 1 + if unique and spliced: + self.stats["Uniquely assigned and spliced and barcoded"] += 1 + + def process_single_chr(self, chr_id: str, saves_prefix: str, transcript_type_dict: Dict[str, Tuple[str, int]], + barcode_feature_table: Dict[str, str], + all_info_file_name: str, filtered_reads_file_name: Optional[str], + stats_output_file_name: str) -> Tuple[str, str]: + """ + Process UMI filtering for a single chromosome. + + Args: + chr_id: Chromosome ID + saves_prefix: Prefix for saved assignment files + transcript_type_dict: Dict mapping transcript_id to (type, polya_site) + barcode_feature_table: Dict mapping barcode to cell type + all_info_file_name: Output file for detailed assignment info + filtered_reads_file_name: Optional output file for filtered read IDs + stats_output_file_name: Output file for statistics + + Returns: + Tuple of (all_info_file_name, stats_output_file_name) + """ + with open(all_info_file_name, "w") as allinfo_outf: + filtered_reads_outf = open(filtered_reads_file_name, "w") if filtered_reads_file_name else None + read_count = 0 + spliced_count = 0 + self.unique_gene_barcode = set() + + loader = create_merging_assignment_loader(chr_id, saves_prefix) + while loader.has_next(): + gene_barcode_dict: Dict[str, Dict[str, List[ReadAssignment]]] = defaultdict(lambda: defaultdict(list)) + _, assignment_storage = loader.get_next() + logger.debug("Processing %d reads" % len(assignment_storage)) + + for read_assignment in assignment_storage: + # Skip unassigned reads + if read_assignment.gene_assignment_type.is_unassigned(): + continue + + # Count ambiguous assignments but don't process them + if read_assignment.gene_assignment_type.is_ambiguous(): + self.stats["Assigned to any gene"] += 1 + if len(read_assignment.corrected_exons) > 1: + self.stats["Spliced"] += 1 + continue + + exon_blocks = read_assignment.corrected_exons + barcode = read_assignment.barcode + umi = read_assignment.umi + spliced = len(exon_blocks) > 1 + barcoded = barcode is not None + + # Get cell type from barcode feature table + cell_type = "None" if barcode is None or barcode not in barcode_feature_table else barcode_feature_table[barcode] + read_assignment.set_additional_attribute('cell_type', cell_type) + + # Process based on number of isoform matches + if len(read_assignment.isoform_matches) == 1: + # Single match - simplest case + isoform_match = read_assignment.isoform_matches[0] + if read_assignment.assignment_type.is_consistent(): + transcript_id = isoform_match.assigned_transcript + transcript_type, polya_site = ( + transcript_type_dict[transcript_id] if transcript_id in transcript_type_dict + else ("unknown_type", -1)) + read_assignment.set_additional_attribute('transcript_type', transcript_type) + read_assignment.set_additional_attribute('polya_site', polya_site) + else: + # Inconsistent single match + read_assignment.set_additional_attribute('transcript_type', 'unknown_type') + read_assignment.set_additional_attribute('polya_site', -1) + + elif read_assignment.assignment_type.is_consistent(): + # Multiple consistent matches - need to resolve ambiguity + transcript_types = set() + polya_sites = set() + gene_ids = set() + + for m in read_assignment.isoform_matches: + transcript_id = m.assigned_transcript + gene_ids.add(m.assigned_gene) + transcript_type, polya_site = ( + transcript_type_dict[transcript_id] if transcript_id in transcript_type_dict + else ("unknown_type", -1)) + transcript_types.add(transcript_type) + polya_sites.add(polya_site) + + assert len(gene_ids) == 1, "Multiple genes assigned to a single read" + + # Use consensus if unique, otherwise mark as ambiguous + transcript_type = "None" if len(transcript_types) != 1 else transcript_types.pop() + polya_site = -1 if len(polya_sites) != 1 else polya_sites.pop() + + read_assignment.set_additional_attribute('transcript_type', transcript_type) + read_assignment.set_additional_attribute('polya_site', polya_site) + + # Set transcript_id to "None" for ambiguous case + if read_assignment.isoform_matches: + read_assignment.isoform_matches[0].assigned_transcript = "None" + else: + # Multiple inconsistent matches - skip + self.total_assignments += 1 + continue + + self.total_assignments += 1 + self.add_stats_for_read(read_assignment) + + # Filter for deduplication + if not barcoded: + continue + if not spliced and self.only_spliced_reads: + continue + + # Add to gene-barcode dict using gene from first isoform match + if read_assignment.isoform_matches: + gene_id = read_assignment.isoform_matches[0].assigned_gene + gene_barcode_dict[gene_id][barcode].append(read_assignment) + + # Process chunk + processed_read_count, processed_spliced_count = self._process_chunk( + gene_barcode_dict, allinfo_outf, filtered_reads_outf) + read_count += processed_read_count + spliced_count += processed_spliced_count + + if filtered_reads_outf: + filtered_reads_outf.close() + + # Write statistics + with open(stats_output_file_name, "w") as count_hist_file: + count_hist_file.write("Unique gene-barcodes pairs\t%d\n" % len(self.unique_gene_barcode)) + count_hist_file.write("Total reads saved\t%d\n" % read_count) + count_hist_file.write("Spliced reads saved\t%d\n" % spliced_count) + count_hist_file.write("Total assignments processed\t%d\n" % self.total_assignments) + for k in sorted(self.stats.keys()): + count_hist_file.write("%s\t%d\n" % (k, self.stats[k])) + + return all_info_file_name, stats_output_file_name + + +def filter_bam(in_file_name: str, out_file_name: str, read_set: Set[str]): + """ + Filter BAM file to keep only reads in the provided set. + + Args: + in_file_name: Input BAM file path + out_file_name: Output BAM file path + read_set: Set of read IDs to keep + """ + inf = pysam.AlignmentFile(in_file_name, "rb") + outf = pysam.AlignmentFile(out_file_name, "wb", template=inf) + + count = 0 + passed = 0 + + for read in inf: + if read.reference_id == -1 or read.is_secondary: + continue + + count += 1 + if count % 10000 == 0: + sys.stdout.write("Processed " + str(count) + " reads\r") + + if read.query_name in read_set: + outf.write(read) + passed += 1 + + print("Processed " + str(count) + " reads, written " + str(passed)) + inf.close() + outf.close() + pysam.index(out_file_name) + + +def create_transcript_info_dict(genedb: str, chr_ids: Optional[List[str]] = None) -> Dict[str, Tuple[str, int]]: + """ + Create dictionary of transcript types and polyA sites from gene database. + + Args: + genedb: Path to gffutils database + chr_ids: Optional list of chromosome IDs to filter + + Returns: + Dict mapping transcript_id to (transcript_type, polya_site) tuple + """ + gffutils_db = gffutils.FeatureDB(genedb) + transcript_type_dict = {} + + for t in gffutils_db.features_of_type(('transcript', 'mRNA')): + if chr_ids and t.seqid not in chr_ids: + continue + polya_site = t.start - 1 if t.strand == '-' else t.end + 1 + if "transcript_type" in t.attributes.keys(): + transcript_type_dict[t.id] = (t.attributes["transcript_type"][0], polya_site) + else: + transcript_type_dict[t.id] = ("unknown_type", polya_site) + + return transcript_type_dict diff --git a/src/common.py b/src/common.py index 527cc615..6f843c07 100644 --- a/src/common.py +++ b/src/common.py @@ -113,10 +113,6 @@ def proper_plural_form(name, count): return str(count) + " " + name + ("" if count == 1 else "s") -def convert_chr_id_to_file_name_str(chr_id: str): - return chr_id.replace('/', '_') - - # check whether genes overlap and should be processed together def genes_overlap(gene_db1, gene_db2): if gene_db1.seqid != gene_db2.seqid: diff --git a/src/dataset_processor.py b/src/dataset_processor.py index 2aee1c2d..a9fbae43 100644 --- a/src/dataset_processor.py +++ b/src/dataset_processor.py @@ -6,22 +6,22 @@ ############################################################################ import glob -import gzip import itertools import logging -import os import shutil +import sys from enum import Enum, unique from collections import defaultdict from concurrent.futures import ProcessPoolExecutor +from functools import partial import gffutils import pysam from pyfaidx import Fasta -from .common import proper_plural_form, convert_chr_id_to_file_name_str +from .modes import IsoQuantMode +from .common import proper_plural_form from .serialization import * -from .isoform_assignment import BasicReadAssignment, ReadAssignmentType, ReadAssignment from .stats import EnumStats from .file_utils import merge_files, merge_counts from .input_data_storage import SampleData @@ -33,10 +33,11 @@ create_gene_counter, create_transcript_counter, ) -from .multimap_resolver import MultimapResolver from .read_groups import ( create_read_grouper, - prepare_read_groups + prepare_read_groups, + load_table, + get_grouping_strategy_names ) from .assignment_io import ( IOSupport, @@ -46,36 +47,20 @@ BasicTSVAssignmentPrinter, TmpFileAssignmentPrinter, ) -from .read_assignment_loader import BasicReadAssignmentLoader, ReadAssignmentLoader +from .read_assignment_loader import BasicReadAssignmentLoader from .processed_read_manager import ProcessedReadsManagerHighMemory, ProcessedReadsManagerNoSecondary, ProcessedReadsManagerNormalMemory from .id_policy import SimpleIDDistributor, ExcludingIdDistributor, FeatureIdStorage +from .file_naming import * from .transcript_printer import GFFPrinter, VoidTranscriptPrinter, create_extended_storage from .graph_based_model_construction import GraphBasedModelConstructor from .gene_info import TranscriptModelType, get_all_chromosome_genes, get_all_chromosome_transcripts +from .assignment_loader import create_assignment_loader, BasicReadAssignmentLoader +from .barcode_calling.umi_filtering import create_transcript_info_dict, UMIFilter +from .table_splitter import split_read_table_parallel logger = logging.getLogger('IsoQuant') -def reads_collected_lock_file_name(sample_out_raw, chr_id): - return "{}_{}_collected".format(sample_out_raw, convert_chr_id_to_file_name_str(chr_id)) - - -def reads_processed_lock_file_name(dump_filename, chr_id): - chr_dump_file = dump_filename + "_" + convert_chr_id_to_file_name_str(chr_id) - return "{}_processed".format(chr_dump_file) - - -def read_group_lock_filename(sample): - return sample.read_group_file + "_lock" - - -def clean_locks(chr_ids, base_name, fname_function): - for chr_id in chr_ids: - fname = fname_function(base_name, chr_id) - if os.path.exists(fname): - os.remove(fname) - - @unique class PolyAUsageStrategies(Enum): auto = 1 @@ -98,7 +83,7 @@ def collect_reads_in_parallel(sample, chr_id, args, processed_read_manager_type) current_chr_record = str(current_chr_record) read_grouper = create_read_grouper(args, sample, chr_id) lock_file = reads_collected_lock_file_name(sample.out_raw_file, chr_id) - save_file = "{}_{}".format(sample.out_raw_file, convert_chr_id_to_file_name_str(chr_id)) + save_file = saves_file_name(sample.out_raw_file, chr_id) group_file = save_file + "_groups" bamstat_file = save_file + "_bamstat" processed_reads_manager = processed_read_manager_type(sample, args.multimap_strategy) @@ -125,16 +110,32 @@ def collect_reads_in_parallel(sample, chr_id, args, processed_read_manager_type) logger.warning("%s does not exist" % save_file) os.remove(lock_file) + if os.path.exists(lock_file): + os.remove(lock_file) + tmp_printer = TmpFileAssignmentPrinter(save_file, args) bam_files = list(map(lambda x: x[0], sample.file_list)) bam_file_pairs = [(pysam.AlignmentFile(bam, "rb", require_index=True), bam) for bam in bam_files] gffutils_db = gffutils.FeatureDB(args.genedb) if args.genedb else None illumina_bam = sample.illumina_bam + # Load barcode dict for this chromosome if available + barcode_dict = {} + if sample.barcodes_split_reads: + barcode_file = sample.barcodes_split_reads + "_" + chr_id + if os.path.exists(barcode_file): + logger.debug(f"Loading barcodes from {barcode_file}") + for line in open(barcode_file): + if line.startswith("#"): + continue + parts = line.split() + barcode_dict[parts[0]] = (parts[1], parts[2]) + logger.debug("Loaded %d barcodes" % len(barcode_dict)) + logger.info("Processing chromosome " + chr_id) alignment_collector = \ AlignmentCollector(chr_id, bam_file_pairs, args, illumina_bam, gffutils_db, current_chr_record, read_grouper, - args.max_coverage_small_chr, args.max_coverage_normal_chr) + barcode_dict, args.max_coverage_small_chr, args.max_coverage_normal_chr) for gene_info, assignment_storage in alignment_collector.process(): tmp_printer.add_gene_info(gene_info) @@ -157,45 +158,37 @@ def collect_reads_in_parallel(sample, chr_id, args, processed_read_manager_type) return chr_id, read_grouper.read_groups, alignment_collector.alignment_stat_counter, processed_reads_manager -def construct_models_in_parallel(sample, chr_id, dump_filename, args, read_groups): +def construct_models_in_parallel(sample, chr_id, saves_prefix, args, read_groups): logger.info("Processing chromosome " + chr_id) - construct_models = not args.no_model_construction - current_chr_record = Fasta(args.reference, indexname=args.fai_file_name)[chr_id] + use_filtered_reads = args.mode.needs_pcr_deduplication() + loader = create_assignment_loader(chr_id, saves_prefix, args.genedb, args.reference, args.fai_file_name, use_filtered_reads) - multimapped_reads = defaultdict(list) - multimap_loader = open(dump_filename + "_multimappers_" + convert_chr_id_to_file_name_str(chr_id), "rb") - list_size = read_int(multimap_loader) - while list_size != TERMINATION_INT: - for i in range(list_size): - a = BasicReadAssignment.deserialize(multimap_loader) - if a.chr_id == chr_id: - multimapped_reads[a.read_id].append(a) - list_size = read_int(multimap_loader) - - chr_dump_file = dump_filename + "_" + convert_chr_id_to_file_name_str(chr_id) - lock_file = reads_processed_lock_file_name(dump_filename, chr_id) + chr_dump_file = saves_file_name(saves_prefix, chr_id) + lock_file = reads_processed_lock_file_name(saves_prefix, chr_id) read_stat_file = "{}_read_stat".format(chr_dump_file) transcript_stat_file = "{}_transcript_stat".format(chr_dump_file) + construct_models = not args.no_model_construction - if os.path.exists(lock_file) and args.resume: - logger.info("Processed assignments from chromosome " + chr_id + " detected") - read_stat = EnumStats(read_stat_file) - transcript_stat = EnumStats(transcript_stat_file) if construct_models else EnumStats() - return read_stat, transcript_stat + if os.path.exists(lock_file): + if args.resume: + logger.info("Processed assignments from chromosome " + chr_id + " detected") + read_stat = EnumStats(read_stat_file) + transcript_stat = EnumStats(transcript_stat_file) if construct_models else EnumStats() + return read_stat, transcript_stat + os.remove(lock_file) - if args.genedb: - gffutils_db = gffutils.FeatureDB(args.genedb) - else: - gffutils_db = None - aggregator = ReadAssignmentAggregator(args, sample, read_groups, gffutils_db, chr_id) + grouping_strategy_names = get_grouping_strategy_names(args) + aggregator = ReadAssignmentAggregator(args, sample, read_groups, loader.genedb, chr_id, + grouping_strategy_names=grouping_strategy_names) transcript_stat_counter = EnumStats() io_support = IOSupport(args) - transcript_id_distributor = ExcludingIdDistributor(gffutils_db, chr_id) - exon_id_storage = FeatureIdStorage(SimpleIDDistributor(), gffutils_db, chr_id, "exon") + transcript_id_distributor = ExcludingIdDistributor(loader.genedb, chr_id) + exon_id_storage = FeatureIdStorage(SimpleIDDistributor(), loader.genedb, chr_id, "exon") if construct_models: tmp_gff_printer = GFFPrinter(sample.out_dir, sample.prefix, exon_id_storage, + output_r2t=not args.no_large_files, check_canonical=args.check_canonical) else: tmp_gff_printer = VoidTranscriptPrinter() @@ -210,8 +203,6 @@ def construct_models_in_parallel(sample, chr_id, dump_filename, args, read_group if args.sqanti_output else VoidTranscriptPrinter() novel_model_storage = [] - - loader = ReadAssignmentLoader(chr_dump_file, gffutils_db, current_chr_record, multimapped_reads) while loader.has_next(): gene_info, assignment_storage = loader.get_next() logger.debug("Processing %d reads" % len(assignment_storage)) @@ -223,10 +214,17 @@ def construct_models_in_parallel(sample, chr_id, dump_filename, args, read_group aggregator.global_counter.add_read_info(read_assignment) if construct_models: - model_constructor = GraphBasedModelConstructor(gene_info, current_chr_record, args, + transcript_grouped = aggregator.transcript_model_grouped_counters if hasattr(aggregator, 'transcript_model_grouped_counters') else [] + gene_grouped = aggregator.gene_model_grouped_counters if hasattr(aggregator, 'gene_model_grouped_counters') else [] + strategy_names = aggregator.grouping_strategy_names if hasattr(aggregator, 'grouping_strategy_names') else [] + model_constructor = GraphBasedModelConstructor(gene_info, loader.chr_record, args, aggregator.transcript_model_global_counter, aggregator.gene_model_global_counter, - transcript_id_distributor) + transcript_id_distributor, + transcript_grouped_counters=transcript_grouped, + gene_grouped_counters=gene_grouped, + grouping_strategy_names=strategy_names, + use_technical_replicas=sample.use_technical_replicas) model_constructor.process(assignment_storage) if args.check_canonical: io_support.add_canonical_info(model_constructor.transcript_model_storage, gene_info) @@ -243,8 +241,8 @@ def construct_models_in_parallel(sample, chr_id, dump_filename, args, read_group aggregator.global_counter.dump() aggregator.read_stat_counter.dump(read_stat_file) if construct_models: - if gffutils_db: - all_models, gene_info = create_extended_storage(gffutils_db, chr_id, current_chr_record, novel_model_storage) + if loader.genedb: + all_models, gene_info = create_extended_storage(loader.genedb, chr_id, loader.chr_record, novel_model_storage) if args.check_canonical: io_support.add_canonical_info(all_models, gene_info) tmp_extended_gff_printer.dump(gene_info, all_models) @@ -257,10 +255,42 @@ def construct_models_in_parallel(sample, chr_id, dump_filename, args, read_group return aggregator.read_stat_counter, transcript_stat_counter +def filter_umis_in_parallel(sample, chr_id, args, edit_distance, output_filtered_reads=False): + transcript_type_dict = create_transcript_info_dict(args.genedb, [chr_id]) + umi_filtered_done = umi_filtered_lock_file_name(sample.out_umi_filtered_done, chr_id, edit_distance) + all_info_file_name = allinfo_file_name(sample.out_umi_filtered_done, chr_id, edit_distance) + stats_output_file_name = allinfo_stats_file_name(sample.out_umi_filtered_done, chr_id, edit_distance) + + if os.path.exists(umi_filtered_done): + if args.resume: + return all_info_file_name, stats_output_file_name, umi_filtered_done + os.remove(umi_filtered_done) + + logger.info("Filtering PCR duplicates for chromosome " + chr_id) + barcode_feature_table = {} + if args.barcode2spot: + for barcode2spot_file in args.barcode2spot: + barcode_feature_table.update(load_table(barcode2spot_file, 0, 1, '\t')) + + umi_filter = UMIFilter(args.umi_length, edit_distance) + filtered_reads = filtered_reads_file_name(sample.out_raw_file, chr_id) if output_filtered_reads else None + umi_filter.process_single_chr(chr_id, sample.out_raw_file, + transcript_type_dict, + barcode_feature_table, + all_info_file_name, + filtered_reads, + stats_output_file_name) + open(umi_filtered_done, "w").close() + logger.info("PCR duplicates filtered for chromosome " + chr_id) + + return all_info_file_name, stats_output_file_name, umi_filtered_done + + class ReadAssignmentAggregator: - def __init__(self, args, sample, read_groups, gffutils_db=None, chr_id=None, gzipped=False): + def __init__(self, args, sample, read_groups, gffutils_db=None, chr_id=None, gzipped=False, grouping_strategy_names=None): self.args = args self.read_groups = read_groups + self.grouping_strategy_names = grouping_strategy_names if grouping_strategy_names else ["default"] self.common_header = "# Command line: " + args._cmd_line + "\n# IsoQuant version: " + args._version + "\n" self.io_support = IOSupport(self.args) @@ -271,14 +301,20 @@ def __init__(self, args, sample, read_groups, gffutils_db=None, chr_id=None, gzi self.transcript_set = set(get_all_chromosome_transcripts(gffutils_db, chr_id)) self.read_stat_counter = EnumStats() - self.corrected_bed_printer = BEDPrinter(sample.out_corrected_bed, - self.args, - print_corrected=True, - gzipped=gzipped) - printer_list = [self.corrected_bed_printer] - if self.args.genedb: + + printer_list = [] + self.corrected_bed_printer = None + if not self.args.no_large_files: + self.corrected_bed_printer = BEDPrinter(sample.out_corrected_bed, + self.args, + print_corrected=True, + gzipped=gzipped) + printer_list.append(self.corrected_bed_printer) + self.basic_printer = None + if self.args.genedb and not self.args.no_large_files: self.basic_printer = BasicTSVAssignmentPrinter(sample.out_assigned_tsv, self.args, self.io_support, additional_header=self.common_header, gzipped=gzipped) + sample.out_assigned_tsv_result = self.basic_printer.output_file_name printer_list.append(self.basic_printer) self.t2t_sqanti_printer = VoidTranscriptPrinter() if self.args.sqanti_output: @@ -312,47 +348,83 @@ def __init__(self, args, sample, read_groups, gffutils_db=None, chr_id=None, gzi self.global_counter.add_counters([self.exon_counter, self.intron_counter]) if self.args.read_group and self.args.genedb: - self.gene_grouped_counter = create_gene_counter(sample.out_gene_grouped_counts_tsv, - self.args.gene_quantification, - complete_feature_list=self.gene_set, - read_groups=self.read_groups) - self.transcript_grouped_counter = create_transcript_counter(sample.out_transcript_grouped_counts_tsv, - self.args.transcript_quantification, - complete_feature_list=self.transcript_set, - read_groups=self.read_groups) - self.global_counter.add_counters([self.gene_grouped_counter, self.transcript_grouped_counter]) - - if self.args.count_exons: - self.exon_grouped_counter = ExonCounter(sample.out_exon_grouped_counts_tsv) - self.intron_grouped_counter = IntronCounter(sample.out_intron_grouped_counts_tsv) - self.global_counter.add_counters([self.exon_grouped_counter, self.intron_grouped_counter]) + self.gene_grouped_counters = [] + self.transcript_grouped_counters = [] + self.exon_grouped_counters = [] + self.intron_grouped_counters = [] + + for group_idx, strategy_name in enumerate(self.grouping_strategy_names): + # Add strategy name as suffix to output file + gene_out_file = f"{sample.out_gene_grouped_counts_tsv}.{strategy_name}" + transcript_out_file = f"{sample.out_transcript_grouped_counts_tsv}.{strategy_name}" + + gene_counter = create_gene_counter(gene_out_file, + self.args.gene_quantification, + complete_feature_list=self.gene_set, + read_groups=self.read_groups, + group_index=group_idx) + transcript_counter = create_transcript_counter(transcript_out_file, + self.args.transcript_quantification, + complete_feature_list=self.transcript_set, + read_groups=self.read_groups, + group_index=group_idx) + + self.gene_grouped_counters.append(gene_counter) + self.transcript_grouped_counters.append(transcript_counter) + self.global_counter.add_counters([gene_counter, transcript_counter]) + + if self.args.count_exons: + exon_out_file = f"{sample.out_exon_grouped_counts_tsv}.{strategy_name}" + intron_out_file = f"{sample.out_intron_grouped_counts_tsv}.{strategy_name}" + exon_counter = ExonCounter(exon_out_file, group_index=group_idx) + intron_counter = IntronCounter(intron_out_file, group_index=group_idx) + self.exon_grouped_counters.append(exon_counter) + self.intron_grouped_counters.append(intron_counter) + self.global_counter.add_counters([exon_counter, intron_counter]) if self.args.read_group and not self.args.no_model_construction: - self.transcript_model_grouped_counter = create_transcript_counter( - sample.out_transcript_model_grouped_counts_tsv, - self.args.transcript_quantification, - read_groups=self.read_groups) - self.gene_model_grouped_counter = create_gene_counter( - sample.out_gene_model_grouped_counts_tsv, - self.args.gene_quantification, - read_groups=self.read_groups) - self.transcript_model_global_counter.add_counters([self.transcript_model_grouped_counter]) - self.gene_model_global_counter.add_counters([self.gene_model_grouped_counter]) + self.transcript_model_grouped_counters = [] + self.gene_model_grouped_counters = [] + + for group_idx, strategy_name in enumerate(self.grouping_strategy_names): + transcript_model_out_file = f"{sample.out_transcript_model_grouped_counts_tsv}.{strategy_name}" + gene_model_out_file = f"{sample.out_gene_model_grouped_counts_tsv}.{strategy_name}" + + transcript_model_counter = create_transcript_counter( + transcript_model_out_file, + self.args.transcript_quantification, + read_groups=self.read_groups, + group_index=group_idx) + gene_model_counter = create_gene_counter( + gene_model_out_file, + self.args.gene_quantification, + read_groups=self.read_groups, + group_index=group_idx) + + self.transcript_model_grouped_counters.append(transcript_model_counter) + self.gene_model_grouped_counters.append(gene_model_counter) + self.transcript_model_global_counter.add_counters([transcript_model_counter]) + self.gene_model_global_counter.add_counters([gene_model_counter]) # Class for processing all samples against gene database class DatasetProcessor: def __init__(self, args): self.args = args + self.input_data = args.input_data self.args.gunzipped_reference = None self.common_header = "# Command line: " + args._cmd_line + "\n# IsoQuant version: " + args._version + "\n" self.io_support = IOSupport(self.args) self.all_read_groups = set() self.alignment_stat_counter = EnumStats() + self.transcript_type_dict = {} if args.genedb: logger.info("Loading gene database from " + self.args.genedb) self.gffutils_db = gffutils.FeatureDB(self.args.genedb) + # TODO remove + if self.args.mode.needs_pcr_deduplication(): + self.transcript_type_dict = create_transcript_info_dict(self.args.genedb) else: self.gffutils_db = None @@ -371,13 +443,22 @@ def clean_up(self): if os.path.exists(self.args.gunzipped_reference): os.remove(self.args.gunzipped_reference) + for sample in self.input_data.samples: + if not self.args.read_assignments and not self.args.keep_tmp: + for f in glob.glob(sample.out_raw_file + "_*"): + os.remove(f) + for f in glob.glob(sample.read_group_file + "*"): + os.remove(f) + if self.args.mode.needs_pcr_deduplication(): + os.remove(umi_filtered_global_lock_file_name(sample.out_umi_filtered_done)) + def process_all_samples(self, input_data): logger.info("Processing " + proper_plural_form("experiment", len(input_data.samples))) logger.info("Secondary alignments will%s be used" % ("" if self.args.use_secondary else " not")) for sample in input_data.samples: self.process_sample(sample) self.clean_up() - logger.info("Processed " + proper_plural_form("experiment", len(input_data.samples))) + logger.info("Processed " + proper_plural_form("experiment", len(self.input_data.samples))) # Run through all genes in db and count stats according to alignments given in bamfile_name def process_sample(self, sample): @@ -385,20 +466,41 @@ def process_sample(self, sample): logger.info("Experiment has " + proper_plural_form("BAM file", len(sample.file_list)) + ": " + ", ".join( map(lambda x: x[0], sample.file_list))) self.chr_ids = self.get_chromosome_ids(sample) + logger.info("Total number of chromosomes to be processed %d: %s " % (len(self.chr_ids), ", ".join(map(lambda x: str(x), sorted(self.chr_ids))))) - self.args.use_technical_replicas = self.args.read_group == "file_name" and len(sample.file_list) > 1 + + # Check if file_name grouping is enabled for this sample (for technical replicas) + sample.use_technical_replicas = (len(sample.file_list) > 1 and + self.args.read_group is not None and + "file_name" in self.args.read_group) self.all_read_groups = set() - if self.args.resume and os.path.exists(sample.read_group_file + "_lock"): + fname = read_group_lock_filename(sample) + if self.args.resume and os.path.exists(fname): logger.info("Read group table was split during the previous run, existing files will be used") else: - fname = read_group_lock_filename(sample) if os.path.exists(fname): os.remove(fname) prepare_read_groups(self.args, sample) open(fname, "w").close() + if self.args.mode.needs_pcr_deduplication(): + if self.args.barcoded_reads: + sample.barcoded_reads = self.args.barcoded_reads + + split_barcodes_dict = {} + for chr_id in self.get_chr_list(): + split_barcodes_dict[chr_id] = sample.barcodes_split_reads + "_" + chr_id + barcode_split_done = split_barcodes_lock_filename(sample) + if self.args.resume and os.path.exists(barcode_split_done): + logger.info("Barcode table was split during the previous run, existing files will be used") + else: + if os.path.exists(barcode_split_done): + os.remove(barcode_split_done) + self.split_read_barcode_table(sample, split_barcodes_dict) + open(barcode_split_done, "w").close() + if self.args.read_assignments: saves_file = self.args.read_assignments[0] logger.info('Using read assignments from {}*'.format(saves_file)) @@ -410,8 +512,17 @@ def process_sample(self, sample): if not self.args.keep_tmp: logger.info("To keep these intermediate files for debug purposes use --keep_tmp flag") + if self.args.mode.needs_pcr_deduplication(): + self.filter_umis(sample) + total_assignments, polya_found, self.all_read_groups = self.load_read_info(saves_file) + if self.args.mode.needs_pcr_deduplication(): + # move clean-up somewhere else + for bc_split_file in split_barcodes_dict.values(): + os.remove(bc_split_file) + os.remove(barcode_split_done) + polya_fraction = polya_found / total_assignments if total_assignments > 0 else 0.0 logger.info("Total assignments used for analysis: %d, polyA tail detected in %d (%.1f%%)" % (total_assignments, polya_found, polya_fraction * 100.0)) @@ -433,11 +544,6 @@ def process_sample(self, sample): self.args.polya_requirement_strategy) self.process_assigned_reads(sample, saves_file) - if not self.args.read_assignments and not self.args.keep_tmp: - for f in glob.glob(saves_file + "_*"): - os.remove(f) - for f in glob.glob(sample.read_group_file + "*"): - os.remove(f) logger.info("Processed experiment " + sample.prefix) def keep_only_defined_chromosomes(self, chr_set: set): @@ -492,6 +598,7 @@ def get_chromosome_ids(self, sample): logger.warning("Chromosome list from the gene annotation is not the same as the chromosome list from" " the reference genomes or BAM file(s). Please, check you input data.") logger.warning("Only %d overlapping chromosomes will be processed." % len(common_overlap)) + return list(sorted( common_overlap, key=lambda x: len(self.reference_record_dict[x]), @@ -577,7 +684,9 @@ def process_assigned_reads(self, sample, dump_filename): ("off" if self.args.no_model_construction else "on")) # set up aggregators and outputs - aggregator = ReadAssignmentAggregator(self.args, sample, self.all_read_groups, gzipped=self.args.gzipped) + grouping_strategy_names = get_grouping_strategy_names(self.args) + aggregator = ReadAssignmentAggregator(self.args, sample, self.all_read_groups, gzipped=self.args.gzipped, + grouping_strategy_names=grouping_strategy_names) transcript_stat_counter = EnumStats() gff_printer = VoidTranscriptPrinter() @@ -599,7 +708,10 @@ def process_assigned_reads(self, sample, dump_filename): # not intended for dumping transcript models directly exon_id_storage = FeatureIdStorage(SimpleIDDistributor()) gff_printer = GFFPrinter( - sample.out_dir, sample.prefix, exon_id_storage, header=self.common_header, gzipped=self.args.gzipped + sample.out_dir, sample.prefix, exon_id_storage, + output_r2t=not self.args.no_large_files, + header=self.common_header, + gzipped=self.args.gzipped ) if self.args.genedb: extended_gff_printer = GFFPrinter( @@ -610,7 +722,7 @@ def process_assigned_reads(self, sample, dump_filename): model_gen = ( construct_models_in_parallel, - (SampleData(sample.file_list, f"{sample.prefix}_{chr_id}", sample.out_dir, sample.readable_names_dict, sample.illumina_bam) for chr_id in chr_ids), + (SampleData(sample.file_list, f"{sample.prefix}_{chr_id}", sample.out_dir, sample.readable_names_dict, sample.illumina_bam, sample.barcoded_reads) for chr_id in chr_ids), chr_ids, itertools.repeat(dump_filename), itertools.repeat(self.args), @@ -648,27 +760,108 @@ def process_assigned_reads(self, sample, dump_filename): logger.info("Counts for generated transcript models are saves to: " + aggregator.transcript_model_counter.output_counts_file_name) if self.args.read_group: - logger.info("Grouped counts for generated transcript models are saves to: " + - aggregator.transcript_model_grouped_counter.output_counts_file_name) + for counter in aggregator.transcript_model_grouped_counters: + logger.info("Grouped counts for generated transcript models are saves to: " + + counter.output_counts_file_name) aggregator.transcript_model_global_counter.finalize(self.args) aggregator.gene_model_global_counter.finalize(self.args) def finalize(self, aggregator): - if self.args.genedb: + if aggregator.basic_printer: logger.info("Read assignments are stored in " + aggregator.basic_printer.output_file_name + (".gz" if self.args.gzipped else "")) + if self.args.genedb: aggregator.read_stat_counter.print_start("Read assignment statistics") logger.info("Gene counts are stored in " + aggregator.gene_counter.output_counts_file_name) logger.info("Transcript counts are stored in " + aggregator.transcript_counter.output_counts_file_name) if self.args.read_group: - logger.info("Grouped gene counts are saves to: " + - aggregator.gene_grouped_counter.output_counts_file_name) - logger.info("Grouped transcript counts are saves to: " + - aggregator.transcript_grouped_counter.output_counts_file_name) + for counter in aggregator.gene_grouped_counters: + logger.info("Grouped gene counts are saves to: " + counter.output_counts_file_name) + for counter in aggregator.transcript_grouped_counters: + logger.info("Grouped transcript counts are saves to: " + counter.output_counts_file_name) logger.info("Counts can be converted to other formats using src/convert_grouped_counts.py") aggregator.global_counter.finalize(self.args) + def filter_umis(self, sample): + umi_filtering_done = umi_filtered_global_lock_file_name(sample.out_umi_filtered_done) + if os.path.exists(umi_filtering_done): + if self.args.resume: + logger.info("UMI filtering detecting, skipping") + return + os.remove(umi_filtering_done) + + # edit distances for UMI filtering, first one will be used for counts + umi_ed_dict = {IsoQuantMode.bulk: [], + IsoQuantMode.tenX_v3: [3], + IsoQuantMode.visium_5prime: [3], + IsoQuantMode.curio: [3], + IsoQuantMode.visium_hd: [4], + IsoQuantMode.stereoseq: [4], + IsoQuantMode.stereoseq_nosplit: [4]} + + for i, edit_distance in enumerate(umi_ed_dict[self.args.mode]): + logger.info("Filtering PCR duplicates with edit distance %d" % edit_distance) + umi_ed_filtering_done = umi_filtered_lock_file_name(sample.out_umi_filtered_done, "", edit_distance) + if os.path.exists(umi_ed_filtering_done): + if self.args.resume: + logger.info("Filtering was done previously, skipping edit distance %d" % edit_distance) + return + os.remove(umi_ed_filtering_done) + + output_prefix = sample.out_umi_filtered + (".ALL" if edit_distance < 0 else ".ED%d" % edit_distance) + logger.info("Results will be saved to %s" % output_prefix) + output_filtered_reads = i == 0 + + umi_gen = ( + filter_umis_in_parallel, + itertools.repeat(sample), + self.get_chr_list(), + itertools.repeat(self.args), + itertools.repeat(edit_distance), + itertools.repeat(output_filtered_reads), + ) + + if self.args.threads > 1: + with ProcessPoolExecutor(max_workers=self.args.threads) as proc: + results = proc.map(*umi_gen, chunksize=1) + else: + results = map(*umi_gen) + + stat_dict = defaultdict(int) + files_to_remove = [] + with open(output_prefix + ".allinfo", "w") as outf: + for all_info_file_name, stats_output_file_name, umi_filter_done in results: + shutil.copyfileobj(open(all_info_file_name, "r"), outf) + for l in open(stats_output_file_name, "r"): + v = l.strip().split("\t") + if len(v) != 2: + continue + stat_dict[v[0]] += int(v[1]) + files_to_remove.append(all_info_file_name) + files_to_remove.append(stats_output_file_name) + files_to_remove.append(umi_filter_done) + + logger.info("PCR duplicates filtered with edit distance %d, filtering stats:" % edit_distance) + with open(output_prefix + ".stats.tsv", "w") as outf: + for k, v in stat_dict.items(): + logger.info(" %s: %d" % (k, v)) + outf.write("%s\t%d\n" % (k, v)) + + open(umi_ed_filtering_done, "w").close() + for f in files_to_remove: + os.remove(f) + + open(umi_filtering_done, "w").close() + + def split_read_barcode_table(self, sample, split_barcodes_file_names): + logger.info("Splitting read barcode table") + # TODO: untrusted UMIs and third party format, both can be done by passing parsing function instead of columns + split_read_table_parallel(sample, sample.barcoded_reads, split_barcodes_file_names, self.args.threads, + read_column=0, group_columns=(1, 2, 3, 4), delim='\t') + logger.info("Read barcode table was split") + + @staticmethod def load_read_info(dump_filename): info_loader = open(dump_filename + "_info", "rb") @@ -679,11 +872,12 @@ def load_read_info(dump_filename): return total_assignments, polya_assignments, all_read_groups def merge_assignments(self, sample, aggregator, chr_ids): - if self.args.genedb: + if self.args.genedb and aggregator.basic_printer: merge_files(sample.out_assigned_tsv, sample.prefix, chr_ids, aggregator.basic_printer.output_file, copy_header=False) - merge_files(sample.out_corrected_bed, sample.prefix, chr_ids, - aggregator.corrected_bed_printer.output_file, copy_header=False) + if aggregator.corrected_bed_printer: + merge_files(sample.out_corrected_bed, sample.prefix, chr_ids, + aggregator.corrected_bed_printer.output_file, copy_header=False) for counter in aggregator.global_counter.counters: unaligned = self.alignment_stat_counter.stats_dict[AlignmentType.unaligned] @@ -692,7 +886,8 @@ def merge_assignments(self, sample, aggregator, chr_ids): def merge_transcript_models(self, label, aggregator, chr_ids, gff_printer): merge_files(gff_printer.model_fname, label, chr_ids, gff_printer.out_gff, copy_header=False) - merge_files(gff_printer.r2t_fname, label, chr_ids, gff_printer.out_r2t, copy_header=False) + if gff_printer.output_r2t: + merge_files(gff_printer.r2t_fname, label, chr_ids, gff_printer.out_r2t, copy_header=False) for counter in aggregator.transcript_model_global_counter.counters: unaligned = self.alignment_stat_counter.stats_dict[AlignmentType.unaligned] merge_counts(counter, label, chr_ids, unaligned) diff --git a/src/file_naming.py b/src/file_naming.py new file mode 100644 index 00000000..48cdf568 --- /dev/null +++ b/src/file_naming.py @@ -0,0 +1,68 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + + +import os + + +def convert_chr_id_to_file_name_str(chr_id: str): + return chr_id.replace('/', '_') + + +def reads_collected_lock_file_name(sample_out_raw, chr_id): + return "{}_{}_collected".format(sample_out_raw, convert_chr_id_to_file_name_str(chr_id)) + + +def reads_processed_lock_file_name(dump_filename, chr_id): + chr_dump_file = dump_filename + "_" + convert_chr_id_to_file_name_str(chr_id) + return "{}_processed".format(chr_dump_file) + + +def read_group_lock_filename(sample): + return sample.read_group_file + "_lock" + + +def split_barcodes_lock_filename(sample): + return sample.barcodes_split_reads + "_lock" + + +def clean_locks(chr_ids, base_name, fname_function): + for chr_id in chr_ids: + fname = fname_function(base_name, chr_id) + if os.path.exists(fname): + os.remove(fname) + + +def saves_file_name(out_raw_file: str, chr_id: str): + return out_raw_file + "_" + convert_chr_id_to_file_name_str(chr_id) + + +def multimappers_file_name(out_raw_file: str, chr_id: str): + return out_raw_file + "_multimappers_" + convert_chr_id_to_file_name_str(chr_id) + + +def filtered_reads_file_name(out_raw_file: str, chr_id: str): + return out_raw_file + "_filtered_" + chr_id + + +def umi_filtered_reads_file_name(out_umi_filtered_tmp: str, chr_id:str, edit_distance:int): + return out_umi_filtered_tmp + ("_%s_ED%d" % (chr_id, edit_distance)) + + +def umi_filtered_lock_file_name(out_umi_filtered_done: str, chr_id: str, edit_distance: int): + return out_umi_filtered_done + ("_%s_ED%d" % (chr_id, edit_distance)) + + +def allinfo_file_name(out_umi_filtered_tmp: str, chr_id: str, edit_distance: int): + return umi_filtered_reads_file_name(out_umi_filtered_tmp, chr_id, edit_distance) + ".allinfo" + + +def allinfo_stats_file_name(out_umi_filtered_tmp: str, chr_id: str, edit_distance: int): + return umi_filtered_reads_file_name(out_umi_filtered_tmp, chr_id, edit_distance) + ".stats.tsv" + + +def umi_filtered_global_lock_file_name(out_umi_filtered_done: str): + return out_umi_filtered_done + ".lock" \ No newline at end of file diff --git a/src/gene_info.py b/src/gene_info.py index e3bc3955..98530361 100644 --- a/src/gene_info.py +++ b/src/gene_info.py @@ -139,6 +139,49 @@ def to_str(self): return "%s\t%d\t%d\t%s\t%s\t%s" % (self.chr_id, self.start, self.end, self.strand, self.type, ",".join(self.gene_ids)) +class GeneList: + def __init__(self, gene_id_list, delta, chr_id, start, end): + self.gene_id_set = set(gene_id_list) + self.delta = delta + self.chr_id = chr_id + self.start = start + self.end = end + + def merge(self, other): + assert self.chr_id == other.chr_id + self.gene_id_set.update(other.gene_id_set) + self.start = min(self.start, other.start) + self.end = max(self.end, other.end) + + def overlaps(self, other): + return self.chr_id == other.chr_id and len(self.gene_id_set.intersection(other.gene_id_set)) > 0 + + @classmethod + def deserialize(cls, infile): + gene_info = cls.__new__(cls) + gene_info.delta = read_int(infile) + gene_info.gene_id_set = set() + + gene_count = read_int(infile) + for i in range(gene_count): + gene_id = read_string(infile) + gene_info.gene_id_set.add(gene_id) + gene_info.chr_id = read_string(infile) + gene_info.start = read_int(infile) + gene_info.end = read_int(infile) + + return gene_info + + def serialize(self, outfile): + write_int(self.delta, outfile) + write_int(len(self.gene_id_set), outfile) + for g in self.gene_id_set: + write_string(g.id, outfile) + write_string(self.chr_id, outfile) + write_int(self.start, outfile) + write_int(self.end, outfile) + + # All gene(s) information class GeneInfo: EXTRA_BASES_FOR_SEQ = 20 @@ -149,6 +192,8 @@ class GeneInfo: def __init__(self, gene_db_list, db, delta=0, prepare_profiles=True): if db is None: return + + assert gene_db_list # gffutils main structure self.db = db # list of genes in cluster diff --git a/src/graph_based_model_construction.py b/src/graph_based_model_construction.py index 731468a8..fc7df298 100644 --- a/src/graph_based_model_construction.py +++ b/src/graph_based_model_construction.py @@ -50,11 +50,19 @@ class GraphBasedModelConstructor: detected_known_isoforms = set() extended_transcript_ids = set() - def __init__(self, gene_info, chr_record, params, transcript_counter, gene_counter, id_distributor): + def __init__(self, gene_info, chr_record, args, transcript_counter, gene_counter, id_distributor, + transcript_grouped_counters=None, gene_grouped_counters=None, grouping_strategy_names=None, + use_technical_replicas=False): self.gene_info = gene_info self.chr_record = chr_record - self.params = params + self.args = args self.id_distributor = id_distributor + self.transcript_grouped_counters = transcript_grouped_counters if transcript_grouped_counters else [] + self.gene_grouped_counters = gene_grouped_counters if gene_grouped_counters else [] + self.grouping_strategy_names = grouping_strategy_names if grouping_strategy_names else [] + self.use_technical_replicas = use_technical_replicas + # Find file_name group index for technical replicas check + self.file_name_group_idx = self.grouping_strategy_names.index("file_name") if "file_name" in self.grouping_strategy_names else -1 self.strand_detector = StrandDetector(self.chr_record) self.intron_genes = defaultdict(set) @@ -67,8 +75,8 @@ def __init__(self, gene_info, chr_record, params, transcript_counter, gene_count self.known_isoforms_in_graph = {} self.known_introns = set() self.known_isoforms_in_graph_ids = {} - self.assigner = LongReadAssigner(self.gene_info, self.params) - self.profile_constructor = CombinedProfileConstructor(self.gene_info, self.params) + self.assigner = LongReadAssigner(self.gene_info, self.args) + self.profile_constructor = CombinedProfileConstructor(self.gene_info, self.args) self.transcript_model_storage = [] self.transcript_read_ids = defaultdict(list) @@ -117,9 +125,9 @@ def select_reference_gene(self, transcript_introns, transcript_range, transcript return None def process(self, read_assignment_storage): - self.intron_graph = IntronGraph(self.params, self.gene_info, read_assignment_storage) - self.path_processor = IntronPathProcessor(self.params, self.intron_graph) - self.path_storage = IntronPathStorage(self.params, self.path_processor) + self.intron_graph = IntronGraph(self.args, self.gene_info, read_assignment_storage) + self.path_processor = IntronPathProcessor(self.args, self.intron_graph) + self.path_storage = IntronPathStorage(self.args, self.path_processor) self.path_storage.fill(read_assignment_storage) self.known_isoforms_in_graph = self.get_known_spliced_isoforms(self.gene_info) self.known_introns = set(self.gene_info.intron_profiles.features) @@ -140,7 +148,7 @@ def process(self, read_assignment_storage): self.transcript_model_storage = transcript_joiner.join_transcripts() self.forward_counts(read_assignment_storage) - if self.params.sqanti_output: + if self.args.sqanti_output: self.compare_models_with_known() def forward_counts(self, read_assignments): @@ -154,8 +162,16 @@ def forward_counts(self, read_assignments): for read_assignment in self.transcript_read_ids[transcript_id]: read_id = read_assignment.read_id if self.read_assignment_counts[read_id] == 1: - self.transcript_counter.add_read_info_raw(read_id, [transcript_id], read_assignment.read_group) - self.gene_counter.add_read_info_raw(read_id, [gene_id], read_assignment.read_group) + # Add to ungrouped counters + self.transcript_counter.add_read_info_raw(read_id, [transcript_id], read_assignment.read_group[0]) + self.gene_counter.add_read_info_raw(read_id, [gene_id], read_assignment.read_group[0]) + # Add to each grouped counter with its corresponding group + for idx, counter in enumerate(self.transcript_grouped_counters): + if idx < len(read_assignment.read_group): + counter.add_read_info_raw(read_id, [transcript_id], read_assignment.read_group[idx]) + for idx, counter in enumerate(self.gene_grouped_counters): + if idx < len(read_assignment.read_group): + counter.add_read_info_raw(read_id, [gene_id], read_assignment.read_group[idx]) continue if read_id not in ambiguous_assignments: @@ -163,10 +179,19 @@ def forward_counts(self, read_assignments): ambiguous_assignments[read_id].append(transcript_id) for read_id in ambiguous_assignments.keys(): + read_groups = ambiguous_assignments[read_id][0] transcript_ids = ambiguous_assignments[read_id][1:] gene_ids = [transcript2gene[transcript_id] for transcript_id in transcript_ids] - self.transcript_counter.add_read_info_raw(read_id, transcript_ids, ambiguous_assignments[read_id][0]) - self.gene_counter.add_read_info_raw(read_id, gene_ids, ambiguous_assignments[read_id][0]) + # Add to ungrouped counters + self.transcript_counter.add_read_info_raw(read_id, transcript_ids, read_groups[0] if read_groups else "NA") + self.gene_counter.add_read_info_raw(read_id, gene_ids, read_groups[0] if read_groups else "NA") + # Add to each grouped counter with its corresponding group + for idx, counter in enumerate(self.transcript_grouped_counters): + if idx < len(read_groups): + counter.add_read_info_raw(read_id, transcript_ids, read_groups[idx]) + for idx, counter in enumerate(self.gene_grouped_counters): + if idx < len(read_groups): + counter.add_read_info_raw(read_id, gene_ids, read_groups[idx]) for r in read_assignments: if self.read_assignment_counts[r.read_id] > 0: continue @@ -264,7 +289,7 @@ def pre_filter_transcripts(self): internal_count_values.append(self.internal_counter[model.transcript_id]) internal_count_values = sorted(internal_count_values, reverse=True) - coverage_cutoff = self.params.min_novel_count + coverage_cutoff = self.args.min_novel_count # dirty hack to avoid slow assignment in chrM if len(internal_count_values) > 50: coverage_cutoff = internal_count_values[50] @@ -281,7 +306,7 @@ def pre_filter_transcripts(self): continue if (model.transcript_type != TranscriptModelType.known and - self.mapping_quality(model.transcript_id) < self.params.simple_models_mapq_cutoff): + self.mapping_quality(model.transcript_id) < self.args.simple_models_mapq_cutoff): self.delete_from_storage(model.transcript_id) continue filtered_storage.append(model) @@ -307,11 +332,11 @@ def filter_transcripts(self): if component_coverage == 0 or len(model.intron_path) == 0 or \ (len(model.intron_path) == 1 and self.intron_graph.is_monointron(model.intron_path[0])): component_coverage = self.intron_graph.get_overlapping_component_max_coverage((model.get_start(), model.get_end())) - novel_isoform_cutoff = max(self.params.min_novel_count, - self.params.min_mono_count_rel * component_coverage) + novel_isoform_cutoff = max(self.args.min_novel_count, + self.args.min_mono_count_rel * component_coverage) else: - novel_isoform_cutoff = max(self.params.min_novel_count, - self.params.min_novel_count_rel * component_coverage) + novel_isoform_cutoff = max(self.args.min_novel_count, + self.args.min_novel_count_rel * component_coverage) if model.transcript_id in to_substitute: #logger.debug("Novel model %s has a similar isoform %s" % (model.transcript_id, to_substitute[model.transcript_id])) @@ -328,7 +353,7 @@ def filter_transcripts(self): if len(model.exon_blocks) <= 2: mapq = self.mapping_quality(model.transcript_id) #logger.debug("Novel model %s has quality %.2f" % (model.transcript_id, mapq)) - if mapq < self.params.simple_models_mapq_cutoff: + if mapq < self.args.simple_models_mapq_cutoff: #logger.debug("Novel model %s has poor quality" % model.transcript_id) self.delete_from_storage(model.transcript_id) continue @@ -366,9 +391,9 @@ def detect_similar_isoforms(self, model_storage): for model in model_storage: if len(model.exon_blocks) <= 2 or model.transcript_id in to_substitute: continue - transcript_model_gene_info = GeneInfo.from_models([model], self.params.delta) - assigner = LongReadAssigner(transcript_model_gene_info, self.params) - profile_constructor = CombinedProfileConstructor(transcript_model_gene_info, self.params) + transcript_model_gene_info = GeneInfo.from_models([model], self.args.delta) + assigner = LongReadAssigner(transcript_model_gene_info, self.args) + profile_constructor = CombinedProfileConstructor(transcript_model_gene_info, self.args) for m in model_storage: if m.transcript_type == TranscriptModelType.known or m.transcript_id == model.transcript_id or \ @@ -449,7 +474,7 @@ def construct_fl_isoforms(self): # adding FL reference isoform if reference_isoform in GraphBasedModelConstructor.detected_known_isoforms: pass - elif count < self.params.min_known_count: + elif count < self.args.min_known_count: pass # logger.debug("uuu Isoform %s has low coverage %d" % (reference_isoform, count)) else: new_model = self.transcript_from_reference(reference_isoform) @@ -460,7 +485,7 @@ def construct_fl_isoforms(self): else: # adding FL novel isoform # component_coverage = self.intron_graph.get_max_component_coverage(intron_path) - novel_isoform_cutoff = self.params.min_novel_count + novel_isoform_cutoff = self.args.min_novel_count has_polyt = path[0][0] == VERTEX_polyt has_polya = path[-1][0] == VERTEX_polya @@ -474,20 +499,26 @@ def construct_fl_isoforms(self): # logger.debug("uuu Novel isoform %s has low coverage: %d\t%d" % (new_transcript_id, count, novel_isoform_cutoff)) pass elif (len(novel_exons) == 2 and - ((self.params.require_monointronic_polya and not polya_site) or transcript_clean_strand == '.')): + ((self.args.require_monointronic_polya and not polya_site) or transcript_clean_strand == '.')): # logger.debug("uuu Avoiding single intron %s isoform: %d\t%s" % (new_transcript_id, count, str(path))) pass - elif ((self.params.report_canonical_strategy == StrandnessReportingLevel.only_canonical + elif ((self.args.report_canonical_strategy == StrandnessReportingLevel.only_canonical and transcript_clean_strand == '.') - or (self.params.report_canonical_strategy == StrandnessReportingLevel.only_stranded + or (self.args.report_canonical_strategy == StrandnessReportingLevel.only_stranded and transcript_strand == '.')): logger.debug("Avoiding unreliable transcript with %d exons (strand cannot be detected)" % len(novel_exons)) pass else: - if self.params.use_technical_replicas and \ - len(set([a.read_group for a in self.path_storage.paths_to_reads[path]])) <= 1: - #logger.debug("%s was suspended due to technical replicas check" % new_transcript_id) - continue + if self.use_technical_replicas and self.file_name_group_idx >= 0: + # Check if reads come from same file (technical replicates) + read_assignments = self.path_storage.paths_to_reads[path] + if read_assignments: + file_groups = set([a.read_group[self.file_name_group_idx] + for a in read_assignments + if self.file_name_group_idx < len(a.read_group)]) + if len(file_groups) <= 1: + #logger.debug("%s was suspended due to technical replicas check" % new_transcript_id) + continue transcript_gene = self.select_reference_gene(intron_path, transcript_range, transcript_strand) if transcript_gene is None: @@ -525,7 +556,7 @@ def construct_assignment_based_isoforms(self, read_assignment_storage): for read_assignment in read_assignment_storage: if len(read_assignment.corrected_exons) <= 2 and \ - (read_assignment.multimapper or read_assignment.mapping_quality < self.params.simple_alignments_mapq_cutoff): + (read_assignment.multimapper or read_assignment.mapping_quality < self.args.simple_alignments_mapq_cutoff): continue if not read_assignment: @@ -571,23 +602,23 @@ def construct_assignment_based_isoforms(self, read_assignment_storage): # (read_assignment.read_id, refrenence_isoform_id)) spliced_isoform_reads[refrenence_isoform_id].append(read_assignment) - if self.params.requires_polya_for_construction and self.gene_info.isoform_strands[refrenence_isoform_id] == '-': + if self.args.requires_polya_for_construction and self.gene_info.isoform_strands[refrenence_isoform_id] == '-': if any(x.event_type == MatchEventSubtype.correct_polya_site_left for x in events): isoform_left_support[refrenence_isoform_id] += 1 - elif abs(self.gene_info.all_isoforms_exons[refrenence_isoform_id][0][0] - read_assignment.corrected_exons[0][0]) <= self.params.apa_delta: + elif abs(self.gene_info.all_isoforms_exons[refrenence_isoform_id][0][0] - read_assignment.corrected_exons[0][0]) <= self.args.apa_delta: isoform_left_support[refrenence_isoform_id] += 1 - if self.params.requires_polya_for_construction and self.gene_info.isoform_strands[refrenence_isoform_id] == '+': + if self.args.requires_polya_for_construction and self.gene_info.isoform_strands[refrenence_isoform_id] == '+': if any(x.event_type == MatchEventSubtype.correct_polya_site_right for x in events): isoform_right_support[refrenence_isoform_id] += 1 - elif abs(self.gene_info.all_isoforms_exons[refrenence_isoform_id][-1][1] - read_assignment.corrected_exons[-1][1]) <= self.params.apa_delta: + elif abs(self.gene_info.all_isoforms_exons[refrenence_isoform_id][-1][1] - read_assignment.corrected_exons[-1][1]) <= self.args.apa_delta: isoform_right_support[refrenence_isoform_id] += 1 self.construct_monoexon_isoforms(mono_exon_isoform_reads, mono_exon_isoform_coverage, polya_sites) - if not self.params.fl_only: + if not self.args.fl_only: logger.debug("Constructing nonFL isoforms") self.construct_nonfl_isoforms(spliced_isoform_reads, isoform_left_support, isoform_right_support) - if self.params.report_novel_unspliced: + if self.args.report_novel_unspliced: self.construct_monoexon_novel(novel_mono_exon_reads) def collect_terminal_exons_from_graph(self): @@ -609,13 +640,13 @@ def is_internal_monoexonic_read(self, alignment, terminal_exons, forward=True): read_coordinates = alignment.corrected_exons[0] if forward: for e in terminal_exons: - if abs(e[1] - alignment.corrected_exons[-1][1]) <= self.params.apa_delta and \ - read_coordinates[0] >= e[0] - self.params.delta: + if abs(e[1] - alignment.corrected_exons[-1][1]) <= self.args.apa_delta and \ + read_coordinates[0] >= e[0] - self.args.delta: return True else: for e in terminal_exons: - if abs(e[0] - alignment.corrected_exons[0][0]) <= self.params.apa_delta and \ - read_coordinates[1] <= e[1] + self.params.delta: + if abs(e[0] - alignment.corrected_exons[0][0]) <= self.args.apa_delta and \ + read_coordinates[1] <= e[1] + self.args.delta: return True return False @@ -639,7 +670,7 @@ def construct_monoexon_novel(self, novel_mono_exon_reads): novel_monoexon.update(self.generate_monoexon_from_clustered(clustered_polyt_reads, False)) def generate_monoexon_from_clustered(self, clustered_reads, forward=True): - cutoff = self.params.min_novel_count + cutoff = self.args.min_novel_count result = set() for three_prime_pos in clustered_reads.keys(): count = len(clustered_reads[three_prime_pos]) @@ -684,7 +715,7 @@ def cluster_monoexons(self, grouped_reads): while grouped_reads: best_pair = max(grouped_reads.items(), key=lambda x:len(x[1])) top_position = best_pair[0] - for pos in range(top_position - self.params.apa_delta, top_position + self.params.apa_delta + 1): + for pos in range(top_position - self.args.apa_delta, top_position + self.args.apa_delta + 1): if pos in grouped_reads: clustered_counts[top_position] += grouped_reads[pos] del grouped_reads[pos] @@ -699,8 +730,8 @@ def construct_monoexon_isoforms(self, mono_exon_isoform_reads, mono_exon_isoform polya_support = polya_sites[isoform_id] # logger.debug(">> Monoexon transcript %s: %d\t%d\t%.4f\t%d" % (isoform_id, self.intron_graph.max_coverage, count, coverage, polya_support)) - if (count < self.params.min_known_count or coverage < self.params.min_mono_exon_coverage or - (self.params.require_monoexonic_polya and polya_support == 0)): + if (count < self.args.min_known_count or coverage < self.args.min_mono_exon_coverage or + (self.args.require_monoexonic_polya and polya_support == 0)): pass # logger.debug(">> Will NOT be added, abs cutoff=%d" % (self.params.min_known_count)) elif isoform_id not in GraphBasedModelConstructor.detected_known_isoforms: new_model = self.transcript_from_reference(isoform_id) @@ -724,7 +755,7 @@ def construct_nonfl_isoforms(self, spliced_isoform_reads, spliced_isoform_left_s intron_path = self.known_isoforms_in_graph_ids[isoform_id] #logger.debug("Known non-FL spliced isoform %s" % isoform_id) - if count < self.params.min_known_count or \ + if count < self.args.min_known_count or \ spliced_isoform_left_support[isoform_id] < 1 or \ spliced_isoform_right_support[isoform_id] < 1: pass @@ -751,9 +782,9 @@ def assign_reads_to_models(self, read_assignments): return logger.debug("Creating artificial GeneInfo from %d transcript models" % len(self.transcript_model_storage)) - transcript_model_gene_info = GeneInfo.from_models(self.transcript_model_storage, self.params.delta) - assigner = LongReadAssigner(transcript_model_gene_info, self.params, quick_mode=True) - profile_constructor = CombinedProfileConstructor(transcript_model_gene_info, self.params) + transcript_model_gene_info = GeneInfo.from_models(self.transcript_model_storage, self.args.delta) + assigner = LongReadAssigner(transcript_model_gene_info, self.args, quick_mode=True) + profile_constructor = CombinedProfileConstructor(transcript_model_gene_info, self.args) for assignment in read_assignments: read_id = assignment.read_id @@ -765,7 +796,7 @@ def assign_reads_to_models(self, read_assignments): # logger.debug("# Checking read %s: %s" % (assignment.read_id, str(read_exons))) model_combined_profile = profile_constructor.construct_profiles(read_exons, assignment.polya_info, []) model_assignment = assigner.assign_to_isoform(assignment.read_id, model_combined_profile) - model_assignment.read_group = assignment.read_group + model_assignment.read_group = assignment.read_group # Full list, not just [0] # check that no serious contradiction occurs if model_assignment.assignment_type.is_consistent(): matched_isoforms = [m.assigned_transcript for m in model_assignment.isoform_matches] @@ -789,11 +820,11 @@ def correct_novel_transcript_ends(self, transcript_model, assigned_reads): for assignment in assigned_reads: read_exons = assignment.corrected_exons - if abs(read_exons[0][0] - transcript_start) <= self.params.apa_delta: + if abs(read_exons[0][0] - transcript_start) <= self.args.apa_delta: start_supported = True if not start_supported and read_exons[0][0] < transcript_model.exon_blocks[0][1]: read_starts.add(read_exons[0][0]) - if abs(read_exons[-1][1] - transcript_end) <= self.params.apa_delta: + if abs(read_exons[-1][1] - transcript_end) <= self.args.apa_delta: end_supported = True if not end_supported and read_exons[-1][1] > transcript_model.exon_blocks[-1][0]: read_ends[read_exons[-1][1]] += 1 diff --git a/src/input_data_storage.py b/src/input_data_storage.py index cdff3164..381e0d8f 100644 --- a/src/input_data_storage.py +++ b/src/input_data_storage.py @@ -30,7 +30,7 @@ def needs_mapping(self): class SampleData: - def __init__(self, file_list, prefix, out_dir, readable_names_dict, illumina_bam): + def __init__(self, file_list, prefix, out_dir, readable_names_dict, illumina_bam, barcoded_reads=None): # list of lists, since each sample may contain several libraries, and each library may contain 2 files (paired) self.file_list = file_list self.readable_names_dict = readable_names_dict @@ -38,6 +38,11 @@ def __init__(self, file_list, prefix, out_dir, readable_names_dict, illumina_bam self.prefix = prefix self.out_dir = out_dir self.aux_dir = os.path.join(self.out_dir, "aux") + if not barcoded_reads: + self.barcoded_reads = [] + else: + self.barcoded_reads = barcoded_reads + self.use_technical_replicas = False # Will be set by DatasetProcessor self._init_paths() def _make_path(self, name): @@ -48,6 +53,7 @@ def _make_aux_path(self, name): def _init_paths(self): self.out_assigned_tsv = self._make_path(self.prefix + ".read_assignments.tsv") + self.out_assigned_tsv_result = self.out_assigned_tsv self.out_raw_file = self._make_aux_path(self.prefix + ".save") self.read_group_file = self._make_aux_path(self.prefix + ".read_group") self.out_corrected_bed = self._make_path(self.prefix + ".corrected_reads.bed") @@ -65,6 +71,13 @@ def _init_paths(self): self.out_exon_grouped_counts_tsv = self._make_path(self.prefix + ".exon_grouped") self.out_intron_grouped_counts_tsv = self._make_path(self.prefix + ".intron_grouped") self.out_t2t_tsv = self._make_path(self.prefix + ".novel_vs_known.SQANTI-like.tsv") + self.barcodes_tsv = self._make_path(self.prefix + ".barcoded_reads") + self.barcodes_done = self._make_aux_path(self.prefix + ".barcodes_done") + self.barcodes_split_reads = self._make_aux_path(self.prefix + ".split_barcodes") + self.out_umi_filtered = self._make_path(self.prefix + ".UMI_filtered") + self.out_umi_filtered_tmp = self._make_aux_path(self.prefix + ".UMI_filtered") + self.out_umi_filtered_done= self._make_aux_path(self.prefix + ".UMI_filtered.done") + self.split_reads_fasta = self._make_path(self.prefix + ".split_reads") class InputDataStorage: diff --git a/src/isoform_assignment.py b/src/isoform_assignment.py index db1ba011..c07825e3 100644 --- a/src/isoform_assignment.py +++ b/src/isoform_assignment.py @@ -486,6 +486,8 @@ def __init__(self, read_assignment): self.penalty_score = 0.0 self.isoforms = [] self.genes = [] + self.barcode = read_assignment.barcode if hasattr(read_assignment, 'barcode') else None + self.umi = read_assignment.umi if hasattr(read_assignment, 'umi') else None if read_assignment.isoform_matches: gene_set = set() @@ -575,7 +577,9 @@ def deserialize_from_read_assignment(cls, infile): read_int_neg(infile) read_int_neg(infile) read_int_neg(infile) - read_string(infile) + read_list(infile, read_string) # read_group is now a list + read_string_or_none(infile) + read_string_or_none(infile) read_string(infile) read_string(infile) read_assignment.chr_id = read_string(infile) @@ -634,7 +638,9 @@ def __init__(self, read_id, assignment_type, match=None): self.polyA_found = False self.cage_found = False self.polya_info = None - self.read_group = "NA" + self.read_group = [] + self.barcode = None # Cell/spatial barcode + self.umi = None # Unique molecular identifier self.mapped_strand = "." self.strand = "." self.chr_id = "." @@ -678,7 +684,9 @@ def deserialize(cls, infile, gene_info): read_assignment.polyA_found = bool_arr[1] read_assignment.cage_found = bool_arr[2] read_assignment.polya_info = PolyAInfo(read_int_neg(infile), read_int_neg(infile), read_int_neg(infile), read_int_neg(infile)) - read_assignment.read_group = read_string(infile) + read_assignment.read_group = read_list(infile, read_string) + read_assignment.barcode = read_string_or_none(infile) + read_assignment.umi = read_string_or_none(infile) read_assignment.mapped_strand = read_string(infile) read_assignment.strand = read_string(infile) read_assignment.chr_id = read_string(infile) @@ -705,7 +713,9 @@ def serialize(self, outfile): write_int_neg(self.polya_info.external_polyt_pos, outfile) write_int_neg(self.polya_info.internal_polya_pos, outfile) write_int_neg(self.polya_info.internal_polyt_pos, outfile) - write_string(self.read_group, outfile) + write_list(self.read_group, outfile, write_string) + write_string_or_none(self.barcode, outfile) + write_string_or_none(self.umi, outfile) write_string(self.mapped_strand, outfile) write_string(self.strand, outfile) write_string(self.chr_id, outfile) diff --git a/src/long_read_assigner.py b/src/long_read_assigner.py index 81630c85..fa5f62b7 100644 --- a/src/long_read_assigner.py +++ b/src/long_read_assigner.py @@ -447,6 +447,7 @@ def assign_to_isoform(self, read_id, combined_read_profile): assignment = ReadAssignment(read_id, ReadAssignmentType.noninformative, IsoformMatch(MatchClassification.intergenic)) elif all(el != 1 for el in read_split_exon_profile.gene_profile): # logger.debug("EMPTY - intronic") + # TODO: match to a gene assignment = ReadAssignment(read_id, ReadAssignmentType.noninformative, IsoformMatch(MatchClassification.genic_intron)) else: # logger.debug("EMPTY - genic") diff --git a/src/long_read_counter.py b/src/long_read_counter.py index 107810d0..e73e7f51 100644 --- a/src/long_read_counter.py +++ b/src/long_read_counter.py @@ -232,10 +232,11 @@ def add_unaligned(self, n_reads=1): # get_feature_id --- function that returns feature id form IsoformMatch object class AssignedFeatureCounter(AbstractCounter): def __init__(self, output_prefix, assignment_extractor, read_groups, read_counter, - all_features=None): + all_features=None, group_index: int = 0): AbstractCounter.__init__(self, output_prefix, not read_groups) self.assignment_extractor = assignment_extractor self.all_features = set(all_features) if all_features is not None else set() + self.group_index = group_index # Index in read_group list to use if not read_groups: self.group_numeric_ids = {AbstractReadGrouper.default_group_id: 0} self.ordered_groups = [AbstractReadGrouper.default_group_id] @@ -258,7 +259,7 @@ def __init__(self, output_prefix, assignment_extractor, read_groups, read_counte self.usable_file_name = self.output_counts_file_name + ".usable" def add_read_info(self, read_assignment=None): - group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group + group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group[self.group_index] group_id = self.group_numeric_ids[group_id] if not read_assignment: @@ -322,7 +323,7 @@ def add_read_info_raw(self, read_id, feature_ids, group_id=AbstractReadGrouper.d self.reads_for_tpm[group_id] += 1 def add_unassigned(self, read_assignment): - group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group + group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group[self.group_index] group_id = self.group_numeric_ids[group_id] self.not_assigned_reads += 1 self.reads_for_tpm[group_id] += 1 @@ -429,17 +430,17 @@ def finalize(self, args=None): convert_to_mtx(self.output_file, self.output_tpm_prefix, convert_to_tpm=True, usable_reads_per_group=reads_for_tpm) -def create_gene_counter(output_file_name, strategy, complete_feature_list=None, read_groups=None): +def create_gene_counter(output_file_name, strategy, complete_feature_list=None, read_groups=None, group_index: int = 0): read_weight_counter = ReadWeightCounter(strategy) return AssignedFeatureCounter(output_file_name, GeneAssignmentExtractor, - read_groups, read_weight_counter, complete_feature_list) + read_groups, read_weight_counter, complete_feature_list, group_index) def create_transcript_counter(output_file_name, strategy, complete_feature_list=None, - read_groups=None): + read_groups=None, group_index: int = 0): read_weight_counter = ReadWeightCounter(strategy) return AssignedFeatureCounter(output_file_name, TranscriptAssignmentExtractor, - read_groups, read_weight_counter, complete_feature_list) + read_groups, read_weight_counter, complete_feature_list, group_index) @unique @@ -487,8 +488,9 @@ def convert_ungrouped_to_tpm(counts_file_name, output_tpm_file_name, normalizati # count simple features inclusion/exclusion (exons / introns) class ProfileFeatureCounter(AbstractCounter): - def __init__(self, output_prefix, ignore_read_groups=False): + def __init__(self, output_prefix, ignore_read_groups=False, group_index: int = 0): AbstractCounter.__init__(self, output_prefix, ignore_read_groups) + self.group_index = group_index # Index in read_group list to use # feature_id -> (group_id -> count) self.inclusion_feature_counter = defaultdict(lambda: IncrementalDict(int)) self.exclusion_feature_counter = defaultdict(lambda: IncrementalDict(int)) @@ -551,25 +553,25 @@ def is_assigned_to_gene(assignment): class ExonCounter(ProfileFeatureCounter): - def __init__(self, output_prefix, ignore_read_groups=False): - ProfileFeatureCounter.__init__(self, output_prefix, ignore_read_groups) + def __init__(self, output_prefix, ignore_read_groups=False, group_index: int = 0): + ProfileFeatureCounter.__init__(self, output_prefix, ignore_read_groups, group_index) def add_read_info(self, read_assignment): if not ProfileFeatureCounter.is_valid(read_assignment) or not ProfileFeatureCounter.is_assigned_to_gene(read_assignment): return - group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group + group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group[self.group_index] self.add_read_info_from_profile(read_assignment.exon_gene_profile, read_assignment.strand, read_assignment.gene_info.exon_property_map, group_id) class IntronCounter(ProfileFeatureCounter): - def __init__(self, output_prefix, ignore_read_groups=False): - ProfileFeatureCounter.__init__(self, output_prefix, ignore_read_groups) + def __init__(self, output_prefix, ignore_read_groups=False, group_index: int = 0): + ProfileFeatureCounter.__init__(self, output_prefix, ignore_read_groups, group_index) def add_read_info(self, read_assignment): if not ProfileFeatureCounter.is_valid(read_assignment) or not ProfileFeatureCounter.is_assigned_to_gene(read_assignment): return - group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group + group_id = AbstractReadGrouper.default_group_id if self.ignore_read_groups else read_assignment.read_group[self.group_index] self.add_read_info_from_profile(read_assignment.intron_gene_profile, read_assignment.strand, read_assignment.gene_info.intron_property_map, group_id) diff --git a/src/modes.py b/src/modes.py new file mode 100644 index 00000000..77305cf8 --- /dev/null +++ b/src/modes.py @@ -0,0 +1,46 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +from enum import Enum, unique + + +@unique +class IsoQuantMode(Enum): + bulk = 1 + tenX_v3 = 2 + curio = 3 + stereoseq_nosplit = 4 + stereoseq = 5 + visium_hd = 6 + visium_5prime = 7 + + def needs_barcode_calling(self): + return self in [IsoQuantMode.tenX_v3, + IsoQuantMode.curio, + IsoQuantMode.stereoseq_nosplit, + IsoQuantMode.stereoseq, + IsoQuantMode.visium_hd, + IsoQuantMode.visium_5prime] + + def needs_pcr_deduplication(self): + return self in [IsoQuantMode.tenX_v3, + IsoQuantMode.curio, + IsoQuantMode.stereoseq_nosplit, + IsoQuantMode.stereoseq, + IsoQuantMode.visium_hd, + IsoQuantMode.visium_5prime] + + def produces_new_fasta(self): + return self in [IsoQuantMode.stereoseq] + + def needs_barcode_iterator(self): + return self in [IsoQuantMode.stereoseq_nosplit, IsoQuantMode.stereoseq] + + def enforces_single_thread(self): + return False + + +ISOQUANT_MODES = [x.name for x in IsoQuantMode] \ No newline at end of file diff --git a/src/processed_read_manager.py b/src/processed_read_manager.py index 92270601..2ebdab79 100644 --- a/src/processed_read_manager.py +++ b/src/processed_read_manager.py @@ -8,7 +8,7 @@ import logging from collections import defaultdict -from .common import convert_chr_id_to_file_name_str +from .file_naming import saves_file_name, multimappers_file_name from .serialization import * from .isoform_assignment import BasicReadAssignment, ReadAssignmentType, ReadAssignment from .multimap_resolver import MultimapResolver @@ -23,7 +23,7 @@ def prepare_multimapper_dict(chr_ids, sample, multimappers_counts): polya_unique_assignments = 0 for chr_id in chr_ids: - chr_dump_file = sample.out_raw_file + "_" + convert_chr_id_to_file_name_str(chr_id) + chr_dump_file = saves_file_name(sample.out_raw_file, chr_id) loader = BasicReadAssignmentLoader(chr_dump_file) while loader.has_next(): for read_assignment in loader.get_next(): @@ -43,8 +43,7 @@ def resolve_multimappers(chr_ids, sample, multimapped_reads, strategy): multimap_resolver = MultimapResolver(strategy) multimap_dumper = {} for chr_id in chr_ids: - multimap_dumper[chr_id] = open(sample.out_raw_file + "_multimappers_" + convert_chr_id_to_file_name_str(chr_id), - "wb") + multimap_dumper[chr_id] = open(multimappers_file_name(sample.out_raw_file, chr_id), 'wb') total_assignments = 0 polya_assignments = 0 diff --git a/src/read_groups.py b/src/read_groups.py index 3129dbd0..abff9630 100644 --- a/src/read_groups.py +++ b/src/read_groups.py @@ -11,6 +11,8 @@ import pysam from collections import defaultdict +from .table_splitter import split_read_table_parallel + logger = logging.getLogger('IsoQuant') @@ -60,7 +62,7 @@ def get_group_id(self, alignment, filename=None): values = read_id.split(self.delim) if len(values) == 1: logger.warning("Delimiter %s is not present in read id %s, skipping" % (self.delim, read_id)) - return + return "" self.read_groups.add(values[-1]) return values[-1] @@ -105,6 +107,119 @@ def get_group_id(self, alignment, filename=None): return filename +class BarcodeSpotGrouper(AbstractReadGrouper): + """Grouper that maps reads to spots/cell types via barcodes""" + def __init__(self, barcode_file, barcode2spot_files): + """ + Initialize barcode-to-spot grouper. + + Args: + barcode_file: Path to split barcode file for chromosome (read_id -> barcode, umi) + barcode2spot_files: List of TSV files mapping barcode -> spot/cell type + """ + AbstractReadGrouper.__init__(self) + logger.debug(f"Reading barcodes from {barcode_file}") + + # Load barcode dict: read_id -> (barcode, umi) + self.read_to_barcode = {} + if os.path.exists(barcode_file): + for line in open(barcode_file): + if line.startswith("#"): + continue + parts = line.split() + if len(parts) >= 2: + # Store just the barcode (second column) + self.read_to_barcode[parts[0]] = parts[1] + + # Load barcode2spot mapping: barcode -> spot/cell type + self.barcode_to_spot = {} + for barcode2spot_file in barcode2spot_files: + logger.debug(f"Reading barcode-to-spot mapping from {barcode2spot_file}") + self.barcode_to_spot.update(load_table(barcode2spot_file, 0, 1, '\t')) + + logger.info(f"Loaded {len(self.read_to_barcode)} read-barcode mappings and " + f"{len(self.barcode_to_spot)} barcode-spot mappings") + + def get_group_id(self, alignment, filename=None): + """Map read to spot via barcode""" + read_id = alignment.query_name + + # Look up barcode for this read + if read_id not in self.read_to_barcode: + self.read_groups.add(self.default_group_id) + return self.default_group_id + + barcode = self.read_to_barcode[read_id] + + # Look up spot for this barcode + if barcode not in self.barcode_to_spot: + self.read_groups.add(self.default_group_id) + return self.default_group_id + + spot = self.barcode_to_spot[barcode] + self.read_groups.add(spot) + return spot + + +class MultiColumnReadTableGrouper(AbstractReadGrouper): + """Grouper that handles TSV files with multiple group columns""" + def __init__(self, table_tsv_file, read_id_column_index=0, group_id_column_indices=None, delim='\t'): + AbstractReadGrouper.__init__(self) + if group_id_column_indices is None: + group_id_column_indices = [1] + logger.debug("Reading read groups from " + table_tsv_file) + self.read_map = load_multicolumn_table(table_tsv_file, read_id_column_index, group_id_column_indices, delim) + self.num_groups = len(group_id_column_indices) + + def get_group_id(self, alignment, filename=None): + """Returns a list of group IDs, one per column""" + if alignment.query_name not in self.read_map: + default_groups = [self.default_group_id] * self.num_groups + for g in default_groups: + self.read_groups.add(g) + return default_groups + group_ids = self.read_map[alignment.query_name] + for g in group_ids: + self.read_groups.add(g) + return group_ids + + +class MultiReadGrouper: + """Manages multiple read groupers and returns list of group IDs""" + def __init__(self, groupers): + """ + Initialize with a list of groupers + Args: + groupers: list of AbstractReadGrouper objects + """ + if not isinstance(groupers, list): + groupers = [groupers] + self.groupers = groupers + self.read_groups = [set() for _ in groupers] + + def get_group_id(self, alignment, filename=None): + """Returns a list of group IDs from all groupers""" + group_ids = [] + for i, grouper in enumerate(self.groupers): + gid = grouper.get_group_id(alignment, filename) + # Handle both single group IDs and lists (for MultiColumnReadTableGrouper) + if isinstance(gid, list): + group_ids.extend(gid) + for g in gid: + self.read_groups[i].add(g) + else: + group_ids.append(gid) + self.read_groups[i].add(gid) + return group_ids + + def get_all_groups(self): + """Returns all unique groups across all groupers""" + all_groups = [] + for groups in self.read_groups: + all_groups.append(list(groups)) + return all_groups + + def get_file_grouping_properties(values): assert len(values) >= 2 if len(values) > 4: @@ -116,25 +231,96 @@ def get_file_grouping_properties(values): def prepare_read_groups(args, sample): + """ + Prepare read group files by splitting them by chromosome for memory efficiency. + Handles both single and multiple grouping specifications. + Uses improved parallel algorithm for better performance. + + args.read_group should be a list of grouping specifications (nargs='+'). + """ if not hasattr(args, "read_group") or args.read_group is None: return - option = args.read_group - values = option.split(':') - if values[0] != 'file': - return - table_filename, read_id_column_index, group_id_column_index, delim = get_file_grouping_properties(values) - logger.info("Splitting read group file %s for better memory consumption" % table_filename) - split_read_group_table(table_filename, sample, read_id_column_index, group_id_column_index, delim) + # Handle both list (nargs='+') and string (backward compatibility) + if isinstance(args.read_group, str): + specs = args.read_group.split(';') + else: + specs = args.read_group -def create_read_grouper(args, sample, chr_id): - if not hasattr(args, "read_group") or args.read_group is None: - return DefaultReadGrouper() + # Collect chromosome names from BAM files to create split files + bam_files = list(map(lambda x: x[0], sample.file_list)) + chromosomes = set() + for bam_file in bam_files: + try: + bam = pysam.AlignmentFile(bam_file, "rb") + chromosomes.update(bam.references) + bam.close() + except: + pass + chromosomes = list(chromosomes) + + for spec in specs: + spec = spec.strip() + if not spec: + continue + + values = spec.split(':') + if values[0] != 'file': + continue + + # Parse specification + table_filename = values[1] + read_id_column_index = int(values[2]) if len(values) > 2 else 0 + + if len(values) >= 4 and ',' in values[3]: + # Multi-column TSV: file:filename:read_col:group_cols:delim + group_id_column_indices = [int(x) for x in values[3].split(',')] + delim = values[4] if len(values) > 4 else '\t' + logger.info("Splitting multi-column read group file %s for better memory consumption" % table_filename) + else: + # Single column TSV + group_id_column_index = int(values[3]) if len(values) > 3 else 1 + group_id_column_indices = [group_id_column_index] + delim = values[4] if len(values) > 4 else '\t' + logger.info("Splitting read group file %s for better memory consumption" % table_filename) + + # Build output file names for each chromosome + split_reads_file_names = {chr_id: sample.read_group_file + "_" + chr_id for chr_id in chromosomes} + + # Use improved parallel splitting with line-by-line streaming + num_threads = args.threads if hasattr(args, 'threads') else 4 + + split_read_table_parallel(sample, table_filename, split_reads_file_names, + num_threads, + read_column=read_id_column_index, + group_columns=tuple(group_id_column_indices), + delim=delim) + + +def parse_grouping_spec(spec_string, args, sample, chr_id): + """Parse a single grouping specification and return the appropriate grouper""" + values = spec_string.split(':') - option = args.read_group - values = option.split(':') if values[0] == "file_name": return FileNameGrouper(args, sample) + elif values[0] == 'barcode_spot': + # Format: barcode_spot:file1.tsv or just use --barcode2spot files + if len(values) >= 2: + # Explicit file(s) specified + barcode2spot_files = values[1:] + elif hasattr(args, 'barcode2spot') and args.barcode2spot: + # Use --barcode2spot files + barcode2spot_files = args.barcode2spot + else: + logger.critical("barcode_spot grouping requires --barcode2spot or explicit file specification") + return None + + # Get split barcode file for this chromosome + if not hasattr(sample, 'barcodes_split_reads') or not sample.barcodes_split_reads: + logger.critical("barcode_spot grouping requires barcoded reads (use --barcoded_reads)") + return None + barcode_file = sample.barcodes_split_reads + "_" + chr_id + return BarcodeSpotGrouper(barcode_file, barcode2spot_files) elif values[0] == 'tag': if len(values) < 2: return AlignmentTagReadGrouper(tag="RG") @@ -142,11 +328,127 @@ def create_read_grouper(args, sample, chr_id): elif values[0] == 'read_id': return ReadIdSplitReadGrouper(delim=values[1]) elif values[0] == 'file': - read_group_chr_filename = sample.read_group_file + "_" + chr_id - return ReadTableGrouper(read_group_chr_filename, 0, 1, '\t') + # Format: file:filename:read_col:group_cols:delim + # group_cols can be comma-separated like "1,2,3" + if len(values) >= 4: + # Check if multiple columns are specified + group_col_spec = values[3] + if ',' in group_col_spec: + # Multiple columns - use MultiColumnReadTableGrouper + group_id_column_indices = [int(x) for x in group_col_spec.split(',')] + read_id_column_index = int(values[2]) if len(values) > 2 else 0 + delim = values[4] if len(values) > 4 else '\t' + read_group_chr_filename = sample.read_group_file + "_" + chr_id + return MultiColumnReadTableGrouper(read_group_chr_filename, read_id_column_index, + group_id_column_indices, delim) + else: + # Single column - use ReadTableGrouper + read_id_column_index = int(values[2]) if len(values) > 2 else 0 + group_id_column_index = int(values[3]) + delim = values[4] if len(values) > 4 else '\t' + read_group_chr_filename = sample.read_group_file + "_" + chr_id + return ReadTableGrouper(read_group_chr_filename, read_id_column_index, + group_id_column_index, delim) + else: + # Default format + read_group_chr_filename = sample.read_group_file + "_" + chr_id + return ReadTableGrouper(read_group_chr_filename, 0, 1, '\t') + else: + logger.critical("Unsupported read grouping option: %s" % values[0]) + return None + + +def create_read_grouper(args, sample, chr_id): + """ + Create read grouper(s) based on args.read_group specification. + + args.read_group should be a list of grouping specifications (nargs='+'): + - Single grouper: ["tag:RG"] or ["file_name"] + - Multiple groupers: ["tag:RG", "file_name", "read_id:_"] + - Multi-column TSV: ["file:table.tsv:0:1,2,3"] (columns 1,2,3 as separate groups) + + Returns: + MultiReadGrouper if multiple specifications, otherwise single grouper + """ + if not hasattr(args, "read_group") or args.read_group is None: + return DefaultReadGrouper() + + # Handle both list (nargs='+') and string (backward compatibility) + if isinstance(args.read_group, str): + # Backward compatibility: semicolon-separated string + specs = args.read_group.split(';') else: - logger.critical("Unsupported read grouping option") + # Native list from nargs='+' + specs = args.read_group + + groupers = [] + for spec in specs: + spec = spec.strip() + if spec: + grouper = parse_grouping_spec(spec, args, sample, chr_id) + if grouper: + groupers.append(grouper) + + if not groupers: + logger.warning("No valid groupers specified, using default") return DefaultReadGrouper() + elif len(groupers) == 1: + return groupers[0] + else: + return MultiReadGrouper(groupers) + + +def get_grouping_strategy_names(args) -> list: + """ + Extract descriptive names for each grouping strategy from args.read_group. + + Returns a list of strategy names like: ["tag_CB", "tag_UB", "file_name", "read_id"] + If no read_group is specified, returns ["default"]. + + For multi-column TSV files with N columns, returns N separate names like: + ["file_col1", "file_col2", "file_col3"] + """ + if not hasattr(args, "read_group") or args.read_group is None: + return ["default"] + + # Handle both list (nargs='+') and string (backward compatibility) + if isinstance(args.read_group, str): + specs = args.read_group.split(';') + else: + specs = args.read_group + + strategy_names = [] + for spec in specs: + spec = spec.strip() + if not spec: + continue + + values = spec.split(':') + spec_type = values[0] + + if spec_type == "file_name": + strategy_names.append("file_name") + elif spec_type == 'barcode_spot': + strategy_names.append("barcode_spot") + elif spec_type == 'tag': + tag_name = values[1] if len(values) > 1 else "RG" + strategy_names.append(f"tag_{tag_name}") + elif spec_type == 'read_id': + delim = values[1] if len(values) > 1 else "_" + # Sanitize delimiter for filename + safe_delim = delim.replace('/', '_').replace('\\', '_') + strategy_names.append(f"read_id_{safe_delim}") + elif spec_type == 'file': + # Check if multi-column + if len(values) >= 4 and ',' in values[3]: + group_col_indices = values[3].split(',') + for col_idx in group_col_indices: + strategy_names.append(f"file_col{col_idx}") + else: + col_idx = values[3] if len(values) > 3 else "1" + strategy_names.append(f"file_col{col_idx}") + + return strategy_names if strategy_names else ["default"] def load_table(table_tsv_file, read_id_column_index, group_id_column_index, delim): @@ -201,3 +503,64 @@ def split_read_group_table(table_file, sample, read_id_column_index, group_id_co for f in read_group_files.values(): f.close() + + +def load_multicolumn_table(table_tsv_file, read_id_column_index, group_id_column_indices, delim): + min_columns = max(read_id_column_index, max(group_id_column_indices)) + _, outer_ext = os.path.splitext(table_tsv_file) + if outer_ext.lower() in ['.gz', '.gzip']: + handle = gzip.open(table_tsv_file, "rt") + else: + handle = open(table_tsv_file, 'r') + + read_map = {} + for line in handle: + line = line.strip() + if line.startswith('#') or not line: + continue + + column_values = line.split(delim) + if len(column_values) <= min_columns: + logger.warning("Malformed input read information table, minimum, of %d columns expected, " + "file %s, line: %s" % (min_columns, table_tsv_file, line)) + continue + + read_id = column_values[read_id_column_index] + if read_id in read_map: + logger.warning("Duplicate information for read %s" % read_id) + + column_data = [column_values[i] for i in group_id_column_indices] + read_map[read_id] = column_data + return read_map + + +def split_table(table_file, sample, out_prefix, read_id_column_index, group_id_column_indices, delim): + read_groups = load_multicolumn_table(table_file, read_id_column_index, group_id_column_indices, delim) + read_group_files = {} + processed_reads = defaultdict(set) + bam_files = list(map(lambda x: x[0], sample.file_list)) + + for bam_file in bam_files: + bam = pysam.AlignmentFile(bam_file, "rb") + for chr_id in bam.references: + if chr_id not in read_group_files: + read_group_files[chr_id] = open(out_prefix + "_" + chr_id, "w") + for read_alignment in bam: + chr_id = read_alignment.reference_name + if not chr_id: + continue + + read_id = read_alignment.query_name + if read_id in read_groups and read_id not in processed_reads[chr_id]: + # read_groups[read_id] is a list, join with delimiter + group_values = delim.join(read_groups[read_id]) + read_group_files[chr_id].write("%s\t%s\n" % (read_id, group_values)) + processed_reads[chr_id].add(read_id) + + for f in read_group_files.values(): + f.close() + + +def prepare_barcoded_reads(args, sample): + logger.info("Splitting barcoded reads %s for better memory consumption" % sample.out_barcodes_tsv) + split_table(sample.out_barcodes_tsv, sample, sample.barcodes_split_reads, 0, [1, 2, 3, 4], '\t') diff --git a/src/read_mapper.py b/src/read_mapper.py index 38828e73..f2f5cefc 100644 --- a/src/read_mapper.py +++ b/src/read_mapper.py @@ -71,8 +71,8 @@ def map_reads(self, args): bam_files.append([bam_file]) if fastq_file in sample.readable_names_dict: readable_names_dict[bam_file] = sample.readable_names_dict[fastq_file] + samples.append(SampleData(bam_files, sample.prefix, sample.out_dir, readable_names_dict, sample.illumina_bam, sample.barcoded_reads)) - samples.append(SampleData(bam_files, sample.prefix, sample.out_dir, readable_names_dict, sample.illumina_bam)) args.input_data.samples = samples args.input_data.input_type = InputDataType.bam return args.input_data diff --git a/src/stats.py b/src/stats.py index 7bb93eb1..05f2ae2c 100644 --- a/src/stats.py +++ b/src/stats.py @@ -32,7 +32,7 @@ def print_start(self, header_string=""): if header_string: logger.info(header_string) for e in sorted(self.stats_dict.keys(), key=lambda x: x.name): - logger.info("%s: %d" % (e.name, self.stats_dict[e])) + logger.info(" %s: %d" % (e.name, self.stats_dict[e])) def dump(self, out_file): pickler = pickle.Pickler(open(out_file, "wb"), -1) diff --git a/src/table_splitter.py b/src/table_splitter.py new file mode 100644 index 00000000..2432ec07 --- /dev/null +++ b/src/table_splitter.py @@ -0,0 +1,282 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + + +import pysam +from concurrent.futures import ProcessPoolExecutor +import logging +import gzip +import os + + +logger = logging.getLogger('IsoQuant') + + +def get_chromosome_read_counts(bam_files, chromosomes): + """ + Quickly get read counts per chromosome from BAM indices. + Uses pysam.get_index_statistics() which reads from .bai without scanning BAM. + + Returns: dict mapping chr_id -> total_read_count + """ + chr_counts = {chr_id: 0 for chr_id in chromosomes} + + for bam_file in bam_files: + try: + bam = pysam.AlignmentFile(bam_file, "rb") + try: + # Get statistics from index + stats = bam.get_index_statistics() + for stat in stats: + chr_id = stat.contig + if chr_id in chr_counts: + # Use mapped + unmapped counts + chr_counts[chr_id] += stat.mapped + stat.unmapped + finally: + bam.close() + except Exception as e: + logger.warning(f"Could not read index statistics from {bam_file}: {e}") + # Fallback: estimate from BAM header + try: + bam = pysam.AlignmentFile(bam_file, "rb") + for chr_id in chromosomes: + if chr_id in bam.references: + # Rough estimate: count via fetch (slower but works without index stats) + try: + count = bam.count(chr_id) + chr_counts[chr_id] += count + except: + pass + bam.close() + except: + pass + + return chr_counts + + +def distribute_chromosomes_weighted(chromosomes, num_workers, chr_read_counts): + """ + Distribute chromosomes to workers using greedy bin packing. + Balances load by assigning large chromosomes first. + + Args: + chromosomes: list of chromosome IDs + num_workers: number of worker processes + chr_read_counts: dict mapping chr_id -> read_count + + Returns: + List of lists, where workers[i] contains chromosome IDs for worker i + """ + if not chromosomes: + return [[] for _ in range(num_workers)] + + # Sort chromosomes by read count (descending) + sorted_chrs = sorted(chromosomes, key=lambda c: chr_read_counts.get(c, 0), reverse=True) + + # Initialize workers + workers = [[] for _ in range(num_workers)] + worker_loads = [0] * num_workers + + # Greedy assignment: assign each chromosome to least-loaded worker + for chr_id in sorted_chrs: + min_worker_idx = worker_loads.index(min(worker_loads)) + workers[min_worker_idx].append(chr_id) + worker_loads[min_worker_idx] += chr_read_counts.get(chr_id, 0) + + # Log distribution for debugging + for i, (chrs, load) in enumerate(zip(workers, worker_loads)): + if chrs: + logger.debug(f"Worker {i}: {len(chrs)} chromosomes, {load:,} reads") + + return workers + + +def collect_chromosome_reads(chr_id, bam_files): + """ + Collect all read IDs for a specific chromosome by scanning BAM files. + """ + read_ids = set() + for bam_file in bam_files: + try: + bam = pysam.AlignmentFile(bam_file, "rb") + try: + for read in bam.fetch(chr_id): + read_ids.add(read.query_name) + finally: + bam.close() + except Exception as e: + logger.warning(f"Could not fetch reads from {bam_file} for chromosome {chr_id}: {e}") + + return read_ids + + +def process_table_for_chromosomes(worker_id, input_tsvs, my_chromosomes, bam_files, + output_files, read_column, group_columns, delim): + """ + Worker function: processes entire table for assigned chromosomes. + + Each worker: + 1. Builds read ID cache for its assigned chromosomes + 2. Streams through table line-by-line (memory efficient) + 3. Writes matching reads to chromosome-specific output files + + Args: + worker_id: Worker identifier for logging + input_tsvs: TSV file(s) to read from + my_chromosomes: List of chromosome IDs assigned to this worker + bam_files: List of BAM files to scan for read IDs + output_files: Dict mapping chr_id -> output_file_path + read_column: Column index for read ID + group_columns: Tuple of column indices for group values + delim: Column delimiter + """ + if not my_chromosomes: + return 0, 0 + + logger.debug(f"Worker {worker_id}: processing {len(my_chromosomes)} chromosomes: {', '.join(my_chromosomes)}") + + # Step 1: Build read ID cache for assigned chromosomes + logger.debug(f"Worker {worker_id}: building read ID cache...") + read_cache = {} + for chr_id in my_chromosomes: + read_cache[chr_id] = collect_chromosome_reads(chr_id, bam_files) + logger.debug(f"Worker {worker_id}: cached {len(read_cache[chr_id]):,} reads for {chr_id}") + + # Step 2: Open output files + out_handles = {} + for chr_id in my_chromosomes: + out_handles[chr_id] = open(output_files[chr_id], 'w') + + # Step 3: Stream through table line-by-line + total_reads_processed = 0 + total_reads_written = 0 + min_columns = max(read_column, max(group_columns)) + 1 + + try: + input_files = input_tsvs if isinstance(input_tsvs, list) else [input_tsvs] + + for input_file in input_files: + # Handle gzipped files + _, ext = os.path.splitext(input_file) + if ext.lower() in ['.gz', '.gzip']: + file_handle = gzip.open(input_file, 'rt') + else: + file_handle = open(input_file) + + try: + for line in file_handle: + if line.startswith("#") or not line.strip(): + continue + + total_reads_processed += 1 + columns = line.rstrip('\n').split(delim) + + if len(columns) < min_columns: + continue + + read_id = columns[read_column] + + # Check if this read belongs to any of my chromosomes + for chr_id in my_chromosomes: + if read_id in read_cache[chr_id]: + # Extract group values + group_vals = delim.join(columns[c] for c in group_columns) + out_handles[chr_id].write(f"{read_id}\t{group_vals}\n") + total_reads_written += 1 + break # Read can only be on one chromosome + finally: + file_handle.close() + finally: + # Close output files + for f in out_handles.values(): + f.close() + + logger.debug(f"Worker {worker_id}: processed {total_reads_processed:,} table entries, " + f"wrote {total_reads_written:,} reads") + + return total_reads_processed, total_reads_written + + +def split_read_table_parallel(sample, input_tsvs, split_reads_file_names, num_threads, + read_column=0, group_columns=(1,), delim='\t'): + """ + Improved parallel table splitting algorithm. + + Strategy: Assign chromosomes to workers, each worker streams table line-by-line for its chromosomes. + + Advantages over old algorithm: + - No redundant work (each line processed once by each worker) + - No chunking overhead (line-by-line streaming) + - Minimal memory usage (no intermediate dicts) + - Better memory distribution (each worker caches only its chromosomes) + - True parallelism (no generator bottleneck) + + Args: + sample: Sample object with file_list + input_tsvs: TSV file(s) containing read groups + split_reads_file_names: Dict mapping chr_id -> output_file_path + num_threads: Number of worker processes + read_column: Column index for read ID (default: 0) + group_columns: Tuple of column indices for group values (default: (1,)) + delim: Column delimiter (default: '\t') + """ + logger.info(f"Splitting table {input_tsvs} across {num_threads} workers") + + bam_files = list(map(lambda x: x[0], sample.file_list)) + chromosomes = list(split_reads_file_names.keys()) + + if not chromosomes: + logger.warning("No chromosomes to process") + return + + # Step 1: Get chromosome read counts from BAM indices for load balancing + logger.info("Analyzing chromosome sizes from BAM indices...") + chr_read_counts = get_chromosome_read_counts(bam_files, chromosomes) + + total_reads = sum(chr_read_counts.values()) + logger.info(f"Total reads across {len(chromosomes)} chromosomes: {total_reads:,}") + + # Step 2: Distribute chromosomes to workers (weighted by read count) + num_workers = min(num_threads, len(chromosomes)) # No point having more workers than chromosomes + chr_assignments = distribute_chromosomes_weighted(chromosomes, num_workers, chr_read_counts) + + # Filter out empty assignments + chr_assignments = [chrs for chrs in chr_assignments if chrs] + num_workers = len(chr_assignments) + + logger.info(f"Using {num_workers} workers to process {len(chromosomes)} chromosomes") + + # Step 3: Initialize output files (truncate) + for chr_id in chromosomes: + open(split_reads_file_names[chr_id], 'w').close() + + # Step 4: Launch workers + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for worker_id, my_chrs in enumerate(chr_assignments): + future = executor.submit( + process_table_for_chromosomes, + worker_id, + input_tsvs, + my_chrs, + bam_files, + split_reads_file_names, + read_column, + group_columns, + delim + ) + futures.append(future) + + # Wait for all workers and collect results + total_processed = 0 + total_written = 0 + for future in futures: + processed, written = future.result() + total_processed += processed + total_written += written + + logger.info(f"Table splitting complete: {total_written:,} reads written to {len(chromosomes)} chromosome files") diff --git a/src/umi_filter.py b/src/umi_filter.py new file mode 100644 index 00000000..3bebaa37 --- /dev/null +++ b/src/umi_filter.py @@ -0,0 +1,12 @@ +############################################################################ +# Copyright (c) 2023 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import logging +from .read_groups import load_multicolumn_table + +logger = logging.getLogger('IsoQuant') + + diff --git a/tests/console_test.py b/tests/console_test.py index a10d5ef5..582108a9 100644 --- a/tests/console_test.py +++ b/tests/console_test.py @@ -50,11 +50,12 @@ def test_clean_start(): assert result.returncode == 0 sample_folder = os.path.join(out_dir, sample_name) assert os.path.isdir(sample_folder) - resulting_files = ["exon_counts.tsv", "exon_grouped_counts.linear.tsv", "gene_counts.tsv", "gene_grouped_counts.tsv", - "intron_counts.tsv", "intron_grouped_counts.linear.tsv", + resulting_files = ["exon_counts.tsv", "exon_grouped.file_col1_counts.linear.tsv", + "gene_counts.tsv", "gene_grouped.file_col1_counts.tsv", + "intron_counts.tsv", "intron_grouped.file_col1_counts.linear.tsv", "corrected_reads.bed.gz", "read_assignments.tsv.gz", "novel_vs_known.SQANTI-like.tsv", - "transcript_counts.tsv", "transcript_grouped_counts.tsv", + "transcript_counts.tsv", "transcript_grouped.file_col1_counts.tsv", "discovered_transcript_counts.tsv", "transcript_models.gtf", "transcript_model_reads.tsv.gz", "discovered_transcript_tpm.tsv"] for f in resulting_files: diff --git a/tests/test_barcode_callers.py b/tests/test_barcode_callers.py new file mode 100644 index 00000000..4856406a --- /dev/null +++ b/tests/test_barcode_callers.py @@ -0,0 +1,484 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import pytest +from src.barcode_calling.barcode_callers import ( + BarcodeDetectionResult, + DoubleBarcodeDetectionResult, + StereoBarcodeDetectionResult, + TenXBarcodeDetectionResult, + SplittingBarcodeDetectionResult, + ReadStats, + increase_if_valid +) + + +class TestIncreaseIfValid: + """Test coordinate increment utility function.""" + + def test_increment_valid_positive(self): + """Test incrementing valid positive value.""" + assert increase_if_valid(10, 5) == 15 + + def test_increment_valid_zero(self): + """Test incrementing zero (treated as invalid).""" + assert increase_if_valid(0, 5) == 0 + + def test_increment_invalid_negative(self): + """Test incrementing -1 (invalid marker).""" + assert increase_if_valid(-1, 5) == -1 + + def test_increment_none(self): + """Test incrementing None.""" + assert increase_if_valid(None, 5) is None + + +class TestBarcodeDetectionResult: + """Test base barcode detection result class.""" + + def test_init_default(self): + """Test initialization with defaults.""" + result = BarcodeDetectionResult("read_001") + + assert result.read_id == "read_001" + assert result.barcode == BarcodeDetectionResult.NOSEQ + assert result.UMI == BarcodeDetectionResult.NOSEQ + assert result.BC_score == -1 + assert result.UMI_good is False + assert result.strand == "." + + def test_init_with_values(self): + """Test initialization with all values.""" + result = BarcodeDetectionResult( + read_id="read_001", + barcode="ACTGACTG", + UMI="GGGGGGGG", + BC_score=16, + UMI_good=True, + strand="+" + ) + + assert result.read_id == "read_001" + assert result.barcode == "ACTGACTG" + assert result.UMI == "GGGGGGGG" + assert result.BC_score == 16 + assert result.UMI_good is True + assert result.strand == "+" + + def test_is_valid_with_barcode(self): + """Test validity check with detected barcode.""" + result = BarcodeDetectionResult("read_001", barcode="ACTG") + assert result.is_valid() is True + + def test_is_valid_without_barcode(self): + """Test validity check without barcode.""" + result = BarcodeDetectionResult("read_001") + assert result.is_valid() is False + + def test_set_strand(self): + """Test setting strand.""" + result = BarcodeDetectionResult("read_001") + result.set_strand("+") + assert result.strand == "+" + + def test_str_format(self): + """Test string formatting.""" + result = BarcodeDetectionResult( + "read_001", "ACTG", "GGGG", 15, True, "+" + ) + output = str(result) + + assert "read_001" in output + assert "ACTG" in output + assert "GGGG" in output + assert "15" in output + assert "True" in output + assert "+" in output + + def test_header(self): + """Test TSV header.""" + header = BarcodeDetectionResult.header() + assert "#read_id" in header + assert "barcode" in header + assert "UMI" in header + assert "BC_score" in header + + +class TestDoubleBarcodeDetectionResult: + """Test double barcode detection result class.""" + + def test_init_default(self): + """Test initialization with defaults.""" + result = DoubleBarcodeDetectionResult("read_001") + + assert result.read_id == "read_001" + assert result.polyT == -1 + assert result.primer == -1 + assert result.linker_start == -1 + assert result.linker_end == -1 + + def test_init_with_positions(self): + """Test initialization with all positions.""" + result = DoubleBarcodeDetectionResult( + read_id="read_001", + barcode="ACTGACTG", + UMI="GGGG", + BC_score=14, + UMI_good=True, + strand="+", + polyT=100, + primer=50, + linker_start=60, + linker_end=75 + ) + + assert result.polyT == 100 + assert result.primer == 50 + assert result.linker_start == 60 + assert result.linker_end == 75 + + def test_update_coordinates(self): + """Test coordinate shifting.""" + result = DoubleBarcodeDetectionResult( + "read_001", + polyT=100, + primer=50, + linker_start=60, + linker_end=75 + ) + + result.update_coordinates(10) + + assert result.polyT == 110 + assert result.primer == 60 + assert result.linker_start == 70 + assert result.linker_end == 85 + + def test_update_coordinates_invalid(self): + """Test coordinate shifting with invalid values.""" + result = DoubleBarcodeDetectionResult("read_001") + result.update_coordinates(10) + + # Invalid values (-1) should remain unchanged + assert result.polyT == -1 + assert result.primer == -1 + + def test_more_informative_than_by_score(self): + """Test comparison by barcode score.""" + result1 = DoubleBarcodeDetectionResult("read_001", BC_score=16) + result2 = DoubleBarcodeDetectionResult("read_001", BC_score=14) + + assert result1.more_informative_than(result2) is True + assert result2.more_informative_than(result1) is False + + def test_more_informative_than_by_linker(self): + """Test comparison by linker position (tie-breaker).""" + result1 = DoubleBarcodeDetectionResult( + "read_001", BC_score=16, linker_start=80 + ) + result2 = DoubleBarcodeDetectionResult( + "read_001", BC_score=16, linker_start=60 + ) + + # Higher linker position is more informative + assert result1.more_informative_than(result2) is True + + def test_get_additional_attributes_all(self): + """Test attribute detection with all features.""" + result = DoubleBarcodeDetectionResult( + "read_001", + polyT=100, + primer=50, + linker_start=60 + ) + + attrs = result.get_additional_attributes() + + assert "PolyT detected" in attrs + assert "Primer detected" in attrs + assert "Linker detected" in attrs + + def test_get_additional_attributes_partial(self): + """Test attribute detection with some features.""" + result = DoubleBarcodeDetectionResult( + "read_001", + polyT=100, + # No primer or linker + ) + + attrs = result.get_additional_attributes() + + assert "PolyT detected" in attrs + assert "Primer detected" not in attrs + assert "Linker detected" not in attrs + + def test_str_format(self): + """Test string formatting includes positions.""" + result = DoubleBarcodeDetectionResult( + "read_001", "ACTG", "GGGG", 16, True, "+", + polyT=100, primer=50, linker_start=60, linker_end=75 + ) + output = str(result) + + assert "100" in output # polyT + assert "50" in output # primer + assert "60" in output # linker_start + assert "75" in output # linker_end + + +class TestStereoBarcodeDetectionResult: + """Test Stereo-seq detection result class.""" + + def test_init_with_tso(self): + """Test initialization with TSO position.""" + result = StereoBarcodeDetectionResult( + "read_001", + polyT=100, + tso=150 + ) + + assert result.tso5 == 150 + + def test_update_coordinates_includes_tso(self): + """Test coordinate shifting includes TSO.""" + result = StereoBarcodeDetectionResult( + "read_001", + polyT=100, + tso=150 + ) + + result.update_coordinates(10) + + assert result.tso5 == 160 + assert result.polyT == 110 + + def test_get_additional_attributes_with_tso(self): + """Test attribute detection includes TSO.""" + result = StereoBarcodeDetectionResult( + "read_001", + polyT=100, + tso=150 + ) + + attrs = result.get_additional_attributes() + + assert "PolyT detected" in attrs + assert "TSO detected" in attrs + + +class TestTenXBarcodeDetectionResult: + """Test 10x Genomics detection result class.""" + + def test_init_with_r1(self): + """Test initialization with R1 position.""" + result = TenXBarcodeDetectionResult( + "read_001", + polyT=100, + r1=20 + ) + + assert result.r1 == 20 + assert result.polyT == 100 + + def test_update_coordinates(self): + """Test coordinate shifting.""" + result = TenXBarcodeDetectionResult( + "read_001", + polyT=100, + r1=20 + ) + + result.update_coordinates(10) + + assert result.r1 == 30 + assert result.polyT == 110 + + def test_more_informative_than_by_polyt(self): + """Test comparison prioritizes polyT.""" + result1 = TenXBarcodeDetectionResult("read_001", polyT=100) + result2 = TenXBarcodeDetectionResult("read_001", polyT=80) + + assert result1.more_informative_than(result2) is True + + def test_more_informative_than_by_r1(self): + """Test comparison uses R1 as tie-breaker.""" + result1 = TenXBarcodeDetectionResult("read_001", polyT=100, r1=30) + result2 = TenXBarcodeDetectionResult("read_001", polyT=100, r1=20) + + assert result1.more_informative_than(result2) is True + + def test_get_additional_attributes(self): + """Test attribute detection.""" + result = TenXBarcodeDetectionResult( + "read_001", + polyT=100, + r1=20 + ) + + attrs = result.get_additional_attributes() + + assert "PolyT detected" in attrs + assert "R1 detected" in attrs + + +class TestSplittingBarcodeDetectionResult: + """Test splitting barcode detection result class.""" + + def test_init(self): + """Test initialization.""" + result = SplittingBarcodeDetectionResult("read_001") + + assert result.read_id == "read_001" + assert result.detected_patterns == [] + + def test_append(self): + """Test appending detection patterns.""" + result = SplittingBarcodeDetectionResult("read_001") + + pattern1 = StereoBarcodeDetectionResult("read_001", barcode="ACTG") + pattern2 = StereoBarcodeDetectionResult("read_001", barcode="TGCA") + + result.append(pattern1) + result.append(pattern2) + + assert len(result.detected_patterns) == 2 + + def test_empty_true(self): + """Test empty detection.""" + result = SplittingBarcodeDetectionResult("read_001") + assert result.empty() is True + + def test_empty_false(self): + """Test non-empty detection.""" + result = SplittingBarcodeDetectionResult("read_001") + pattern = StereoBarcodeDetectionResult("read_001", barcode="ACTG") + result.append(pattern) + + assert result.empty() is False + + def test_filter_keeps_barcoded(self): + """Test filter keeps results with barcodes.""" + result = SplittingBarcodeDetectionResult("read_001") + + barcoded = StereoBarcodeDetectionResult("read_001", barcode="ACTG") + unbarcoded = StereoBarcodeDetectionResult("read_001") # No barcode + + result.append(barcoded) + result.append(unbarcoded) + result.filter() + + # Should keep only barcoded result + assert len(result.detected_patterns) == 1 + assert result.detected_patterns[0].barcode == "ACTG" + + def test_filter_keeps_first_if_none_barcoded(self): + """Test filter keeps first result if none have barcodes.""" + result = SplittingBarcodeDetectionResult("read_001") + + unbarcoded1 = StereoBarcodeDetectionResult("read_001") + unbarcoded2 = StereoBarcodeDetectionResult("read_001") + + result.append(unbarcoded1) + result.append(unbarcoded2) + result.filter() + + # Should keep only first result + assert len(result.detected_patterns) == 1 + + +class TestReadStats: + """Test statistics tracker class.""" + + def test_init(self): + """Test initialization.""" + stats = ReadStats() + + assert stats.read_count == 0 + assert stats.bc_count == 0 + assert stats.umi_count == 0 + assert len(stats.additional_attributes_counts) == 0 + + def test_add_read_with_barcode(self): + """Test adding read with valid barcode.""" + stats = ReadStats() + result = DoubleBarcodeDetectionResult( + "read_001", + barcode="ACTG", + UMI_good=True, + polyT=100 + ) + + stats.add_read(result) + + assert stats.read_count == 1 + assert stats.bc_count == 1 + assert stats.umi_count == 1 + assert stats.additional_attributes_counts["PolyT detected"] == 1 + + def test_add_read_without_barcode(self): + """Test adding read without barcode.""" + stats = ReadStats() + result = DoubleBarcodeDetectionResult("read_001") # No barcode + + stats.add_read(result) + + assert stats.read_count == 1 + assert stats.bc_count == 0 + assert stats.umi_count == 0 + + def test_add_multiple_reads(self): + """Test adding multiple reads.""" + stats = ReadStats() + + result1 = DoubleBarcodeDetectionResult("read_001", barcode="ACTG", polyT=100) + result2 = DoubleBarcodeDetectionResult("read_002", barcode="TGCA", primer=50) + result3 = DoubleBarcodeDetectionResult("read_003") # No barcode + + stats.add_read(result1) + stats.add_read(result2) + stats.add_read(result3) + + assert stats.read_count == 3 + assert stats.bc_count == 2 + assert stats.additional_attributes_counts["PolyT detected"] == 1 + assert stats.additional_attributes_counts["Primer detected"] == 1 + + def test_add_custom_stats(self): + """Test adding custom statistics.""" + stats = ReadStats() + + stats.add_custom_stats("Custom feature", 5) + stats.add_custom_stats("Custom feature", 3) + + assert stats.additional_attributes_counts["Custom feature"] == 8 + + def test_str_format(self): + """Test string formatting.""" + stats = ReadStats() + result = DoubleBarcodeDetectionResult("read_001", barcode="ACTG", UMI_good=True) + stats.add_read(result) + + output = str(stats) + + assert "Total reads\t1" in output + assert "Barcode detected\t1" in output + assert "Reliable UMI\t1" in output + + def test_iter(self): + """Test iteration over statistics.""" + stats = ReadStats() + result = DoubleBarcodeDetectionResult("read_001", barcode="ACTG", polyT=100) + stats.add_read(result) + + lines = list(stats) + + assert "Total reads: 1" in lines + assert "Barcode detected: 1" in lines + assert "PolyT detected: 1" in lines + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_common.py b/tests/test_common.py index 607b2d4d..640d25b7 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,6 +5,7 @@ ############################################################################ import pytest +import unittest import src.common as c @@ -279,3 +280,96 @@ def test_empty(self): @pytest.mark.parametrize("element_list, sep, expected", [([1], ",", "1"), ([1, 0, 2], ":", "1:0:2")]) def test_not_empty(self, element_list, sep, expected): assert c.list_to_str(element_list, sep) == expected + + +class TestRangeOperations(unittest.TestCase): + def test_overlaps(self): + self.assertTrue(c.overlaps((100, 200), (150, 250))) + self.assertFalse(c.overlaps((100, 200), (201, 300))) + self.assertTrue(c.overlaps((100, 200), (200, 300))) + + def test_contains(self): + self.assertTrue(c.contains((100, 300), (150, 250))) + self.assertFalse(c.contains((150, 250), (100, 300))) + self.assertTrue(c.contains((100, 300), (100, 300))) + + def test_left_of(self): + self.assertTrue(c.left_of((100, 200), (201, 300))) + self.assertFalse(c.left_of((100, 200), (200, 300))) + + def test_interval_len(self): + self.assertEqual(c.interval_len((100, 200)), 101) + self.assertEqual(c.interval_len((100, 100)), 1) + + def test_intervals_total_length(self): + ranges = [(100, 200), (300, 400), (500, 600)] + expected = 101 + 101 + 101 + self.assertEqual(c.intervals_total_length(ranges), expected) + + def test_max_range(self): + range1 = (100, 200) + range2 = (150, 250) + result = c.max_range(range1, range2) + self.assertEqual(result, (100, 250)) + + def test_intersection_len(self): + self.assertEqual(c.intersection_len((100, 200), (150, 250)), 51) + self.assertEqual(c.intersection_len((100, 200), (300, 400)), 0) + + +class TestJunctions(unittest.TestCase): + def test_junctions_from_blocks(self): + blocks = [(100, 200), (300, 400), (500, 600)] + junctions = c.junctions_from_blocks(blocks) + expected = [(201, 299), (401, 499)] + self.assertEqual(junctions, expected) + + def test_junctions_from_single_exon(self): + blocks = [(100, 200)] + junctions = c.junctions_from_blocks(blocks) + self.assertEqual(junctions, []) + + def test_get_exons_from_junctions(self): + region = (100, 600) + introns = [(201, 299), (401, 499)] + exons = c.get_exons(region, introns) + expected = [(100, 200), (300, 400), (500, 600)] + self.assertEqual(exons, expected) + + +class TestProfileFunctions(unittest.TestCase): + def test_difference_in_present_features(self): + profile1 = [1, 1, 0, 1, -1] + profile2 = [1, -1, 0, 1, -1] + diff = c.difference_in_present_features(profile1, profile2) + self.assertEqual(diff, 1) # Only position 1 differs + + def test_has_overlapping_features(self): + profile1 = [1, 0, 1, 0] + profile2 = [0, 1, 1, 0] + self.assertTrue(c.has_overlapping_features(profile1, profile2)) + + profile3 = [1, 0, 0, 0] + profile4 = [0, 1, 1, 0] + self.assertFalse(c.has_overlapping_features(profile3, profile4)) + + +class TestUtilityFunctions(unittest.TestCase): + def test_get_first_best_empty_list(self): + result = c.get_first_best_from_sorted([]) + self.assertEqual(result, []) + + def test_rindex(self): + lst = [1, 2, 3, 2, 1] + self.assertEqual(c.rindex(lst, 2), 3) + self.assertEqual(c.rindex(lst, 1), 4) + + def test_argmin(self): + lst = [5, 2, 8, 1, 9] + self.assertEqual(c.argmin(lst), 3) + + def test_reverse_complement(self): + seq = "ATCG" + self.assertEqual(c.reverse_complement(seq), "CGAT") + seq = "AAAA" + self.assertEqual(c.reverse_complement(seq), "TTTT") \ No newline at end of file diff --git a/tests/test_common_utilities.py b/tests/test_common_utilities.py new file mode 100644 index 00000000..bec5ed20 --- /dev/null +++ b/tests/test_common_utilities.py @@ -0,0 +1,33 @@ + +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +import os +import tempfile +from src.common import * + + +class TestAdditionalCommon(unittest.TestCase): + def test_proper_plural_form(self): + self.assertEqual(proper_plural_form("read", 1), "1 read") + self.assertEqual(proper_plural_form("read", 2), "2 reads") + + def test_rreplace(self): + self.assertEqual(rreplace("test_string", "string", "replaced"), "test_replaced") + self.assertEqual(rreplace("test_string_string", "string", "replaced"), "test_string_replaced") + self.assertEqual(rreplace("no_match", "other", "replaced"), "no_match") + + def test_list_to_str(self): + self.assertEqual(list_to_str([1, 2, 3]), "1,2,3") + self.assertEqual(list_to_str(["a", "b", "c"]), "a,b,c") + self.assertEqual(list_to_str([]), ".") + + def test_get_path_to_program(self): + # Test existing program (assuming 'python' exists) + path = get_path_to_program("python") + self.assertTrue(os.path.exists(path)) diff --git a/tests/test_enum_stats.py b/tests/test_enum_stats.py new file mode 100644 index 00000000..0acee5f9 --- /dev/null +++ b/tests/test_enum_stats.py @@ -0,0 +1,73 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +import tempfile +import os +from src.stats import EnumStats +from enum import Enum + + +class TestEnum(Enum): + value1 = 1 + value2 = 2 + value3 = 3 + + +class TestEnumStats(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.temp_dir, "test_stats.txt") + + def tearDown(self): + if os.path.exists(self.temp_dir): + import shutil + shutil.rmtree(self.temp_dir) + + def test_add_single_value(self): + stats = EnumStats() + stats.add(TestEnum.value1) + self.assertEqual(stats.stats_dict[TestEnum.value1], 1) + + def test_add_multiple_values(self): + stats = EnumStats() + stats.add(TestEnum.value1) + stats.add(TestEnum.value1) + stats.add(TestEnum.value2) + self.assertEqual(stats.stats_dict[TestEnum.value1], 2) + self.assertEqual(stats.stats_dict[TestEnum.value2], 1) + + def test_add_with_count(self): + stats = EnumStats() + stats.add(TestEnum.value1, 5) + self.assertEqual(stats.stats_dict[TestEnum.value1], 5) + + def test_merge(self): + stats1 = EnumStats() + stats1.add(TestEnum.value1, 3) + stats1.add(TestEnum.value2, 2) + + stats2 = EnumStats() + stats2.add(TestEnum.value1, 2) + stats2.add(TestEnum.value3, 1) + + stats1.merge(stats2) + self.assertEqual(stats1.stats_dict[TestEnum.value1], 5) + self.assertEqual(stats1.stats_dict[TestEnum.value2], 2) + self.assertEqual(stats1.stats_dict[TestEnum.value3], 1) + + def test_dump_and_load(self): + stats = EnumStats() + stats.add(TestEnum.value1, 3) + stats.add(TestEnum.value2, 5) + + stats.dump(self.test_file) + self.assertTrue(os.path.exists(self.test_file)) + + loaded_stats = EnumStats(self.test_file) + self.assertEqual(loaded_stats.stats_dict[TestEnum.value1], 3) + self.assertEqual(loaded_stats.stats_dict[TestEnum.value2], 5) diff --git a/tests/test_file_naming.py b/tests/test_file_naming.py new file mode 100644 index 00000000..528d9533 --- /dev/null +++ b/tests/test_file_naming.py @@ -0,0 +1,39 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +from src.file_naming import * + + +class TestFileNaming(unittest.TestCase): + def test_saves_file_name(self): + base_name = "test_output" + chr_id = "chr1" + result = saves_file_name(base_name, chr_id) + self.assertIn(base_name, result) + + def test_multimappers_file_name(self): + base_name = "test_output" + chr_id = "chr1" + result = multimappers_file_name(base_name, chr_id) + self.assertIn(base_name, result) + self.assertIn("multimappers", result) + + def test_filtered_reads_file_name(self): + base_name = "test_output" + chr_id = "chr1" + result = filtered_reads_file_name(base_name, chr_id) + self.assertIn(base_name, result) + self.assertIn("filtered", result) + + def test_chromosome_id_sanitization(self): + # Test that special characters in chromosome IDs are handled + chr_id = "chr1:1000-2000" + result = saves_file_name("test", chr_id) + # Should not contain colons or dashes in unexpected places + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py new file mode 100644 index 00000000..bc782f80 --- /dev/null +++ b/tests/test_file_utils.py @@ -0,0 +1,62 @@ + +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +import os +import shutil +import tempfile +from src.file_utils import * + + +class TestFileUtils(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.chr_ids = ["chr1", "chr2", "chr3"] + + def tearDown(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_merge_file_list(self): + fname = "test_file.txt" + label = "file" + expected = ["test_file_chr1.txt", "test_file_chr2.txt", "test_file_chr3.txt"] + result = merge_file_list(fname, label, self.chr_ids) + self.assertEqual(result, expected) + + def test_merge_files(self): + # Create test files + test_files = [] + contents = ["#header\nchr1_content\n", "#header\nchr2_content\n", "#header\nchr3_content\n"] + for i, chr_id in enumerate(self.chr_ids): + fname = os.path.join(self.test_dir, f"test_file_{chr_id}.txt") + with open(fname, 'w') as f: + f.write(contents[i]) + test_files.append(fname) + + # Test merging with header + merged_file = os.path.join(self.test_dir, "test_file.txt") + with open(merged_file, 'w') as f: + merge_files(merged_file, "file", self.chr_ids, f, copy_header=True) + + with open(merged_file) as f: + content = f.read() + expected = "#header\nchr1_content\nchr2_content\nchr3_content\n" + self.assertEqual(content, expected) + + def test_normalize_path(self): + config_path = "/path/to/config/config.txt" + rel_path = "data/file.txt" + abs_path = "/absolute/path/file.txt" + + # Test relative path + expected_rel = "/path/to/config/data/file.txt" + self.assertEqual(normalize_path(config_path, rel_path), expected_rel) + + # Test absolute path + self.assertEqual(normalize_path(config_path, abs_path), abs_path) diff --git a/tests/test_gene_info.py b/tests/test_gene_info.py index 8674299c..80615faa 100644 --- a/tests/test_gene_info.py +++ b/tests/test_gene_info.py @@ -1,13 +1,14 @@ - ############################################################################ # Copyright (c) 2022-2024 University of Helsinki -# All Rights Reserved +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved # See file LICENSE for details. ############################################################################ -import gffutils import os -from src.gene_info import GeneInfo +import pytest +import gffutils +from src.gene_info import * class TestGeneInfo: @@ -41,11 +42,95 @@ def test_exon_profiles(self): assert gene_info.exon_profiles.features[0] == (1000, 1100) assert gene_info.exon_profiles.features[-1] == (9500, 10000) assert len(gene_info.exon_profiles.features) == 11 - assert gene_info.exon_profiles.profiles["ENSMUST00000001712.7"] == [1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1] + assert gene_info.exon_profiles.profiles["ENSMUST00000001712.7"] == [1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1] assert gene_info.exon_profiles.profiles["ENSMUST00000001714.7"] == [-2, -2, -2, -2, -2, -2, -2, 1, -2, -2, -2] assert gene_info.split_exon_profiles.features[2] == (2101, 2200) assert gene_info.split_exon_profiles.features[-1] == (9500, 10000) assert len(gene_info.split_exon_profiles.features) == 11 - assert gene_info.split_exon_profiles.profiles["ENSMUST00000001712.7"] == [1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1] - assert gene_info.split_exon_profiles.profiles["ENSMUST00000001714.7"] == [-2, -2, -2, -2, -2, -2, -2, 1, -2, -2, -2] \ No newline at end of file + assert gene_info.split_exon_profiles.profiles["ENSMUST00000001712.7"] == [1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1] + assert gene_info.split_exon_profiles.profiles["ENSMUST00000001714.7"] == [-2, -2, -2, -2, -2, -2, -2, 1, -2, -2, + -2] + + def test_from_region(self): + chr_id = "chr10" + start, end = 1000, 2000 + gene_info = GeneInfo.from_region(chr_id, start, end) + + assert gene_info.chr_id == chr_id + assert gene_info.start == start + assert gene_info.end == end + assert gene_info.db is None + assert gene_info.gene_db_list == [] + assert gene_info.empty() + + def test_from_models(self): + from src.gene_info import TranscriptModel + + # Create a simple transcript model + transcript_model = TranscriptModel("chr1", "+", "test_transcript", "test_gene", [(100, 200), (300, 400)], TranscriptModelType.known) + + gene_info = GeneInfo.from_models([transcript_model]) + + assert gene_info.chr_id == "chr1" + assert gene_info.start == 100 + assert gene_info.end == 400 + assert gene_info.gene_id_map["test_transcript"] == "test_gene" + assert gene_info.isoform_strands["test_transcript"] == "+" + assert gene_info.all_isoforms_exons["test_transcript"] == [(100, 200), (300, 400)] + + def test_serialization(self): + import io + from src.serialization import write_int, write_string + + # Create test gene info + gene_info = GeneInfo([self.gene_db], self.gffutils_db) + + # Serialize + outfile = io.BytesIO() + gene_info.serialize(outfile) + outfile.seek(0) + + # Deserialize + new_gene_info = GeneInfo.deserialize(outfile, self.gffutils_db) + + # Check key attributes + assert new_gene_info.chr_id == gene_info.chr_id + assert new_gene_info.start == gene_info.start + assert new_gene_info.end == gene_info.end + assert new_gene_info.delta == gene_info.delta + + def test_get_gene_regions(self): + gene_info = GeneInfo([self.gene_db], self.gffutils_db) + regions = gene_info.get_gene_regions() + + assert isinstance(regions, dict) + assert self.gene_db.id in regions + assert regions[self.gene_db.id] == (self.gene_db.start, self.gene_db.end) + + def test_empty(self): + gene_info = GeneInfo([self.gene_db], self.gffutils_db) + assert not gene_info.empty() + + # Test empty gene info + empty_gene_info = GeneInfo.from_region("chr1", 1, 100) + assert empty_gene_info.empty() + + def test_feature_properties(self): + gene_info = GeneInfo([self.gene_db], self.gffutils_db) + + # Test exon properties + exon_props = gene_info.exon_property_map + assert len(exon_props) == len(gene_info.exon_profiles.features) + for prop in exon_props: + assert isinstance(prop, FeatureInfo) + assert prop.chr_id == gene_info.chr_id + assert len(prop.gene_ids) > 0 + + # Test intron properties + intron_props = gene_info.intron_property_map + assert len(intron_props) == len(gene_info.intron_profiles.features) + for prop in intron_props: + assert isinstance(prop, FeatureInfo) + assert prop.chr_id == gene_info.chr_id + assert len(prop.gene_ids) > 0 diff --git a/tests/test_id_distributor.py b/tests/test_id_distributor.py new file mode 100644 index 00000000..ec1a772b --- /dev/null +++ b/tests/test_id_distributor.py @@ -0,0 +1,18 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +from src.id_policy import * + + +class TestSimpleIDDistributor(unittest.TestCase): + def test_get_next_id(self): + distributor = SimpleIDDistributor() + self.assertEqual(distributor.increment(), 1) + self.assertEqual(distributor.increment(), 2) + self.assertEqual(distributor.increment(), 3) + diff --git a/tests/test_intron_graph.py b/tests/test_intron_graph.py new file mode 100644 index 00000000..48f9f192 --- /dev/null +++ b/tests/test_intron_graph.py @@ -0,0 +1,374 @@ +# ############################################################################ +# # Copyright (c) 2022-2024 University of Helsinki +# # Copyright (c) 2019-2022 Saint Petersburg State University +# # # All Rights Reserved +# # See file LICENSE for details. +# ############################################################################ +# +# import unittest +# from unittest.mock import Mock +# from src.intron_graph import * +# from src.gene_info import GeneInfo +# from src.isoform_assignment import ReadAssignment, ReadAssignmentType, IsoformMatch, PolyAInfo +# from src.common import junctions_from_blocks +# +# +# class MockParams: +# """Mock parameters object for testing""" +# def __init__(self): +# self.delta = 6 +# self.graph_clustering_distance = 10 +# self.min_novel_intron_count = 2 +# self.min_novel_isolated_intron_abs = 3 +# self.min_novel_isolated_intron_rel = 0.02 +# self.terminal_position_abs = 1 +# self.terminal_position_rel = 0.05 +# self.singleton_adjacent_cov = 100 +# +# +# def create_mock_gene_info(exons_list, chr_id="chr1"): +# """ +# Create a mock GeneInfo object with given exons. +# exons_list: list of tuples representing exon coordinates +# """ +# gene_info = Mock(spec=GeneInfo) +# gene_info.chr_id = chr_id +# gene_info.start = min(e[0] for e in exons_list) +# gene_info.end = max(e[1] for e in exons_list) +# gene_info.all_isoforms_introns = {} +# gene_info.all_isoforms_exons = {"transcript1": exons_list} +# +# # Create intron profiles +# introns = junctions_from_blocks(exons_list) +# gene_info.intron_profiles = Mock() +# gene_info.intron_profiles.features = introns +# +# return gene_info +# +# +# def create_read_assignment(read_id, exon_blocks, chr_id="chr1", strand="+", +# assignment_type=ReadAssignmentType.unique, polyA_found=False): +# """ +# Create a ReadAssignment object for testing. +# """ +# match = IsoformMatch(None, None, None) +# match.assigned_gene = "gene1" +# match.assigned_transcript = "transcript1" +# +# read_assignment = ReadAssignment(read_id, assignment_type, match) +# read_assignment.chr_id = chr_id +# read_assignment.strand = strand +# read_assignment.mapped_strand = strand +# read_assignment.exons = exon_blocks +# read_assignment.corrected_exons = exon_blocks +# read_assignment.corrected_introns = junctions_from_blocks(exon_blocks) +# read_assignment.genomic_region = (exon_blocks[0][0], exon_blocks[-1][1]) +# read_assignment.polyA_found = polyA_found +# read_assignment.polya_info = PolyAInfo(-1, -1, -1, -1) +# read_assignment.read_group = "test_group" +# read_assignment.mapping_quality = 60 +# +# return read_assignment +# +# +# class TestIntronCollector(unittest.TestCase): +# """Test the IntronCollector class""" +# +# def setUp(self): +# self.exons = [(100, 200), (300, 400), (500, 600)] +# self.gene_info = create_mock_gene_info(self.exons) +# self.params = MockParams() +# +# def test_collector(self): +# """Test IntronCollector initialization""" +# collector = IntronCollector(self.gene_info, self.params) +# self.assertEqual(collector.gene_info, self.gene_info) +# self.assertEqual(collector.delta, self.params.delta) +# self.assertIsNotNone(collector.known_introns) +# +# def test_collect_introns_simple(self): +# """Test collecting introns from reads matching known structure""" +# collector = IntronCollector(self.gene_info, self.params) +# +# # Create reads matching the gene structure +# read1 = create_read_assignment("read1", self.exons) +# read2 = create_read_assignment("read2", self.exons) +# +# read_assignments = [read1, read2] +# all_introns = collector.collect_introns(read_assignments) +# +# # Should have collected the known introns +# expected_introns = junctions_from_blocks(self.exons) +# for intron in expected_introns: +# self.assertIn(intron, all_introns) +# +# def test_collect_novel_introns(self): +# """Test collecting novel introns""" +# collector = IntronCollector(self.gene_info, self.params) +# +# # Create read with novel intron +# novel_exons = [(100, 200), (350, 400), (500, 600)] # Different middle exon +# read1 = create_read_assignment("read1", novel_exons) +# read2 = create_read_assignment("read2", novel_exons) +# read3 = create_read_assignment("read3", novel_exons) +# +# read_assignments = [read1, read2, read3] +# all_introns = collector.collect_introns(read_assignments) +# +# # Should have collected the novel intron +# novel_introns = junctions_from_blocks(novel_exons) +# for intron in novel_introns: +# self.assertIn(intron, all_introns) +# +# +# class TestIntronGraph(unittest.TestCase): +# """Test the IntronGraph class""" +# +# def setUp(self): +# self.exons = [(100, 200), (300, 400), (500, 600)] +# self.gene_info = create_mock_gene_info(self.exons) +# self.params = MockParams() +# +# def test_initialization(self): +# """Test IntronGraph initialization""" +# read_assignments = [] +# graph = IntronGraph(read_assignments, self.gene_info, self.params) +# +# self.assertEqual(graph.gene_info, self.gene_info) +# self.assertEqual(graph.params, self.params) +# self.assertIsNotNone(graph.intron_collector) +# +# def test_simple_linear_graph(self): +# """Test graph construction for simple linear transcript""" +# # Create reads matching the gene structure +# reads = [ +# create_read_assignment("read1", self.exons), +# create_read_assignment("read2", self.exons), +# create_read_assignment("read3", self.exons), +# ] +# +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# +# # Graph should have vertices for each intron boundary +# introns = junctions_from_blocks(self.exons) +# self.assertGreater(len(graph.outgoing_edges), 0) +# +# def test_graph_with_alternative_splicing(self): +# """Test graph construction with alternative splicing""" +# # Create reads with different exon structures +# variant1_exons = [(100, 200), (300, 400), (500, 600)] +# variant2_exons = [(100, 200), (350, 450), (500, 600)] # Alternative middle exon +# +# reads = [ +# create_read_assignment("read1", variant1_exons), +# create_read_assignment("read2", variant1_exons), +# create_read_assignment("read3", variant2_exons), +# create_read_assignment("read4", variant2_exons), +# ] +# +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# +# # Graph should have branches for alternative splicing +# self.assertGreater(len(graph.outgoing_edges), 0) +# +# def test_monoexonic_transcript(self): +# """Test graph with monoexonic (single exon) transcript""" +# mono_exons = [(100, 500)] +# mono_gene_info = create_mock_gene_info(mono_exons) +# +# reads = [ +# create_read_assignment("read1", mono_exons), +# create_read_assignment("read2", mono_exons), +# ] +# +# graph = IntronGraph(reads, mono_gene_info, self.params) +# graph.construct() +# +# # Monoexonic should have minimal graph structure +# # No introns, so graph should be simple or empty +# self.assertIsNotNone(graph.outgoing_edges) +# +# def test_add_edge(self): +# """Test adding edges to the graph""" +# reads = [create_read_assignment("read1", self.exons)] +# graph = IntronGraph(reads, self.gene_info, self.params) +# +# # Add an edge +# v1 = (200, 'R') # Right end of first exon +# v2 = (300, 'L') # Left end of second exon +# graph.add_edge(v1, v2) +# +# self.assertIn(v1, graph.outgoing_edges) +# self.assertIn(v2, graph.outgoing_edges[v1]) +# self.assertEqual(graph.edge_weights[(v1, v2)], 1) +# +# def test_get_outgoing_edges(self): +# """Test retrieving outgoing edges""" +# reads = [ +# create_read_assignment("read1", self.exons), +# create_read_assignment("read2", self.exons), +# ] +# +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# +# # Find a vertex with outgoing edges +# if graph.outgoing_edges: +# vertex = list(graph.outgoing_edges.keys())[0] +# outgoing = graph.get_outgoing(vertex) +# self.assertIsInstance(outgoing, list) +# +# def test_is_isolated(self): +# """Test checking if a vertex is isolated""" +# reads = [create_read_assignment("read1", self.exons)] +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# +# # Create an isolated vertex +# isolated_vertex = (9999, 'L') +# result = graph.is_isolated(isolated_vertex) +# # An isolated vertex has no connections +# self.assertTrue(result or isolated_vertex not in graph.outgoing_edges) +# +# +# class TestTerminalVertexFunctions(unittest.TestCase): +# """Test terminal vertex helper functions""" +# +# def test_is_terminal_vertex(self): +# """Test identifying terminal vertices""" +# self.assertTrue(is_terminal_vertex(VERTEX_polya)) +# self.assertTrue(is_terminal_vertex(VERTEX_read_end)) +# self.assertFalse(is_terminal_vertex((100, 'L'))) +# +# def test_is_starting_vertex(self): +# """Test identifying starting vertices""" +# self.assertTrue(is_starting_vertex(VERTEX_polyt)) +# self.assertTrue(is_starting_vertex(VERTEX_read_start)) +# self.assertFalse(is_starting_vertex((100, 'R'))) +# +# +# class TestComplexGraphScenarios(unittest.TestCase): +# """Test complex graph construction scenarios""" +# +# def setUp(self): +# self.params = MockParams() +# +# def test_multiple_isoforms(self): +# """Test graph with multiple isoforms from same gene""" +# # Create gene with two isoforms +# isoform1_exons = [(100, 200), (300, 400), (500, 600), (700, 800)] +# isoform2_exons = [(100, 200), (300, 400), (700, 800)] # Skips middle exon +# +# gene_info = create_mock_gene_info(isoform1_exons) +# +# reads = [ +# create_read_assignment("read1", isoform1_exons), +# create_read_assignment("read2", isoform1_exons), +# create_read_assignment("read3", isoform2_exons), +# create_read_assignment("read4", isoform2_exons), +# ] +# +# graph = IntronGraph(reads, gene_info, self.params) +# graph.construct() +# +# # Should have vertices for both isoforms +# self.assertGreater(len(graph.outgoing_edges), 0) +# +# def test_reads_with_polya(self): +# """Test graph construction with polyA-containing reads""" +# exons = [(100, 200), (300, 400), (500, 600)] +# gene_info = create_mock_gene_info(exons) +# +# reads = [ +# create_read_assignment("read1", exons, polyA_found=True), +# create_read_assignment("read2", exons, polyA_found=True), +# create_read_assignment("read3", exons, polyA_found=False), +# ] +# +# graph = IntronGraph(reads, gene_info, self.params) +# graph.construct() +# +# # PolyA reads should influence terminal positions +# self.assertIsNotNone(graph.terminal_known_positions) +# +# def test_partial_read_coverage(self): +# """Test reads that don't cover full transcript""" +# full_exons = [(100, 200), (300, 400), (500, 600), (700, 800)] +# gene_info = create_mock_gene_info(full_exons) +# +# # Reads covering different parts +# reads = [ +# create_read_assignment("read1", [(100, 200), (300, 400)]), # First half +# create_read_assignment("read2", [(500, 600), (700, 800)]), # Second half +# create_read_assignment("read3", full_exons), # Full length +# ] +# +# graph = IntronGraph(reads, gene_info, self.params) +# graph.construct() +# +# # Should handle partial coverage +# self.assertIsNotNone(graph.outgoing_edges) +# +# def test_noisy_introns(self): +# """Test graph with low-coverage (noisy) introns""" +# main_exons = [(100, 200), (300, 400), (500, 600)] +# gene_info = create_mock_gene_info(main_exons) +# +# # Most reads follow main pattern +# reads = [create_read_assignment(f"read{i}", main_exons) for i in range(10)] +# +# # One read has a noisy intron +# noisy_exons = [(100, 200), (320, 380), (500, 600)] +# reads.append(create_read_assignment("noisy_read", noisy_exons)) +# +# graph = IntronGraph(reads, gene_info, self.params) +# graph.construct() +# +# # Noisy intron should be filtered or have low weight +# self.assertIsNotNone(graph.edge_weights) +# +# +# class TestGraphSimplification(unittest.TestCase): +# """Test graph simplification methods""" +# +# def setUp(self): +# self.params = MockParams() +# self.exons = [(100, 200), (300, 400), (500, 600)] +# self.gene_info = create_mock_gene_info(self.exons) +# +# def test_simplify_graph(self): +# """Test graph simplification""" +# reads = [create_read_assignment(f"read{i}", self.exons) for i in range(5)] +# +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# graph.simplify() +# +# # After simplification, graph should still be valid +# self.assertIsNotNone(graph.outgoing_edges) +# +# def test_remove_low_coverage_edges(self): +# """Test removal of low-coverage edges""" +# main_exons = [(100, 200), (300, 400), (500, 600)] +# +# # High coverage for main path +# reads = [create_read_assignment(f"read{i}", main_exons) for i in range(20)] +# +# # Low coverage alternative +# alt_exons = [(100, 200), (350, 450), (500, 600)] +# reads.append(create_read_assignment("alt_read", alt_exons)) +# +# graph = IntronGraph(reads, self.gene_info, self.params) +# graph.construct() +# +# # Low coverage edges may be removed during simplification +# initial_edge_count = len(graph.edge_weights) +# graph.simplify() +# # Simplification may reduce edges +# self.assertLessEqual(len(graph.edge_weights), initial_edge_count + 10) +# +# +# if __name__ == '__main__': +# unittest.main() diff --git a/tests/test_iso_quant_mode.py b/tests/test_iso_quant_mode.py new file mode 100644 index 00000000..617e616f --- /dev/null +++ b/tests/test_iso_quant_mode.py @@ -0,0 +1,52 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +from src.modes import * + + +class TestModes(unittest.TestCase): + def test_needs_barcode_calling(self): + # Test modes that need barcode calling + self.assertTrue(IsoQuantMode.tenX_v3.needs_barcode_calling()) + self.assertTrue(IsoQuantMode.curio.needs_barcode_calling()) + self.assertTrue(IsoQuantMode.stereoseq.needs_barcode_calling()) + + # Test mode that doesn't need barcode calling + self.assertFalse(IsoQuantMode.bulk.needs_barcode_calling()) + + def test_needs_pcr_deduplication(self): + # Test modes that need PCR deduplication + self.assertTrue(IsoQuantMode.tenX_v3.needs_pcr_deduplication()) + self.assertTrue(IsoQuantMode.visium_hd.needs_pcr_deduplication()) + self.assertTrue(IsoQuantMode.stereoseq.needs_pcr_deduplication()) + + # Test mode that doesn't need PCR deduplication + self.assertFalse(IsoQuantMode.bulk.needs_pcr_deduplication()) + + def test_produces_new_fasta(self): + # Only stereoseq mode produces new fasta + self.assertTrue(IsoQuantMode.stereoseq.produces_new_fasta()) + + # Test other modes + self.assertFalse(IsoQuantMode.bulk.produces_new_fasta()) + self.assertFalse(IsoQuantMode.tenX_v3.produces_new_fasta()) + self.assertFalse(IsoQuantMode.curio.produces_new_fasta()) + + def test_needs_barcode_iterator(self): + # Test modes that need barcode iterator + self.assertTrue(IsoQuantMode.stereoseq.needs_barcode_iterator()) + self.assertTrue(IsoQuantMode.stereoseq_nosplit.needs_barcode_iterator()) + + # Test modes that don't need barcode iterator + self.assertFalse(IsoQuantMode.bulk.needs_barcode_iterator()) + self.assertFalse(IsoQuantMode.tenX_v3.needs_barcode_iterator()) + + def test_enforces_single_thread(self): + # All modes should return False for enforces_single_thread + for mode in IsoQuantMode: + self.assertFalse(mode.enforces_single_thread()) \ No newline at end of file diff --git a/tests/test_junction_comparator.py b/tests/test_junction_comparator.py new file mode 100644 index 00000000..8182b804 --- /dev/null +++ b/tests/test_junction_comparator.py @@ -0,0 +1,33 @@ +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import unittest +from src.common import junctions_from_blocks +from src.junction_comparator import * + + +class TestJunctionComparator(unittest.TestCase): + def test_overlaps_at_least(self): + junction1 = (100, 200) + junction2 = (150, 250) + self.assertTrue(overlaps_at_least(junction1, junction2, 10)) + + def test_equal_ranges_with_delta(self): + junction1 = (100, 200) + junction2 = (102, 198) + self.assertTrue(equal_ranges(junction1, junction2, delta=5)) + self.assertFalse(equal_ranges(junction1, junction2, delta=1)) + + def test_junctions_from_blocks(self): + blocks = [(100, 200), (300, 400), (500, 600)] + junctions = junctions_from_blocks(blocks) + self.assertEqual(junctions, [(201, 299), (401, 499)]) + + def test_junctions_from_single_block(self): + blocks = [(100, 200)] + junctions = junctions_from_blocks(blocks) + self.assertEqual(junctions, []) diff --git a/tests/test_kmer_indexer.py b/tests/test_kmer_indexer.py new file mode 100644 index 00000000..bed72aed --- /dev/null +++ b/tests/test_kmer_indexer.py @@ -0,0 +1,263 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import pytest +from src.barcode_calling.kmer_indexer import KmerIndexer, ArrayKmerIndexer, Array2BitKmerIndexer +from src.barcode_calling.common import str_to_2bit + + +class TestKmerIndexer: + """Test basic dictionary-based KmerIndexer.""" + + def test_init(self): + """Test indexer initialization.""" + barcodes = ["ACTG", "TGCA", "GGGG"] + indexer = KmerIndexer(barcodes, kmer_size=3) + + assert len(indexer.seq_list) == 3 + assert indexer.k == 3 + assert not indexer.empty() + + def test_empty(self): + """Test empty index detection.""" + indexer = KmerIndexer([], kmer_size=3) + assert indexer.empty() + + def test_get_kmers(self): + """Test k-mer generation.""" + indexer = KmerIndexer([], kmer_size=3) + kmers = list(indexer._get_kmers("ACTGAC")) + + assert len(kmers) == 4 + assert kmers == ["ACT", "CTG", "TGA", "GAC"] + + def test_get_kmers_short_sequence(self): + """Test k-mer generation for short sequence.""" + indexer = KmerIndexer([], kmer_size=5) + kmers = list(indexer._get_kmers("ACT")) + + assert len(kmers) == 0 + + def test_exact_match(self): + """Test finding exact match.""" + barcodes = ["ACTGACTG", "TGCATGCA", "GGGGGGGG"] + indexer = KmerIndexer(barcodes, kmer_size=4) + + results = indexer.get_occurrences("ACTGACTG") + + assert len(results) > 0 + assert results[0][0] == "ACTGACTG" + assert results[0][1] > 0 # Shared k-mer count + + def test_similar_match(self): + """Test finding similar sequence.""" + barcodes = ["ACTGACTG", "TGCATGCA"] + indexer = KmerIndexer(barcodes, kmer_size=4) + + # Query with 1 mismatch + results = indexer.get_occurrences("ACTGACTT") + + assert len(results) > 0 + # Should find ACTGACTG as closest match + assert results[0][0] == "ACTGACTG" + + def test_min_kmers_filter(self): + """Test minimum k-mer threshold.""" + barcodes = ["ACTGACTG", "TGCATGCA"] + indexer = KmerIndexer(barcodes, kmer_size=4) + + # Very different sequence + results = indexer.get_occurrences("CCCCCCCC", min_kmers=3) + + # Should have fewer results with higher threshold + assert len(results) <= 2 + + def test_max_hits(self): + """Test maximum hits limit.""" + barcodes = ["ACTG", "ACTC", "ACTA", "ACTT"] + indexer = KmerIndexer(barcodes, kmer_size=3) + + results = indexer.get_occurrences("ACTG", max_hits=2) + + assert len(results) <= 2 + + def test_ignore_equal(self): + """Test ignoring exact matches.""" + barcodes = ["ACTG", "TGCA"] + indexer = KmerIndexer(barcodes, kmer_size=3) + + results = indexer.get_occurrences("ACTG", ignore_equal=True) + + # Should not return the exact match + for seq, _, _ in results: + assert seq != "ACTG" + + def test_append(self): + """Test adding barcode dynamically.""" + barcodes = ["ACTG"] + indexer = KmerIndexer(barcodes, kmer_size=3) + + indexer.append("TGCA") + + assert len(indexer.seq_list) == 2 + results = indexer.get_occurrences("TGCA") + assert any(seq == "TGCA" for seq, _, _ in results) + + def test_result_format(self): + """Test result tuple format.""" + barcodes = ["ACTGACTG"] + indexer = KmerIndexer(barcodes, kmer_size=4) + + results = indexer.get_occurrences("ACTGACTG") + + assert len(results) > 0 + # Result should be (sequence, count, positions) + seq, count, positions = results[0] + assert isinstance(seq, str) + assert isinstance(count, int) + assert isinstance(positions, list) + + +class TestArrayKmerIndexer: + """Test optimized array-based KmerIndexer.""" + + def test_init(self): + """Test indexer initialization.""" + barcodes = ["ACTG", "TGCA"] + indexer = ArrayKmerIndexer(barcodes, kmer_size=3) + + assert len(indexer.seq_list) == 2 + assert indexer.k == 3 + assert len(indexer.index) == 64 # 4^3 + + def test_binary_encoding(self): + """Test nucleotide to binary conversion.""" + assert ArrayKmerIndexer.NUCL2BIN['A'] == 0 + assert ArrayKmerIndexer.NUCL2BIN['C'] == 1 + assert ArrayKmerIndexer.NUCL2BIN['G'] == 2 + assert ArrayKmerIndexer.NUCL2BIN['T'] == 3 + + def test_get_kmer_indexes(self): + """Test k-mer index generation.""" + indexer = ArrayKmerIndexer([], kmer_size=2) + # "AC" = 00 01 = 1 + # "CT" = 01 11 = 7 + kmer_idxs = list(indexer._get_kmer_indexes("ACT")) + + assert len(kmer_idxs) == 2 + assert kmer_idxs[0] == 1 # AC + assert kmer_idxs[1] == 7 # CT + + def test_exact_match(self): + """Test finding exact match.""" + barcodes = ["ACTGACTG", "TGCATGCA"] + indexer = ArrayKmerIndexer(barcodes, kmer_size=4) + + results = indexer.get_occurrences("ACTGACTG") + + assert len(results) > 0 + assert results[0][0] == "ACTGACTG" + + def test_consistency_with_basic(self): + """Test that results match basic KmerIndexer.""" + barcodes = ["ACTGACTG", "TGCATGCA", "GGGGGGGG"] + basic = KmerIndexer(barcodes, kmer_size=4) + array = ArrayKmerIndexer(barcodes, kmer_size=4) + + query = "ACTGACTT" + basic_results = basic.get_occurrences(query) + array_results = array.get_occurrences(query) + + # Should find same sequences (order/counts might differ slightly) + basic_seqs = set(seq for seq, _, _ in basic_results) + array_seqs = set(seq for seq, _, _ in array_results) + assert basic_seqs == array_seqs + + +class TestArray2BitKmerIndexer: + """Test memory-efficient 2-bit KmerIndexer.""" + + def test_init(self): + """Test indexer initialization with 2-bit sequences.""" + barcodes = ["ACTGACTGACTGACTGACTGACTGA"] # 25-mer + bin_seqs = [str_to_2bit(b) for b in barcodes] + indexer = Array2BitKmerIndexer(bin_seqs, kmer_size=6, seq_len=25) + + assert indexer.total_sequences == 1 + assert indexer.seq_len == 25 + assert indexer.k == 6 + assert not indexer.empty() + + def test_empty(self): + """Test empty index detection.""" + indexer = Array2BitKmerIndexer([], kmer_size=6, seq_len=25) + assert indexer.empty() + + def test_get_kmer_bin_indexes(self): + """Test k-mer extraction from 2-bit sequence.""" + # Create simple sequence + bin_seq = str_to_2bit("ACTGAA") + indexer = Array2BitKmerIndexer([], kmer_size=3, seq_len=6) + + kmer_idxs = list(indexer._get_kmer_bin_indexes(bin_seq)) + + assert len(kmer_idxs) == 4 # 6 - 3 + 1 + + def test_exact_match(self): + """Test finding exact match.""" + seq = "ACTGACTGACTGACTGACTGACTGA" + bin_seqs = [str_to_2bit(seq)] + indexer = Array2BitKmerIndexer(bin_seqs, kmer_size=8, seq_len=25) + + results = indexer.get_occurrences(seq) + + assert len(results) > 0 + assert results[0][0] == seq + + def test_similar_match(self): + """Test finding similar sequence.""" + barcode1 = "ACTGACTGACTGACTGACTGACTGA" + barcode2 = "TGCATGCATGCATGCATGCATGCAT" + bin_seqs = [str_to_2bit(b) for b in [barcode1, barcode2]] + indexer = Array2BitKmerIndexer(bin_seqs, kmer_size=8, seq_len=25) + + # Query with 1 mismatch from barcode1 + query = "ACTGACTGACTGACTGACTGACTGT" + results = indexer.get_occurrences(query) + + assert len(results) > 0 + # Should find barcode1 as closest + assert results[0][0] == barcode1 + + def test_max_hits(self): + """Test maximum hits limit.""" + barcodes = [ + "ACTGACTGACTGACTGACTGACTGA", + "ACTGACTGACTGACTGACTGACTGC", + "ACTGACTGACTGACTGACTGACTGG" + ] + bin_seqs = [str_to_2bit(b) for b in barcodes] + indexer = Array2BitKmerIndexer(bin_seqs, kmer_size=8, seq_len=25) + + results = indexer.get_occurrences("ACTGACTGACTGACTGACTGACTGA", max_hits=2) + + assert len(results) <= 2 + + def test_flat_index_structure(self): + """Test that index uses flat array structure.""" + barcodes = ["ACTGACTGACTGACTGACTGACTGA"] + bin_seqs = [str_to_2bit(b) for b in barcodes] + indexer = Array2BitKmerIndexer(bin_seqs, kmer_size=6, seq_len=25) + + # Index should be flat list + assert isinstance(indexer.index, list) + # Index ranges should map to flat index + assert isinstance(indexer.index_ranges, list) + assert len(indexer.index_ranges) == 4**6 + 1 # 4^k + 1 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_read_groups.py b/tests/test_read_groups.py new file mode 100644 index 00000000..913eefd4 --- /dev/null +++ b/tests/test_read_groups.py @@ -0,0 +1,145 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import pytest +import tempfile +import os +from src.read_groups import ( + AlignmentTagReadGrouper, + ReadIdSplitReadGrouper, + ReadTableGrouper, +) + + +# Mock alignment object +class MockAlignment: + def __init__(self, read_id="", tags=None): + self.query_name = read_id + if tags is None: + tags = {} + self._tags = tags + + def get_tag(self, tag_name): + return self._tags.get(tag_name) + + def has_tag(self, tag_name): + return tag_name in self._tags + + +class TestAlignmentTagReadGrouper: + """Test AlignmentTagReadGrouper class.""" + + def test_init_default_tag(self): + """Test AlignmentTagReadGrouper with default RG tag.""" + grouper = AlignmentTagReadGrouper() + assert grouper.tag == "RG" + + def test_init_custom_tag(self): + """Test AlignmentTagReadGrouper with custom tag.""" + grouper = AlignmentTagReadGrouper("CB") + assert grouper.tag == "CB" + + def test_get_group_with_tag(self): + """Test getting group when read has tag.""" + grouper = AlignmentTagReadGrouper("CB") + alignment = MockAlignment(tags={"CB": "ACTGACTG"}) + assert grouper.get_group_id(alignment, None) == "ACTGACTG" + + def test_get_group_missing_tag(self): + """Test getting group when read lacks tag.""" + grouper = AlignmentTagReadGrouper("CB") + alignment = MockAlignment() + result = grouper.get_group_id(alignment, None) + # Should return None + assert result is None + + +class TestReadIdSplitReadGrouper: + """Test ReadIdSplitReadGrouper class.""" + + def test_init(self): + """Test ReadIdSplitReadGrouper initialization.""" + grouper = ReadIdSplitReadGrouper("_") + assert grouper.delim == "_" + + def test_get_group_with_delimiter(self): + """Test getting group from read ID with delimiter.""" + grouper = ReadIdSplitReadGrouper("_") + + alignment1 = MockAlignment(read_id="read_001_groupA") + alignment2 = MockAlignment(read_id="read_002_groupB") + + assert grouper.get_group_id(alignment1) == "groupA" + assert grouper.get_group_id(alignment2) == "groupB" + + def test_get_group_no_delimiter(self): + """Test getting group from read ID without delimiter.""" + grouper = ReadIdSplitReadGrouper("_") + alignment = MockAlignment(read_id="read001") + + # Returns empty string if no delimiter found + result = grouper.get_group_id(alignment) + assert result == "" + + def test_get_group_multiple_delimiters(self): + """Test getting group with multiple delimiters.""" + grouper = ReadIdSplitReadGrouper("_") + alignment = MockAlignment(read_id="prefix_middle_suffix") + + # Should return last part after delimiter + assert grouper.get_group_id(alignment) == "suffix" + + +class TestReadTableGrouper: + """Test ReadTableGrouper class.""" + + def test_init_and_get_group(self): + """Test ReadTableGrouper initialization and file loading.""" + # Create temporary file with read-group mapping + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tsv') as f: + f.write("read_001\tgroupA\n") + f.write("read_002\tgroupB\n") + f.write("read_003\tgroupA\n") + temp_file = f.name + + try: + grouper = ReadTableGrouper(temp_file, + read_id_column_index=0, + group_id_column_index=1, + delim='\t') + + alignment1 = MockAlignment(read_id="read_001") + alignment2 = MockAlignment(read_id="read_002") + alignment3 = MockAlignment(read_id="read_003") + + assert grouper.get_group_id(alignment1) == "groupA" + assert grouper.get_group_id(alignment2) == "groupB" + assert grouper.get_group_id(alignment3) == "groupA" + finally: + os.unlink(temp_file) + + def test_missing_read(self): + """Test getting group for read not in file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tsv') as f: + f.write("read_001\tgroupA\n") + temp_file = f.name + + try: + grouper = ReadTableGrouper(temp_file, + read_id_column_index=0, + group_id_column_index=1, + delim='\t') + + alignment = MockAlignment(read_id="read_999") + # Should return "NA" for missing reads + result = grouper.get_group_id(alignment) + assert result == "NA" + finally: + os.unlink(temp_file) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a8b05c3c..477320ac 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,47 +1,46 @@ -import os +############################################################################ +# Copyright (c) 2022-2024 University of Helsinki +# Copyright (c) 2019-2022 Saint Petersburg State University +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ import pytest - +import io from src import serialization - -filename = "ser_test_file" +from src.serialization import * -@pytest.fixture(scope="class") +@pytest.fixture def setup_class(request): - request.cls.filehandler = open(filename, "xb") - yield - os.remove(filename) + request.cls.filehandler = io.BytesIO() + def teardown(): + request.cls.filehandler.close() -@pytest.fixture(autouse=True) -def run_after_each_test(request): - request.cls.filehandler = open(filename, "r+b") - yield - request.cls.filehandler.close() + request.addfinalizer(teardown) @pytest.mark.usefixtures("setup_class") class TestSerialization: - @pytest.mark.parametrize( "func_wr, func_rd, value", [ - (serialization.write_int, serialization.read_int, 1), - (serialization.write_int, serialization.read_int, 0), - (serialization.write_int, serialization.read_int, 65536), - (serialization.write_short_int, serialization.read_short_int, 0), - (serialization.write_short_int, serialization.read_short_int, 65535), - (serialization.write_string, serialization.read_string, ""), - (serialization.write_string, serialization.read_string, "quite a long string dont you think"), - (serialization.write_string_or_none, serialization.read_string_or_none, None), - (serialization.write_string_or_none, serialization.read_string_or_none, "None"), - (serialization.write_int_neg, serialization.read_int_neg, -1), - (serialization.write_int_neg, serialization.read_int_neg, -65536), - (serialization.write_bool_array, serialization.read_bool_array, [False]), - (serialization.write_dict, serialization.read_dict, {'key1': 'value1', 'key2': 'value2'}), - (serialization.write_dict, serialization.read_dict, {'key1': 0, 'key2': (655536, 1)}), - (serialization.write_dict, serialization.read_dict, {}) + (write_int, read_int, 1), + (write_int, read_int, 0), + (write_int, read_int, 65536), + (write_short_int, read_short_int, 0), + (write_short_int, read_short_int, 65535), + (write_string, read_string, ""), + (write_string, read_string, "quite a long string dont you think"), + (write_string_or_none, read_string_or_none, None), + (write_string_or_none, read_string_or_none, "None"), + (write_int_neg, read_int_neg, -1), + (write_int_neg, read_int_neg, -65536), + (write_bool_array, read_bool_array, [False]), + (write_dict, read_dict, {'key1': 'value1', 'key2': 'value2'}), + (write_dict, read_dict, {'key1': 0, 'key2': (655536, 1)}), + (write_dict, read_dict, {}) ] ) def test_write_read_positive(self, value, func_wr, func_rd): @@ -54,21 +53,21 @@ def test_write_read_positive(self, value, func_wr, func_rd): @pytest.mark.parametrize( "func_wr, func_rd, value", [ - (serialization.write_int, serialization.read_int, [65539, 0, 1]), - (serialization.write_int, serialization.read_int, []), - (serialization.write_string_or_none, serialization.read_string_or_none, ["str", "", None]), - (serialization.write_string_or_none, serialization.read_string_or_none, []), - (serialization.write_string, serialization.read_string, ["str", "quite a long string dont you think", ""]), - (serialization.write_string, serialization.read_string, []), - (serialization.write_int_neg, serialization.read_int_neg, [-6, -33333, 0]), - (serialization.write_int_neg, serialization.read_int_neg, []) + (write_int, read_int, [65539, 0, 1]), + (write_int, read_int, []), + (write_string_or_none, read_string_or_none, ["str", "", None]), + (write_string_or_none, read_string_or_none, []), + (write_string, read_string, ["str", "quite a long string dont you think", ""]), + (write_string, read_string, []), + (write_int_neg, read_int_neg, [-6, -33333, 0]), + (write_int_neg, read_int_neg, []) ] ) def test_write_read_lists_positive(self, func_wr, func_rd, value): - serialization.write_list(value, self.filehandler, func_wr) + write_list(value, self.filehandler, func_wr) self.filehandler.flush() self.filehandler.seek(0) - actual = serialization.read_list(self.filehandler, func_rd) + actual = read_list(self.filehandler, func_rd) assert actual == value @pytest.mark.parametrize( @@ -80,24 +79,56 @@ def test_write_read_lists_positive(self, func_wr, func_rd, value): ] ) def test_write_read_bool_array_positive(self, value): - serialization.write_bool_array(value, self.filehandler) + write_bool_array(value, self.filehandler) self.filehandler.flush() self.filehandler.seek(0) - actual = serialization.read_bool_array(self.filehandler, len(value)) + actual = read_bool_array(self.filehandler, len(value)) assert actual == value @pytest.mark.parametrize( "func_wr, func_rd, list_of_pairs", [ - (serialization.write_int, serialization.read_int, []), - (serialization.write_int, serialization.read_int, [(1225, 78854)]), - (serialization.write_int, serialization.read_int, [(1225, 78854), (1, 0), (65536, 65536)]), - (serialization.write_string, serialization.read_string, [("1225", "78854"), ("", " "), ("65536", "65536")]) + (write_int, read_int, []), + (write_int, read_int, [(1225, 78854)]), + (write_int, read_int, [(1225, 78854), (1, 0), (65536, 65536)]), + (write_string, read_string, [("1225", "78854"), ("", " "), ("65536", "65536")]) ] ) def test_write_read_list_of_pairs_positive(self, func_wr, func_rd, list_of_pairs): - serialization.write_list_of_pairs(list_of_pairs, self.filehandler, func_wr) + write_list_of_pairs(list_of_pairs, self.filehandler, func_wr) self.filehandler.flush() self.filehandler.seek(0) - actual = serialization.read_list_of_pairs(self.filehandler, func_rd) + actual = read_list_of_pairs(self.filehandler, func_rd) assert actual == list_of_pairs + + def test_read_write_list_with_custom_serializer(self): + class TestClass: + def __init__(self, value): + self.value = value + + def serialize(self, outfile): + write_int(self.value, outfile) + + @staticmethod + def deserialize(infile): + obj = TestClass(0) + obj.value = read_int(infile) + return obj + + test_objects = [TestClass(1), TestClass(2), TestClass(3)] + write_list(test_objects, self.filehandler, TestClass.serialize) + self.filehandler.flush() + self.filehandler.seek(0) + read_objects = read_list(self.filehandler, TestClass.deserialize) + + assert len(test_objects) == len(read_objects) + for orig, read in zip(test_objects, read_objects): + assert orig.value == read.value + + def test_termination_int(self): + write_int(TERMINATION_INT, self.filehandler) + self.filehandler.flush() + self.filehandler.seek(0) + value = read_int(self.filehandler) + assert value == TERMINATION_INT + diff --git a/tests/test_umi_filtering.py b/tests/test_umi_filtering.py new file mode 100644 index 00000000..ee49bca4 --- /dev/null +++ b/tests/test_umi_filtering.py @@ -0,0 +1,161 @@ +############################################################################ +# Copyright (c) 2025 University of Helsinki +# All Rights Reserved +# See file LICENSE for details. +############################################################################ + +import pytest +from collections import defaultdict +from src.barcode_calling.umi_filtering import ( + UMIFilter, + format_read_assignment_for_output, + create_transcript_info_dict +) +from src.isoform_assignment import ReadAssignment, ReadAssignmentType +from src.gene_info import GeneInfo + + +class TestFormatReadAssignmentForOutput: + """Test read assignment output formatting.""" + + def test_format_basic(self): + """Test basic formatting with all attributes.""" + # Create a ReadAssignment + read_assignment = ReadAssignment( + read_id="read_001", + assignment_type=ReadAssignmentType.unique + ) + read_assignment.chr_id = "chr1" + read_assignment.start = 1000 + read_assignment.end = 2000 + read_assignment.corrected_exons = [(1000, 1500), (1600, 2000)] + read_assignment.barcode = "ACTG" + read_assignment.umi = "GGGG" + read_assignment.strand = '+' + read_assignment.set_additional_attribute('transcript_type', 'known') + read_assignment.set_additional_attribute('polya_site', 1950) + read_assignment.set_additional_attribute('cell_type', 'neuron') + + output = format_read_assignment_for_output(read_assignment) + + # Verify all fields are in output + assert "read_001" in output + assert "ACTG" in output + assert "GGGG" in output + assert "known" in output + assert "neuron" in output + # polyA site formatting depends on matching_events, so just check output is non-empty + assert len(output) > 0 + + def test_format_missing_attributes(self): + """Test formatting with missing optional attributes.""" + read_assignment = ReadAssignment( + read_id="read_002", + assignment_type=ReadAssignmentType.ambiguous + ) + read_assignment.chr_id = "chr1" + read_assignment.start = 1000 + read_assignment.end = 2000 + read_assignment.corrected_exons = [(1000, 2000)] + read_assignment.barcode = "ACTG" + read_assignment.umi = "CCCC" + read_assignment.strand = '-' + + output = format_read_assignment_for_output(read_assignment) + + # Should handle missing attributes gracefully + assert "read_002" in output + assert "unknown" in output or "None" in output + + +class TestUMIFilter: + """Test UMIFilter class.""" + + @pytest.fixture + def umi_filter(self): + """Create a UMIFilter instance.""" + return UMIFilter(umi_length=10, edit_distance=3) + + def test_init(self, umi_filter): + """Test UMIFilter initialization.""" + assert umi_filter.umi_length == 10 + assert umi_filter.max_edit_distance == 3 + assert isinstance(umi_filter.stats, dict) + assert isinstance(umi_filter.unique_gene_barcode, set) + assert isinstance(umi_filter.selected_reads, set) + + def test_construct_umi_dict(self, umi_filter): + """Test UMI dictionary construction.""" + # Create mock ReadAssignments + read1 = ReadAssignment(read_id="read_001", assignment_type=ReadAssignmentType.unique) + read1.umi = "AAAA" + + read2 = ReadAssignment(read_id="read_002", assignment_type=ReadAssignmentType.unique) + read2.umi = "AAAA" + + read3 = ReadAssignment(read_id="read_003", assignment_type=ReadAssignmentType.unique) + read3.umi = "TTTT" + + molecule_list = [read1, read2, read3] + umi_dict = umi_filter._construct_umi_dict(molecule_list) + + # Should group by UMI + assert len(umi_dict) == 2 + assert "AAAA" in umi_dict + assert "TTTT" in umi_dict + assert len(umi_dict["AAAA"]) == 2 + assert len(umi_dict["TTTT"]) == 1 + + def test_construct_umi_dict_untrusted(self, umi_filter): + """Test UMI dict construction with untrusted UMIs.""" + read1 = ReadAssignment(read_id="read_001", assignment_type=ReadAssignmentType.unique) + read1.umi = "" # Untrusted UMI + + molecule_list = [read1] + umi_dict = umi_filter._construct_umi_dict(molecule_list) + + # Untrusted UMIs should be grouped separately + assert len(umi_dict) == 1 + assert "" in umi_dict or "read_001" in umi_dict + + def test_select_best_read(self, umi_filter): + """Test selecting best read from duplicates.""" + # Create reads with different qualities + read1 = ReadAssignment(read_id="read_001", assignment_type=ReadAssignmentType.unique) + read1.corrected_exons = [(1000, 1500), (1600, 2000)] + + read2 = ReadAssignment(read_id="read_002", assignment_type=ReadAssignmentType.ambiguous) + read2.corrected_exons = [(1000, 1500), (1600, 2000)] + + read3 = ReadAssignment(read_id="read_003", assignment_type=ReadAssignmentType.unique) + read3.corrected_exons = [(1000, 2000)] # Single exon, less informative + + duplicates = [read2, read1, read3] + best = umi_filter._process_duplicates(duplicates) + + # Should select unique assignment with more exons + assert len(best) >= 1 + # The best read should be unique type + assert any(r.assignment_type == ReadAssignmentType.unique for r in best) + + +class TestCreateTranscriptInfoDict: + """Test transcript info dictionary creation.""" + + @pytest.mark.skip(reason="Requires actual genedb file, complex to mock") + def test_create_empty_dict(self): + """Test creating dict with no genes.""" + # This function requires a real genedb file path, not a dict + # Skipping as it requires complex setup + pass + + @pytest.mark.skip(reason="Requires actual genedb file, complex to mock") + def test_create_dict_with_genes(self): + """Test creating dict with mock genes.""" + # This function requires a real genedb file path, not a dict + # Skipping as it requires complex setup + pass + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tox.ini b/tox.ini index 5414f3af..a9985505 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ skipsdist=True [testenv] deps = -r{toxinidir}/requirements_tests.txt -commands = pytest --cov --cov-branch +commands = pytest --cov --cov-branch --ignore=tests/console_test.py [coverage:run] -omit = venv/*, .tox/*, tests/* +omit = venv/*, .tox/*, tests/*, setup.py