Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
23795d2
added option to switch reconstruction algorithms in the end2end pipeline
vivaansinghvi07 Sep 1, 2025
a0e3488
added a few more options in tthe benchmarking
vivaansinghvi07 Sep 1, 2025
f40fe50
allow parameterization of naive and shortcut algorithm
vivaansinghvi07 Sep 3, 2025
e046cd4
made the reconstruction something adjusted within the program
vivaansinghvi07 Sep 7, 2025
c659edc
collapse unifs everywhere for better distance calculation
vivaansinghvi07 Sep 8, 2025
ba7349f
more flexibility when running the script
vivaansinghvi07 Sep 8, 2025
2ab73ed
write out erconstruction algorithm in the dataframe
vivaansinghvi07 Sep 9, 2025
dfb1468
removed unused imports
vivaansinghvi07 Sep 14, 2025
18e8ed3
added double quotes around varaible access in shell script
vivaansinghvi07 Sep 14, 2025
e645265
added more data generation code
vivaansinghvi07 Sep 16, 2025
9ae078e
fixed issue with file paths
vivaansinghvi07 Sep 16, 2025
51f16c8
fized issue with fossil interval and none
vivaansinghvi07 Sep 16, 2025
11b5cfc
remove index when printing to csv
vivaansinghvi07 Sep 16, 2025
00d2746
added option to customize dstream algos
vivaansinghvi07 Sep 21, 2025
d36c2b1
output retention algo in the record
vivaansinghvi07 Sep 23, 2025
23d0535
fix the futurewarning in mark ancestor origin time
vivaansinghvi07 Oct 16, 2025
2b64f95
Merge pull request #263 from mmore500/master
mmore500 Oct 21, 2025
97e217e
Bump downstream pin
mmore500 Oct 22, 2025
37c3433
Improve telemetry
mmore500 Oct 22, 2025
3f951fd
Add progress bar and hybrid algo
mmore500 Oct 22, 2025
6f3b075
Add inner progress bar
mmore500 Oct 22, 2025
65f3cd3
Clean up subprocessing and add telemetry
mmore500 Oct 22, 2025
2c78c12
Add telemetry
mmore500 Oct 22, 2025
206e2dc
Force flushes from all prints
mmore500 Oct 22, 2025
16062ce
Redirect stdout to stderr instead of silencing
mmore500 Oct 22, 2025
c9743fd
fixup! Redirect stdout to stderr instead of silencing
mmore500 Oct 22, 2025
d074d3c
Further clean up telemetry
mmore500 Oct 22, 2025
c1ce009
fixup! Further clean up telemetry
mmore500 Oct 22, 2025
36c5c82
Merge branch 'end2end-naive-reconstruction' into sync-260
mmore500 Oct 26, 2025
b901c14
Merge pull request #265 from mmore500/sync-260
mmore500 Oct 26, 2025
6016a13
Strip out of date log message
mmore500 Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 160 additions & 45 deletions examples/end2end_tree_reconstruction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
from Bio.Phylo.BaseTree import Clade as BioClade
import alifedata_phyloinformatics_convert as apc
from colorclade import draw_colorclade_tree
from downstream import dstream
import matplotlib.pyplot as plt
import opytional as opyt
import pandas as pd
from teeplot import teeplot as tp
from tqdm import tqdm

from hstrat._auxiliary_lib import (
alifestd_calc_triplet_distance_asexual,
alifestd_collapse_unifurcations,
alifestd_count_leaf_nodes,
alifestd_delete_unifurcating_roots_asexual,
alifestd_mark_node_depth_asexual,
alifestd_prune_extinct_lineages_asexual,
alifestd_try_add_ancestor_list_col,
)
from hstrat.dataframe._surface_unpack_reconstruct import (
ReconstructionAlgorithm,
)


def to_ascii(
Expand All @@ -35,8 +41,9 @@ def to_ascii(
phylogeny_df, mutate=True
).drop(columns=["extant"])
phylogeny_df = alifestd_collapse_unifurcations(phylogeny_df, mutate=True)

dp_tree = apc.RosettaTree(phylogeny_df).as_dendropy
if dp_tree is None:
return "Tree is empty after visualization preprocessing"
for nd in dp_tree.preorder_node_iter():
nd._child_nodes.sort(
key=lambda nd: max(leaf.taxon.label for leaf in nd.leaf_iter()),
Expand All @@ -63,8 +70,19 @@ def sample_reference_and_reconstruction(
differentia_bitwidth: int,
surface_size: int,
fossil_interval: typing.Optional[int],
*,
no_preset_randomness: bool,
reconstruction_algorithm: ReconstructionAlgorithm,
retention_algo: str,
) -> typing.Dict[str, pd.DataFrame]:
"""Sample a reference phylogeny and corresponding reconstruction."""
print("sample_reference_and_reconstruction subprocess...", flush=True)
print(f" differentia_bitwidth: {differentia_bitwidth}", flush=True)
print(f" surface_size: {surface_size}", flush=True)
print(f" fossil_interval: {fossil_interval}", flush=True)
print(f" no_preset_randomness: {no_preset_randomness}", flush=True)
print(f" reconst algo: {reconstruction_algorithm.value}", flush=True)
print(f" retention_algo: {retention_algo}", flush=True)
try:
paths = subprocess.run(
[
Expand All @@ -78,14 +96,22 @@ def sample_reference_and_reconstruction(
["--fossil-interval", f"{fossil_interval}"]
* (fossil_interval is not None)
),
"--retention-algo",
f"{retention_algo}",
*(["--no-preset-randomness"] if no_preset_randomness else []),
],
check=True,
capture_output=True,
env=dict(
os.environ,
HSTRAT_RECONSTRUCTION_ALGO=reconstruction_algorithm.value,
),
stderr=None,
stdout=subprocess.PIPE,
text=True,
).stdout.strip()
except subprocess.CalledProcessError as e:
print(f"\033[33m{e.stdout}\033[0m") # color yellow
print(f"\033[31m{e.stderr}\033[0m") # color red
print(f"\033[33m{e.stdout}\033[0m", flush=True) # color yellow
print(f"\033[31m{e.stderr}\033[0m", flush=True) # color red
raise e

path_vars = dict() # outparam for exec
Expand All @@ -94,6 +120,9 @@ def sample_reference_and_reconstruction(
reconst_phylo_df = alifestd_try_add_ancestor_list_col(
load_df(path_vars["reconst_phylo_df_path"]),
) # ancestor_list column must be added to comply with alife standard
for fp in path_vars.values(): # these are temporary anyways
if isinstance(fp, str) and os.path.exists(fp):
os.remove(fp)

assert alifestd_count_leaf_nodes(
true_phylo_df
Expand Down Expand Up @@ -212,14 +241,19 @@ def display_reconstruction(
"""Print a sample of the reference and reconstructed phylogenies."""
show_taxa = (
frames["reconst_dropped_fossils"]["taxon_label"]
.apply(
lambda x: (
pd.NA if not isinstance(x, str) or x.startswith("Inner") else x
)
)
.dropna()
.sample(6, random_state=1)
)
print("ground-truth phylogeny sample:")
print(to_ascii(frames["exact_dropped_fossils"], show_taxa))
print()
print("reconstructed phylogeny sample:")
print(to_ascii(frames["reconst_dropped_fossils"], show_taxa))
print("ground-truth phylogeny sample:", flush=True)
print(to_ascii(frames["exact_dropped_fossils"], show_taxa), flush=True)
print(flush=True)
print("reconstructed phylogeny sample:", flush=True)
print(to_ascii(frames["reconst_dropped_fossils"], show_taxa), flush=True)

if create_plots:
for df in frames.values():
Expand All @@ -242,17 +276,25 @@ def test_reconstruct_one(
fossil_interval: typing.Optional[int],
*,
visualize: bool,
) -> typing.Dict[str, typing.Union[int, float, None]]:
no_preset_randomness: bool,
reconstruction_algorithm: ReconstructionAlgorithm,
retention_algo: str,
) -> typing.Dict[str, typing.Union[int, float, str, None]]:
"""Test the reconstruction of a single phylogeny."""
print("=" * 80)
print(f"surface_size: {surface_size}")
print(f"differentia_bitwidth: {differentia_bitwidth}")
print(f"fossil_interval: {fossil_interval}")
print("=" * 80, flush=True)
print(f"surface_size: {surface_size}", flush=True)
print(f"differentia_bitwidth: {differentia_bitwidth}", flush=True)
print(f"fossil_interval: {fossil_interval}", flush=True)
print(f"reconstruction_algorithm: {reconstruction_algorithm}", flush=True)
print(f"retention_algo: {retention_algo}", flush=True)

frames = sample_reference_and_reconstruction(
differentia_bitwidth,
surface_size,
fossil_interval,
no_preset_randomness=no_preset_randomness,
reconstruction_algorithm=reconstruction_algorithm,
retention_algo=retention_algo,
)

display_reconstruction(
Expand All @@ -263,18 +305,33 @@ def test_reconstruct_one(
create_plots=visualize,
)
reconstruction_error = alifestd_calc_triplet_distance_asexual(
alifestd_collapse_unifurcations(frames["exact"]), frames["reconst"]
alifestd_delete_unifurcating_roots_asexual(
alifestd_collapse_unifurcations(frames["exact"])
),
alifestd_delete_unifurcating_roots_asexual(
alifestd_collapse_unifurcations(frames["reconst"])
),
taxon_label_key="taxon_label",
)

reconstruction_error_dropped_fossils = (
alifestd_calc_triplet_distance_asexual(
alifestd_collapse_unifurcations(frames["exact_dropped_fossils"]),
frames["reconst_dropped_fossils"],
alifestd_delete_unifurcating_roots_asexual(
alifestd_collapse_unifurcations(
frames["exact_dropped_fossils"]
)
),
alifestd_delete_unifurcating_roots_asexual(
alifestd_collapse_unifurcations(
frames["reconst_dropped_fossils"]
)
),
taxon_label_key="taxon_label",
)
)

print(f"{reconstruction_error=}")
print(f"{reconstruction_error_dropped_fossils=}")
print(f"{reconstruction_error=}", flush=True)
print(f"{reconstruction_error_dropped_fossils=}", flush=True)
assert 0 <= reconstruction_error <= 1 # should be in the range [0,1]

return {
Expand All @@ -283,58 +340,116 @@ def test_reconstruct_one(
"fossil_interval": fossil_interval,
"error": reconstruction_error,
"error_dropped_fossils": reconstruction_error_dropped_fossils,
"reconstruction_algorithm": reconstruction_algorithm.value,
"retention_algorithm": retention_algo,
}


def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--skip-visualization", action="store_true")
return parser.parse_args()
parser.add_argument("--no-preset-randomness", action="store_true")
parser.add_argument("--repeats", type=int, default=1)
parser.add_argument(
"--reconstruction-algorithm",
type=ReconstructionAlgorithm,
choices=list(ReconstructionAlgorithm),
nargs="+",
default=(ReconstructionAlgorithm.SHORTCUT,),
)
parser.add_argument(
"--fossil-interval",
type=lambda val: None if val == "None" else int(val),
nargs="+",
default=(None, 200, 50),
)
parser.add_argument(
"--surface-size", type=int, nargs="+", default=(256, 64, 16)
)
parser.add_argument(
"--differentia-bitwidth",
type=int,
nargs="+",
choices=(64, 16, 8, 1),
default=(64, 8, 1),
)
parser.add_argument(
"--retention-algo",
type=str,
nargs="+",
choices=[f"dstream.{x}" for x in dir(dstream) if x.endswith("algo")],
default=("dstream.steady_algo",),
)
parser.add_argument(
"--output-path",
type=str,
default="/tmp/end2end-reconstruction-error.csv",
)
args = parser.parse_args()
if args.repeats > 1 and not args.no_preset_randomness:
raise ValueError(
"No point in having more than 1 repeat if using preset random seeds."
)
return args


if __name__ == "__main__":
sys.setrecursionlimit(100000)
args = _parse_args()
print(args, flush=True)
reconstruction_error_results = pd.DataFrame(
[
test_reconstruct_one(
differentia_bitwidth,
surface_size,
fossil_interval,
visualize=not args.skip_visualization,
no_preset_randomness=args.no_preset_randomness,
reconstruction_algorithm=reconstruction_algorithm,
retention_algo=retention_algo
)
for (
fossil_interval,
surface_size,
differentia_bitwidth,
) in itertools.product((None, 50, 200), (256, 64, 16), (64, 8, 1))
reconstruction_algorithm,
retention_algo,
) in tqdm(itertools.product(
args.fossil_interval,
args.surface_size,
args.differentia_bitwidth,
args.reconstruction_algorithm,
args.retention_algo,
))
for _ in tqdm(range(args.repeats))
]
).sort_values(
["fossil_interval", "surface_size", "differentia_bitwidth"],
ascending=False,
)
reconstruction_error_results.to_csv(
"/tmp/end2end-reconstruction-error.csv",
)

# error should increase with decreasing surface size
tolerance = 0.02
for f, x in reconstruction_error_results.groupby("fossil_interval"):
for first, second in itertools.pairwise(x.itertuples()):
if second.error_dropped_fossils < first.error_dropped_fossils: # type: ignore
msg = (
f"Reconstruction error of {first.error_dropped_fossils} from run " # type: ignore
f"{first.differentia_bitwidth}-{first.surface_size}-{opyt.apply_if(first.fossil_interval, int)} " # type: ignore
f" unexpectedly higher than {second.error_dropped_fossils} from run " # type: ignore
f"{second.differentia_bitwidth}-{second.surface_size}-{opyt.apply_if(second.fossil_interval, int)}" # type: ignore
)
if (
first.error_dropped_fossils - second.error_dropped_fossils # type: ignore
< tolerance
):
print(msg)
print(
"Difference is within error tolerance, continuing..."
reconstruction_error_results.to_csv(args.output_path, index=False)

# if there is a preset random seed, we need to make sure that the
# error increases with decreasing surface size and differentia bitwidth
if not args.no_preset_randomness:
tolerance = 0.02
for f, x in reconstruction_error_results.groupby("fossil_interval"):
for first, second in itertools.pairwise(x.itertuples()):
if second.error_dropped_fossils < first.error_dropped_fossils: # type: ignore
msg = (
f"Reconstruction error of {first.error_dropped_fossils} from run " # type: ignore
f"{first.differentia_bitwidth}-{first.surface_size}-{opyt.apply_if(first.fossil_interval, int)} " # type: ignore
f" unexpectedly higher than {second.error_dropped_fossils} from run " # type: ignore
f"{second.differentia_bitwidth}-{second.surface_size}-{opyt.apply_if(second.fossil_interval, int)}" # type: ignore
)
else:
raise ValueError(msg)
if (
first.error_dropped_fossils - second.error_dropped_fossils # type: ignore
< tolerance
):
print(msg, flush=True)
print(
"Difference within error tolerance, continuing...",
flush=True,
)
else:
raise ValueError(msg)
16 changes: 9 additions & 7 deletions examples/end2end_tree_reconstruction_with_dstream_surf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,31 @@ set -euo pipefail

has_cppimport="$(python3 -m pip freeze | grep '^cppimport==' | wc -l)"
if [ "${has_cppimport}" -eq 0 ]; then
echo "cppimport required for $(basename "$0") but not installed."
echo "python3 -m pip install cppimport"
echo "cppimport required for $(basename "$0") but not installed." >&2
echo "python3 -m pip install cppimport" >&2
exit 1
fi

cd "$(dirname "$0")"

genome_df_path="/tmp/end2end-raw-genome-evolve_surf_dstream.pqt"
true_phylo_df_path="/tmp/end2end-true-phylo-evolve_surf_dstream.csv"
reconst_phylo_df_path="/tmp/end2end-reconst-phylo-evolve_surf_dstream.pqt"
id="$(date +"%H-%M-%S")-$(uuidgen)"
genome_df_path="/tmp/end2end-raw-genome-evolve_surf_dstream_$id.pqt"
true_phylo_df_path="/tmp/end2end-true-phylo-evolve_surf_dstream_$id.csv"
reconst_phylo_df_path="/tmp/end2end-reconst-phylo-evolve_surf_dstream_$id.pqt"

# generate data
./evolve_dstream_surf.py \
"$@" \
--genome-df-path "${genome_df_path}" \
--phylo-df-path "${true_phylo_df_path}" \
>/dev/null 2>&1
>&2

# do reconstruction
ls "${genome_df_path}" | python3 -m \
hstrat.dataframe.surface_unpack_reconstruct \
"${reconst_phylo_df_path}" \
>/dev/null 2>&1
--reconstruction-algorithm "${HSTRAT_RECONSTRUCTION_ALGO:-shortcut}" \
>&2

# log output paths
echo "genome_df_path = '${genome_df_path}'"
Expand Down
Loading
Loading