diff --git a/applications/DynaCell/README.md b/applications/DynaCell/README.md new file mode 100644 index 000000000..3daf170d6 --- /dev/null +++ b/applications/DynaCell/README.md @@ -0,0 +1,129 @@ +# DynaCell Application + +This directory contains tools for computing and analyzing virtual staining metrics using the DynaCell benchmark datasets. + +## Overview + +The DynaCell application provides functionality to: +- Compute intensity-based metrics between target and predicted virtual staining images +- Process multiple infection conditions (Mock, DENV) in single function calls +- Support parallel processing for faster computation +- Generate detailed CSV reports with position names and dataset information + +## Key Files + +- `compute_virtual_staining_metrics.py` - Main script for computing VSCyto3D and CellDiff metrics +- `benchmark.py` - Core functions for metrics computation with parallel processing support +- `example_parallel_usage.py` - Example showing sequential vs parallel processing +- `test_parallel_metrics.py` - Test script for verifying parallel functionality + +## Parallel Processing Architecture + +### Dataset Structure and Worker Distribution + +The DynaCell metrics pipeline processes data at the individual timepoint level, enabling efficient parallel processing: + +``` +Dataset Structure: +├── Position B/1/000001 (Mock, A549, HIST2H2BE) +│ ├── Sample 0: timepoint 0 <- Worker A processes this +│ ├── Sample 1: timepoint 1 <- Worker B processes this +│ ├── Sample 2: timepoint 2 <- Worker C processes this +│ └── Sample 3: timepoint 3 <- Worker D processes this +├── Position B/1/000002 (DENV, A549, HIST2H2BE) +│ ├── Sample 4: timepoint 0 <- Worker A processes this +│ ├── Sample 5: timepoint 1 <- Worker B processes this +│ └── ... +└── Position B/2/000001 (Mock, A549, HIST2H2BE) + ├── Sample N: timepoint 0 <- Worker C processes this + └── ... +``` + +### How Workers Process Data + +**Granularity**: Each worker processes individual `(position, timepoint)` combinations + +**Distribution**: PyTorch's DataLoader distributes samples across workers in round-robin fashion: +- With `num_workers=4` and 100 samples: Worker 0 gets samples [0,4,8,12...], Worker 1 gets [1,5,9,13...], etc. + +**Batch Processing**: With `batch_size=1`, each worker processes exactly one sample at a time for metrics compatibility + +**Concurrency Benefits**: +- **I/O Parallelism**: Workers read different zarr files/slices simultaneously +- **CPU Parallelism**: Image processing, transforms, and metrics computation happen in parallel +- **Memory Efficiency**: Each worker only loads one timepoint at a time + +### Performance Optimization + +This design is particularly effective for DynaCell data because: +- Each timepoint requires significant I/O (loading image slices from zarr) +- Metrics computation is CPU-intensive +- Different timepoints are independent and can be processed in any order + +**Recommended Settings**: +- `num_workers=4-12` for typical HPC setups +- `batch_size=1` (hardcoded for metrics compatibility) +- More workers help when you have many positions/timepoints + +## Usage Examples + +### Basic Usage (Sequential) +```python +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], + target_database=database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=database, + prediction_channel_name="nuclei_prediction", + log_output_dir=output_dir, + log_name="metrics_example" +) +``` + +### Parallel Processing +```python +metrics = compute_metrics( + # ... same parameters as above ... + num_workers=8, # Use 8 workers for parallel processing +) +``` + +### Multiple Conditions +The system supports processing multiple infection conditions in a single call: +```python +infection_conditions=["Mock", "DENV"] # Processes both conditions together +``` + +## Output Format + +The generated CSV files include: +- **Standard metrics**: SSIM, PSNR, correlation coefficients, etc. +- **Position information**: `position_name` (e.g., "B/1/000001") +- **Dataset information**: `dataset` name from the database +- **Condition metadata**: `cell_type`, `organelle`, `infection_condition` +- **Temporal information**: `time_idx` for tracking timepoints + +## Thread Safety + +The system uses `ParallelSafeMetricsLogger` to prevent race conditions when multiple workers write metrics. This logger: +- Collects metrics in memory during processing +- Writes all data atomically to CSV after completion +- Prevents file corruption from concurrent writes + +## Database Structure + +The system expects database CSV files with columns: +- `Cell type` - Cell line (e.g., "A549", "HEK293T") +- `Organelle` - Target organelle (e.g., "HIST2H2BE") +- `Infection` - Condition (e.g., "Mock", "DENV") +- `Path` - Path to zarr files +- `Dataset` - Dataset identifier +- Additional metadata columns + +The system automatically handles: +- Filtering by multiple conditions using OR logic +- Deduplication by zarr path AND FOV name +- Metadata preservation through the processing pipeline \ No newline at end of file diff --git a/applications/DynaCell/benchmark.py b/applications/DynaCell/benchmark.py new file mode 100644 index 000000000..c6b37c31d --- /dev/null +++ b/applications/DynaCell/benchmark.py @@ -0,0 +1,218 @@ +""" +This script is a demo script for the DynaCell application. +It loads the ome-zarr 0.4v format, calculates metrics and saves the results as csv files +""" + +import datetime +import os +from pathlib import Path + +import pandas as pd +import torch +from lightning import LightningModule +from lightning.pytorch.loggers import CSVLogger + +from viscy.data.dynacell import DynaCellDatabase, DynaCellDataModule +from viscy.trainer import Trainer +from viscy.utils.logging import ParallelSafeMetricsLogger + +# Set float32 matmul precision for better performance on Tensor Cores +torch.set_float32_matmul_precision("high") + +# Suppress Lightning warnings for intentional CPU usage +os.environ["SLURM_NTASKS"] = "1" # Suppress SLURM warning +import warnings +warnings.filterwarnings("ignore", "GPU available but not used") +warnings.filterwarnings("ignore", "The `srun` command is available") + + +def compute_metrics( + metrics_module: LightningModule, + cell_types: list, + organelles: list, + infection_conditions: list, + target_database: pd.DataFrame, + target_channel_name: str, + prediction_database: pd.DataFrame, + prediction_channel_name: str, + log_output_dir: Path, + log_name: str = "dynacell_metrics", + log_version: str = None, + z_slice: slice = None, + transforms: list = None, + num_workers: int = 0, + use_gpu: bool = False, +): + """ + Compute DynaCell metrics with optional parallel processing. + + This function processes virtual staining metrics at the individual timepoint level, + enabling efficient parallel computation across multiple positions and timepoints. + + Parallel Processing Architecture: + - Each sample represents one (position, timepoint) combination + - Workers are distributed samples in round-robin fashion by PyTorch DataLoader + - With num_workers=4: Worker 0 gets samples [0,4,8...], Worker 1 gets [1,5,9...], etc. + - Each worker processes different timepoints/positions simultaneously + - Thread-safe logging prevents race conditions in CSV output + + Parameters + ---------- + metrics_module : LightningModule + The metrics module to use (e.g., IntensityMetrics()) + cell_types : list + List of cell types to process (e.g., ["A549"]) + organelles : list + List of organelles to process (e.g., ["HIST2H2BE"]) + infection_conditions : list + List of infection conditions to process (e.g., ["Mock", "DENV"]) + Multiple conditions are processed with OR logic in a single call + target_database : pd.DataFrame + Database containing target image paths and metadata + target_channel_name : str + Channel name in target dataset + prediction_database : pd.DataFrame + Database containing prediction image paths and metadata + prediction_channel_name : str + Channel name in prediction dataset + log_output_dir : Path + Directory for output metrics CSV files + log_name : str, optional + Name for metrics logging, by default "dynacell_metrics" + log_version : str, optional + Version string for logging, by default None (uses timestamp) + z_slice : slice, optional + Z-slice to extract from 3D data, by default None + transforms : list, optional + List of data transforms to apply, by default None + num_workers : int, optional + Number of workers for parallel data loading, by default 0 (sequential) + Recommended: 2-4 workers for CPU, 4-8 for GPU + use_gpu : bool, optional + Whether to use GPU acceleration, by default False + GPU provides 10-25x speedup for metrics computation + + Notes + ----- + - GPU acceleration provides massive speedup for metrics computation + - batch_size is hardcoded to 1 for compatibility with existing metrics code + - GPU acceleration works excellently even with batch_size=1 + - Uses ParallelSafeMetricsLogger to prevent race conditions in CSV writing + - Output CSV includes position_name, dataset, and condition metadata + + Returns + ------- + pd.DataFrame or None + Metrics DataFrame if CSV file is successfully created, None otherwise + """ + # Generate timestamp for unique versioning + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + if log_version is None: + log_version = timestamp + + # Create target database + target_db = DynaCellDatabase( + database=target_database, + cell_types=cell_types, + organelles=organelles, + infection_conditions=infection_conditions, + channel_name=target_channel_name, + z_slice=z_slice, + ) + + # For segmentation, use same channel for pred and target (self-comparison) + pred_db = DynaCellDatabase( + database=prediction_database, + cell_types=cell_types, + organelles=organelles, + infection_conditions=infection_conditions, + channel_name=prediction_channel_name, + z_slice=z_slice, + ) + + # Create data module with both databases + dm = DynaCellDataModule( + target_database=target_db, + pred_database=pred_db, + batch_size=1, # Hardcoded to 1 for metrics compatibility + num_workers=num_workers, + transforms=transforms, + ) + dm.setup(stage="test") + + # Print dataset configuration summary + sample = next(iter(dm.test_dataloader())) + # Determine device and processing info + device_name = "GPU" if use_gpu and torch.cuda.is_available() else "CPU" + processing_mode = "Parallel" if num_workers > 0 else "Sequential" + + print(f"\n📊 Dataset Configuration:") + print(f" • Samples: {len(dm.test_dataset)} total across all positions/timepoints") + print(f" • Cell types: {cell_types}") + print(f" • Organelles: {organelles}") + print(f" • Infection conditions: {infection_conditions}") + print(f" • Sample metadata: {sample['cell_type']}, {sample['organelle']}, {sample['infection_condition']}") + + # Setup logging + log_output_dir.mkdir(exist_ok=True) + + if num_workers > 0: + logger = ParallelSafeMetricsLogger(save_dir=log_output_dir, name=log_name, version=log_version) + print(f"\n🚀 Processing Mode: {processing_mode} ({num_workers} workers)") + else: + logger = CSVLogger(save_dir=log_output_dir, name=log_name, version=log_version) + print(f"\n🔄 Processing Mode: {processing_mode} (single-threaded)") + + print(f" • Device: {device_name}") + print(f" • Batch size: 1 (hardcoded for metrics compatibility)") + if use_gpu and torch.cuda.is_available(): + print(f" • GPU: {torch.cuda.get_device_name()}") + + # Configure trainer based on device preference + if use_gpu and torch.cuda.is_available(): + accelerator = "gpu" + precision = "16-mixed" # Use fp16 on GPU + else: + accelerator = "cpu" + precision = "bf16-mixed" # Use bf16 for CPU + + trainer = Trainer( + logger=logger, + accelerator=accelerator, + devices=1, + precision=precision, + num_nodes=1, + enable_progress_bar=True, + enable_model_summary=False + ) + trainer.test(metrics_module, datamodule=dm) + + # Finalize logging if using parallel-safe logger + if hasattr(logger, 'finalize'): + logger.finalize() + + # Find and report results + metrics_file = log_output_dir / log_name / log_version / "metrics.csv" + if metrics_file.exists(): + metrics = pd.read_csv(metrics_file) + print(f"\n✅ Metrics computation completed successfully!") + print(f" • Output file: {metrics_file}") + print(f" • Records: {len(metrics)} samples") + print(f" • Device: {device_name}") + print(f" • Batch size: 1 (hardcoded)") + print(f" • Metrics: {[col for col in metrics.columns if col not in ['position', 'time', 'cell_type', 'organelle', 'infection_condition', 'dataset', 'position_name']]}") + + # Show infection condition breakdown + if 'infection_condition' in metrics.columns: + condition_counts = metrics['infection_condition'].value_counts() + print(f" • Conditions: {dict(condition_counts)}") + + # Show GPU memory usage if applicable + if use_gpu and torch.cuda.is_available(): + memory_used = torch.cuda.max_memory_allocated() / 1024**3 + print(f" • Peak GPU memory: {memory_used:.2f} GB") + else: + print(f"❌ Warning: Metrics file not found at {metrics_file}") + metrics = None + + return metrics diff --git a/applications/DynaCell/check_metrics_accuracy.py b/applications/DynaCell/check_metrics_accuracy.py new file mode 100644 index 000000000..068e366f2 --- /dev/null +++ b/applications/DynaCell/check_metrics_accuracy.py @@ -0,0 +1,269 @@ +#Compare metric accuracy when computed using scikit-image and torchmetrics as implemented in viscy + +# Metrics list: mae, mse, ssim, pearson +# %% +import numpy as np +from pathlib import Path +from iohub import open_ome_zarr +import torch + +from sklearn.metrics import ( + mean_squared_error as mse, + mean_absolute_error as mae, +) +from skimage.metrics import ( + structural_similarity as ssim +) +from skimage.measure import pearson_corr_coeff as pearson +from skimage.exposure import rescale_intensity + +from torchmetrics.functional import ( + mean_squared_error as torch_mse, + mean_absolute_error as torch_mae, + pearson_corrcoef as torch_pearson, + structural_similarity_index_measure as torch_ssim +) + +from monai.transforms import NormalizeIntensity + + +# %% Load data +data_path = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_04_17_A549_H2B_CAAX_DENV/2-assemble/2025_04_17_A549_H2B_CAAX_DENV.zarr/B/1/000001") +t_idx = 0 +z_idx = 28 +target_channel_name = "raw Cy5 EX639 EM698-70" +prediction_channel_name = "nuclei_prediction" + +target_offset = 100 +pred_offset = 0 + +with open_ome_zarr(data_path, mode="r") as dataset: + channel_names = dataset.channel_names + target_channel_index = channel_names.index(target_channel_name) + prediction_channel_index = channel_names.index(prediction_channel_name) + target_volume_raw = dataset.data[t_idx, target_channel_index] + pred = dataset.data[t_idx, prediction_channel_index, z_idx] + +# Normalize entire volume by median and iqr, then take out slice of interest +# Offset doesn't matter +median = np.median(target_volume_raw) +iqr = np.percentile(target_volume_raw, 75) - np.percentile(target_volume_raw, 25) + +target_volume_correctly_normalized = (target_volume_raw - median) / iqr +target = target_volume_correctly_normalized[z_idx] + +# Convert to tensor +target_tensor = torch.from_numpy(target) +pred_tensor = torch.from_numpy(pred) + +# %% Compare MSE +mse_sk = mse(target, pred) +mse_torch = torch_mse(target_tensor, pred_tensor).item() +print(f"MSE: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}") + +# %% Compare MAE +mae_sk = mae(target, pred) +mae_torch = torch_mae(target_tensor, pred_tensor).item() +print(f"MAE: SK: {mae_sk:.3f}, Torch: {mae_torch:.3f}") + +# %% Compare Pearson +pearson_sk = pearson(target, pred).statistic +pearson_torch = torch_pearson(target_tensor.flatten(), pred_tensor.flatten()).item() +print(f"Pearson: SK: {pearson_sk:.3f}, Torch: {pearson_torch:.3f}") + +# %% Compare SSIM +ssim_sk = ssim(target, pred, data_range=target.max() - target.min()) +ssim_torch = torch_ssim( + target_tensor.unsqueeze(0).unsqueeze(0), + pred_tensor.unsqueeze(0).unsqueeze(0) +).item() +print(f"SSIM: SK: {ssim_sk:.3f}, Torch: {ssim_torch:.3f}") + + +# %% +### ------------------ Normalize both target and prediction ------------------ +# Note: we are normalizing the 2D slice which is not correct + + +# Normalize +target_raw = target_volume_raw[z_idx] +target_norm = (target_raw - target_raw.mean()) / target_raw.std() +pred_norm = (pred - pred.mean()) / pred.std() + +target_no_offset = target_raw - target_offset +pred_no_offset = pred - pred_offset + +target_no_offset_norm = (target_no_offset - target_no_offset.mean()) / target_no_offset.std() +pred_no_offset_norm = (pred_no_offset - pred_no_offset.mean()) / pred_no_offset.std() + +# Convert to tensor +target_tensor = torch.from_numpy(target_raw) +pred_tensor = torch.from_numpy(pred) + +target_no_offset_tensor = torch.from_numpy(target_no_offset) +pred_no_offset_tensor = torch.from_numpy(pred_no_offset) + +intensity_normalizer = NormalizeIntensity() +target_tensor_norm = intensity_normalizer(target_tensor) +pred_tensor_norm = intensity_normalizer(pred_tensor) + +target_no_offset_tensor_norm = intensity_normalizer(target_no_offset_tensor) +pred_no_offset_tensor_norm = intensity_normalizer(pred_no_offset_tensor) + +# SSIM normalization +target_ssim = rescale_intensity( + target_no_offset, + in_range=(np.quantile(target_no_offset, 0.01), np.quantile(target_no_offset, 0.99)), + out_range=np.uint16 +) +pred_ssim = rescale_intensity( + pred_no_offset, + in_range=(np.quantile(pred_no_offset, 0.01), np.quantile(pred_no_offset, 0.99)), + out_range=np.uint16 +) + +# %% Compare MSE +# Note: MSE is sensitive intensity normalization but not offset +print("Comparing MSE...") + +# With offset, no normalization +print(" With offset:") +mse_sk = mse(target_raw, pred) +mse_torch = torch_mse(target_tensor, pred_tensor).item() +print( + f" Without normalization: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}" +) + +# With offset, with intensity normalization +mse_sk = mse(target_norm, pred_norm) +mse_torch = torch_mse(target_tensor_norm, pred_tensor_norm).item() +print( + f" With intensity normalization: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}" +) + +# Without offset, no intensity normalization +mse_sk = mse(target_no_offset, pred_no_offset) +mse_torch = torch_mse(target_no_offset_tensor, pred_no_offset_tensor).item() +print(" Without offset:") +print( + f" Without normalization: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}" +) + +# Without offset, with intensity normalization +mse_sk = mse(target_no_offset_norm, pred_no_offset_norm) +mse_torch = torch_mse(target_no_offset_tensor_norm, pred_no_offset_tensor_norm).item() +print( + f" With intensity normalization: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}" +) + +mse_sk = mse(target_norm, pred) +mse_torch = torch_mse(target_tensor_norm, pred_tensor).item() +print( + f"Comparing normalized target with unnormalized prediction: SK: {mse_sk:.3f}, Torch: {mse_torch:.3f}" +) + +# %% Compare MAE +# Note: MAE is sensitive intensity normalization but not offset +print("Comparing MAE...") + +# With offset, no normalization +print(" With offset:") +mae_sk = mae(target_raw, pred) +mae_torch = torch_mae(target_tensor, pred_tensor).item() +print( + f" Without normalization: SK: {mae_sk:.3f}, Torch: {mae_torch:.3f}" +) + +# With offset, with intensity normalization +mae_sk = mae(target_norm, pred_norm) +mae_torch = torch_mae(target_tensor_norm, pred_tensor_norm).item() +print( + f" With intensity normalization: SK: {mae_sk:.3f}, Torch: {mae_torch:.3f}" +) + +# Without offset, no intensity normalization +mae_sk = mae(target_no_offset, pred_no_offset) +mae_torch = torch_mae(target_no_offset_tensor, pred_no_offset_tensor).item() +print(" Without offset:") +print( + f" Without normalization: SK: {mae_sk:.3f}, Torch: {mae_torch:.3f}" +) + +# Without offset, with intensity normalization +mae_sk = mae(target_no_offset_norm, pred_no_offset_norm) +mae_torch = torch_mae(target_no_offset_tensor_norm, pred_no_offset_tensor_norm).item() +print( + f" With intensity normalization: SK: {mae_sk:.3f}, Torch: {mae_torch:.3f}" +) + +# %% Compare SSIM +# Note: SSIM is sensitive to offset if intensity normalization is not applied +# SSIM is sensitive to intensity normalization +print("Comparing SSIM...") + +# With offset, no normalization +print(" With offset:") +ssim_sk = ssim(target_raw, pred, data_range=target_raw.max() - target_raw.min()) +ssim_torch = torch_ssim(target_tensor.unsqueeze(0).unsqueeze(0), pred_tensor.unsqueeze(0).unsqueeze(0)).item() +print( + f" Without normalization: SK: {ssim_sk:.3f}, Torch: {ssim_torch:.3f}" +) + +# With offset, with intensity normalization +ssim_sk = ssim(target_norm, pred_norm, data_range=target_norm.max() - target_norm.min()) +ssim_torch = torch_ssim(target_tensor_norm.unsqueeze(0).unsqueeze(0), pred_tensor_norm.unsqueeze(0).unsqueeze(0)).item() +print( + f" With intensity normalization: SK: {ssim_sk:.3f}, Torch: {ssim_torch:.3f}" +) + +# Without offset, no intensity normalization +ssim_sk = ssim(target_no_offset, pred_no_offset, data_range=target_no_offset.max() - target_no_offset.min()) +ssim_torch = torch_ssim(target_no_offset_tensor.unsqueeze(0).unsqueeze(0), pred_no_offset_tensor.unsqueeze(0).unsqueeze(0)).item() +print(" Without offset:") +print( + f" Without normalization: SK: {ssim_sk:.3f}, Torch: {ssim_torch:.3f}" +) + +# Without offset, with intensity normalization +ssim_sk = ssim(target_no_offset_norm, pred_no_offset_norm, data_range=target_no_offset_norm.max() - target_no_offset_norm.min()) +ssim_torch = torch_ssim(target_no_offset_tensor_norm.unsqueeze(0).unsqueeze(0), pred_no_offset_tensor_norm.unsqueeze(0).unsqueeze(0)).item() +print( + f" With intensity normalization: SK: {ssim_sk:.3f}, Torch: {ssim_torch:.3f}" +) + +# %% Compare Pearson +# Note: Pearson is not sensitive to offset or intensity normalization +print("Comparing Pearson...") + +# With offset, no normalization +print(" With offset:") +pearson_sk = pearson(target_raw, pred).statistic +pearson_torch = torch_pearson(target_tensor.flatten(), pred_tensor.flatten()).item() +print( + f" Without normalization: SK: {pearson_sk:.3f}, Torch: {pearson_torch:.3f}" +) + +# With offset, with intensity normalization +pearson_sk = pearson(target_norm, pred_norm).statistic +pearson_torch = torch_pearson(target_tensor_norm.flatten(), pred_tensor_norm.flatten()).item() +print( + f" With intensity normalization: SK: {pearson_sk:.3f}, Torch: {pearson_torch:.3f}" +) + +# Without offset, no normalization +pearson_sk = pearson(target_no_offset, pred_no_offset).statistic +pearson_torch = torch_pearson(target_no_offset_tensor.flatten(), pred_no_offset_tensor.flatten()).item() +print(" Without offset:") +print( + f" Without normalization: SK: {pearson_sk:.3f}, Torch: {pearson_torch:.3f}" +) + +# Without offset, with intensity normalization +pearson_sk = pearson(target_no_offset_norm, pred_no_offset_norm).statistic +pearson_torch = torch_pearson(target_no_offset_tensor_norm.flatten(), pred_no_offset_tensor_norm.flatten()).item() +print( + f" With intensity normalization: SK: {pearson_sk:.3f}, Torch: {pearson_torch:.3f}" +) + +# # %% +# %% diff --git a/applications/DynaCell/compute_virtual_staining_metrics.py b/applications/DynaCell/compute_virtual_staining_metrics.py new file mode 100644 index 000000000..5222d49a5 --- /dev/null +++ b/applications/DynaCell/compute_virtual_staining_metrics.py @@ -0,0 +1,340 @@ +# %% +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import pandas as pd +from functools import partial +from DynaCell.benchmark import compute_metrics +from monai.transforms import NormalizeIntensityd +# from viscy.transforms import NormalizedSampled # TODO +from viscy.translation.evaluation import IntensityMetrics + + +# csv_database_path = Path( +# "~/mydata/gdrive/dynacell/summary_table/dynacell_summary_table_2025_05_05.csv" +# ).expanduser() +# output_dir = Path("/home/eduardo.hirata/repos/viscy/applications/DynaCell/metrics") + +csv_database_path = Path( + "~/gdrive/publications/dynacell/summary_table/dynacell_summary_table_2025_05_05.csv" +).expanduser() +output_dir = Path("~/Documents/dynacell//metrics/virtual_staining").expanduser() +output_dir.mkdir(parents=True, exist_ok=True) + +database = pd.read_csv(csv_database_path, dtype={"FOV": str}) + +# Select test set only +database = database[database["Test Set"] == "x"] + +# TODO: +# z index may be different between Mock and Infected +# Don't normalize prediction, only target + +# %% Compute VSCyto3D intensity-based metrics +print("\nComputing VSCyto3D intensity-based metrics...") + +# HEK293T cells - Mock condition only +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=database, + target_channel_name="Organelle", + prediction_database=database, + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_nuclei_mock", + z_slice=36, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=database, + target_channel_name="Membrane", + prediction_database=database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_membrane_mock", + z_slice=36, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% A549 cells - Both Mock and DENV conditions in single calls +# Note: these metrics are computer on the full FOV, CellDiff crops it down +print("Computing A549 nuclei metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=database, + prediction_channel_name="nuclei_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_nuclei_mock_denv", + z_slice=28, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +print("Computing A549 membrane metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=database, + target_channel_name="raw mCherry EX561 EM600-37", + prediction_database=database, + prediction_channel_name="membrane_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_membrane_mock_denv", + z_slice=28, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% Compute CellDiff intensity metrics + +# Construct required databases - cropped target and CellDiff prediction are saved in target_root and pred_root +hek_h2b_database = database[ + (database["Organelle"] == "HIST2H2BE") & (database["Cell type"] == "HEK293T") +] +_database = hek_h2b_database.copy() +pred_database = hek_h2b_database.copy() +target_database = hek_h2b_database.copy() + +def replace_root(path: str, new_root: Path) -> str: + new_path = new_root / Path(path).relative_to(old_root) + return str(new_path) + +old_root = Path("/hpc/projects/comp.micro/mantis/mantis_paper_data_release/figure_4.zarr") +target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/hek/mantis_figure_4.zarr") +_database["Path"] = hek_h2b_database["Path"].apply(partial(replace_root, new_root=target_root)) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=_database, + target_channel_name="Organelle", + prediction_database=_database, # VSCyto3D predictions are in the same store + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_cropped_nuclei_mock", + z_slice=36-15, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=_database, + target_channel_name="Membrane", + prediction_database=_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_cropped_membrane_mock", + z_slice=36-15, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# CellDiff predictions are in a different store +pred_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/prediction/hek/output.zarr") +target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/hek/mantis_figure_4.zarr") + +pred_database["Path"] = pred_database["Path"].apply(partial(replace_root, new_root=pred_root)) +target_database["Path"] = target_database["Path"].apply(partial(replace_root, new_root=target_root)) + +print("\nComputing CellDiff intensity-based metrics...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=target_database, + target_channel_name="Organelle", + prediction_database=pred_database, + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_HEK293T_nuclei", + z_slice=36-15, # Crop start at slice 15 + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=target_database, + target_channel_name="Membrane", + prediction_database=pred_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_HEK293T_membrane", + z_slice=36-15, # Crop start at slice 15 + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% +a549_h2b_database = database[ + (database["Organelle"] == "HIST2H2BE") & (database["Cell type"] == "A549") +] +_database = a549_h2b_database.copy() +pred_database = a549_h2b_database.copy() +target_database = a549_h2b_database.copy() + +old_root = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_04_17_A549_H2B_CAAX_DENV/2-assemble/2025_04_17_A549_H2B_CAAX_DENV.zarr") +target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/a549/2025_04_17_A549_H2B_CAAX_DENV.zarr") +_database["Path"] = a549_h2b_database["Path"].apply(partial(replace_root, new_root=target_root)) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=_database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=_database, + prediction_channel_name="nuclei_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_cropped_nuclei_mock_denv", + z_slice=13, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +print("Computing A549 membrane metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=_database, + target_channel_name="raw mCherry EX561 EM600-37", + prediction_database=_database, + prediction_channel_name="membrane_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_cropped_membrane_mock_denv", + z_slice=13, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +pred_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr") +target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/a549/2025_04_17_A549_H2B_CAAX_DENV.zarr") +pred_database["Path"] = pred_database["Path"].apply(partial(replace_root, new_root=pred_root)) +target_database["Path"] = target_database["Path"].apply(partial(replace_root, new_root=target_root)) + +print("Computing CellDiff A549 nuclei metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=target_database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=pred_database, + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_A549_nuclei_mock_denv", + z_slice=13, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +print("Computing CellDiff A549 membrane metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=target_database, + target_channel_name="raw mCherry EX561 EM600-37", + prediction_database=pred_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_A549_membrane_mock_denv", + z_slice=13, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% diff --git a/applications/DynaCell/compute_virtual_staining_metrics_cropped.py b/applications/DynaCell/compute_virtual_staining_metrics_cropped.py new file mode 100644 index 000000000..675531327 --- /dev/null +++ b/applications/DynaCell/compute_virtual_staining_metrics_cropped.py @@ -0,0 +1,374 @@ +# %% +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import pandas as pd +from functools import partial +from DynaCell.benchmark import compute_metrics +from monai.transforms import NormalizeIntensityd +# from viscy.transforms import NormalizedSampled # TODO +from viscy.translation.evaluation import IntensityMetrics + +def replace_root(path: str, old_root: Path, new_root: Path) -> str: + new_path = new_root / Path(path).relative_to(old_root) + return str(new_path) + + +# csv_database_path = Path( +# "~/mydata/gdrive/dynacell/summary_table/dynacell_summary_table_2025_05_05.csv" +# ).expanduser() +# output_dir = Path("/home/eduardo.hirata/repos/viscy/applications/DynaCell/metrics") + +csv_database_path = Path( + "~/gdrive/publications/dynacell/summary_table/dynacell_summary_table_2025_09_16.csv" +).expanduser() +output_dir = Path("~/Documents/dynacell/metrics/virtual_staining").expanduser() +output_dir.mkdir(parents=True, exist_ok=True) + +database = pd.read_csv(csv_database_path, dtype={"FOV": str}) + +# Select test set only +database = database[database["Test Set"] == "x"] + +# TODO: +# z index may be different between Mock and Infected +# Don't normalize prediction, only target + +# %% +old_a549_h2b_database = database[ + (database["Organelle"] == "HIST2H2BE") & (database["Cell type"] == "A549") +] +crops_per_fov = 4 + +old_root = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_04_17_A549_H2B_CAAX_DENV/2-assemble/2025_04_17_A549_H2B_CAAX_DENV.zarr") +new_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/2025_04_17_A549_H2B_CAAX_DENV.zarr/") +rows = [] +for idx, row in old_a549_h2b_database.iterrows(): + for crop_idx in range(crops_per_fov): + _row = row.copy() + fov = f'{_row["FOV"]:0>6}{crop_idx}' + _row["Path"] = replace_root(_row["Path"], old_root, new_root) + str(crop_idx) + _row["FOV"] = fov + rows.append(_row) +a549_h2b_database = pd.DataFrame(rows).reset_index(drop=True) + + +old_hek_h2b_database = database[ + (database["Organelle"] == "HIST2H2BE") & (database["Cell type"] == "HEK293T") +] +crops_per_fov = 6 +old_root = Path("/hpc/projects/comp.micro/mantis/mantis_paper_data_release/figure_4.zarr") +new_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr") +rows = [] +for idx, row in old_hek_h2b_database.iterrows(): + for crop_idx in range(crops_per_fov): + _row = row.copy() + fov = f'{_row["FOV"]:0>6}{crop_idx}' + _row["Path"] = replace_root(_row["Path"], old_root, new_root) + str(crop_idx) + _row["FOV"] = fov + rows.append(_row) +hek_h2b_database = pd.DataFrame(rows).reset_index(drop=True) + +# %% Compute VSCyto3D intensity-based metrics +print("\nComputing HEK293T VSCyto3D intensity-based metrics...") + +# HEK293T cells - Mock condition only +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=hek_h2b_database, + target_channel_name="Organelle", + prediction_database=hek_h2b_database, + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_nuclei_mock", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=hek_h2b_database, + target_channel_name="Membrane", + prediction_database=hek_h2b_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_HEK293T_membrane_mock", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% A549 cells - Both Mock and DENV conditions in single calls +# Note: these metrics are computer on the full FOV, CellDiff crops it down +print("Computing A549 nuclei metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=a549_h2b_database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=a549_h2b_database, + prediction_channel_name="nuclei_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_nuclei_mock_denv", + z_slice=16, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +print("Computing A549 membrane metrics for Mock and DENV conditions...") +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=a549_h2b_database, + target_channel_name="raw mCherry EX561 EM600-37", + prediction_database=a549_h2b_database, + prediction_channel_name="membrane_prediction", + log_output_dir=output_dir, + log_name="intensity_VSCyto3D_A549_membrane_mock_denv", + z_slice=16, + num_workers=8, # Use parallel processing for speed + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% Compute CellDiff intensity metrics +# Construct required databases - cropped target and CellDiff prediction are saved in target_root and pred_root + +# HEK293T cells - Mock condition only +target_database = hek_h2b_database +pred_database = hek_h2b_database.copy() +old_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr") +new_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/prediction/hek/output.zarr") +pred_database["Path"] = pred_database["Path"].apply(partial(replace_root, old_root=old_root, new_root=new_root)) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=target_database, + target_channel_name="Organelle", + prediction_database=pred_database, # VSCyto3D predictions are in the same store + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_HEK293T_nuclei_mock", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["HEK293T"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock"], + target_database=target_database, + target_channel_name="Membrane", + prediction_database=pred_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_HEK293T_membrane_mock", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% +# A549 cells - Both Mock and DENV conditions + +target_database = a549_h2b_database +pred_database = a549_h2b_database.copy() +old_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/2025_04_17_A549_H2B_CAAX_DENV.zarr/") +new_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr") +pred_database["Path"] = pred_database["Path"].apply(partial(replace_root, old_root=old_root, new_root=new_root)) + +print("\nComputing CellDiff A549 nuclei metrics for Mock and DENV conditions...") + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=target_database, + target_channel_name="raw Cy5 EX639 EM698-70", + prediction_database=pred_database, + prediction_channel_name="Nuclei-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_A549_nuclei_mock_denv", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +metrics = compute_metrics( + metrics_module=IntensityMetrics(), + cell_types=["A549"], + organelles=["HIST2H2BE"], + infection_conditions=["Mock", "DENV"], # Multiple conditions in single call + target_database=target_database, + target_channel_name="raw mCherry EX561 EM600-37", + prediction_database=pred_database, + prediction_channel_name="Membrane-prediction", + log_output_dir=output_dir, + log_name="intensity_CellDiff_A549_membrane_mock_denv", + z_slice=16, + num_workers=8, + use_gpu=True, + transforms=[ + NormalizeIntensityd( + keys=["pred", "target"], + ) + ], +) + +# %% +# a549_h2b_database = database[ +# (database["Organelle"] == "HIST2H2BE") & (database["Cell type"] == "A549") +# ] +# _database = a549_h2b_database.copy() +# pred_database = a549_h2b_database.copy() +# target_database = a549_h2b_database.copy() + +# old_root = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_04_17_A549_H2B_CAAX_DENV/2-assemble/2025_04_17_A549_H2B_CAAX_DENV.zarr") +# target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/a549/2025_04_17_A549_H2B_CAAX_DENV.zarr") +# _database["Path"] = a549_h2b_database["Path"].apply(partial(replace_root, new_root=target_root)) + +# metrics = compute_metrics( +# metrics_module=IntensityMetrics(), +# cell_types=["A549"], +# organelles=["HIST2H2BE"], +# infection_conditions=["Mock", "DENV"], # Multiple conditions in single call +# target_database=_database, +# target_channel_name="raw Cy5 EX639 EM698-70", +# prediction_database=_database, +# prediction_channel_name="nuclei_prediction", +# log_output_dir=output_dir, +# log_name="intensity_VSCyto3D_A549_cropped_nuclei_mock_denv", +# z_slice=13, +# num_workers=8, # Use parallel processing for speed +# use_gpu=True, +# transforms=[ +# NormalizeIntensityd( +# keys=["pred", "target"], +# ) +# ], +# ) + +# print("Computing A549 membrane metrics for Mock and DENV conditions...") +# metrics = compute_metrics( +# metrics_module=IntensityMetrics(), +# cell_types=["A549"], +# organelles=["HIST2H2BE"], +# infection_conditions=["Mock", "DENV"], # Multiple conditions in single call +# target_database=_database, +# target_channel_name="raw mCherry EX561 EM600-37", +# prediction_database=_database, +# prediction_channel_name="membrane_prediction", +# log_output_dir=output_dir, +# log_name="intensity_VSCyto3D_A549_cropped_membrane_mock_denv", +# z_slice=13, +# num_workers=8, # Use parallel processing for speed +# use_gpu=True, +# transforms=[ +# NormalizeIntensityd( +# keys=["pred", "target"], +# ) +# ], +# ) + +# pred_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr") +# target_root = Path("/hpc/projects/virtual_staining/datasets/huang-lab/crops/a549/2025_04_17_A549_H2B_CAAX_DENV.zarr") +# pred_database["Path"] = pred_database["Path"].apply(partial(replace_root, new_root=pred_root)) +# target_database["Path"] = target_database["Path"].apply(partial(replace_root, new_root=target_root)) + +# print("Computing CellDiff A549 nuclei metrics for Mock and DENV conditions...") +# metrics = compute_metrics( +# metrics_module=IntensityMetrics(), +# cell_types=["A549"], +# organelles=["HIST2H2BE"], +# infection_conditions=["Mock", "DENV"], # Multiple conditions in single call +# target_database=target_database, +# target_channel_name="raw Cy5 EX639 EM698-70", +# prediction_database=pred_database, +# prediction_channel_name="Nuclei-prediction", +# log_output_dir=output_dir, +# log_name="intensity_CellDiff_A549_nuclei_mock_denv", +# z_slice=13, +# num_workers=8, # Use parallel processing for speed +# use_gpu=True, +# transforms=[ +# NormalizeIntensityd( +# keys=["pred", "target"], +# ) +# ], +# ) + +# print("Computing CellDiff A549 membrane metrics for Mock and DENV conditions...") +# metrics = compute_metrics( +# metrics_module=IntensityMetrics(), +# cell_types=["A549"], +# organelles=["HIST2H2BE"], +# infection_conditions=["Mock", "DENV"], # Multiple conditions in single call +# target_database=target_database, +# target_channel_name="raw mCherry EX561 EM600-37", +# prediction_database=pred_database, +# prediction_channel_name="Membrane-prediction", +# log_output_dir=output_dir, +# log_name="intensity_CellDiff_A549_membrane_mock_denv", +# z_slice=13, +# num_workers=8, # Use parallel processing for speed +# use_gpu=True, +# transforms=[ +# NormalizeIntensityd( +# keys=["pred", "target"], +# ) +# ], +# ) + +# %% diff --git a/applications/DynaCell/demo_plotter.py b/applications/DynaCell/demo_plotter.py new file mode 100644 index 000000000..950c40ee6 --- /dev/null +++ b/applications/DynaCell/demo_plotter.py @@ -0,0 +1,141 @@ + +#%% +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + + +def plot_segmentation_metrics(csv_path, metrics=['dice', 'jaccard', 'mAP_50'], + cell_type=None, infection_condition=None, organelle=None): + """ + Plot segmentation metrics over time with filtering options. + + Parameters + ---------- + csv_path : str or Path + Path to the CSV file containing segmentation metrics data + metrics : list of str, default ['dice', 'jaccard', 'mAP_50'] + List of metric column names to plot from the CSV + cell_type : str, optional + Filter by cell type + infection_condition : str, optional + Filter by infection condition + organelle : str, optional + Filter by organelle + """ + # Load data + df = pd.read_csv(csv_path) + df_clean = df.dropna() + + # Apply filters + filtered_data = df_clean.copy() + if cell_type: + filtered_data = filtered_data[filtered_data['cell_type'] == cell_type] + if infection_condition: + filtered_data = filtered_data[filtered_data['infection_condition'] == infection_condition] + if organelle: + filtered_data = filtered_data[filtered_data['organelle'] == organelle] + + if filtered_data.empty: + print("No data found with the specified filters.") + return + + # Set up colorblind-friendly colors (blue and orange) + plt.style.use('seaborn-v0_8') + colors = ['#1f77b4', '#ff7f0e'] # Blue and orange + sns.set_palette(colors) + + # Validate metrics exist in data + available_metrics = [m for m in metrics if m in df_clean.columns] + if not available_metrics: + print(f"None of the requested metrics {metrics} found in data.") + print(f"Available columns: {list(df_clean.columns)}") + return + + # Get unique conditions for plotting + conditions = filtered_data['infection_condition'].unique() + + # Create individual plots for each metric + for metric in available_metrics: + fig, ax = plt.subplots(figsize=(12, 6)) + + # Plot each condition + for i, condition in enumerate(conditions): + condition_data = filtered_data[filtered_data['infection_condition'] == condition] + + # Plot individual trajectories with transparency + for pos in condition_data['position'].unique(): + pos_data = condition_data[condition_data['position'] == pos] + ax.plot(pos_data['time'], pos_data[metric], + alpha=0.3, linewidth=1, color=colors[i % len(colors)]) + + # Plot mean trend line + sns.lineplot(data=condition_data, x='time', y=metric, + ax=ax, label=condition, linewidth=2, marker='o', + markersize=4, color=colors[i % len(colors)]) + + # Customize plot + filter_info = [] + if cell_type: + filter_info.append(f"Cell: {cell_type}") + if infection_condition: + filter_info.append(f"Condition: {infection_condition}") + if organelle: + filter_info.append(f"Organelle: {organelle}") + + title = f'{metric} Over Time' + if filter_info: + title += f" ({', '.join(filter_info)})" + + ax.set_title(title, fontsize=14, fontweight='bold') + ax.set_xlabel('Time', fontsize=12) + ax.set_ylabel(metric, fontsize=12) + ax.grid(True, alpha=0.3) + + if len(conditions) > 1: + ax.legend() + + # Set appropriate y-axis limits + y_min = filtered_data[metric].min() + y_max = filtered_data[metric].max() + padding = (y_max - y_min) * 0.05 if y_max > y_min else 0.05 + ax.set_ylim(max(0, y_min - padding), min(1, y_max + padding)) + + plt.tight_layout() + + # Save plot + filter_suffix = "" + if cell_type or infection_condition or organelle: + filter_parts = [] + if cell_type: + filter_parts.append(cell_type) + if infection_condition: + filter_parts.append(infection_condition) + if organelle: + filter_parts.append(organelle) + filter_suffix = f"_{'_'.join(filter_parts)}" + + output_path = Path(csv_path).parent / f'{metric}_over_time{filter_suffix}.png' + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Saved {metric} plot to: {output_path}") + + plt.show() + +# Usage +if __name__ == "__main__": + + # Example for segmentation metrics + segmentation_csv = "/home/eduardo.hirata/repos/viscy/applications/DynaCell/metrics/segmentation_2025_04_17_A549_H2B_CAAX_DENV_membrane_only/20250731_131927/metrics.csv" + + # Plot default segmentation metrics (dice, jaccard, mAP_50) + plot_segmentation_metrics(segmentation_csv) + + # Plot custom metrics with filters + plot_segmentation_metrics(segmentation_csv, + metrics=['dice', 'jaccard', 'mAP_50', 'accuracy'], + cell_type='A549', + infection_condition='DENV', + organelle='HIST2H2BE') +# %% diff --git a/applications/DynaCell/plot_virtual_staining_metrics.py b/applications/DynaCell/plot_virtual_staining_metrics.py new file mode 100644 index 000000000..5ba8e5bef --- /dev/null +++ b/applications/DynaCell/plot_virtual_staining_metrics.py @@ -0,0 +1,160 @@ +# %% +from pathlib import Path + + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + + +csv_path = Path("~/Documents/dynacell/metrics/virtual_staining/master_metrics.csv").expanduser() +df = pd.read_csv(csv_path) + +hek_time_step = 15 # in minutes +a549_time_step = 10 # in minutes + +# Add a new column to df for time in hours, depending on cell_type +def compute_time_hours(row): + if row["cell_type"].lower() == "hek293t": + return row["time"] * hek_time_step / 60 + elif row["cell_type"].lower() == "a549": + return row["time"] * a549_time_step / 60 + else: + return np.nan + +df["time_hours"] = df.apply(compute_time_hours, axis=1) + +# %% Group by model - VSCyto3D and CellDiff, and plot pearson over time +# for each cell line, organelle, and infection condition + +# Make sure the plots directory exists +plots_dir = csv_path.parent / "plots" +plots_dir.mkdir(exist_ok=True) + +# Group by cell line and organelle, plot both infection conditions in one plot +# Define a consistent color mapping for models +model_colors = { + "VSCyto3D": "#1f77b4", # blue + "CellDiff": "#ff7f0e", # orange + # Add more models here if needed +} + +# %% +# Plot each group individually in a 2x3 grid of subplots + +# Get all unique combinations of cell_type, organelle, infection_condition +# Keep nuclei on the top row and membrane on the bottom row +groups_nuclei = [g for g in df.groupby(["cell_type", "organelle", "infection_condition"]).groups.keys() if g[1].lower() == "nuclei"] +groups_membrane = [g for g in df.groupby(["cell_type", "organelle", "infection_condition"]).groups.keys() if g[1].lower() == "membrane"] +groups = groups_nuclei + groups_membrane + +n_groups = len(groups) +n_rows, n_cols = 2, 3 +fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 10), sharex=True, sharey=True) +axes = axes.flatten() + +for idx, ((cell_type, organelle, infection_condition), ax) in enumerate(zip(groups, axes)): + group = df[ + (df["cell_type"] == cell_type) & + (df["organelle"] == organelle) & + (df["infection_condition"] == infection_condition) + ] + if group.empty: + ax.set_visible(False) + continue + for model, sub_group in group.groupby("model"): + color = model_colors.get(model, None) + ax.plot( + sub_group["time_hours"], + sub_group["pearson"], + label=model, + color=color, + marker="o" + ) + ax.set_title(f"{cell_type}, {organelle}, {infection_condition}") + ax.set_xlabel("Time [hours]") + ax.set_ylabel("Pearson") + ax.set_ylim(-0.1, 0.85) + ax.legend() + +# Hide any unused subplots +for j in range(idx + 1, n_rows * n_cols): + axes[j].set_visible(False) + +plt.tight_layout() +# plt.show() +plt.savefig(plots_dir / "virtual_staining_metrics_grid.png") +plt.close() + +# %% Compare mock and infected conditions +for (cell_line, organelle), group in df.groupby(["cell_type", "organelle"]): + if not group.empty: + plt.figure() + for (infection_condition, model), sub_group in group.groupby(["infection_condition", "model"]): + if infection_condition.lower() == "mock": + linestyle = "-" + else: + linestyle = "--" + color = model_colors.get(model, None) + plt.plot( + sub_group["time_hours"], + sub_group["pearson"], + label=f"{model} ({infection_condition})", + linestyle=linestyle, + color=color, + ) + plt.ylim(-0.1, 0.85) + plt.legend() + plt.title(f"{cell_line} {organelle}") + plt.xlabel("Time [hours]") + plt.ylabel("Pearson Cross-Correlation Coefficient") + plt.savefig(plots_dir / f"{cell_line}_{organelle}_all_conditions.png") + plt.close() + +# %% Compare cell lines +# Plot for nuclei and membrane, comparing A549 (mock only) and HEK293T (all conditions), for both models + +for organelle in ["nuclei", "membrane"]: + plt.figure(figsize=(8, 6)) + for model in model_colors.keys(): + # A549, mock only (solid line) + group_a549 = df[ + (df["cell_type"] == "A549") & + (df["organelle"].str.lower() == organelle) & + (df["infection_condition"].str.lower() == "mock") & + (df["model"] == model) + ] + if not group_a549.empty: + plt.plot( + group_a549["time_hours"], + group_a549["pearson"], + label=f"A549 ({model})", + color=model_colors[model], + linestyle="-" + ) + # HEK293T, mock only (dashed line) + group_hek_mock = df[ + (df["cell_type"] == "HEK293T") & + (df["organelle"].str.lower() == organelle) & + (df["infection_condition"].str.lower() == "mock") & + (df["model"] == model) + ] + if not group_hek_mock.empty: + plt.plot( + group_hek_mock["time_hours"], + group_hek_mock["pearson"], + label=f"HEK293T ({model})", + color=model_colors[model], + linestyle="--" + ) + plt.ylim(-0.1, 0.85) + plt.xlabel("Time [hours]") + plt.ylabel("Pearson Cross-Correlation Coefficient") + plt.title(f"Virtual Staining: {organelle.capitalize()} (A549 mock vs HEK293T mock)") + plt.legend() + plt.tight_layout() + # plt.show() + plt.savefig(plots_dir / f"compare_A549_HEK293T_{organelle}.png") + plt.close() + +# %% diff --git a/applications/DynaCell/plot_virtual_staining_metrics_cropped.py b/applications/DynaCell/plot_virtual_staining_metrics_cropped.py new file mode 100644 index 000000000..aa494870d --- /dev/null +++ b/applications/DynaCell/plot_virtual_staining_metrics_cropped.py @@ -0,0 +1,319 @@ +# %% +from pathlib import Path + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +root = Path("/home/ivan.ivanov/Documents/dynacell/metrics/virtual_staining") +a549_VSCyto3D_csv_path = root / "intensity_VSCyto3D_A549_nuclei_mock_denv/20250916_171028/metrics.csv" +hek_VSCyto3D_csv_path = root / "intensity_VSCyto3D_HEK293T_nuclei_mock/20250916_170923/metrics.csv" +a549_CellDiff_csv_path = root / "intensity_CellDiff_A549_nuclei_mock_denv/20251002_092905/metrics.csv" +hek_CellDiff_csv_path = root / "intensity_CellDiff_HEK293T_nuclei_mock/20251002_092333/metrics.csv" +a549_VSCyto3D_df = pd.read_csv(a549_VSCyto3D_csv_path) +hek_VSCyto3D_df = pd.read_csv(hek_VSCyto3D_csv_path) +a549_VSCyto3D_df["Model"] = "VSCyto3D" +hek_VSCyto3D_df["Model"] = "VSCyto3D" +a549_CellDiff_df = pd.read_csv(a549_CellDiff_csv_path) +hek_CellDiff_df = pd.read_csv(hek_CellDiff_csv_path) +a549_CellDiff_df["Model"] = "CellDiff" +hek_CellDiff_df["Model"] = "CellDiff" +df = pd.concat([a549_VSCyto3D_df, hek_VSCyto3D_df, a549_CellDiff_df, hek_CellDiff_df]) + +plots_dir = root / "plots" +plots_dir.mkdir(exist_ok=True) + +hek_time_step = 15 # in minutes +a549_time_step = 10 # in minutes + +# Add a new column to df for time in hours, depending on cell_type +def compute_time_hours(row): + if row["cell_type"].lower() == "hek293t": + return row["time"] * hek_time_step / 60 + elif row["cell_type"].lower() == "a549": + return row["time"] * a549_time_step / 60 + else: + return np.nan + +df["time_hours"] = df.apply(compute_time_hours, axis=1) +df["original_position_name"] = df["position_name"].str[:-1] +df["crop_index"] = df["position_name"].str[-1] + +# %% A549 VSCyto3D plots +_df_mock = df[(df["infection_condition"] == "Mock") & (df["cell_type"] == "A549") & (df["Model"] == "VSCyto3D")] +_df_mock_mean_vscyto3d = _df_mock.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 Mock VSCyto3D Nuclei") +# plt.legend() +plt.savefig(plots_dir / "A549_Mock_VSCyto3D_Nuclei_pearson.png") + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 Mock VSCyto3D Nuclei") +# plt.legend() +plt.show() +# plt.savefig(plots_dir / "A549_Mock_VSCyto3D_Nuclei_ssim.png") + + +_df_denv = df[(df["infection_condition"] == "DENV") & (df["Model"] == "VSCyto3D")] +_df_denv_mean_vscyto3d = _df_denv.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_denv.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 DENV VSCyto3D Nuclei") +# plt.legend() +plt.show() +# plt.savefig(plots_dir / "A549_DENV_VSCyto3D_Nuclei_pearson.png") + + +plt.figure() +for idx, sub_df in _df_denv.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 DENV VSCyto3D Nuclei") +# plt.legend() +plt.show() +# plt.savefig(plots_dir / "A549_DENV_VSCyto3D_Nuclei_ssim.png") + + +# %% Compare A549 Mock and DENV +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["pearson"], label="Mock") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["pearson"], label="DENV") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 VSCyto3D Nuclei") +plt.legend() +plt.show() +# plt.savefig(plots_dir / "A549_VSCyto3D_Nuclei_pearson.png") + + +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["ssim"], label="Mock") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["ssim"], label="DENV") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 VSCyto3D Nuclei") +plt.legend() +plt.show() +# plt.savefig(plots_dir / "A549_VSCyto3D_Nuclei_ssim.png") + + +# %% A549 CellDiff plots +_df_mock = df[(df["infection_condition"] == "Mock") & (df["cell_type"] == "A549") & (df["Model"] == "CellDiff")] +_df_mock_mean_celldiff = _df_mock.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 Mock CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_Mock_CellDiff_Nuclei_pearson.png") + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 Mock CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_Mock_CellDiff_Nuclei_ssim.png") + +_df_denv = df[(df["infection_condition"] == "DENV") & (df["cell_type"] == "A549") & (df["Model"] == "CellDiff")] +_df_denv_mean_celldiff = _df_denv.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_denv.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 DENV CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_DENV_CellDiff_Nuclei_pearson.png") + + +plt.figure() +for idx, sub_df in _df_denv.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 DENV CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_DENV_CellDiff_Nuclei_ssim.png") + +# %% Compare A549 Mock and DENV +plt.figure() +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["pearson"], label="Mock") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["pearson"], label="DENV") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 CellDiff Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_CellDiff_Nuclei_pearson.png") + +plt.figure() +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["ssim"], label="Mock") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["ssim"], label="DENV") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 CellDiff Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_CellDiff_Nuclei_ssim.png") + +# %% Compare mean pearson and ssim for all models +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["pearson"], label="Mock VSCyto3D", color="blue") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["pearson"], label="Mock CellDiff", linestyle="--", color="orange") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["pearson"], label="DENV VSCyto3D", color="blue") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["pearson"], label="DENV CellDiff", linestyle="--", color="orange") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_Nuclei_pearson.png") + +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["ssim"], label="Mock VSCyto3D", color="blue") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["ssim"], label="Mock CellDiff", linestyle="--", color="orange") +plt.plot(_df_denv_mean_vscyto3d["time_hours"], _df_denv_mean_vscyto3d["ssim"], label="DENV VSCyto3D", color="blue") +plt.plot(_df_denv_mean_celldiff["time_hours"], _df_denv_mean_celldiff["ssim"], label="DENV CellDiff", linestyle="--", color="orange") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_Nuclei_ssim.png") + + +# %% HEK293T VSCyto3D plots +_df_mock = df[(df["infection_condition"] == "Mock") & (df["cell_type"] == "HEK293T") & (df["Model"] == "VSCyto3D")] +_df_hek_mean_vscyto3d = _df_mock.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_hek_mean_vscyto3d["time_hours"], _df_hek_mean_vscyto3d["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("HEK293T Mock VSCyto3D Nuclei") +# plt.legend() +plt.show() +# plt.savefig(plots_dir / "HEK293T_Mock_VSCyto3D_Nuclei_pearson.png") + + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_hek_mean_vscyto3d["time_hours"], _df_hek_mean_vscyto3d["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("HEK293T Mock VSCyto3D Nuclei") +# plt.legend() +plt.show() +# plt.savefig(plots_dir / "HEK293T_Mock_VSCyto3D_Nuclei_ssim.png") + +# %% HEK293T CellDiff plots +_df_mock = df[(df["infection_condition"] == "Mock") & (df["cell_type"] == "HEK293T") & (df["Model"] == "CellDiff")] +_df_hek_mean_celldiff = _df_mock.groupby("time")[["pearson", "ssim", "time_hours"]].mean() + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["pearson"], label=f"FOV {idx}") +plt.plot(_df_hek_mean_celldiff["time_hours"], _df_hek_mean_celldiff["pearson"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("HEK293T Mock CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "HEK293T_Mock_CellDiff_Nuclei_pearson.png") + +plt.figure() +for idx, sub_df in _df_mock.groupby("position_name"): + plt.plot(sub_df["time_hours"], sub_df["ssim"], label=f"FOV {idx}") +plt.plot(_df_hek_mean_celldiff["time_hours"], _df_hek_mean_celldiff["ssim"], 'k-', label="Average") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("HEK293T Mock CellDiff Nuclei") +# plt.legend() +# plt.show() +plt.savefig(plots_dir / "HEK293T_Mock_CellDiff_Nuclei_ssim.png") + +# %% Compare mean metrics for A549 and HEK293T +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["pearson"], label="A549 Mock VSCyto3D", color="blue") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["pearson"], label="A549 Mock CellDiff", linestyle="--", color="orange") +plt.plot(_df_hek_mean_vscyto3d["time_hours"], _df_hek_mean_vscyto3d["pearson"], label="HEK293T Mock VSCyto3D", color="green") +plt.plot(_df_hek_mean_celldiff["time_hours"], _df_hek_mean_celldiff["pearson"], label="HEK293T Mock CellDiff", linestyle="--", color="green") +plt.xlabel("Time [hours]") +plt.ylabel("Pearson") +plt.title("A549 and HEK293T Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_HEK293T_Nuclei_pearson.png") + +plt.figure() +plt.plot(_df_mock_mean_vscyto3d["time_hours"], _df_mock_mean_vscyto3d["ssim"], label="A549 Mock VSCyto3D", color="blue") +plt.plot(_df_mock_mean_celldiff["time_hours"], _df_mock_mean_celldiff["ssim"], label="A549 Mock CellDiff", linestyle="--", color="orange") +plt.plot(_df_hek_mean_vscyto3d["time_hours"], _df_hek_mean_vscyto3d["ssim"], label="HEK293T Mock VSCyto3D", color="green") +plt.plot(_df_hek_mean_celldiff["time_hours"], _df_hek_mean_celldiff["ssim"], label="HEK293T Mock CellDiff", linestyle="--", color="green") +plt.xlabel("Time [hours]") +plt.ylabel("SSIM") +plt.title("A549 and HEK293T Nuclei") +plt.legend() +# plt.show() +plt.savefig(plots_dir / "A549_HEK293T_Nuclei_ssim.png") + +# %% Plot heatmap of pearson with time as columns and cell type as rows, averaging over position name +_df = df[(df["time_hours"] < 7) & (df["infection_condition"] == "Mock") & (df["Model"] == "VSCyto3D")] +# bin time_hours into 1 hour bins +_df["time_hours_binned"] = _df["time_hours"].round(0) + +plt.figure() +sns.heatmap(_df.groupby(["time_hours_binned", "cell_type"])["pearson"].mean().unstack(), cmap="Blues") +plt.ylabel("Time [hours]") +plt.xlabel("Cell Type") +plt.title("Pearson Heatmap") +plt.show() +# plt.savefig(plots_dir / "Pearson_Heatmap_Mock.png") + +plt.figure() +sns.heatmap(_df.groupby(["time_hours_binned", "cell_type"])["ssim"].mean().unstack(), cmap="Blues") +plt.ylabel("Time [hours]") +plt.xlabel("Cell Type") +plt.title("Pearson Heatmap") +plt.show() +# plt.savefig(plots_dir / "SSIM_Heatmap_Mock.png") + + + +# %% diff --git a/docs/usage.md b/docs/usage.md index 229029a9b..947748c41 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -84,3 +84,49 @@ requires an exclusive node on HPC OR a non-distributed system (e.g. a PC). with a valid `config.yaml` in order to be initialized. This can be "hacked" by locating the config in a directory called `checkpoints` beneath a valid config's directory. + +## DynaCell Metrics + +Compute metrics on DynaCell datasets using the `compute_dynacell_metrics` command: + +```sh +viscy compute_dynacell_metrics -c config.yaml +``` + +### Configuration File Format + +Example configuration file: + +```yaml +# Required parameters +target_database: /path/to/target_database.csv +pred_database: /path/to/prediction_database.csv +output_dir: ./metrics_output +method: intensity # Options: 'intensity' or 'segmentation2D' + +# Optional parameters +target_channel: Organelle +pred_channel: Organelle +# Z-slice options: +# - Single integer (e.g., 16): Use specific z-slice +# - List of two integers [start, end] (e.g., [15, 17]): Use range of z-slices +# - -1: Use all available z-slices +target_z_slice: 16 +pred_z_slice: 16 +# You can also specify a range of z-slices: +# target_z_slice: [15, 17] # Use z-slices from 15 to 16 (exclusive of 17) +# pred_z_slice: [15, 17] # Use z-slices from 15 to 16 (exclusive of 17) +target_cell_types: [HEK293T] # or leave empty [] for all available +target_organelles: [HIST2H2BE] +target_infection_conditions: [Mock] +pred_cell_types: [HEK293T] +pred_organelles: [HIST2H2BE] +pred_infection_conditions: [Mock] +batch_size: 1 +num_workers: 0 +version: "1" +``` + +If cell types, organelles, or infection conditions are not specified or left empty, all available values from the respective database will be used. + +Using a z-slice range (e.g., `[15, 17]`) can be particularly useful for computing metrics on multiple consecutive z-slices, which is beneficial for 3D analysis or when working with volumes where the structures of interest span multiple z-slices. diff --git a/examples/configs/dynacell_metrics_example.yml b/examples/configs/dynacell_metrics_example.yml new file mode 100644 index 000000000..13e5561d0 --- /dev/null +++ b/examples/configs/dynacell_metrics_example.yml @@ -0,0 +1,34 @@ +# Example configuration for DynaCell metrics computation + +# Required parameters +target_database: /path/to/target_database.csv +pred_database: /path/to/prediction_database.csv +output_dir: ./metrics_output +method: intensity # Options: 'intensity' or 'segmentation2D' + +# Target dataset parameters +target_channel: Organelle +# Z-slice can be a single integer (e.g., 16) or a range specified as a list of two integers [start, end] (e.g., [15, 17]) +target_z_slice: 16 # Use -1 for all slices, or a list like [15, 17] for a range +target_cell_types: + - HEK293T +target_organelles: + - HIST2H2BE +target_infection_conditions: + - Mock + +# Prediction dataset parameters +pred_channel: Organelle +# Z-slice can be a single integer (e.g., 16) or a range specified as a list of two integers [start, end] (e.g., [15, 17]) +pred_z_slice: 16 # Use -1 for all slices, or a list like [15, 17] for a range +pred_cell_types: + - HEK293T +pred_organelles: + - HIST2H2BE +pred_infection_conditions: + - Mock + +# Processing parameters +batch_size: 1 +num_workers: 0 +version: "1" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 32c3374dd..18c619e88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ metrics = [ "umap-learn", "captum>=0.7.0", "phate", + "panoptica", ] examples = ["napari", "jupyter", "jupytext"] visual = [ @@ -77,5 +78,9 @@ line-length = 88 [tool.ruff] src = ["viscy", "tests"] -lint.extend-select = ["I001"] -lint.isort.known-first-party = ["viscy"] + +[tool.ruff.lint] +extend-select = ["I001"] + +[tool.ruff.lint.isort] +known-first-party = ["viscy"] diff --git a/tests/translation/test_evaluation.py b/tests/translation/test_evaluation.py index 14d399ef9..9a04b6176 100644 --- a/tests/translation/test_evaluation.py +++ b/tests/translation/test_evaluation.py @@ -4,14 +4,14 @@ from lightning.pytorch.loggers import CSVLogger from numpy.testing import assert_array_equal -from viscy.data.segmentation import SegmentationDataModule +from viscy.data.segmentation import TargetPredictionDataModule from viscy.trainer import Trainer from viscy.translation.evaluation import SegmentationMetrics2D @pytest.mark.parametrize("pred_channel", ["DAPI", "GFP"]) def test_segmentation_metrics_2d(pred_channel, labels_hcs_dataset, tmp_path) -> None: - dm = SegmentationDataModule( + dm = TargetPredictionDataModule( pred_dataset=labels_hcs_dataset, target_dataset=labels_hcs_dataset, target_channel="DAPI", diff --git a/viscy/cli.py b/viscy/cli.py index 1a82ad505..405cc92a0 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -2,6 +2,7 @@ import os import sys from datetime import datetime +from pathlib import Path import torch from jsonargparse import lazy_instance @@ -22,6 +23,7 @@ def subcommands() -> dict[str, set[str]]: subcommands["preprocess"] = subcommand_base_args subcommands["export"] = subcommand_base_args subcommands["precompute"] = subcommand_base_args + subcommands["compute_dynacell_metrics"] = subcommand_base_args return subcommands def add_arguments_to_parser(self, parser) -> None: @@ -51,8 +53,15 @@ def main() -> None: Set default random seed to 42. """ _setup_environment() - require_model = {"preprocess", "precompute"}.isdisjoint(sys.argv) - require_data = {"preprocess", "precompute", "export"}.isdisjoint(sys.argv) + require_model = {"preprocess", "precompute", "compute_dynacell_metrics"}.isdisjoint( + sys.argv + ) + require_data = { + "preprocess", + "precompute", + "export", + "compute_dynacell_metrics", + }.isdisjoint(sys.argv) _ = VisCyCLI( model_class=LightningModule, datamodule_class=LightningDataModule if require_data else None, diff --git a/viscy/data/dynacell.py b/viscy/data/dynacell.py new file mode 100644 index 000000000..8a74c0f3c --- /dev/null +++ b/viscy/data/dynacell.py @@ -0,0 +1,298 @@ +"""Test stage data modules for loading data from DynaCell benchmark datasets.""" + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from iohub.ngff import open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch.utils.data import ConcatDataset, DataLoader + +from viscy.data.segmentation import TargetPredictionDataset +from viscy.data.typing import DynaCellSample + +_logger = logging.getLogger("lightning.pytorch") + + +class DynaCellDataset(TargetPredictionDataset): + """Return a DynaCellSample object with the cell type, organelle, and infection condition.""" + + def __init__( + self, + cell_type: str, + organelle: str, + infection_condition: str, + dataset: str, + transforms: list[MapTransform] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.cell_type = cell_type + self.organelle = organelle + self.infection_condition = infection_condition + self.dataset = dataset + self.transforms = Compose(transforms) if transforms else None + + def __getitem__(self, idx) -> DynaCellSample: + sample = super().__getitem__(idx) + position_idx = sample["position_idx"] + + # Convert tensors to float32 for metrics compatibility + sample["pred"] = sample["pred"].float() + sample["target"] = sample["target"].float() + + # Add channel dimension if needed for the metrics (BxHxW -> BxCxHxW) + if sample["pred"].ndim == 2: + sample["pred"] = sample["pred"].unsqueeze(0) # Add channel dimension + elif sample["pred"].ndim == 3 and sample["pred"].shape[0] == 1: + # If the first dimension is batch with size 1, reshape to [C,H,W] + sample["pred"] = sample["pred"].squeeze(0).unsqueeze(0) + + if sample["target"].ndim == 2: + sample["target"] = sample["target"].unsqueeze(0) # Add channel dimension + elif sample["target"].ndim == 3 and sample["target"].shape[0] == 1: + # If the first dimension is batch with size 1, reshape to [C,H,W] + sample["target"] = sample["target"].squeeze(0).unsqueeze(0) + + sample.update( + { + "cell_type": self.cell_type, + "organelle": self.organelle, + "infection_condition": self.infection_condition, + "dataset": self.dataset, + "position_name": self.position_names[position_idx], + } + ) + + # Apply transforms if provided + if self.transforms: + sample = self.transforms(sample) + + return sample + + +class DynaCellDatabase: + """Database for DynaCell datasets filtered by cell types, organelles, and infection conditions.""" + + def __init__( + self, + database: pd.DataFrame, + cell_types: list[str], + organelles: list[str], + infection_conditions: list[str], + channel_name: str, + zarr_path_column_name: str = "Path", + z_slice: int | slice | None = None, + ): + self.database = database + self.cell_types = cell_types + self.organelles = organelles + self.infection_conditions = infection_conditions + self.channel_name = channel_name + self.z_slice = z_slice + self.zarr_path_column_name = zarr_path_column_name + + required_columns = [ + "Cell type", + "Organelle", + "Infection", + zarr_path_column_name, + ] + if not set(required_columns).issubset(self.database.columns): + raise ValueError(f"Database must contains {required_columns}.") + + self._process_database() + + def _process_database(self): + # Select the portion of the database that matches the provided criteria + self._filtered_db = self.database[ + self.database["Cell type"].isin(self.cell_types) + & self.database["Organelle"].isin(self.organelles) + & self.database["Infection"].isin(self.infection_conditions) + ].copy() + + # Extract zarr store paths + self._filtered_db["Zarr path"] = self._filtered_db[ + self.zarr_path_column_name + ].apply(lambda x: Path(*Path(x).parts[:-3])) + self._filtered_db["FOV name"] = self._filtered_db[ + self.zarr_path_column_name + ].apply(lambda x: Path(*Path(x).parts[-3:]).as_posix()) + # Fixed: dedup by both zarr path AND FOV name to preserve different positions/conditions + self._filtered_db = self._filtered_db.drop_duplicates(subset=["Zarr path", "FOV name"]) + + # Store values for later use + self.zarr_paths = self._filtered_db["Zarr path"].values.tolist() + self.position_names = self._filtered_db["FOV name"].values.tolist() + self.cell_types_per_store = self._filtered_db["Cell type"].values.tolist() + self.organelles_per_store = self._filtered_db["Organelle"].values.tolist() + self.infection_per_store = self._filtered_db["Infection"].values.tolist() + self.datasets_per_store = self._filtered_db["Dataset"].values.tolist() + + def __getitem__(self, idx) -> dict: + return { + "zarr_path": self.zarr_paths[idx], + "position_names": [self.position_names[idx]], + "cell_type": self.cell_types_per_store[idx], + "organelle": self.organelles_per_store[idx], + "infection_condition": self.infection_per_store[idx], + "dataset": self.datasets_per_store[idx], + "channel_name": self.channel_name, + "z_slice": self.z_slice, + } + + def __len__(self) -> int: + return len(self.zarr_paths) + + +class DynaCellDataModule(LightningDataModule): + """ + Lightning DataModule for DynaCell metrics computation with parallel processing support. + + This data module creates datasets from database entries and enables parallel processing + at the individual timepoint level. Each sample represents a single (position, timepoint) + combination, allowing workers to process different timepoints/positions simultaneously. + + Parallel Processing: + - Samples are distributed across workers in round-robin fashion + - Each worker loads and processes independent (position, timepoint) combinations + - Thread-safe collate function preserves metadata for metrics logging + - Recommended num_workers: 4-12 for typical HPC environments + + Parameters + ---------- + target_database : DynaCellDatabase + Database containing target image information + pred_database : DynaCellDatabase + Database containing prediction image information + batch_size : int + Batch size (typically 1 for metrics compatibility) + num_workers : int + Number of parallel workers for data loading + transforms : list[MapTransform] | None + Optional data transforms to apply + """ + def __init__( + self, + target_database: DynaCellDatabase, + pred_database: DynaCellDatabase, + batch_size: int, + num_workers: int, + transforms: list[MapTransform] | None = None, + ) -> None: + super().__init__() + self.target_database = target_database + self.pred_database = pred_database + self.batch_size = batch_size + self.num_workers = num_workers + self.transforms = transforms + + def setup(self, stage: str) -> None: + if stage != "test": + raise NotImplementedError("Only test stage is supported!") + + # Verify both databases have the same length + if len(self.target_database) != len(self.pred_database): + raise ValueError( + f"Target database length ({len(self.target_database)}) doesn't match " + f"prediction database length ({len(self.pred_database)})" + ) + + # Create datasets + datasets = [] + for i in range(len(self.target_database)): + target_data = self.target_database[i] + pred_data = self.pred_database[i] + + # Ensure target and prediction metadata match + self._validate_matching_metadata(target_data, pred_data, i) + + datasets.append( + DynaCellDataset( + cell_type=target_data["cell_type"], + organelle=target_data["organelle"], + infection_condition=target_data["infection_condition"], + dataset=target_data["dataset"], + pred_dataset=open_ome_zarr(pred_data["zarr_path"]), + target_dataset=open_ome_zarr(target_data["zarr_path"]), + position_names=target_data["position_names"], + pred_channel=pred_data["channel_name"], + target_channel=target_data["channel_name"], + pred_z_slice=pred_data["z_slice"], + target_z_slice=target_data["z_slice"], + transforms=self.transforms, + dtype=np.float32, # Ensure float32 + ) + ) + + self.test_dataset = ConcatDataset(datasets) + + def _validate_matching_metadata( + self, target_data: dict, pred_data: dict, idx: int + ) -> None: + """Validate that target and prediction metadata match.""" + # Check cell type + if target_data["cell_type"] != pred_data["cell_type"]: + raise ValueError( + f"Cell type mismatch at index {idx}: " + f"target={target_data['cell_type']}, pred={pred_data['cell_type']}" + ) + + # Check organelle + if target_data["organelle"] != pred_data["organelle"]: + raise ValueError( + f"Organelle mismatch at index {idx}: " + f"target={target_data['organelle']}, pred={pred_data['organelle']}" + ) + + # Check infection condition + if target_data["infection_condition"] != pred_data["infection_condition"]: + raise ValueError( + f"Infection condition mismatch at index {idx}: " + f"target={target_data['infection_condition']}, pred={pred_data['infection_condition']}" + ) + + # Check dataset + if target_data["dataset"] != pred_data["dataset"]: + raise ValueError( + f"Dataset mismatch at index {idx}: " + f"target={target_data['dataset']}, pred={pred_data['dataset']}" + ) + + # Check zarr paths if they should match + if target_data["zarr_path"] != pred_data["zarr_path"]: + _logger.warning( + f"Zarr path mismatch at index {idx}: " + f"target={target_data['zarr_path']}, pred={pred_data['zarr_path']}" + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self._custom_collate, + ) + + def _custom_collate(self, batch): + """Custom collate function that preserves metadata strings.""" + assert len(batch) == 1, "Batch size must be 1 for DynaCellDataModule" + # Extract metadata from single element in batch + metadata = { + "cell_type": batch[0]["cell_type"], + "organelle": batch[0]["organelle"], + "infection_condition": batch[0]["infection_condition"], + "dataset": batch[0]["dataset"], + "position_name": batch[0]["position_name"], + } + + # Standard collate for tensors + collated = torch.utils.data.default_collate(batch) + + # Add metadata back into collated batch + collated.update(metadata) + + return collated diff --git a/viscy/data/segmentation.py b/viscy/data/segmentation.py index 553d9241c..8918f6452 100644 --- a/viscy/data/segmentation.py +++ b/viscy/data/segmentation.py @@ -14,32 +14,58 @@ _logger = logging.getLogger("lightning.pytorch") -class SegmentationDataset(Dataset): +class TargetPredictionDataset(Dataset): + """ + A PyTorch Dataset providing paired target and prediction images / volumes from OME-Zarr + datasets. + + Attributes: + pred_dataset (Plate): The prediction dataset Plate object. + target_dataset (Plate): The target dataset Plate object. + pred_channel (str): The channel name in the prediction dataset. + target_channel (str): The channel name in the target dataset. + pred_z_slice (int | slice | None): The z-slice or range of z-slices for the + prediction dataset. Defaults to None which is converted to slice(None). + target_z_slice (int | slice | None): The z-slice or range of z-slices for the + target dataset. Defaults to None which is converted to slice(None). + img_name (str): The name of the image to retrieve from the datasets. Defaults to "0". + dtype (np.dtype | None): The data type to cast the images to. Defaults to np.int16. + """ + def __init__( self, pred_dataset: Plate, target_dataset: Plate, pred_channel: str, target_channel: str, - pred_z_slice: int | slice, - target_z_slice: int | slice, + pred_z_slice: int | slice | None = None, + target_z_slice: int | slice | None = None, + position_names: list[str] | None = None, img_name: str = "0", + dtype: np.dtype | None = np.int16, ) -> None: super().__init__() self.pred_dataset = pred_dataset self.target_dataset = target_dataset self.pred_channel = pred_dataset.get_channel_index(pred_channel) self.target_channel = target_dataset.get_channel_index(target_channel) - self.pred_z_slice = pred_z_slice - self.target_z_slice = target_z_slice + self.pred_z_slice = pred_z_slice if pred_z_slice is not None else slice(None) + self.target_z_slice = ( + target_z_slice if target_z_slice is not None else slice(None) + ) self.img_name = img_name + self.dtype = dtype + self.position_names = position_names + if not position_names: + self.position_names = list([p[0] for p in self.target_dataset.positions()]) + self._build_indices() def _build_indices(self) -> None: self._indices = [] - for p, (name, target_fov) in enumerate(self.target_dataset.positions()): + for p, name in enumerate(self.position_names): pred_img: ImageArray = self.pred_dataset[name][self.img_name] - target_img: ImageArray = target_fov[self.img_name] + target_img: ImageArray = self.target_dataset[name][self.img_name] if not pred_img.shape[0] == target_img.shape[0]: raise ValueError( "Shape mismatch between prediction and target: " @@ -47,7 +73,12 @@ def _build_indices(self) -> None: ) for t in range(pred_img.shape[0]): self._indices.append((pred_img, target_img, p, t)) - _logger.info(f"Number of test samples: {len(self)}") + # Only log sample count once to reduce noise + if hasattr(self, '_samples_logged'): + pass # Already logged for this dataset type + else: + _logger.info(f"Built dataset with {len(self)} samples across {len(self.position_names)} positions") + type(self)._samples_logged = True def __len__(self) -> int: return len(self._indices) @@ -55,26 +86,28 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> SegmentationSample: pred_img, target_img, p, t = self._indices[idx] _logger.debug(f"Target image: {target_img.name}") - pred = torch.from_numpy( - pred_img[t, self.pred_channel, self.pred_z_slice].astype(np.int16) - ) - target = torch.from_numpy( - target_img[t, self.target_channel, self.target_z_slice].astype(np.int16) - ) + _pred = pred_img[t, self.pred_channel, self.pred_z_slice] + _target = target_img[t, self.target_channel, self.target_z_slice] + if self.dtype is not None: + _pred = _pred.astype(self.dtype) + _target = _target.astype(self.dtype) + pred = torch.from_numpy(_pred.astype(self.dtype)) + target = torch.from_numpy(_target.astype(self.dtype)) return {"pred": pred, "target": target, "position_idx": p, "time_idx": t} -class SegmentationDataModule(LightningDataModule): +class TargetPredictionDataModule(LightningDataModule): def __init__( self, pred_dataset: Path, target_dataset: Path, pred_channel: str, target_channel: str, - pred_z_slice: int, - target_z_slice: int, + pred_z_slice: int | slice | None, + target_z_slice: int | slice | None, batch_size: int, num_workers: int, + dtype: np.dtype | None = np.int16, ) -> None: super().__init__() self.pred_dataset = open_ome_zarr(pred_dataset) @@ -85,17 +118,19 @@ def __init__( self.target_z_slice = target_z_slice self.batch_size = batch_size self.num_workers = num_workers + self.dtype = dtype def setup(self, stage: str) -> None: if stage != "test": raise NotImplementedError("Only test stage is supported!") - self.test_dataset = SegmentationDataset( + self.test_dataset = TargetPredictionDataset( self.pred_dataset, self.target_dataset, self.pred_channel, self.target_channel, self.pred_z_slice, self.target_z_slice, + dtype=self.dtype, ) def test_dataloader(self) -> DataLoader: diff --git a/viscy/data/typing.py b/viscy/data/typing.py index c824b9416..88e91a708 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -64,6 +64,18 @@ class SegmentationSample(TypedDict): time_idx: OneOrSeq[int] +class DynaCellSample(SegmentationSample): + """ + DynaCell sample type for mini-batches. + """ + + cell_type: str + organelle: str + infection: str + dataset: str + position_name: str + + class ChannelMap(TypedDict): """Source channel names.""" diff --git a/viscy/trainer.py b/viscy/trainer.py index 03395a371..2e64803ef 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -1,14 +1,19 @@ +import datetime import logging from pathlib import Path from typing import Literal +import pandas as pd import torch from iohub import open_ome_zarr from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized from torch.onnx import OperatorExportTypes +from viscy.data.dynacell import DynaCellDatabase, DynaCellDataModule from viscy.preprocessing.precompute import precompute_array +from viscy.translation.evaluation import IntensityMetrics, SegmentationMetrics from viscy.utils.meta_utils import generate_normalization_metadata _logger = logging.getLogger("lightning.pytorch") @@ -128,3 +133,244 @@ def precompute( include_wells=include_wells, exclude_fovs=exclude_fovs, ) + + def compute_dynacell_metrics( + self, + target_database: Path, + pred_database: Path, + output_dir: Path, + method: str = "intensity", + target_channel: str = "Organelle", + pred_channel: str = "Organelle", + target_z_slice: int | list[int] = 16, + pred_z_slice: int | list[int] = 16, + target_cell_types: list[str] = None, + target_organelles: list[str] = None, + target_infection_conditions: list[str] = None, + pred_cell_types: list[str] = None, + pred_organelles: list[str] = None, + pred_infection_conditions: list[str] = None, + batch_size: int = 1, + num_workers: int = 0, + version: str = "1", + transforms: list = None, + model: LightningModule | None = None, + ): + """ + Compute metrics for DynaCell datasets. + + Parameters + ---------- + target_database : Path + Path to the target DynaCell database file + pred_database : Path + Path to the prediction DynaCell database file + output_dir : Path + Directory to save output metrics + method : str, optional + Type of metrics to compute ('intensity' or 'segmentation2D'), by default "intensity" + target_channel : str, optional + Channel name for target dataset, by default "Organelle" + pred_channel : str, optional + Channel name for prediction dataset, by default "Organelle" + target_z_slice : int | list[int], optional + Z-slice to use for target dataset, by default 16 + pred_z_slice : int | list[int], optional + Z-slice to use for prediction dataset, by default 16 + target_cell_types : list[str], optional + Cell types to include for target dataset, by default None (all available) + target_organelles : list[str], optional + Organelles to include for target dataset, by default None (all available) + target_infection_conditions : list[str], optional + Infection conditions to include for target dataset, by default None (all available) + pred_cell_types : list[str], optional + Cell types to include for prediction dataset, by default None (all available) + pred_organelles : list[str], optional + Organelles to include for prediction dataset, by default None (all available) + pred_infection_conditions : list[str], optional + Infection conditions to include for prediction dataset, by default None (all available) + batch_size : int, optional + Batch size for processing, by default 1 + num_workers : int, optional + Number of workers for data loading, by default 0 + version : str, optional + Version string for output directory, by default "1" + transforms : list, optional + List of transforms to apply to the data (e.g., normalization), by default None + model : LightningModule | None, optional + Ignored placeholder, by default None + """ + if model is not None: + _logger.warning("Ignoring model configuration for DynaCell metrics.") + + # Set default empty lists for filters + target_cell_types = target_cell_types or [] + target_organelles = target_organelles or [] + target_infection_conditions = target_infection_conditions or [] + pred_cell_types = pred_cell_types or [] + pred_organelles = pred_organelles or [] + pred_infection_conditions = pred_infection_conditions or [] + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate timestamp for unique versioning + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Handle z_slice values (-1 means all slices) + if isinstance(target_z_slice, list) and len(target_z_slice) == 2: + # Use the list as a range [start, stop] for the slice + if target_z_slice[1] - target_z_slice[0] == 1: + # If range length is 1, just use the single integer + target_z_slice_value = int(target_z_slice[0]) + else: + target_z_slice_value = slice(target_z_slice[0], target_z_slice[1]) + else: + target_z_slice_value = ( + slice(None) if target_z_slice == -1 else int(target_z_slice) + ) + + if isinstance(pred_z_slice, list) and len(pred_z_slice) == 2: + # Use the list as a range [start, stop] for the slice + if pred_z_slice[1] - pred_z_slice[0] == 1: + # If range length is 1, just use the single integer + pred_z_slice_value = int(pred_z_slice[0]) + else: + pred_z_slice_value = slice(pred_z_slice[0], pred_z_slice[1]) + else: + pred_z_slice_value = ( + slice(None) if pred_z_slice == -1 else int(pred_z_slice) + ) + + # Default to all available values if not specified for target database + if ( + not target_cell_types + or not target_organelles + or not target_infection_conditions + ): + _logger.info("Loading target database to get available values...") + df = pd.read_csv(target_database, dtype={"FOV": str}) + + if not target_cell_types: + target_cell_types = df["Cell type"].unique().tolist() + _logger.info( + f"Using all available target cell types: {target_cell_types}" + ) + + if not target_organelles: + target_organelles = df["Organelle"].unique().tolist() + _logger.info( + f"Using all available target organelles: {target_organelles}" + ) + + if not target_infection_conditions: + target_infection_conditions = df["Infection"].unique().tolist() + _logger.info( + f"Using all available target infection conditions: {target_infection_conditions}" + ) + + # Default to all available values if not specified for prediction database + if not pred_cell_types or not pred_organelles or not pred_infection_conditions: + _logger.info("Loading prediction database to get available values...") + df = pd.read_csv(pred_database, dtype={"FOV": str}) + + if not pred_cell_types: + pred_cell_types = df["Cell type"].unique().tolist() + _logger.info( + f"Using all available prediction cell types: {pred_cell_types}" + ) + + if not pred_organelles: + pred_organelles = df["Organelle"].unique().tolist() + _logger.info( + f"Using all available prediction organelles: {pred_organelles}" + ) + + if not pred_infection_conditions: + pred_infection_conditions = df["Infection"].unique().tolist() + _logger.info( + f"Using all available prediction infection conditions: {pred_infection_conditions}" + ) + + # Create target database + _logger.info( + f"Creating target database from {target_database} with channel '{target_channel}'" + ) + target_db = DynaCellDatabase( + database=target_database, + cell_types=target_cell_types, + organelles=target_organelles, + infection_conditions=target_infection_conditions, + channel_name=target_channel, + z_slice=target_z_slice_value, + ) + + # Create prediction database + _logger.info( + f"Creating prediction database from {pred_database} with channel '{pred_channel}'" + ) + pred_db = DynaCellDatabase( + database=pred_database, + cell_types=pred_cell_types, + organelles=pred_organelles, + infection_conditions=pred_infection_conditions, + channel_name=pred_channel, + z_slice=pred_z_slice_value, + ) + + # Create datamodule + _logger.info("Creating DynaCellDataModule...") + dm = DynaCellDataModule( + target_database=target_db, + pred_database=pred_db, + batch_size=batch_size, + num_workers=num_workers, + transforms=transforms, + ) + + # Setup datamodule + dm.setup(stage="test") + + # Determine run-specific output paths + method_dir = output_dir / method + method_dir.mkdir(exist_ok=True) + + # Unique name based on method and timestamp + run_name = f"{method}_{timestamp}" + + # Create logger + _logger.info(f"Creating logger for run '{run_name}' with version '{version}'") + logger = CSVLogger(save_dir=method_dir, name=run_name, version=version) + + # Select and run appropriate metrics + if method == "Segmentation2D": + _logger.info("Running segmentation metrics...") + metrics_module = SegmentationMetrics() + elif method == "Intensity": + _logger.info("Running intensity metrics...") + metrics_module = IntensityMetrics() + else: + raise ValueError(f"Invalid method: {method}") + + # Run the metrics computation + self.test(metrics_module, datamodule=dm) + + # Find the metrics file + metrics_file = method_dir / run_name / version / "metrics.csv" + + if not metrics_file.exists(): + _logger.warning(f"No metrics file found at {metrics_file}") + return None + + metrics_df = pd.read_csv(metrics_file) + _logger.info(f"Metrics saved to: {metrics_file}") + _logger.info(f"Computed {len(metrics_df)} metric rows") + + # Display columns in the metrics + _logger.info(f"Metrics columns: {metrics_df.columns.tolist()}") + + # Display a preview of the metrics + if not metrics_df.empty: + _logger.info(f"Metrics preview:\n{repr(metrics_df.head())}") + + return metrics_file diff --git a/viscy/translation/evaluation.py b/viscy/translation/evaluation.py index 11812f4fe..225b959c6 100644 --- a/viscy/translation/evaluation.py +++ b/viscy/translation/evaluation.py @@ -1,8 +1,12 @@ -"""Test stage lightning module for comparing virtual staining and segmentations.""" +"""Test stage lightning modules for comparing segmentation based on virtual staining and fluorescence ground truth""" import logging +import warnings +import numpy as np from lightning.pytorch import LightningModule +from panoptica import InputType, Panoptica_Evaluator +from panoptica.metrics import Metric from torchmetrics.functional import accuracy, jaccard_index from torchmetrics.functional.segmentation import dice_score @@ -12,43 +16,270 @@ _logger = logging.getLogger("lightning.pytorch") -class SegmentationMetrics2D(LightningModule): - """Test runner for 2D segmentation.""" +class SegmentationMetrics(LightningModule): + """Test runner for segmentation that handles both 2D and 3D data. - def __init__(self, aggregate_epoch: bool = False) -> None: + Parameters + ---------- + mode : str, optional + Mode of operation, can be "auto", "2D", or "3D", by default "auto". + In "auto" mode, dimensionality is determined from input data. + aggregate_epoch : bool, optional + Whether to aggregate results over the entire epoch, by default False + """ + + def __init__(self, mode: str = "auto", aggregate_epoch: bool = False) -> None: super().__init__() + self.mode = mode self.aggregate_epoch = aggregate_epoch + self._validate_mode() + + def _validate_mode(self): + """Validate the mode parameter.""" + valid_modes = ["auto", "2D", "3D"] + if self.mode not in valid_modes: + raise ValueError(f"Mode must be one of {valid_modes}, got {self.mode}") def test_step(self, batch: SegmentationSample, batch_idx: int) -> None: - pred = batch["pred"] - target = batch["target"] - if not pred.shape[0] == 1 and target.shape[0] == 1: - raise ValueError( - f"Expected 2D segmentation, got {pred.shape[0]} and {target.shape[0]}" - ) - pred = pred[0] - target = target[0] + pred = batch["pred"] #4D + target = batch["target"] + assert pred.ndim == target.ndim, f"Pred and target must have the same number of dimensions, got {pred.ndim} and {target.ndim}" + + # Determine dimensionality from input data if in auto mode + if self.mode == "auto": + if pred.ndim == 4: + current_mode = "2D" + elif pred.ndim == 5: + current_mode = "3D" + else: + raise ValueError( + f"Cannot determine dimensionality from shapes {pred.shape} and {target.shape}" + ) + else: + if self.mode == "2D": + if pred.ndim ==4: + warnings.warn(f"Pred and target have more than 2 dimensions: {pred.shape}. Taking the last two dimensions (Y,X)") + pred = pred[0,0] + target = target[0,0] + elif pred.ndim == 5: + warnings.warn(f"Pred and target have more than 2 dimensions: {pred.shape}. Taking the last two dimensions (Y,X)") + pred = pred[0,0,0] + target = target[0,0,0] + else: + raise ValueError( + f"Cannot determine dimensionality from shapes {pred.shape} and {target.shape}" + ) + elif self.mode == "3D": + if pred.ndim == 5: + warnings.warn(f"Pred and target have more than 2 dimensions: {pred.shape}. Taking the last two dimensions (Y,X)") + pred = pred[0,0,0] + target = target[0,0,0] + else: + raise ValueError( + f"Cannot determine dimensionality from shapes {pred.shape} and {target.shape}" + ) + else: + raise ValueError(f"Invalid mode: {self.mode}") + + # Common preprocessing for both modes pred_binary = pred > 0 target_binary = target > 0 + + if self.mode == "2D": + self._compute_2d_metrics(pred, target, pred_binary, target_binary, batch) + else: # 3D mode + self._compute_3d_metrics(pred, target, pred_binary, target_binary, batch) + + def _compute_2d_metrics(self, pred, target, pred_binary, target_binary, batch): + """Compute and log metrics for 2D segmentation.""" coco_metrics = mean_average_precision(pred, target) _logger.debug(coco_metrics) - self.logger.log_metrics( - { - "position": batch["position_idx"][0], - "time": batch["time_idx"][0], - "accuracy": (accuracy(pred_binary, target_binary, task="binary")), - "dice": ( - dice_score( - pred_binary.long()[None], - target_binary.long()[None], - num_classes=2, - input_format="index", - ) - ), - "jaccard": (jaccard_index(pred_binary, target_binary, task="binary")), - "mAP": coco_metrics["map"], - "mAP_50": coco_metrics["map_50"], - "mAP_75": coco_metrics["map_75"], - "mAR_100": coco_metrics["mar_100"], - } + + # Create metrics dictionary + metrics_dict = { + "position": batch["position_idx"][0], + "time": batch["time_idx"][0], + "accuracy": (accuracy(pred_binary, target_binary, task="binary")), + "dice": ( + dice_score( + pred_binary.long()[None], + target_binary.long()[None], + num_classes=2, + input_format="index", + ) + ), + "jaccard": (jaccard_index(pred_binary, target_binary, task="binary")), + "mAP": coco_metrics["map"], + "mAP_50": coco_metrics["map_50"], + "mAP_75": coco_metrics["map_75"], + "mAR_100": coco_metrics["mar_100"], + } + + # Add metadata if available + if "cell_type" in batch: + metrics_dict["cell_type"] = batch["cell_type"] + if "organelle" in batch: + metrics_dict["organelle"] = batch["organelle"] + if "infection_condition" in batch: + metrics_dict["infection_condition"] = batch["infection_condition"] + + self.logger.log_metrics(metrics_dict) + + def _compute_3d_metrics(self, pred, target, pred_binary, target_binary, batch): + """Compute and log metrics for 3D segmentation.""" + unique_instances_target = np.unique(target) + unique_instances_pred = np.unique(pred) + + _logger.debug( + f"Unique instances: {unique_instances_target} and {unique_instances_pred}" + ) + + ## Measuring Panoptic Quality + evaluator = Panoptica_Evaluator( + expected_input=InputType.UnmatchedInstancePair, + instance_metrics=[Metric.DSC, Metric.IoU], + decision_metric=Metric.DSC, + decision_threshold=0.5, + log_times=True, ) + result = evaluator.evaluate(pred, target, verbose=False) + result = result.to_dict() + _logger.debug(result) + + # Create metrics dictionary + metrics_dict = { + "position": batch["position_idx"][0], + "time": batch["time_idx"][0], + "target_instances": unique_instances_target, + "pred_instances": unique_instances_pred, + **result, + } + + # Add metadata if available + if "cell_type" in batch: + metrics_dict["cell_type"] = batch["cell_type"] + if "organelle" in batch: + metrics_dict["organelle"] = batch["organelle"] + if "infection_condition" in batch: + metrics_dict["infection_condition"] = batch["infection_condition"] + + self.logger.log_metrics(metrics_dict) + + +class BiologicalMetrics(LightningModule): + """Test runner for biological metrics.""" + + def __init__(self, aggregate_epoch: bool = False) -> None: + super().__init__() + self.aggregate_epoch = aggregate_epoch + + def test_step(self, batch: SegmentationSample, batch_idx: int) -> None: + # TODO: Implement biological metrics (i.e regionprops logic) + raise NotImplementedError("Biological metrics not implemented") + + +class IntensityMetrics(LightningModule): + """Test runner for intensity metrics. + + Parameters + ---------- + metrics : list[str], optional + List of metrics to compute, by default ["mae", "mse", "ssim", "pearson"] + aggregate_epoch : bool, optional + Whether to aggregate results over the entire epoch, by default False + """ + + def __init__( + self, + metrics: list[str] = ["mae", "mse", "ssim", "pearson"], + aggregate_epoch: bool = False, + ) -> None: + super().__init__() + self.metrics = metrics + self.aggregate_epoch = aggregate_epoch + self._validate_metrics() + + def _validate_metrics(self): + """Validate the metrics parameter.""" + valid_metrics = ["mae", "mse", "ssim", "ms_ssim", "pearson", "cosine"] + for metric in self.metrics: + if metric not in valid_metrics: + raise ValueError(f"Metric '{metric}' not in {valid_metrics}") + + def test_step(self, batch, batch_idx: int) -> None: + """Compute intensity metrics between prediction and target.""" + from torchmetrics.functional import ( + cosine_similarity, + mean_absolute_error, + mean_squared_error, + pearson_corrcoef, + structural_similarity_index_measure, + ) + + from viscy.translation.evaluation_metrics import ms_ssim_25d + + pred = batch["pred"] + target = batch["target"] + + # Dictionary to store computed metrics + metrics_dict = { + "position": batch["position_idx"][0] if "position_idx" in batch else -1, + "time": batch["time_idx"][0] if "time_idx" in batch else -1, + } + + # Add metadata if available + if "cell_type" in batch: + metrics_dict["cell_type"] = batch["cell_type"] + if "organelle" in batch: + metrics_dict["organelle"] = batch["organelle"] + if "infection_condition" in batch: + metrics_dict["infection_condition"] = batch["infection_condition"] + if "dataset" in batch: + metrics_dict["dataset"] = batch["dataset"] + if "position_name" in batch: + metrics_dict["position_name"] = batch["position_name"] + + # Compute metrics + for metric in self.metrics: + if metric == "mae": + metrics_dict["mae"] = mean_absolute_error(pred, target) + elif metric == "mse": + metrics_dict["mse"] = mean_squared_error(pred, target) + elif metric == "ssim": + # TODO: find out more about data_range parameter + # Handle different dimensionality cases + if pred.shape[0] > 1: # 3D/2.5D case + metrics_dict["ssim"] = structural_similarity_index_measure( + ( + pred.squeeze(2) + if pred.shape[2] == 1 + else pred[:, :, pred.shape[2] // 2] + ), + ( + target.squeeze(2) + if target.shape[2] == 1 + else target[:, :, target.shape[2] // 2] + ), + ) + else: # 2D case + metrics_dict["ssim"] = structural_similarity_index_measure( + pred, target + ) + elif metric == "ms_ssim": + if pred.ndim > 1: + metrics_dict["ms_ssim"] = ms_ssim_25d(pred, target) + elif metric == "pearson": + metrics_dict["pearson"] = pearson_corrcoef( + pred.flatten(), target.flatten() + ) + elif metric == "cosine": + metrics_dict["cosine"] = cosine_similarity( + pred.flatten(), target.flatten() + ) + + # Convert tensors to Python numbers before logging + from viscy.utils.logging import convert_tensors_to_numbers + cleaned_metrics = convert_tensors_to_numbers(metrics_dict) + + # Log computed metrics + self.logger.log_metrics(cleaned_metrics) diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index 33c66f9da..a5844475f 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -1,12 +1,59 @@ import datetime +import logging import os +import threading import time +from pathlib import Path +from typing import Any, Dict +import pandas as pd import torch +from lightning.pytorch.loggers import Logger from viscy.utils.cli_utils import save_figure from viscy.utils.normalize import hist_clipping +_logger = logging.getLogger("lightning.pytorch") + + +def convert_tensors_to_numbers(metrics: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert tensor values in a metrics dictionary to Python numbers. + + This function handles PyTorch tensors, numpy arrays, and other tensor types + that have single scalar values, converting them to native Python types + suitable for CSV serialization. + + Parameters + ---------- + metrics : Dict[str, Any] + Dictionary containing metric names and values, potentially with tensor objects + + Returns + ------- + Dict[str, Any] + Dictionary with tensor values converted to Python numbers + + Examples + -------- + >>> import torch + >>> metrics = {"mae": torch.tensor(0.5), "count": 42, "name": "test"} + >>> convert_tensors_to_numbers(metrics) + {"mae": 0.5, "count": 42, "name": "test"} + """ + cleaned_metrics = {} + for key, value in metrics.items(): + if hasattr(value, 'item'): # PyTorch tensor with single value + cleaned_metrics[key] = value.item() + elif hasattr(value, 'numpy'): # Other tensor types (e.g., numpy arrays with single values) + if hasattr(value.numpy(), 'item'): + cleaned_metrics[key] = value.numpy().item() + else: + cleaned_metrics[key] = value.numpy() + else: + cleaned_metrics[key] = value + return cleaned_metrics + def log_feature(feature_map, name, log_save_folder, debug_mode): """ @@ -282,3 +329,86 @@ def interleave_bars(self, arrays, axis, pixel_width=3, value=0): for i in range(1, len(arrays) * 2 - 1, 2): arrays.insert(i, bar) return arrays + + +class ParallelSafeMetricsLogger(Logger): + """ + A Lightning logger that collects metrics in memory and writes them atomically + to avoid race conditions with multiple workers. + """ + + def __init__(self, save_dir: Path, name: str, version: str): + super().__init__() + self._save_dir = Path(save_dir) + self._name = name + self._version = version + self._log_dir = self._save_dir / name / version + self._log_dir.mkdir(parents=True, exist_ok=True) + + # Thread-safe storage for metrics + self._metrics = [] + self._lock = threading.Lock() + self._finalized = False + + @property + def experiment(self): + return None + + @property + def save_dir(self) -> Path: + return self._save_dir + + @property + def log_dir(self) -> Path: + return self._log_dir + + def log_metrics(self, metrics: Dict[str, Any], step: int = None) -> None: + """Log metrics in a thread-safe manner.""" + with self._lock: + # Convert tensor values to Python numbers + cleaned_metrics = convert_tensors_to_numbers(metrics) + + # Add step if provided + if step is not None: + cleaned_metrics["step"] = step + self._metrics.append(cleaned_metrics) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Hyperparameters logging - not needed for metrics computation.""" + pass + + def finalize(self, status: str = "") -> None: + """Write all collected metrics to CSV file atomically.""" + if self._finalized: + return # Already finalized, skip duplicate write + + if not self._metrics: + _logger.warning("No metrics collected") + return + + # Convert to DataFrame (tensors already converted in log_metrics) + df = pd.DataFrame(self._metrics) + + # Write to CSV file + metrics_file = self._log_dir / "metrics.csv" + df.to_csv(metrics_file, index=False) + + _logger.debug(f"ParallelSafeLogger: Wrote {len(df)} records to {metrics_file.name}") + + self._finalized = True + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, name: str) -> None: + self._name = name + + @property + def version(self) -> str: + return self._version + + @version.setter + def version(self, version: str) -> None: + self._version = version