1+ import logging
12import multiprocessing as mp
23import numbers
34import os
89import numpy as np
910from matplotlib import animation , cm , gridspec
1011from numpy .exceptions import AxisError
11- from ripser import ripser
12+ from canns_ripser import ripser
13+ # from ripser import ripser
1214from scipy import signal
1315from scipy .ndimage import (
1416 _nd_image ,
@@ -58,6 +60,7 @@ class TDAConfig:
5860 show : bool = True
5961 do_shuffle : bool = False
6062 num_shuffles : int = 1000
63+ progress_bar : bool = True
6164
6265
6366@dataclass
@@ -465,6 +468,7 @@ def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -
465468 show = kwargs .get ("show" , True ),
466469 do_shuffle = kwargs .get ("do_shuffle" , False ),
467470 num_shuffles = kwargs .get ("num_shuffles" , 1000 ),
471+ progress_bar = kwargs .get ("progress_bar" , True ),
468472 )
469473
470474 try :
@@ -496,31 +500,29 @@ def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -
496500def _compute_real_persistence (embed_data : np .ndarray , config : TDAConfig ) -> dict [str , Any ]:
497501 """Compute persistent homology for real data with progress tracking."""
498502
499- with tqdm ( total = 5 , desc = "Processing real data" ) as pbar :
500- # Step 1: Time point downsampling
501- pbar . set_description ( " Time point downsampling" )
502- times_cube = _downsample_timepoints ( embed_data , config . num_times )
503- pbar . update ( 1 )
503+ logging . info ( "Processing real data - Starting TDA analysis (5 steps)" )
504+
505+ # Step 1: Time point downsampling
506+ logging . info ( "Step 1/5: Time point downsampling" )
507+ times_cube = _downsample_timepoints ( embed_data , config . num_times )
504508
505- # Step 2: Select most active time points
506- pbar .set_description ("Selecting active time points" )
507- movetimes = _select_active_timepoints (embed_data , times_cube , config .active_times )
508- pbar .update (1 )
509+ # Step 2: Select most active time points
510+ logging .info ("Step 2/5: Selecting active time points" )
511+ movetimes = _select_active_timepoints (embed_data , times_cube , config .active_times )
509512
510- # Step 3: PCA dimensionality reduction
511- pbar .set_description ("PCA dimensionality reduction" )
512- dimred = _apply_pca_reduction (embed_data , movetimes , config .dim )
513- pbar .update (1 )
513+ # Step 3: PCA dimensionality reduction
514+ logging .info ("Step 3/5: PCA dimensionality reduction" )
515+ dimred = _apply_pca_reduction (embed_data , movetimes , config .dim )
514516
515- # Step 4: Point cloud sampling (denoising)
516- pbar .set_description ("Point cloud denoising" )
517- indstemp = _apply_denoising (dimred , config )
518- pbar .update (1 )
517+ # Step 4: Point cloud sampling (denoising)
518+ logging .info ("Step 4/5: Point cloud denoising" )
519+ indstemp = _apply_denoising (dimred , config )
519520
520- # Step 5: Compute persistent homology
521- pbar .set_description ("Computing persistent homology" )
522- persistence = _compute_persistence_homology (dimred , indstemp , config )
523- pbar .update (1 )
521+ # Step 5: Compute persistent homology
522+ logging .info ("Step 5/5: Computing persistent homology" )
523+ persistence = _compute_persistence_homology (dimred , indstemp , config )
524+
525+ logging .info ("TDA analysis completed successfully" )
524526
525527 # Return all necessary data in dictionary format
526528 return {
@@ -573,7 +575,12 @@ def _compute_persistence_homology(
573575 np .fill_diagonal (d , 0 )
574576
575577 return ripser (
576- d , maxdim = config .maxdim , coeff = config .coeff , do_cocycles = True , distance_matrix = True
578+ d ,
579+ maxdim = config .maxdim ,
580+ coeff = config .coeff ,
581+ do_cocycles = True ,
582+ distance_matrix = True ,
583+ progress_bar = config .progress_bar
577584 )
578585
579586
@@ -598,6 +605,7 @@ def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict
598605 embed_data ,
599606 num_shuffles = config .num_shuffles ,
600607 num_cores = Constants .MULTIPROCESSING_CORES ,
608+ progress_bar = config .progress_bar ,
601609 ** shuffle_params ,
602610 )
603611
@@ -907,6 +915,7 @@ def _compute_persistence(
907915 nbs = 800 ,
908916 maxdim = 1 ,
909917 coeff = 47 ,
918+ progress_bar = True ,
910919):
911920 # Time point downsampling
912921 times_cube = np .arange (0 , sspikes .shape [0 ], num_times )
@@ -927,7 +936,7 @@ def _compute_persistence(
927936 np .fill_diagonal (d , 0 )
928937
929938 # Compute persistent homology
930- persistence = ripser (d , maxdim = maxdim , coeff = coeff , do_cocycles = True , distance_matrix = True )
939+ persistence = ripser (d , maxdim = maxdim , coeff = coeff , do_cocycles = True , distance_matrix = True , progress_bar = progress_bar )
931940
932941 return persistence
933942
@@ -1302,32 +1311,28 @@ def _second_build(data, indstemp, nbs=800, metric="cosine"):
13021311 return d
13031312
13041313
1305- def _run_shuffle_analysis (sspikes , num_shuffles = 1000 , num_cores = 4 , ** kwargs ):
1314+ def _run_shuffle_analysis (sspikes , num_shuffles = 1000 , num_cores = 4 , progress_bar = True , ** kwargs ):
13061315 """Perform shuffle analysis with optimized computation."""
1307- return _run_shuffle_analysis_multiprocessing (sspikes , num_shuffles , num_cores , ** kwargs )
1316+ return _run_shuffle_analysis_multiprocessing (sspikes , num_shuffles , num_cores , progress_bar , ** kwargs )
13081317
13091318
1310- def _run_shuffle_analysis_multiprocessing (sspikes , num_shuffles = 1000 , num_cores = 4 , ** kwargs ):
1319+ def _run_shuffle_analysis_multiprocessing (sspikes , num_shuffles = 1000 , num_cores = 4 , progress_bar = True , ** kwargs ):
13111320 """Original multiprocessing implementation for fallback."""
13121321 max_lifetimes = {0 : [], 1 : [], 2 : []}
13131322
13141323 # Estimate runtime with a test iteration
1315- print ("Running test iteration to estimate runtime..." )
1324+ logging . info ("Running test iteration to estimate runtime..." )
13161325
13171326 _ = _process_single_shuffle ((0 , sspikes , kwargs ))
13181327
13191328 # Prepare task list
13201329 tasks = [(i , sspikes , kwargs ) for i in range (num_shuffles )]
1330+ logging .info (f"Starting shuffle analysis with { num_shuffles } iterations using { num_cores } cores..." )
13211331
13221332 # Use multiprocessing pool for parallel processing
13231333 with mp .Pool (processes = num_cores ) as pool :
1324- results = list (
1325- tqdm (
1326- pool .imap (_process_single_shuffle , tasks ),
1327- total = num_shuffles ,
1328- desc = "Running shuffle analysis" ,
1329- )
1330- )
1334+ results = list (pool .imap (_process_single_shuffle , tasks ))
1335+ logging .info ("Shuffle analysis completed" )
13311336
13321337 # Collect results
13331338 for res in results :
@@ -1798,7 +1803,6 @@ def plot_3d_bump_on_torus(
17981803 frame_data = []
17991804 prev_m = None
18001805
1801- print ("Preparing animation data..." )
18021806 for frame_idx in tqdm (range (n_frames ), desc = "Processing frames" ):
18031807 start_idx = frame_idx * frame_step
18041808 end_idx = start_idx + window_size
@@ -2102,8 +2106,8 @@ def _smooth_image(img, sigma):
21022106 # reduce_func = reducer.fit_transform
21032107 #
21042108 # plot_projection(reduce_func=reduce_func, embed_data=spikes, show=True)
2105- # results = tda_vis(
2106- # embed_data=spikes, maxdim=2 , do_shuffle=False, show=True
2107- # )
2109+ results = tda_vis (
2110+ embed_data = spikes , maxdim = 1 , do_shuffle = False , show = True
2111+ )
21082112
2109- results = tda_vis (embed_data = spikes , maxdim = 1 , do_shuffle = True , num_shuffles = 10 , show = True )
2113+ # results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)
0 commit comments