Skip to content

Commit 4b89528

Browse files
authored
Integrate canns-ripser with progress bar support (#24)
* Switch to canns-ripser and add progress bar support Replaced ripser with canns-ripser for persistent homology computations and added progress bar configuration throughout TDA analysis functions. Updated dependencies in pyproject.toml and improved logging for analysis steps. Also updated example usage and made related dependency and lock file changes. * Remove old ripser dependency in favor of canns-ripser This completes the migration from the original ripser package to canns-ripser, eliminating potential conflicts and ensuring consistent behavior.
1 parent 30b8f30 commit 4b89528

File tree

4 files changed

+1197
-924
lines changed

4 files changed

+1197
-924
lines changed

examples/experimental_cann2d_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@
6060
tda_config = TDAConfig(
6161
maxdim=1,
6262
do_shuffle=False,
63+
# num_shuffles=10,
6364
show=True,
6465
dim=6,
65-
n_points=1200
66+
n_points=1200,
67+
progress_bar=True,
6668
)
6769

6870
persistence_result = tda_vis(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ classifiers = [
3939

4040
dependencies = [
4141
"BrainX[cpu]",
42+
"canns-ripser>=0.4.3",
4243
"furo>=2025.7.19",
4344
"notebook>=7.4.4",
4445
"ratinabox>=1.15.3",
45-
"ripser>=0.6.12",
4646
"tqdm",
4747
]
4848

src/canns/analyzer/experimental_data/cann2d.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import multiprocessing as mp
23
import numbers
34
import os
@@ -8,7 +9,8 @@
89
import numpy as np
910
from matplotlib import animation, cm, gridspec
1011
from numpy.exceptions import AxisError
11-
from ripser import ripser
12+
from canns_ripser import ripser
13+
# from ripser import ripser
1214
from scipy import signal
1315
from scipy.ndimage import (
1416
_nd_image,
@@ -58,6 +60,7 @@ class TDAConfig:
5860
show: bool = True
5961
do_shuffle: bool = False
6062
num_shuffles: int = 1000
63+
progress_bar: bool = True
6164

6265

6366
@dataclass
@@ -465,6 +468,7 @@ def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -
465468
show=kwargs.get("show", True),
466469
do_shuffle=kwargs.get("do_shuffle", False),
467470
num_shuffles=kwargs.get("num_shuffles", 1000),
471+
progress_bar=kwargs.get("progress_bar", True),
468472
)
469473

470474
try:
@@ -496,31 +500,29 @@ def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -
496500
def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]:
497501
"""Compute persistent homology for real data with progress tracking."""
498502

499-
with tqdm(total=5, desc="Processing real data") as pbar:
500-
# Step 1: Time point downsampling
501-
pbar.set_description("Time point downsampling")
502-
times_cube = _downsample_timepoints(embed_data, config.num_times)
503-
pbar.update(1)
503+
logging.info("Processing real data - Starting TDA analysis (5 steps)")
504+
505+
# Step 1: Time point downsampling
506+
logging.info("Step 1/5: Time point downsampling")
507+
times_cube = _downsample_timepoints(embed_data, config.num_times)
504508

505-
# Step 2: Select most active time points
506-
pbar.set_description("Selecting active time points")
507-
movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
508-
pbar.update(1)
509+
# Step 2: Select most active time points
510+
logging.info("Step 2/5: Selecting active time points")
511+
movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
509512

510-
# Step 3: PCA dimensionality reduction
511-
pbar.set_description("PCA dimensionality reduction")
512-
dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
513-
pbar.update(1)
513+
# Step 3: PCA dimensionality reduction
514+
logging.info("Step 3/5: PCA dimensionality reduction")
515+
dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
514516

515-
# Step 4: Point cloud sampling (denoising)
516-
pbar.set_description("Point cloud denoising")
517-
indstemp = _apply_denoising(dimred, config)
518-
pbar.update(1)
517+
# Step 4: Point cloud sampling (denoising)
518+
logging.info("Step 4/5: Point cloud denoising")
519+
indstemp = _apply_denoising(dimred, config)
519520

520-
# Step 5: Compute persistent homology
521-
pbar.set_description("Computing persistent homology")
522-
persistence = _compute_persistence_homology(dimred, indstemp, config)
523-
pbar.update(1)
521+
# Step 5: Compute persistent homology
522+
logging.info("Step 5/5: Computing persistent homology")
523+
persistence = _compute_persistence_homology(dimred, indstemp, config)
524+
525+
logging.info("TDA analysis completed successfully")
524526

525527
# Return all necessary data in dictionary format
526528
return {
@@ -573,7 +575,12 @@ def _compute_persistence_homology(
573575
np.fill_diagonal(d, 0)
574576

575577
return ripser(
576-
d, maxdim=config.maxdim, coeff=config.coeff, do_cocycles=True, distance_matrix=True
578+
d,
579+
maxdim=config.maxdim,
580+
coeff=config.coeff,
581+
do_cocycles=True,
582+
distance_matrix=True,
583+
progress_bar=config.progress_bar
577584
)
578585

579586

@@ -598,6 +605,7 @@ def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict
598605
embed_data,
599606
num_shuffles=config.num_shuffles,
600607
num_cores=Constants.MULTIPROCESSING_CORES,
608+
progress_bar=config.progress_bar,
601609
**shuffle_params,
602610
)
603611

@@ -907,6 +915,7 @@ def _compute_persistence(
907915
nbs=800,
908916
maxdim=1,
909917
coeff=47,
918+
progress_bar=True,
910919
):
911920
# Time point downsampling
912921
times_cube = np.arange(0, sspikes.shape[0], num_times)
@@ -927,7 +936,7 @@ def _compute_persistence(
927936
np.fill_diagonal(d, 0)
928937

929938
# Compute persistent homology
930-
persistence = ripser(d, maxdim=maxdim, coeff=coeff, do_cocycles=True, distance_matrix=True)
939+
persistence = ripser(d, maxdim=maxdim, coeff=coeff, do_cocycles=True, distance_matrix=True, progress_bar=progress_bar)
931940

932941
return persistence
933942

@@ -1302,32 +1311,28 @@ def _second_build(data, indstemp, nbs=800, metric="cosine"):
13021311
return d
13031312

13041313

1305-
def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, **kwargs):
1314+
def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
13061315
"""Perform shuffle analysis with optimized computation."""
1307-
return _run_shuffle_analysis_multiprocessing(sspikes, num_shuffles, num_cores, **kwargs)
1316+
return _run_shuffle_analysis_multiprocessing(sspikes, num_shuffles, num_cores, progress_bar, **kwargs)
13081317

13091318

1310-
def _run_shuffle_analysis_multiprocessing(sspikes, num_shuffles=1000, num_cores=4, **kwargs):
1319+
def _run_shuffle_analysis_multiprocessing(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
13111320
"""Original multiprocessing implementation for fallback."""
13121321
max_lifetimes = {0: [], 1: [], 2: []}
13131322

13141323
# Estimate runtime with a test iteration
1315-
print("Running test iteration to estimate runtime...")
1324+
logging.info("Running test iteration to estimate runtime...")
13161325

13171326
_ = _process_single_shuffle((0, sspikes, kwargs))
13181327

13191328
# Prepare task list
13201329
tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)]
1330+
logging.info(f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores...")
13211331

13221332
# Use multiprocessing pool for parallel processing
13231333
with mp.Pool(processes=num_cores) as pool:
1324-
results = list(
1325-
tqdm(
1326-
pool.imap(_process_single_shuffle, tasks),
1327-
total=num_shuffles,
1328-
desc="Running shuffle analysis",
1329-
)
1330-
)
1334+
results = list(pool.imap(_process_single_shuffle, tasks))
1335+
logging.info("Shuffle analysis completed")
13311336

13321337
# Collect results
13331338
for res in results:
@@ -1798,7 +1803,6 @@ def plot_3d_bump_on_torus(
17981803
frame_data = []
17991804
prev_m = None
18001805

1801-
print("Preparing animation data...")
18021806
for frame_idx in tqdm(range(n_frames), desc="Processing frames"):
18031807
start_idx = frame_idx * frame_step
18041808
end_idx = start_idx + window_size
@@ -2102,8 +2106,8 @@ def _smooth_image(img, sigma):
21022106
# reduce_func = reducer.fit_transform
21032107
#
21042108
# plot_projection(reduce_func=reduce_func, embed_data=spikes, show=True)
2105-
# results = tda_vis(
2106-
# embed_data=spikes, maxdim=2, do_shuffle=False, show=True
2107-
# )
2109+
results = tda_vis(
2110+
embed_data=spikes, maxdim=1, do_shuffle=False, show=True
2111+
)
21082112

2109-
results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)
2113+
# results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)

0 commit comments

Comments
 (0)