Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
96f266f
protype of metrics and files to modify
edyoshikun Apr 21, 2025
3aa3ddf
fix ruff settings
ieivanov Apr 22, 2025
d5394d3
add panoptica to deps
ieivanov Apr 22, 2025
93524ea
remove metrics module and add optional dtype casting to SegmentationD…
ieivanov Apr 23, 2025
71eb7ed
rename Segmentation to TargetPrediction modules
ieivanov Apr 23, 2025
5275349
add dynacell data module draft
ieivanov Apr 23, 2025
56a0119
dynacell dataloader WIP
ieivanov Apr 24, 2025
10a33c7
style
ieivanov Apr 24, 2025
c13914f
style
ieivanov Apr 25, 2025
48df3de
debug
ieivanov Apr 25, 2025
620e368
splitting the logic for computing metrics by accepting two databases …
edyoshikun May 8, 2025
149d36e
fix demo
edyoshikun May 8, 2025
bc93f89
CLI prototype to compute metrics
edyoshikun May 8, 2025
d3c3d8a
support z-slice 3d via list to slice object conversion
edyoshikun May 8, 2025
b7141d9
allow for independent target and prediction databases
ieivanov Jul 18, 2025
4f2e13a
refactor demo script
ieivanov Jul 18, 2025
9bc4f49
rename demo script
ieivanov Jul 18, 2025
e9ab7a1
docs
ieivanov Jul 18, 2025
db86e3c
test of specified positions
ieivanov Jul 18, 2025
5b1b683
fixing the trainer from 'auto' for resources to 'cpu' and limiting th…
edyoshikun Jul 23, 2025
ea4ed97
adding transforms to do normalization.
edyoshikun Jul 29, 2025
0e9dd7c
WIP
ieivanov Jul 31, 2025
7215723
bugfix - convert data to float32
ieivanov Jul 31, 2025
a10ac34
segment prototype. can be deleted later
edyoshikun Jul 31, 2025
ab8617b
add plotting
edyoshikun Jul 31, 2025
8da84e9
vs metrics v1
ieivanov Sep 8, 2025
411aeb6
compute metrics on multiple conditions at a time
ieivanov Sep 9, 2025
1de0525
use multiple workers
ieivanov Sep 9, 2025
dbd9b56
cleaner messaging
ieivanov Sep 9, 2025
9467a91
use gpu acceleration
ieivanov Sep 10, 2025
62d2703
add note on ssim data_range
ieivanov Sep 10, 2025
c80e4aa
ivan's VS metrics scripts
ieivanov Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions applications/DynaCell/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# DynaCell Application

This directory contains tools for computing and analyzing virtual staining metrics using the DynaCell benchmark datasets.

## Overview

The DynaCell application provides functionality to:
- Compute intensity-based metrics between target and predicted virtual staining images
- Process multiple infection conditions (Mock, DENV) in single function calls
- Support parallel processing for faster computation
- Generate detailed CSV reports with position names and dataset information

## Key Files

- `compute_virtual_staining_metrics.py` - Main script for computing VSCyto3D and CellDiff metrics
- `benchmark.py` - Core functions for metrics computation with parallel processing support
- `example_parallel_usage.py` - Example showing sequential vs parallel processing
- `test_parallel_metrics.py` - Test script for verifying parallel functionality

## Parallel Processing Architecture

### Dataset Structure and Worker Distribution

The DynaCell metrics pipeline processes data at the individual timepoint level, enabling efficient parallel processing:

```
Dataset Structure:
├── Position B/1/000001 (Mock, A549, HIST2H2BE)
│ ├── Sample 0: timepoint 0 <- Worker A processes this
│ ├── Sample 1: timepoint 1 <- Worker B processes this
│ ├── Sample 2: timepoint 2 <- Worker C processes this
│ └── Sample 3: timepoint 3 <- Worker D processes this
├── Position B/1/000002 (DENV, A549, HIST2H2BE)
│ ├── Sample 4: timepoint 0 <- Worker A processes this
│ ├── Sample 5: timepoint 1 <- Worker B processes this
│ └── ...
└── Position B/2/000001 (Mock, A549, HIST2H2BE)
├── Sample N: timepoint 0 <- Worker C processes this
└── ...
```

### How Workers Process Data

**Granularity**: Each worker processes individual `(position, timepoint)` combinations

**Distribution**: PyTorch's DataLoader distributes samples across workers in round-robin fashion:
- With `num_workers=4` and 100 samples: Worker 0 gets samples [0,4,8,12...], Worker 1 gets [1,5,9,13...], etc.

**Batch Processing**: With `batch_size=1`, each worker processes exactly one sample at a time for metrics compatibility

**Concurrency Benefits**:
- **I/O Parallelism**: Workers read different zarr files/slices simultaneously
- **CPU Parallelism**: Image processing, transforms, and metrics computation happen in parallel
- **Memory Efficiency**: Each worker only loads one timepoint at a time

### Performance Optimization

This design is particularly effective for DynaCell data because:
- Each timepoint requires significant I/O (loading image slices from zarr)
- Metrics computation is CPU-intensive
- Different timepoints are independent and can be processed in any order

**Recommended Settings**:
- `num_workers=4-12` for typical HPC setups
- `batch_size=1` (hardcoded for metrics compatibility)
- More workers help when you have many positions/timepoints

## Usage Examples

### Basic Usage (Sequential)
```python
metrics = compute_metrics(
metrics_module=IntensityMetrics(),
cell_types=["A549"],
organelles=["HIST2H2BE"],
infection_conditions=["Mock", "DENV"],
target_database=database,
target_channel_name="raw Cy5 EX639 EM698-70",
prediction_database=database,
prediction_channel_name="nuclei_prediction",
log_output_dir=output_dir,
log_name="metrics_example"
)
```

### Parallel Processing
```python
metrics = compute_metrics(
# ... same parameters as above ...
num_workers=8, # Use 8 workers for parallel processing
)
```

### Multiple Conditions
The system supports processing multiple infection conditions in a single call:
```python
infection_conditions=["Mock", "DENV"] # Processes both conditions together
```

## Output Format

The generated CSV files include:
- **Standard metrics**: SSIM, PSNR, correlation coefficients, etc.
- **Position information**: `position_name` (e.g., "B/1/000001")
- **Dataset information**: `dataset` name from the database
- **Condition metadata**: `cell_type`, `organelle`, `infection_condition`
- **Temporal information**: `time_idx` for tracking timepoints

## Thread Safety

The system uses `ParallelSafeMetricsLogger` to prevent race conditions when multiple workers write metrics. This logger:
- Collects metrics in memory during processing
- Writes all data atomically to CSV after completion
- Prevents file corruption from concurrent writes

## Database Structure

The system expects database CSV files with columns:
- `Cell type` - Cell line (e.g., "A549", "HEK293T")
- `Organelle` - Target organelle (e.g., "HIST2H2BE")
- `Infection` - Condition (e.g., "Mock", "DENV")
- `Path` - Path to zarr files
- `Dataset` - Dataset identifier
- Additional metadata columns

The system automatically handles:
- Filtering by multiple conditions using OR logic
- Deduplication by zarr path AND FOV name
- Metadata preservation through the processing pipeline
218 changes: 218 additions & 0 deletions applications/DynaCell/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""
This script is a demo script for the DynaCell application.
It loads the ome-zarr 0.4v format, calculates metrics and saves the results as csv files
"""

import datetime
import os
from pathlib import Path

import pandas as pd
import torch
from lightning import LightningModule
from lightning.pytorch.loggers import CSVLogger

from viscy.data.dynacell import DynaCellDatabase, DynaCellDataModule
from viscy.trainer import Trainer
from viscy.utils.logging import ParallelSafeMetricsLogger

# Set float32 matmul precision for better performance on Tensor Cores
torch.set_float32_matmul_precision("high")

# Suppress Lightning warnings for intentional CPU usage
os.environ["SLURM_NTASKS"] = "1" # Suppress SLURM warning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this have side effect for the current shell?

import warnings
warnings.filterwarnings("ignore", "GPU available but not used")
warnings.filterwarnings("ignore", "The `srun` command is available")


def compute_metrics(
metrics_module: LightningModule,
cell_types: list,
organelles: list,
infection_conditions: list,
target_database: pd.DataFrame,
target_channel_name: str,
prediction_database: pd.DataFrame,
prediction_channel_name: str,
log_output_dir: Path,
log_name: str = "dynacell_metrics",
log_version: str = None,
z_slice: slice = None,
transforms: list = None,
num_workers: int = 0,
use_gpu: bool = False,
):
"""
Compute DynaCell metrics with optional parallel processing.

This function processes virtual staining metrics at the individual timepoint level,
enabling efficient parallel computation across multiple positions and timepoints.

Parallel Processing Architecture:
- Each sample represents one (position, timepoint) combination
- Workers are distributed samples in round-robin fashion by PyTorch DataLoader
- With num_workers=4: Worker 0 gets samples [0,4,8...], Worker 1 gets [1,5,9...], etc.
- Each worker processes different timepoints/positions simultaneously
- Thread-safe logging prevents race conditions in CSV output

Parameters
----------
metrics_module : LightningModule
The metrics module to use (e.g., IntensityMetrics())
cell_types : list
List of cell types to process (e.g., ["A549"])
organelles : list
List of organelles to process (e.g., ["HIST2H2BE"])
infection_conditions : list
List of infection conditions to process (e.g., ["Mock", "DENV"])
Multiple conditions are processed with OR logic in a single call
target_database : pd.DataFrame
Database containing target image paths and metadata
target_channel_name : str
Channel name in target dataset
prediction_database : pd.DataFrame
Database containing prediction image paths and metadata
prediction_channel_name : str
Channel name in prediction dataset
log_output_dir : Path
Directory for output metrics CSV files
log_name : str, optional
Name for metrics logging, by default "dynacell_metrics"
log_version : str, optional
Version string for logging, by default None (uses timestamp)
z_slice : slice, optional
Z-slice to extract from 3D data, by default None
transforms : list, optional
List of data transforms to apply, by default None
num_workers : int, optional
Number of workers for parallel data loading, by default 0 (sequential)
Recommended: 2-4 workers for CPU, 4-8 for GPU
use_gpu : bool, optional
Whether to use GPU acceleration, by default False
GPU provides 10-25x speedup for metrics computation

Notes
-----
- GPU acceleration provides massive speedup for metrics computation
- batch_size is hardcoded to 1 for compatibility with existing metrics code
- GPU acceleration works excellently even with batch_size=1
- Uses ParallelSafeMetricsLogger to prevent race conditions in CSV writing
- Output CSV includes position_name, dataset, and condition metadata

Returns
-------
pd.DataFrame or None
Metrics DataFrame if CSV file is successfully created, None otherwise
"""
# Generate timestamp for unique versioning
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if log_version is None:
log_version = timestamp

# Create target database
target_db = DynaCellDatabase(
database=target_database,
cell_types=cell_types,
organelles=organelles,
infection_conditions=infection_conditions,
channel_name=target_channel_name,
z_slice=z_slice,
)

# For segmentation, use same channel for pred and target (self-comparison)
pred_db = DynaCellDatabase(
database=prediction_database,
cell_types=cell_types,
organelles=organelles,
infection_conditions=infection_conditions,
channel_name=prediction_channel_name,
z_slice=z_slice,
)

# Create data module with both databases
dm = DynaCellDataModule(
target_database=target_db,
pred_database=pred_db,
batch_size=1, # Hardcoded to 1 for metrics compatibility
num_workers=num_workers,
transforms=transforms,
)
dm.setup(stage="test")

# Print dataset configuration summary
sample = next(iter(dm.test_dataloader()))
# Determine device and processing info
device_name = "GPU" if use_gpu and torch.cuda.is_available() else "CPU"
processing_mode = "Parallel" if num_workers > 0 else "Sequential"

print(f"\n📊 Dataset Configuration:")
print(f" • Samples: {len(dm.test_dataset)} total across all positions/timepoints")
print(f" • Cell types: {cell_types}")
print(f" • Organelles: {organelles}")
print(f" • Infection conditions: {infection_conditions}")
print(f" • Sample metadata: {sample['cell_type']}, {sample['organelle']}, {sample['infection_condition']}")

# Setup logging
log_output_dir.mkdir(exist_ok=True)

if num_workers > 0:
logger = ParallelSafeMetricsLogger(save_dir=log_output_dir, name=log_name, version=log_version)
print(f"\n🚀 Processing Mode: {processing_mode} ({num_workers} workers)")
else:
logger = CSVLogger(save_dir=log_output_dir, name=log_name, version=log_version)
print(f"\n🔄 Processing Mode: {processing_mode} (single-threaded)")

print(f" • Device: {device_name}")
print(f" • Batch size: 1 (hardcoded for metrics compatibility)")
if use_gpu and torch.cuda.is_available():
print(f" • GPU: {torch.cuda.get_device_name()}")

# Configure trainer based on device preference
if use_gpu and torch.cuda.is_available():
accelerator = "gpu"
precision = "16-mixed" # Use fp16 on GPU
else:
accelerator = "cpu"
precision = "bf16-mixed" # Use bf16 for CPU

trainer = Trainer(
logger=logger,
accelerator=accelerator,
devices=1,
precision=precision,
num_nodes=1,
enable_progress_bar=True,
enable_model_summary=False
)
trainer.test(metrics_module, datamodule=dm)

# Finalize logging if using parallel-safe logger
if hasattr(logger, 'finalize'):
logger.finalize()

# Find and report results
metrics_file = log_output_dir / log_name / log_version / "metrics.csv"
if metrics_file.exists():
metrics = pd.read_csv(metrics_file)
print(f"\n✅ Metrics computation completed successfully!")
print(f" • Output file: {metrics_file}")
print(f" • Records: {len(metrics)} samples")
print(f" • Device: {device_name}")
print(f" • Batch size: 1 (hardcoded)")
print(f" • Metrics: {[col for col in metrics.columns if col not in ['position', 'time', 'cell_type', 'organelle', 'infection_condition', 'dataset', 'position_name']]}")

# Show infection condition breakdown
if 'infection_condition' in metrics.columns:
condition_counts = metrics['infection_condition'].value_counts()
print(f" • Conditions: {dict(condition_counts)}")

# Show GPU memory usage if applicable
if use_gpu and torch.cuda.is_available():
memory_used = torch.cuda.max_memory_allocated() / 1024**3
print(f" • Peak GPU memory: {memory_used:.2f} GB")
else:
print(f"❌ Warning: Metrics file not found at {metrics_file}")
metrics = None

return metrics
Loading