diff --git a/.gitignore b/.gitignore index 21c6468..89ba335 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,8 @@ share/python-wheels/ *.egg target/ *.out + +# Local directories +stubs/ +test_files/ +test_scripts/ diff --git a/docs/source/notebooks/human_cerebellum.ipynb b/docs/source/notebooks/human_cerebellum.ipynb index 8be1529..776009a 100644 --- a/docs/source/notebooks/human_cerebellum.ipynb +++ b/docs/source/notebooks/human_cerebellum.ipynb @@ -803,7 +803,7 @@ "id": "405b984c-4131-4d17-a253-5d056caa922e", "metadata": {}, "source": [ - "Next, let's calculate the QC metrics using the `pycistopic qc` command." + "Next, let's calculate the QC metrics using the `pycistopic qc run` command." ] }, { @@ -813,7 +813,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pycistopic qc \\\n", + "!pycistopic qc run \\\n", " --fragments data/fragments.tsv.gz \\\n", " --regions outs/consensus_peak_calling/consensus_regions.bed \\\n", " --tss outs/qc/tss.bed \\\n", @@ -841,11 +841,11 @@ "\n", "pycistopic_qc_commands_filename = \"pycistopic_qc_commands.txt\"\n", "\n", - "# Create text file with all pycistopic qc command lines.\n", + "# Create text file with all pycistopic qc run command lines.\n", "with open(pycistopic_qc_commands_filename, \"w\") as fh:\n", " for sample, fragment_filename in fragments_dict.items():\n", " print(\n", - " \"pycistopic qc\",\n", + " \"pycistopic qc run\",\n", " f\"--fragments {fragment_filename}\",\n", " f\"--regions {regions_bed_filename}\",\n", " f\"--tss {tss_bed_filename}\",\n", @@ -935,7 +935,7 @@ "\n", "**Note:**\n", "\n", - "The `pycistopic qc` command will determine automatic thresholds for the minimum number of unique number of fragments and the minumum TSS enrichment.\n", + "The `pycistopic qc run` command will determine automatic thresholds for the minimum number of unique number of fragments and the minumum TSS enrichment.\n", "In case you want to change these thresholds or want to threhold based on FRIP, you can provide manually defined thresholds using the parameters:\n", "- unique_fragments_threshold\n", "- tss_enrichment_threshold\n", diff --git a/pyproject.toml b/pyproject.toml index 126de6c..e7c3db7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,13 +31,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Bio-Informatics", ] dependencies = [ - "numpy >= 1.20.3", + "numpy >= 1.20.3, < 2", "pandas == 1.5", - "polars >= 0.18.3", + "polars >= 1", "pyarrow >= 8.0.0", "pyranges < 0.0.128", + "numba", "ray", - "scatac_fragment_tools", + "scatac_fragment_tools >= 0.1.2", "scikit-learn", "lda", "matplotlib < 3.7", @@ -144,3 +145,7 @@ max-doc-length = 88 [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" + +[tool.mypy] +mypy_path = "$MYPY_CONFIG_FILE_DIR/stubs" +plugins = ["numpy.typing.mypy_plugin"] diff --git a/src/pycisTopic/cistopic_class.py b/src/pycisTopic/cistopic_class.py index 15da6ee..14a6145 100644 --- a/src/pycisTopic/cistopic_class.py +++ b/src/pycisTopic/cistopic_class.py @@ -14,10 +14,10 @@ get_position_index, non_zero_rows, prepare_tag_cells, - read_fragments_from_file, region_names_to_coordinates, subset_list, ) +from pycisTopic.fragments import read_fragments_to_pyranges from scipy import sparse if TYPE_CHECKING: @@ -813,7 +813,10 @@ def create_cistopic_object_from_fragments( if path_to_fragments is not None: log.info("Using fragments of provided pandas data frame") else: - fragments = read_fragments_from_file(path_to_fragments, use_polars=use_polars) + fragments = read_fragments_to_pyranges( + fragments_bed_filename=path_to_fragments, + engine = "polars" + ) if "Score" not in fragments.df: fragments_df = fragments.df diff --git a/src/pycisTopic/cli/subcommand/qc.py b/src/pycisTopic/cli/subcommand/qc.py index 3a97dee..1e749cb 100644 --- a/src/pycisTopic/cli/subcommand/qc.py +++ b/src/pycisTopic/cli/subcommand/qc.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import os +from typing import TYPE_CHECKING, Literal import polars as pl @@ -27,6 +28,7 @@ def qc( min_fragments_per_cb: int = 10, collapse_duplicates: bool = True, no_threads: int = 8, + engine: str | Literal["polars"] | Literal["pyarrow"] = "pyarrow", ) -> None: """ Compute quality check statistics from fragments file. @@ -85,6 +87,8 @@ def qc( probability density function (PDF) values for log10 unique fragments in peaks vs TSS enrichment, fractions of fragments in peaks and duplication ratio. Default: ``8`` + engine + Use Polars or pyarrow to read BED and fragment files (default: `pyarrow`). Returns ------- @@ -127,12 +131,13 @@ def format(self, record): regions_df_pl = read_bed_to_polars_df( bed_filename=regions_bed_filename, min_column_count=3, + engine=engine, ) logger.info(f'Loading fragments TSV file from "{fragments_tsv_filename}".') fragments_df_pl = read_fragments_to_polars_df( fragments_tsv_filename, - engine="pyarrow", + engine=engine, ) logger.info("Computing QC stats.") @@ -209,14 +214,14 @@ def format(self, record): fragments_stats_per_cb_for_otsu_threshold_df_pl.write_csv( f"{output_prefix}.fragments_stats_per_cb_for_otsu_thresholds.tsv", separator="\t", - has_header=True, + include_header=True, ) logger.info(f'Writing "{output_prefix}.cbs_for_otsu_thresholds.tsv".') fragments_stats_per_cb_for_otsu_threshold_df_pl.select(pl.col("CB")).write_csv( f"{output_prefix}.cbs_for_otsu_thresholds.tsv", separator="\t", - has_header=False, + include_header=False, ) logger.info(f'Writing "{output_prefix}.otsu_thresholds.tsv".') @@ -229,7 +234,45 @@ def format(self, record): logger.info("pycisTopic QC finished.") -def run_qc(args): +def qc_filter_barcodes( + sample_id: str, + pycistopic_qc_output_dir: str | Path, + unique_fragments_threshold: int | None = None, + tss_enrichment_threshold: float | None = None, + frip_threshold: float | None = None, +): + from pycisTopic.qc import get_barcodes_passing_qc_for_sample + + selected_cbs_filename = os.path.join( + pycistopic_qc_output_dir, + f"{sample_id}.min_fragments_{unique_fragments_threshold}_min_tss_{tss_enrichment_threshold}_min_frip_{frip_threshold}.cbs.txt", + ) + + print( + f'Writing selected cell barcodes for "{sample_id}" based on QC statistics with:\n' + f" - minimum unqiue fragments:\t{unique_fragments_threshold}\n" + f" - minimum TSS threshold:\t{tss_enrichment_threshold}\n" + f" - minimum FRiP threshold:\t{frip_threshold}\n" + f'to "{selected_cbs_filename}".' + ) + + barcodes_passing_filters, _ = get_barcodes_passing_qc_for_sample( + sample_id=sample_id, + pycistopic_qc_output_dir=pycistopic_qc_output_dir, + unique_fragments_threshold=unique_fragments_threshold, + tss_enrichment_threshold=tss_enrichment_threshold, + frip_threshold=frip_threshold, + use_automatic_thresholds=False, + ) + + pl.Series("CB", barcodes_passing_filters).to_frame().select(pl.col("CB")).write_csv( + selected_cbs_filename, + separator="\t", + include_header=False, + ) + + +def run_qc_run(args): qc( fragments_tsv_filename=args.fragments_tsv_filename, regions_bed_filename=args.regions_bed_filename, @@ -244,6 +287,17 @@ def run_qc(args): min_fragments_per_cb=args.min_fragments_per_cb, collapse_duplicates=args.collapse_duplicates, no_threads=args.threads, + engine=args.engine, + ) + + +def run_qc_filter_barcodes(args): + qc_filter_barcodes( + sample_id=args.sample_id, + pycistopic_qc_output_dir=args.pycistopic_qc_output_dir, + unique_fragments_threshold=args.unique_fragments_threshold, + tss_enrichment_threshold=args.tss_enrichment_threshold, + frip_threshold=args.frip_threshold, ) @@ -252,9 +306,20 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): "qc", help="Run QC statistics on fragment file.", ) - parser_qc.set_defaults(func=run_qc) + subparser_qc = parser_qc.add_subparsers( + title="QC", + dest="qc", + help="List of QC subcommands.", + ) + subparser_qc.required = True + + parser_qc_run = subparser_qc.add_parser( + "run", + help="Run QC statistics on fragment file.", + ) + parser_qc_run.set_defaults(func=run_qc_run) - parser_qc.add_argument( + parser_qc_run.add_argument( "-f", "--fragments", dest="fragments_tsv_filename", @@ -264,7 +329,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): help="Fragments TSV filename which contains scATAC fragments.", ) - parser_qc.add_argument( + parser_qc_run.add_argument( "-r", "--regions", dest="regions_bed_filename", @@ -277,7 +342,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - parser_qc.add_argument( + parser_qc_run.add_argument( "-t", "--tss", dest="tss_annotation_bed_filename", @@ -290,7 +355,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - parser_qc.add_argument( + parser_qc_run.add_argument( "-o", "--output", dest="output_prefix", @@ -300,7 +365,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): help="Output prefix to use for QC statistics parquet output files.", ) - parser_qc.add_argument( + parser_qc_run.add_argument( "--threads", dest="threads", action="store", @@ -314,10 +379,22 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): "Default: 8.", ) - group_qc_tss = parser_qc.add_argument_group( + parser_qc_run.add_argument( + "-e", + "--engine", + dest="engine", + action="store", + type=str, + choices=["polars", "pyarrow"], + required=False, + default="pyarrow", + help="Use Polars or pyarrow to read BED and fragment files. Default: pyarrow.", + ) + + group_qc_run_tss = parser_qc_run.add_argument_group( "TSS profile", "TSS profile statistics calculation settings." ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--tss_flank_window", dest="tss_flank_window", action="store", @@ -329,7 +406,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): "Default: 2000 (+/- 2000 bp).", ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--tss_smoothing_rolling_window", dest="tss_smoothing_rolling_window", action="store", @@ -339,7 +416,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): help="Rolling window used to smooth the cut sites signal. Default: 10.", ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--tss_minimum_signal_window", dest="tss_minimum_signal_window", action="store", @@ -356,7 +433,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--tss_window", dest="tss_window", action="store", @@ -370,7 +447,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--tss_min_norm", dest="tss_min_norm", action="store", @@ -385,7 +462,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - group_qc_tss.add_argument( + group_qc_run_tss.add_argument( "--use-pyranges", dest="use_genomic_ranges", action="store_false", @@ -396,7 +473,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - parser_qc.add_argument( + parser_qc_run.add_argument( "--min_fragments_per_cb", dest="min_fragments_per_cb", action="store", @@ -410,7 +487,7 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): """, ) - parser_qc.add_argument( + parser_qc_run.add_argument( "--dont-collapse_duplicates", dest="collapse_duplicates", action="store_false", @@ -421,3 +498,62 @@ def add_parser_qc(subparsers: _SubParsersAction[ArgumentParser]): Default: collapse duplicates. """, ) + + parser_qc_filter_barcodes = subparser_qc.add_parser( + "filter", + help="Filter cell barcodes based on QC statistics.", + ) + parser_qc_filter_barcodes.set_defaults(func=run_qc_filter_barcodes) + + parser_qc_filter_barcodes.add_argument( + "-s", + "--sample", + dest="sample_id", + action="store", + type=str, + required=True, + help="Sample ID for which to get list of cell barcodes based on QC statistics.", + ) + + parser_qc_filter_barcodes.add_argument( + "-o", + "--output", + dest="pycistopic_qc_output_dir", + action="store", + type=str, + required=True, + help='Output directory from "pycistopic run qc" which contains QC statistics parquet output files.', + ) + + parser_qc_filter_barcodes.add_argument( + "-f", + "--fragments", + dest="unique_fragments_threshold", + action="store", + type=int, + required=False, + default=1000, + help="Threshold for number of unique fragments in peaks. Default: 1000.", + ) + + parser_qc_filter_barcodes.add_argument( + "-t", + "--tss", + dest="tss_enrichment_threshold", + action="store", + type=float, + required=False, + default=5.0, + help="Threshold for TSS enrichment score. Default: 5.0.", + ) + + parser_qc_filter_barcodes.add_argument( + "-p", + "--frip", + dest="frip_threshold", + action="store", + type=float, + required=False, + default=0.0, + help="Threshold for fraction of reads in peaks (FRiP). Default: 0.0.", + ) diff --git a/src/pycisTopic/cli/subcommand/topic_modeling.py b/src/pycisTopic/cli/subcommand/topic_modeling.py index f301c4a..1577231 100644 --- a/src/pycisTopic/cli/subcommand/topic_modeling.py +++ b/src/pycisTopic/cli/subcommand/topic_modeling.py @@ -1,7 +1,9 @@ from __future__ import annotations +import logging import os import pickle +import sys import tempfile from argparse import ArgumentTypeError from typing import TYPE_CHECKING @@ -10,18 +12,18 @@ from argparse import ArgumentParser, _SubParsersAction -def run_topic_modeling_lda(args): +def run_topic_modeling_with_lda(args): from pycisTopic.lda_models import run_cgs_models input_filename = args.input output_filename = args.output - topics = args.topics + n_topics = args.topics alpha = args.alpha alpha_by_topic = args.alpha_by_topic eta = args.eta eta_by_topic = args.eta_by_topic - iterations = args.iterations - parallel = args.parallel + n_iter = args.iterations + n_cpu = args.parallel save_path = ( (output_filename[:-4] if output_filename.endswith(".pkl") else output_filename) if args.keep_intermediate_topic_models @@ -30,16 +32,22 @@ def run_topic_modeling_lda(args): random_state = args.seed temp_dir = args.temp_dir + if args.verbose: + level = logging.INFO + log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + handlers = [logging.StreamHandler(stream=sys.stdout)] + logging.basicConfig(level=level, format=log_format, handlers=handlers) + print("Run topic modeling with lda with the following settings:") print(f" - Input cisTopic object filename: {input_filename}") print(f" - Topic modeling output filename: {output_filename}") - print(f" - Number of topics to run topic modeling for: {topics}") + print(f" - Number of topics to run topic modeling for: {n_topics}") print(f" - Alpha: {alpha}") print(f" - Divide alpha by the number of topics: {alpha_by_topic}") print(f" - Eta: {eta}") print(f" - Divide eta by the number of topics: {eta_by_topic}") - print(f" - Number of iterations: {iterations}") - print(f" - Number of topic models to run in parallel: {parallel}") + print(f" - Number of iterations: {n_iter}") + print(f" - Number of topic models to run in parallel: {n_cpu}") print(f" - Seed: {random_state}") print(f" - Save intermediate topic models in dir: {save_path}") print(f" - TMP dir: {temp_dir}") @@ -53,9 +61,9 @@ def run_topic_modeling_lda(args): print("--------------") models = run_cgs_models( cistopic_obj, - n_topics=topics, - n_cpu=parallel, - n_iter=iterations, + n_topics=n_topics, + n_cpu=n_cpu, + n_iter=n_iter, random_state=random_state, alpha=alpha, alpha_by_topic=alpha_by_topic, @@ -70,76 +78,242 @@ def run_topic_modeling_lda(args): pickle.dump(models, fh) -def run_topic_modeling_mallet(args): - from pycisTopic.lda_models import run_cgs_models_mallet +def run_topic_modeling_with_mallet(args): + from pycisTopic.lda_models import LDAMallet - input_filename = args.input - output_filename = args.output - topics = args.topics + mallet_corpus_filename = args.mallet_corpus_filename + output_prefix = args.output_prefix + n_topics_list = [args.topics] if isinstance(args.topics, int) else args.topics alpha = args.alpha alpha_by_topic = args.alpha_by_topic eta = args.eta eta_by_topic = args.eta_by_topic - iterations = args.iterations - parallel = args.parallel - save_path = ( - (output_filename[:-4] if output_filename.endswith(".pkl") else output_filename) - if args.keep_intermediate_topic_models - else None - ) + n_iter = args.iterations + n_cpu = args.parallel random_state = args.seed memory_in_gb = f"{args.memory_in_gb}G" - temp_dir = args.temp_dir - reuse_corpus = args.reuse_corpus mallet_path = args.mallet_path + if args.verbose: + level = logging.INFO + log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + handlers = [logging.StreamHandler(stream=sys.stdout)] + logging.basicConfig(level=level, format=log_format, handlers=handlers) + print("Run topic modeling with Mallet with the following settings:") - print(f" - Input cisTopic object filename: {input_filename}") - print(f" - Topic modeling output filename: {output_filename}") - print(f" - Number of topics to run topic modeling for: {topics}") + print(f" - Mallet corpus filename: {mallet_corpus_filename}") + print(f" - Output prefix: {output_prefix}") + print(f" - Number of topics to run topic modeling for: {n_topics_list}") print(f" - Alpha: {alpha}") print(f" - Divide alpha by the number of topics: {alpha_by_topic}") print(f" - Eta: {eta}") print(f" - Divide eta by the number of topics: {eta_by_topic}") - print(f" - Number of iterations: {iterations}") - print(f" - Number threads Mallet is allowed to use: {parallel}") + print(f" - Number of iterations: {n_iter}") + print(f" - Number threads Mallet is allowed to use: {n_cpu}") print(f" - Seed: {random_state}") - print(f" - Save intermediate topic models in dir: {save_path}") - print(f" - TMP dir: {temp_dir}") - print(f" - Reuse Mallet corpus: {reuse_corpus}") print(f" - Amount of memory Mallet is allowed to use: {memory_in_gb}") print(f" - Mallet binary: {mallet_path}") - print(f'\nLoading cisTopic object from "{input_filename}"...\n') - with open(input_filename, "rb") as fh: - cistopic_obj = pickle.load(fh) + os.environ["MALLET_MEMORY"] = memory_in_gb - # Run models - print("Running models") - print("--------------") + for n_topics in n_topics_list: + # Run models + print(f"\nRunning Mallet topic modeling for {n_topics} topics.") + print(f"----------------------------------{'-' * len(str(n_topics))}--------") + + LDAMallet.run_mallet_topic_modeling( + mallet_corpus_filename=mallet_corpus_filename, + output_prefix=output_prefix, + n_topics=n_topics, + alpha=alpha, + alpha_by_topic=alpha_by_topic, + eta=eta, + eta_by_topic=eta_by_topic, + n_cpu=n_cpu, + optimize_interval=0, + iterations=n_iter, + topic_threshold=0.0, + random_seed=random_state, + mallet_path=mallet_path, + ) + + print( + f'\nWriting Mallet topic modeling output files to "{output_prefix}.{n_topics}_topics.*"...' + ) + + +def run_convert_binary_matrix_to_mallet_corpus_file(args): + import scipy + from pycisTopic.lda_models import LDAMallet + + binary_accessibility_matrix_filename = args.binary_accessibility_matrix_filename + mallet_corpus_filename = args.mallet_corpus_filename + mallet_path = args.mallet_path + memory_in_gb = f"{args.memory_in_gb}G" + + if args.verbose: + level = logging.INFO + log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + handlers = [logging.StreamHandler(stream=sys.stdout)] + logging.basicConfig(level=level, format=log_format, handlers=handlers) + + print( + f'Read binary accessibility matrix from "{binary_accessibility_matrix_filename}" Matrix Market file.' + ) + binary_accessibility_matrix = scipy.io.mmread(binary_accessibility_matrix_filename) os.environ["MALLET_MEMORY"] = memory_in_gb - models = run_cgs_models_mallet( - cistopic_obj, - n_topics=topics, - n_cpu=parallel, - n_iter=iterations, - random_state=random_state, - alpha=alpha, - alpha_by_topic=alpha_by_topic, - eta=eta, - eta_by_topic=eta_by_topic, - save_path=save_path, - top_topics_coh=5, - tmp_path=temp_dir, - reuse_corpus=reuse_corpus, + print( + f'Convert binary accessibility matrix to Mallet serialized corpus file "{mallet_corpus_filename}".' + ) + LDAMallet.convert_binary_matrix_to_mallet_corpus_file( + binary_accessibility_matrix=binary_accessibility_matrix, + mallet_corpus_filename=mallet_corpus_filename, mallet_path=mallet_path, ) - print(f'\nWriting topic modeling output to "{output_filename}"...') - with open(output_filename, "wb") as fh: - pickle.dump(models, fh) + +def run_mallet_calculate_model_evaluation_stats(args): + import scipy + from pycisTopic.fragments import read_barcodes_file_to_polars_series + from pycisTopic.lda_models import LDAMallet, calculate_model_evaluation_stats + + binary_accessibility_matrix_filename = args.binary_accessibility_matrix_filename + cell_barcodes_filename = args.cell_barcodes_filename + region_ids_filename = args.region_ids_filename + output_prefix = args.output_prefix + n_topics_list = [args.topics] if isinstance(args.topics, int) else args.topics + + print( + f'Read binary accessibility matrix from "{binary_accessibility_matrix_filename}" Matrix Market file.' + ) + binary_accessibility_matrix = scipy.io.mmread(binary_accessibility_matrix_filename) + + print(f'Read cell barcodes filename "{cell_barcodes_filename}".') + cell_barcodes = read_barcodes_file_to_polars_series( + barcodes_tsv_filename=cell_barcodes_filename, + sample_id=None, + cb_end_to_remove=None, + cb_sample_separator=None, + ).to_list() + + print(f'Read region IDs filename "{region_ids_filename}".') + region_ids = read_barcodes_file_to_polars_series( + barcodes_tsv_filename=region_ids_filename, + sample_id=None, + cb_end_to_remove=None, + cb_sample_separator=None, + ).to_list() + + for n_topics in n_topics_list: + print( + f'Calculate model evaluation statistics for {n_topics} topics from "{output_prefix}.{n_topics}_topics.*"...' + ) + calculate_model_evaluation_stats( + binary_accessibility_matrix=binary_accessibility_matrix, + cell_barcodes=cell_barcodes, + region_ids=region_ids, + output_prefix=output_prefix, + n_topics=n_topics, + top_topics_coh=5, + ) + +def binarize_cell_or_region_topic(args): + """ + target, method, ntop, smooth_topics, nbins, cb, regions, output, topic + """ + + target = args.target + method = args.method + ntop = args.ntop + smooth_topics = args.smooth_topics + nbins = args.nbins + cell_barcodes_filename = args.cell_barcodes_filename + region_ids_filename = args.region_ids_filename + output_prefix = args.output_prefix + n_topics = args.n_topics + out_dir = args.out_dir + + # input validation + if target == "cell" and cell_barcodes_filename is None: + raise ValueError("`cell_barcodes_filename` using `--cb` should be provided when target is `cell`") + if target == "region" and region_ids_filename is None: + raise ValueError("`region_ids_filename` using `--regions` should be provided when target is `region`") + + import os + if not os.path.exists(out_dir): + print(f'Making directory: {out_dir}') + os.makedirs(out_dir) + + from pycisTopic.fragments import read_barcodes_file_to_polars_series + from pycisTopic.lda_models import LDAMallet, LDAMalletFilenames + from pycisTopic.topic_binarization import binarize_topics + + lda_mallet_filenames = LDAMalletFilenames( + output_prefix=output_prefix, n_topics=n_topics + ) + + if target == "cell": + print(f'Read cell barcodes filename "{cell_barcodes_filename}".') + cell_or_region_names = read_barcodes_file_to_polars_series( + barcodes_tsv_filename=cell_barcodes_filename, + sample_id=None, + cb_end_to_remove=None, + cb_sample_separator=None, + ).to_list() + print(f'Read cell-topic probabilities filename "{lda_mallet_filenames.cell_topic_probabilities_parquet_filename}".') + cell_or_region_topic_prob = LDAMallet.read_cell_topic_probabilities_parquet_file( + mallet_cell_topic_probabilities_parquet_filename=lda_mallet_filenames.cell_topic_probabilities_parquet_filename + ) + + if target == "region": + print(f'Read region IDs filename "{region_ids_filename}".') + cell_or_region_names = read_barcodes_file_to_polars_series( + barcodes_tsv_filename=region_ids_filename, + sample_id=None, + cb_end_to_remove=None, + cb_sample_separator=None, + ).to_list() + print(f'Read region-topic probabilities filename "{lda_mallet_filenames.region_topic_counts_parquet_filename}".') + cell_or_region_topic_prob = LDAMallet.read_region_topic_counts_parquet_file_to_region_topic_probabilities( + mallet_region_topic_counts_parquet_filename=lda_mallet_filenames.region_topic_counts_parquet_filename + ).T + + print("Binarizing topics ...") + cell_or_region_names_per_topic, scores_per_topic, thresholds = binarize_topics( + cell_or_region_topic_prob=cell_or_region_topic_prob, + cell_or_region_names=cell_or_region_names, + method=method, + smooth_topics=smooth_topics, + ntop=ntop, + nbins=nbins + ) + + print(f'Saving results to "{out_dir}".') + + with open(os.path.join(out_dir, f"{target}_thresholds.tsv"), "wt") as f: + for topic, thr in enumerate(thresholds): + _ = f.write( + f"{topic + 1}\t{thr}\n" + ) + + if target == "cell": + for topic, (cells, scores) in enumerate(zip(cell_or_region_names_per_topic, scores_per_topic)): + with open(os.path.join(out_dir, f"{target}_Topic_{topic + 1}_binarized.txt"), "wt") as f: + for cell, score in zip(cells, scores): + _ = f.write( + f"{cell}\t{score}\n" + ) + + elif target == "region": + for topic, (regions, scores) in enumerate(zip(cell_or_region_names_per_topic, scores_per_topic)): + with open(os.path.join(out_dir, f"{target}_Topic_{topic + 1}_binarized.bed"), "wt") as f: + for region, score in zip(regions, scores): + chrom, start, end = region.replace(":", "-").split("-") + _ = f.write( + f"{chrom}\t{start}\t{end}\tTopic_{topic + 1}\t{score}\n" + ) def str_to_bool(v: str) -> bool: @@ -184,9 +358,9 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): parser_topic_modeling_lda = subparser_topic_modeling.add_parser( "lda", - help='"Run LDA topic modeling with "lda" package.', + help='Run LDA topic modeling with "lda" package.', ) - parser_topic_modeling_lda.set_defaults(func=run_topic_modeling_lda) + parser_topic_modeling_lda.set_defaults(func=run_topic_modeling_with_lda) parser_topic_modeling_lda.add_argument( "-i", @@ -229,8 +403,8 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): dest="iterations", type=int, required=False, - default=500, - help="Number of iterations. Default: 500.", + default=150, + help="Number of iterations. Default: 150.", ) parser_topic_modeling_lda.add_argument( "-a", @@ -300,32 +474,104 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=None, help=f'TMP directory to use instead of the default ("{tempfile.gettempdir()}").', ) + parser_topic_modeling_lda.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + required=False, + help="Enable verbose mode.", + ) parser_topic_modeling_mallet = subparser_topic_modeling.add_parser( - "mallet", - help='"Run LDA topic modeling with "Mallet".', + "mallet", help='Run LDA topic modeling with "Mallet".' + ) + + subparser_topic_modeling_mallet = parser_topic_modeling_mallet.add_subparsers( + title='Topic modeling with "Mallet"', + dest="mallet", + help='List of "Mallet" topic modeling subcommands.', + ) + subparser_topic_modeling_mallet.required = True + + parser_topic_modeling_mallet_create_corpus = subparser_topic_modeling_mallet.add_parser( + "create_corpus", + help="Convert binary accessibility matrix to Mallet serialized corpus file.", + ) + parser_topic_modeling_mallet_create_corpus.set_defaults( + func=run_convert_binary_matrix_to_mallet_corpus_file ) - parser_topic_modeling_mallet.set_defaults(func=run_topic_modeling_mallet) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_create_corpus.add_argument( "-i", "--input", - dest="input", + dest="binary_accessibility_matrix_filename", action="store", type=str, required=True, - help="cisTopic object pickle input filename.", + help="Binary accessibility matrix (region IDs vs cell barcodes) in Matrix Market format.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_create_corpus.add_argument( "-o", "--output", - dest="output", + dest="mallet_corpus_filename", action="store", type=str, required=True, - help="Topic model list pickle output filename.", + help="Mallet serialized corpus filename.", + ) + parser_topic_modeling_mallet_create_corpus.add_argument( + "-m", + "--memory", + dest="memory_in_gb", + type=int, + required=False, + default=10, + help='Amount of memory (in GB) Mallet is allowed to use. Default: "10"', ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_create_corpus.add_argument( + "-b", + "--mallet_path", + dest="mallet_path", + type=str, + required=False, + default="mallet", + help='Path to Mallet binary (e.g. "/xxx/Mallet/bin/mallet"). Default: "mallet".', + ) + parser_topic_modeling_mallet_create_corpus.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + required=False, + help="Enable verbose mode.", + ) + + parser_topic_modeling_mallet_run = subparser_topic_modeling_mallet.add_parser( + "run", + help='Run LDA topic modeling with "Mallet".', + ) + parser_topic_modeling_mallet_run.set_defaults(func=run_topic_modeling_with_mallet) + + parser_topic_modeling_mallet_run.add_argument( + "-i", + "--input", + dest="mallet_corpus_filename", + action="store", + type=str, + required=True, + help="Mallet corpus filename.", + ) + parser_topic_modeling_mallet_run.add_argument( + "-o", + "--output", + dest="output_prefix", + action="store", + type=str, + required=True, + help="Topic model output prefix.", + ) + parser_topic_modeling_mallet_run.add_argument( "-t", "--topics", dest="topics", @@ -334,7 +580,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): nargs="+", help="Number(s) of topics to create during topic modeling.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-p", "--parallel", dest="parallel", @@ -342,16 +588,16 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): required=True, help="Number of threads Mallet is allowed to use.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-n", "--iterations", dest="iterations", type=int, required=False, default=150, - help="Number of iterations. Default: 500.", + help="Number of iterations. Default: 150.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-a", "--alpha", dest="alpha", @@ -360,7 +606,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=50, help="Alpha value. Default: 50.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-A", "--alpha_by_topic", dest="alpha_by_topic", @@ -370,7 +616,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=True, help="Whether the alpha value should by divided by the number of topics. Default: True.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-e", "--eta", dest="eta", @@ -379,7 +625,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=0.1, help="Eta value. Default: 0.1.", ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-E", "--eta_by_topic", dest="eta_by_topic", @@ -389,19 +635,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=False, help="Whether the eta value should by divided by the number of topics. Default: False.", ) - parser_topic_modeling_mallet.add_argument( - "-k", - "--keep", - dest="keep_intermediate_topic_models", - type=str_to_bool, - choices=(True, False), - required=False, - default=False, - help="Whether intermediate topic models should be kept. " - "Useful to enable if running with a lot of topic numbers, to not lose finished topic model runs. " - "Default: False.", - ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-s", "--seed", dest="seed", @@ -410,16 +644,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=555, help="Seed for ensuring reproducibility. Default: 555.", ) - parser_topic_modeling_mallet.add_argument( - "-T", - "--temp_dir", - dest="temp_dir", - type=str, - required=False, - default=None, - help=f'TMP directory to use instead of the default ("{tempfile.gettempdir()}").', - ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-m", "--memory", dest="memory_in_gb", @@ -428,16 +653,7 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default=100, help='Amount of memory (in GB) Mallet is allowed to use. Default: "100"', ) - parser_topic_modeling_mallet.add_argument( - "-r", - "--reuse_corpus", - dest="reuse_corpus", - type=str_to_bool, - required=False, - default=False, - help="Whether to reuse the corpus from Mallet. Default: False.", - ) - parser_topic_modeling_mallet.add_argument( + parser_topic_modeling_mallet_run.add_argument( "-b", "--mallet_path", dest="mallet_path", @@ -446,3 +662,174 @@ def add_parser_topic_modeling(subparsers: _SubParsersAction[ArgumentParser]): default="mallet", help='Path to Mallet binary (e.g. "/xxx/Mallet/bin/mallet"). Default: "mallet".', ) + parser_topic_modeling_mallet_run.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + required=False, + help="Enable verbose mode.", + ) + + parser_topic_modeling_mallet_calculate_stats = ( + subparser_topic_modeling_mallet.add_parser( + "stats", + help="Calculate model evaluation statistics.", + ) + ) + parser_topic_modeling_mallet_calculate_stats.set_defaults( + func=run_mallet_calculate_model_evaluation_stats + ) + + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-i", + "--input", + dest="binary_accessibility_matrix_filename", + action="store", + type=str, + required=True, + help="Binary accessibility matrix (region IDs vs cell barcodes) in Matrix Market format.", + ) + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-c", + "--cb", + dest="cell_barcodes_filename", + action="store", + type=str, + required=True, + help="Filename with cell barcodes.", + ) + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-r", + "--regions", + dest="region_ids_filename", + action="store", + type=str, + required=True, + help="Filename with region IDs.", + ) + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-o", + "--output", + dest="output_prefix", + action="store", + type=str, + required=True, + help="Topic model output prefix.", + ) + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-t", + "--topics", + dest="topics", + type=int, + required=True, + nargs="+", + help="Topic number(s) to create the model evaluation statistics for.", + ) + parser_topic_modeling_mallet_calculate_stats.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + required=False, + help="Enable verbose mode.", + ) + + parser_topic_modeling_mallet_binarize = subparser_topic_modeling_mallet.add_parser( + "binarize", help="Binarize cell- or region-topic probabilities" + ) + parser_topic_modeling_mallet_binarize.set_defaults(func=binarize_cell_or_region_topic) + parser_topic_modeling_mallet_binarize.add_argument( + "-a", + "--target", + dest="target", + action="store", + type=str, + choices=["region", "cell"], + required=True, + help='Choose between "region" or "cell" topic binarization.' + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-m", + "--method", + dest="method", + action="store", + type=str, + choices=("ntop", "otsu", "aucell", "li", "yen"), + required=True, + help='Binarization method. Choose between "ntop", "otsu", "aucell", "li" or "yen" for cell-or region-topic binarization.' + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-n", + "--ntop", + dest="ntop", + action="store", + type=int, + required=False, + help="Number of top regions to select. Can only be used when `--method` is set to `ntop`." + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-s", + "--smooth", + dest="smooth_topics", + action="store", + type=str_to_bool, + choices=(True, False), + required=False, + default=True, + help="Wether to smooth the cell- or region-topic probabilities." + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-b", + "--nbins", + dest="nbins", + action="store", + type=int, + required=False, + default=100, + help="Number of bins to use in the histogram used for `otsu`, `yen` and `li` thresholding." + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-c", + "--cb", + dest="cell_barcodes_filename", + action="store", + type=str, + required=False, + help="Filename with cell barcodes.", + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-r", + "--regions", + dest="region_ids_filename", + action="store", + type=str, + required=False, + help="Filename with region IDs.", + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-o", + "--output", + dest="output_prefix", + action="store", + type=str, + required=True, + help="Topic model output prefix.", + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-t", + "--n_topics", + dest="n_topics", + type=int, + required=True, + help="Model with `topic` number of topics to binarize.", + ) + parser_topic_modeling_mallet_binarize.add_argument( + "-p", + "--output_dir", + dest="out_dir", + action="store", + type=str, + required=True, + help="Directory to store results.", + ) diff --git a/src/pycisTopic/cli/subcommand/tss.py b/src/pycisTopic/cli/subcommand/tss.py index 5e441dd..3e25fef 100644 --- a/src/pycisTopic/cli/subcommand/tss.py +++ b/src/pycisTopic/cli/subcommand/tss.py @@ -161,12 +161,24 @@ def get_tss_annotation_bed_file( f" - use_cache: {use_cache}", file=sys.stderr, ) - tss_annotation_bed_df_pl = ga.get_tss_annotation_from_ensembl( - biomart_name=biomart_name, - biomart_host=biomart_host, - transcript_type=transcript_type, - use_cache=use_cache, - ) + + try: + tss_annotation_bed_df_pl = ga.get_tss_annotation_from_ensembl( + biomart_name=biomart_name, + biomart_host=biomart_host, + transcript_type=transcript_type, + use_cache=use_cache, + ) + except Exception as e: + print( + "\nError: Could not get TSS annotation from Ensembl BioMart. " + "Likely this is caused by and invalid/incomplete cached request from " + "BioMart.\n\n" + 'Use "--no-cache" or remove ".pybiomart.sqlite" in the current working ' + "directory and try again.\n", + file=sys.stderr, + ) + raise e if to_chrom_source_name and ( chrom_sizes_and_alias_tsv_filename or ncbi_accession_id or ucsc_assembly @@ -296,10 +308,21 @@ def get_species_gene_annotation_ensembl_biomart_dataset_names( """ # noqa: W505 import pycisTopic.gene_annotation as ga - biomart_datasets = ga.get_all_gene_annotation_ensembl_biomart_dataset_names( - biomart_host=biomart_host, - use_cache=use_cache, - ) + try: + biomart_datasets = ga.get_all_gene_annotation_ensembl_biomart_dataset_names( + biomart_host=biomart_host, + use_cache=use_cache, + ) + except Exception as e: + print( + "Error: Could not get gene annotation Ensembl BioMart dataset names. " + "Likely this is caused by and invalid/incomplete cached request from " + "BioMart.\n\n" + 'Use "--no-cache" or remove ".pybiomart.sqlite" in the current working ' + "directory and try again.\n", + file=sys.stderr, + ) + raise e if not species: biomart_datasets.to_csv(sys.stdout, sep="\t", header=False, index=False) @@ -582,7 +605,10 @@ def add_parser_tss(subparsers: _SubParsersAction[ArgumentParser]): action="store_false", required=False, default=True, - help="Disable caching of requests to Ensembl BioMart server.", + help="Disable caching of requests to Ensembl BioMart server. Cached requests " + 'can also be removed by deleting ".pybiomart.sqlite" file in the current ' + 'working directory. If you got a crash running the "get_tss" subcommand, try ' + 'to remove the ".pybiomart.sqlite" file or add "--no_cache".', ) group_tgt_remap_chroms = parser_tss_get_tss.add_argument_group( @@ -683,7 +709,10 @@ def add_parser_tss(subparsers: _SubParsersAction[ArgumentParser]): action="store_false", required=False, default=True, - help="Disable caching of requests to Ensembl BioMart server.", + help="Disable caching of requests to Ensembl BioMart server. Cached requests " + 'can also be removed by deleting ".pybiomart.sqlite" file in the current ' + 'working directory. If you got a crash running the "gene_annotation_list" ' + 'subcommand, try to remove the ".pybiomart.sqlite" file or add "--no_cache".', ) parser_tss_get_ncbi_acc = subparser_tss.add_parser( diff --git a/src/pycisTopic/clust_vis.py b/src/pycisTopic/clust_vis.py index 653244a..4d2d69c 100644 --- a/src/pycisTopic/clust_vis.py +++ b/src/pycisTopic/clust_vis.py @@ -14,6 +14,7 @@ import matplotlib.patches as mpatches import matplotlib.patheffects as PathEffects import matplotlib.pyplot as plt +from matplotlib import colormaps import numpy as np import pandas as pd import seaborn as sns @@ -96,10 +97,11 @@ def find_clusters( if target == "cell": data_mat = model.cell_topic_harmony if harmony else model.cell_topic data_names = cistopic_obj.cell_names - - if target == "region": + elif target == "region": data_mat = model.topic_region.T data_names = cistopic_obj.region_names + else: + raise ValueError(f"target should be 'cell' or 'region' not {target}.") if selected_topics is not None: data_mat = data_mat.loc[["Topic" + str(x) for x in selected_topics]] @@ -929,26 +931,30 @@ def cell_topic_heatmap( columns=cell_topic.columns, ) - if remove_nan and (sum(cell_data[variables].isnull().sum()) > 0): - cell_data = cell_data[variables].dropna() - cell_topic = cell_topic.loc[:, cell_data.index.tolist()] - cell_topic = cell_topic.transpose() - var = variables[0] - var_data = cell_data.loc[:, var].sort_values() - cell_topic = cell_topic.loc[var_data.index.to_list()] - df = pd.concat([cell_topic, var_data], axis=1, sort=False) - topic_order = df.groupby(var).mean().idxmax().sort_values().index.to_list() - cell_topic = cell_topic.loc[:, topic_order].T - # Color dict col_colors = {} if variables is not None: + if remove_nan and (sum(cell_data[variables].isnull().sum()) > 0): + cell_data = cell_data[variables].dropna() + cell_topic = cell_topic.loc[:, cell_data.index.tolist()] + # sort by first variable + var = variables[0] + var_data = cell_data.loc[:, var].sort_values() + cell_topic = cell_topic.loc[var_data.index.to_list()] + df = pd.concat([cell_topic, var_data], axis=1, sort=False) + topic_order = df.groupby(var).mean().idxmax().sort_values().index.to_list() + cell_topic = cell_topic.loc[:, topic_order].T + + if color_dictionary is None: + color_dictionary = {} + for var in variables: var_data = cell_data.loc[:, var].sort_values() categories = set(var_data) if var not in color_dictionary: + # generate random color mapping random.seed(seed) color = [ mcolors.to_rgb("#" + "%06x" % random.randint(0, 0xFFFFFF)) @@ -956,7 +962,9 @@ def cell_topic_heatmap( ] color_dict = dict(zip(categories, color)) color_dictionary[var] = color_dict + col_colors[var] = var_data.map(color_dictionary[var]) + col_colors = pd.concat( [col_colors[var] for var in variables], axis=1, sort=False ) @@ -990,7 +998,7 @@ def cell_topic_heatmap( loc="center", title=key, ) - ax = plt.gca().add_artist(legend) + _ = plt.gca().add_artist(legend) pos += legend_dist_y else: g = sns.clustermap( diff --git a/src/pycisTopic/diff_features.py b/src/pycisTopic/diff_features.py index 54d28dc..c0d4bdf 100644 --- a/src/pycisTopic/diff_features.py +++ b/src/pycisTopic/diff_features.py @@ -4,10 +4,12 @@ import sys from typing import TYPE_CHECKING, Self +import math import matplotlib import matplotlib.pyplot as plt import numba import numpy as np +import numpy.typing as npt import pandas as pd import ray import scipy @@ -18,9 +20,13 @@ if TYPE_CHECKING: from pycisTopic.cistopic_class import CistopicObject -# FIXME -from .cistopic_class import * -from .utils import * +from pycisTopic.utils import ( + get_nonzero_row_indices, + get_position_index, + non_zero_rows, + prepare_tag_cells, + subset_list, +) class CistopicImputedFeatures: @@ -31,7 +37,7 @@ class CistopicImputedFeatures: cell names :attr:`cell_names` and feature names :attr:`feature_names`. Attributes - --------- + ---------- mtx: sparse.csr_matrix A matrix containing imputed values. cell_names: list @@ -183,24 +189,16 @@ def merge( if sparse.issparse(mtx): mtx_common = sparse.hstack( [ - mtx[ - common_index_fm, - ], - mtx_to_add[ - common_index_fm_to_add, - ], + mtx[common_index_fm,], + mtx_to_add[common_index_fm_to_add,], ], format="csr", ) else: mtx_common = np.hstack( [ - mtx[ - common_index_fm, - ], - mtx_to_add[ - common_index_fm_to_add, - ], + mtx[common_index_fm,], + mtx_to_add[common_index_fm_to_add,], ] ) if len(diff_features) > 0: @@ -211,9 +209,7 @@ def merge( if sparse.issparse(mtx): mtx_diff_1 = sparse.hstack( [ - mtx[ - diff_index_fm_1, - ], + mtx[diff_index_fm_1,], np.zeros((len(diff_features_1), mtx_to_add.shape[1])), ], format="csr", @@ -221,9 +217,7 @@ def merge( else: mtx_diff_1 = np.hstack( [ - mtx[ - diff_index_fm_1, - ], + mtx[diff_index_fm_1,], np.zeros((len(diff_features_1), mtx_to_add.shape[1])), ] ) @@ -238,9 +232,7 @@ def merge( mtx_diff_2 = sparse.hstack( [ np.zeros((len(diff_features_2), mtx.shape[1])), - mtx_to_add[ - diff_index_fm_2, - ], + mtx_to_add[diff_index_fm_2,], ], format="csr", ) @@ -251,9 +243,7 @@ def merge( mtx_diff_2 = np.hstack( [ np.zeros((len(diff_features_2), mtx.shape[1])), - mtx_to_add[ - diff_index_fm_2, - ], + mtx_to_add[diff_index_fm_2,], ] ) mtx = np.vstack([mtx_common, mtx_diff_1, mtx_diff_2]) @@ -328,15 +318,476 @@ def rank_scores_and_assign_random_ranking_in_range_for_ties( # Rank all scores per motif/track and assign a random ranking in range for regions/genes with the same score. for col_idx in range(len(imputed_acc_ranking.cell_names)): - imputed_acc_ranking.mtx[ - :, col_idx - ] = rank_scores_and_assign_random_ranking_in_range_for_ties( - mtx[:, col_idx].toarray().flatten() + imputed_acc_ranking.mtx[:, col_idx] = ( + rank_scores_and_assign_random_ranking_in_range_for_ties( + mtx[:, col_idx].toarray().flatten() + ) ) return imputed_acc_ranking +@numba.njit(parallel=False, error_model="numpy") +def calculate_partial_imputed_acc_sums_per_cell_for_requested_regions( + imputed_acc_chunk: npt.NDArray[np.float32], + region_idx_to_keep_chunk: npt.NDArray[np.intp], +) -> npt.NDArray[np.float64]: + """ + Calculate (partial) sum of imputed accessibility for the whole (partial) cell column for each cell. + + Calculate (partial) sum of imputed accessibility for the whole (partial) cell + column for each cell using all regions for which the whole row is not completely + zero (taken from `region_idx_to_keep_chunk`) + (regions which are never accessible in any cell). + + Returns + ------- + per_cell_imputed_acc_sums_partial + + """ + # To get the sum of imputed accessibility for the whole (partial) cell + # column, the values for each imputed_acc_chunk need to be summed together. + # To avoid problems with precision in summing a lot of values, summing is done + # with float64 values. + # + # This function is a memory optimized version of: + # + # np.sum( + # imputed_acc_chunk[region_idx_to_keep_chunk], + # axis=0, + # dtype=np.float64, + # ) + # + # The above code triggers a big temporarily memory allocation, when the input + # matrix is not contiguous in memory (which happens due subsetting + # `imputed_acc_chunk` with specific region indexes). + + n_cells = imputed_acc_chunk.shape[1] + + # Preallocate array for whole (partial) imputed accessibility per cell for each + # cell. + per_cell_imputed_acc_sums_partial = np.zeros((n_cells,), dtype=np.float64) + + # Get each region index of regions to keep and retrieve imputed accessibility + # per region and compute (partial) sum of imputed accessibility for the whole + # (partial) cell for each cell. Use float64 to avoid problems with precision + # when summing a lot of values. + for region_idx in region_idx_to_keep_chunk: + per_cell_imputed_acc_sums_partial += imputed_acc_chunk[region_idx].astype( + np.float64 + ) + return per_cell_imputed_acc_sums_partial + + +@numba.njit(parallel=True, error_model="numpy") +def calculate_per_region_mean_and_dispersion_on_normalized_imputed_acc_chunk( + normalized_imputed_acc_chunk: npt.NDArray[np.float32], +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """ + Calculate mean and dispersion on normalized imputed accessibility for each region. + + Parameters + ---------- + normalized_imputed_acc_chunk + Normalized imputed accessibility matrix (regions_chunk_size x n_cells). + + Returns + ------- + Mean and dispersion of normalized imputed accessibility per region for all regions: + (per_region_means_on_normalized_imputed_acc_chunk, + per_region_dispersions_on_normalized_imputed_acc_chunk) + + """ + # Memory and speed optimized version of: + # + # per_region_means_on_normalized_imputed_acc_chunk = normalized_imputed_acc_chunk.mean( + # axis=1, dtype=np.float64 + # ) + # per_region_variances_on_normalized_imputed_acc_chunk = normalized_imputed_acc_chunk.var( + # axis=1, dtype=np.float64 + # ) + # per_region_means_on_normalized_imputed_acc_chunk[ + # per_region_means_on_normalized_imputed_acc_chunk == 0 + # ] = 1e-12 + # per_region_dispersions_on_normalized_imputed_acc_chunk = ( + # per_region_variances_on_normalized_imputed_acc_chunk + # / per_region_means_on_normalized_imputed_acc_chunk + # ) + # per_region_dispersions_on_normalized_imputed_acc_chunk[ + # per_region_dispersions_on_normalized_imputed_acc_chunk == 0 + # ] = np.nan + # np.log( + # per_region_dispersions_on_normalized_imputed_acc_chunk, + # out=per_region_dispersions_on_normalized_imputed_acc_chunk, + # ) + # return ( + # per_region_means_on_normalized_imputed_acc_chunk, + # per_region_dispersions_on_normalized_imputed_acc_chunk, + # ) + + n_regions_in_chunk = normalized_imputed_acc_chunk.shape[0] + n_cells = normalized_imputed_acc_chunk.shape[1] + + # Preallocate arrays for mean and dispersion of normalized imputed accessibility + # per region for each region in the current chunk. + per_region_means_on_normalized_imputed_acc_chunk = np.empty( + (n_regions_in_chunk), + dtype=np.float64, + ) + per_region_dispersions_on_normalized_imputed_acc_chunk = np.empty( + (n_regions_in_chunk), + dtype=np.float64, + ) + + for region_idx in numba.prange(n_regions_in_chunk): + # Get normalized imputed accessibility for current region. + normalized_impute_acc_for_region = normalized_imputed_acc_chunk[ + region_idx + ].astype(np.float64) + + # Calculate mean of normalized imputed accessibility for current region. + mean = np.float64(0.0) + for cell_idx in range(n_cells): + mean += normalized_impute_acc_for_region[cell_idx] + mean /= n_cells + + # Calculate variance of normalized imputed accessibility for current region. + variance = np.float64(0.0) + for cell_idx in range(n_cells): + variance += (normalized_impute_acc_for_region[cell_idx] - mean) ** 2 + variance /= n_cells + + # Calculate dispersion of normalized imputed accessibility for current region. + mean = mean if mean != 0.0 else 1e-12 + dispersion = variance / mean + dispersion = np.log(dispersion) if dispersion != 0.0 else np.nan + + per_region_means_on_normalized_imputed_acc_chunk[region_idx] = mean + per_region_dispersions_on_normalized_imputed_acc_chunk[region_idx] = dispersion + return ( + per_region_means_on_normalized_imputed_acc_chunk, + per_region_dispersions_on_normalized_imputed_acc_chunk, + ) + + +def calculate_per_region_mean_and_dispersion_on_normalized_imputed_acc( + region_topic: npt.NDArray[np.float32], + cell_topic: npt.NDArray[np.float32], + region_names: list[str], + scale_factor1: int = 10**6, + scale_factor2: int = 10**4, + regions_chunk_size: int = 20000, +) -> tuple[list[str], npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """ + Calculate per region mean and dispersion on normalized imputed accessibility in chunks of `regions_chunk_size`. + + High level overview of the function: + - Calculate imputed accessibility: `region_topic @ cell_topic`. + - Scale imputed accessibility is scaled by `scale_factor1` to create a "count" + matrix. + - Only keep integer part of the scaled imputed accessibility. + - Remove all regions for which scaled imputed accessibility was 0 in all cells. + - Calculate total imputed accessibility per cell by summing the imputed + accessibility for each region. + - Calculate normalized imputed accessibility for regions for which scaled imputed + accessibility was not 0 in all cells, by dividing the scaled imputed + accessibility by the total imputed accessibility per cell and multiplying by + scale_factor2 and taking the log(x + 1) of the result. + - Calculate mean and dispersion of normalized imputed accessibility per region. + + Parameters + ---------- + region_topic + Region topic matrix (regions x topics). + cell_topic + Cell topic matrix (topic x cells). + region_names + List of all region names. + scale_factor1 + Multiply imputed accessiblitity by this scale factor to create a "count" matrix. + This will remove noise by putting very small values to zero. + Default: 10**6. + scale_factor2 + Scale factor used to normalize scaled imputed accessibility, similar to + RNA-seq normalization. + Divide scaled (`scale_factor1`) imputed accessibility by total imputed + accessibility per cell, multiply by `scale_factor2`, add 1 and take logarithm. + Default: 10**4. + regions_chunk_size + Regions chunk size used (number of regions for which imputed accessibility is + calculated at the same time). + + Returns + ------- + Numpy array with imputed accessibility for each region and a list of region + names (some regions for which all row values were 0 are filtered out). + (imputed_acc, region_names_to_keep) + + """ + # Create cisTopic logger + level = logging.INFO + log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + handlers = [logging.StreamHandler(stream=sys.stdout)] + logging.basicConfig(level=level, format=log_format, handlers=handlers) + log = logging.getLogger("cisTopic") + + output_regions_chunk_end = 0 + n_regions = region_topic.shape[0] + n_cells = cell_topic.shape[1] + # n_topics = region_topic.shape[1] + + region_idx_to_keep = np.zeros((n_regions,), dtype=np.intp) + + impute_acc_per_cell_sum = np.zeros((n_cells,), dtype=np.float64) + + log.info("Calculate total imputed accessibility per cell.") + region_topic = np.asarray(region_topic, dtype=np.float32) + cell_topic = np.asarray(cell_topic, dtype=np.float32) + + log.info( + f"Allocate {(regions_chunk_size * n_cells * 4 / 1024 ** 3):.3f} GiB of RAM for " + f"calculating (partial) imputed accessibility per cell for ({n_cells}) cells " + f"for chunk of {regions_chunk_size} regions." + ) + # Preallocate imputed accessibility chunk array (regions_chunk_size x n_cells) so + # it can be reused in each loop iteration (except for the last one as that one will + # likely be smaller). + imputed_acc_chunk = np.empty((regions_chunk_size, n_cells), dtype=np.float32) + + # Calculate total imputed accessibility per cell. + for input_regions_chunk_start in range(0, n_regions, regions_chunk_size): + input_regions_chunk_end = input_regions_chunk_start + regions_chunk_size + + # Set correct output chunk start position. + output_regions_chunk_start = output_regions_chunk_end + + log.info( + "Calculate (partial) imputed accessibility per cell for regions " + f"{input_regions_chunk_start}-{input_regions_chunk_end} (out of {n_regions})." + ) + + # Get the current chunk of regions. + topic_region_chunk = region_topic[ + input_regions_chunk_start : input_regions_chunk_start + regions_chunk_size + ] + current_regions_chunk_size = topic_region_chunk.shape[0] + + if current_regions_chunk_size < regions_chunk_size: + log.info( + f"Allocate {(current_regions_chunk_size * n_cells * 4 / 1024 ** 3):.3f} " + "GiB of RAM for calculating (partial) imputed accessibility per cell " + f"for ({n_cells}) cells for chunk of {current_regions_chunk_size} " + "regions." + ) + # Reallocate imputed_acc_chunk to the correct size. + del imputed_acc_chunk + imputed_acc_chunk = np.empty( + (current_regions_chunk_size, n_cells), dtype=np.float32 + ) + + log.info( + " - Calculate imputed accessibility for the current chunk of regions." + ) + # Calculate imputed accessibility for the current chunk of regions. + np.matmul(topic_region_chunk, cell_topic, out=imputed_acc_chunk) + + log.info(' - Scale imputed accessibility matrix chunk ("count" matrix).') + # Scale imputed accessibility matrix chunk ("count" matrix). + imputed_acc_chunk *= np.float32(scale_factor1) + + log.info(" - Only keep integer part.") + # Only keep integer part. + # This will convert very small values (< 1.0) to zero and removes noise. + np.floor(imputed_acc_chunk, out=imputed_acc_chunk) + + log.info(" - Get non-zero regions.") + # Get all region index positions of the matrix for which the whole row is not + # completely zero (regions which are never accessible in any cell). + region_idx_to_keep_chunk = get_nonzero_row_indices(imputed_acc_chunk) + + # Set correct output chunk end position by taking into account + # that rows with all zeros will be filtered out. + output_regions_chunk_end = output_regions_chunk_start + len( + region_idx_to_keep_chunk + ) + + # Get all region indexes that need to be kept from this chunk and + # assign them to the correct positions in the full region_idx_to_keep array. + region_idx_to_keep[output_regions_chunk_start:output_regions_chunk_end] = ( + region_idx_to_keep_chunk + input_regions_chunk_start + ) + + log.info( + " - Calculate (partial) sum of imputed accessibility for the whole (partial) cell column." + ) + # Calculate (partial) sum of imputed accessibility for the whole (partial) cell + # column for each cell using all regions for which the whole row is not + # completely zero (regions which are never accessible in any cell). + # To avoid problems with precision in summing multiple chunks later, float64 are + # used. + cells_imputed_acc_sums_partial = ( + calculate_partial_imputed_acc_sums_per_cell_for_requested_regions( + imputed_acc_chunk, region_idx_to_keep_chunk + ) + ) + + impute_acc_per_cell_sum += cells_imputed_acc_sums_partial + + del imputed_acc_chunk + + n_regions_to_keep = output_regions_chunk_end + + log.info(f"Keeping {n_regions_to_keep} of {n_regions} (non_zero) regions.") + + # Only retain that part of region_idx_to_keep that was actually + # filled in. + region_idx_to_keep = region_idx_to_keep[:output_regions_chunk_end] + + # Subset region_topic to regions we want to keep. + region_topic = region_topic[region_idx_to_keep] + + # Get all region names that need to be kept. + region_names_to_keep = subset_list( + region_names, + region_idx_to_keep, + ) + + # Preallocate arrays for mean and dispersion of normalized imputed accessibility per + # region. + per_region_means_on_normalized_imputed_acc = np.empty( + (n_regions_to_keep,), + dtype=np.float64, + ) + per_region_dispersions_on_normalized_imputed_acc = np.empty( + (n_regions_to_keep,), + dtype=np.float64, + ) + + log.info( + f"Scale total imputed accessibility per cell by dividing by {scale_factor2}." + ) + # Scale total imputed accessibility per cell by dividing by scale_factor2. + # The sum was calculated in float64 to avoid problems with precision, but after + # scaling it can be converted back to float32 to avoid converting + # normalized_imputed_acc_chunk in each for loop iteration to float64, causing a + # big allocation (if there are a lot of cells). + impute_acc_per_cell_sum_scaled = (impute_acc_per_cell_sum / scale_factor2).astype( + np.float32 + ) + + log.info( + f"Allocate {(regions_chunk_size * n_cells * 4 / 1024 ** 3):.3f} GiB of RAM " + "for calculating normalized imputed accessibility per region for chunk of " + f"{regions_chunk_size} regions." + ) + # Preallocate normalized imputed accessibility chunk array + # (regions_chunk_size x n_cells) so it can be reused in each loop iteration + # (except for the last one as that one will likely be smaller). + normalized_imputed_acc_chunk = np.empty( + (regions_chunk_size, n_cells), + dtype=np.float32, + ) + + output_regions_chunk_end = 0 + + log.info( + "Calculate mean and dispersion of normalized imputed accessibility per region." + ) + + # Calculate mean and dispersion of normalized imputed accessibility per region. + for input_regions_chunk_start in range(0, n_regions_to_keep, regions_chunk_size): + input_regions_chunk_end = input_regions_chunk_start + regions_chunk_size + + # Set correct output chunk start position. + output_regions_chunk_start = output_regions_chunk_end + + log.info( + "Calculate mean and dispersion of normalized imputed accessibility for regions " + f"{input_regions_chunk_start}-{input_regions_chunk_end} (out of {n_regions_to_keep})." + ) + + topic_region_chunk = region_topic[ + input_regions_chunk_start : input_regions_chunk_start + regions_chunk_size + ] + current_regions_chunk_size = topic_region_chunk.shape[0] + + if current_regions_chunk_size < regions_chunk_size: + log.info( + f"Allocate {(current_regions_chunk_size * n_cells * 4 / 1024 ** 3):.3f} " + f"GiB of RAM for calculating normalized imputed accessibility per " + f"region for chunk of {current_regions_chunk_size} regions." + ) + # Reallocate imputed_acc_chunk to the correct size. + del normalized_imputed_acc_chunk + normalized_imputed_acc_chunk = np.empty( + (current_regions_chunk_size, n_cells), + dtype=np.float32, + ) + + log.info( + " - Calculate imputed accessibility for the current chunk of regions." + ) + # Calculate imputed accessibility for the current chunk of regions. + np.matmul(topic_region_chunk, cell_topic, out=normalized_imputed_acc_chunk) + + log.info(' - Scale imputed accessibility matrix chunk ("count" matrix).') + # Scale imputed accessibility matrix chunk ("count" matrix). + normalized_imputed_acc_chunk *= np.float32(scale_factor1) + + log.info(" - Only keep integer part.") + # Only keep integer part. + # This will convert very small values (< 1.0) to zero and removes noise. + np.floor(normalized_imputed_acc_chunk, out=normalized_imputed_acc_chunk) + + # Set correct output chunk end position by taking into account + # that rows with all zeros will be filtered out. + output_regions_chunk_end = output_regions_chunk_start + regions_chunk_size + + log.info( + " - Normalize imputed accessibility by dividing by the total imputed " + f"accessibility per cell and multiply by {scale_factor2}." + ) + # Normalize imputed accessibility by dividing by the total imputed + # accessibility per cell and multiply by 10^4. + normalized_imputed_acc_chunk /= impute_acc_per_cell_sum_scaled + + log.info(" - Add pseudocount of 1 and apply log normalization.") + # Add pseudocount of 1 and apply log normalization. + np.log1p( + normalized_imputed_acc_chunk, + out=normalized_imputed_acc_chunk, + ) + + log.info( + " - Calculate mean and dispersion of imputed accessibility per region." + ) + ( + per_region_means_on_normalized_imputed_acc_chunk, + per_region_dispersions_on_normalized_imputed_acc_chunk, + ) = calculate_per_region_mean_and_dispersion_on_normalized_imputed_acc_chunk( + normalized_imputed_acc_chunk + ) + + per_region_means_on_normalized_imputed_acc[ + output_regions_chunk_start:output_regions_chunk_end + ] = per_region_means_on_normalized_imputed_acc_chunk + per_region_dispersions_on_normalized_imputed_acc[ + output_regions_chunk_start:output_regions_chunk_end + ] = per_region_dispersions_on_normalized_imputed_acc_chunk + + del normalized_imputed_acc_chunk + + log.info( + "Finished Calculating mean and dispersion of imputed accessibility per region." + ) + + return ( + region_names_to_keep, + per_region_means_on_normalized_imputed_acc, + per_region_dispersions_on_normalized_imputed_acc, + ) + + def impute_accessibility( cistopic_obj: CistopicObject, selected_cells: list[str] | None = None, @@ -388,7 +839,7 @@ def impute_accessibility( if selected_regions is not None: topic_region = topic_region.loc[selected_regions] region_names = selected_regions - # Convert cell_topic and topic_region 2d arrays to np.float32 so + # Convert cell_topic and region_topic 2d arrays to np.float32 so # multiplying them uses 4 times less memory than with np.float64 cell_topic = cell_topic.to_numpy().astype(np.float32) topic_region = topic_region.to_numpy().astype(np.float32) @@ -400,7 +851,7 @@ def calculate_imputed_accessibility( cell_topic: np.ndarray, region_names: list, scale_factor: int, - chunk_size: int + chunk_size: int, ) -> tuple[np.ndarray, list]: """ Calculate imputed accessibility in chunks of chunk_size. @@ -451,7 +902,7 @@ def calculate_imputed_accessibility( f"{input_chunk_start}-{input_chunk_end}" ) topic_region_chunk = topic_region[ - input_chunk_start:input_chunk_start + chunk_size + input_chunk_start : input_chunk_start + chunk_size ] imputed_acc_chunk = topic_region_chunk @ cell_topic @@ -481,9 +932,9 @@ def calculate_imputed_accessibility( # Convert from float32 to int32 and fill in the values in the full # imputed accessibility matrix. - imputed_acc[ - output_chunk_start:output_chunk_end, : - ] = imputed_acc_chunk[region_idx_to_keep_chunk] + imputed_acc[output_chunk_start:output_chunk_end, :] = imputed_acc_chunk[ + region_idx_to_keep_chunk + ] # Only retain that part of the imputed accessibility matrix that was actually # filled in. @@ -508,111 +959,96 @@ def calculate_imputed_accessibility( return imputed_acc_obj -def normalize_scores( - imputed_acc: pd.DataFrame | CistopicImputedFeatures, - scale_factor: int = 10**4, -): - """ - Log-normalize imputation data. Feature counts for each cell are divided by the total counts for that cell and multiplied by the scale_factor. - - Parameters - ---------- - imputed_acc: pd.DataFrame or :class:`CistopicImputedFeatures` - A dataframe with values to be normalized or cisTopic imputation data. - scale_factor: int - Scale factor for cell-level normalization. Default: 10**4 - - Return - ------ - pd.DataFrame or CistopicImputedFeatures - The output class will be the same as the used as input. - - """ - # Create cisTopic logger - level = logging.INFO - log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" - handlers = [logging.StreamHandler(stream=sys.stdout)] - logging.basicConfig(level=level, format=log_format, handlers=handlers) - log = logging.getLogger("cisTopic") - - log.info("Normalizing imputed data") - - def calculate_normalized_scores(imputed_acc: np.ndarray, scale_factor: int): - # Divide each imputed accessibility by sum of imputed accessibility for the - # whole cell column and multiply then by the scale factor. - # To avoid a big extra memory allocation matrix for applying the scale factor, - # imputed_acc is divided by (np.sum(imputed_acc, axis=0) / scale_factor), - # instead of doing `imputed_acc / np.sum(imputed_acc, axis=0) * scale_factor`. - normalized_acc = imputed_acc / (np.sum(imputed_acc, axis=0) / scale_factor) - # Apply log1p element wise in place, to avoid a big memory allocation. - return np.log1p(normalized_acc, out=normalized_acc) - - if isinstance(imputed_acc, CistopicImputedFeatures): - output = CistopicImputedFeatures( - calculate_normalized_scores( - imputed_acc=( - imputed_acc.mtx.toarray() - if scipy.sparse.issparse(imputed_acc.mtx) - else imputed_acc.mtx - ), - scale_factor=scale_factor - ), - imputed_acc.feature_names, - imputed_acc.cell_names, - imputed_acc.project, - ) - elif isinstance(imputed_acc, pd.DataFrame): - output = pd.DataFrame( - calculate_normalized_scores( - imputed_acc=imputed_acc.to_numpy(), - scale_factor=scale_factor - ), - index=imputed_acc.index, - columns=imputed_acc.columns, - ) - log.info("Done!") - return output - - def find_highly_variable_features( - input_mat: pd.DataFrame | CistopicImputedFeatures, + features: list[str], + per_region_means_on_normalized_imputed_acc: npt.NDArray[np.float64], + per_region_dispersions_on_normalized_imputed_acc: npt.NDArray[np.float64], min_disp: float = 0.05, min_mean: float = 0.0125, - max_disp: float = np.inf, + max_disp: float = float("inf"), max_mean: float = 3, n_bins: int = 20, - n_top_features: int = None, - plot: bool = True, - save: str = None, + n_top_features: int | None = None, + plot: bool | str | None = True, ): """ Find highly variable features. + Find highly variable features by using output of + `calculate_per_region_mean_and_dispersion_on_normalized_imputed_acc` + as input: + - `features` + - `per_region_means_on_normalized_imputed_acc` + - `per_region_dispersions_on_normalized_imputed_acc` + Parameters ---------- - input_mat: pd.DataFrame or :class:`CistopicImputedFeatures` - A dataframe with values to be normalize or cisTopic imputation data. - min_disp: float, optional - Minimum dispersion value for a feature to be selected. Default: 0.05 - min_mean: float, optional - Minimum mean value for a feature to be selected. Default: 0.0125 - max_disp: float, optional - Maximum dispersion value for a feature to be selected. Default: np.inf - max_mean: float, optional - Maximum mean value for a feature to be selected. Default: 3 - n_bins: int, optional - Number of bins for binning the mean gene expression. Normalization is done with respect to each bin. Default: 20 - n_top_features: int, optional - Number of highly-variable features to keep. If specifed, dispersion and mean thresholds will be ignored. Default: None - plot: bool, optional - Whether to plot dispersion versus mean values. Default: True. - save: str, optional - Path to save feature selection plot. Default: None + features + List of feature (region) names. + per_region_means_on_normalized_imputed_acc + Mean of normalized imputed accessibility per region. + per_region_dispersions_on_normalized_imputed_acc + Dispersion of normalized imputed accessibility per region. + min_disp + Minimum dispersion value for a feature to be selected. + Default: 0.05 + min_mean + Minimum mean value for a feature to be selected. + Default: 0.0125 + max_disp + Maximum dispersion value for a feature to be selected. + Default: float("inf") + max_mean + Maximum mean value for a feature to be selected. + Default: 3 + n_bins + Number of bins for binning the mean gene expression. + Normalization is done with respect to each bin. + Default: 20 + n_top_features + Number of highly-variable features to keep. + If specified, dispersion and mean thresholds will be ignored. + Default: None + plot + Whether to plot dispersion versus mean values. + If `True`, plot will be shown. + if `str`, plot will be saved to the specified path. + If `False` or `None, plot will not be shown. + Default: True Return ------ - List - List with selected features. + List with selected features. + + Examples + -------- + Get mean and dispersion of normalized imputed accessibility per region. + >>> ( + ... region_names_to_keep, + ... per_region_means_on_normalized_imputed_acc, + ... per_region_dispersions_on_normalized_imputed_acc, + ... ) = calculate_per_region_mean_and_dispersion_on_normalized_imputed_acc( + ... region_topic=region_topic, + ... cell_topic=cell_topic, + ... region_names=region_names, + ... scale_factor1 = 10**6, + ... scale_factor2 = 10**4, + ... regions_chunk_size=20000, + ... ) + + Find highly variable features. + >>> find_highly_variable_features( + ... features=region_names_to_keep, + ... per_region_means_on_normalized_imputed_acc=per_region_means_on_normalized_imputed_acc, + ... per_region_dispersions_on_normalized_imputed_acc=per_region_dispersions_on_normalized_imputed_acc, + ... min_disp = 0.05, + ... min_mean = 0.0125, + ... max_disp = float("inf"), + ... max_mean = 3, + ... n_bins = 20, + ... n_top_features = None, + ... plot = True, + ... ) """ # Create cisTopic logger @@ -622,36 +1058,22 @@ def find_highly_variable_features( logging.basicConfig(level=level, format=log_format, handlers=handlers) log = logging.getLogger("cisTopic") - if isinstance(input_mat, pd.DataFrame): - mat = input_mat.values - features = input_mat.index.tolist() - else: - mat = input_mat.mtx - features = input_mat.feature_names - - if sparse.issparse(mat): - mean, var = sklearn.utils.sparsefuncs.mean_variance_axis(mat, axis=1) - else: - log.info("Calculating mean") - mean = np.mean(mat, axis=1, dtype=np.float32) - log.info("Calculating variance") - var = np.var(mat, axis=1, dtype=np.float32) - - mean[mean == 0] = 1e-12 - dispersion = var / mean - # Logarithmic dispersion as in Seurat - dispersion[dispersion == 0] = np.nan - dispersion = np.log(dispersion) df = pd.DataFrame() - df["means"] = mean - df["dispersions"] = dispersion + df["means"] = np.asarray( + per_region_means_on_normalized_imputed_acc, + dtype=np.float64, + ) + df["dispersions"] = np.asarray( + per_region_dispersions_on_normalized_imputed_acc, + dtype=np.float64, + ) df["mean_bin"] = pd.cut(df["means"], bins=n_bins) - disp_grouped = df.groupby("mean_bin")["dispersions"] + disp_grouped = df.groupby("mean_bin", observed=False)["dispersions"] disp_mean_bin = disp_grouped.mean() disp_std_bin = disp_grouped.std(ddof=1) # Retrieve those regions that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have - # a normalized dispersion of 1 + # a normalized dispersion of 1. one_feature_per_bin = disp_std_bin.isnull() feature_indices = np.where(one_feature_per_bin[df["mean_bin"].values])[0].tolist() @@ -687,8 +1109,8 @@ def find_highly_variable_features( dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat feature_subset = np.logical_and.reduce( ( - mean > min_mean, - mean < max_mean, + per_region_means_on_normalized_imputed_acc > min_mean, + per_region_means_on_normalized_imputed_acc < max_mean, dispersion_norm > min_disp, dispersion_norm < max_disp, ) @@ -697,16 +1119,16 @@ def find_highly_variable_features( df["highly_variable"] = feature_subset var_features = [features[i] for i in df[df.highly_variable].index.to_list()] - fig = plt.figure() if plot: + fig = plt.figure() matplotlib.rcParams["agg.path.chunksize"] = 10000 plt.scatter( df["means"], df["dispersions_norm"], c=feature_subset, s=10, alpha=0.1 ) plt.xlabel("Mean measurement of features") plt.ylabel("Normalized dispersion of the features") - if save is not None: - fig.savefig(save) + if isinstance(plot, str): + fig.savefig(plot) plt.show() log.info("Done!") @@ -723,7 +1145,7 @@ def find_diff_features( log2fc_thr: float = np.log2(1.5), split_pattern: str = "___", n_cpu: int = 1, - **kwargs + **kwargs, ): """ Find differential imputed features. @@ -895,7 +1317,9 @@ def markers( dtype=np.int64, ) - log.info(f"Subsetting data for {contrast_name} ({fg_cells_index.shape[0]} of {mat.shape[1]})") + log.info( + f"Subsetting data for {contrast_name} ({fg_cells_index.shape[0]} of {mat.shape[1]})" + ) if sparse.issparse(mat): fg_mat = mat[:, fg_cells_index].toarray() @@ -913,14 +1337,14 @@ def markers( chunk_size = 3000 - # Calculate wilcox test for each region in multiple ray processes (3000 regions per process). - wilcox_test_pvalues_nested_list = ray.get( + # Calculate wilcoxon test for each region in multiple ray processes (3000 regions per process). + wilcoxon_test_pvalues_nested_list = ray.get( [ - get_wilcox_test_pvalues_ray.remote( + get_wilcoxon_test_pvalues_ray.remote( fg_mat_ref, bg_mat_ref, start=start, - end=min(start + chunk_size, fg_mat.shape[0]) + end=min(start + chunk_size, fg_mat.shape[0]), ) for start in range(0, fg_mat.shape[0], chunk_size) ] @@ -929,24 +1353,24 @@ def markers( # Remove foreground and background matrix from ray object store. del fg_mat_ref, bg_mat_ref - # Flatten wilcox tests pvalues nested list. - wilcox_test_pvalues = [] + # Flatten wilcoxon tests pvalues nested list. + wilcoxon_test_pvalues = [] - for wilcox_test_pvalues_part in wilcox_test_pvalues_nested_list: - wilcox_test_pvalues.extend(wilcox_test_pvalues_part) + for wilcoxon_test_pvalues_part in wilcoxon_test_pvalues_nested_list: + wilcoxon_test_pvalues.extend(wilcoxon_test_pvalues_part) else: - wilcox_test_pvalues = get_wilcox_test_pvalues(fg_mat, bg_mat) + wilcoxon_test_pvalues = get_wilcoxon_test_pvalues(fg_mat, bg_mat) log.info(f"Computing log2FC for {contrast_name}") log2_fc = get_log2_fc(fg_mat, bg_mat) - adj_pvalues = p_adjust_bh(wilcox_test_pvalues) + adj_pvalues = p_adjust_bh(wilcoxon_test_pvalues) markers_dataframe = pd.DataFrame( { "Log2FC": log2_fc, "Adjusted_pval": adj_pvalues, - "Contrast": [contrast_name] * adj_pvalues.shape[0] + "Contrast": [contrast_name] * adj_pvalues.shape[0], }, index=features, ) @@ -954,9 +1378,7 @@ def markers( markers_dataframe = markers_dataframe.loc[ markers_dataframe["Adjusted_pval"] <= adjpval_thr ] - markers_dataframe = markers_dataframe.loc[ - markers_dataframe["Log2FC"] >= log2fc_thr - ] + markers_dataframe = markers_dataframe.loc[markers_dataframe["Log2FC"] >= log2fc_thr] markers_dataframe = markers_dataframe.sort_values( ["Log2FC", "Adjusted_pval"], ascending=[False, True], @@ -965,9 +1387,9 @@ def markers( return markers_dataframe -def get_wilcox_test_pvalues(fg_mat, bg_mat): +def get_wilcoxon_test_pvalues(fg_mat, bg_mat): """ - Calculate wilcox test p-values between foreground and background matrix. + Calculate wilcoxon test p-values between foreground and background matrix. Parameters ---------- @@ -983,21 +1405,20 @@ def get_wilcox_test_pvalues(fg_mat, bg_mat): f" {fg_mat.shape[0]} vs {bg_mat.shape[0]}" ) - wilcox_test_pvalues = [ - wilcox_test.pvalue - for wilcox_test in [ - ranksums(fg_mat[i], y=bg_mat[i]) - for i in range(fg_mat.shape[0]) + wilcoxon_test_pvalues = [ + wilcoxon_test.pvalue + for wilcoxon_test in [ + ranksums(fg_mat[i], y=bg_mat[i]) for i in range(fg_mat.shape[0]) ] ] - return wilcox_test_pvalues + return wilcoxon_test_pvalues @ray.remote -def get_wilcox_test_pvalues_ray(fg_mat, bg_mat, start, end): +def get_wilcoxon_test_pvalues_ray(fg_mat, bg_mat, start, end): """ - Calculate wilcox test p-values with ray between a subset of foreground and background matrix. + Calculate wilcoxon test p-values with ray between a subset of foreground and background matrix. Parameters ---------- @@ -1017,21 +1438,19 @@ def get_wilcox_test_pvalues_ray(fg_mat, bg_mat, start, end): f" {fg_mat.shape[0]} vs {bg_mat.shape[0]}" ) - wilcox_test_pvalues_part = [ - wilcox_test.pvalue - for wilcox_test in [ - ranksums(fg_mat[i], y=bg_mat[i]) - for i in range(start, end) + wilcoxon_test_pvalues_part = [ + wilcoxon_test.pvalue + for wilcoxon_test in [ + ranksums(fg_mat[i], y=bg_mat[i]) for i in range(start, end) ] ] - return wilcox_test_pvalues_part + return wilcoxon_test_pvalues_part -def p_adjust_bh(p: float): - """ - Benjamini-Hochberg p-value correction for multiple hypothesis testing. - """ +# TODO: Add these generic functions to another package +def p_adjust_bh(p: npt.NDArray): + """Benjamini-Hochberg p-value correction for multiple hypothesis testing.""" p = np.asfarray(p) by_descend = p.argsort()[::-1] by_orig = by_descend.argsort() @@ -1041,7 +1460,7 @@ def p_adjust_bh(p: float): @numba.njit(parallel=True) -def subset_array_second_axis(arr, col_indices): +def subset_array_second_axis(arr: npt.NDArray, col_indices: npt.NDArray): """ Subset array by second axis based on provided `col_indices`. @@ -1056,10 +1475,20 @@ def subset_array_second_axis(arr, col_indices): 1D-numpy array (preferably with np.int64 as dtype) with column indices. """ + if arr.ndim != 2: + raise ValueError("arr should be a 2D array") + + if col_indices.ndim != 1: + raise ValueError("col_indices should be a 1D array") + if np.max(col_indices) >= arr.shape[1]: - raise IndexError(f"index {np.max(col_indices)} is out of bounds for axis 1 with size {arr.shape[1]}") + raise IndexError( + f"index {np.max(col_indices)} is out of bounds for axis 1 with size {arr.shape[1]}" + ) if np.min(col_indices) < -arr.shape[1]: - raise IndexError(f"index {np.min(col_indices)} is out of bounds for axis 1 with size {arr.shape[1]}") + raise IndexError( + f"index {np.min(col_indices)} is out of bounds for axis 1 with size {arr.shape[1]}" + ) # Create empty subset array of correct dimensions and dtype. subset_arr = np.empty( @@ -1075,7 +1504,7 @@ def subset_array_second_axis(arr, col_indices): @numba.njit(parallel=True) -def mean_axis1(arr): +def mean_axis1(arr: npt.NDArray): """ Calculate column wise mean of 2D-numpy matrix with numba, mimicking `np.mean(x, axis=1)`. @@ -1092,7 +1521,7 @@ def mean_axis1(arr): @numba.njit -def get_log2_fc(fg_mat, bg_mat): +def get_log2_fc(fg_mat: npt.NDArray, bg_mat: npt.NDArray): """ Calculate log2 fold change between foreground and background matrix. @@ -1104,6 +1533,9 @@ def get_log2_fc(fg_mat, bg_mat): 2D-numpy background matrix. """ + if fg_mat.ndim != 2 or bg_mat.ndim != 2: + raise ValueError("fg_mat and bg_mat should be 2D arrays") + if fg_mat.shape[0] != bg_mat.shape[0]: raise ValueError( "Foreground matrix and background matrix have a different first dimension:" @@ -1115,6 +1547,90 @@ def get_log2_fc(fg_mat, bg_mat): # np.log2( # (np.mean(fg_mat, axis=1) + 10**-12) / (np.mean(bg_mat, axis=1) + 10**-12) # ) - return np.log2( - (mean_axis1(fg_mat) + 10**-12) / (mean_axis1(bg_mat) + 10**-12) - ) + return np.log2((mean_axis1(fg_mat) + 10**-12) / (mean_axis1(bg_mat) + 10**-12)) + + +@numba.jit(nopython=True) +def rankdata_average_numba(arr: npt.NDArray): + """ + Assign ranks to data, dealing with ties by taking average of ranks that would have been assigned to all tied values. + + Ranks begin at 1. + + Algorithm based on `scipy.stats.ranksums` of scipy 1.11.x with the following + parameters: `rankdata(a, method="average, axis=None, nan_policy="omit")`, + but with the assumption that there are no `np.nan` values. + + https://github.com/scipy/scipy/blob/maintenance/1.11.x/scipy/stats/_stats_py.py#L10123-L10267 + + """ + sorter = np.argsort(arr, kind="quicksort") + inv = np.empty(sorter.size, dtype=np.intp) + inv[sorter] = np.arange(sorter.size, dtype=np.intp) + arr = arr[sorter] + obs = np.empty(arr.shape, dtype=np.intp) + obs[0] = True + obs[1:] = arr[1:] != arr[:-1] + dense = obs.cumsum()[inv] + non_zero = np.nonzero(obs)[0] + count = np.empty(non_zero.shape[0] + 1) + count[0:-1] = non_zero + count[-1] = len(obs) + result = 0.5 * (count[dense] + count[dense - 1] + 1) + return result + + +@numba.jit(nopython=True) +def norm_sf(z: float): + """Survival function (1 - `cdf`) at z of the given RV.""" + return (1.0 + math.erf(-z / math.sqrt(2.0))) / 2.0 + + +@numba.jit(nopython=True) +def ranksums_numba(x: npt.NDArray, y: npt.NDArray): + """ + Compute the Wilcoxon rank-sum statistic for two samples. + + The Wilcoxon rank-sum test tests the null hypothesis that two sets + of measurements are drawn from the same distribution. The alternative + hypothesis is that values in one sample are more likely to be + larger than the values in the other sample. + + This test should be used to compare two samples from continuous + distributions. It does not handle ties between measurements + in x and y. + + Algorithm based on `scipy.stats.ranksums`. + """ + n1 = len(x) + n2 = len(y) + alldata = np.concatenate((x, y)) + ranked = rankdata_average_numba(alldata) + x = ranked[:n1] + s = np.sum(x, axis=0) + expected = n1 * (n1 + n2 + 1) / 2.0 + z = (s - expected) / np.sqrt(n1 * n2 * (n1 + n2 + 1) / 12.0) + prob = 2 * norm_sf(abs(z)) + return z, prob + + +@numba.jit(nopython=True, parallel=True) +def ranksums_numba_multiple(X: npt.NDArray, Y: npt.NDArray): + """ + Compute multiple Wilcoxon rank-sum statistics for two samples. + + For each row of X and Y Wilcoxon rank-sum statistics for two samples are computed. + + """ + if X.ndim != 2 or Y.ndim != 2: + raise ValueError("X and Y should be 2D arrays") + n = X.shape[0] + if Y.shape[0] != n: + raise ValueError("X and Y should have the same shape on dimension 0") + ranksums_z = np.empty((n,), dtype=np.float64) + ranksums_p = np.empty((n,), dtype=np.float64) + for i in numba.prange(n): + z, p = ranksums_numba(X[i], Y[i]) + ranksums_z[i] = z + ranksums_p[i] = p + return ranksums_z, ranksums_p diff --git a/src/pycisTopic/fragments.py b/src/pycisTopic/fragments.py index e35bf00..68c0306 100644 --- a/src/pycisTopic/fragments.py +++ b/src/pycisTopic/fragments.py @@ -2,7 +2,7 @@ import gzip from operator import itemgetter -from typing import Literal, Sequence +from typing import TYPE_CHECKING, Literal, Sequence import numpy as np import pandas as pd @@ -10,9 +10,15 @@ import pyarrow as pa # type: ignore[import] import pyarrow.csv # type: ignore[import] import pyranges as pr # type: ignore[import] +import scipy as sp + from pycisTopic.genomic_ranges import intersection as gr_intersection +from pycisTopic.genomic_ranges import overlap as gr_overlap from pycisTopic.utils import normalise_filepath +if TYPE_CHECKING: + from pathlib import Path + # Enable Polars global string cache so all categoricals are created with the same # string cache. pl.enable_string_cache() @@ -103,7 +109,7 @@ def read_fragments_to_pyranges( separator="\t", use_pyarrow=False, new_columns=bed_column_names[:column_count], - dtypes={ + schema_overrides={ bed_column: dtype for bed_column, dtype in { "Chromosome": pl.Categorical, @@ -263,7 +269,7 @@ def read_bed_to_polars_df( separator="\t", use_pyarrow=False, new_columns=bed_column_names[:column_count], - dtypes={ + schema_overrides={ bed_column: dtype for bed_column, dtype in { "Chromosome": pl.Categorical, @@ -313,6 +319,9 @@ def read_bed_to_polars_df( def read_fragments_to_polars_df( fragments_bed_filename: str, engine: str | Literal["polars"] | Literal["pyarrow"] = "pyarrow", + sample_id: str | None = None, + cb_end_to_remove: str | None = "-1", + cb_sample_separator: str | None = "___", ) -> pl.DataFrame: """ Read fragments BED file to a Polars DataFrame. @@ -325,7 +334,15 @@ def read_fragments_to_polars_df( fragments_bed_filename Fragments BED filename. engine - Use Polars or pyarrow to read the fragments BED file (default: pyarrow). + Use Polars or pyarrow to read the fragments BED file (default: `pyarrow`). + sample_id + Optional sample ID to append after cell barcode after removing `cb_end_to_remove` + and appending `cb_sample_separator`. + cb_end_to_remove + Remove this string from the end of the cell barcode if `sample_id` is specified. + cb_sample_separator + Add this string to the cell barcode if `sample_id` is specified, after removing + `cb_end_to_remove` and before appending `sample_id`. Returns ------- @@ -349,29 +366,67 @@ def read_fragments_to_polars_df( ... fragments_bed_filename="fragments.tsv", ... ) + Read gzipped fragments BED file to a Polars DataFrame and add sample ID to cell + barcode names after removing `cb_end_to_remove` string from cell barcode and + appending `cb_sample_separator` to the cell barcode. + + >>> fragments_df_pl = read_fragments_to_polars_df( + ... fragments_bed_filename="fragments.tsv.gz", + ... sample_id="sample1", + ... cb_end_to_remove="-1", + ... cb_sample_separator="___", + ... ) + """ fragments_df_pl = read_bed_to_polars_df( bed_filename=fragments_bed_filename, engine=engine, min_column_count=4, - ) + ).lazy() # If no score is provided or score column is ".", generate a score column with the # number of fragments which have the same chromosome, start, end and CB. - if ( - "Score" not in fragments_df_pl.columns - or fragments_df_pl.schema["Score"] == pl.Utf8 - ): + if fragments_df_pl.collect_schema().get("Score") in (None, pl.Utf8): fragments_df_pl = fragments_df_pl.group_by( ["Chromosome", "Start", "End", "Name"] ).agg(pl.len().cast(pl.Int32()).alias("Score")) else: fragments_df_pl = fragments_df_pl.with_columns(pl.col("Score").cast(pl.Int32())) + # Modify cell barcode if sample ID is specified or an empty string. + if sample_id or sample_id == "": + separator_and_sample_id = ( + f"{cb_sample_separator + sample_id}" if cb_sample_separator else sample_id + ) + + if not cb_end_to_remove: + # Append separator and sample ID to cell barcode. + fragments_df_pl = fragments_df_pl.with_columns( + (pl.col("Name").cast(pl.Utf8) + pl.lit(separator_and_sample_id)).cast( + pl.Categorical + ) + ) + else: + # Remove `cb_end_to_remove` from the end of the cell barcode before adding + # separator and sample ID to cell barcode. + fragments_df_pl = fragments_df_pl.with_columns( + pl.col("Name") + .cast(pl.Utf8) + .str.replace(cb_end_to_remove + "$", separator_and_sample_id) + .cast(pl.Categorical) + ) + + fragments_df_pl = fragments_df_pl.collect() + return fragments_df_pl -def read_barcodes_file_to_polars_series(barcodes_tsv_filename: str) -> pl.Series: +def read_barcodes_file_to_polars_series( + barcodes_tsv_filename: str, + sample_id: str | None = None, + cb_end_to_remove: str | None = "-1", + cb_sample_separator: str | None = "___", +) -> pl.Series: """ Read barcode TSV file to a Polars Series. @@ -379,6 +434,14 @@ def read_barcodes_file_to_polars_series(barcodes_tsv_filename: str) -> pl.Series ---------- barcodes_tsv_filename TSV file with CBs. + sample_id + Optional sample ID to append after cell barcode after removing `cb_end_to_remove` + and appending `cb_sample_separator`. + cb_end_to_remove + Remove this string from the end of the cell barcode if `sample_id` is specified. + cb_sample_separator + Add this string to the cell barcode if `sample_id` is specified, after removing + `cb_end_to_remove` and before appending `sample_id`. Returns ------- @@ -402,17 +465,55 @@ def read_barcodes_file_to_polars_series(barcodes_tsv_filename: str) -> pl.Series ... barcodes_tsv_filename="barcodes.tsv", ... ) + Read gzipped barcodes TSV file to a Polars Series and add sample ID to cell + barcode names after removing `cb_end_to_remove` string from cell barcode and + appending `cb_sample_separator` to the cell barcode. + + >>> cbs = read_barcodes_file_to_polars_series( + ... barcodes_tsv_filename="barcodes.tsv", + ... sample_id="sample1", + ... cb_end_to_remove="-1", + ... cb_sample_separator="___", + ... ) + """ - cbs = pl.read_csv( - barcodes_tsv_filename, - has_header=False, - separator="\t", - columns=[0], - new_columns=["CB"], - dtypes={"CB": pl.Categorical}, - ).to_series() + cbs = ( + pl.read_csv( + barcodes_tsv_filename, + has_header=False, + separator="\t", + columns=[0], + new_columns=["CB"], + schema={"CB": pl.Categorical}, + ) + .filter(pl.col("CB").is_not_null()) + .unique(maintain_order=True) + ) + + # Modify cell barcode if sample ID is specified or an empty string. + if sample_id or sample_id == "": + separator_and_sample_id = ( + f"{cb_sample_separator + sample_id}" if cb_sample_separator else sample_id + ) - return cbs + if not cb_end_to_remove: + # Append separator and sample ID to cell barcode. + cbs = cbs.with_columns( + (pl.col("CB").cast(pl.Utf8) + pl.lit(separator_and_sample_id)).cast( + pl.Categorical + ) + ) + else: + # Remove `cb_end_to_remove` from the end of the cell barcode before adding + # separator and sample ID to cell barcode. + cbs = cbs.with_columns( + pl.col("CB") + .cast(pl.Utf8) + .str.replace(cb_end_to_remove + "$", separator_and_sample_id) + .cast(pl.Categorical) + ) + + return cbs.to_series() def create_pyranges_from_polars_df(bed_df_pl: pl.DataFrame) -> pr.PyRanges: @@ -496,10 +597,10 @@ def create_pyranges_from_polars_df(bed_df_pl: pl.DataFrame) -> pr.PyRanges: # those index values or not). bed_with_idx_df_pl = ( bed_df_pl - # Add index column and cast it from UInt32 to Int64 - .with_row_index("__index_level_0__").with_columns( - pl.col("__index_level_0__").cast(pl.Int64) - ) + # Add index column and cast it from UInt32 (`polars`) or + # UInt64 (`polars-u64-idx`) to Int64. + .with_row_index("__index_level_0__") + .with_columns(pl.col("__index_level_0__").cast(pl.Int64)) # Put index column as last column. .select(pl.col(pa_schema_fixed_categoricals.names)) ) @@ -550,7 +651,7 @@ def create_per_chrom_or_chrom_strand_df_pd( # Partition Polars DataFrame with BED entries per chromosome-strand # (stranded). bed_with_idx_df_pl.partition_by( - by=["Chromosome", "Strand"], maintain_order=False, as_dict=True + "Chromosome", "Strand", maintain_order=False, as_dict=True ).items(), key=itemgetter(0), ) @@ -564,7 +665,7 @@ def create_per_chrom_or_chrom_strand_df_pd( # Partition Polars DataFrame with BED entries per chromosome # (unstranded). bed_with_idx_df_pl.partition_by( - by=["Chromosome"], maintain_order=False, as_dict=True + "Chromosome", maintain_order=False, as_dict=True ).items(), key=itemgetter(0), ) @@ -629,24 +730,28 @@ def get_fragments_per_cb( fragments_stats_per_cb_df_pl = ( fragments_df_pl.lazy() .rename({"Name": "CB"}) - .with_columns( - (pl.col("End") - pl.col("Start")).alias("fragment_length") - ) + .with_columns((pl.col("End") - pl.col("Start")).alias("fragment_length")) .with_columns( pl.col("fragment_length").lt(147).alias("nucleosome_free"), pl.col("fragment_length").is_between(147, 294).alias("mononucleosome"), ) - .group_by(by="CB", maintain_order=True) + .group_by("CB", maintain_order=True) .agg( [ - pl.col("Score").sum().alias("total_fragments_count"), - pl.len().alias("unique_fragments_count"), - (pl.col("mononucleosome").sum() / pl.col("nucleosome_free").sum()).alias("nucleosome_signal") + pl.col("Score").sum().cast(pl.UInt32).alias("total_fragments_count"), + pl.len().cast(pl.UInt32).alias("unique_fragments_count"), + ( + pl.col("mononucleosome").sum() / pl.col("nucleosome_free").sum() + ).alias("nucleosome_signal"), ] ) .filter(pl.col(fragments_count_column) > min_fragments_per_cb) - .sort(by=fragments_count_column, descending=True) - .with_row_index(name="barcode_rank", offset=1) + .sort(fragments_count_column, descending=True) + .with_row_index( + name="barcode_rank", + offset=1, + ) + .with_columns(pl.col("barcode_rank").cast(pl.UInt32)) .with_columns( (pl.col("total_fragments_count") - pl.col("unique_fragments_count")).alias( "duplication_count" @@ -656,9 +761,10 @@ def get_fragments_per_cb( (pl.col("duplication_count") / pl.col("total_fragments_count")).alias( "duplication_ratio" ) - ).select( + ) + .select( pl.selectors.all() - pl.selectors.by_name("nucleosome_signal"), - pl.selectors.by_name("nucleosome_signal") + pl.selectors.by_name("nucleosome_signal"), ) .collect() ) @@ -768,7 +874,7 @@ def get_cbs_passing_filter( elif isinstance(keep_top_x_cbs, int): fragments_stats_per_cb_filtered_df_pl = ( fragments_stats_per_cb_df_pl.lazy() - .sort(by=fragments_count_column, descending=True) + .sort(fragments_count_column, descending=True) .head(keep_top_x_cbs) .collect() ) @@ -892,10 +998,6 @@ def get_insert_size_distribution( ---------- fragments_df_pl Polars DataFrame with fragments. - cbs - List/Polars Series with Cell barcodes. - See :func:`pycisTopic.fragments.get_cbs_passing_filter` for a way to get a - filtered list of cell barcodes (``selected_cbs`` variable). Returns ------- @@ -929,8 +1031,8 @@ def get_insert_size_distribution( (pl.col("End") - pl.col("Start")).abs().alias("insert_size"), ) .group_by("insert_size") - .agg([pl.len().alias("fragments_count")]) - .sort(by="insert_size", descending=True) + .agg([pl.len().cast(pl.UInt32).alias("fragments_count")]) + .sort("insert_size", descending=True) .with_columns( (pl.col("fragments_count") / pl.col("fragments_count").sum()).alias( "fragments_ratio" @@ -942,7 +1044,10 @@ def get_insert_size_distribution( return insert_size_distribution_df_pl -def get_fragments_in_peaks(fragments_df_pl: pl.DataFrame, regions_df_pl: pl.DataFrame) -> pl.DataFrame: +def get_fragments_in_peaks( + fragments_df_pl: pl.DataFrame, + regions_df_pl: pl.DataFrame, +) -> pl.DataFrame: """ Get number of total and unique fragments in peaks. @@ -1009,13 +1114,224 @@ def get_fragments_in_peaks(fragments_df_pl: pl.DataFrame, regions_df_pl: pl.Data pl.col("Name").alias("CB"), pl.col("Score"), ) - .group_by(by="CB", maintain_order=True) + .group_by("CB", maintain_order=True) .agg( [ - pl.col("Score").sum().alias("total_fragments_in_peaks_count"), - pl.len().alias("unique_fragments_in_peaks_count"), + pl.col("Score") + .sum() + .cast(pl.UInt32) + .alias("total_fragments_in_peaks_count"), + pl.len().cast(pl.UInt32).alias("unique_fragments_in_peaks_count"), ] ) ) return fragments_in_peaks_df_pl + + +def create_fragment_matrix_from_fragments( + fragments_bed_filename: str | Path, + regions_bed_filename: str | Path, + barcodes_tsv_filename: str | Path, + blacklist_bed_filename: str | Path | None = None, + sample_id: str | None = None, + cb_end_to_remove: str | None = "-1", + cb_sample_separator: str | None = "___", +): + """ + Create fragments matrix from a fragment file and BED file with regions. + + Parameters + ---------- + fragments_bed_filename + Fragments BED filename. + regions_bed_filename + Consensus peaks / SCREEN regions BED file for which to make the fragments matrix per cell barcode. + barcodes_tsv_filename + TSV file with selected cell barcodes after pycisTopic QC filtering. + blacklist_bed_filename + BED file with blacklisted regions (Amemiya et al., 2019). Default: None. + sample_id + Optional sample ID to append after cell barcode after removing `cb_end_to_remove` + and appending `cb_sample_separator`. + cb_end_to_remove + Remove this string from the end of the cell barcode if `sample_id` is specified. + cb_sample_separator + Add this string to the cell barcode if `sample_id` is specified, after removing + `cb_end_to_remove` and before appending `sample_id`. + + Returns + ------- + ( + counts_fragments_matrix, + cbs, + region_ids, + ) + + References + ---------- + Amemiya, H. M., Kundaje, A., & Boyle, A. P. (2019). The ENCODE blacklist: identification of problematic regions of the genome. Scientific reports, 9(1), 1-5. + + """ + # Create logger + # level = logging.INFO + # log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + # handlers = [logging.StreamHandler(stream=sys.stdout)] + # logging.basicConfig(level=level, format=log_format, handlers=handlers) + # log = logging.getLogger("cisTopic") + + # Read file with cell barcodes as a Polars Series and add sample ID to cell barcodes. + cbs = read_barcodes_file_to_polars_series( + barcodes_tsv_filename=barcodes_tsv_filename, + sample_id=sample_id, + cb_end_to_remove=cb_end_to_remove, + cb_sample_separator=cb_sample_separator, + ) + + # log.info("Reading data for " + project) + # Read fragments file to Polars Dataframe and add sample ID to cell barcodes. + fragments_df_pl = read_fragments_to_polars_df( + fragments_bed_filename=fragments_bed_filename, + engine="pyarrow", + sample_id=sample_id, + cb_end_to_remove=cb_end_to_remove, + cb_sample_separator=cb_sample_separator, + ) + + # Only keep fragments with the requested cell barcode. + fragments_cb_filtered_df_pl = filter_fragments_by_cb( + fragments_df_pl=fragments_df_pl, + cbs=cbs, + ).rename({"Name": "CB", "Score": "CB_count"}) + + del fragments_df_pl + + # Read regions BED file as a Polars Dataframe. + regions_df_pl = ( + read_bed_to_polars_df( + bed_filename=regions_bed_filename, + engine="polars", + min_column_count=3, + ) + .with_columns( + ( + pl.col("Chromosome") + + ":" + + pl.col("Start").cast(pl.Utf8) + + "-" + + pl.col("End").cast(pl.Utf8) + ) + .cast(pl.Categorical) + .alias("RegionID") + ) + .select( + pl.col("Chromosome"), + pl.col("Start"), + pl.col("End"), + pl.col("RegionID"), + ) + ) + + if blacklist_bed_filename: + # Read BED file with blacklisted regions . + blacklist_df_pl = read_bed_to_polars_df( + bed_filename=blacklist_bed_filename, + engine="polars", + min_column_count=3, + ).select( + pl.col("Chromosome"), + pl.col("Start"), + pl.col("End"), + ) + + # Filter out regions that overlap with blacklisted regions. + regions_df_pl = ( + regions_df_pl.lazy() + .join( + # Get all regionIDs that overlap with blacklisted regions. + gr_overlap( + regions1_df_pl=regions_df_pl, + regions2_df_pl=blacklist_df_pl, + how="first", + ) + .lazy() + .select( + pl.col("RegionID"), + ), + on="RegionID", + how="anti", + ) + .select( + pl.col("Chromosome"), + pl.col("Start"), + pl.col("End"), + pl.col("RegionID"), + ) + .collect() + ) + + # Get accessibility (binary and counts) for each region ID and cell barcode. + region_cb_df_pl = ( + gr_intersection( + regions1_df_pl=regions_df_pl, + regions2_df_pl=fragments_cb_filtered_df_pl, + # how: Literal["all", "containment", "first", "last"] | str | None = None, + how="all", + regions1_info=True, + regions2_info=True, + regions1_coord=False, + regions2_coord=False, + regions1_suffix="@1", + regions2_suffix="@2", + ) + .rename({"CB@2": "CB"}) + .lazy() + .group_by(["RegionID", "CB"]) + .agg( + # Get accessibility in binary form. + pl.lit(1).cast(pl.Int8).alias("accessible_binary"), + # Get accessibility in count form. + pl.len().cast(pl.UInt32).alias("accessible_count"), + ) + .join( + regions_df_pl.lazy() + .select(pl.col("RegionID")) + .with_row_index("region_idx"), + on="RegionID", + how="left", + ) + .join( + cbs.to_frame().lazy().with_row_index("CB_idx"), + on="CB", + how="left", + ) + .collect() + ) + + # Construct binary accessibility matrix as a sparse matrix + # (regions as rows and cells as columns). + counts_fragments_matrix = sp.sparse.csr_matrix( + ( + # All data points are 1: + # - same as: region_cb_df_pl.get_column("accessible_binary").to_numpy() + # - for count matrix: region_cb_df_pl.get_column("accessible_count").to_numpy() + # np.ones(region_cb_df_pl.shape[0], dtype=np.int8), + region_cb_df_pl.get_column("accessible_count").to_numpy(), + ( + # Row indices: + region_cb_df_pl.get_column("region_idx").to_numpy(), + # Column indices: + region_cb_df_pl.get_column("CB_idx").to_numpy(), + ), + ), + # Specify shape of the sparse matrix to avoid potential issues if the last + # (few) rows or (few) columns are empty as this will cause the CB list + # and regions list to be greater than the dimensions of the sparse matrix. + shape=(regions_df_pl.height, cbs.len()), + ) + + return ( + counts_fragments_matrix, + cbs.to_list(), + regions_df_pl.get_column("RegionID").to_list(), + ) diff --git a/src/pycisTopic/gene_activity.py b/src/pycisTopic/gene_activity.py index c093a5e..5d920ba 100644 --- a/src/pycisTopic/gene_activity.py +++ b/src/pycisTopic/gene_activity.py @@ -149,9 +149,7 @@ def get_gene_activity( gene_act = gene_act.round() gene_act = sparse.csr_matrix(gene_act) keep_features_index = non_zero_rows(gene_act) - gene_act = gene_act[ - keep_features_index, - ] + gene_act = gene_act[keep_features_index,] genes = subset_list(genes, keep_features_index) gene_act = CistopicImputedFeatures( gene_act, genes, imputed_acc_object.cell_names, project @@ -538,9 +536,7 @@ def region_weights( gini_weight, columns=["Gini"], index=subset_imputed_acc_object.feature_names ) gini_weight["Gini_weight"] = np.exp((1 - gini_weight["Gini"])) + np.exp(-1) - gini_weight = gini_weight.loc[ - regions_per_gene.Name, - ] + gini_weight = gini_weight.loc[regions_per_gene.Name,] regions_per_gene.Gini_weight = gini_weight.loc[:, "Gini_weight"] else: regions_per_gene.Gini_weight = 1 diff --git a/src/pycisTopic/gene_annotation.py b/src/pycisTopic/gene_annotation.py index d473a51..5388cb5 100644 --- a/src/pycisTopic/gene_annotation.py +++ b/src/pycisTopic/gene_annotation.py @@ -165,6 +165,7 @@ def get_tss_annotation_from_ensembl( "strand", "external_gene_name", "transcript_biotype", + "ensembl_gene_id", ], filters={"transcript_biotype": transcript_type} if transcript_type else None, ) @@ -192,6 +193,7 @@ def get_tss_annotation_from_ensembl( .alias("Strand") ), pl.col("Transcript type").alias("Transcript_type"), + pl.col("Gene stable ID").alias("Ensembl_gene_id"), ] ) @@ -261,7 +263,7 @@ def read_tss_annotation_from_bed(tss_annotation_bed_filename: str) -> pl.DataFra separator="\t", # Use 0-bytes as comment character so the header can start with "# Chromosome". comment_prefix="\0", - dtypes={ + schema_overrides={ # Convert Chromosome, Start and End column to the correct datatypes. "Chromosome": pl.Categorical, "# Chromosome": pl.Categorical, @@ -489,7 +491,7 @@ def get_chrom_sizes_and_alias_mapping_from_ncbi( chrom_sizes_and_alias_df_pl.rename({"ucsc": "# ucsc"}).write_csv( file=chrom_sizes_and_alias_tsv_filename, separator="\t", - has_header=True, + include_header=True, ) return chrom_sizes_and_alias_df_pl @@ -621,7 +623,7 @@ def get_chrom_sizes_and_alias_mapping_from_ucsc( has_header=True, comment_prefix="#", # Read all columns as strings. - infer_schema_length=0, + infer_schema=False, ) else: raise ValueError( @@ -646,7 +648,7 @@ def get_chrom_sizes_and_alias_mapping_from_ucsc( has_header=False, comment_prefix="#", new_columns=["ucsc", "length"], - dtypes=[pl.Utf8, pl.Int64], + schema={"uscc": pl.Utf8, "length": pl.Int64}, ) else: @@ -665,7 +667,7 @@ def get_chrom_sizes_and_alias_mapping_from_ucsc( chrom_sizes_and_alias_df_pl.rename({"ucsc": "# ucsc"}).write_csv( file=chrom_sizes_and_alias_tsv_filename, separator="\t", - has_header=True, + include_header=True, ) return chrom_sizes_and_alias_df_pl @@ -726,7 +728,7 @@ def find_most_likely_chromosome_source_in_bed( chrom_source_stats_df_pl = chrom_sizes_and_alias_df_pl.select( [ pl.col(column_name).is_in(chroms_from_bed).sum() - for column_name in chrom_sizes_and_alias_df_pl.columns + for column_name in chrom_sizes_and_alias_df_pl.collect_schema().names() ] ) diff --git a/src/pycisTopic/genomic_ranges.py b/src/pycisTopic/genomic_ranges.py index 411e219..ddf840e 100644 --- a/src/pycisTopic/genomic_ranges.py +++ b/src/pycisTopic/genomic_ranges.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: import numpy as np + import numpy.typing as npt # Intersection/overlap code is based on: # https://github.com/biocore-ntnu/pyranges/blob/master/pyranges/methods/intersection.py @@ -15,7 +16,7 @@ def _get_start_end_and_indexes_for_chrom( regions_per_chrom_dfs_pl: dict[str, pl.DataFrame], chrom: str, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: """ Get start, end and index positions from per chromosome Polars dataframe. @@ -55,7 +56,7 @@ def _intersect_per_chrom( regions2_per_chrom_dfs_pl: dict[str, pl.DataFrame], chrom: str, how: Literal["all", "containment", "first", "last"] | str | None = None, -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[pl.Series, pl.Series]: """ Get intersection between two region sets per chromosome. @@ -144,10 +145,10 @@ def _intersect_per_chrom( regions2_indexes = pl.Series("idx", regions2_indexes, dtype=pl.get_index_type()) regions1_all_indexes = pl.arange( - 0, indexes1_length, dtype=pl.get_index_type()(), eager=True + 0, indexes1_length, dtype=pl.get_index_type(), eager=True ).alias("idx") regions2_all_indexes = pl.arange( - 0, indexes2_length, dtype=pl.get_index_type()(), eager=True + 0, indexes2_length, dtype=pl.get_index_type(), eager=True ).alias("idx") regions1_missing_indexes = ( @@ -170,12 +171,20 @@ def _intersect_per_chrom( .to_series() ) - regions1_none_indexes = pl.repeat( - None, regions2_missing_indexes.len(), name="idx", eager=True - ).cast(pl.get_index_type()) - regions2_none_indexes = pl.repeat( - None, regions1_missing_indexes.len(), name="idx", eager=True - ).cast(pl.get_index_type()) + regions1_none_indexes = pl.select( + pl.repeat( + None, + n=regions2_missing_indexes.len(), + dtype=pl.get_index_type(), + ).alias("idx") + ).to_series() + regions2_none_indexes = pl.select( + pl.repeat( + None, + n=regions1_missing_indexes.len(), + dtype=pl.get_index_type(), + ).alias("idx") + ).to_series() if how == "outer": regions1_indexes = pl.concat( @@ -214,7 +223,7 @@ def _overlap_per_chrom( regions2_per_chrom_dfs_pl: dict[str, pl.DataFrame], chrom: str, how: Literal["all", "containment", "first"] | str | None = "first", -) -> np.ndarray: +) -> npt.NDArray[np.int64]: """ Get overlap between two region sets per chromosome. @@ -280,8 +289,7 @@ def _filter_intersection_output_columns( regions2_coord: bool, regions1_suffix: str, regions2_suffix: str, -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload @@ -293,8 +301,7 @@ def _filter_intersection_output_columns( regions2_coord: bool, regions1_suffix: str, regions2_suffix: str, -) -> pl.LazyFrame: - ... +) -> pl.LazyFrame: ... def _filter_intersection_output_columns( @@ -351,7 +358,7 @@ def _filter_intersection_output_columns( regions1_info_columns = [ # Remove region1 suffix from column names. pl.col(column_name).alias(column_name[:-regions1_suffix_length]) - for column_name in df.columns + for column_name in df.collect_schema().names() if ( column_name.endswith(regions1_suffix) and column_name not in regions1_coord_columns @@ -366,7 +373,7 @@ def _filter_intersection_output_columns( pl.col(column_name) if regions1_info else pl.col(column_name).alias(column_name[:-regions2_suffix_length]) - for column_name in df.columns + for column_name in df.collect_schema().names() if ( column_name.endswith(regions2_suffix) and column_name not in regions2_coord_columns @@ -579,13 +586,23 @@ def intersection( """ # TODO: chrom, stranded partitioning - regions1_per_chrom_dfs_pl = regions1_df_pl.partition_by( - ["Chromosome"], as_dict=True, maintain_order=True - ) - - regions2_per_chrom_dfs_pl = regions2_df_pl.partition_by( - ["Chromosome"], as_dict=True, maintain_order=True - ) + regions1_per_chrom_dfs_pl = { + str(chrom): regions1_chrom_df_pl + for (chrom,), regions1_chrom_df_pl in regions1_df_pl.partition_by( + ["Chromosome"], + as_dict=True, + maintain_order=True, + ).items() + } + + regions2_per_chrom_dfs_pl = { + str(chrom): regions2_chrom_df_pl + for (chrom,), regions2_chrom_df_pl in regions2_df_pl.partition_by( + ["Chromosome"], + as_dict=True, + maintain_order=True, + ).items() + } intersection_chrom_dfs_pl = {} @@ -725,7 +742,8 @@ def intersection( # Combine per chromosome dataframes to a full one. intersection_df_pl = pl.concat( - list(intersection_chrom_dfs_pl.values()), rechunk=False + list(intersection_chrom_dfs_pl.values()), + rechunk=False, ) return intersection_df_pl @@ -866,13 +884,23 @@ def overlap( """ # TODO: chrom, stranded partitioning - regions1_per_chrom_dfs_pl = regions1_df_pl.partition_by( - ["Chromosome"], as_dict=True, maintain_order=True - ) - - regions2_per_chrom_dfs_pl = regions2_df_pl.partition_by( - ["Chromosome"], as_dict=True, maintain_order=True - ) + regions1_per_chrom_dfs_pl = { + str(chrom): regions1_chrom_df_pl + for (chrom,), regions1_chrom_df_pl in regions1_df_pl.partition_by( + ["Chromosome"], + as_dict=True, + maintain_order=True, + ).items() + } + + regions2_per_chrom_dfs_pl = { + str(chrom): regions2_chrom_df_pl + for (chrom,), regions2_chrom_df_pl in regions2_df_pl.partition_by( + ["Chromosome"], + as_dict=True, + maintain_order=True, + ).items() + } overlap_chrom_dfs_pl = {} @@ -913,6 +941,9 @@ def overlap( return regions1_df_pl.head(0) # Combine per chromosome dataframes to a full one. - overlap_df_pl = pl.concat(list(overlap_chrom_dfs_pl.values()), rechunk=False) + overlap_df_pl = pl.concat( + list(overlap_chrom_dfs_pl.values()), + rechunk=False, + ) return overlap_df_pl diff --git a/src/pycisTopic/label_transfer.py b/src/pycisTopic/label_transfer.py index f8695ee..f0a1a97 100644 --- a/src/pycisTopic/label_transfer.py +++ b/src/pycisTopic/label_transfer.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from anndata import AnnData + def label_transfer( ref_anndata: AnnData, query_anndata: AnnData, @@ -29,7 +30,7 @@ def label_transfer( bbknn_components: int = 30, cca_components: int = 30, return_label_weights: bool = False, - **kwargs + **kwargs, ): """ diff --git a/src/pycisTopic/lda_models.py b/src/pycisTopic/lda_models.py index 1ea574f..8c9dd41 100644 --- a/src/pycisTopic/lda_models.py +++ b/src/pycisTopic/lda_models.py @@ -1,12 +1,11 @@ from __future__ import annotations +import json import logging import os import pickle -import random import subprocess import sys -import tempfile import time import warnings from itertools import chain @@ -19,15 +18,14 @@ import pandas as pd import polars as pl import ray +import scipy import tmtoolkit -from gensim import matutils, utils -from gensim.models import basemodel from pycisTopic.utils import loglikelihood, subset_list -from scipy import sparse if TYPE_CHECKING: from pycisTopic.cistopic_class import CistopicObject + class CistopicLDAModel: """ cisTopic LDA model class @@ -148,14 +146,16 @@ def run_cgs_models( Griffiths, T. L., & Steyvers, M. (2004). Finding scientific topics. Proceedings of the National academy of Sciences, 101(suppl 1), 5228-5235. """ - binary_matrix = sparse.csr_matrix(cistopic_obj.binary_matrix.transpose()) + binary_accessibility_matrix = scipy.sparse.csr_matrix( + cistopic_obj.binary_matrix.transpose() + ) region_names = cistopic_obj.region_names cell_names = cistopic_obj.cell_names ray.init(num_cpus=n_cpu, **kwargs) model_list = ray.get( [ run_cgs_model.remote( - binary_matrix, + binary_accessibility_matrix, n_topics=n_topic, cell_names=cell_names, region_names=region_names, @@ -177,7 +177,7 @@ def run_cgs_models( @ray.remote def run_cgs_model( - binary_matrix: sparse.csr_matrix, + binary_accessibility_matrix: sparse.csr_matrix, n_topics: int, cell_names: list[str], region_names: list[str], @@ -195,13 +195,13 @@ def run_cgs_model( Parameters ---------- - binary_matrix: sparse.csr_matrix + binary_accessibility_matrix: sparse.csr_matrix Binary sparse matrix containing cells as columns, regions as rows, and 1 if a regions is considered accessible on a cell (otherwise, 0). n_topics: int Number of topics to use in the model. - cell_names: list of str + cell_barcodes: list of str List containing cell names as ordered in the binary matrix columns. - region_names: list of str + region_ids: list of str List containing region names as ordered in the binary matrix rows. n_iter: int, optional Number of iterations for which the Gibbs sampler will be run. Default: 150. @@ -243,66 +243,40 @@ def run_cgs_model( lda_log.propagate = False warnings.filterwarnings("ignore") - # Set models - if alpha_by_topic and eta_by_topic: - model = lda.LDA( - n_topics=n_topics, - n_iter=n_iter, - random_state=random_state, - alpha=alpha / n_topics, - eta=eta / n_topics, - refresh=n_iter, - ) - elif alpha_by_topic and eta_by_topic is False: - model = lda.LDA( - n_topics=n_topics, - n_iter=n_iter, - random_state=random_state, - alpha=alpha / n_topics, - eta=eta, - refresh=n_iter, - ) - elif alpha_by_topic is False and eta_by_topic is True: - model = lda.LDA( - n_topics=n_topics, - n_iter=n_iter, - random_state=random_state, - alpha=alpha, - eta=eta / n_topics, - refresh=n_iter, - ) - else: - model = lda.LDA( - n_topics=n_topics, - n_iter=n_iter, - random_state=random_state, - alpha=alpha, - eta=eta, - refresh=n_iter, - ) + lda_alpha = alpha / n_topics if alpha_by_topic else alpha + lda_eta = eta / n_topics if eta_by_topic else eta + + model = lda.LDA( + n_topics=n_topics, + n_iter=n_iter, + random_state=random_state, + alpha=lda_alpha, + eta=lda_eta, + refresh=n_iter, + ) # Running model log.info(f"Running model with {n_topics} topics") start_time = time.time() - model.fit(binary_matrix) + model.fit(binary_accessibility_matrix) end_time = time.time() - start_time # Model evaluation arun_2010 = tmtoolkit.topicmod.evaluate.metric_arun_2010( model.topic_word_, model.doc_topic_, - np.asarray(binary_matrix.sum(axis=1)).astype(float), + np.asarray(binary_accessibility_matrix.sum(axis=1)).astype(float), ) cao_juan_2009 = tmtoolkit.topicmod.evaluate.metric_cao_juan_2009(model.topic_word_) mimno_2011 = tmtoolkit.topicmod.evaluate.metric_coherence_mimno_2011( model.topic_word_, - dtm=binary_matrix, + dtm=binary_accessibility_matrix, top_n=20, eps=1e-12, normalize=True, return_mean=False, ) - ll = loglikelihood(model.nzw_, model.ndz_, alpha, eta) + ll = loglikelihood(model.nzw_, model.ndz_, lda_alpha, lda_eta) # Organize data if len(mimno_2011) <= top_topics_coh: @@ -335,7 +309,7 @@ def run_cgs_model( list( chain.from_iterable( tmtoolkit.topicmod.model_stats.marginal_topic_distrib( - model.doc_topic_, binary_matrix.sum(axis=1) + model.doc_topic_, binary_accessibility_matrix.sum(axis=1) ).tolist() ) ), @@ -396,504 +370,576 @@ def run_cgs_model( return model -class LDAMallet(utils.SaveLoad, basemodel.BaseTopicModel): - """ - Wrapper class to run LDA models with Mallet. This class has been adapted from gensim (https://github.com/RaRe-Technologies/gensim/blob/27bbb7015dc6bbe02e00bb1853e7952ac13e7fe0/gensim/models/wrappers/ldamallet.py). +class LDAMallet: + """Class for running LDA models with Mallet.""" - Parameters - ---------- - num_topics: int - The number of topics to use in the model. - corpus: iterable of iterable of (int, int), optional - Collection of texts in BoW format. Default: None. - alpha: float, optional - Scalar value indicating the symmetric Dirichlet hyperparameter for topic proportions. Default: 50. - id2word : :class:`gensim.utils.FakeDict`, optional - Mapping between tokens ids and words from corpus, if not specified - will be inferred from `corpus`. Default: None. - n_cpu : int, optional - Number of threads that will be used for training. Default: 1. - tmp_dir : str, optional - tmp_dir for produced temporary files. Default: None. - optimize_interval : int, optional - Optimize hyperparameters every `optimize_interval` iterations (sometimes leads to Java exception 0 to switch off hyperparameter optimization). Default: 0. - iterations : int, optional - Number of training iterations. Default: 150. - topic_threshold : float, optional - Threshold of the probability above which we consider a topic. Default: 0.0. - random_seed: int, optional - Random seed to ensure consistent results, if 0 - use system clock. Default: 555. - mallet_path: str - Path to the mallet binary (e.g. /xxx/Mallet/bin/mallet). Default: "mallet". - - """ - - def __init__( - self, - num_topics: int, - corpus: list | None = None, - alpha: float = 50, - eta: float = 0.1, - id2word: utils.FakeDict = None, - n_cpu: int = 1, - tmp_dir: str = None, - optimize_interval: int = 0, - iterations: int = 150, - topic_threshold: float = 0.0, - random_seed: int = 555, - reuse_corpus: bool = False, + @staticmethod + def convert_binary_matrix_to_mallet_corpus_file( + binary_accessibility_matrix: scipy.sparse.csr, + mallet_corpus_filename: str, mallet_path: str = "mallet", - ): - logger = logging.getLogger("LDAMalletWrapper") - if id2word is None: - logger.warning( - "No id2word mapping provided; initializing from corpus, assuming identity" - ) - self.num_terms = utils.get_max_id(corpus) + 1 - else: - self.num_terms = id2word.num_terms - - if self.num_terms == 0: - raise ValueError("Cannot compute LDA over an empty collection (no terms)") - - self.num_topics = num_topics - self.topic_threshold = topic_threshold - self.alpha = alpha - self.eta = eta - self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir() - self.random_label = hex(random.randint(0, 0xFFFFFF))[2:] - self.n_cpu = n_cpu - self.optimize_interval = optimize_interval - self.iterations = iterations - self.random_seed = random_seed - self.mallet_path = mallet_path - if corpus is not None: - self.train(corpus, reuse_corpus) - - def corpus_to_mallet(self, corpus, file_like): + ) -> None: """ - Convert `corpus` to Mallet format and write it to `file_like` descriptor. + Convert binary matrix to Mallet serialized corpus file. Parameters ---------- - corpus - iterable of iterable of (int, int) - Collection of texts in BoW format. - file_like - Writable file-like object in text mode. + binary_accessibility_matrix + Binary accessibility matrix (region IDs vs cell barcodes) + mallet_corpus_filename + Mallet serialized corpus filename + mallet_path + Path to Mallet binary. Returns ------- None. """ - # Iterate over each cell ("document"). - for doc_idx, doc in enumerate(corpus): - # Get all accessible regions for the current cell. - tokens = chain.from_iterable([str(token_id)] for token_id, _cnt in doc) + logger = logging.getLogger("LDAMallet") - file_like.write(f'{doc_idx}\t0\t{" ".join(tokens)}\n') - - def convert_input(self, corpus): - """ - Convert corpus to Mallet format and save it to a temporary text file. + # Convert binary accessibility matrix to compressed sparse column matrix format + # and eliminate zeros as we assume later that for each found index, the + # associated value is 1. + binary_accessibility_matrix_csc = binary_accessibility_matrix.tocsc() + binary_accessibility_matrix_csc.eliminate_zeros() - Parameters - ---------- - corpus - iterable of iterable of (int, int) - Collection of texts in BoW format. + mallet_corpus_txt_filename = f"{mallet_corpus_filename}.txt" - Returns - ------- - None. + logger.info( + f'Serializing binary accessibility matrix to Mallet text corpus to "{mallet_corpus_txt_filename}".' + ) - """ - logger = logging.getLogger("LDAMalletWrapper") + if binary_accessibility_matrix_csc.shape[0] == 0: + raise ValueError( + "Binary accessibility matrix does not contain any cell barcodes." + ) - logger.info(f"Serializing temporary corpus to {self.fcorpustxt()}") + if binary_accessibility_matrix_csc.shape[1] == 0: + raise ValueError( + "Binary accessibility matrix does not contain any regions." + ) - with utils.open(self.fcorpustxt(), "wt") as fh: - self.corpus_to_mallet(corpus, fh) + with open(mallet_corpus_txt_filename, "w") as mallet_corpus_txt_fh: + # Iterate over each column (cell barcode index) of the sparse binary + # accessibility matrix in compressed sparse column matrix format and get + # all index positions (region IDs indices) for that cell barcode index. + for cell_barcode_idx, (indptr_start, indptr_end) in enumerate( + zip( + binary_accessibility_matrix_csc.indptr, + binary_accessibility_matrix_csc.indptr[1:], + ) + ): + # Get all region ID indices (assume all have an associated value of 1) + # for the current cell barcode index. + region_ids_idx = binary_accessibility_matrix_csc.indices[ + indptr_start:indptr_end + ] + + # Write Mallet text corpus for the current cell barcode index: + # - column 1: cell barcode index. + # - column 2: document number (always 0). + # - column 3: region IDs indices accessible in the current cell barcode. + mallet_corpus_txt_fh.write( + f'{cell_barcode_idx}\t0\t{" ".join([str(x) for x in region_ids_idx])}\n' + ) - cmd = [ - self.mallet_path, + mallet_import_file_cmd = [ + mallet_path, "import-file", "--preserve-case", "--keep-sequence", "--token-regex", "\\S+", "--input", - self.fcorpustxt(), + mallet_corpus_txt_filename, "--output", - self.fcorpusmallet(), + mallet_corpus_filename, ] logger.info( - f"Converting temporary corpus to MALLET format with: {' '.join(cmd)}" + f"Converting Mallet text corpus to Mallet serialised corpus with: {' '.join(mallet_import_file_cmd)}" ) + try: - subprocess.check_output(args=cmd, shell=False, stderr=subprocess.STDOUT) + subprocess.check_output( + args=mallet_import_file_cmd, shell=False, stderr=subprocess.STDOUT + ) except subprocess.CalledProcessError as e: raise RuntimeError( f"command '{e.cmd}' return with error (code {e.returncode}): {e.output}" ) - def train(self, corpus, reuse_corpus): + # Remove Mallet text corpus as only Mallet serialised corpus file is needed. + if os.path.exists(mallet_corpus_txt_filename): + os.remove(mallet_corpus_txt_filename) + + @staticmethod + def convert_cell_topic_probabilities_txt_to_parquet( + mallet_cell_topic_probabilities_txt_filename: str, + mallet_cell_topic_probabilities_parquet_filename: str, + ) -> None: """ - Train Mallet LDA. + Convert cell-topic probabilities from Mallet output to Parquet file. Parameters ---------- - corpus : iterable of iterable of (int, int) - Corpus in BoW format - reuse_corpus: bool, optional - Whether to reuse the mallet corpus in the tmp directory. Default: False - - """ - logger = logging.getLogger("LDAMalletWrapper") - if os.path.isfile(self.fcorpusmallet()) is False or reuse_corpus is False: - self.convert_input(corpus) - else: - logger.info("MALLET corpus already exists, training model") - - cmd = [ - self.mallet_path, - "train-topics", - "--input", - self.fcorpusmallet(), - "--num-topics", - str(self.num_topics), - "--alpha", - str(self.alpha), - "--beta", - str(self.eta), - "--optimize-interval", - str(self.optimize_interval), - "--num-threads", - str(self.n_cpu), - "--output-state", - self.fstate(), - "--output-doc-topics", - self.fdoctopics(), - "--output-topic-keys", - self.ftopickeys(), - "--num-iterations", - str(self.iterations), - "--inferencer-filename", - self.finferencer(), - "--doc-topics-threshold", - str(self.topic_threshold), - "--random-seed", - str(self.random_seed), - ] - - start = time.time() - logger.info(f"Training MALLET LDA with: {' '.join(cmd)}") - try: - subprocess.check_output(args=cmd, shell=False, stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - raise RuntimeError( - f"command '{e.cmd}' return with error (code {e.returncode}): {e.output}" - ) - self.word_topics = self.load_word_topics() - self.time = time.time() - start - - def load_word_topics(self): - """ - Load words X topics matrix from :meth:`gensim.models.wrappers.LDAMallet.LDAMallet.fstate` file. + mallet_cell_topic_probabilities_txt_filename + Mallet cell-topic probabilities text file. + mallet_cell_topic_probabilities_parquet_filename + Parquet output file with cell-topic probabilities. Returns ------- - np.ndarray - Matrix words X topics. + None """ - logger = logging.getLogger("LDAMalletWrapper") - logger.info("loading assigned topics from %s", self.fstate()) - word_topics = np.zeros((self.num_topics, self.num_terms), dtype=np.float64) - - with utils.open(self.fstate(), "rb") as fin: - _ = next(fin) # header - self.alpha = np.fromiter(next(fin).split()[2:], dtype=float) - assert ( - len(self.alpha) == self.num_topics - ), "Mismatch between MALLET vs. requested topics" - - # Get occurrence of each found topic-region combination: - # - Get region (type) and topic column from Mallet state file. - # - Count occurrence of each topic-region combination. - topic_region_occurrence_df_pl = ( - pl.read_csv( - self.fstate(), - separator=" ", - has_header=False, - skip_rows=3, - columns=[4, 5], - new_columns=["region", "topic"], + # Read cell-topic probabilities Mallet output file and extract for each cell + # barcode the probability for the cell barcode to belong to a certain topic. + # + # Column 0: order in which cell barcode idx was seen in the input corpus file. + # Column 1: cell barcode idx + # Column 3-n: "topic probability" for each topic + # + # Mallet cell-topic probabilities file example: + # --------------------------------------------- + # + # 0 0 0.06355276993175679 0.1908026307651073 0.06691338680081645 0.007391295383790694 0.07775807681999052 0.08091252087499742 0.08262375523163516 9.793208667505102E-4 0.007721171886275076 0.01605055357400573 0.014071294559099437 0.025307712924973712 0.020524503638950167 0.061903387419334883 0.07344906500628827 0.02866832979403336 0.03520400799950518 0.07608807702616333 0.047656845968290625 0.022421293528235367 + # 1 1 0.10109016579604815 0.0016579604814898933 0.033499886441062915 0.003792868498750852 0.06665909607086078 0.19216443334090394 0.023143311378605497 0.0011128775834658188 0.08719055189643425 0.00401998637292755 0.0030206677265500795 0.03617987735634794 0.02473313649784238 0.255984555984556 0.004383374971610266 0.037179196002725415 0.023143311378605497 0.06202589143765614 0.009379968203497615 0.02963888258005905 + # 2 2 0.08937104175357427 0.03120615116234973 0.11623971329970799 0.03952083886381736 0.034562364898175886 0.08415658538435283 0.03002104744207213 0.040440479350752775 0.02172532140012894 0.025119458455003983 0.01332530623080132 0.06196196291099397 0.07174617922560582 0.03189825173499185 0.05144772270469111 0.00540881337934696 0.08696291099397019 0.07489381470666313 0.04997819409154689 0.04001384201145284 + # 3 3 0.05694870514375401 0.003620603552828708 0.07264393236783906 0.11541342655347078 0.005546835984875508 0.025451237782692444 0.010790468716558465 0.377309695369908 0.03540343868160091 0.007580081329813798 0.023453663408717986 0.02869729614040094 0.08166868802168795 0.01703288863522865 0.006153242491260612 0.0172112434900478 0.06311978312049654 0.02124206320896055 0.012895056003424414 0.017817649996432903 + # 4 4 0.08079825190344497 0.002168049438355697 0.06058588548601864 0.002919184676841135 0.07448188739799926 0.12989518249172044 0.15225852709208235 0.008962409095564889 0.02753593499265936 0.001519341732391 0.011386527365222438 0.012376660179589606 0.015108061046809382 0.1424596264809314 0.015449486155211854 0.027740790057700842 0.068370377957595 0.1540339376557752 0.002168049438355697 0.00978182935573082 + cell_topic_probabilities_ldf = pl.scan_csv( + mallet_cell_topic_probabilities_txt_filename, + separator="\t", + has_header=False, + with_column_names=lambda cols: [ + f"topic_{idx - 1}" if idx > 1 else f"cell_idx{idx}" + for idx, col in enumerate(cols) + ], + ) + # Get cell-topic probabilities as numpy matrix. + cell_topic_probabilities = ( + cell_topic_probabilities_ldf.select( + pl.col("^topic_[0-9]+$").cast(pl.Float32) ) - .lazy() - .group_by(["topic", "region"]) - .agg(pl.len().alias("occurrence")) .collect() + .to_numpy() ) - # Fill in word topics matrix values. - word_topics[ - topic_region_occurrence_df_pl.get_column("topic"), - topic_region_occurrence_df_pl.get_column("region"), - ] = topic_region_occurrence_df_pl.get_column("occurrence") - - return word_topics + # Write cell-topic probabilities matrix to one column of a Parquet file. + pl.Series( + "cell_topic_probabilities", cell_topic_probabilities + ).to_frame().write_parquet( + f"{mallet_cell_topic_probabilities_parquet_filename}" + ) - def get_topics(self): + @staticmethod + def read_cell_topic_probabilities_parquet_file( + mallet_cell_topic_probabilities_parquet_filename: str, + ) -> np.ndarray: """ - Get topics X words matrix. + Read cell-topic probabilities Parquet file to cell-topic probabilities matrix. + + Parameters + ---------- + mallet_cell_topic_probabilities_parquet_filename + Mallet cell-topic probabilities Parquet filename. Returns ------- - np.ndarray - Topics X words matrix, shape `num_topics` x `vocabulary_size`. + Cell-topic probabilities matrix. """ - topics = self.word_topics - return topics / topics.sum(axis=1)[:, None] + return ( + pl.read_parquet(mallet_cell_topic_probabilities_parquet_filename) + .get_column("cell_topic_probabilities") + .to_numpy() + ) - def fcorpustxt(self): + @staticmethod + def convert_region_topic_counts_txt_to_parquet( + mallet_region_topic_counts_txt_filename: str, + mallet_region_topic_counts_parquet_filename: str, + ) -> None: """ - Get path to corpus text file. + Convert region-topic counts from Mallet output to Parquet file. + + Parameters + ---------- + mallet_region_topic_counts_txt_filename + Mallet region-topic counts text file. + mallet_region_topic_counts_parquet_filename + Parquet output file with region-topic counts. Returns ------- - str - Path to corpus text file. + None """ - return os.path.join(self.tmp_dir, "corpus.txt") + n_region_ids = -1 + n_topics = -1 + region_id_topic_counts = [] + + with open(mallet_region_topic_counts_txt_filename) as fh: + # Column 0: order in which region ID idx was seen in the input corpus file. + # Column 1: region ID idx + # Column 3-n: "topic:count" pairs + # + # Mallet region-topics count file example: + # ---------------------------------------- + # + # 0 12 3:94 11:84 1:84 18:75 17:36 0:31 13:25 4:23 6:22 12:16 9:10 10:6 15:3 7:2 8:1 + # 1 28 8:368 15:267 3:267 17:255 0:245 10:227 16:216 19:201 7:92 18:85 1:58 14:52 9:31 6:17 13:6 2:3 + # 2 33 8:431 16:418 10:354 3:257 17:211 12:146 7:145 9:115 4:108 13:106 18:66 1:60 15:45 6:45 19:33 5:19 14:12 0:1 + # 3 35 7:284 18:230 15:199 10:191 16:164 0:114 4:112 19:107 12:104 13:68 3:49 9:35 1:28 11:25 5:20 17:17 6:11 14:2 8:1 + # 4 57 8:192 3:90 19:88 1:69 18:67 2:63 10:62 17:38 15:37 13:10 4:9 12:2 9:1 + for line in fh: + columns = line.rstrip().split() + # Get region ID index from second column. + region_id_idx = int(columns[1]) + # Get topic index and counts from column 3 till the end by splitting + # "topic:count" pairs. + topics_counts = [ + (int(topic), int(count)) + for topic, count in [ + topic_counts.split(":", 1) for topic_counts in columns[2:] + ] + ] + # Get topic indices. + topics_idx = np.array([topic for topic, count in topics_counts]) + # Get counts. + counts = np.array([count for topic, count in topics_counts]) + # Store region ID index, topics indices and counts till we know how many + # regions and topics we have. + region_id_topic_counts.append((region_id_idx, topics_idx, counts)) + + # Keep track of the highest seen region ID index and topic index + # (0-based). + n_region_ids = max(region_id_idx, n_region_ids) + n_topics = max(topics_idx.max(), n_topics) + + # Add 1 to region IDs and topics counts to account for start at 0. + n_region_ids += 1 + n_topics += 1 + + # Create region-topic counts matrix and populate it. + regions_topic_counts = np.zeros((n_topics, n_region_ids), dtype=np.int32) + for region_idx, topics_idx, counts in region_id_topic_counts: + regions_topic_counts[topics_idx, region_idx] = counts + + # Write region-topic counts matrix to one column of a Parquet file. + pl.Series("region_topic_counts", regions_topic_counts).to_frame().write_parquet( + mallet_region_topic_counts_parquet_filename + ) - def fcorpusmallet(self): + @staticmethod + def read_region_topic_counts_parquet_file( + mallet_region_topic_counts_parquet_filename: str, + ) -> np.ndarray: """ - Get path to corpus.mallet file. + Read region-topic counts Parquet file to region-topic counts matrix. + + Parameters + ---------- + mallet_region_topic_counts_parquet_filename + Mallet region-topic counts Parquet filename. Returns ------- - str - Path to corpus.mallet file. + Region-topic counts matrix. """ - return os.path.join(self.tmp_dir, "corpus.mallet") + return ( + pl.read_parquet(mallet_region_topic_counts_parquet_filename) + .get_column("region_topic_counts") + .to_numpy() + ) - def fstate(self): + @staticmethod + def read_region_topic_counts_parquet_file_to_region_topic_probabilities( + mallet_region_topic_counts_parquet_filename: str, + ) -> np.ndarray: """ - Get path to temporary file. + Get the region-topic probabilities matrix learned during inference. Returns ------- - str - Path to file. + The probability for each region in each topic, shape (n_regions, n_topics). """ - return os.path.join(self.tmp_dir, f"{self.random_label}_state.mallet.gz") + region_topic_counts = np.asarray( + LDAMallet.read_region_topic_counts_parquet_file( + mallet_region_topic_counts_parquet_filename=mallet_region_topic_counts_parquet_filename, + ), + np.float64, + ) - def fdoctopics(self): - """ - Get path to document topic text file. + # Create region-topic probabilities matrix by dividing all count values for a + # topic by total counts for that topic. + region_topic_probabilities = ( + region_topic_counts / region_topic_counts.sum(axis=1)[:, None] + ).astype(np.float32) - Returns - ------- - str - Path to document topic text file. + return region_topic_probabilities + @staticmethod + def read_parameters_json_filename(parameters_json_filename: str) -> dict: """ - return os.path.join(self.tmp_dir, f"{self.random_label}_doctopics.txt") + Read parameters from JSON file which gets written by `LDAMallet.run_mallet_topic_modeling`. - def finferencer(self): - """ - Get path to inferencer.mallet file. + Parameters + ---------- + parameters_json_filename + Parameters JSON filename created by `LDAMallet.run_mallet_topic_modeling`. Returns ------- - str - Path to inferencer.mallet file. + Dictionary with Mallet LDA parameters and settings. """ - return os.path.join(self.tmp_dir, f"{self.random_label}_inferencer.mallet") - - def ftopickeys(self): + with open(parameters_json_filename, "r") as fh: + mallet_train_topics_parameters = json.load(fh) + return mallet_train_topics_parameters + + @staticmethod + def run_mallet_topic_modeling( + mallet_corpus_filename: str, + output_prefix: str, + n_topics: int, + alpha: float = 50, + alpha_by_topic: bool = True, + eta: float = 0.1, + eta_by_topic: bool = False, + n_cpu: int = 1, + optimize_interval: int = 0, + iterations: int = 150, + topic_threshold: float = 0.0, + random_seed: int = 555, + mallet_path: str = "mallet", + ): """ - Get path to topic keys text file. + Run Mallet LDA. - Returns - ------- - str - Path to topic keys text file. + Parameters + ---------- + mallet_corpus_filename + Mallet corpus file. + output_prefix + Output prefix. + n_topics + The number of topics to use in the model. + alpha + Scalar value indicating the symmetric Dirichlet hyperparameter for topic + proportions. Default: 50. + alpha_by_topic + Boolean indicating whether the scalar given in alpha has to be divided by + the number of topics. Default: True. + eta + Scalar value indicating the symmetric Dirichlet hyperparameter for topic + multinomials. Default: 0.1. + eta_by_topic + Boolean indicating whether the scalar given in beta has to be divided by + the number of topics. Default: False + n_cpu + Number of threads that will be used for training. Default: 1. + optimize_interval + Optimize hyperparameters every `optimize_interval` iterations (sometimes + leads to Java exception, 0 to switch off hyperparameter optimization). + Default: 0. + iterations + Number of training iterations. Default: 150. + topic_threshold + Threshold of the probability above which we consider a topic. Default: 0.0. + random_seed + Random seed to ensure consistent results, if 0 - use system clock. + Default: 555. + mallet_path + Path to the mallet binary (e.g. /xxx/Mallet/bin/mallet). Default: "mallet". """ - return os.path.join(self.tmp_dir, f"{self.random_label}_topickeys.txt") + logger = logging.getLogger("LDAMallet") + # Mallet divides alpha value by default by the number of topics, so in case + # alpha_by_topic=False, input alpha needs to be multiplied by n_topics. + mallet_alpha = alpha if alpha_by_topic else alpha * n_topics -def run_cgs_models_mallet( - cistopic_obj: CistopicObject, - n_topics: list[int], - n_cpu: int = 1, - n_iter: int = 150, - random_state: int = 555, - alpha: float = 50.0, - alpha_by_topic: bool = True, - eta: float = 0.1, - eta_by_topic: bool = False, - top_topics_coh: int = 5, - tmp_path: str = None, - save_path: str = None, - reuse_corpus: bool = False, - mallet_path: str = "mallet", -): - """ - Run Latent Dirichlet Allocation per model as implemented in Mallet (McCallum, 2002). + mallet_beta = eta / n_topics if eta_by_topic else eta - Parameters - ---------- - cistopic_obj: CistopicObject - A :class:`CistopicObject`. Note that cells/regions have to be filtered before running any LDA model. - n_topics: list of int - A list containing the number of topics to use in each model. - n_cpu: int, optional - Number of cpus to use for modelling. In this function parallelization is done per model, that is, one model will run entirely in a unique cpu. We recommend to set the number of cpus as the number of models that will be inferred, so all models start at the same time. - n_iter: int, optional - Number of iterations for which the Gibbs sampler will be run. Default: 150. - random_state: int, optional - Random seed to initialize the models. Default: 555. - alpha: float, optional - Scalar value indicating the symmetric Dirichlet hyperparameter for topic proportions. Default: 50. - alpha_by_topic: bool, optional - Boolean indicating whether the scalar given in alpha has to be divided by the number of topics. Default: True - eta: float, optional - Scalar value indicating the symmetric Dirichlet hyperparameter for topic multinomials. Default: 0.1. - eta_by_topic: bool, optional - Boolean indicating whether the scalar given in beta has to be divided by the number of topics. Default: False - top_topics_coh: int, optional - Number of topics to use to calculate the model coherence. For each model, the coherence will be calculated as the average of the top coherence values. Default: 5. - tmp_path: str, optional - Path to a temporary folder for Mallet. Default: None. - save_path: str, optional - Path to save models as independent files as they are completed. This is recommended for large data sets. Default: None. - reuse_corpus: bool, optional - Whether to reuse the mallet corpus in the tmp directory. Default: False - mallet_path: str - Path to Mallet binary (e.g. "/xxx/Mallet/bin/mallet"). Default: "mallet". + lda_mallet_filenames = LDAMalletFilenames( + output_prefix=output_prefix, n_topics=n_topics + ) - Return - ------ - list of :class:`CistopicLDAModel` - A list with cisTopic LDA models. + if not os.path.exists(mallet_corpus_filename): + raise FileNotFoundError( + f'Mallet corpus file "{mallet_corpus_filename}" does not exist.' + ) - References - ---------- - McCallum, A. K. (2002). Mallet: A machine learning for language toolkit. http://mallet.cs.umass.edu. + cmd = [ + mallet_path, + "train-topics", + "--input", + mallet_corpus_filename, + "--num-topics", + str(n_topics), + "--alpha", + str(mallet_alpha), + "--beta", + str(mallet_beta), + "--optimize-interval", + str(optimize_interval), + "--num-threads", + str(n_cpu), + "--num-iterations", + str(iterations), + "--word-topic-counts-file", + lda_mallet_filenames.region_topic_counts_txt_filename, + "--output-doc-topics", + lda_mallet_filenames.cell_topic_probabilities_txt_filename, + "--doc-topics-threshold", + str(topic_threshold), + "--random-seed", + str(random_seed), + ] - """ - # Create cisTopic logger - level = logging.INFO - log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" - handlers = [logging.StreamHandler(stream=sys.stdout)] - logging.basicConfig(level=level, format=log_format, handlers=handlers) - log = logging.getLogger("cisTopic") + start_time = time.time() + logger.info(f"Train topics with Mallet LDA: {' '.join(cmd)}") + try: + subprocess.check_output(args=cmd, shell=False, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + raise RuntimeError( # noqa: B904 + f"command '{e.cmd}' return with error (code {e.returncode}): {e.output}" + ) - binary_matrix = cistopic_obj.binary_matrix - region_names = cistopic_obj.region_names - cell_names = cistopic_obj.cell_names + # Convert cell-topic probabilities text version to parquet. + logger.info( + f'Write cell-topic probabilities to "{lda_mallet_filenames.cell_topic_probabilities_parquet_filename}".' + ) + LDAMallet.convert_cell_topic_probabilities_txt_to_parquet( + mallet_cell_topic_probabilities_txt_filename=lda_mallet_filenames.cell_topic_probabilities_txt_filename, + mallet_cell_topic_probabilities_parquet_filename=lda_mallet_filenames.cell_topic_probabilities_parquet_filename, + ) - log.info("Formatting input to corpus") - corpus = matutils.Sparse2Corpus(binary_matrix) - id2word = utils.FakeDict(len(region_names)) - - model_list = [ - run_cgs_model_mallet( - binary_matrix=binary_matrix, - corpus=corpus, - id2word=id2word, - n_topics=n_topic, - cell_names=cell_names, - region_names=region_names, - n_cpu=n_cpu, - n_iter=n_iter, - random_state=random_state, - alpha=alpha, - alpha_by_topic=alpha_by_topic, - eta=eta, - eta_by_topic=eta_by_topic, - top_topics_coh=top_topics_coh, - tmp_path=tmp_path, - save_path=save_path, - reuse_corpus=reuse_corpus, - mallet_path=mallet_path, + # Convert region-topic counts text version to parquet. + logger.info( + f'Write region-topic counts to "{lda_mallet_filenames.region_topic_counts_parquet_filename}".' + ) + LDAMallet.convert_region_topic_counts_txt_to_parquet( + mallet_region_topic_counts_txt_filename=lda_mallet_filenames.region_topic_counts_txt_filename, + mallet_region_topic_counts_parquet_filename=lda_mallet_filenames.region_topic_counts_parquet_filename, ) - for n_topic in n_topics - ] - return model_list + total_time = time.time() - start_time -def run_cgs_model_mallet( - binary_matrix: sparse.csr_matrix, - corpus: list, - id2word: utils.FakeDict, - n_topics: list[int], - cell_names: list[str], - region_names: list[str], - n_cpu: int = 1, - n_iter: int = 500, - random_state: int = 555, - alpha: float = 50, - alpha_by_topic: bool = True, - eta: float = 0.1, - eta_by_topic: bool = False, + # Write JSON file with all used parameters. + logger.info( + f'Write JSON parameters file to "{lda_mallet_filenames.parameters_json_filename}".' + ) + with open(lda_mallet_filenames.parameters_json_filename, "w") as fh: + mallet_train_topics_parameters = { + "mallet_corpus_filename": mallet_corpus_filename, + "output_prefix": output_prefix, + "n_topics": n_topics, + "alpha": alpha, + "alpha_by_topic": alpha_by_topic, + "eta": eta, + "eta_by_topic": eta_by_topic, + "n_cpu": n_cpu, + "optimize_interval": optimize_interval, + "iterations": iterations, + "random_seed": random_seed, + "mallet_path": mallet_path, + "time": total_time, + "mallet_cmd": cmd, + } + json.dump(mallet_train_topics_parameters, fh) + + +class LDAMalletFilenames: + """Class to generate output filenames when running functions of LDAMallet.""" + + def __init__(self, output_prefix: str, n_topics: int): + """ + Generate output filenames when running functions of LDAMallet. + + Parameters + ---------- + output_prefix + Output prefix. + n_topics + The number of topics used in the model. + + """ + self.output_prefix = output_prefix + self.n_topics = n_topics + + @property + def parameters_json_filename(self): + return os.path.join( + f"{self.output_prefix}.{self.n_topics}_topics.parameters.json" + ) + + @property + def cell_topic_probabilities_txt_filename(self): + return os.path.join( + f"{self.output_prefix}.{self.n_topics}_topics.cell_topic_probabilities.txt" + ) + + @property + def cell_topic_probabilities_parquet_filename(self): + return os.path.join( + f"{self.output_prefix}.{self.n_topics}_topics.cell_topic_probabilities.parquet" + ) + + @property + def region_topic_counts_txt_filename(self): + return f"{self.output_prefix}.{self.n_topics}_topics.region_topic_counts.txt" + + @property + def region_topic_counts_parquet_filename(self): + return ( + f"{self.output_prefix}.{self.n_topics}_topics.region_topic_counts.parquet" + ) + + @property + def model_pickle_filename(self): + return f"{self.output_prefix}.{self.n_topics}_topics.model.pkl" + + +def calculate_model_evaluation_stats( + binary_accessibility_matrix: scipy.sparse.csr_matrix, + cell_barcodes: list[str], + region_ids: list[str], + output_prefix: str, + n_topics: int, top_topics_coh: int = 5, - tmp_path: str = None, - save_path: str = None, - reuse_corpus: bool = False, - mallet_path: str = "mallet", -): +) -> None: """ - Run Latent Dirichlet Allocation in a model as implemented in Mallet (McCallum, 2002). + Calculate model evaluation statistics after running Mallet (McCallum, 2002) topic modeling. Parameters ---------- - binary_matrix: sparse.csr_matrix - Binary sparse matrix containing cells as columns, regions as rows, and 1 if a regions is considered accessible on a cell (otherwise, 0). - n_topics: list of int - A list containing the number of topics to use in each model. - cell_names: list of str + binary_accessibility_matrix + Binary accessibility sparse matrix with cells as columns, regions as rows, + and 1 as value if a region is considered accessible in a cell (otherwise, 0). + cell_barcodes List containing cell names as ordered in the binary matrix columns. - region_names: list of str + region_ids List containing region names as ordered in the binary matrix rows. - n_cpu: int, optional - Number of cpus to use for modelling. In this function parallelization is done per model, that is, one model will run entirely in a unique cpu. We recommend to set the number of cpus as the number of models that will be inferred, so all models start at the same time. - n_iter: int, optional - Number of iterations for which the Gibbs sampler will be run. Default: 150. - random_state: int, optional - Random seed to initialize the models. Default: 555. - alpha: float, optional - Scalar value indicating the symmetric Dirichlet hyperparameter for topic proportions. Default: 50. - alpha_by_topic: bool, optional - Boolean indicating whether the scalar given in alpha has to be divided by the number of topics. Default: True - eta: float, optional - Scalar value indicating the symmetric Dirichlet hyperparameter for topic multinomials. Default: 0.1. - eta_by_topic: bool, optional - Boolean indicating whether the scalar given in beta has to be divided by the number of topics. Default: False - top_topics_coh: int, optional - Number of topics to use to calculate the model coherence. For each model, the coherence will be calculated as the average of the top coherence values. Default: 5. - tmp_path: str, optional - Path to a temporary folder for Mallet. Default: None. - save_path: str, optional - Path to save models as independent files as they are completed. This is recommended for large data sets. Default: None. - reuse_corpus: bool, optional - Whether to reuse the mallet corpus in the tmp directory. Default: False - mallet_path: str - Path to Mallet binary (e.g. "/xxx/Mallet/bin/mallet"). Default: "mallet". + output_prefix + Output prefix used for running topic modeling with Mallet. + n_topics + Number of topics used in the topic model created by Mallet. + In combination with output_prefix, this allows to load the correct region + topic counts and cell topic probabilties parquet files. + top_topics_coh + Number of topics to use to calculate the model coherence. For each model, + the coherence will be calculated as the average of the top coherence values. + Default: 5. Return ------ - CistopicLDAModel - A cisTopic LDA model. + None References ---------- @@ -907,53 +953,62 @@ def run_cgs_model_mallet( logging.basicConfig(level=level, format=log_format, handlers=handlers) log = logging.getLogger("cisTopic") - # Set models - if not alpha_by_topic: - alpha = alpha * n_topics - if eta_by_topic: - eta = eta / n_topics - - # Running model - start = time.time() - log.info(f"Running model with {n_topics} topics") - model = LDAMallet( - corpus=corpus, - id2word=id2word, - num_topics=n_topics, - iterations=n_iter, - alpha=alpha, - eta=eta, - n_cpu=n_cpu, - tmp_dir=tmp_path, - random_seed=random_state, - reuse_corpus=reuse_corpus, - mallet_path=mallet_path, + # Get distributions + lda_mallet_filenames = LDAMalletFilenames( + output_prefix=output_prefix, n_topics=n_topics + ) + topic_word_distrib = LDAMallet.read_region_topic_counts_parquet_file_to_region_topic_probabilities( + mallet_region_topic_counts_parquet_filename=lda_mallet_filenames.region_topic_counts_parquet_filename + ) + doc_topic_distrib = LDAMallet.read_cell_topic_probabilities_parquet_file( + mallet_cell_topic_probabilities_parquet_filename=lda_mallet_filenames.cell_topic_probabilities_parquet_filename + ) + topic_word_counts = LDAMallet.read_region_topic_counts_parquet_file( + mallet_region_topic_counts_parquet_filename=lda_mallet_filenames.region_topic_counts_parquet_filename ) - end_time = time.time() - start - # Get distributions - topic_word = model.get_topics() - doc_topic = ( - pd.read_csv(model.fdoctopics(), header=None, sep="\t").iloc[:, 2:].to_numpy() + # Read used Mallet LDA parameters from JSON file. + mallet_train_topics_parameters = LDAMallet.read_parameters_json_filename( + lda_mallet_filenames.parameters_json_filename ) + if mallet_train_topics_parameters["n_topics"] != n_topics: + raise ValueError( + f"Number of topics does not match: {n_topics} vs {mallet_train_topics_parameters['n_topics']}." + ) + + alpha = mallet_train_topics_parameters["alpha"] + alpha_by_topic = mallet_train_topics_parameters["alpha_by_topic"] + eta = mallet_train_topics_parameters["eta"] + eta_by_topic = mallet_train_topics_parameters["eta_by_topic"] + n_iter = mallet_train_topics_parameters["iterations"] + random_state = mallet_train_topics_parameters["random_seed"] + mallet_time = mallet_train_topics_parameters["time"] + + ll_alpha = alpha / n_topics if alpha_by_topic else alpha + ll_eta = eta / n_topics if eta_by_topic else eta + # Model evaluation - cell_cov = np.asarray(binary_matrix.sum(axis=0)).astype(float) + cell_cov = np.asarray(binary_accessibility_matrix.sum(axis=0)).astype(float) arun_2010 = tmtoolkit.topicmod.evaluate.metric_arun_2010( - topic_word, doc_topic, cell_cov + topic_word_distrib=topic_word_distrib, + doc_topic_distrib=doc_topic_distrib, + doc_lengths=cell_cov, + ) + cao_juan_2009 = tmtoolkit.topicmod.evaluate.metric_cao_juan_2009( + topic_word_distrib=topic_word_distrib ) - cao_juan_2009 = tmtoolkit.topicmod.evaluate.metric_cao_juan_2009(topic_word) mimno_2011 = tmtoolkit.topicmod.evaluate.metric_coherence_mimno_2011( - topic_word, - dtm=binary_matrix.transpose(), + topic_word_distrib=topic_word_distrib, + dtm=binary_accessibility_matrix.transpose(), top_n=20, eps=1e-12, normalize=True, return_mean=False, ) - topic_word_assig = model.word_topics - doc_topic_assig = (doc_topic.T * (cell_cov)).T - ll = loglikelihood(topic_word_assig, doc_topic_assig, alpha, eta) + + doc_topic_counts = (doc_topic_distrib.T * (cell_cov)).T + ll = loglikelihood(topic_word_counts, doc_topic_counts, ll_alpha, ll_eta) # Organize data if len(mimno_2011) <= top_topics_coh: @@ -983,25 +1038,27 @@ def run_cgs_model_mallet( marg_topic = pd.DataFrame( [ range(1, n_topics + 1), - tmtoolkit.topicmod.model_stats.marginal_topic_distrib(doc_topic, cell_cov), + tmtoolkit.topicmod.model_stats.marginal_topic_distrib( + doc_topic_distrib=doc_topic_distrib, doc_lengths=cell_cov + ), ], index=["Topic", "Marg_Topic"], ).transpose() topic_ass = pd.DataFrame.from_records( [ range(1, n_topics + 1), - list(chain.from_iterable(model.word_topics.sum(axis=1)[:, None])), + list(chain.from_iterable(topic_word_counts.sum(axis=1)[:, None])), ], index=["Topic", "Assignments"], ).transpose() cell_topic = pd.DataFrame.from_records( - doc_topic, - index=cell_names, + doc_topic_distrib, + index=cell_barcodes, columns=["Topic" + str(i) for i in range(1, n_topics + 1)], ).transpose() topic_region = pd.DataFrame.from_records( - topic_word, - columns=region_names, + topic_word_distrib, + columns=region_ids, index=["Topic" + str(i) for i in range(1, n_topics + 1)], ).transpose() parameters = pd.DataFrame( @@ -1014,8 +1071,7 @@ def run_cgs_model_mallet( alpha_by_topic, eta, top_topics_coh, - end_time, - model.time, + mallet_time, ], index=[ "package", @@ -1026,7 +1082,6 @@ def run_cgs_model_mallet( "alpha_by_topic", "eta", "top_topics_coh", - "full_time", "model_time", ], columns=["Parameter"], @@ -1035,14 +1090,12 @@ def run_cgs_model_mallet( model = CistopicLDAModel( metrics, coherence, marg_topic, topic_ass, cell_topic, topic_region, parameters ) - log.info(f"Model with {n_topics} topics done!") - if isinstance(save_path, str): - log.info(f"Saving model with {n_topics} topics at {save_path}") - if not os.path.exists(save_path): - os.mkdir(save_path) - with open(os.path.join(save_path, f"Topic{n_topics}.pkl"), "wb") as f: - pickle.dump(model, f) - return model + + log.info( + f"Saving model with {n_topics} topics at {lda_mallet_filenames.model_pickle_filename}" + ) + with open(lda_mallet_filenames.model_pickle_filename, "wb") as fh: + pickle.dump(model, fh) def evaluate_models( diff --git a/src/pycisTopic/loom.py b/src/pycisTopic/loom.py index 98a5878..081070f 100644 --- a/src/pycisTopic/loom.py +++ b/src/pycisTopic/loom.py @@ -496,7 +496,7 @@ def export_region_accessibility_to_loom( out_fname: str, selected_regions: list[str] | None = None, selected_cells: list[str] | None = None, - cluster_annotation: list[str] | None = None, + cluster_annotation: list[str] | None = None, cluster_markers: dict[str, dict[str, pd.DataFrame]] | None = None, tree_structure: tuple = (), title: str | None = None, @@ -746,7 +746,6 @@ def export_minimal_loom_region( auc_thresholds=None, compress: bool = False, ): - # Information on the general loom file format: http://linnarssonlab.org/loompy/format/index.html # Information on the SCope specific alterations: https://github.com/aertslab/SCope/wiki/Data-Format diff --git a/src/pycisTopic/plotting/qc_plot.py b/src/pycisTopic/plotting/qc_plot.py index 405f561..044794a 100644 --- a/src/pycisTopic/plotting/qc_plot.py +++ b/src/pycisTopic/plotting/qc_plot.py @@ -14,7 +14,7 @@ def plot_barcode_rank( fragments_stats_per_cb_df: pl.DataFrame, ax: plt.Axes | None = None, - **matplotlib_plot_kwargs + **matplotlib_plot_kwargs, ) -> plt.Axes: """ Plot barcode rank vs number of fragments in a log-log scale. @@ -47,6 +47,7 @@ def plot_barcode_rank( ax.set_ylabel("Number of fragments", fontsize=10) return ax + def plot_insert_size_distribution( fragments_insert_size_dist_df: pl.DataFrame, ax: plt.Axes | None = None, @@ -85,10 +86,11 @@ def plot_insert_size_distribution( ax.set_xlim(*insert_size_distriubtion_xlim) return ax + def plot_tss_enrichment( - tss_norm_matrix_sample_df, - ax: plt.Axes | None = None, - **matplotlib_plot_kwargs + tss_norm_matrix_sample_df: pl.DataFrame, + ax: plt.Axes | None = None, + **matplotlib_plot_kwargs, ) -> plt.Axes: """ Plot TSS enrichment. @@ -141,7 +143,7 @@ def plot_sample_stats( sample_id : str Sample ID. pycistopic_qc_output_dir : str | Path - Directory containing the output of the pycistopic qc command. + Directory containing the output of the ``pycistopic qc run`` command. save : str | Path, optional Path to save the plot, by default None. insert_size_distriubtion_xlim : tuple, optional @@ -195,8 +197,11 @@ def plot_sample_stats( figsize = (6.4 * ncols, 4.8 * nrows) fig, axs = plt.subplots( - nrows=nrows, ncols=ncols, figsize=figsize, - layout = "constrained") + nrows=nrows, + ncols=ncols, + figsize=figsize, + layout="constrained", + ) # Set centered sample title for 3 combined plots. fig.suptitle(sample_id if sample_alias is None else sample_alias) @@ -204,20 +209,20 @@ def plot_sample_stats( # Plot barcode rank plot on the left. plot_barcode_rank( fragments_stats_per_cb_df, - ax = axs[0] + ax=axs[0], ) # Plot insert size distribution plot in the center. plot_insert_size_distribution( fragments_insert_size_dist_df, - ax = axs[1], - insert_size_distriubtion_xlim = insert_size_distriubtion_xlim + ax=axs[1], + insert_size_distriubtion_xlim=insert_size_distriubtion_xlim, ) # Plot TSS enrichment plot on the right. plot_tss_enrichment( tss_norm_matrix_sample_df, - ax = axs[2] + ax=axs[2], ) if save: @@ -227,13 +232,14 @@ def plot_sample_stats( return fig + def _plot_fragment_stats( fragments_stats_per_cb_df: pl.DataFrame, ax: plt.Axes, x_var: str, y_var: str, c_var: str, - **matplotlib_plot_kwargs + **matplotlib_plot_kwargs, ) -> plt.Axes: """ Helper function to plot fragment statistics. @@ -260,16 +266,18 @@ def _plot_fragment_stats( """ fragments_stats_per_cb_df = fragments_stats_per_cb_df.sort( - by = c_var, descending = False + c_var, + descending=False, ) ax.scatter( - x = fragments_stats_per_cb_df.get_column(x_var).to_numpy(), - y = fragments_stats_per_cb_df.get_column(y_var).to_numpy(), - c = fragments_stats_per_cb_df.get_column(c_var).to_numpy(), - **matplotlib_plot_kwargs + x=fragments_stats_per_cb_df.get_column(x_var).to_numpy(), + y=fragments_stats_per_cb_df.get_column(y_var).to_numpy(), + c=fragments_stats_per_cb_df.get_column(c_var).to_numpy(), + **matplotlib_plot_kwargs, ) return ax + def plot_barcode_stats( sample_id: str, pycistopic_qc_output_dir: str | Path, @@ -290,7 +298,7 @@ def plot_barcode_stats( sample_id : str Sample ID. pycistopic_qc_output_dir : str | Path - Directory containing the output of the pycistopic qc command. + Directory containing the output of the ``pycistopic qc run`` command. unique_fragments_threshold : int, optional Unique fragments threshold, by default None. tss_enrichment_threshold : float, optional @@ -321,82 +329,92 @@ def plot_barcode_stats( """ # check if files exist - if not os.path.isfile(os.path.join(pycistopic_qc_output_dir, f"{sample_id}.fragments_stats_per_cb.parquet")): + if not os.path.isfile( + os.path.join( + pycistopic_qc_output_dir, + f"{sample_id}.fragments_stats_per_cb.parquet", + ) + ): raise FileNotFoundError( - f"Could not find {sample_id}.fragments_stats_per_cb.parquet in {pycistopic_qc_output_dir}") + f"Could not find {sample_id}.fragments_stats_per_cb.parquet in {pycistopic_qc_output_dir}" + ) fragments_stats_per_cb_df = pl.read_parquet( os.path.join( - pycistopic_qc_output_dir, f"{sample_id}.fragments_stats_per_cb.parquet" + pycistopic_qc_output_dir, + f"{sample_id}.fragments_stats_per_cb.parquet", ) ) - if detailed_title and bc_passing_filters is None: + if detailed_title and bc_passing_filters is None: Warning("bc_passing_filters is None, no detailed title will be shown") ncols = 3 nrows = 1 figsize = (4 * ncols, 4 * nrows) fig, axs = plt.subplots( - figsize = figsize, nrows = nrows, ncols = ncols, - sharex = True, - layout = "constrained") + figsize=figsize, + nrows=nrows, + ncols=ncols, + sharex=True, + layout="constrained", + ) # Plot TSS enrichment vs unique number of fragments on the left. axs[0] = _plot_fragment_stats( fragments_stats_per_cb_df, - ax = axs[0], - x_var = "unique_fragments_in_peaks_count", - y_var = "tss_enrichment", - c_var = "pdf_values_for_tss_enrichment", - s = 10, - edgecolors = None, - marker = "+", - cmap = "viridis" + ax=axs[0], + x_var="unique_fragments_in_peaks_count", + y_var="tss_enrichment", + c_var="pdf_values_for_tss_enrichment", + s=10, + edgecolors=None, + marker="+", + cmap="viridis", ) axs[0].set_ylabel("TSS enrichment") # Plot FRIP vs unique number of fragments in the center. axs[1] = _plot_fragment_stats( fragments_stats_per_cb_df, - ax = axs[1], - x_var = "unique_fragments_in_peaks_count", - y_var = "fraction_of_fragments_in_peaks", - c_var = "pdf_values_for_fraction_of_fragments_in_peaks", - s = 10, - edgecolors = None, - marker = "+", - cmap = "viridis" + ax=axs[1], + x_var="unique_fragments_in_peaks_count", + y_var="fraction_of_fragments_in_peaks", + c_var="pdf_values_for_fraction_of_fragments_in_peaks", + s=10, + edgecolors=None, + marker="+", + cmap="viridis", ) axs[1].set_ylabel("Fraction of fragments in peaks") # Plot duplication ratio vs unique number of fragments on the right. axs[2] = _plot_fragment_stats( fragments_stats_per_cb_df, - ax = axs[2], - x_var = "unique_fragments_in_peaks_count", - y_var = "duplication_ratio", - c_var = "pdf_values_for_duplication_ratio", - s = 10, - edgecolors = None, - marker = "+", - cmap = "viridis" + ax=axs[2], + x_var="unique_fragments_in_peaks_count", + y_var="duplication_ratio", + c_var="pdf_values_for_duplication_ratio", + s=10, + edgecolors=None, + marker="+", + cmap="viridis", ) axs[2].set_ylabel("Duplication ratio") # plot thresholds if unique_fragments_threshold is not None: for ax in axs: - ax.axvline(x = unique_fragments_threshold, color = "r", linestyle = "--") + ax.axvline(x=unique_fragments_threshold, color="r", linestyle="--") if tss_enrichment_threshold is not None: - axs[0].axhline(y = tss_enrichment_threshold, color = "r", linestyle = "--") + axs[0].axhline(y=tss_enrichment_threshold, color="r", linestyle="--") if frip_threshold is not None: - axs[1].axhline(y = frip_threshold, color = "r", linestyle = "--") + axs[1].axhline(y=frip_threshold, color="r", linestyle="--") if duplication_ratio_threshold is not None: - axs[2].axhline(y = duplication_ratio_threshold, color = "r", linestyle = "--") + axs[2].axhline(y=duplication_ratio_threshold, color="r", linestyle="--") # Set x-axis to log scale and plot x-axis label. for ax in axs: @@ -422,10 +440,12 @@ def plot_barcode_stats( title += f"Median Unique Fragments: {median_no_fragments:.0f}\n" title += f"Median TSS Enrichment: {median_tss_enrichment:.2f}\n" title += f"Median FRIP: {fraction_of_fragments_in_peaks:.2f}\n" - if (unique_fragments_threshold is not None) \ - or (tss_enrichment_threshold is not None) \ - or (frip_threshold is not None) \ - or (duplication_ratio_threshold is not None): + if ( + (unique_fragments_threshold is not None) + or (tss_enrichment_threshold is not None) + or (frip_threshold is not None) + or (duplication_ratio_threshold is not None) + ): title += "Thresholds:\n" if unique_fragments_threshold is not None: title += f" Unique fragments: {unique_fragments_threshold:.2f}\n" @@ -438,7 +458,10 @@ def plot_barcode_stats( else: title = sample_id if sample_alias is None else sample_alias - fig.suptitle(title, horizontalalignment = "left") + fig.suptitle( + title, + horizontalalignment="left", + ) if save: fig.savefig(save) diff --git a/src/pycisTopic/pseudobulk_peak_calling.py b/src/pycisTopic/pseudobulk_peak_calling.py index 5b66035..c697d9d 100644 --- a/src/pycisTopic/pseudobulk_peak_calling.py +++ b/src/pycisTopic/pseudobulk_peak_calling.py @@ -24,22 +24,24 @@ def _generate_bigwig( - path_to_fragments: str, - chromsizes: dict[str, int], - normalize_bigwig: bool, - bw_filename: str, - log: logging.Logger): + path_to_fragments: str, + chromsizes: dict[str, int], + normalize_bigwig: bool, + bw_filename: str, + log: logging.Logger, +): fragments_df = read_fragments_to_polars_df(path_to_fragments) fragments_to_bw( - fragments_df = fragments_df, - chrom_sizes = chromsizes, - bw_filename = bw_filename, - normalize = normalize_bigwig, - scaling_factor = 1, - cut_sites = False + fragments_df=fragments_df, + chrom_sizes=chromsizes, + bw_filename=bw_filename, + normalize=normalize_bigwig, + scaling_factor=1, + cut_sites=False, ) log.info(f"{bw_filename} done!") + def export_pseudobulk( input_data: Union[CistopicObject, pd.DataFrame], variable: str, @@ -51,7 +53,7 @@ def export_pseudobulk( n_cpu: int = 1, normalize_bigwig: bool = True, split_pattern: str = "___", - temp_dir: str = "/tmp" + temp_dir: str = "/tmp", ) -> tuple[dict[str, str], dict[str, str]]: """ Create pseudobulks as bed and bigwig from single cell fragments file given a barcode annotation. @@ -124,15 +126,19 @@ def export_pseudobulk( # Check wether we have a path to fragments for each sample if not all([sample_id in path_to_fragments.keys() for sample_id in sample_ids]): - raise ValueError("Please, provide a path to fragments for each sample in your cell metadata!") + raise ValueError( + "Please, provide a path to fragments for each sample in your cell metadata!" + ) # Check for NaNs in variable column if cell_data[variable].isna().any(): log.warning( - f"NaNs detected in {variable} column. These will be converted to 'nan' string.") + f"NaNs detected in {variable} column. These will be converted to 'nan' string." + ) # Check for numerical values in variable column if not all([isinstance(x, str) for x in cell_data[variable].dropna()]): log.warning( - f"Non-string values detected in {variable} column. These will be converted to strings.") + f"Non-string values detected in {variable} column. These will be converted to strings." + ) # Convert variable column to string cell_data[variable] = cell_data[variable].astype(str) # make output folders, if they don't exists @@ -147,10 +153,11 @@ def export_pseudobulk( sample_to_cell_type_to_barcodes = {} for sample in sample_ids: _sample_cell_data = cell_data.loc[cell_data[sample_id_col] == sample] - _cell_type_to_cell_barcodes = _sample_cell_data \ - .groupby(variable, group_keys=False)["barcode"] \ - .apply(list) \ + _cell_type_to_cell_barcodes = ( + _sample_cell_data.groupby(variable, group_keys=False)["barcode"] + .apply(list) .to_dict() + ) sample_to_cell_type_to_barcodes[sample] = _cell_type_to_cell_barcodes if isinstance(chromsizes, pr.PyRanges): chromsizes_dict = chromsizes.df.set_index("Chromosome").to_dict()["End"] @@ -160,21 +167,22 @@ def export_pseudobulk( log.info("Splitting fragments by cell type.") split_fragment_files_by_cell_type( - sample_to_fragment_file = path_to_fragments, - path_to_temp_folder = temp_dir, - path_to_output_folder = bed_path, - sample_to_cell_type_to_cell_barcodes = sample_to_cell_type_to_barcodes, - chromsizes = chromsizes_dict, - n_cpu = n_cpu, - verbose = False, - clear_temp_folder = True + sample_to_fragment_file=path_to_fragments, + path_to_temp_folder=temp_dir, + path_to_output_folder=bed_path, + sample_to_cell_type_to_cell_barcodes=sample_to_cell_type_to_barcodes, + chromsizes=chromsizes_dict, + n_cpu=n_cpu, + verbose=False, + clear_temp_folder=True, ) bed_paths = {} for cell_type in cell_data[variable].unique(): _bed_fname = os.path.join( bed_path, - f"{_santize_string_for_filename(cell_type)}.fragments.tsv.gz") + f"{_santize_string_for_filename(cell_type)}.fragments.tsv.gz", + ) if os.path.exists(_bed_fname): bed_paths[cell_type] = _bed_fname else: @@ -182,13 +190,15 @@ def export_pseudobulk( log.info("generating bigwig files") joblib.Parallel(n_jobs=n_cpu)( - joblib.delayed(_generate_bigwig) - ( - path_to_fragments = bed_paths[cell_type], - chromsizes = chromsizes_dict, - normalize_bigwig = normalize_bigwig, - bw_filename = os.path.join(bigwig_path, f"{_santize_string_for_filename(cell_type)}.bw"), - log = log + joblib.delayed(_generate_bigwig)( + path_to_fragments=bed_paths[cell_type], + chromsizes=chromsizes_dict, + normalize_bigwig=normalize_bigwig, + bw_filename=os.path.join( + bigwig_path, + f"{_santize_string_for_filename(cell_type)}.bw", + ), + log=log, ) for cell_type in bed_paths.keys() ) @@ -196,7 +206,8 @@ def export_pseudobulk( for cell_type in cell_data[variable].unique(): _bw_fname = os.path.join( bigwig_path, - f"{_santize_string_for_filename(cell_type)}.bw") + f"{_santize_string_for_filename(cell_type)}.bw", + ) if os.path.exists(_bw_fname): bw_paths[cell_type] = _bw_fname else: @@ -204,6 +215,7 @@ def export_pseudobulk( return bw_paths, bed_paths + def peak_calling( macs_path: str, bed_paths: dict, @@ -217,7 +229,7 @@ def peak_calling( q_value: float = 0.05, nolambda: bool = True, skip_empty_peaks: bool = False, - **kwargs + **kwargs, ): """ Performs pseudobulk peak calling with MACS2. It requires to have MACS2 installed (https://github.com/macs3-project/MACS). @@ -275,34 +287,33 @@ def peak_calling( keep_dup, q_value, nolambda, - skip_empty_peaks - + skip_empty_peaks, ) for name in list(bed_paths.keys()) ] ) except Exception as e: ray.shutdown() - raise(e) + raise (e) ray.shutdown() else: - narrow_peaks = [macs_call_peak( - macs_path, - bed_paths[name], - name, - outdir, - genome_size, - input_format, - shift, - ext_size, - keep_dup, - q_value, - nolambda, - skip_empty_peaks - - ) - for name in list(bed_paths.keys()) - ] + narrow_peaks = [ + macs_call_peak( + macs_path, + bed_paths[name], + name, + outdir, + genome_size, + input_format, + shift, + ext_size, + keep_dup, + q_value, + nolambda, + skip_empty_peaks, + ) + for name in list(bed_paths.keys()) + ] narrow_peaks_dict = { list(bed_paths.keys())[i]: narrow_peaks[i].narrow_peak for i in range(len(narrow_peaks)) @@ -310,6 +321,7 @@ def peak_calling( } return narrow_peaks_dict + def macs_call_peak( macs_path: str, bed_path: str, @@ -322,7 +334,7 @@ def macs_call_peak( keep_dup: str = "all", q_value: int = 0.05, nolambda: bool = True, - skip_empty_peaks: bool = False + skip_empty_peaks: bool = False, ): """ Performs pseudobulk peak calling with MACS2 in a group. It requires to have MACS2 installed (https://github.com/macs3-project/MACS). @@ -379,11 +391,12 @@ def macs_call_peak( keep_dup=keep_dup, q_value=q_value, nolambda=nolambda, - skip_empty_peaks=skip_empty_peaks + skip_empty_peaks=skip_empty_peaks, ) log.info(f"{name} done!") return MACS_peak_calling + @ray.remote def macs_call_peak_ray( macs_path: str, @@ -397,7 +410,7 @@ def macs_call_peak_ray( keep_dup: str = "all", q_value: int = 0.05, nolambda: bool = True, - skip_empty_peaks: bool = False + skip_empty_peaks: bool = False, ): """ Performs pseudobulk peak calling with MACS2 in a group. It requires to have MACS2 installed (https://github.com/macs3-project/MACS). @@ -454,8 +467,7 @@ def macs_call_peak_ray( keep_dup=keep_dup, q_value=q_value, nolambda=nolambda, - skip_empty_peaks=skip_empty_peaks - + skip_empty_peaks=skip_empty_peaks, ) log.info(name + " done!") return MACS_peak_calling @@ -580,9 +592,11 @@ def load_narrow_peak(self, skip_empty_peaks: bool): file_is_empty = True if file_is_empty and skip_empty_peaks: print(f"{self.name} has no peaks, skipping") - return pr.PyRanges() + return pr.PyRanges() elif file_is_empty and not skip_empty_peaks: - raise ValueError(f"{self.name} has no peaks, exiting. Set skip_empty_peaks to True to skip empty peaks.") + raise ValueError( + f"{self.name} has no peaks, exiting. Set skip_empty_peaks to True to skip empty peaks." + ) narrow_peak = pd.read_csv( os.path.join(self.outdir, f"{self.name}_peaks.narrowPeak"), sep="\t", diff --git a/src/pycisTopic/py.typed b/src/pycisTopic/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/pycisTopic/pyGREAT.py b/src/pycisTopic/pyGREAT.py index 51cd297..de65df9 100644 --- a/src/pycisTopic/pyGREAT.py +++ b/src/pycisTopic/pyGREAT.py @@ -30,7 +30,7 @@ def pyGREAT( bg_choice: str = "wholeGenome", tmp_dir: str | None = None, n_cpu: int = 1, - **kwargs + **kwargs, ): """ Running GREAT (McLean et al., 2010) on a dictionary of pyranges. For more details in GREAT parameters, please visit http://great.stanford.edu/public/html/. @@ -248,7 +248,7 @@ def pyGREAT_oneset( random_label = hex(random.randint(0, 0xFFFFFF))[2:] bed_file = os.path.join( tmp_dir if tmp_dir else tempfile.gettempdir(), - f"{random_label}_great.bed" + f"{random_label}_great.bed", ) region_set.df.to_csv(bed_file, sep="\t", index=False, header=False) @@ -263,11 +263,11 @@ def pyGREAT_oneset( "twoDistance": two_distance, "oneDistance": one_distance, "includeCuratedRegDoms": include_curated_reg_doms, - "bgChoice": "file" if (bg_choice != 'wholeGenome') else 'wholeGenome', + "bgChoice": "file" if (bg_choice != "wholeGenome") else "wholeGenome", "fgChoice": "file", } - if bg_choice == 'wholeGenome': + if bg_choice == "wholeGenome": files = {"fgFile": open(bed_file, "r")} else: files = {"fgFile": open(bed_file, "r"), "bgFile": open(bg_choice, "r")} diff --git a/src/pycisTopic/qc.py b/src/pycisTopic/qc.py index 3ab3574..b150b45 100644 --- a/src/pycisTopic/qc.py +++ b/src/pycisTopic/qc.py @@ -20,10 +20,14 @@ from pycisTopic.tss_profile import get_tss_profile from scipy.stats import gaussian_kde +if TYPE_CHECKING: + import numpy.typing as npt + # Enable Polars global string cache so all categoricals are created with the same # string cache. pl.enable_string_cache() + def get_barcodes_passing_qc_for_sample( sample_id: str, pycistopic_qc_output_dir: str | Path, @@ -31,7 +35,7 @@ def get_barcodes_passing_qc_for_sample( tss_enrichment_threshold: float | None = None, frip_threshold: float | None = None, use_automatic_thresholds: bool = True, -) -> tuple[np.ndarray, dict[str, float]]: +) -> tuple[list[str], dict[str, float]]: """ Get barcodes passing quality control (QC) for a sample. @@ -40,7 +44,7 @@ def get_barcodes_passing_qc_for_sample( sample_id Sample ID. pycistopic_qc_output_dir - Directory with output from pycistopic qc. + Directory with output from ``pycistopic qc run``. unique_fragments_threshold Threshold for number of unique fragments in peaks. If not defined, and use_automatic_thresholds is False, @@ -48,10 +52,10 @@ def get_barcodes_passing_qc_for_sample( tss_enrichment_threshold Threshold for TSS enrichment score. If not defined, and use_automatic_thresholds is False, - the threshold will be set to 0. + the threshold will be set to 0.0. frip_threshold Threshold for fraction of reads in peaks (FRiP). - If not defined the threshold will be set to 0. + If not defined the threshold will be set to 0.0. use_automatic_thresholds Use automatic thresholds for unique fragments in peaks and TSS enrichment score as calculated by Otsu's method. If False, the thresholds will be set to 0 if not @@ -60,7 +64,7 @@ def get_barcodes_passing_qc_for_sample( Returns ------- Tuple with: - - Numpy array with cell barcodes passing QC. + - List with cell barcodes passing QC. - Dictionary with thresholds used for QC. Raises @@ -69,84 +73,110 @@ def get_barcodes_passing_qc_for_sample( If the file with fragments statistics per cell barcode does not exist. """ - # Check wether files exist - if not os.path.exists(os.path.join(pycistopic_qc_output_dir, f"{sample_id}.fragments_stats_per_cb.parquet")): - raise FileNotFoundError(f"File {os.path.join(pycistopic_qc_output_dir, f'{sample_id}.fragments_stats_per_cb.parquet')} does not exist") - - first_print = True - if use_automatic_thresholds: - # Check wether files exist - if not os.path.exists(os.path.join(pycistopic_qc_output_dir, f"{sample_id}.otsu_thresholds.tsv")): - Warning(f"File {os.path.join(pycistopic_qc_output_dir, f'{sample_id}.otsu_thresholds.tsv')} does not exist") + otsu_thresholds_tsv_filename = os.path.join( + pycistopic_qc_output_dir, + f"{sample_id}.otsu_thresholds.tsv", + ) + + # Check whether files exist. + if not os.path.exists(otsu_thresholds_tsv_filename): + Warning(f'File "{otsu_thresholds_tsv_filename}" does not exist.') else: - # Read automatic thresholds - otsu_unique_fragments_threshold, otsu_tss_enrichment_threshold = pl.read_csv( - os.path.join(pycistopic_qc_output_dir, f"{sample_id}.otsu_thresholds.tsv"), - separator = "\t" - ).select(["unique_fragments_in_peaks_count_otsu_threshold", "tss_enrichment_otsu_threshold"]).to_numpy()[0] + # Read automatic thresholds. + ( + otsu_unique_fragments_threshold, + otsu_tss_enrichment_threshold, + ) = pl.read_csv( + otsu_thresholds_tsv_filename, + separator="\t", + columns=[ + "unique_fragments_in_peaks_count_otsu_threshold", + "tss_enrichment_otsu_threshold", + ], + ).row(0) + + print(f"{sample_id}:") + if unique_fragments_threshold is None: - if first_print: - print(f"{sample_id}:") - first_print = False - print(f"\tUsing automatic threshold for unique fragments: {otsu_unique_fragments_threshold}") + print( + f"\tUsing automatic threshold for unique fragments: {otsu_unique_fragments_threshold}" + ) unique_fragments_threshold = otsu_unique_fragments_threshold else: - if first_print: - print(f"{sample_id}:") - first_print = False - print(f"\tUsing user-defined threshold for unique fragments: {unique_fragments_threshold}") + print( + f"\tUsing user-defined threshold for unique fragments: {unique_fragments_threshold}" + ) + if tss_enrichment_threshold is None: - if first_print: - print(f"{sample_id}:") - first_print = False - print(f"\tUsing automatic threshold for TSS enrichment: {otsu_tss_enrichment_threshold}") + print( + f"\tUsing automatic threshold for TSS enrichment: {otsu_tss_enrichment_threshold}" + ) tss_enrichment_threshold = otsu_tss_enrichment_threshold else: - if first_print: - print(f"{sample_id}:") - first_print = False - print(f"\tUsing user-defined threshold for TSS enrichment: {tss_enrichment_threshold}") - - # Set thresholds to 0 if not defined - if unique_fragments_threshold is None: - if first_print: + print( + f"\tUsing user-defined threshold for TSS enrichment: {tss_enrichment_threshold}" + ) + + if ( + unique_fragments_threshold is None + or tss_enrichment_threshold is None + or frip_threshold is None + ): + if not use_automatic_thresholds: print(f"{sample_id}:") - first_print = False - print("\tNo threshold for unique fragments defined, setting to 0") - unique_fragments_threshold = 0 - if tss_enrichment_threshold is None: - if first_print: - print(f"{sample_id}:") - first_print = False - print("\tNo threshold for TSS enrichment defined, setting to 0") - tss_enrichment_threshold = 0 + # Set thresholds to 0 if not defined. + if unique_fragments_threshold is None: + print("\tNo threshold for unique fragments defined, setting to 0.") + unique_fragments_threshold = 0 - if frip_threshold is None: - if first_print: - print(f"{sample_id}:") - first_print = False - print("\tNo threshold for FRiP defined, setting to 0") - frip_threshold = 0 - - # Get barcodes passing filters - barcodes_passing_filters = pl.scan_parquet( - os.path.join(pycistopic_qc_output_dir, f"{sample_id}.fragments_stats_per_cb.parquet") - ).filter( - ( pl.col("unique_fragments_in_peaks_count") > unique_fragments_threshold ) & \ - ( pl.col("tss_enrichment") > tss_enrichment_threshold ) & \ - ( pl.col("fraction_of_fragments_in_peaks") > frip_threshold ) - ).select("CB").collect().to_numpy().squeeze() + if tss_enrichment_threshold is None: + print("\tNo threshold for TSS enrichment defined, setting to 0.0.") + tss_enrichment_threshold = 0.0 + + if frip_threshold is None: + print("\tNo threshold for FRiP defined, setting to 0.0.") + frip_threshold = 0.0 + + # fragments_stats_per_cb_df_pl + fragments_stats_per_cb_filename = os.path.join( + pycistopic_qc_output_dir, + f"{sample_id}.fragments_stats_per_cb.parquet", + ) + + # Check whether files exist. + if not os.path.exists(fragments_stats_per_cb_filename): + raise FileNotFoundError( + f'File "{fragments_stats_per_cb_filename}" does not exist.' + ) + + # Get barcodes passing filters. + barcodes_passing_filters = ( + pl.scan_parquet(fragments_stats_per_cb_filename) + .filter( + (pl.col("unique_fragments_in_peaks_count") >= unique_fragments_threshold) + & (pl.col("tss_enrichment") >= tss_enrichment_threshold) + & (pl.col("fraction_of_fragments_in_peaks") >= frip_threshold) + ) + .select("CB") + .collect() + .to_series() + .to_list() + ) return barcodes_passing_filters, { "unique_fragments_threshold": unique_fragments_threshold, "tss_enrichment_threshold": tss_enrichment_threshold, - "frip_threshold": frip_threshold + "frip_threshold": frip_threshold, } -def compute_kde(training_data: np.ndarray, test_data: np.ndarray, no_threads: int = 8): +def compute_kde( + training_data: npt.ArrayLike, + test_data: npt.ArrayLike, + no_threads: int = 8, +) -> npt.NDArray[np.float64]: """ Compute kernel-density estimate (KDE) using Gaussian kernels. @@ -171,6 +201,27 @@ def compute_kde(training_data: np.ndarray, test_data: np.ndarray, no_threads: in test_data. """ + training_data = np.asarray(training_data, dtype=np.float64) + test_data = np.asarray(test_data, dtype=np.float64) + + # Avoid very rare cases where second column of training_data has the same + # value everywhere. This can happen in some cases for duplication ratio as + # it can be 0.0% when fragment counts for each fragment are 1. + # + # This will result in the following error: + # LinAlgError: The data appears to lie in a lower-dimensional subspace of + # the space in which it is expressed. This has resulted in a singular data + # covariance matrix, which cannot be treated using the algorithms implemented + # in `gaussian_kde`. Consider performing principle component analysis / + # dimensionality reduction and using `gaussian_kde` with the transformed data. + if np.var(training_data[1]) == 0.0: + # Add small value to first element to avoid all of them to be equal. + if training_data[1][0] == 0.0: + training_data[1][0] = 0.000000000001 + else: + # In even rarer case that the value is not 0.0, change the value proportionally. + training_data[1][0] = training_data[1][0] * 1.000000000001 + # Convert 2D numpy array test data to complex number array so numpy considers both # columns at the same time in further operations. test_data_all = np.empty(test_data.shape[1], dtype=np.complex128) @@ -191,7 +242,9 @@ def compute_kde(training_data: np.ndarray, test_data: np.ndarray, no_threads: in axis=1, ) - def compute_kde_part(test_data_unique_split_array): + def compute_kde_part( + test_data_unique_split_array: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: """ Compute kernel-density estimate (KDE) using Gaussian kernels for a subsection of the test_data. diff --git a/src/pycisTopic/topic_binarization.py b/src/pycisTopic/topic_binarization.py index 1ce0aa3..0b2dffb 100644 --- a/src/pycisTopic/topic_binarization.py +++ b/src/pycisTopic/topic_binarization.py @@ -1,210 +1,17 @@ from __future__ import annotations -import logging -import sys -from typing import TYPE_CHECKING +from functools import partial +from typing import Callable, Literal -import matplotlib -import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import pandas as pd from pyscenic import binarization -if TYPE_CHECKING: - from pycisTopic.cistopic_class import CistopicObject - - -def binarize_topics( - cistopic_obj: CistopicObject, - target: str | None = "region", - method: str | None = "otsu", - smooth_topics: bool = True, - ntop: int = 2000, - predefined_thr: dict[str, float] | None = None, - nbins: int = 100, - plot: bool = False, - figsize: tuple[float, float] | None = (6.4, 4.8), - num_columns: int = 1, - save: str | None = None, -): - r""" - Binarize topic distributions. - - Parameters - ---------- - cistopic_obj - A cisTopic object with a model in :class:`CistopicObject`. - target - Whether cell-topic ("cell") or region-topic ("region") distributions should be - binarized. Default: "region". - method - Method to use for topic binarization. Possible options are: - - ``otsu`` [Otsu, 1979] - - ``yen`` [Yen et al., 1995] - - ``li`` [Li & Lee, 1993] - - ``aucell`` [Van de Sande et al., 2020] - - ``ntop`` [Taking the top n regions per topic] - Default: ``otsu``. - smooth_topics - Whether to smooth topics distributions to penalize regions enriched across many - topics. The following formula is applied: - - .. math:: - \beta_{w, k} (\log\beta_{w,k} - 1 / K \sum_{k'} \log \beta_{w,k'}) - ntop - Number of top regions to select when using ``method="ntop"``. - Default: 2000. - predefined_thr - A dictionary containing topics as keys and threshold as values. If a topic is - not present, thresholds will be computed with the specified method. - This can be used for manually adjusting thresholds when necessary. - Default: None. - nbins - Number of bins to use in the histogram used for ``otsu``, ``yen`` and - ``li`` thresholding. - Default: 100. - plot - Whether to plot region-topic distributions and their threshold. - Default: False. - figsize - Size of the figure. If num_columns is 1, this is the size for each figure. - If ``num_columns`` is above 1, this is the overall size of the figure. - If keeping the default, it will be the size of each subplot in the figure. - Default: (6.4, 4.8). - num_columns - For multiplot figures, indicates the number of columns (the number of rows will - be automatically determined based on the number of plots). - Default: 1. - save - Path to save plot. - Default: None. - - Returns - ------- - A dictionary containing a pd.DataFrame with the selected regions with region names - as indexes and a topic score column. - - References - ---------- - - Otsu, N., 1979. - A threshold selection method from gray-level histograms. - IEEE transactions on systems, man, and cybernetics, 9(1), pp.62-66. - - Yen, J.C., Chang, F.J. and Chang, S., 1995. - A new criterion for automatic multilevel thresholding. - IEEE Transactions on Image Processing, 4(3), pp.370-378. - - Li, C.H. and Lee, C.K., 1993. - Minimum cross entropy thresholding. - Pattern recognition, 26(4), pp.617-625. - - Van de Sande, B., Flerin, C., Davie, K., De Waegeneer, M., Hulselmans, G., - Aibar, S., Seurinck, R., Saelens, W., Cannoodt, R., Rouchon, Q. and - Verbeiren, T., 2020. - A scalable SCENIC workflow for single-cell gene regulatory network analysis. - Nature Protocols, 15(7), pp.2247-2276. - - """ - # Create cisTopic logger - level = logging.INFO - log_format = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" - handlers = [logging.StreamHandler(stream=sys.stdout)] - logging.basicConfig(level=level, format=log_format, handlers=handlers) - log = logging.getLogger("cisTopic") - - if target == "region": - topic_dist = cistopic_obj.selected_model.topic_region - elif target == "cell": - topic_dist = cistopic_obj.selected_model.cell_topic.T - - if smooth_topics: - topic_dist = smooth_topics_distributions(topic_dist) - - binarized_topics = {} - pdf = None - if (save is not None) and (num_columns == 1): - pdf = matplotlib.backends.backend_pdf.PdfPages(save) - - if num_columns > 1: - num_rows = int(np.ceil(topic_dist.shape[1] / num_columns)) - if figsize == (6.4, 4.8): - figsize = (6.4 * num_columns, 4.8 * num_rows) - - fig = plt.figure(figsize=figsize) - j = 1 - for i in range(topic_dist.shape[1]): - l = np.asarray(topic_dist.iloc[:, i]) - l_norm = (l - np.min(l)) / np.ptp(l) - if isinstance(predefined_thr, dict) and "Topic" + str(i + 1) in ( - list(predefined_thr.keys()) - ): - thr = predefined_thr["Topic" + str(i + 1)] - elif method == "otsu": - thr = threshold_otsu(l_norm, nbins=nbins) - elif method == "yen": - thr = threshold_yen(l_norm, nbins=nbins) - elif method == "li": - thresholds = np.arange(np.min(l_norm) + 0.01, np.max(l_norm) - 0.01, 0.01) - entropies = [cross_entropy(l_norm, t, nbins=nbins) for t in thresholds] - thr = thresholds[np.argmin(entropies)] - elif method == "aucell": - df, thr = binarization.binarize(pd.DataFrame(l_norm)) - thr = float(thr) - elif method == "ntop": - data = pd.DataFrame(l_norm).sort_values(0, ascending=False) - thr = float(data.iloc[ntop,]) - else: - log.info( - 'Binarization method not found. Please choose: "otsu", "yen", "li" or "ntop".' - ) - - if plot: - if num_columns > 1: - plt.subplot(num_rows, num_columns, j) - j = j + 1 - plt.hist(l_norm, bins=nbins) - plt.axvline(thr, color="tomato", linestyle="--") - plt.xlabel( - "Standardized probability Topic " - + str(i + 1) - + "\n" - + "Selected:" - + str(sum(l_norm > thr)), - fontsize=10, - ) - if num_columns == 1: - if pdf is not None: - pdf.savefig(fig, bbox_inches="tight") - if plot: - plt.show() - binarized_topics["Topic" + str(i + 1)] = pd.DataFrame( - topic_dist.iloc[l_norm > thr, i] - ).sort_values("Topic" + str(i + 1), ascending=False) - - if target == "region": - cistopic_obj.selected_model.topic_ass["Regions_in_binarized_topic"] = [ - binarized_topics[x].shape[0] for x in binarized_topics - ] - elif target == "cell": - cistopic_obj.selected_model.topic_ass["Cells_in_binarized_topic"] = [ - binarized_topics[x].shape[0] for x in binarized_topics - ] - - if num_columns > 1: - plt.tight_layout() - if save is not None: - fig.savefig(save, bbox_inches="tight") - if plot: - plt.show() - else: - plt.close() - - if pdf is not None: - pdf.close() - - return binarized_topics - def smooth_topics_distributions( - topic_region_distributions: pd.DataFrame, -) -> pd.DataFrame: + cell_or_region_topic_prob: npt.NDArray[np.float64], +) -> npt.NDArray[np.float64]: r""" Smooth topic-region distributions. @@ -216,9 +23,8 @@ def smooth_topics_distributions( Parameters ---------- - topic_region_distributions - A pandas dataframe with topic-region distributions - (with topics as columns and regions as rows). + cell_or_region_topic_prob + Numpy array containing cell or region topic probabilities with topics along columns. Returns ------- @@ -226,7 +32,9 @@ def smooth_topics_distributions( """ - def smooth_topic_distribution(x: np.ndarray) -> np.ndarray: + def smooth_topic_distribution( + x: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: """ Smooth topic-region distribution for a topic. @@ -240,21 +48,12 @@ def smooth_topic_distribution(x: np.ndarray) -> np.ndarray: Smoothed topic-region distribution for a topic. """ - return x * (np.log(x + 1e-100) - np.sum(np.log(x + 1e-100)) / x.shape[0]) + return x * (np.log(x + 1e-45) - np.sum(np.log(x + 1e-45)) / x.shape[0]) - smoothed_topic_region_distributions = pd.DataFrame( - np.apply_along_axis( - smooth_topic_distribution, - 1, - topic_region_distributions.values, - ), - index=topic_region_distributions.index, - columns=topic_region_distributions.columns, - ) - return smoothed_topic_region_distributions + return np.apply_along_axis(smooth_topic_distribution, 1, cell_or_region_topic_prob) -def threshold_yen(array: np.ndarray, nbins: int = 100) -> float: +def threshold_yen(array: npt.NDArray[np.float64], nbins: int = 100) -> float: """ Apply Yen threshold on topic-region distributions [Yen et al., 1995]. @@ -288,7 +87,7 @@ def threshold_yen(array: np.ndarray, nbins: int = 100) -> float: return bin_centers[crit.argmax()] -def threshold_otsu(array: np.ndarray, nbins: int = 100) -> float: +def threshold_otsu(array: npt.NDArray[np.float64], nbins: int = 100) -> float: """ Apply Otsu threshold on topic-region distributions [Otsu, 1979]. @@ -326,7 +125,9 @@ def threshold_otsu(array: np.ndarray, nbins: int = 100) -> float: return threshold -def cross_entropy(array: np.ndarray, threshold: float, nbins: int = 100) -> float: +def cross_entropy( + array: npt.NDArray[np.float64], threshold: float, nbins: int = 100 +) -> float: """ Calculate entropies for Li thresholding on topic-region distributions [Li & Lee, 1993]. @@ -362,8 +163,8 @@ def cross_entropy(array: np.ndarray, threshold: float, nbins: int = 100) -> floa def histogram_and_bin_centers( - array: np.ndarray, nbins: int = 100 -) -> tuple[np.ndarray, np.ndarray]: + array: npt.NDArray[np.float64], nbins: int = 100 +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """ Draw histogram from distribution and identify centers. @@ -383,3 +184,133 @@ def histogram_and_bin_centers( hist, bin_edges = np.histogram(array, bins=nbins, range=None) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 return hist, bin_centers + + +def threshold_li(array: npt.NDArray[np.float64], nbins) -> float: + thresholds = np.arange(np.min(array) + 0.01, np.max(array) - 0.01, 0.01) + entropies = [cross_entropy(array, t, nbins=nbins) for t in thresholds] + thr = thresholds[np.argmin(entropies)] + return thr + + +def threshold_aucell(array: npt.NDArray[np.float64]): + _, thr = binarization.binarize(pd.DataFrame(array)) + return float(thr) + + +def threshold_ntop(array: npt.NDArray[np.float64], ntop: int) -> float: + return np.sort(array)[::-1][ntop] + + +def binarize_topics( + cell_or_region_topic_prob: npt.NDArray[np.float64], + cell_or_region_names: list[str], + method: Literal["otsu", "ntop", "li", "yen", "aucell"] = "otsu", + smooth_topics: bool = True, + ntop: int | None = None, + nbins: int = 100, +) -> tuple[list[list[str]], list[npt.NDArray[np.float64]], list[float]]: + r""" + Binarize topic distributions. + + Parameters + ---------- + cell_or_region_topic_prob + Numpy array containing cell or region topic probabilities with topics along columns. + cell_or_region_names + A list of str containing cell or region names (should be the same length as the number of rows in `cell_or_region_topic_prob`) + method + Method to use for topic binarization. Possible options are: + - ``otsu`` [Otsu, 1979] + - ``yen`` [Yen et al., 1995] + - ``li`` [Li & Lee, 1993] + - ``aucell`` [Van de Sande et al., 2020] + - ``ntop`` [Taking the top n regions per topic] + + smooth_topics + Whether to smooth topics distributions to penalize regions enriched across many + topics. The following formula is applied: + + .. math:: + \beta_{w, k} (\log\beta_{w,k} - 1 / K \sum_{k'} \log \beta_{w,k'}) + ntop + Number of top regions to select when using ``method="ntop"``. + nbins + Number of bins to use in the histogram used for ``otsu``, ``yen`` and + ``li`` thresholding. + Default: 100. + + Returns + ------- + A list of string containing binarized cells or regions, an array of scores and a list of floats containing thresholds + + """ + # input validation + if len(cell_or_region_names) != cell_or_region_topic_prob.shape[0]: + raise ValueError( + f"{len(cell_or_region_names)} cells or region names provided while `cell_or_region_topic_prob` only has {cell_or_region_topic_prob.shape[0]} rows." + ) + + if len(cell_or_region_names) != len(set(cell_or_region_names)): + raise ValueError("`cell_or_region_names` contains duplicates.") + + if method == "ntop" and ntop is None: + raise ValueError( + "A value for ntop should be provided when using `ntop` as binarization method." + ) + + method_to_bin_func: dict[str, Callable[[npt.NDArray[np.float64]], float]] = { + "otsu": partial(threshold_otsu, nbins=nbins), + "yen": partial(threshold_yen, nbins=nbins), + "li": partial(threshold_li, nbins=nbins), + "aucell": threshold_aucell, + "ntop": partial(threshold_ntop, ntop=ntop), # type: ignore + } + + bin_func = method_to_bin_func.get(method) + + if bin_func is None: + raise ValueError( + f'`method` should be one of "otsu", "ntop", "li", "yen", "aucell". Not {method}.' + ) + + # create index used for sorting + cell_or_region_names_idx = {x: i for i, x in enumerate(cell_or_region_names)} + + if smooth_topics: + cell_or_region_topic_prob = smooth_topics_distributions( + cell_or_region_topic_prob + ) + + cell_or_region_names_per_topic: list[list[str]] = [] + scores_per_topic: list[npt.NDArray[np.float64]] = [] + thresholds: list[float] = [] + + # iterate over topics + for i in range(cell_or_region_topic_prob.shape[1]): + # normalize between 0 and 1 + l_norm = ( + cell_or_region_topic_prob[:, i] - np.min(cell_or_region_topic_prob[:, i]) + ) / np.ptp(cell_or_region_topic_prob[:, i]) + # get threshold + thr = bin_func(l_norm) + # sort cell or region names based on l_norm, features with highest score first (reverse=True) + cell_or_region_names_sorted = sorted( + cell_or_region_names, + key=lambda x: l_norm[cell_or_region_names_idx[x]], + reverse=True, + ) + # get cell or regions passing threshold + l_norm_a_sort = np.argsort(l_norm)[::-1] + l_norm_sorted = l_norm[l_norm_a_sort] + cell_or_regions_passing_threshold = cell_or_region_names_sorted[ + 0 : np.where(l_norm_sorted > thr)[0].max() + 1 + ] + scores_passing_threshold = cell_or_region_topic_prob[l_norm_a_sort, i][ + 0 : np.where(l_norm_sorted > thr)[0].max() + 1 + ] + cell_or_region_names_per_topic.append(cell_or_regions_passing_threshold) + scores_per_topic.append(scores_passing_threshold) + thresholds.append(thr) + + return cell_or_region_names_per_topic, scores_per_topic, thresholds diff --git a/src/pycisTopic/topic_qc.py b/src/pycisTopic/topic_qc.py index 021fdfd..7e091b5 100644 --- a/src/pycisTopic/topic_qc.py +++ b/src/pycisTopic/topic_qc.py @@ -19,7 +19,8 @@ def compute_topic_metrics( - cistopic_obj: CistopicObject, return_metrics: bool = True + cistopic_obj: CistopicObject, + return_metrics: bool = True, ): """ Compute topic quality control metrics. @@ -235,7 +236,7 @@ def topic_annotation( annot_var: str, binarized_cell_topic: dict[str, pd.DataFrame] | None = None, general_topic_thr: float = 0.2, - **kwargs + **kwargs, ): """ Automatic annotation of topics. @@ -268,6 +269,7 @@ def topic_annotation( annot = cistopic_obj.cell_data[annot_var] if binarized_cell_topic is None: from pycisTopic.topic_binarization import binarize_topics + binarized_cell_topic = binarize_topics(cistopic_obj, target="cell", **kwargs) topic_annot_dict = {topic: [] for topic in cell_topic.index.tolist()} diff --git a/src/pycisTopic/tss_profile.py b/src/pycisTopic/tss_profile.py index f690d65..026a4c3 100644 --- a/src/pycisTopic/tss_profile.py +++ b/src/pycisTopic/tss_profile.py @@ -96,7 +96,8 @@ def get_tss_profile( # Extend TSS position with flanking window and only keep minimal necessary columns # needed to find the overlap with fragments. tss_annotation_with_flanking_window_df_pl = ( - tss_annotation.select( + tss_annotation.clone() + .select( # Only keep needed columns for faster Genomics Ranges / PyRanges join. pl.col("Chromosome").cast(pl.Categorical), pl.col("Start"), @@ -156,9 +157,8 @@ def get_tss_profile( } ) if use_genomic_ranges - else # Use pyranges to calculate the intersection. - pl.from_pandas( + else pl.from_pandas( ( # Create PyRanges object from filtered fragments Polars DataFrame. create_pyranges_from_polars_df(filtered_fragments_df_pl).join( @@ -287,11 +287,11 @@ def get_tss_profile( aggregate_function="len", ) # Remove "no_CB" cell barcode (was only needed for the pivot). - .filter(pl.col("CB") != "no_CB").with_columns( - # Fill in 0, for non-observed values in the pivot table. - pl.col(pl.UInt32) - .cast(pl.Int32) - .fill_null(0), + .filter(pl.col("CB") != "no_CB") + .with_columns( + # Fill in 0, for non-observed values in the pivot table after casting + # column from UInt32 (`polars`) or UInt64 (`polars-u64-idx`) to Int32. + pl.col(pl.get_index_type()).cast(pl.Int32).fill_null(0), ) ) @@ -313,7 +313,8 @@ def get_tss_profile( header_name="position_from_tss", # Add old "CB" column as column names. column_names=tss_matrix_tmp.get_column("CB"), - ).with_columns( + ) + .with_columns( # Convert "position_from_tss" column from pl.Utf8 to pl.Int32. pl.col("position_from_tss").cast(pl.Int32) ) @@ -336,37 +337,41 @@ def get_tss_profile( # Normalize smoothed TSS matrix. # Get normalized sample TSS enrichment per position from the per CB # smoothed TSS matrix. - tss_norm_matrix_sample = tss_smoothed_matrix_per_cb.select( - pl.col("position_from_tss"), - # Get total number of cut sites per position over all CBs. - pl.sum_horizontal(pl.all().exclude("position_from_tss")).alias( - "smoothed_per_pos_sum" - ), - ).select( - pl.col("position_from_tss"), - # Normalize total number of cut sites per position over all CBs. - ( - pl.col("smoothed_per_pos_sum") - / ( - # Calculate background value from start and end over - # minimum_signal_window length. - ( + tss_norm_matrix_sample = ( + tss_smoothed_matrix_per_cb.clone() + .select( + pl.col("position_from_tss"), + # Get total number of cut sites per position over all CBs. + pl.sum_horizontal(pl.all().exclude("position_from_tss")).alias( + "smoothed_per_pos_sum" + ), + ) + .select( + pl.col("position_from_tss"), + # Normalize total number of cut sites per position over all CBs. + ( + pl.col("smoothed_per_pos_sum") + / ( + # Calculate background value from start and end over + # minimum_signal_window length. ( - pl.col("smoothed_per_pos_sum") - .head(minimum_signal_window) - .mean() - + pl.col("smoothed_per_pos_sum") - .tail(minimum_signal_window) - .mean() + ( + pl.col("smoothed_per_pos_sum") + .head(minimum_signal_window) + .mean() + + pl.col("smoothed_per_pos_sum") + .tail(minimum_signal_window) + .mean() + ) + / 2 ) - / 2 + # Or use min_norm. + .append(min_norm) + # Take highest value. + .max() ) - # Or use min_norm. - .append(min_norm) - # Take highest value. - .max() - ) - ).alias("normalized_tss_enrichment"), + ).alias("normalized_tss_enrichment"), + ) ) # Get normalized TSS matrix per CB for each cut site position. @@ -390,7 +395,7 @@ def get_tss_profile( .max() ) ).alias(CB) - for CB in tss_smoothed_matrix_per_cb.columns[1:] + for CB in tss_smoothed_matrix_per_cb.collect_schema().names()[1:] ] ) diff --git a/src/pycisTopic/utils.py b/src/pycisTopic/utils.py index e2bde3d..4fc42b8 100644 --- a/src/pycisTopic/utils.py +++ b/src/pycisTopic/utils.py @@ -2,16 +2,17 @@ import gc import gzip -import logging import math import os import re from pathlib import Path -from typing import Literal, Sequence, Union +from typing import Sequence, Union import matplotlib.backends.backend_pdf import matplotlib.pyplot as plt +import numba import numpy as np +import numpy.typing as npt import pandas as pd import polars as pl import pyranges as pr @@ -40,10 +41,9 @@ def coord_to_region_names(df_pl: pl.DataFrame) -> list[str]: Returns ------- - List of region names. - """ + """ df_pl.select( [ ( @@ -67,13 +67,12 @@ def region_names_to_coordinates(region_names: Sequence[str]) -> pd.DataFrame: Returns ------- - Pandas DataFrame with region IDs to coordinates mapping. - """ + """ region_df = ( pl.DataFrame( - data = {"RegionIDs": region_names}, + data={"RegionIDs": region_names}, ) .with_columns( pl.col("RegionIDs") @@ -84,7 +83,8 @@ def region_names_to_coordinates(region_names: Sequence[str]) -> pd.DataFrame: # Give sensible names to each splitted part. .struct.rename_fields( ["Chromosome", "Start", "End"], - ).alias("RegionIDsFields") + ) + .alias("RegionIDsFields") ) # Unpack "RegionIDsFields" struct column and create Chromosome", "Start" and "End" columns. .unnest("RegionIDsFields") @@ -126,7 +126,46 @@ def non_zero_rows(matrix: Union[sparse.csr_matrix, np.ndarray]): return np.nonzero(np.count_nonzero(matrix, axis=1))[0] +@numba.njit +def get_nonzero_row_indices(x: npt.NDArray[np.float32]): + """Get the indices of the rows that have at least one nonzero element.""" + # Optimized version of: + # np.nonzero(np.count_nonzero(x, axis=1))[0] + nonzero_row_indices = np.empty((x.shape[0],), dtype=np.intp) + output_idx = 0 + for i in range(x.shape[0]): + for j in range(x.shape[1]): + if x[i, j] != 0: + # Found a nonzero element in the row, so keep the row index and go to + # the next row. + nonzero_row_indices[output_idx] = i + output_idx += 1 + break + # Return row indices of nonzero rows (and truncate the array to the correct size). + return nonzero_row_indices[:output_idx] + + def loglikelihood(nzw, ndz, alpha, eta): + """ + Loglikelihood function to use with collapsed gibbs sampling LDA model from python `lda` package. + + The loglikelihood function in python `lda` package does not return what we want: + https://github.com/lda-project/lda/issues/102 + + The loglikelihood function is based on the following implementation: + https://github.com/slycoder/R-lda/blob/master/src/gibbs.c + + Parameters + ---------- + nzw + ndz + alpha + eta + + Returns + ------- + + """ D = ndz.shape[0] n_topics = ndz.shape[1] vocab_size = nzw.shape[1] @@ -202,6 +241,7 @@ def regions_overlap(target, query): ).to_list() return selected_regions + def prepare_tag_cells(cell_names, split_pattern="___"): if split_pattern == "-": new_cell_names = [ @@ -353,112 +393,6 @@ def get_tss_matrix(fragments, flank_window, tss_space_annotation): return TSS_matrix -def read_fragments_from_file( - fragments_bed_filename, use_polars: bool = True -) -> pr.PyRanges: - """ - Read fragments BED file to PyRanges object. - - Parameters - ---------- - fragments_bed_filename: Fragments BED filename. - use_polars: Use polars instead of pandas for reading the fragments BED file. - - Returns - ------- - PyRanges object of fragments. - """ - - bed_column_names = ( - "Chromosome", - "Start", - "End", - "Name", - "Score", - "Strand", - "ThickStart", - "ThickEnd", - "ItemRGB", - "BlockCount", - "BlockSizes", - "BlockStarts", - ) - - # Set the correct open function depending if the fragments BED file is gzip compressed or not. - open_fn = gzip.open if fragments_bed_filename.endswith(".gz") else open - - skip_rows = 0 - nbr_columns = 0 - with open_fn(fragments_bed_filename, "rt") as fragments_bed_fh: - for line in fragments_bed_fh: - # Remove newlines and spaces. - line = line.strip() - - if not line or line.startswith("#"): - # Count number of empty lines and lines which start with a comment before the actual data. - skip_rows += 1 - else: - # Get number of columns from the first real BED entry. - nbr_columns = len(line.split("\t")) - - # Stop reading the BED file. - break - - if nbr_columns < 4: - raise ValueError( - f'Fragments BED file needs to have at least 4 columns. "{fragments_bed_filename}" contains only ' - f"{nbr_columns} columns." - ) - - if use_polars: - import polars as pl - - # Read fragments BED file with polars. - df = ( - pl.read_csv( - fragments_bed_filename, - has_header=False, - skip_rows=skip_rows, - separator="\t", - use_pyarrow=False, - new_columns=bed_column_names[:nbr_columns], - ) - .with_columns( - [ - pl.col("Chromosome").cast(pl.Utf8), - pl.col("Start").cast(pl.Int32), - pl.col("End").cast(pl.Int32), - pl.col("Name").cast(pl.Utf8), - ] - ) - .to_pandas() - ) - - # Convert "Name" column to pd.Categorical as groupby operations will be done on it later. - df["Name"] = df["Name"].astype("category") - else: - # Read fragments BED file with pandas. - df = pd.read_table( - fragments_bed_filename, - sep="\t", - skiprows=skip_rows, - header=None, - names=bed_column_names[:nbr_columns], - doublequote=False, - engine="c", - dtype={ - "Chromosome": str, - "Start'": np.int32, - "End": np.int32, - "Name": "category", - "Strand": str, - }, - ) - - # Convert pandas dataframe to PyRanges dataframe. - # This will convert "Chromosome" and "Strand" columns to pd.Categorical. - return pr.PyRanges(df) - def coord_to_region_names(coord): """