-
Notifications
You must be signed in to change notification settings - Fork 12
DynaCell Metrics #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
edyoshikun
wants to merge
32
commits into
main
Choose a base branch
from
dynacell_metrics
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
DynaCell Metrics #242
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
96f266f
protype of metrics and files to modify
edyoshikun 3aa3ddf
fix ruff settings
ieivanov d5394d3
add panoptica to deps
ieivanov 93524ea
remove metrics module and add optional dtype casting to SegmentationD…
ieivanov 71eb7ed
rename Segmentation to TargetPrediction modules
ieivanov 5275349
add dynacell data module draft
ieivanov 56a0119
dynacell dataloader WIP
ieivanov 10a33c7
style
ieivanov c13914f
style
ieivanov 48df3de
debug
ieivanov 620e368
splitting the logic for computing metrics by accepting two databases …
edyoshikun 149d36e
fix demo
edyoshikun bc93f89
CLI prototype to compute metrics
edyoshikun d3c3d8a
support z-slice 3d via list to slice object conversion
edyoshikun b7141d9
allow for independent target and prediction databases
ieivanov 4f2e13a
refactor demo script
ieivanov 9bc4f49
rename demo script
ieivanov e9ab7a1
docs
ieivanov db86e3c
test of specified positions
ieivanov 5b1b683
fixing the trainer from 'auto' for resources to 'cpu' and limiting th…
edyoshikun ea4ed97
adding transforms to do normalization.
edyoshikun 0e9dd7c
WIP
ieivanov 7215723
bugfix - convert data to float32
ieivanov a10ac34
segment prototype. can be deleted later
edyoshikun ab8617b
add plotting
edyoshikun 8da84e9
vs metrics v1
ieivanov 411aeb6
compute metrics on multiple conditions at a time
ieivanov 1de0525
use multiple workers
ieivanov dbd9b56
cleaner messaging
ieivanov 9467a91
use gpu acceleration
ieivanov 62d2703
add note on ssim data_range
ieivanov c80e4aa
ivan's VS metrics scripts
ieivanov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# DynaCell Application | ||
|
||
This directory contains tools for computing and analyzing virtual staining metrics using the DynaCell benchmark datasets. | ||
|
||
## Overview | ||
|
||
The DynaCell application provides functionality to: | ||
- Compute intensity-based metrics between target and predicted virtual staining images | ||
- Process multiple infection conditions (Mock, DENV) in single function calls | ||
- Support parallel processing for faster computation | ||
- Generate detailed CSV reports with position names and dataset information | ||
|
||
## Key Files | ||
|
||
- `compute_virtual_staining_metrics.py` - Main script for computing VSCyto3D and CellDiff metrics | ||
- `benchmark.py` - Core functions for metrics computation with parallel processing support | ||
- `example_parallel_usage.py` - Example showing sequential vs parallel processing | ||
- `test_parallel_metrics.py` - Test script for verifying parallel functionality | ||
|
||
## Parallel Processing Architecture | ||
|
||
### Dataset Structure and Worker Distribution | ||
|
||
The DynaCell metrics pipeline processes data at the individual timepoint level, enabling efficient parallel processing: | ||
|
||
``` | ||
Dataset Structure: | ||
├── Position B/1/000001 (Mock, A549, HIST2H2BE) | ||
│ ├── Sample 0: timepoint 0 <- Worker A processes this | ||
│ ├── Sample 1: timepoint 1 <- Worker B processes this | ||
│ ├── Sample 2: timepoint 2 <- Worker C processes this | ||
│ └── Sample 3: timepoint 3 <- Worker D processes this | ||
├── Position B/1/000002 (DENV, A549, HIST2H2BE) | ||
│ ├── Sample 4: timepoint 0 <- Worker A processes this | ||
│ ├── Sample 5: timepoint 1 <- Worker B processes this | ||
│ └── ... | ||
└── Position B/2/000001 (Mock, A549, HIST2H2BE) | ||
├── Sample N: timepoint 0 <- Worker C processes this | ||
└── ... | ||
``` | ||
|
||
### How Workers Process Data | ||
|
||
**Granularity**: Each worker processes individual `(position, timepoint)` combinations | ||
|
||
**Distribution**: PyTorch's DataLoader distributes samples across workers in round-robin fashion: | ||
- With `num_workers=4` and 100 samples: Worker 0 gets samples [0,4,8,12...], Worker 1 gets [1,5,9,13...], etc. | ||
|
||
**Batch Processing**: With `batch_size=1`, each worker processes exactly one sample at a time for metrics compatibility | ||
|
||
**Concurrency Benefits**: | ||
- **I/O Parallelism**: Workers read different zarr files/slices simultaneously | ||
- **CPU Parallelism**: Image processing, transforms, and metrics computation happen in parallel | ||
- **Memory Efficiency**: Each worker only loads one timepoint at a time | ||
|
||
### Performance Optimization | ||
|
||
This design is particularly effective for DynaCell data because: | ||
- Each timepoint requires significant I/O (loading image slices from zarr) | ||
- Metrics computation is CPU-intensive | ||
- Different timepoints are independent and can be processed in any order | ||
|
||
**Recommended Settings**: | ||
- `num_workers=4-12` for typical HPC setups | ||
- `batch_size=1` (hardcoded for metrics compatibility) | ||
- More workers help when you have many positions/timepoints | ||
|
||
## Usage Examples | ||
|
||
### Basic Usage (Sequential) | ||
```python | ||
metrics = compute_metrics( | ||
metrics_module=IntensityMetrics(), | ||
cell_types=["A549"], | ||
organelles=["HIST2H2BE"], | ||
infection_conditions=["Mock", "DENV"], | ||
target_database=database, | ||
target_channel_name="raw Cy5 EX639 EM698-70", | ||
prediction_database=database, | ||
prediction_channel_name="nuclei_prediction", | ||
log_output_dir=output_dir, | ||
log_name="metrics_example" | ||
) | ||
``` | ||
|
||
### Parallel Processing | ||
```python | ||
metrics = compute_metrics( | ||
# ... same parameters as above ... | ||
num_workers=8, # Use 8 workers for parallel processing | ||
) | ||
``` | ||
|
||
### Multiple Conditions | ||
The system supports processing multiple infection conditions in a single call: | ||
```python | ||
infection_conditions=["Mock", "DENV"] # Processes both conditions together | ||
``` | ||
|
||
## Output Format | ||
|
||
The generated CSV files include: | ||
- **Standard metrics**: SSIM, PSNR, correlation coefficients, etc. | ||
- **Position information**: `position_name` (e.g., "B/1/000001") | ||
- **Dataset information**: `dataset` name from the database | ||
- **Condition metadata**: `cell_type`, `organelle`, `infection_condition` | ||
- **Temporal information**: `time_idx` for tracking timepoints | ||
|
||
## Thread Safety | ||
|
||
The system uses `ParallelSafeMetricsLogger` to prevent race conditions when multiple workers write metrics. This logger: | ||
- Collects metrics in memory during processing | ||
- Writes all data atomically to CSV after completion | ||
- Prevents file corruption from concurrent writes | ||
|
||
## Database Structure | ||
|
||
The system expects database CSV files with columns: | ||
- `Cell type` - Cell line (e.g., "A549", "HEK293T") | ||
- `Organelle` - Target organelle (e.g., "HIST2H2BE") | ||
- `Infection` - Condition (e.g., "Mock", "DENV") | ||
- `Path` - Path to zarr files | ||
- `Dataset` - Dataset identifier | ||
- Additional metadata columns | ||
|
||
The system automatically handles: | ||
- Filtering by multiple conditions using OR logic | ||
- Deduplication by zarr path AND FOV name | ||
- Metadata preservation through the processing pipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
""" | ||
This script is a demo script for the DynaCell application. | ||
It loads the ome-zarr 0.4v format, calculates metrics and saves the results as csv files | ||
""" | ||
|
||
import datetime | ||
import os | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import torch | ||
from lightning import LightningModule | ||
from lightning.pytorch.loggers import CSVLogger | ||
|
||
from viscy.data.dynacell import DynaCellDatabase, DynaCellDataModule | ||
from viscy.trainer import Trainer | ||
from viscy.utils.logging import ParallelSafeMetricsLogger | ||
|
||
# Set float32 matmul precision for better performance on Tensor Cores | ||
torch.set_float32_matmul_precision("high") | ||
|
||
# Suppress Lightning warnings for intentional CPU usage | ||
os.environ["SLURM_NTASKS"] = "1" # Suppress SLURM warning | ||
import warnings | ||
warnings.filterwarnings("ignore", "GPU available but not used") | ||
warnings.filterwarnings("ignore", "The `srun` command is available") | ||
|
||
|
||
def compute_metrics( | ||
metrics_module: LightningModule, | ||
cell_types: list, | ||
organelles: list, | ||
infection_conditions: list, | ||
target_database: pd.DataFrame, | ||
target_channel_name: str, | ||
prediction_database: pd.DataFrame, | ||
prediction_channel_name: str, | ||
log_output_dir: Path, | ||
log_name: str = "dynacell_metrics", | ||
log_version: str = None, | ||
z_slice: slice = None, | ||
transforms: list = None, | ||
num_workers: int = 0, | ||
use_gpu: bool = False, | ||
): | ||
""" | ||
Compute DynaCell metrics with optional parallel processing. | ||
|
||
This function processes virtual staining metrics at the individual timepoint level, | ||
enabling efficient parallel computation across multiple positions and timepoints. | ||
|
||
Parallel Processing Architecture: | ||
- Each sample represents one (position, timepoint) combination | ||
- Workers are distributed samples in round-robin fashion by PyTorch DataLoader | ||
- With num_workers=4: Worker 0 gets samples [0,4,8...], Worker 1 gets [1,5,9...], etc. | ||
- Each worker processes different timepoints/positions simultaneously | ||
- Thread-safe logging prevents race conditions in CSV output | ||
|
||
Parameters | ||
---------- | ||
metrics_module : LightningModule | ||
The metrics module to use (e.g., IntensityMetrics()) | ||
cell_types : list | ||
List of cell types to process (e.g., ["A549"]) | ||
organelles : list | ||
List of organelles to process (e.g., ["HIST2H2BE"]) | ||
infection_conditions : list | ||
List of infection conditions to process (e.g., ["Mock", "DENV"]) | ||
Multiple conditions are processed with OR logic in a single call | ||
target_database : pd.DataFrame | ||
Database containing target image paths and metadata | ||
target_channel_name : str | ||
Channel name in target dataset | ||
prediction_database : pd.DataFrame | ||
Database containing prediction image paths and metadata | ||
prediction_channel_name : str | ||
Channel name in prediction dataset | ||
log_output_dir : Path | ||
Directory for output metrics CSV files | ||
log_name : str, optional | ||
Name for metrics logging, by default "dynacell_metrics" | ||
log_version : str, optional | ||
Version string for logging, by default None (uses timestamp) | ||
z_slice : slice, optional | ||
Z-slice to extract from 3D data, by default None | ||
transforms : list, optional | ||
List of data transforms to apply, by default None | ||
num_workers : int, optional | ||
Number of workers for parallel data loading, by default 0 (sequential) | ||
Recommended: 2-4 workers for CPU, 4-8 for GPU | ||
use_gpu : bool, optional | ||
Whether to use GPU acceleration, by default False | ||
GPU provides 10-25x speedup for metrics computation | ||
|
||
Notes | ||
----- | ||
- GPU acceleration provides massive speedup for metrics computation | ||
- batch_size is hardcoded to 1 for compatibility with existing metrics code | ||
- GPU acceleration works excellently even with batch_size=1 | ||
- Uses ParallelSafeMetricsLogger to prevent race conditions in CSV writing | ||
- Output CSV includes position_name, dataset, and condition metadata | ||
|
||
Returns | ||
------- | ||
pd.DataFrame or None | ||
Metrics DataFrame if CSV file is successfully created, None otherwise | ||
""" | ||
# Generate timestamp for unique versioning | ||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||
if log_version is None: | ||
log_version = timestamp | ||
|
||
# Create target database | ||
target_db = DynaCellDatabase( | ||
database=target_database, | ||
cell_types=cell_types, | ||
organelles=organelles, | ||
infection_conditions=infection_conditions, | ||
channel_name=target_channel_name, | ||
z_slice=z_slice, | ||
) | ||
|
||
# For segmentation, use same channel for pred and target (self-comparison) | ||
pred_db = DynaCellDatabase( | ||
database=prediction_database, | ||
cell_types=cell_types, | ||
organelles=organelles, | ||
infection_conditions=infection_conditions, | ||
channel_name=prediction_channel_name, | ||
z_slice=z_slice, | ||
) | ||
|
||
# Create data module with both databases | ||
dm = DynaCellDataModule( | ||
target_database=target_db, | ||
pred_database=pred_db, | ||
batch_size=1, # Hardcoded to 1 for metrics compatibility | ||
num_workers=num_workers, | ||
transforms=transforms, | ||
) | ||
dm.setup(stage="test") | ||
|
||
# Print dataset configuration summary | ||
sample = next(iter(dm.test_dataloader())) | ||
# Determine device and processing info | ||
device_name = "GPU" if use_gpu and torch.cuda.is_available() else "CPU" | ||
processing_mode = "Parallel" if num_workers > 0 else "Sequential" | ||
|
||
print(f"\n📊 Dataset Configuration:") | ||
print(f" • Samples: {len(dm.test_dataset)} total across all positions/timepoints") | ||
print(f" • Cell types: {cell_types}") | ||
print(f" • Organelles: {organelles}") | ||
print(f" • Infection conditions: {infection_conditions}") | ||
print(f" • Sample metadata: {sample['cell_type']}, {sample['organelle']}, {sample['infection_condition']}") | ||
|
||
# Setup logging | ||
log_output_dir.mkdir(exist_ok=True) | ||
|
||
if num_workers > 0: | ||
logger = ParallelSafeMetricsLogger(save_dir=log_output_dir, name=log_name, version=log_version) | ||
print(f"\n🚀 Processing Mode: {processing_mode} ({num_workers} workers)") | ||
else: | ||
logger = CSVLogger(save_dir=log_output_dir, name=log_name, version=log_version) | ||
print(f"\n🔄 Processing Mode: {processing_mode} (single-threaded)") | ||
|
||
print(f" • Device: {device_name}") | ||
print(f" • Batch size: 1 (hardcoded for metrics compatibility)") | ||
if use_gpu and torch.cuda.is_available(): | ||
print(f" • GPU: {torch.cuda.get_device_name()}") | ||
|
||
# Configure trainer based on device preference | ||
if use_gpu and torch.cuda.is_available(): | ||
accelerator = "gpu" | ||
precision = "16-mixed" # Use fp16 on GPU | ||
else: | ||
accelerator = "cpu" | ||
precision = "bf16-mixed" # Use bf16 for CPU | ||
|
||
trainer = Trainer( | ||
logger=logger, | ||
accelerator=accelerator, | ||
devices=1, | ||
precision=precision, | ||
num_nodes=1, | ||
enable_progress_bar=True, | ||
enable_model_summary=False | ||
) | ||
trainer.test(metrics_module, datamodule=dm) | ||
|
||
# Finalize logging if using parallel-safe logger | ||
if hasattr(logger, 'finalize'): | ||
logger.finalize() | ||
|
||
# Find and report results | ||
metrics_file = log_output_dir / log_name / log_version / "metrics.csv" | ||
if metrics_file.exists(): | ||
metrics = pd.read_csv(metrics_file) | ||
print(f"\n✅ Metrics computation completed successfully!") | ||
print(f" • Output file: {metrics_file}") | ||
print(f" • Records: {len(metrics)} samples") | ||
print(f" • Device: {device_name}") | ||
print(f" • Batch size: 1 (hardcoded)") | ||
print(f" • Metrics: {[col for col in metrics.columns if col not in ['position', 'time', 'cell_type', 'organelle', 'infection_condition', 'dataset', 'position_name']]}") | ||
|
||
# Show infection condition breakdown | ||
if 'infection_condition' in metrics.columns: | ||
condition_counts = metrics['infection_condition'].value_counts() | ||
print(f" • Conditions: {dict(condition_counts)}") | ||
|
||
# Show GPU memory usage if applicable | ||
if use_gpu and torch.cuda.is_available(): | ||
memory_used = torch.cuda.max_memory_allocated() / 1024**3 | ||
print(f" • Peak GPU memory: {memory_used:.2f} GB") | ||
else: | ||
print(f"❌ Warning: Metrics file not found at {metrics_file}") | ||
metrics = None | ||
|
||
return metrics |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this have side effect for the current shell?