diff --git a/examples/end2end_tree_reconstruction_test.py b/examples/end2end_tree_reconstruction_test.py index 23187ca8e..01b0596b5 100755 --- a/examples/end2end_tree_reconstruction_test.py +++ b/examples/end2end_tree_reconstruction_test.py @@ -10,19 +10,25 @@ from Bio.Phylo.BaseTree import Clade as BioClade import alifedata_phyloinformatics_convert as apc from colorclade import draw_colorclade_tree +from downstream import dstream import matplotlib.pyplot as plt import opytional as opyt import pandas as pd from teeplot import teeplot as tp +from tqdm import tqdm from hstrat._auxiliary_lib import ( alifestd_calc_triplet_distance_asexual, alifestd_collapse_unifurcations, alifestd_count_leaf_nodes, + alifestd_delete_unifurcating_roots_asexual, alifestd_mark_node_depth_asexual, alifestd_prune_extinct_lineages_asexual, alifestd_try_add_ancestor_list_col, ) +from hstrat.dataframe._surface_unpack_reconstruct import ( + ReconstructionAlgorithm, +) def to_ascii( @@ -35,8 +41,9 @@ def to_ascii( phylogeny_df, mutate=True ).drop(columns=["extant"]) phylogeny_df = alifestd_collapse_unifurcations(phylogeny_df, mutate=True) - dp_tree = apc.RosettaTree(phylogeny_df).as_dendropy + if dp_tree is None: + return "Tree is empty after visualization preprocessing" for nd in dp_tree.preorder_node_iter(): nd._child_nodes.sort( key=lambda nd: max(leaf.taxon.label for leaf in nd.leaf_iter()), @@ -63,8 +70,19 @@ def sample_reference_and_reconstruction( differentia_bitwidth: int, surface_size: int, fossil_interval: typing.Optional[int], + *, + no_preset_randomness: bool, + reconstruction_algorithm: ReconstructionAlgorithm, + retention_algo: str, ) -> typing.Dict[str, pd.DataFrame]: """Sample a reference phylogeny and corresponding reconstruction.""" + print("sample_reference_and_reconstruction subprocess...", flush=True) + print(f" differentia_bitwidth: {differentia_bitwidth}", flush=True) + print(f" surface_size: {surface_size}", flush=True) + print(f" fossil_interval: {fossil_interval}", flush=True) + print(f" no_preset_randomness: {no_preset_randomness}", flush=True) + print(f" reconst algo: {reconstruction_algorithm.value}", flush=True) + print(f" retention_algo: {retention_algo}", flush=True) try: paths = subprocess.run( [ @@ -78,14 +96,22 @@ def sample_reference_and_reconstruction( ["--fossil-interval", f"{fossil_interval}"] * (fossil_interval is not None) ), + "--retention-algo", + f"{retention_algo}", + *(["--no-preset-randomness"] if no_preset_randomness else []), ], check=True, - capture_output=True, + env=dict( + os.environ, + HSTRAT_RECONSTRUCTION_ALGO=reconstruction_algorithm.value, + ), + stderr=None, + stdout=subprocess.PIPE, text=True, ).stdout.strip() except subprocess.CalledProcessError as e: - print(f"\033[33m{e.stdout}\033[0m") # color yellow - print(f"\033[31m{e.stderr}\033[0m") # color red + print(f"\033[33m{e.stdout}\033[0m", flush=True) # color yellow + print(f"\033[31m{e.stderr}\033[0m", flush=True) # color red raise e path_vars = dict() # outparam for exec @@ -94,6 +120,9 @@ def sample_reference_and_reconstruction( reconst_phylo_df = alifestd_try_add_ancestor_list_col( load_df(path_vars["reconst_phylo_df_path"]), ) # ancestor_list column must be added to comply with alife standard + for fp in path_vars.values(): # these are temporary anyways + if isinstance(fp, str) and os.path.exists(fp): + os.remove(fp) assert alifestd_count_leaf_nodes( true_phylo_df @@ -212,14 +241,19 @@ def display_reconstruction( """Print a sample of the reference and reconstructed phylogenies.""" show_taxa = ( frames["reconst_dropped_fossils"]["taxon_label"] + .apply( + lambda x: ( + pd.NA if not isinstance(x, str) or x.startswith("Inner") else x + ) + ) .dropna() .sample(6, random_state=1) ) - print("ground-truth phylogeny sample:") - print(to_ascii(frames["exact_dropped_fossils"], show_taxa)) - print() - print("reconstructed phylogeny sample:") - print(to_ascii(frames["reconst_dropped_fossils"], show_taxa)) + print("ground-truth phylogeny sample:", flush=True) + print(to_ascii(frames["exact_dropped_fossils"], show_taxa), flush=True) + print(flush=True) + print("reconstructed phylogeny sample:", flush=True) + print(to_ascii(frames["reconst_dropped_fossils"], show_taxa), flush=True) if create_plots: for df in frames.values(): @@ -242,17 +276,25 @@ def test_reconstruct_one( fossil_interval: typing.Optional[int], *, visualize: bool, -) -> typing.Dict[str, typing.Union[int, float, None]]: + no_preset_randomness: bool, + reconstruction_algorithm: ReconstructionAlgorithm, + retention_algo: str, +) -> typing.Dict[str, typing.Union[int, float, str, None]]: """Test the reconstruction of a single phylogeny.""" - print("=" * 80) - print(f"surface_size: {surface_size}") - print(f"differentia_bitwidth: {differentia_bitwidth}") - print(f"fossil_interval: {fossil_interval}") + print("=" * 80, flush=True) + print(f"surface_size: {surface_size}", flush=True) + print(f"differentia_bitwidth: {differentia_bitwidth}", flush=True) + print(f"fossil_interval: {fossil_interval}", flush=True) + print(f"reconstruction_algorithm: {reconstruction_algorithm}", flush=True) + print(f"retention_algo: {retention_algo}", flush=True) frames = sample_reference_and_reconstruction( differentia_bitwidth, surface_size, fossil_interval, + no_preset_randomness=no_preset_randomness, + reconstruction_algorithm=reconstruction_algorithm, + retention_algo=retention_algo, ) display_reconstruction( @@ -263,18 +305,33 @@ def test_reconstruct_one( create_plots=visualize, ) reconstruction_error = alifestd_calc_triplet_distance_asexual( - alifestd_collapse_unifurcations(frames["exact"]), frames["reconst"] + alifestd_delete_unifurcating_roots_asexual( + alifestd_collapse_unifurcations(frames["exact"]) + ), + alifestd_delete_unifurcating_roots_asexual( + alifestd_collapse_unifurcations(frames["reconst"]) + ), + taxon_label_key="taxon_label", ) reconstruction_error_dropped_fossils = ( alifestd_calc_triplet_distance_asexual( - alifestd_collapse_unifurcations(frames["exact_dropped_fossils"]), - frames["reconst_dropped_fossils"], + alifestd_delete_unifurcating_roots_asexual( + alifestd_collapse_unifurcations( + frames["exact_dropped_fossils"] + ) + ), + alifestd_delete_unifurcating_roots_asexual( + alifestd_collapse_unifurcations( + frames["reconst_dropped_fossils"] + ) + ), + taxon_label_key="taxon_label", ) ) - print(f"{reconstruction_error=}") - print(f"{reconstruction_error_dropped_fossils=}") + print(f"{reconstruction_error=}", flush=True) + print(f"{reconstruction_error_dropped_fossils=}", flush=True) assert 0 <= reconstruction_error <= 1 # should be in the range [0,1] return { @@ -283,18 +340,63 @@ def test_reconstruct_one( "fossil_interval": fossil_interval, "error": reconstruction_error, "error_dropped_fossils": reconstruction_error_dropped_fossils, + "reconstruction_algorithm": reconstruction_algorithm.value, + "retention_algorithm": retention_algo, } def _parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--skip-visualization", action="store_true") - return parser.parse_args() + parser.add_argument("--no-preset-randomness", action="store_true") + parser.add_argument("--repeats", type=int, default=1) + parser.add_argument( + "--reconstruction-algorithm", + type=ReconstructionAlgorithm, + choices=list(ReconstructionAlgorithm), + nargs="+", + default=(ReconstructionAlgorithm.SHORTCUT,), + ) + parser.add_argument( + "--fossil-interval", + type=lambda val: None if val == "None" else int(val), + nargs="+", + default=(None, 200, 50), + ) + parser.add_argument( + "--surface-size", type=int, nargs="+", default=(256, 64, 16) + ) + parser.add_argument( + "--differentia-bitwidth", + type=int, + nargs="+", + choices=(64, 16, 8, 1), + default=(64, 8, 1), + ) + parser.add_argument( + "--retention-algo", + type=str, + nargs="+", + choices=[f"dstream.{x}" for x in dir(dstream) if x.endswith("algo")], + default=("dstream.steady_algo",), + ) + parser.add_argument( + "--output-path", + type=str, + default="/tmp/end2end-reconstruction-error.csv", + ) + args = parser.parse_args() + if args.repeats > 1 and not args.no_preset_randomness: + raise ValueError( + "No point in having more than 1 repeat if using preset random seeds." + ) + return args if __name__ == "__main__": sys.setrecursionlimit(100000) args = _parse_args() + print(args, flush=True) reconstruction_error_results = pd.DataFrame( [ test_reconstruct_one( @@ -302,39 +404,52 @@ def _parse_args(): surface_size, fossil_interval, visualize=not args.skip_visualization, + no_preset_randomness=args.no_preset_randomness, + reconstruction_algorithm=reconstruction_algorithm, + retention_algo=retention_algo ) for ( fossil_interval, surface_size, differentia_bitwidth, - ) in itertools.product((None, 50, 200), (256, 64, 16), (64, 8, 1)) + reconstruction_algorithm, + retention_algo, + ) in tqdm(itertools.product( + args.fossil_interval, + args.surface_size, + args.differentia_bitwidth, + args.reconstruction_algorithm, + args.retention_algo, + )) + for _ in tqdm(range(args.repeats)) ] ).sort_values( ["fossil_interval", "surface_size", "differentia_bitwidth"], ascending=False, ) - reconstruction_error_results.to_csv( - "/tmp/end2end-reconstruction-error.csv", - ) - - # error should increase with decreasing surface size - tolerance = 0.02 - for f, x in reconstruction_error_results.groupby("fossil_interval"): - for first, second in itertools.pairwise(x.itertuples()): - if second.error_dropped_fossils < first.error_dropped_fossils: # type: ignore - msg = ( - f"Reconstruction error of {first.error_dropped_fossils} from run " # type: ignore - f"{first.differentia_bitwidth}-{first.surface_size}-{opyt.apply_if(first.fossil_interval, int)} " # type: ignore - f" unexpectedly higher than {second.error_dropped_fossils} from run " # type: ignore - f"{second.differentia_bitwidth}-{second.surface_size}-{opyt.apply_if(second.fossil_interval, int)}" # type: ignore - ) - if ( - first.error_dropped_fossils - second.error_dropped_fossils # type: ignore - < tolerance - ): - print(msg) - print( - "Difference is within error tolerance, continuing..." + reconstruction_error_results.to_csv(args.output_path, index=False) + + # if there is a preset random seed, we need to make sure that the + # error increases with decreasing surface size and differentia bitwidth + if not args.no_preset_randomness: + tolerance = 0.02 + for f, x in reconstruction_error_results.groupby("fossil_interval"): + for first, second in itertools.pairwise(x.itertuples()): + if second.error_dropped_fossils < first.error_dropped_fossils: # type: ignore + msg = ( + f"Reconstruction error of {first.error_dropped_fossils} from run " # type: ignore + f"{first.differentia_bitwidth}-{first.surface_size}-{opyt.apply_if(first.fossil_interval, int)} " # type: ignore + f" unexpectedly higher than {second.error_dropped_fossils} from run " # type: ignore + f"{second.differentia_bitwidth}-{second.surface_size}-{opyt.apply_if(second.fossil_interval, int)}" # type: ignore ) - else: - raise ValueError(msg) + if ( + first.error_dropped_fossils - second.error_dropped_fossils # type: ignore + < tolerance + ): + print(msg, flush=True) + print( + "Difference within error tolerance, continuing...", + flush=True, + ) + else: + raise ValueError(msg) diff --git a/examples/end2end_tree_reconstruction_with_dstream_surf.sh b/examples/end2end_tree_reconstruction_with_dstream_surf.sh index 2caee29fa..9d22fac84 100755 --- a/examples/end2end_tree_reconstruction_with_dstream_surf.sh +++ b/examples/end2end_tree_reconstruction_with_dstream_surf.sh @@ -4,29 +4,31 @@ set -euo pipefail has_cppimport="$(python3 -m pip freeze | grep '^cppimport==' | wc -l)" if [ "${has_cppimport}" -eq 0 ]; then - echo "cppimport required for $(basename "$0") but not installed." - echo "python3 -m pip install cppimport" + echo "cppimport required for $(basename "$0") but not installed." >&2 + echo "python3 -m pip install cppimport" >&2 exit 1 fi cd "$(dirname "$0")" -genome_df_path="/tmp/end2end-raw-genome-evolve_surf_dstream.pqt" -true_phylo_df_path="/tmp/end2end-true-phylo-evolve_surf_dstream.csv" -reconst_phylo_df_path="/tmp/end2end-reconst-phylo-evolve_surf_dstream.pqt" +id="$(date +"%H-%M-%S")-$(uuidgen)" +genome_df_path="/tmp/end2end-raw-genome-evolve_surf_dstream_$id.pqt" +true_phylo_df_path="/tmp/end2end-true-phylo-evolve_surf_dstream_$id.csv" +reconst_phylo_df_path="/tmp/end2end-reconst-phylo-evolve_surf_dstream_$id.pqt" # generate data ./evolve_dstream_surf.py \ "$@" \ --genome-df-path "${genome_df_path}" \ --phylo-df-path "${true_phylo_df_path}" \ - >/dev/null 2>&1 + >&2 # do reconstruction ls "${genome_df_path}" | python3 -m \ hstrat.dataframe.surface_unpack_reconstruct \ "${reconst_phylo_df_path}" \ - >/dev/null 2>&1 + --reconstruction-algorithm "${HSTRAT_RECONSTRUCTION_ALGO:-shortcut}" \ + >&2 # log output paths echo "genome_df_path = '${genome_df_path}'" diff --git a/examples/evolve_dstream_surf.py b/examples/evolve_dstream_surf.py index 15fac7dc6..96aa39b7f 100755 --- a/examples/evolve_dstream_surf.py +++ b/examples/evolve_dstream_surf.py @@ -26,7 +26,7 @@ print("python3 -m pip install phylotrackpy") raise e -evolution_selector = random.Random(1) # ensure consistent true phylogeny +evolution_selector: random.Random def consistent_state_randrange(bitwidth: int): @@ -126,6 +126,7 @@ def make_Organism( surf_dtype = { 1: np.uint8, 8: np.uint8, + 16: np.uint16, 64: np.uint64, }[differentia_bitwidth] empty_surface = np.empty(surface_size, dtype=surf_dtype) @@ -313,6 +314,13 @@ def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--differentia-bitwidth", type=int, default=1) parser.add_argument("--surface-size", type=int, default=64) + parser.add_argument( + "--fossil-interval", + type=int, + ) + parser.add_argument( + "--retention-algo", type=str, default="dstream.steady_algo" + ) parser.add_argument( "--genome-df-path", type=str, @@ -324,13 +332,13 @@ def _parse_args() -> argparse.Namespace: default="/tmp/phylo-evolve_surf_dstream.csv", ) parser.add_argument( - "--fossil-interval", - type=int, + "--no-preset-randomness", + action="store_true", ) args = parser.parse_args() - if args.differentia_bitwidth not in (1, 8, 64): + if args.differentia_bitwidth not in (1, 8, 16, 64): raise NotImplementedError() if args.surface_size < 8: @@ -358,16 +366,21 @@ def _get_df_save_handler(path: str) -> typing.Callable: if __name__ == "__main__": - np.random.seed(2) # ensure reproducibility - random.seed(2) - args = _parse_args() + if args.no_preset_randomness: + evolution_selector = random.Random() + else: + # ensure consistent true phylogeny + evolution_selector = random.Random(1) + np.random.seed(2) # ensure reproducibility + random.seed(2) + # configure organism class syst = systematics.Systematics(lambda x: x.uid) # each org is own taxon syst.add_snapshot_fun(systematics.Taxon.get_info, "taxon_label") Organism = make_Organism( - dstream_algo=dstream.steady_algo, + dstream_algo=eval(args.retention_algo, {"dstream": dstream}), differentia_bitwidth=args.differentia_bitwidth, surface_size=args.surface_size, syst=syst, diff --git a/examples/naive_shortcut_comparison_data_generation.sh b/examples/naive_shortcut_comparison_data_generation.sh new file mode 100755 index 000000000..d8f3632c3 --- /dev/null +++ b/examples/naive_shortcut_comparison_data_generation.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set -euo pipefail # stop running if something errs (ex. a KB interrupt) + +if [[ $# -ne 2 ]]; then + echo "Must pass in a two arguments, the first denoting the maximum number of jobs running at a time and the second denoting the number of samples" + exit +fi + +echo "which python3 $(which python3)" +python3 --version +python3 -m pip show downstream +python3 examples/end2end_tree_reconstruction_test.py --help 2>&1 | head -n 4 +python3 examples/end2end_tree_reconstruction_test.py --help 2>&1 | head -n 4 & +wait + +for i in $(seq 1 $2); do + while [[ $(jobs | wc -l) -ge $1 ]]; do + sleep 1 + done + echo "Spawning job $i" + python3 examples/end2end_tree_reconstruction_test.py \ + --skip-vis \ + --no-preset-random \ + --repeats 1 \ + --fossil-interval None 200 50 \ + --reconstruction-algo shortcut naive \ + --retention-algo dstream.steady_algo dstream.tilted_algo dstream.stretched_algo dstream.hybrid_0_steady_1_tiltedxtc_2_algo \ + --differentia-bitwidth 64 8 1 \ + --surface-size 256 32 16 \ + --output-path end2end-reconstruction-error-$i.csv \ + 2>&1 | tee "run-$i.log" & +done + +wait + +zip data.zip end2end-reconstruction-error-*.csv +zip archive.zip run-*.log diff --git a/hstrat/_auxiliary_lib/_alifestd_mark_ancestor_origin_time_asexual.py b/hstrat/_auxiliary_lib/_alifestd_mark_ancestor_origin_time_asexual.py index 4a5376c39..1dec53ece 100644 --- a/hstrat/_auxiliary_lib/_alifestd_mark_ancestor_origin_time_asexual.py +++ b/hstrat/_auxiliary_lib/_alifestd_mark_ancestor_origin_time_asexual.py @@ -70,6 +70,6 @@ def alifestd_mark_ancestor_origin_time_asexual( for idx in reversed(phylogeny_df.index): ancestor_id = phylogeny_df.at[idx, "ancestor_id"] ancestor_origin_time = phylogeny_df.at[ancestor_id, "origin_time"] - phylogeny_df.at[idx, "ancestor_origin_time"] = ancestor_origin_time + phylogeny_df.at[idx, "ancestor_origin_time"] = int(ancestor_origin_time) return phylogeny_df diff --git a/hstrat/dataframe/_surface_unpack_reconstruct.py b/hstrat/dataframe/_surface_unpack_reconstruct.py index efceb66a6..b7f544a40 100644 --- a/hstrat/dataframe/_surface_unpack_reconstruct.py +++ b/hstrat/dataframe/_surface_unpack_reconstruct.py @@ -1,4 +1,5 @@ import contextlib +from enum import Enum import logging import math import multiprocessing @@ -6,6 +7,7 @@ import typing import uuid +import downstream as dstream from downstream import dataframe as dstream_dataframe import pandas as pd import polars as pl @@ -20,12 +22,19 @@ log_memory_usage, render_polars_snapshot, ) +from ..phylogenetic_inference.tree import build_tree_trie from ..phylogenetic_inference.tree._impl._build_tree_searchtable_cpp_impl_stub import ( Records, collapse_unifurcations, extend_tree_searchtable_cpp_from_exploded, extract_records_to_dict, ) +from ..serialization import surf_from_hex + + +class ReconstructionAlgorithm(Enum): + NAIVE = "naive" + SHORTCUT = "shortcut" def _sort_Tbar_argv( @@ -331,6 +340,7 @@ def _surface_unpacked_reconstruct( ) -> pl.DataFrame: """Reconstruct phylogenetic tree from unpacked dstream data.""" logging.info("building tree searchtable chunkwise...") + records = _build_records_chunked( slices, collapse_unif_freq=collapse_unif_freq, @@ -401,6 +411,7 @@ def surface_unpack_reconstruct( collapse_unif_freq: int = 1, exploded_slice_size: int = 1_000_000, mp_context: str = "spawn", + reconstruction_algorithm: ReconstructionAlgorithm = ReconstructionAlgorithm.SHORTCUT, ) -> pl.DataFrame: """Unpack dstream buffer and counter from genome data and construct an estimated phylogenetic tree for the genomes. @@ -455,6 +466,11 @@ def surface_unpack_reconstruct( mp_context : str, default 'spawn' Multiprocessing context to use for parallel processing. + reconstruction_algorithm : ReconstructionAlgorithm, default SHORTCUT + Reconstruction algorithm to use. ReconstructionAlgorithm.SHORTCUT + should nearly always be used, but ReconstructionAlgorithm.NAIVE + is also supported for benchmarking and evaluation purposes. + Returns ------- pl.DataFrame @@ -524,18 +540,63 @@ def surface_unpack_reconstruct( ) differentia_bitwidth = dstream_storage_bitwidth // dstream_S logging.info(f" - differentia bitwidth: {differentia_bitwidth}") - - logging.info("dispatching to surface_unpacked_reconstruct") - with _generate_exploded_slices_mp( - df, exploded_slice_size, mp_context - ) as slices: - phylo_df = _surface_unpacked_reconstruct( - slices, - collapse_unif_freq=collapse_unif_freq, - differentia_bitwidth=differentia_bitwidth, - dstream_S=dstream_S, - exploded_slice_size=exploded_slice_size, + if reconstruction_algorithm == ReconstructionAlgorithm.NAIVE: + if "downstream_exclude_exploded" in df.columns: + df = df.filter( + pl.col("downstream_exclude_exploded").not_().fill_null(True) + ).drop("downstream_exclude_exploded") + if "dstream_data_id" not in df.columns: + df = df.with_row_index("dstream_data_id").with_columns( + pl.col("dstream_data_id").cast(pl.UInt64) + ) + population = [ + surf_from_hex( + hex, + eval(algo, {"dstream": dstream.dstream}), + dstream_S=S, + dstream_T_bitwidth=T_bitwidth, + dstream_T_bitoffset=T_bitoffset, + dstream_storage_bitwidth=storage_bitwidth, + dstream_storage_bitoffset=storage_bitoffset, + ) + for hex, algo, S, T_bitwidth, T_bitoffset, storage_bitwidth, storage_bitoffset in df.lazy() + .select( + [ + "data_hex", + "dstream_algo", + "dstream_S", + "dstream_T_bitwidth", + "dstream_T_bitoffset", + "dstream_storage_bitwidth", + "dstream_storage_bitoffset", + ] + ) + .collect() + .iter_rows() + ] + phylo_df = pl.from_pandas( + build_tree_trie( + population, + [*df.lazy().collect()["taxon_label"]], + force_common_ancestry=True, + ) + ).join( + df.select(["taxon_label", "dstream_data_id"]).lazy().collect(), + on="taxon_label", + how="left", ) + else: + logging.info("dispatching to surface_unpacked_reconstruct") + with _generate_exploded_slices_mp( + df, exploded_slice_size, mp_context + ) as slices: + phylo_df = _surface_unpacked_reconstruct( + slices, + collapse_unif_freq=collapse_unif_freq, + differentia_bitwidth=differentia_bitwidth, + dstream_S=dstream_S, + exploded_slice_size=exploded_slice_size, + ) logging.info("joining user-defined columns...") with log_context_duration("_join_user_defined_columns", logging.info): diff --git a/hstrat/dataframe/surface_unpack_reconstruct.py b/hstrat/dataframe/surface_unpack_reconstruct.py index 8454a7a0c..ac1c9e3e0 100644 --- a/hstrat/dataframe/surface_unpack_reconstruct.py +++ b/hstrat/dataframe/surface_unpack_reconstruct.py @@ -12,7 +12,10 @@ get_hstrat_version, log_context_duration, ) -from ._surface_unpack_reconstruct import surface_unpack_reconstruct +from ._surface_unpack_reconstruct import ( + ReconstructionAlgorithm, + surface_unpack_reconstruct, +) raw_message = f"""{os.path.basename(__file__)} | (hstrat v{get_hstrat_version()}/joinem v{joinem.__version__}) @@ -158,12 +161,19 @@ def _create_parser() -> argparse.ArgumentParser: default=1_000_000, help="Number of rows to process at once. Low values reduce memory use.", ) + parser.add_argument( + "--reconstruction-algorithm", + type=ReconstructionAlgorithm, + default=ReconstructionAlgorithm.SHORTCUT, + choices=list(ReconstructionAlgorithm), + help='Phylogenetic tree reconstruction algorithm to use. "shortcut" should be used unless benchmarking naive approach.', + ) return parser def _main(mp_context: str) -> None: parser = _create_parser() - args, __ = parser.parse_known_args() + args, _ = parser.parse_known_args() with log_context_duration( "hstrat.dataframe.surface_unpack_reconstruct", logging.info @@ -174,6 +184,7 @@ def _main(mp_context: str) -> None: surface_unpack_reconstruct, collapse_unif_freq=args.collapse_unif_freq, exploded_slice_size=args.exploded_slice_size, + reconstruction_algorithm=args.reconstruction_algorithm, mp_context=mp_context, ), )