- Pradeep, Imran, Liu et al., 2024
-
-
-@misc{pradeep_contrastive_2024,
- title={Contrastive learning of cell state dynamics in response to perturbations},
- author={Soorya Pradeep and Alishba Imran and Ziwen Liu and Taylla Milena Theodoro and Eduardo Hirata-Miyasaki and Ivan Ivanov and Madhura Bhave and Sudip Khadka and Hunter Woosley and Carolina Arias and Shalin B. Mehta},
- year={2024},
- eprint={2410.11281},
- archivePrefix={arXiv},
- primaryClass={cs.CV},
- url={https://arxiv.org/abs/2410.11281},
-}
-
-
-### Workflow demo
+### Demo
+- [DynaCLR demos](examples/DynaCLR/README.md)
- Example test dataset, model checkpoint, and predictions can be found
[here](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_demo/).
@@ -159,7 +143,6 @@ for representation learning on [arXiv](https://arxiv.org/abs/2410.11281).
- See tutorial on exploration of learned embeddings with napari-iohub
[here](https://github.com/czbiohub-sf/napari-iohub/wiki/View-tracked-cells-and-their-associated-predictions/).
-
## Installation
diff --git a/applications/benchmarking/DynaCLR/ImageNet/config.yml b/applications/benchmarking/DynaCLR/ImageNet/config.yml
new file mode 100644
index 000000000..e832c0b21
--- /dev/null
+++ b/applications/benchmarking/DynaCLR/ImageNet/config.yml
@@ -0,0 +1,45 @@
+datamodule:
+ batch_size: 32
+ final_yx_patch_size:
+ - 160
+ - 160
+ include_fov_names: null
+ include_track_ids: null
+ initial_yx_patch_size:
+ - 160
+ - 160
+ normalizations:
+ - class_path: viscy.transforms.ScaleIntensityRangePercentilesd
+ init_args:
+ b_max: 1.0
+ b_min: 0.0
+ keys:
+ - RFP
+ lower: 50
+ upper: 99
+ num_workers: 60
+ source_channel:
+ - RFP
+ z_range:
+ - 15
+ - 45
+embedding:
+ PCA_kwargs:
+ n_components: 8
+ PHATE_kwargs:
+ decay: 40
+ knn: 5
+ n_components: 2
+ n_jobs: -1
+ random_state: 42
+execution:
+ overwrite: true
+ save_config: true
+ show_config: true
+model:
+ channel_reduction_methods:
+ RFP: max
+paths:
+ data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr
+ output_path: /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/ImageNet/20240204_A549_DENV_ZIKV_sensor_only_imagenet.zarr
+ tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr
diff --git a/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py b/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py
new file mode 100644
index 000000000..e33b7cf0c
--- /dev/null
+++ b/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py
@@ -0,0 +1,388 @@
+"""
+Generate embeddings using a pre-trained ImageNet model and save them to a zarr store
+using VisCy Trainer and EmbeddingWriter callback.
+"""
+
+import importlib
+import logging
+import os
+from pathlib import Path
+from typing import Dict, List, Literal, Optional
+
+import click
+import timm
+import torch
+import yaml
+from lightning.pytorch import LightningModule
+
+from viscy.data.triplet import TripletDataModule
+from viscy.representation.embedding_writer import EmbeddingWriter
+from viscy.trainer import VisCyTrainer
+
+logger = logging.getLogger(__name__)
+
+
+class ImageNetModule(LightningModule):
+ def __init__(
+ self,
+ model_name: str = "convnext_tiny",
+ channel_reduction_methods: Optional[
+ Dict[str, Literal["middle_slice", "mean", "max"]]
+ ] = None,
+ channel_names: Optional[List[str]] = None,
+ ):
+ """Initialize the ImageNet module.
+
+ Args:
+ model_name: Name of the pre-trained ImageNet model to use
+ channel_reduction_methods: Dict mapping channel names to reduction methods:
+ - "middle_slice": Take the middle slice along the depth dimension
+ - "mean": Average across the depth dimension
+ - "max": Take the maximum value across the depth dimension
+ channel_names: List of channel names corresponding to the input channels
+ """
+ super().__init__()
+ self.channel_reduction_methods = channel_reduction_methods or {}
+ self.channel_names = channel_names or []
+
+ try:
+ torch.set_float32_matmul_precision("high")
+ self.model = timm.create_model(model_name, pretrained=True)
+ self.model.eval()
+ except ImportError:
+ raise ImportError("Please install the timm library: " "pip install timm")
+
+ def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor:
+ """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods.
+
+ Args:
+ x: 5D input tensor
+
+ Returns:
+ 4D tensor after applying reduction methods
+ """
+ if x.dim() != 5:
+ return x
+
+ B, C, D, H, W = x.shape
+ result = torch.zeros((B, C, H, W), device=x.device)
+
+ # Process all channels at once for each reduction method to minimize loops
+ middle_slice_indices = []
+ mean_indices = []
+ max_indices = []
+
+ # Group channels by reduction method
+ for c in range(C):
+ channel_name = (
+ self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}"
+ )
+ method = self.channel_reduction_methods.get(channel_name, "middle_slice")
+
+ if method == "mean":
+ mean_indices.append(c)
+ elif method == "max":
+ max_indices.append(c)
+ else: # Default to middle_slice for any unknown method
+ middle_slice_indices.append(c)
+
+ # Apply middle_slice reduction to all relevant channels at once
+ if middle_slice_indices:
+ indices = torch.tensor(middle_slice_indices, device=x.device)
+ result[:, indices] = x[:, indices, D // 2]
+
+ # Apply mean reduction to all relevant channels at once
+ if mean_indices:
+ indices = torch.tensor(mean_indices, device=x.device)
+ result[:, indices] = x[:, indices].mean(dim=2)
+
+ # Apply max reduction to all relevant channels at once
+ if max_indices:
+ indices = torch.tensor(max_indices, device=x.device)
+ result[:, indices] = x[:, indices].max(dim=2)[0]
+
+ return result
+
+ def _convert_to_rgb(self, x: torch.Tensor) -> torch.Tensor:
+ """Convert input tensor to 3-channel RGB format as needed.
+
+ Args:
+ x: Input tensor with 1, 2, or 3+ channels
+
+ Returns:
+ 3-channel tensor suitable for ImageNet models
+ """
+ if x.shape[1] == 3:
+ return x
+ elif x.shape[1] == 1:
+ # Convert to RGB by repeating the channel 3 times
+ return x.repeat(1, 3, 1, 1)
+ elif x.shape[1] == 2:
+ # Normalize each channel independently to handle different scales
+ B, _, H, W = x.shape
+ x_3ch = torch.zeros((B, 3, H, W), device=x.device, dtype=x.dtype)
+
+ # Normalize each channel to 0-1 range
+ ch0 = x[:, 0:1]
+ ch1 = x[:, 1:2]
+
+ ch0_min = ch0.reshape(B, -1).min(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1)
+ ch0_max = ch0.reshape(B, -1).max(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1)
+ ch0_range = ch0_max - ch0_min + 1e-7 # Add epsilon for numerical stability
+ ch0_norm = (ch0 - ch0_min) / ch0_range
+
+ ch1_min = ch1.reshape(B, -1).min(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1)
+ ch1_max = ch1.reshape(B, -1).max(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1)
+ ch1_range = ch1_max - ch1_min + 1e-7 # Add epsilon for numerical stability
+ ch1_norm = (ch1 - ch1_min) / ch1_range
+
+ # Create blended RGB channels - map each normalized channel to different colors
+ x_3ch[:, 0] = ch0_norm.squeeze(1) # R channel from first input
+ x_3ch[:, 1] = ch1_norm.squeeze(1) # G channel from second input
+ x_3ch[:, 2] = 0.5 * (
+ ch0_norm.squeeze(1) + ch1_norm.squeeze(1)
+ ) # B channel as blend
+
+ return x_3ch
+ else:
+ # For more than 3 channels, use the first 3
+ return x[:, :3]
+
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
+ """Extract features from the input images.
+
+ Returns:
+ Dictionary with features, properly shaped empty projections tensor, and index information
+ """
+ x = batch["anchor"]
+
+ # Handle 5D input (B, C, D, H, W) using configured reduction methods
+ if x.dim() == 5:
+ x = self._reduce_5d_input(x)
+
+ # Convert input to RGB format
+ x = self._convert_to_rgb(x)
+
+ # Get embeddings
+ with torch.no_grad():
+ features = self.model.forward_features(x)
+
+ # Average pooling to get feature vector
+ if features.dim() > 2:
+ features = features.mean(dim=[2, 3])
+
+ # Return features and empty projections with correct batch dimension
+ return {
+ "features": features,
+ "projections": torch.zeros((features.shape[0], 0), device=features.device),
+ "index": batch["index"],
+ }
+
+
+def load_config(config_file):
+ """Load configuration from a YAML file."""
+ with open(config_file, "r") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def load_normalization_from_config(norm_config):
+ """Load a normalization transform from a configuration dictionary."""
+ class_path = norm_config["class_path"]
+ init_args = norm_config.get("init_args", {})
+
+ # Split module and class name
+ module_path, class_name = class_path.rsplit(".", 1)
+
+ # Import the module
+ module = importlib.import_module(module_path)
+
+ # Get the class
+ transform_class = getattr(module, class_name)
+
+ # Instantiate the transform
+ return transform_class(**init_args)
+
+
+@click.command()
+@click.option(
+ "--config",
+ "-c",
+ type=click.Path(exists=True),
+ required=True,
+ help="Path to YAML configuration file",
+)
+@click.option(
+ "--model",
+ "-m",
+ type=str,
+ default="convnext_tiny",
+ help="Name of the pre-trained ImageNet model to use",
+)
+def main(config, model):
+ """Extract ImageNet embeddings and save to zarr format using VisCy Trainer."""
+ # Configure logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ # Load config file
+ cfg = load_config(config)
+ logger.info(f"Loaded configuration from {config}")
+
+ # Prepare datamodule parameters
+ dm_params = {}
+
+ # Add data and tracks paths from the paths section
+ if "paths" not in cfg:
+ raise ValueError("Configuration must contain a 'paths' section")
+
+ if "data_path" not in cfg["paths"]:
+ raise ValueError(
+ "Data path is required in the configuration file (paths.data_path)"
+ )
+ dm_params["data_path"] = cfg["paths"]["data_path"]
+
+ if "tracks_path" not in cfg["paths"]:
+ raise ValueError(
+ "Tracks path is required in the configuration file (paths.tracks_path)"
+ )
+ dm_params["tracks_path"] = cfg["paths"]["tracks_path"]
+
+ # Add datamodule parameters
+ if "datamodule" not in cfg:
+ raise ValueError("Configuration must contain a 'datamodule' section")
+
+ # Prepare normalizations
+ if (
+ "normalizations" not in cfg["datamodule"]
+ or not cfg["datamodule"]["normalizations"]
+ ):
+ raise ValueError(
+ "Normalizations are required in the configuration file (datamodule.normalizations)"
+ )
+
+ norm_configs = cfg["datamodule"]["normalizations"]
+ normalizations = [load_normalization_from_config(norm) for norm in norm_configs]
+ dm_params["normalizations"] = normalizations
+
+ # Copy all other datamodule parameters
+ for param, value in cfg["datamodule"].items():
+ if param != "normalizations":
+ # Handle patch sizes
+ if param == "patch_size":
+ dm_params["initial_yx_patch_size"] = value
+ dm_params["final_yx_patch_size"] = value
+ else:
+ dm_params[param] = value
+
+ # Set up the data module
+ logger.info("Setting up data module")
+ dm = TripletDataModule(**dm_params)
+
+ # Get model parameters for handling 5D inputs
+ channel_reduction_methods = {}
+
+ if "model" in cfg and "channel_reduction_methods" in cfg["model"]:
+ channel_reduction_methods = cfg["model"]["channel_reduction_methods"]
+
+ # Initialize ImageNet model with reduction settings
+ logger.info(f"Loading ImageNet model: {model}")
+ model_module = ImageNetModule(
+ model_name=model,
+ channel_reduction_methods=channel_reduction_methods,
+ channel_names=dm_params.get("source_channel", []),
+ )
+
+ # Get dimensionality reduction parameters from config
+ PHATE_kwargs = None
+ PCA_kwargs = None
+
+ if "embedding" in cfg:
+ # Check for both capitalization variants and normalize
+ if "PHATE_kwargs" in cfg["embedding"]:
+ PHATE_kwargs = cfg["embedding"]["PHATE_kwargs"]
+
+ if "UMAP_kwargs" in cfg["embedding"]:
+ UMAP_kwargs = cfg["embedding"]["UMAP_kwargs"]
+
+ if "PCA_kwargs" in cfg["embedding"]:
+ PCA_kwargs = cfg["embedding"]["PCA_kwargs"]
+
+ # Check if output path exists and should be overwritten
+ if "output_path" not in cfg["paths"]:
+ raise ValueError(
+ "Output path is required in the configuration file (paths.output_path)"
+ )
+
+ output_path = Path(cfg["paths"]["output_path"])
+ output_dir = output_path.parent
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ overwrite = False
+ if "execution" in cfg and "overwrite" in cfg["execution"]:
+ overwrite = cfg["execution"]["overwrite"]
+ elif output_path.exists():
+ logger.warning(f"Output path {output_path} already exists, will overwrite")
+ overwrite = True
+
+ # Set up EmbeddingWriter callback
+ embedding_writer = EmbeddingWriter(
+ output_path=output_path,
+ PHATE_kwargs=PHATE_kwargs,
+ PCA_kwargs=PCA_kwargs,
+ overwrite=overwrite,
+ )
+
+ # Set up and run VisCy trainer
+ logger.info("Setting up VisCy trainer")
+ trainer = VisCyTrainer(
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
+ devices=1,
+ callbacks=[embedding_writer],
+ inference_mode=True,
+ )
+
+ logger.info(f"Running prediction and saving to {output_path}")
+ trainer.predict(model_module, datamodule=dm)
+
+ # Save configuration if requested
+ save_config_flag = True
+ show_config_flag = True
+
+ if "execution" in cfg:
+ if "save_config" in cfg["execution"]:
+ save_config_flag = cfg["execution"]["save_config"]
+ if "show_config" in cfg["execution"]:
+ show_config_flag = cfg["execution"]["show_config"]
+
+ # Save configuration if requested
+ if save_config_flag:
+ config_path = os.path.join(output_dir, "config.yml")
+ with open(config_path, "w") as f:
+ yaml.dump(cfg, f, default_flow_style=False)
+ logger.info(f"Configuration saved to {config_path}")
+
+ # Display configuration if requested
+ if show_config_flag:
+ click.echo("\nConfiguration used:")
+ click.echo("-" * 40)
+ for key, value in cfg.items():
+ click.echo(f"{key}:")
+ if isinstance(value, dict):
+ for subkey, subvalue in value.items():
+ if isinstance(subvalue, list) and subkey == "normalizations":
+ click.echo(f" {subkey}:")
+ for norm in subvalue:
+ click.echo(f" - class_path: {norm['class_path']}")
+ click.echo(f" init_args: {norm['init_args']}")
+ else:
+ click.echo(f" {subkey}: {subvalue}")
+ else:
+ click.echo(f" {value}")
+ click.echo("-" * 40)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml
new file mode 100644
index 000000000..5dd65687b
--- /dev/null
+++ b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml
@@ -0,0 +1,66 @@
+# OpenPhenom Embeddings Configuration
+
+# Paths section
+paths:
+ data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr
+ tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr
+ output_path: "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_sec61b_n_phase_3.zarr"
+
+# Model configuration
+model:
+ # Channel-specific 5D input handling methods
+ # Options: "middle_slice", "mean", "max"
+ # Default is "middle_slice" if not specified
+ channel_reduction_methods:
+ "Phase3D": "middle_slice" # For phase contrast, middle slice often works well
+ "raw GFP EX488 EM525-45": "max"
+
+# Data module configuration
+datamodule:
+ source_channel:
+ - Phase3D
+ - "raw GFP EX488 EM525-45"
+ z_range: [25, 40]
+ batch_size: 32
+ num_workers: 10
+ initial_yx_patch_size: [192, 192]
+ final_yx_patch_size: [192, 192]
+ predict_cells: true
+ include_fov_names:
+ - "/C/2/000000"
+ - "/C/2/000000"
+ - "/C/2/000000"
+ - "/C/2/000000"
+ include_track_ids: [33,60,57,65]
+ normalizations:
+ - class_path: viscy.transforms.ScaleIntensityRangePercentilesd
+ init_args:
+ keys: ["Phase3D"]
+ lower: 50
+ upper: 99
+ b_min: 0.0
+ b_max: 1.0
+ - class_path: viscy.transforms.ScaleIntensityRangePercentilesd
+ init_args:
+ keys: ["raw GFP EX488 EM525-45"]
+ lower: 50
+ upper: 99
+ b_min: 0.0
+ b_max: 1.0
+
+# Embedding parameters
+embedding:
+ PHATE_kwargs:
+ n_components: 2
+ knn: 5
+ decay: 40
+ n_jobs: -1
+ random_state: 42
+ PCA_kwargs:
+ n_components: 2
+
+# Execution configuration
+execution:
+ overwrite: false
+ save_config: true
+ show_config: true
\ No newline at end of file
diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py
new file mode 100644
index 000000000..621a711c5
--- /dev/null
+++ b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py
@@ -0,0 +1,330 @@
+"""
+Generate embeddings using the OpenPhenom model and save them to a zarr store
+using VisCy Trainer and EmbeddingWriter callback.
+"""
+
+import importlib
+import logging
+import os
+from pathlib import Path
+from typing import Dict, List, Literal, Optional
+
+import click
+import torch
+import yaml
+from lightning.pytorch import LightningModule
+from transformers import AutoModel
+
+from viscy.data.triplet import TripletDataModule
+from viscy.representation.embedding_writer import EmbeddingWriter
+from viscy.trainer import VisCyTrainer
+
+
+class OpenPhenomModule(LightningModule):
+ def __init__(
+ self,
+ channel_reduction_methods: Optional[
+ Dict[str, Literal["middle_slice", "mean", "max"]]
+ ] = None,
+ channel_names: Optional[List[str]] = None,
+ ):
+ """Initialize the OpenPhenom module.
+
+ Parameters
+ ----------
+ channel_reduction_methods : dict, optional
+ Dictionary mapping channel names to reduction methods:
+ - "middle_slice": Take the middle slice along the depth dimension
+ - "mean": Average across the depth dimension
+ - "max": Take the maximum value across the depth dimension
+ channel_names : list of str, optional
+ List of channel names corresponding to the input channels
+
+ Notes
+ -----
+ The module uses the OpenPhenom model from HuggingFace for generating embeddings.
+ """
+ super().__init__()
+
+ self.channel_reduction_methods = channel_reduction_methods or {}
+ self.channel_names = channel_names or []
+
+ try:
+ torch.set_float32_matmul_precision("high")
+ self.model = AutoModel.from_pretrained(
+ "recursionpharma/OpenPhenom", trust_remote_code=True
+ )
+ self.model.eval()
+ except ImportError:
+ raise ImportError(
+ "Please install the OpenPhenom dependencies: "
+ "pip install transformers"
+ )
+
+ def on_predict_start(self):
+ # Move model to GPU when prediction starts
+ self.model.to(self.device)
+
+ def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor:
+ """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods.
+
+ Args:
+ x: 5D input tensor
+
+ Returns:
+ 4D tensor after applying reduction methods
+ """
+ if x.dim() != 5:
+ return x
+
+ B, C, D, H, W = x.shape
+ result = torch.zeros((B, C, H, W), device=x.device)
+
+ # Apply reduction method for each channel
+ for c in range(C):
+ channel_name = (
+ self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}"
+ )
+ # Default to middle slice if not specified
+ method = self.channel_reduction_methods.get(channel_name, "middle_slice")
+
+ if method == "middle_slice":
+ result[:, c] = x[:, c, D // 2]
+ elif method == "mean":
+ result[:, c] = x[:, c].mean(dim=1)
+ elif method == "max":
+ result[:, c] = x[:, c].max(dim=1)[0]
+ else:
+ # Fallback to middle slice for unknown methods
+ result[:, c] = x[:, c, D // 2]
+
+ return result
+
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
+ """Extract features from the input images.
+
+ Returns:
+ Dictionary with features, projections (None), and index information
+ """
+ x = batch["anchor"]
+
+ # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W]
+ # If 5D input, handle according to specified reduction methods
+ if x.dim() == 5:
+ x = self._reduce_5d_input(x)
+
+ # Convert to uint8 as OpenPhenom expects uint8 inputs
+ if x.dtype != torch.uint8:
+ x = (
+ ((x - x.min()) / (x.max() - x.min()) * 255)
+ .clamp(0, 255)
+ .to(torch.uint8)
+ )
+
+ # Get embeddings
+ self.model.return_channelwise_embeddings = False
+ features = self.model.predict(x)
+ # Create empty projections tensor with same batch size as features
+ # This ensures the EmbeddingWriter can process it
+ projections = torch.zeros((features.shape[0], 0), device=features.device)
+
+ return {
+ "features": features,
+ "projections": projections,
+ "index": batch["index"],
+ }
+
+
+def load_config(config_file):
+ """Load configuration from a YAML file."""
+ with open(config_file, "r") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def load_normalization_from_config(norm_config):
+ """Load a normalization transform from a configuration dictionary."""
+ class_path = norm_config["class_path"]
+ init_args = norm_config.get("init_args", {})
+
+ # Split module and class name
+ module_path, class_name = class_path.rsplit(".", 1)
+
+ # Import the module
+ module = importlib.import_module(module_path)
+
+ # Get the class
+ transform_class = getattr(module, class_name)
+
+ # Instantiate the transform
+ return transform_class(**init_args)
+
+
+@click.command()
+@click.option(
+ "--config",
+ "-c",
+ type=click.Path(exists=True),
+ required=True,
+ help="Path to YAML configuration file",
+)
+def main(config):
+ """Extract OpenPhenom embeddings and save to zarr format using VisCy Trainer."""
+ # Configure logging
+ logging.basicConfig(level=logging.INFO)
+ logger = logging.getLogger(__name__)
+
+ # Load config file
+ cfg = load_config(config)
+ logger.info(f"Loaded configuration from {config}")
+
+ # Prepare datamodule parameters
+ dm_params = {}
+
+ # Add data and tracks paths from the paths section
+ if "paths" not in cfg:
+ raise ValueError("Configuration must contain a 'paths' section")
+
+ if "data_path" not in cfg["paths"]:
+ raise ValueError(
+ "Data path is required in the configuration file (paths.data_path)"
+ )
+ dm_params["data_path"] = cfg["paths"]["data_path"]
+
+ if "tracks_path" not in cfg["paths"]:
+ raise ValueError(
+ "Tracks path is required in the configuration file (paths.tracks_path)"
+ )
+ dm_params["tracks_path"] = cfg["paths"]["tracks_path"]
+
+ # Add datamodule parameters
+ if "datamodule" not in cfg:
+ raise ValueError("Configuration must contain a 'datamodule' section")
+
+ # Prepare normalizations
+ if (
+ "normalizations" not in cfg["datamodule"]
+ or not cfg["datamodule"]["normalizations"]
+ ):
+ raise ValueError(
+ "Normalizations are required in the configuration file (datamodule.normalizations)"
+ )
+
+ norm_configs = cfg["datamodule"]["normalizations"]
+ normalizations = [load_normalization_from_config(norm) for norm in norm_configs]
+ dm_params["normalizations"] = normalizations
+
+ # Copy all other datamodule parameters
+ for param, value in cfg["datamodule"].items():
+ if param != "normalizations":
+ # Handle patch sizes
+ if param == "patch_size":
+ dm_params["initial_yx_patch_size"] = value
+ dm_params["final_yx_patch_size"] = value
+ else:
+ dm_params[param] = value
+
+ # Set up the data module
+ logger.info("Setting up data module")
+ dm = TripletDataModule(**dm_params)
+
+ # Get model parameters for handling 5D inputs
+ channel_reduction_methods = {}
+
+ if "model" in cfg and "channel_reduction_methods" in cfg["model"]:
+ channel_reduction_methods = cfg["model"]["channel_reduction_methods"]
+
+ # Initialize OpenPhenom model with reduction settings
+ logger.info("Loading OpenPhenom model")
+ model = OpenPhenomModule(
+ channel_reduction_methods=channel_reduction_methods,
+ channel_names=dm_params.get("source_channel", []),
+ )
+
+ # Get dimensionality reduction parameters from config
+ PHATE_kwargs = None
+ PCA_kwargs = None
+
+ if "embedding" in cfg:
+ if "PHATE_kwargs" in cfg["embedding"]:
+ PHATE_kwargs = cfg["embedding"]["PHATE_kwargs"]
+ if "PCA_kwargs" in cfg["embedding"]:
+ PCA_kwargs = cfg["embedding"]["PCA_kwargs"]
+ # Check if output path exists and should be overwritten
+ if "output_path" not in cfg["paths"]:
+ raise ValueError(
+ "Output path is required in the configuration file (paths.output_path)"
+ )
+
+ output_path = Path(cfg["paths"]["output_path"])
+ output_dir = output_path.parent
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ overwrite = False
+ if "execution" in cfg and "overwrite" in cfg["execution"]:
+ overwrite = cfg["execution"]["overwrite"]
+ elif output_path.exists():
+ logger.warning(f"Output path {output_path} already exists, will overwrite")
+ overwrite = True
+
+ # Set up EmbeddingWriter callback
+ embedding_writer = EmbeddingWriter(
+ output_path=output_path,
+ PHATE_kwargs=PHATE_kwargs,
+ PCA_kwargs=PCA_kwargs,
+ overwrite=overwrite,
+ )
+
+ # Set up and run VisCy trainer
+ logger.info("Setting up VisCy trainer")
+ trainer = VisCyTrainer(
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
+ devices=1,
+ callbacks=[embedding_writer],
+ inference_mode=True,
+ )
+
+ logger.info(f"Running prediction and saving to {output_path}")
+ trainer.predict(model, datamodule=dm)
+
+ # Save configuration if requested
+ save_config_flag = True
+ show_config_flag = True
+
+ if "execution" in cfg:
+ if "save_config" in cfg["execution"]:
+ save_config_flag = cfg["execution"]["save_config"]
+ if "show_config" in cfg["execution"]:
+ show_config_flag = cfg["execution"]["show_config"]
+
+ # Save configuration if requested
+ if save_config_flag:
+ config_path = os.path.join(output_dir, "config.yml")
+ with open(config_path, "w") as f:
+ yaml.dump(cfg, f, default_flow_style=False)
+ logger.info(f"Configuration saved to {config_path}")
+
+ # Display configuration if requested
+ if show_config_flag:
+ click.echo("\nConfiguration used:")
+ click.echo("-" * 40)
+ for key, value in cfg.items():
+ click.echo(f"{key}:")
+ if isinstance(value, dict):
+ for subkey, subvalue in value.items():
+ if isinstance(subvalue, list) and subkey == "normalizations":
+ click.echo(f" {subkey}:")
+ for norm in subvalue:
+ click.echo(f" - class_path: {norm['class_path']}")
+ click.echo(f" init_args: {norm['init_args']}")
+ else:
+ click.echo(f" {subkey}: {subvalue}")
+ else:
+ click.echo(f" {value}")
+ click.echo("-" * 40)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml b/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml
deleted file mode 100644
index 47def3ae9..000000000
--- a/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml
+++ /dev/null
@@ -1,117 +0,0 @@
-# See help here on how to configure hyper-parameters with config files:
-# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
-seed_everything: 42
-trainer:
- accelerator: gpu
- strategy: auto
- devices: 1
- num_nodes: 1
- precision: 32-true
- logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
- # Nesting the logger config like this is equivalent to
- # supplying the following argument to `lightning.pytorch.Trainer`:
- # logger=TensorBoardLogger(
- # "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
- # log_graph=True,
- # version="vanilla",
- # )
- init_args:
- save_dir: /Users/ziwen.liu/Projects/test-time
- # this is the name of the experiment.
- # The logs will be saved in `save_dir/lightning_logs/version`
- version: time_interval_1
- log_graph: True
- callbacks:
- - class_path: lightning.pytorch.callbacks.LearningRateMonitor
- init_args:
- logging_interval: step
- - class_path: lightning.pytorch.callbacks.ModelCheckpoint
- init_args:
- monitor: loss/val
- every_n_epochs: 1
- save_top_k: 4
- save_last: true
- fast_dev_run: false
- max_epochs: 100
- log_every_n_steps: 10
- enable_checkpointing: true
- inference_mode: true
- use_distributed_sampler: true
- # synchronize batchnorm parameters across multiple GPUs.
- # important for contrastive learning to normalize the tensors across the whole batch.
- sync_batchnorm: true
-model:
- class_path: viscy.representation.engine.ContrastiveModule
- init_args:
- encoder:
- class_path: viscy.representation.contrastive.ContrastiveEncoder
- init_args:
- backbone: convnext_tiny
- in_channels: 1
- in_stack_depth: 1
- stem_kernel_size: [1, 4, 4]
- stem_stride: [1, 4, 4]
- embedding_dim: 768
- projection_dim: 32
- drop_path_rate: 0.0
- loss_function:
- class_path: torch.nn.TripletMarginLoss
- init_args:
- margin: 0.5
- lr: 0.0002
- log_batches_per_epoch: 3
- log_samples_per_batch: 2
- example_input_array_shape: [1, 1, 1, 128, 128]
-data:
- class_path: viscy.data.triplet.TripletDataModule
- init_args:
- data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
- tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
- source_channel:
- - DIC
- z_range: [0, 1]
- batch_size: 16
- num_workers: 4
- initial_yx_patch_size: [256, 256]
- final_yx_patch_size: [128, 128]
- time_interval: 1
- normalizations:
- - class_path: viscy.transforms.NormalizeSampled
- init_args:
- keys: [DIC]
- level: fov_statistics
- subtrahend: mean
- divisor: std
- augmentations:
- - class_path: viscy.transforms.RandAffined
- init_args:
- keys: [DIC]
- prob: 0.8
- scale_range: [0, 0.2, 0.2]
- rotate_range: [3.14, 0.0, 0.0]
- shear_range: [0.0, 0.01, 0.01]
- padding_mode: zeros
- - class_path: viscy.transforms.RandAdjustContrastd
- init_args:
- keys: [DIC]
- prob: 0.5
- gamma: [0.8, 1.2]
- - class_path: viscy.transforms.RandScaleIntensityd
- init_args:
- keys: [DIC]
- prob: 0.5
- factors: 0.5
- - class_path: viscy.transforms.RandGaussianSmoothd
- init_args:
- keys: [DIC]
- prob: 0.5
- sigma_x: [0.25, 0.75]
- sigma_y: [0.25, 0.75]
- sigma_z: [0.0, 0.0]
- - class_path: viscy.transforms.RandGaussianNoised
- init_args:
- keys: [DIC]
- prob: 0.5
- mean: 0.0
- std: 0.2
diff --git a/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml b/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml
deleted file mode 100644
index d305e32aa..000000000
--- a/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml
+++ /dev/null
@@ -1,48 +0,0 @@
-seed_everything: 42
-trainer:
- accelerator: gpu
- strategy: auto
- devices: auto
- num_nodes: 1
- precision: 32-true
- callbacks:
- - class_path: viscy.representation.embedding_writer.EmbeddingWriter
- init_args:
- output_path: /Users/ziwen.liu/Projects/test-time/predict/time_interval_1.zarr
- inference_mode: true
-model:
- class_path: viscy.representation.engine.ContrastiveModule
- init_args:
- encoder:
- class_path: viscy.representation.contrastive.ContrastiveEncoder
- init_args:
- backbone: convnext_tiny
- in_channels: 1
- in_stack_depth: 1
- stem_kernel_size: [1, 4, 4]
- stem_stride: [1, 4, 4]
- embedding_dim: 768
- projection_dim: 32
- drop_path_rate: 0.0
- example_input_array_shape: [1, 1, 1, 128, 128]
-data:
- class_path: viscy.data.triplet.TripletDataModule
- init_args:
- data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
- tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr
- source_channel: DIC
- z_range: [0, 1]
- batch_size: 16
- num_workers: 4
- initial_yx_patch_size: [128, 128]
- final_yx_patch_size: [128, 128]
- time_interval: 1
- normalizations:
- - class_path: viscy.transforms.NormalizeSampled
- init_args:
- keys: [DIC]
- level: fov_statistics
- subtrahend: mean
- divisor: std
-return_predictions: false
-ckpt_path: /Users/ziwen.liu/Projects/test-time/lightning_logs/time_interval_1/checkpoints/last.ckpt
diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py
new file mode 100644
index 000000000..b79c7cbe7
--- /dev/null
+++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py
@@ -0,0 +1,227 @@
+# %%
+from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
+from viscy.representation.embedding_writer import read_embedding_dataset
+from viscy.representation.evaluation.distance import (
+ compute_displacement,
+ compute_displacement_statistics,
+)
+
+# Paths to datasets
+feature_paths = {
+ "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr",
+ "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr",
+}
+
+# Colors for different time intervals
+interval_colors = {
+ "7 min interval": "blue",
+ "21 min interval": "red",
+}
+
+# %% Compute MSD for each dataset
+results = {}
+raw_displacements = {}
+
+for label, path in feature_paths.items():
+ print(f"\nProcessing {label}...")
+ embedding_dataset = read_embedding_dataset(Path(path))
+
+ # Compute displacements
+ displacements = compute_displacement(
+ embedding_dataset=embedding_dataset,
+ distance_metric="euclidean_squared",
+ )
+ means, stds = compute_displacement_statistics(displacements)
+ results[label] = (means, stds)
+ raw_displacements[label] = displacements
+
+ # Print some statistics
+ taus = sorted(means.keys())
+ print(f" Number of different τ values: {len(taus)}")
+ print(f" τ range: {min(taus)} to {max(taus)}")
+ print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}")
+
+# %% Plot MSD vs time (linear scale)
+plt.figure(figsize=(10, 6))
+
+# Plot each time interval
+for interval_label, path in feature_paths.items():
+ means, stds = results[interval_label]
+
+ # Sort by tau for plotting
+ taus = sorted(means.keys())
+ mean_values = [means[tau] for tau in taus]
+ std_values = [stds[tau] for tau in taus]
+
+ plt.plot(
+ taus,
+ mean_values,
+ "-",
+ color=interval_colors[interval_label],
+ alpha=0.5,
+ zorder=1,
+ )
+ plt.scatter(
+ taus,
+ mean_values,
+ color=interval_colors[interval_label],
+ s=20,
+ label=interval_label,
+ zorder=2,
+ )
+
+plt.xlabel("Time Shift (τ)")
+plt.ylabel("Mean Square Displacement")
+plt.title("MSD vs Time Shift")
+plt.grid(True, alpha=0.3)
+plt.legend()
+plt.tight_layout()
+plt.show()
+
+# %% Plot MSD vs time (log-log scale with slopes)
+plt.figure(figsize=(10, 6))
+
+# Plot each time interval
+for interval_label, path in feature_paths.items():
+ means, stds = results[interval_label]
+
+ # Sort by tau for plotting
+ taus = sorted(means.keys())
+ mean_values = [means[tau] for tau in taus]
+ std_values = [stds[tau] for tau in taus]
+
+ # Filter out non-positive values for log scale
+ valid_mask = np.array(mean_values) > 0
+ valid_taus = np.array(taus)[valid_mask]
+ valid_means = np.array(mean_values)[valid_mask]
+
+ # Calculate slopes for different regions
+ log_taus = np.log(valid_taus)
+ log_means = np.log(valid_means)
+
+ # Early slope (first third of points)
+ n_points = len(log_taus)
+ early_end = n_points // 3
+ early_slope, early_intercept = np.polyfit(
+ log_taus[:early_end], log_means[:early_end], 1
+ )
+
+ # Late slope (last third of points)
+ late_start = 2 * (n_points // 3)
+ late_slope, late_intercept = np.polyfit(
+ log_taus[late_start:], log_means[late_start:], 1
+ )
+
+ plt.plot(
+ valid_taus,
+ valid_means,
+ "-",
+ color=interval_colors[interval_label],
+ alpha=0.5,
+ zorder=1,
+ )
+ plt.scatter(
+ valid_taus,
+ valid_means,
+ color=interval_colors[interval_label],
+ s=20,
+ label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})",
+ zorder=2,
+ )
+
+ # Plot fitted lines for early and late regions
+ early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end])
+ late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:])
+
+ plt.plot(
+ valid_taus[:early_end],
+ early_fit,
+ "--",
+ color=interval_colors[interval_label],
+ alpha=0.3,
+ zorder=1,
+ )
+ plt.plot(
+ valid_taus[late_start:],
+ late_fit,
+ "--",
+ color=interval_colors[interval_label],
+ alpha=0.3,
+ zorder=1,
+ )
+
+plt.xscale("log")
+plt.yscale("log")
+plt.xlabel("Time Shift (τ)")
+plt.ylabel("Mean Square Displacement")
+plt.title("MSD vs Time Shift (log-log)")
+plt.grid(True, alpha=0.3, which="both")
+plt.legend(
+ title="α = slope in log-log space", bbox_to_anchor=(1.05, 1), loc="upper left"
+)
+plt.tight_layout()
+plt.show()
+
+# %% Plot slopes analysis
+early_slopes = []
+late_slopes = []
+intervals = []
+
+for interval_label in feature_paths.keys():
+ means, _ = results[interval_label]
+
+ # Calculate slopes
+ taus = np.array(sorted(means.keys()))
+ mean_values = np.array([means[tau] for tau in taus])
+ valid_mask = mean_values > 0
+
+ if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes
+ log_taus = np.log(taus[valid_mask])
+ log_means = np.log(mean_values[valid_mask])
+
+ # Calculate early and late slopes
+ n_points = len(log_taus)
+ early_end = n_points // 3
+ late_start = 2 * (n_points // 3)
+
+ early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1)
+ late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1)
+
+ early_slopes.append(early_slope)
+ late_slopes.append(late_slope)
+ intervals.append(interval_label)
+
+# Create bar plot
+plt.figure(figsize=(12, 6))
+
+x = np.arange(len(intervals))
+width = 0.35
+
+plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7)
+plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7)
+
+# Add reference lines
+plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)")
+plt.axhline(y=0, color="k", linestyle="-", alpha=0.2)
+
+plt.xlabel("Time Interval")
+plt.ylabel("Slope (α)")
+plt.title("MSD Slopes by Time Interval")
+plt.xticks(x, intervals, rotation=45)
+plt.legend()
+
+# Add annotations for diffusion regimes
+plt.text(
+ plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center"
+)
+plt.text(
+ plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center"
+)
+
+plt.grid(True, alpha=0.3)
+plt.tight_layout()
+plt.show()
+
+# %%
diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
deleted file mode 100644
index 595f283f7..000000000
--- a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py
+++ /dev/null
@@ -1,312 +0,0 @@
-# %%
-from pathlib import Path
-import matplotlib.pyplot as plt
-import numpy as np
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation.distance import (
- calculate_normalized_euclidean_distance_cell,
- compute_displacement,
- compute_dynamic_range,
- compute_rms_per_track,
-)
-from collections import defaultdict
-from tabulate import tabulate
-
-import numpy as np
-from sklearn.metrics.pairwise import cosine_similarity
-from collections import OrderedDict
-
-# %% function
-
-# Removed redundant compute_displacement_mean_std_full function
-# Removed redundant compute_dynamic_range and compute_rms_per_track functions
-
-
-def plot_rms_histogram(rms_values, label, bins=30):
- """
- Plot histogram of RMS values across tracks.
-
- Parameters:
- rms_values : list
- List of RMS values, one for each track.
- label : str
- Label for the dataset (used in the title).
- bins : int, optional
- Number of bins for the histogram. Default is 30.
-
- Returns:
- None: Displays the histogram.
- """
- plt.figure(figsize=(10, 6))
- plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black")
- plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16)
- plt.xlabel("RMS of Time Derivative", fontsize=14)
- plt.ylabel("Frequency", fontsize=14)
- plt.grid(True)
- plt.show()
-
-
-def plot_displacement(
- mean_displacement, std_displacement, label, metrics_no_track=None
-):
- """
- Plot embedding displacement over time with mean and standard deviation.
-
- Parameters:
- mean_displacement : dict
- Mean displacement for each tau.
- std_displacement : dict
- Standard deviation of displacement for each tau.
- label : str
- Label for the dataset.
- metrics_no_track : dict, optional
- Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against.
-
- Returns:
- None: Displays the plot.
- """
- plt.figure(figsize=(10, 6))
- taus = list(mean_displacement.keys())
- mean_values = list(mean_displacement.values())
- std_values = list(std_displacement.values())
-
- plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green")
- plt.fill_between(
- taus,
- np.array(mean_values) - np.array(std_values),
- np.array(mean_values) + np.array(std_values),
- color="green",
- alpha=0.3,
- label=f"Std Dev ({label})",
- )
-
- if metrics_no_track:
- mean_values_no_track = list(metrics_no_track["mean_displacement"].values())
- std_values_no_track = list(metrics_no_track["std_displacement"].values())
-
- plt.plot(
- taus,
- mean_values_no_track,
- marker="o",
- label="Classical Contrastive (No Tracking)",
- color="blue",
- )
- plt.fill_between(
- taus,
- np.array(mean_values_no_track) - np.array(std_values_no_track),
- np.array(mean_values_no_track) + np.array(std_values_no_track),
- color="blue",
- alpha=0.3,
- label="Std Dev (No Tracking)",
- )
-
- plt.xlabel("Time Shift (τ)", fontsize=14)
- plt.ylabel("Euclidean Distance", fontsize=14)
- plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16)
- plt.grid(True)
- plt.legend(fontsize=12)
- plt.show()
-
-
-def plot_overlay_displacement(overlay_displacement_data):
- """
- Plot embedding displacement over time for all datasets in one plot.
-
- Parameters:
- overlay_displacement_data : dict
- A dictionary containing mean displacement per tau for all datasets.
-
- Returns:
- None: Displays the plot.
- """
- plt.figure(figsize=(12, 8))
- for label, mean_displacement in overlay_displacement_data.items():
- taus = list(mean_displacement.keys())
- mean_values = list(mean_displacement.values())
- plt.plot(taus, mean_values, marker="o", label=label)
-
- plt.xlabel("Time Shift (τ)", fontsize=14)
- plt.ylabel("Euclidean Distance", fontsize=14)
- plt.title("Overlayed Embedding Displacement Over Time", fontsize=16)
- plt.grid(True)
- plt.legend(fontsize=12)
- plt.show()
-
-
-# %% hist stats
-def plot_boxplot_rms_across_models(datasets_rms):
- """
- Plot a boxplot for the distribution of RMS values across models.
-
- Parameters:
- datasets_rms : dict
- A dictionary where keys are dataset names and values are lists of RMS values.
-
- Returns:
- None: Displays the boxplot.
- """
- plt.figure(figsize=(12, 6))
- labels = list(datasets_rms.keys())
- data = list(datasets_rms.values())
- print(labels)
- print(data)
- # Plot the boxplot
- plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True)
-
- plt.title(
- "Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16
- )
- plt.ylabel("RMS of Time Derivative", fontsize=14)
- plt.xticks(rotation=45, fontsize=12)
- plt.grid(axis="y", linestyle="--", alpha=0.7)
- plt.tight_layout()
- plt.show()
-
-
-def plot_histogram_absolute_differences(datasets_abs_diff):
- """
- Plot histograms of absolute differences across embeddings for all models.
-
- Parameters:
- datasets_abs_diff : dict
- A dictionary where keys are dataset names and values are lists of absolute differences.
-
- Returns:
- None: Displays the histograms.
- """
- plt.figure(figsize=(12, 6))
- for label, abs_diff in datasets_abs_diff.items():
- plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True)
-
- plt.title("Histograms of Absolute Differences Across Models", fontsize=16)
- plt.xlabel("Absolute Difference", fontsize=14)
- plt.ylabel("Density", fontsize=14)
- plt.legend(fontsize=12)
- plt.grid(alpha=0.7)
- plt.tight_layout()
- plt.show()
-
-
-# %% Paths to datasets
-feature_paths = {
- "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr",
- "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr",
- "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr",
- "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr",
- "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr",
-}
-
-no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr"
-
-# %% Process Datasets
-max_tau = 69
-metrics = {}
-
-overlay_displacement_data = {}
-datasets_rms = {}
-datasets_abs_diff = {}
-
-# Process "No Tracking" dataset
-features_path_no_track = Path(no_track_path)
-embedding_dataset_no_track = read_embedding_dataset(features_path_no_track)
-
-mean_displacement_no_track, std_displacement_no_track = compute_displacement(
- embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True
-)
-dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track)
-metrics["No Tracking"] = {
- "dynamic_range": dynamic_range_no_track,
- "mean_displacement": mean_displacement_no_track,
- "std_displacement": std_displacement_no_track,
-}
-
-overlay_displacement_data["No Tracking"] = mean_displacement_no_track
-
-print("\nProcessing No Tracking dataset...")
-print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}")
-
-plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking")
-
-rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track)
-datasets_rms["No Tracking"] = rms_values_no_track
-
-print(f"Plotting histogram of RMS values for No Tracking dataset...")
-plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30)
-
-# Compute absolute differences for "No Tracking"
-abs_diff_no_track = np.concatenate(
- [
- np.linalg.norm(
- np.diff(embedding_dataset_no_track["features"].values[indices], axis=0),
- axis=-1,
- )
- for indices in np.split(
- np.arange(len(embedding_dataset_no_track["track_id"])),
- np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1,
- )
- ]
-)
-datasets_abs_diff["No Tracking"] = abs_diff_no_track
-
-# Process other datasets
-for label, path in feature_paths.items():
- print(f"\nProcessing {label} dataset...")
-
- features_path = Path(path)
- embedding_dataset = read_embedding_dataset(features_path)
-
- mean_displacement, std_displacement = compute_displacement(
- embedding_dataset, max_tau=max_tau, return_mean_std=True
- )
- dynamic_range = compute_dynamic_range(mean_displacement)
- metrics[label] = {
- "dynamic_range": dynamic_range,
- "mean_displacement": mean_displacement,
- "std_displacement": std_displacement,
- }
-
- overlay_displacement_data[label] = mean_displacement
-
- print(f"Dynamic Range for {label}: {dynamic_range}")
-
- plot_displacement(
- mean_displacement,
- std_displacement,
- label,
- metrics_no_track=metrics.get("No Tracking", None),
- )
-
- rms_values = compute_rms_per_track(embedding_dataset)
- datasets_rms[label] = rms_values
-
- print(f"Plotting histogram of RMS values for {label}...")
- plot_rms_histogram(rms_values, label, bins=30)
-
- abs_diff = np.concatenate(
- [
- np.linalg.norm(
- np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1
- )
- for indices in np.split(
- np.arange(len(embedding_dataset["track_id"])),
- np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1,
- )
- ]
- )
- datasets_abs_diff[label] = abs_diff
-
-print("\nPlotting overlayed displacement for all datasets...")
-plot_overlay_displacement(overlay_displacement_data)
-
-print("\nSummary of Dynamic Ranges:")
-for label, metric in metrics.items():
- print(f"{label}: Dynamic Range = {metric['dynamic_range']}")
-
-print("\nPlotting RMS boxplot across models...")
-plot_boxplot_rms_across_models(datasets_rms)
-
-print("\nPlotting histograms of absolute differences across models...")
-plot_histogram_absolute_differences(datasets_abs_diff)
-
-
-# %%
diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py
deleted file mode 100644
index 66a459ddc..000000000
--- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py
+++ /dev/null
@@ -1,449 +0,0 @@
-""" Script to compute the correlation between PCA and UMAP features and computed features
-* finds the computed features best representing the PCA and UMAP components
-* outputs a heatmap of the correlation between PCA and UMAP features and computed features
-"""
-
-# %%
-import os
-import sys
-from pathlib import Path
-
-sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy")
-
-import matplotlib.pyplot as plt
-import numpy as np
-import seaborn as sns
-from sklearn.decomposition import PCA
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation import dataset_of_tracks
-from viscy.representation.evaluation.feature import (
- FeatureExtractor as FE,
-)
-
-# %%
-features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr"
-)
-data_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
-)
-tracks_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr"
-)
-
-# %%
-
-source_channel = ["Phase3D", "RFP"]
-z_range = (28, 43)
-normalizations = None
-# fov_name = "/B/4/5"
-# track_id = 11
-
-embedding_dataset = read_embedding_dataset(features_path)
-embedding_dataset
-
-# load all unprojected features:
-features = embedding_dataset["features"]
-
-# %% PCA analysis of the features
-
-pca = PCA(n_components=5)
-pca_features = pca.fit_transform(features.values)
-features = (
- features.assign_coords(PCA1=("sample", pca_features[:, 0]))
- .assign_coords(PCA2=("sample", pca_features[:, 1]))
- .assign_coords(PCA3=("sample", pca_features[:, 2]))
- .assign_coords(PCA4=("sample", pca_features[:, 3]))
- .assign_coords(PCA5=("sample", pca_features[:, 4]))
- .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5"], append=True)
-)
-
-# %% convert the xarray to dataframe structure and add columns for computed features
-features_df = features.to_dataframe()
-features_df = features_df.drop(columns=["features"])
-df = features_df.drop_duplicates()
-features = df.reset_index(drop=True)
-
-features = features[features["fov_name"].str.startswith("/B/")]
-
-features["Phase Symmetry Score"] = np.nan
-features["Fluor Symmetry Score"] = np.nan
-features["Sensor Area"] = np.nan
-features["Masked Sensor Intensity"] = np.nan
-features["Entropy Phase"] = np.nan
-features["Entropy Fluor"] = np.nan
-features["Contrast Phase"] = np.nan
-features["Dissimilarity Phase"] = np.nan
-features["Homogeneity Phase"] = np.nan
-features["Contrast Fluor"] = np.nan
-features["Dissimilarity Fluor"] = np.nan
-features["Homogeneity Fluor"] = np.nan
-features["Phase IQR"] = np.nan
-features["Fluor Mean Intensity"] = np.nan
-features["Phase Standard Deviation"] = np.nan
-features["Fluor Standard Deviation"] = np.nan
-features["Phase radial profile"] = np.nan
-features["Fluor radial profile"] = np.nan
-
-# %% compute the computed features and add them to the dataset
-
-fov_names_list = features["fov_name"].unique()
-unique_fov_names = sorted(list(set(fov_names_list)))
-
-
-for fov_name in unique_fov_names:
-
- unique_track_ids = features[features["fov_name"] == fov_name]["track_id"].unique()
- unique_track_ids = list(set(unique_track_ids))
-
- for track_id in unique_track_ids:
-
- prediction_dataset = dataset_of_tracks(
- data_path,
- tracks_path,
- [fov_name],
- [track_id],
- source_channel=source_channel,
- )
-
- whole = np.stack([p["anchor"] for p in prediction_dataset])
- phase = whole[:, 0, 3]
- fluor = np.max(whole[:, 1], axis=1)
-
- for t in range(phase.shape[0]):
- # Compute Fourier descriptors for phase image
- phase_descriptors = FE.compute_fourier_descriptors(phase[t])
- # Analyze symmetry of phase image
- phase_symmetry_score = FE.analyze_symmetry(phase_descriptors)
-
- # Compute Fourier descriptors for fluor image
- fluor_descriptors = FE.compute_fourier_descriptors(fluor[t])
- # Analyze symmetry of fluor image
- fluor_symmetry_score = FE.analyze_symmetry(fluor_descriptors)
-
- # Compute area of sensor
- masked_intensity, area = FE.compute_area(fluor[t])
-
- # Compute higher frequency features using spectral entropy
- entropy_phase = FE.compute_spectral_entropy(phase[t])
- entropy_fluor = FE.compute_spectral_entropy(fluor[t])
-
- # Compute texture analysis using GLCM
- contrast_phase, dissimilarity_phase, homogeneity_phase = (
- FE.compute_glcm_features(phase[t])
- )
- contrast_fluor, dissimilarity_fluor, homogeneity_fluor = (
- FE.compute_glcm_features(fluor[t])
- )
-
- # Compute interqualtile range of pixel intensities
- iqr = FE.compute_iqr(phase[t])
-
- # Compute mean pixel intensity
- fluor_mean_intensity = FE.compute_mean_intensity(fluor[t])
-
- # Compute standard deviation of pixel intensities
- phase_std_dev = FE.compute_std_dev(phase[t])
- fluor_std_dev = FE.compute_std_dev(fluor[t])
-
- # Compute radial intensity gradient
- phase_radial_profile = FE.compute_radial_intensity_gradient(phase[t])
- fluor_radial_profile = FE.compute_radial_intensity_gradient(fluor[t])
-
- # update the features dataframe with the computed features
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Fluor Symmetry Score",
- ] = fluor_symmetry_score
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase Symmetry Score",
- ] = phase_symmetry_score
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Sensor Area",
- ] = area
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Masked Sensor Intensity",
- ] = masked_intensity
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Entropy Phase",
- ] = entropy_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Entropy Fluor",
- ] = entropy_fluor
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Contrast Phase",
- ] = contrast_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Dissimilarity Phase",
- ] = dissimilarity_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Homogeneity Phase",
- ] = homogeneity_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Contrast Fluor",
- ] = contrast_fluor
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Dissimilarity Fluor",
- ] = dissimilarity_fluor
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Homogeneity Fluor",
- ] = homogeneity_fluor
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase IQR",
- ] = iqr
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Fluor Mean Intensity",
- ] = fluor_mean_intensity
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase Standard Deviation",
- ] = phase_std_dev
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Fluor Standard Deviation",
- ] = fluor_std_dev
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase radial profile",
- ] = phase_radial_profile
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Fluor radial profile",
- ] = fluor_radial_profile
-
-# %%
-
-# Save the features dataframe to a CSV file
-features.to_csv(
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan.csv",
- index=False,
-)
-
-# # read the features dataframe from the CSV file
-# features = pd.read_csv(
-# "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan.csv"
-# )
-
-# remove the rows with missing values
-features = features.dropna()
-
-# sub_features = features[features["Time"] == 20]
-feature_df_removed = features.drop(
- columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"]
-)
-
-# Compute correlation between PCA features and computed features
-correlation = feature_df_removed.corr(method="spearman")
-
-# %% display PCA correlation as a heatmap
-
-plt.figure(figsize=(20, 5))
-sns.heatmap(
- correlation.drop(columns=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5"]).loc[
- "PCA1":"PCA5", :
- ],
- annot=True,
- cmap="coolwarm",
- fmt=".2f",
-)
-plt.title("Correlation between PCA features and computed features")
-plt.xlabel("Computed Features")
-plt.ylabel("PCA Features")
-plt.savefig(
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca.svg"
-)
-
-
-# %% plot PCA vs set of computed features
-
-set_features = [
- "Fluor radial profile",
- "Homogeneity Phase",
- "Phase IQR",
- "Phase Standard Deviation",
- "Sensor Area",
- "Homogeneity Fluor",
- "Contrast Fluor",
- "Phase radial profile",
-]
-
-plt.figure(figsize=(8, 10))
-sns.heatmap(
- correlation.loc[set_features, "PCA1":"PCA5"],
- annot=True,
- cmap="coolwarm",
- fmt=".2f",
- vmin=-1,
- vmax=1,
-)
-
-plt.savefig(
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_setfeatures.svg"
-)
-
-# %% find the cell patches with the highest and lowest value in each feature
-
-def save_patches(fov_name, track_id):
- data_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
- )
- tracks_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr"
- )
- source_channel = ["Phase3D", "RFP"]
- prediction_dataset = dataset_of_tracks(
- data_path,
- tracks_path,
- [fov_name],
- [track_id],
- source_channel=source_channel,
- )
- whole = np.stack([p["anchor"] for p in prediction_dataset])
- phase = whole[:, 0]
- fluor = whole[:, 1]
- out_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/data/computed_features/"
- fov_name_out = fov_name.replace("/", "_")
- np.save(
- (os.path.join(out_dir, "phase" + fov_name_out + "_" + str(track_id) + ".npy")),
- phase,
- )
- np.save(
- (os.path.join(out_dir, "fluor" + fov_name_out + "_" + str(track_id) + ".npy")),
- fluor,
- )
-
-
-# PCA1: Fluor radial profile
-highest_fluor_radial_profile = features.loc[features["Fluor radial profile"].idxmax()]
-print("Row with highest 'Fluor radial profile':")
-# print(highest_fluor_radial_profile)
-print(
- f"fov_name: {highest_fluor_radial_profile['fov_name']}, time: {highest_fluor_radial_profile['t']}"
-)
-save_patches(
- highest_fluor_radial_profile["fov_name"], highest_fluor_radial_profile["track_id"]
-)
-
-lowest_fluor_radial_profile = features.loc[features["Fluor radial profile"].idxmin()]
-print("Row with lowest 'Fluor radial profile':")
-# print(lowest_fluor_radial_profile)
-print(
- f"fov_name: {lowest_fluor_radial_profile['fov_name']}, time: {lowest_fluor_radial_profile['t']}"
-)
-save_patches(
- lowest_fluor_radial_profile["fov_name"], lowest_fluor_radial_profile["track_id"]
-)
-
-# PCA2: Entropy phase
-highest_entropy_phase = features.loc[features["Entropy Phase"].idxmax()]
-print("Row with highest 'Entropy Phase':")
-# print(highest_entropy_phase)
-print(
- f"fov_name: {highest_entropy_phase['fov_name']}, time: {highest_entropy_phase['t']}"
-)
-save_patches(highest_entropy_phase["fov_name"], highest_entropy_phase["track_id"])
-
-lowest_entropy_phase = features.loc[features["Entropy Phase"].idxmin()]
-print("Row with lowest 'Entropy Phase':")
-# print(lowest_entropy_phase)
-print(
- f"fov_name: {lowest_entropy_phase['fov_name']}, time: {lowest_entropy_phase['t']}"
-)
-save_patches(lowest_entropy_phase["fov_name"], lowest_entropy_phase["track_id"])
-
-# PCA3: Phase IQR
-highest_phase_iqr = features.loc[features["Phase IQR"].idxmax()]
-print("Row with highest 'Phase IQR':")
-# print(highest_phase_iqr)
-print(f"fov_name: {highest_phase_iqr['fov_name']}, time: {highest_phase_iqr['t']}")
-save_patches(highest_phase_iqr["fov_name"], highest_phase_iqr["track_id"])
-
-tenth_lowest_phase_iqr = features.nsmallest(10, "Phase IQR").iloc[9]
-print("Row with tenth lowest 'Phase IQR':")
-# print(tenth_lowest_phase_iqr)
-print(
- f"fov_name: {tenth_lowest_phase_iqr['fov_name']}, time: {tenth_lowest_phase_iqr['t']}"
-)
-save_patches(tenth_lowest_phase_iqr["fov_name"], tenth_lowest_phase_iqr["track_id"])
-
-# PCA4: Phase Standard Deviation
-highest_phase_std_dev = features.loc[features["Phase Standard Deviation"].idxmax()]
-print("Row with highest 'Phase Standard Deviation':")
-# print(highest_phase_std_dev)
-print(
- f"fov_name: {highest_phase_std_dev['fov_name']}, time: {highest_phase_std_dev['t']}"
-)
-save_patches(highest_phase_std_dev["fov_name"], highest_phase_std_dev["track_id"])
-
-lowest_phase_std_dev = features.loc[features["Phase Standard Deviation"].idxmin()]
-print("Row with lowest 'Phase Standard Deviation':")
-# print(lowest_phase_std_dev)
-print(
- f"fov_name: {lowest_phase_std_dev['fov_name']}, time: {lowest_phase_std_dev['t']}"
-)
-save_patches(lowest_phase_std_dev["fov_name"], lowest_phase_std_dev["track_id"])
-
-# PCA5: Sensor area
-highest_sensor_area = features.loc[features["Sensor Area"].idxmax()]
-print("Row with highest 'Sensor Area':")
-# print(highest_sensor_area)
-print(f"fov_name: {highest_sensor_area['fov_name']}, time: {highest_sensor_area['t']}")
-save_patches(highest_sensor_area["fov_name"], highest_sensor_area["track_id"])
-
-tenth_lowest_sensor_area = features.nsmallest(10, "Sensor Area").iloc[9]
-print("Row with tenth lowest 'Sensor Area':")
-# print(tenth_lowest_sensor_area)
-print(
- f"fov_name: {tenth_lowest_sensor_area['fov_name']}, time: {tenth_lowest_sensor_area['t']}"
-)
-save_patches(tenth_lowest_sensor_area["fov_name"], tenth_lowest_sensor_area["track_id"])
diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py
deleted file mode 100644
index 3d5049166..000000000
--- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py
+++ /dev/null
@@ -1,245 +0,0 @@
-""" Script to compute the correlation between PCA and UMAP features and computed features
-* finds the computed features best representing the PCA and UMAP components
-* outputs a heatmap of the correlation between PCA and UMAP features and computed features
-"""
-
-# %%
-import sys
-from pathlib import Path
-
-sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy")
-
-import numpy as np
-import pandas as pd
-import plotly.express as px
-from scipy.stats import spearmanr
-from sklearn.decomposition import PCA
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation import dataset_of_tracks
-from viscy.representation.evaluation.feature import (
- FeatureExtractor as FE,
-)
-
-# %%
-features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval_phase/predictions/epoch_186/1chan_128patch_186ckpt_Febtest.zarr"
-)
-data_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
-)
-tracks_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/track.zarr"
-)
-
-# %%
-
-source_channel = ["Phase3D"]
-z_range = (28, 43)
-normalizations = None
-# fov_name = "/B/4/5"
-# track_id = 11
-
-embedding_dataset = read_embedding_dataset(features_path)
-embedding_dataset
-
-# load all unprojected features:
-features = embedding_dataset["features"]
-
-# %% PCA analysis of the features
-
-pca = PCA(n_components=3)
-embedding = pca.fit_transform(features.values)
-features = (
- features.assign_coords(PCA1=("sample", embedding[:, 0]))
- .assign_coords(PCA2=("sample", embedding[:, 1]))
- .assign_coords(PCA3=("sample", embedding[:, 2]))
- .set_index(sample=["PCA1", "PCA2", "PCA3"], append=True)
-)
-
-# %% convert the xarray to dataframe structure and add columns for computed features
-features_df = features.to_dataframe()
-features_df = features_df.drop(columns=["features"])
-df = features_df.drop_duplicates()
-features = df.reset_index(drop=True)
-
-features = features[features["fov_name"].str.startswith("/B/")]
-
-features["Phase Symmetry Score"] = np.nan
-features["Entropy Phase"] = np.nan
-features["Contrast Phase"] = np.nan
-features["Dissimilarity Phase"] = np.nan
-features["Homogeneity Phase"] = np.nan
-features["Phase IQR"] = np.nan
-features["Phase Standard Deviation"] = np.nan
-features["Phase radial profile"] = np.nan
-
-# %% compute the computed features and add them to the dataset
-
-fov_names_list = features["fov_name"].unique()
-unique_fov_names = sorted(list(set(fov_names_list)))
-
-for fov_name in unique_fov_names:
-
- unique_track_ids = features[features["fov_name"] == fov_name]["track_id"].unique()
- unique_track_ids = list(set(unique_track_ids))
-
- for track_id in unique_track_ids:
-
- # load the image patches
-
- prediction_dataset = dataset_of_tracks(
- data_path,
- tracks_path,
- [fov_name],
- [track_id],
- source_channel=source_channel,
- )
-
- whole = np.stack([p["anchor"] for p in prediction_dataset])
- phase = whole[:, 0, 3]
-
- for t in range(phase.shape[0]):
- # Compute Fourier descriptors for phase image
- phase_descriptors = FE.compute_fourier_descriptors(phase[t])
- # Analyze symmetry of phase image
- phase_symmetry_score = FE.analyze_symmetry(phase_descriptors)
-
- # Compute higher frequency features using spectral entropy
- entropy_phase = FE.compute_spectral_entropy(phase[t])
-
- # Compute texture analysis using GLCM
- contrast_phase, dissimilarity_phase, homogeneity_phase = (
- FE.compute_glcm_features(phase[t])
- )
-
- # Compute interqualtile range of pixel intensities
- iqr = FE.compute_iqr(phase[t])
-
- # Compute standard deviation of pixel intensities
- phase_std_dev = FE.compute_std_dev(phase[t])
-
- # Compute radial intensity gradient
- phase_radial_profile = FE.compute_radial_intensity_gradient(phase[t])
-
- # update the features dataframe with the computed features
-
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase Symmetry Score",
- ] = phase_symmetry_score
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Entropy Phase",
- ] = entropy_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Contrast Phase",
- ] = contrast_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Dissimilarity Phase",
- ] = dissimilarity_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Homogeneity Phase",
- ] = homogeneity_phase
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase IQR",
- ] = iqr
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase Standard Deviation",
- ] = phase_std_dev
- features.loc[
- (features["fov_name"] == fov_name)
- & (features["track_id"] == track_id)
- & (features["t"] == t),
- "Phase radial profile",
- ] = phase_radial_profile
-
-# %%
-# Save the features dataframe to a CSV file
-features.to_csv(
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv",
- index=False,
-)
-
-# read the csv file
-# features = pd.read_csv(
-# "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv"
-# )
-
-# remove the rows with missing values
-features = features.dropna()
-
-# sub_features = features[features["Time"] == 20]
-feature_df_removed = features.drop(
- columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"]
-)
-
-# Compute correlation between PCA features and computed features
-correlation = feature_df_removed.corr(method="spearman")
-
-# %% calculate the p-value and draw volcano plot to show the significance of the correlation
-
-p_values = pd.DataFrame(index=correlation.index, columns=correlation.columns)
-
-for i in correlation.index:
- for j in correlation.columns:
- if i != j:
- p_values.loc[i, j] = spearmanr(
- feature_df_removed[i], feature_df_removed[j]
- )[1]
-
-p_values = p_values.astype(float)
-
-# %% draw an interactive volcano plot showing -log10(p-value) vs fold change
-
-# Flatten the correlation and p-values matrices and create a DataFrame
-correlation_flat = correlation.values.flatten()
-p_values_flat = p_values.values.flatten()
-# Create a list of feature names for the flattened correlation and p-values
-feature_names = [f"{i}_{j}" for i in correlation.index for j in correlation.columns]
-
-data = pd.DataFrame(
- {
- "Correlation": correlation_flat,
- "-log10(p-value)": -np.log10(p_values_flat),
- "feature_names": feature_names,
- }
-)
-
-# Create an interactive scatter plot using Plotly
-fig = px.scatter(
- data,
- x="Correlation",
- y="-log10(p-value)",
- title="Volcano plot showing significance of correlation",
- labels={"Correlation": "Correlation", "-log10(p-value)": "-log10(p-value)"},
- opacity=0.5,
- hover_data=["feature_names"],
-)
-
-fig.show()
-# Save the interactive volcano plot as an HTML file
-fig.write_html(
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/volcano_plot_1chan.html"
-)
-
-# %%
diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py b/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py
new file mode 100644
index 000000000..3d300ea7b
--- /dev/null
+++ b/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py
@@ -0,0 +1,160 @@
+""" Script to compute the correlation between PCA and UMAP features and computed features
+* finds the computed features best representing the PCA and UMAP components
+* outputs a heatmap of the correlation between PCA and UMAP features and computed features
+"""
+
+# %%
+from pathlib import Path
+import matplotlib.pyplot as plt
+import seaborn as sns
+from compute_pca_features import compute_features, compute_correlation_and_save_png
+
+# %% for sensor features
+
+features_path = Path(
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_94ckpt_rev6_2.zarr"
+)
+data_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
+)
+tracks_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr"
+)
+
+source_channel = ["Phase3D", "RFP"]
+seg_channel = ["Nuclei_prediction_labels"]
+z_range = (28, 43)
+fov_list = ["/A/3", "/B/3", "/B/4"]
+
+features_sensor = compute_features(
+ features_path,
+ data_path,
+ tracks_path,
+ source_channel,
+ seg_channel,
+ z_range,
+ fov_list,
+)
+
+features_sensor.to_csv(
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_allset_sensor.csv",
+ index=False,
+)
+
+# features_sensor = pd.read_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_allset_sensor.csv")
+
+# take a subset without the 768 features
+feature_columns = [f"feature_{i+1}" for i in range(768)]
+features_subset_sensor = features_sensor.drop(columns=feature_columns)
+correlation_sensor = compute_correlation_and_save_png(
+ features_subset_sensor,
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_sensor_allset.svg",
+)
+
+# %% plot PCA vs set of computed features for sensor features
+
+set_features = [
+ "Fluor Radial Intensity Gradient",
+ "Phase Interquartile Range",
+ "Perimeter area ratio",
+ "Fluor Interquartile Range",
+ "Phase Entropy",
+ "Fluor Zernike Moment Mean",
+]
+
+plt.figure(figsize=(10, 8))
+sns.heatmap(
+ correlation_sensor.loc[set_features, "PCA1":"PCA6"],
+ annot=True,
+ cmap="coolwarm",
+ fmt=".2f",
+ annot_kws={"size": 24},
+ vmin=-1,
+ vmax=1,
+)
+plt.xlabel("Computed Features", fontsize=24)
+plt.ylabel("PCA Features", fontsize=24)
+plt.xticks(fontsize=24) # Increase x-axis tick labels
+plt.yticks(fontsize=24) # Increase y-axis tick labels
+
+plt.savefig(
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_allset_sensor_6features.svg"
+)
+
+# plot the PCA1 vs PCA2 map for sensor features
+
+plt.figure(figsize=(10, 10))
+sns.scatterplot(
+ x="PCA1",
+ y="PCA2",
+ data=features_sensor,
+)
+
+
+# .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-.
+# / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \
+# '-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'
+
+
+# %% for organelle features
+
+features_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/Soorya/timeAware_2chan_ntxent_192patch_91ckpt_rev7_GT.zarr"
+)
+data_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_ZIKV_DENV.zarr"
+)
+tracks_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr"
+)
+
+source_channel = ["Phase3D", "raw GFP EX488 EM525-45"]
+seg_channel = ["nuclei_prediction_labels_labels"]
+z_range = (16, 21)
+normalizations = None
+fov_list = ["/B/2/000000", "/B/3/000000", "/C/2/000000"]
+
+features_organelle = compute_features(
+ features_path,
+ data_path,
+ tracks_path,
+ source_channel,
+ seg_channel,
+ z_range,
+ fov_list,
+)
+
+# Save the features dataframe to a CSV file
+features_organelle.to_csv(
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan_organelle_multiwell.csv",
+ index=False,
+)
+
+correlation_organelle = compute_correlation_and_save_png(
+ features_organelle,
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_organelle_multiwell.svg",
+)
+
+# features_organelle = pd.read_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan_organelle_multiwell_refinedPCA.csv")
+
+# %% plot PCA vs set of computed features for organelle features
+
+plt.figure(figsize=(10, 8))
+sns.heatmap(
+ correlation_organelle.loc[set_features, "PCA1":"PCA6"],
+ annot=True,
+ cmap="coolwarm",
+ fmt=".2f",
+ annot_kws={"size": 24},
+ vmin=-1,
+ vmax=1,
+)
+plt.xlabel("Computed Features", fontsize=24)
+plt.ylabel("PCA Features", fontsize=24)
+plt.xticks(fontsize=24) # Increase x-axis tick labels
+plt.yticks(fontsize=24) # Increase y-axis tick labels
+plt.savefig(
+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_setfeatures_organelle_6features.svg"
+)
+
+# %%
diff --git a/applications/contrastive_phenotyping/evaluation/compute_pca_features.py b/applications/contrastive_phenotyping/evaluation/compute_pca_features.py
new file mode 100644
index 000000000..a7f09b59f
--- /dev/null
+++ b/applications/contrastive_phenotyping/evaluation/compute_pca_features.py
@@ -0,0 +1,382 @@
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from sklearn.decomposition import PCA
+from sklearn.preprocessing import StandardScaler
+
+from viscy.representation.embedding_writer import read_embedding_dataset
+from viscy.representation.evaluation import dataset_of_tracks
+from viscy.representation.evaluation.feature import CellFeatures
+
+
+## function to read the embedding dataset and return the features
+def compute_PCA(features_path: Path):
+ """Compute PCA components from embedding features and combine with original features.
+
+ This function reads an embedding dataset, standardizes the features, and computes
+ 8 principal components. The PCA components are then combined with the original
+ features in an xarray dataset structure.
+
+ Parameters
+ ----------
+ features_path : Path
+ Path to the embedding dataset containing the feature vectors.
+
+ Returns
+ -------
+ features: xarray dataset with PCA components as new coordinates
+
+ """
+ embedding_dataset = read_embedding_dataset(features_path)
+ embedding_dataset
+
+ # load all unprojected features:
+ features = embedding_dataset["features"]
+ scaled_features = StandardScaler().fit_transform(features.values)
+ # PCA analysis of the features
+
+ pca = PCA(n_components=8)
+ pca_features = pca.fit_transform(scaled_features)
+ features = (
+ features.assign_coords(PCA1=("sample", pca_features[:, 0]))
+ .assign_coords(PCA2=("sample", pca_features[:, 1]))
+ .assign_coords(PCA3=("sample", pca_features[:, 2]))
+ .assign_coords(PCA4=("sample", pca_features[:, 3]))
+ .assign_coords(PCA5=("sample", pca_features[:, 4]))
+ .assign_coords(PCA6=("sample", pca_features[:, 5]))
+ .assign_coords(PCA7=("sample", pca_features[:, 6]))
+ .assign_coords(PCA8=("sample", pca_features[:, 7]))
+ .set_index(
+ sample=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7", "PCA8"],
+ append=True,
+ )
+ )
+
+ return features
+
+
+def compute_features(
+ features_path: Path,
+ data_path: Path,
+ tracks_path: Path,
+ source_channel: list,
+ seg_channel: list,
+ z_range: tuple,
+ fov_list: list,
+):
+ """Compute various cell features and combine them with PCA features.
+
+ This function processes cell tracking data to compute various morphological and
+ intensity-based features for both phase and fluorescence channels, and combines
+ them with PCA features from an embedding dataset.
+
+ Parameters
+ ----------
+ features_path : Path
+ Path to the embedding dataset containing PCA features.
+ data_path : Path
+ Path to the raw data directory containing image data.
+ tracks_path : Path
+ Path to the directory containing tracking data in CSV format.
+ source_channel : list
+ List of source channels to process from the data.
+ seg_channel : list
+ List of segmentation channels to process from the data.
+ z_range : tuple
+ Tuple specifying the z-range to process (min_z, max_z).
+ fov_list : list
+ List of field of view names to process.
+
+ Returns
+ -------
+ pandas.DataFrame
+ DataFrame containing all computed features including:
+ - Basic features (mean intensity, std dev, kurtosis, etc.) for both Phase and Fluor channels
+ - Organelle features (area, masked intensity)
+ - Nuclear features (area, perimeter, eccentricity)
+ - PCA components (PCA1-PCA8)
+ - Original tracking information (fov_name, track_id, time points)
+ """
+
+ embedding_dataset = compute_PCA(features_path)
+ features_npy = embedding_dataset["features"].values
+
+ # convert the xarray to dataframe structure and add columns for computed features
+ embedding_df = embedding_dataset["sample"].to_dataframe().reset_index(drop=True)
+ feature_columns = pd.DataFrame(
+ features_npy, columns=[f"feature_{i+1}" for i in range(768)]
+ )
+
+ embedding_df = pd.concat([embedding_df, feature_columns], axis=1)
+ embedding_df = embedding_df.drop(columns=["sample", "UMAP1", "UMAP2"])
+
+ # Filter features based on FOV names that start with any of the items in fov_list
+ embedding_df = embedding_df[
+ embedding_df["fov_name"].apply(
+ lambda x: any(x.startswith(fov) for fov in fov_list)
+ )
+ ]
+
+ # Define feature categories and their corresponding column names
+ feature_columns = {
+ "basic_features": [
+ ("Mean Intensity", ["Phase", "Fluor"]),
+ ("Std Dev", ["Phase", "Fluor"]),
+ ("Kurtosis", ["Phase", "Fluor"]),
+ ("Skewness", ["Phase", "Fluor"]),
+ ("Entropy", ["Phase", "Fluor"]),
+ ("Interquartile Range", ["Phase", "Fluor"]),
+ ("Dissimilarity", ["Phase", "Fluor"]),
+ ("Contrast", ["Phase", "Fluor"]),
+ ("Texture", ["Phase", "Fluor"]),
+ ("Weighted Intensity Gradient", ["Phase", "Fluor"]),
+ ("Radial Intensity Gradient", ["Phase", "Fluor"]),
+ ("Zernike Moment Std", ["Phase", "Fluor"]),
+ ("Zernike Moment Mean", ["Phase", "Fluor"]),
+ ("Intensity Localization", ["Phase", "Fluor"]),
+ ],
+ "organelle_features": [
+ "Fluor Area",
+ "Fluor Masked Intensity",
+ ],
+ "nuclear_features": [
+ "Nuclear area",
+ "Perimeter",
+ "Perimeter area ratio",
+ "Nucleus eccentricity",
+ ],
+ }
+
+ # Initialize all feature columns
+ for category, feature_list in feature_columns.items():
+ if isinstance(feature_list[0], tuple): # Handle features with multiple channels
+ for feature, channels in feature_list:
+ for channel in channels:
+ col_name = f"{channel} {feature}"
+ embedding_df[col_name] = np.nan
+ else: # Handle single features
+ for feature in feature_list:
+ embedding_df[feature] = np.nan
+
+ # compute the computed features and add them to the dataset
+
+ fov_names_list = embedding_df["fov_name"].unique()
+ unique_fov_names = sorted(list(set(fov_names_list)))
+
+ for fov_name in unique_fov_names:
+
+ unique_track_ids = embedding_df[embedding_df["fov_name"] == fov_name][
+ "track_id"
+ ].unique()
+ unique_track_ids = list(set(unique_track_ids))
+
+ # iteration_count = 0
+
+ for track_id in unique_track_ids:
+ if not embedding_df[
+ (embedding_df["fov_name"] == fov_name)
+ & (embedding_df["track_id"] == track_id)
+ ].empty:
+
+ prediction_dataset = dataset_of_tracks(
+ data_path,
+ tracks_path,
+ [fov_name],
+ [track_id],
+ z_range=z_range,
+ source_channel=source_channel,
+ )
+ track_channel = dataset_of_tracks(
+ tracks_path,
+ tracks_path,
+ [fov_name],
+ [track_id],
+ z_range=(0, 1),
+ source_channel=seg_channel,
+ )
+
+ whole = np.stack([p["anchor"] for p in prediction_dataset])
+ seg_mask = np.stack([p["anchor"] for p in track_channel])
+ phase = whole[:, 0, 2]
+ # Normalize phase image to 0-255 range
+ # phase = ((phase - phase.min()) / (phase.max() - phase.min()) * 255).astype(np.uint8)
+ # Normalize fluorescence image to 0-255 range
+ fluor = np.max(whole[:, 1], axis=1)
+ # fluor = ((fluor - fluor.min()) / (fluor.max() - fluor.min()) * 255).astype(np.uint8)
+ nucl_mask = seg_mask[:, 0, 0]
+
+ for i, t in enumerate(
+ embedding_df[
+ (embedding_df["fov_name"] == fov_name)
+ & (embedding_df["track_id"] == track_id)
+ ]["t"]
+ ):
+
+ # Basic statistical features for both channels
+ phase_features = CellFeatures(phase[i], nucl_mask[i])
+ PF = phase_features.compute_all_features()
+
+ # Get all basic statistical measures at once
+ phase_stats = {
+ "Mean Intensity": PF["mean_intensity"],
+ "Std Dev": PF["std_dev"],
+ "Kurtosis": PF["kurtosis"],
+ "Skewness": PF["skewness"],
+ "Interquartile Range": PF["iqr"],
+ "Entropy": PF["spectral_entropy"],
+ "Dissimilarity": PF["dissimilarity"],
+ "Contrast": PF["contrast"],
+ "Texture": PF["texture"],
+ "Zernike Moment Std": PF["zernike_std"],
+ "Zernike Moment Mean": PF["zernike_mean"],
+ "Radial Intensity Gradient": PF["radial_intensity_gradient"],
+ "Weighted Intensity Gradient": PF[
+ "weighted_intensity_gradient"
+ ],
+ "Intensity Localization": PF["intensity_localization"],
+ }
+
+ fluor_cell_features = CellFeatures(fluor[i], nucl_mask[i])
+
+ FF = fluor_cell_features.compute_all_features()
+
+ fluor_stats = {
+ "Mean Intensity": FF["mean_intensity"],
+ "Std Dev": FF["std_dev"],
+ "Kurtosis": FF["kurtosis"],
+ "Skewness": FF["skewness"],
+ "Interquartile Range": FF["iqr"],
+ "Entropy": FF["spectral_entropy"],
+ "Contrast": FF["contrast"],
+ "Dissimilarity": FF["dissimilarity"],
+ "Texture": FF["texture"],
+ "Masked Area": FF["masked_area"],
+ "Masked Intensity": FF["masked_intensity"],
+ "Weighted Intensity Gradient": FF[
+ "weighted_intensity_gradient"
+ ],
+ "Radial Intensity Gradient": FF["radial_intensity_gradient"],
+ "Zernike Moment Std": FF["zernike_std"],
+ "Zernike Moment Mean": FF["zernike_mean"],
+ "Intensity Localization": FF["intensity_localization"],
+ "Area": FF["area"],
+ }
+
+ mask_features = CellFeatures(nucl_mask[i], nucl_mask[i])
+ MF = mask_features.compute_all_features()
+
+ mask_stats = {
+ "perimeter": MF["perimeter"],
+ "area": MF["area"],
+ "eccentricity": MF["eccentricity"],
+ "perimeter_area_ratio": MF["perimeter_area_ratio"],
+ }
+
+ # Create dictionaries for each feature category
+ phase_feature_mapping = {
+ f"Phase {k.replace('_', ' ').title()}": v
+ for k, v in phase_stats.items()
+ }
+
+ fluor_feature_mapping = {
+ f"Fluor {k.replace('_', ' ').title()}": v
+ for k, v in fluor_stats.items()
+ }
+
+ mask_feature_mapping = {
+ "Nuclear area": mask_stats["area"],
+ "Perimeter": mask_stats["perimeter"],
+ "Perimeter area ratio": mask_stats["perimeter_area_ratio"],
+ "Nucleus eccentricity": mask_stats["eccentricity"],
+ }
+
+ # Combine all feature dictionaries
+ feature_values = {
+ **phase_feature_mapping,
+ **fluor_feature_mapping,
+ **mask_feature_mapping,
+ }
+
+ # update the features dataframe
+ for feature_name, value in feature_values.items():
+ embedding_df.loc[
+ (embedding_df["fov_name"] == fov_name)
+ & (embedding_df["track_id"] == track_id)
+ & (embedding_df["t"] == t),
+ feature_name,
+ ] = value[0]
+
+ # iteration_count += 1
+ print(f"Processed {fov_name}+{track_id}")
+
+ return embedding_df
+
+
+## save all feature dataframe to png file
+def compute_correlation_and_save_png(features: pd.DataFrame, filename: str):
+ """Compute correlation between PCA features and computed features, and save as heatmap.
+
+ This function calculates the Spearman correlation between PCA components and all
+ computed features, then visualizes the results as a heatmap. The heatmap focuses
+ on the correlation between PCA components (PCA1-PCA8) and all other computed features.
+
+ Parameters
+ ----------
+ features : pandas.DataFrame
+ DataFrame containing all features including:
+ - PCA components (PCA1-PCA8)
+ - Computed features (morphological, intensity-based, etc.)
+ - Tracking metadata (fov_name, track_id, t, etc.)
+ filename : str
+ Path where the correlation heatmap will be saved as a PNG or SVG file.
+
+ Returns
+ -------
+ pandas.DataFrame
+ The correlation matrix between all features.
+ """
+ # remove the rows with missing values
+ features = features.dropna()
+
+ # sub_features = features[features["Time"] == 20]
+ feature_df_removed = features.drop(
+ columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"]
+ )
+
+ # Compute correlation between PCA features and computed features
+ correlation = feature_df_removed.corr(method="spearman")
+
+ # display PCA correlation as a heatmap
+
+ plt.figure(figsize=(30, 10))
+ sns.heatmap(
+ correlation.drop(
+ columns=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7", "PCA8"]
+ ).loc["PCA1":"PCA8", :],
+ annot=True,
+ cmap="coolwarm",
+ fmt=".2f",
+ annot_kws={"size": 18},
+ cbar=False,
+ )
+ plt.title("Correlation between PCA features and computed features", fontsize=12)
+ plt.xlabel("Computed Features", fontsize=18)
+ plt.ylabel("PCA Features", fontsize=18)
+ plt.xticks(fontsize=18, rotation=45, ha="right") # Rotate labels and align them
+ plt.yticks(fontsize=18)
+
+ # Adjust layout to prevent label cutoff
+ plt.tight_layout()
+
+ plt.savefig(
+ filename,
+ dpi=300,
+ bbox_inches="tight",
+ pad_inches=0.5, # Add padding around the figure
+ )
+ plt.close()
+
+ return correlation
diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py
new file mode 100644
index 000000000..6005408cb
--- /dev/null
+++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py
@@ -0,0 +1,121 @@
+# metrics for the knowledge distillation figure
+
+# %%
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sns
+from sklearn.metrics import accuracy_score, classification_report, f1_score
+
+# %%
+# Mantis
+test_virus = ["C/2/000000", "C/2/001001"]
+test_mock = ["B/3/000000", "B/3/000001"]
+
+# Mantis
+TRAIN_FOVS = ["C/2/000001", "C/2/001000", "B/3/001000", "B/3/001001"]
+
+VAL_FOVS = test_virus + test_mock
+
+# %%
+prediction_from_scratch = pd.read_csv(
+ "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/from-scratch-last-1126.csv"
+)
+prediction_from_scratch["pretraining"] = "ImageNet"
+
+prediction_finetuned = pd.read_csv(
+ "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-last-1126.csv"
+)
+pretrained_name = "DynaCLR"
+prediction_finetuned["pretraining"] = pretrained_name
+
+prediction = pd.concat([prediction_from_scratch, prediction_finetuned], axis=0)
+
+# %%
+prediction = prediction[prediction["fov_name"].isin(VAL_FOVS)]
+prediction["prediction_binary"] = prediction["prediction"] > 0.5
+prediction
+
+# %%
+print(
+ classification_report(
+ prediction["label"], prediction["prediction_binary"], digits=3
+ )
+)
+
+# %%
+prediction["HPI"] = prediction["t"] / 6 + 3
+
+bins = [3, 6, 9, 12, 15, 18, 21, 24]
+labels = [f"{start}-{end}" for start, end in zip(bins[:-1], bins[1:])]
+prediction["stage"] = pd.cut(prediction["HPI"], bins=bins, labels=labels, right=True)
+prediction["well"] = prediction["fov_name"].apply(
+ lambda x: "ZIKV" if x in test_virus else "Mock"
+)
+comparison = prediction.melt(
+ id_vars=["fov_name", "id", "HPI", "well", "stage", "pretraining"],
+ value_vars=["label", "prediction_binary"],
+ var_name="source",
+ value_name="value",
+)
+with sns.axes_style("whitegrid"):
+ ax = sns.lineplot(
+ data=comparison[comparison["pretraining"] == pretrained_name],
+ x="HPI",
+ y="value",
+ hue="well",
+ hue_order=["Mock", "ZIKV"],
+ style="source",
+ errorbar=None,
+ color="gray",
+ )
+ ax.set_ylabel("Infection ratio")
+
+# %%
+id_vars = ["stage", "pretraining"]
+
+accuracy_by_t = prediction.groupby(id_vars).apply(
+ lambda x: float(accuracy_score(x["label"], x["prediction_binary"]))
+)
+f1_by_t = prediction.groupby(id_vars).apply(
+ lambda x: float(f1_score(x["label"], x["prediction_binary"]))
+)
+
+metrics_df = pd.DataFrame(
+ data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values},
+ index=f1_by_t.index,
+).reset_index()
+
+metrics_long = metrics_df.melt(
+ id_vars=id_vars,
+ value_vars=["accuracy"],
+ var_name="metric",
+ value_name="score",
+)
+
+with sns.axes_style("ticks"):
+ plt.style.use("./figure.mplstyle")
+ g = sns.catplot(
+ data=metrics_long,
+ x="stage",
+ y="score",
+ hue="pretraining",
+ kind="point",
+ linewidth=1.5,
+ linestyles="--",
+ )
+ g.set_axis_labels("HPI", "accuracy")
+ sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1))
+ g.figure.set_size_inches(3.5, 2)
+ plt.show()
+
+
+# %%
+g.figure.savefig(
+ Path.home()
+ / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/metrics_points.pdf",
+ dpi=300,
+)
+
+# %%
diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py
deleted file mode 100644
index 5f59da3e0..000000000
--- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# %%
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import seaborn as sns
-from sklearn.decomposition import PCA
-from sklearn.preprocessing import StandardScaler
-from umap import UMAP
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation import load_annotation
-
-# %% Paths and parameters.
-
-
-features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr"
-)
-data_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
-)
-tracks_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr"
-)
-
-
-# %%
-embedding_dataset = read_embedding_dataset(features_path)
-embedding_dataset
-
-
-# %%
-# Compute UMAP over all features
-features = embedding_dataset["features"]
-# or select a well:
-# features = features[features["fov_name"].str.contains("B/4")]
-
-
-scaled_features = StandardScaler().fit_transform(features.values)
-umap = UMAP()
-# Fit UMAP on all features
-embedding = umap.fit_transform(scaled_features)
-
-
-# %%
-# Add UMAP coordinates to the dataset and plot w/ time
-
-
-features = (
- features.assign_coords(UMAP1=("sample", embedding[:, 0]))
- .assign_coords(UMAP2=("sample", embedding[:, 1]))
- .set_index(sample=["UMAP1", "UMAP2"], append=True)
-)
-features
-
-
-sns.scatterplot(
- x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
-)
-
-
-# Add the title to the plot
-plt.title("Cell & Time Aware Sampling (30 min interval)")
-plt.xlim(-10, 20)
-plt.ylim(-10, 20)
-# plt.savefig('umap_cell_time_aware_time.svg', format='svg')
-plt.savefig("updated_cell_time_aware_time.png", format="png")
-# Show the plot
-plt.show()
-
-
-# %%
-
-
-any_features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr"
-)
-embedding_dataset = read_embedding_dataset(any_features_path)
-embedding_dataset
-
-
-# %%
-# Compute UMAP over all features
-features = embedding_dataset["features"]
-# or select a well:
-# features = features[features["fov_name"].str.contains("B/4")]
-
-
-scaled_features = StandardScaler().fit_transform(features.values)
-umap = UMAP()
-# Fit UMAP on all features
-embedding = umap.fit_transform(scaled_features)
-
-
-# %% Any time sampling plot
-
-
-features = (
- features.assign_coords(UMAP1=("sample", embedding[:, 0]))
- .assign_coords(UMAP2=("sample", embedding[:, 1]))
- .set_index(sample=["UMAP1", "UMAP2"], append=True)
-)
-features
-
-
-sns.scatterplot(
- x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
-)
-
-
-# Add the title to the plot
-plt.title("Cell Aware Sampling")
-
-plt.xlim(-10, 20)
-plt.ylim(-10, 20)
-
-plt.savefig("1_updated_cell_aware_time.png", format="png")
-# plt.savefig('umap_cell_aware_time.pdf', format='pdf')
-# Show the plot
-plt.show()
-
-
-# %%
-
-
-contrastive_learning_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr"
-)
-embedding_dataset = read_embedding_dataset(contrastive_learning_path)
-embedding_dataset
-
-
-# %%
-# Compute UMAP over all features
-features = embedding_dataset["features"]
-# or select a well:
-# features = features[features["fov_name"].str.contains("B/4")]
-
-
-scaled_features = StandardScaler().fit_transform(features.values)
-umap = UMAP()
-# Fit UMAP on all features
-embedding = umap.fit_transform(scaled_features)
-
-
-# %% Any time sampling plot
-
-
-features = (
- features.assign_coords(UMAP1=("sample", embedding[:, 0]))
- .assign_coords(UMAP2=("sample", embedding[:, 1]))
- .set_index(sample=["UMAP1", "UMAP2"], append=True)
-)
-features
-
-sns.scatterplot(
- x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
-)
-
-# Add the title to the plot
-plt.title("Classical Contrastive Learning Sampling")
-plt.xlim(-10, 20)
-plt.ylim(-10, 20)
-plt.savefig("updated_classical_time.png", format="png")
-# plt.savefig('classical_time.pdf', format='pdf')
-
-# Show the plot
-plt.show()
-
-
-# %% PCA
-
-
-pca = PCA(n_components=4)
-# scaled_features = StandardScaler().fit_transform(features.values)
-# pca_features = pca.fit_transform(scaled_features)
-pca_features = pca.fit_transform(features.values)
-
-
-features = (
- features.assign_coords(PCA1=("sample", pca_features[:, 0]))
- .assign_coords(PCA2=("sample", pca_features[:, 1]))
- .assign_coords(PCA3=("sample", pca_features[:, 2]))
- .assign_coords(PCA4=("sample", pca_features[:, 3]))
- .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True)
-)
-
-
-# %% plot PCA components w/ time
-
-
-plt.figure(figsize=(10, 10))
-sns.scatterplot(
- x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8
-)
-
-
-# %% OVERLAY INFECTION ANNOTATION
-ann_root = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred"
-)
-
-
-infection = load_annotation(
- features,
- ann_root / "extracted_inf_state.csv",
- "infection_state",
- {0.0: "background", 1.0: "uninfected", 2.0: "infected"},
-)
-
-
-# %%
-sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8)
-
-
-# %% plot PCA components with infection hue
-sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8)
-
-
-# %%
diff --git a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py b/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py
deleted file mode 100644
index fb64b9f07..000000000
--- a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import os
-import warnings
-from argparse import ArgumentParser
-
-import numpy as np
-import pandas as pd
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-from viscy.data.triplet import TripletDataModule
-
-warnings.filterwarnings(
- "ignore",
- category=UserWarning,
- message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).",
-)
-
-# %% Paths and constants
-save_dir = (
- "/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4"
-)
-
-# rechunked data
-data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr"
-
-# updated tracking data
-tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
-
-source_channel = ["background_mask", "uninfected_mask", "infected_mask"]
-z_range = (0, 1)
-batch_size = 1 # match the number of fovs being processed such that no data is left
-# set to 15 for full, 12 for infected, and 8 for uninfected
-
-# non-rechunked data
-data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
-
-# updated tracking data
-tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
-
-source_channel_1 = ["Nuclei_prediction_labels"]
-
-
-# %% Define the main function for training
-def main(hparams):
- # Initialize the data module for prediction, re-do embeddings but with size 224 by 224
- data_module = TripletDataModule(
- data_path=data_path,
- tracks_path=tracks_path,
- source_channel=source_channel,
- z_range=z_range,
- initial_yx_patch_size=(224, 224),
- final_yx_patch_size=(224, 224),
- batch_size=batch_size,
- num_workers=hparams.num_workers,
- )
-
- data_module.setup(stage="predict")
-
- print(f"Total prediction dataset size: {len(data_module.predict_dataset)}")
-
- dataloader = DataLoader(
- data_module.predict_dataset,
- batch_size=batch_size,
- num_workers=hparams.num_workers,
- )
-
- # Initialize the second data module for segmentation masks
- seg_data_module = TripletDataModule(
- data_path=data_path_1,
- tracks_path=tracks_path_1,
- source_channel=source_channel_1,
- z_range=z_range,
- initial_yx_patch_size=(224, 224),
- final_yx_patch_size=(224, 224),
- batch_size=batch_size,
- num_workers=hparams.num_workers,
- )
-
- seg_data_module.setup(stage="predict")
-
- seg_dataloader = DataLoader(
- seg_data_module.predict_dataset,
- batch_size=batch_size,
- num_workers=hparams.num_workers,
- )
-
- # Initialize lists to store average values
- background_avg = []
- uninfected_avg = []
- infected_avg = []
-
- for batch, seg_batch in tqdm(
- zip(dataloader, seg_dataloader),
- desc="Processing batches",
- total=len(data_module.predict_dataset),
- ):
- anchor = batch["anchor"]
- seg_anchor = seg_batch["anchor"].int()
-
- # Extract the fov_name and id from the batch
- fov_name = batch["index"]["fov_name"][0]
- cell_id = batch["index"]["id"].item()
-
- fov_dirs = fov_name.split("/")
- # Construct the path to the CSV file
- csv_path = os.path.join(
- tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv"
- )
-
- # Read the CSV file
- df = pd.read_csv(csv_path)
-
- # Find the row with the specified id and extract the track_id
- track_id = df.loc[df["id"] == cell_id, "track_id"].values[0]
-
- # Create a boolean mask where segmentation values are equal to the track_id
- mask = seg_anchor == track_id
- # mask = (seg_anchor > 0)
-
- # Find the most frequent non-zero value in seg_anchor
- # unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True)
- # most_frequent_value = unique[np.argmax(counts)]
-
- # # Create a boolean mask where segmentation values are equal to the most frequent value
- # mask = (seg_anchor == most_frequent_value)
-
- # Expand the mask to match the anchor tensor shape
- mask = mask.expand(1, 3, 1, 224, 224)
-
- # Calculate average values for each channel (background, uninfected, infected) using the mask
- background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item())
- uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item())
- infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item())
-
- # Convert lists to numpy arrays
- background_avg = np.array(background_avg)
- uninfected_avg = np.array(uninfected_avg)
- infected_avg = np.array(infected_avg)
-
- print("Average values per cell for each mask calculated.")
- print("Background average shape:", background_avg.shape)
- print("Uninfected average shape:", uninfected_avg.shape)
- print("Infected average shape:", infected_avg.shape)
-
- # Save the averages as .npy files
- np.save(os.path.join(save_dir, "background_avg.npy"), background_avg)
- np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg)
- np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg)
-
-
-if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("--backbone", type=str, default="resnet50")
- parser.add_argument("--margin", type=float, default=0.5)
- parser.add_argument("--lr", type=float, default=1e-3)
- parser.add_argument("--schedule", type=str, default="Constant")
- parser.add_argument("--log_steps_per_epoch", type=int, default=10)
- parser.add_argument("--embedding_len", type=int, default=256)
- parser.add_argument("--max_epochs", type=int, default=100)
- parser.add_argument("--accelerator", type=str, default="gpu")
- parser.add_argument("--devices", type=int, default=1)
- parser.add_argument("--num_nodes", type=int, default=1)
- parser.add_argument("--log_every_n_steps", type=int, default=1)
- parser.add_argument("--num_workers", type=int, default=8)
- args = parser.parse_args()
- main(args)
diff --git a/applications/contrastive_phenotyping/examples_cli/predict.yml b/applications/contrastive_phenotyping/examples_cli/predict.yml
deleted file mode 100644
index 622f273fb..000000000
--- a/applications/contrastive_phenotyping/examples_cli/predict.yml
+++ /dev/null
@@ -1,59 +0,0 @@
-seed_everything: 42
-trainer:
- accelerator: gpu
- strategy: auto
- devices: auto
- num_nodes: 1
- precision: 32-true
- callbacks:
- - class_path: viscy.representation.embedding_writer.EmbeddingWriter
- init_args:
- output_path: "/path/to/output.zarr"
- phate_kwargs:
- n_components: 2
- knn: 10
- decay: 50
- gamma: 1
- # edit the following lines to specify logging path
- # - class_path: lightning.pytorch.loggers.TensorBoardLogger
- # init_args:
- # save_dir: /path/to/save_dir
- # version: name-of-experiment
- # log_graph: True
- inference_mode: true
-model:
- class_path: viscy.representation.engine.ContrastiveModule
- init_args:
- backbone: convnext_tiny
- in_channels: 2
- in_stack_depth: 15
- stem_kernel_size: [5, 4, 4]
-data:
- class_path: viscy.data.triplet.TripletDataModule
- init_args:
- data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
- tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
- source_channel:
- - Phase3D
- - RFP
- z_range: [28, 43]
- batch_size: 32
- num_workers: 16
- initial_yx_patch_size: [192, 192]
- final_yx_patch_size: [192, 192]
- normalizations:
- - class_path: viscy.transforms.NormalizeSampled
- init_args:
- keys: [Phase3D]
- level: fov_statistics
- subtrahend: mean
- divisor: std
- - class_path: viscy.transforms.ScaleIntensityRangePercentilesd
- init_args:
- keys: [RFP]
- lower: 50
- upper: 99
- b_min: 0.0
- b_max: 1.0
-return_predictions: false
-ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt
diff --git a/applications/contrastive_phenotyping/figures/classify_feb.py b/applications/contrastive_phenotyping/figures/classify_feb.py
deleted file mode 100644
index b9dd81b8e..000000000
--- a/applications/contrastive_phenotyping/figures/classify_feb.py
+++ /dev/null
@@ -1,100 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import seaborn as sns
-from imblearn.over_sampling import SMOTE
-from sklearn.linear_model import LogisticRegression
-from sklearn.metrics import classification_report, confusion_matrix
-from tqdm import tqdm
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation import load_annotation
-from viscy.representation.evaluation.dimensionality_reduction import compute_pca
-
-# %% Defining Paths for February Dataset
-feb_features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr"
-)
-
-
-# %% Load and Process February Dataset
-feb_embedding_dataset = read_embedding_dataset(feb_features_path)
-print(feb_embedding_dataset)
-pca_df = compute_pca(feb_embedding_dataset, n_components=6)
-
-# Print shape before merge
-print("Shape of pca_df before merge:", pca_df.shape)
-
-# Load the ground truth infection labels
-feb_ann_root = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track"
-)
-feb_infection = load_annotation(
- feb_embedding_dataset,
- feb_ann_root / "tracking_v1_infection.csv",
- "infection class",
- {0.0: "background", 1.0: "uninfected", 2.0: "infected"},
-)
-
-# Print shape of feb_infection
-print("Shape of feb_infection:", feb_infection.shape)
-
-# Merge PCA results with ground truth labels on both 'fov_name' and 'id'
-pca_df = pd.merge(pca_df, feb_infection.reset_index(), on=["fov_name", "id"])
-
-# Print shape after merge
-print("Shape of pca_df after merge:", pca_df.shape)
-
-# Prepare the full dataset
-X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]]
-y = pca_df["infection class"]
-
-# Apply SMOTE to balance the classes in the full dataset
-smote = SMOTE(random_state=42)
-X_resampled, y_resampled = smote.fit_resample(X, y)
-
-# Print shape after SMOTE
-print(
- f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}"
-)
-
-# %% Train Logistic Regression Classifier with Progress Bar
-model = LogisticRegression(max_iter=1000, random_state=42)
-
-# Wrap the training with tqdm to show a progress bar
-for _ in tqdm(range(1)):
- model.fit(X_resampled, y_resampled)
-
-# %% Predict Labels for the Entire Dataset
-pca_df["Predicted_Label"] = model.predict(X)
-
-# Compute metrics based on the entire original dataset
-print("Classification Report for Entire Dataset:")
-print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"]))
-
-print("Confusion Matrix for Entire Dataset:")
-print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"]))
-
-# %% Plotting the Results
-plt.figure(figsize=(10, 8))
-sns.scatterplot(
- x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8
-)
-plt.title("PCA with Ground Truth Labels")
-plt.savefig("up_pca_ground_truth_labels.png", format="png", dpi=300)
-plt.show()
-
-plt.figure(figsize=(10, 8))
-sns.scatterplot(
- x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8
-)
-plt.title("PCA with Logistic Regression Predicted Labels")
-plt.savefig("up_pca_predicted_labels.png", format="png", dpi=300)
-plt.show()
-
-# %% Save Predicted Labels to CSV
-save_path_csv = "up_logistic_regression_predicted_labels_feb_pca.csv"
-pca_df[["id", "fov_name", "Predicted_Label"]].to_csv(save_path_csv, index=False)
-print(f"Predicted labels saved to {save_path_csv}")
diff --git a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py b/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py
deleted file mode 100644
index da63c52a8..000000000
--- a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import pandas as pd
-from imblearn.over_sampling import SMOTE
-from sklearn.linear_model import LogisticRegression
-from sklearn.metrics import classification_report, confusion_matrix
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-from viscy.representation.evaluation import load_annotation
-
-# %% Defining Paths for February Dataset
-feb_features_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/"
-)
-
-
-# %% Load and Process February Dataset (Embedding Features)
-feb_embedding_dataset = read_embedding_dataset(
- feb_features_path / "febtest_predict.zarr"
-)
-print(feb_embedding_dataset)
-
-# Extract the embedding feature values as the input matrix (X)
-X = feb_embedding_dataset["features"].values
-
-# Prepare a DataFrame for the embeddings with id and fov_name
-embedding_df = pd.DataFrame(X, columns=[f"feature_{i+1}" for i in range(X.shape[1])])
-embedding_df["id"] = feb_embedding_dataset["id"].values
-embedding_df["fov_name"] = feb_embedding_dataset["fov_name"].values
-print(embedding_df.head())
-
-# %% Load the ground truth infection labels
-feb_ann_root = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred"
-)
-feb_infection = load_annotation(
- feb_embedding_dataset,
- feb_ann_root / "extracted_inf_state.csv",
- "infection_state",
- {0.0: "background", 1.0: "uninfected", 2.0: "infected"},
-)
-
-# %% Merge embedding features with infection labels on 'fov_name' and 'id'
-merged_df = pd.merge(embedding_df, feb_infection.reset_index(), on=["fov_name", "id"])
-print(merged_df.head())
-# %% Prepare the full dataset for training
-X = merged_df.drop(
- columns=["id", "fov_name", "infection_state"]
-).values # Use embeddings as features
-y = merged_df["infection_state"] # Use infection state as labels
-print(X.shape)
-print(y.shape)
-# %% Print class distribution before applying SMOTE
-print("Class distribution before SMOTE:")
-print(y.value_counts())
-
-# Apply SMOTE to balance the classes
-smote = SMOTE(random_state=42)
-X_resampled, y_resampled = smote.fit_resample(X, y)
-
-# Print class distribution after applying SMOTE
-print("Class distribution after SMOTE:")
-print(pd.Series(y_resampled).value_counts())
-
-# Train Logistic Regression Classifier
-model = LogisticRegression(max_iter=1000, random_state=42)
-model.fit(X_resampled, y_resampled)
-
-# Predict Labels for the Entire Dataset
-y_pred = model.predict(X)
-
-# Compute metrics based on the entire original dataset
-print("Classification Report for Entire Dataset:")
-print(classification_report(y, y_pred))
-
-print("Confusion Matrix for Entire Dataset:")
-print(confusion_matrix(y, y_pred))
-
-# %%
-# Save the predicted labels to a CSV
-save_path_csv = feb_features_path / "feb_test_regression_predicted_labels_embedding.csv"
-predicted_labels_df = pd.DataFrame(
- {
- "id": merged_df["id"].values,
- "fov_name": merged_df["fov_name"].values,
- "Predicted_Label": y_pred,
- }
-)
-
-predicted_labels_df.to_csv(save_path_csv, index=False)
-print(f"Predicted labels saved to {save_path_csv}")
-
-# %%
diff --git a/applications/contrastive_phenotyping/figures/classify_june.py b/applications/contrastive_phenotyping/figures/classify_june.py
deleted file mode 100644
index ca51f2b17..000000000
--- a/applications/contrastive_phenotyping/figures/classify_june.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import seaborn as sns
-from imblearn.over_sampling import SMOTE
-from sklearn.decomposition import PCA
-from sklearn.linear_model import LogisticRegression
-from sklearn.metrics import classification_report, confusion_matrix
-from sklearn.preprocessing import StandardScaler
-from tqdm import tqdm
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-
-# %% Defining Paths for June Dataset
-june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr")
-
-# %% Function to Load Annotations
-def load_annotation(da, path, name, categories: dict | None = None):
- annotation = pd.read_csv(path)
- annotation["fov_name"] = "/" + annotation["fov ID"]
- annotation = annotation.set_index(["fov_name", "id"])
- mi = pd.MultiIndex.from_arrays(
- [da["fov_name"].values, da["id"].values], names=["fov_name", "id"]
- )
- selected = annotation.loc[mi][name]
- if categories:
- selected = selected.astype("category").cat.rename_categories(categories)
- return selected
-
-# %% Function to Compute PCA
-def compute_pca(embedding_dataset, n_components=6):
- features = embedding_dataset["features"]
- scaled_features = StandardScaler().fit_transform(features.values)
-
- # Compute PCA with specified number of components
- pca = PCA(n_components=n_components, random_state=42)
- pca_embedding = pca.fit_transform(scaled_features)
-
- # Prepare DataFrame with id and PCA coordinates
- pca_df = pd.DataFrame({
- "id": embedding_dataset["id"].values,
- "fov_name": embedding_dataset["fov_name"].values,
- "PCA1": pca_embedding[:, 0],
- "PCA2": pca_embedding[:, 1],
- "PCA3": pca_embedding[:, 2],
- "PCA4": pca_embedding[:, 3],
- "PCA5": pca_embedding[:, 4],
- "PCA6": pca_embedding[:, 5]
- })
-
- return pca_df
-
-# %% Load and Process June Dataset
-june_embedding_dataset = read_embedding_dataset(june_features_path)
-print(june_embedding_dataset)
-pca_df = compute_pca(june_embedding_dataset, n_components=6)
-
-# Print shape before merge
-print("Shape of pca_df before merge:", pca_df.shape)
-
-# Load the ground truth infection labels
-june_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking")
-june_infection = load_annotation(june_embedding_dataset, june_ann_root / "tracking_v1_infection.csv", "infection class",
- {0.0: "background", 1.0: "uninfected", 2.0: "infected"})
-
-# Print shape of june_infection
-print("Shape of june_infection:", june_infection.shape)
-
-# Merge PCA results with ground truth labels on both 'fov_name' and 'id'
-pca_df = pd.merge(pca_df, june_infection.reset_index(), on=['fov_name', 'id'])
-
-# Print shape after merge
-print("Shape of pca_df after merge:", pca_df.shape)
-
-# Prepare the full dataset
-X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]]
-y = pca_df["infection class"]
-
-# Apply SMOTE to balance the classes in the full dataset
-smote = SMOTE(random_state=42)
-X_resampled, y_resampled = smote.fit_resample(X, y)
-
-# Print shape after SMOTE
-print(f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}")
-
-# %% Train Logistic Regression Classifier with Progress Bar
-model = LogisticRegression(max_iter=1000, random_state=42)
-
-# Wrap the training with tqdm to show a progress bar
-for _ in tqdm(range(1)):
- model.fit(X_resampled, y_resampled)
-
-# %% Predict Labels for the Entire Dataset
-pca_df["Predicted_Label"] = model.predict(X)
-
-# Compute metrics based on the entire original dataset
-print("Classification Report for Entire Dataset:")
-print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"]))
-
-print("Confusion Matrix for Entire Dataset:")
-print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"]))
-
-# %% Plotting the Results
-plt.figure(figsize=(10, 8))
-sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8)
-plt.title("PCA with Ground Truth Labels")
-plt.savefig("june_pca_ground_truth_labels.png", format='png', dpi=300)
-plt.show()
-
-plt.figure(figsize=(10, 8))
-sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8)
-plt.title("PCA with Logistic Regression Predicted Labels")
-plt.savefig("june_pca_predicted_labels.png", format='png', dpi=300)
-plt.show()
-
-# %% Save Predicted Labels to CSV
-save_path_csv = "june_logistic_regression_predicted_labels_feb_pca.csv"
-pca_df[['id', 'fov_name', 'Predicted_Label']].to_csv(save_path_csv, index=False)
-print(f"Predicted labels saved to {save_path_csv}")
diff --git a/applications/contrastive_phenotyping/evaluation/figure.mplstyle b/applications/contrastive_phenotyping/figures/figure.mplstyle
similarity index 100%
rename from applications/contrastive_phenotyping/evaluation/figure.mplstyle
rename to applications/contrastive_phenotyping/figures/figure.mplstyle
diff --git a/applications/contrastive_phenotyping/figures/figure_4a_1.py b/applications/contrastive_phenotyping/figures/figure_4a_1.py
deleted file mode 100644
index a670db0d0..000000000
--- a/applications/contrastive_phenotyping/figures/figure_4a_1.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import seaborn as sns
-from sklearn.preprocessing import StandardScaler
-from umap import UMAP
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-
-# %% Defining Paths for February and June Datasets
-feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr")
-feb_data_path = Path("/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr")
-feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr")
-
-# %% Function to Load and Process the Embedding Dataset
-def compute_umap(embedding_dataset):
- features = embedding_dataset["features"]
- scaled_features = StandardScaler().fit_transform(features.values)
- umap = UMAP()
- embedding = umap.fit_transform(scaled_features)
-
- features = (
- features.assign_coords(UMAP1=("sample", embedding[:, 0]))
- .assign_coords(UMAP2=("sample", embedding[:, 1]))
- .set_index(sample=["UMAP1", "UMAP2"], append=True)
- )
- return features
-
-# %% Function to Load Annotations
-def load_annotation(da, path, name, categories: dict | None = None):
- annotation = pd.read_csv(path)
- annotation["fov_name"] = "/" + annotation["fov ID"]
- annotation = annotation.set_index(["fov_name", "id"])
- mi = pd.MultiIndex.from_arrays(
- [da["fov_name"].values, da["id"].values], names=["fov_name", "id"]
- )
- selected = annotation.loc[mi][name]
- if categories:
- selected = selected.astype("category").cat.rename_categories(categories)
- return selected
-
-# %% Function to Plot UMAP with Infection Annotations
-def plot_umap_infection(features, infection, title):
- plt.figure(figsize=(10, 8))
- sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8)
- plt.title(f"UMAP Plot - {title}")
- plt.show()
-
-# %% Load and Process February Dataset
-feb_embedding_dataset = read_embedding_dataset(feb_features_path)
-feb_features = compute_umap(feb_embedding_dataset)
-
-feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track")
-feb_infection = load_annotation(feb_features, feb_ann_root / "tracking_v1_infection.csv", "infection class", {0.0: "background", 1.0: "uninfected", 2.0: "infected"})
-
-# %% Plot UMAP with Infection Status for February Dataset
-plot_umap_infection(feb_features, feb_infection, "February Dataset")
-
-# %%
-print(feb_embedding_dataset)
-print(feb_infection)
-print(feb_features)
-# %%
-
-
-# %% Identify cells by infection type using fov_name
-mock_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/3') | feb_features['fov_name'].str.contains('/B/3'))
-zika_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/4'))
-dengue_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/B/4'))
-
-# %% Plot UMAP with Infection Status
-plt.figure(figsize=(10, 8))
-sns.scatterplot(x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8)
-
-# Overlay with circled cells
-plt.scatter(mock_cells["UMAP1"], mock_cells["UMAP2"], facecolors='none', edgecolors='blue', s=20, label='Mock Cells')
-plt.scatter(zika_cells["UMAP1"], zika_cells["UMAP2"], facecolors='none', edgecolors='green', s=20, label='Zika MOI 5')
-plt.scatter(dengue_cells["UMAP1"], dengue_cells["UMAP2"], facecolors='none', edgecolors='red', s=20, label='Dengue MOI 5')
-
-# Add legend and show plot
-plt.legend(loc='best')
-plt.title("UMAP Plot - February Dataset with Mock, Zika, and Dengue Highlighted")
-plt.show()
-
-# %%
-# %% Create a 1x3 grid of heatmaps
-fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True)
-
-# Mock Cells Heatmap
-sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0])
-axs[0].set_title('Mock Cells')
-axs[0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Zika Cells Heatmap
-sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1])
-axs[1].set_title('Zika MOI 5')
-axs[1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Dengue Cells Heatmap
-sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[2])
-axs[2].set_title('Dengue MOI 5')
-axs[2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Set labels and adjust layout
-for ax in axs:
- ax.set_xlabel('UMAP1')
- ax.set_ylabel('UMAP2')
-
-plt.tight_layout()
-plt.show()
-
-# %%
-import matplotlib.pyplot as plt
-import seaborn as sns
-
-# %% Create a 2x3 grid of heatmaps (1 row for each heatmap, splitting infected and uninfected in the second row)
-fig, axs = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True)
-
-# Mock Cells Heatmap
-sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0, 0])
-axs[0, 0].set_title('Mock Cells')
-axs[0, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[0, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Zika Cells Heatmap
-sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[0, 1])
-axs[0, 1].set_title('Zika MOI 5')
-axs[0, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[0, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Dengue Cells Heatmap
-sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[0, 2])
-axs[0, 2].set_title('Dengue MOI 5')
-axs[0, 2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[0, 2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Infected Cells Heatmap
-sns.histplot(x=infected_cells["UMAP1"], y=infected_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[1, 0])
-axs[1, 0].set_title('Infected Cells')
-axs[1, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[1, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Uninfected Cells Heatmap
-sns.histplot(x=uninfected_cells["UMAP1"], y=uninfected_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1, 1])
-axs[1, 1].set_title('Uninfected Cells')
-axs[1, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max())
-axs[1, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max())
-
-# Remove the last subplot (bottom right corner)
-fig.delaxes(axs[1, 2])
-
-# Set labels and adjust layout
-for ax in axs.flat:
- ax.set_xlabel('UMAP1')
- ax.set_ylabel('UMAP2')
-
-plt.tight_layout()
-plt.show()
-
-
-
-# %%
diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py
deleted file mode 100644
index d3052018a..000000000
--- a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import pandas as pd
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-
-
-# %% Function to Load Annotations from GMM CSV
-def load_gmm_annotation(gmm_csv_path):
- gmm_df = pd.read_csv(gmm_csv_path)
- return gmm_df
-
-# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on GMM Labels
-def count_infected_cell_states_over_time(embedding_dataset, gmm_df):
- # Convert the embedding dataset to a DataFrame
- df = pd.DataFrame({
- "fov_name": embedding_dataset["fov_name"].values,
- "track_id": embedding_dataset["track_id"].values,
- "t": embedding_dataset["t"].values,
- "id": embedding_dataset["id"].values
- })
-
- # Merge with GMM data to add GMM labels
- df = pd.merge(df, gmm_df[['id', 'fov_name', 'Predicted_Label']], on=['fov_name', 'id'], how='left')
-
- # Filter by time range (3 HPI to 30 HPI)
- df = df[(df['t'] >= 3) & (df['t'] <= 27)]
-
- # Determine the well type (Mock, Zika, Dengue) based on fov_name
- df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else
- ('Zika' if '/A/4' in x else 'Dengue'))
-
- # Group by time, well type, and GMM label to count the number of infected cells
- state_counts = df.groupby(['t', 'well_type', 'Predicted_Label']).size().unstack(fill_value=0)
-
- # Ensure that 'infected' column exists
- if 'infected' not in state_counts.columns:
- state_counts['infected'] = 0
-
- # Calculate the percentage of infected cells
- state_counts['total'] = state_counts.sum(axis=1)
- state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100
-
- return state_counts
-
-# %% Function to Plot Percentage of Infected Cells Over Time
-def plot_infected_cell_states(state_counts):
- plt.figure(figsize=(12, 8))
-
- # Loop through each well type
- for well_type in ['Mock', 'Zika', 'Dengue']:
- # Select the data for the current well type
- if well_type in state_counts.index.get_level_values('well_type'):
- well_data = state_counts.xs(well_type, level='well_type')
-
- # Plot only the percentage of infected cells
- if 'infected' in well_data.columns:
- plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected')
-
- plt.title("Percentage of Infected Cells Over Time - February")
- plt.xlabel("Hours Post Perturbation")
- plt.ylabel("Percentage of Infected Cells")
- plt.legend(title="Well Type")
- plt.grid(True)
- plt.show()
-
-# %% Load and process Feb Dataset
-feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr")
-feb_embedding_dataset = read_embedding_dataset(feb_features_path)
-
-# Load the GMM annotation CSV
-gmm_csv_path = "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file
-gmm_df = load_gmm_annotation(gmm_csv_path)
-
-# %% Count Infected Cell States Over Time as Percentage using GMM labels
-state_counts = count_infected_cell_states_over_time(feb_embedding_dataset, gmm_df)
-print(state_counts.head())
-state_counts.info()
-
-# %% Plot Infected Cell States Over Time as Percentage
-plot_infected_cell_states(state_counts)
-
-# %%
-
-
diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py b/applications/contrastive_phenotyping/figures/figure_4e_2_june.py
deleted file mode 100644
index 1605ba278..000000000
--- a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# %% Importing Necessary Libraries
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import pandas as pd
-
-from viscy.representation.embedding_writer import read_embedding_dataset
-
-
-# %% Function to Load Annotations from CSV
-def load_annotation(csv_path):
- return pd.read_csv(csv_path)
-
-# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on Predicted Labels
-def count_infected_cell_states_over_time(embedding_dataset, prediction_df):
- # Convert the embedding dataset to a DataFrame
- df = pd.DataFrame({
- "fov_name": embedding_dataset["fov_name"].values,
- "track_id": embedding_dataset["track_id"].values,
- "t": embedding_dataset["t"].values,
- "id": embedding_dataset["id"].values
- })
-
- # Merge with the prediction data to add Predicted Labels
- df = pd.merge(df, prediction_df[['id', 'fov_name', 'Infection_Class']], on=['fov_name', 'id'], how='left')
-
- # Filter by time range (2 HPI to 50 HPI)
- df = df[(df['t'] >= 2) & (df['t'] <= 50)]
-
- # Determine the well type (Mock, Dengue, Zika) based on fov_name
- df['well_type'] = df['fov_name'].apply(
- lambda x: 'Mock' if '/0/1' in x or '/0/2' in x or '/0/3' in x or '/0/4' in x else
- ('Dengue' if '/0/5' in x or '/0/6' in x else 'Zika'))
-
- # Group by time, well type, and Predicted_Label to count the number of infected cells
- state_counts = df.groupby(['t', 'well_type', 'Infection_Class']).size().unstack(fill_value=0)
-
- # Ensure that 'infected' column exists
- if 'infected' not in state_counts.columns:
- state_counts['infected'] = 0
-
- # Calculate the percentage of infected cells
- state_counts['total'] = state_counts.sum(axis=1)
- state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100
-
- return state_counts
-
-# %% Function to Plot Percentage of Infected Cells Over Time
-def plot_infected_cell_states(state_counts):
- plt.figure(figsize=(12, 8))
-
- # Loop through each well type
- for well_type in ['Mock', 'Dengue', 'Zika']:
- # Select the data for the current well type
- if well_type in state_counts.index.get_level_values('well_type'):
- well_data = state_counts.xs(well_type, level='well_type')
-
- # Plot only the percentage of infected cells
- if 'infected' in well_data.columns:
- plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected')
-
- plt.title("Percentage of Infected Cells Over Time - June")
- plt.xlabel("Hours Post Perturbation")
- plt.ylabel("Percentage of Infected Cells")
- plt.legend(title="Well Type")
- plt.grid(True)
- plt.show()
-
-# %% Load and process June Dataset
-june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr")
-june_embedding_dataset = read_embedding_dataset(june_features_path)
-
-# Load the predicted labels from CSV
-prediction_csv_path = "3up_gmm_clustering_results_june_pca_6components.csv" # Path to predicted labels CSV file
-prediction_df = load_annotation(prediction_csv_path)
-
-# %% Count Infected Cell States Over Time as Percentage using Predicted labels
-state_counts = count_infected_cell_states_over_time(june_embedding_dataset, prediction_df)
-print(state_counts.head())
-state_counts.info()
-
-# %% Plot Infected Cell States Over Time as Percentage
-plot_infected_cell_states(state_counts)
-
-# %%
diff --git a/applications/contrastive_phenotyping/evaluation/grad_attr.py b/applications/contrastive_phenotyping/figures/grad_attr.py
similarity index 86%
rename from applications/contrastive_phenotyping/evaluation/grad_attr.py
rename to applications/contrastive_phenotyping/figures/grad_attr.py
index 2def1bfce..f0873c288 100644
--- a/applications/contrastive_phenotyping/evaluation/grad_attr.py
+++ b/applications/contrastive_phenotyping/figures/grad_attr.py
@@ -29,8 +29,8 @@
# %%
dm = TripletDataModule(
- data_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr",
- tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr",
+ data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr",
+ tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr",
source_channel=["Phase3D", "RFP"],
z_range=[25, 40],
batch_size=48,
@@ -76,10 +76,10 @@
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178_gt_tracks.zarr"
)
path_annotations_infection = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv"
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv"
)
path_annotations_division = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv"
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv"
)
infection_dataset = read_embedding_dataset(path_infection_embedding)
@@ -145,7 +145,7 @@
# %%
# load infection annotations
infection = pd.read_csv(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv",
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv",
)
track_classes_infection = infection[infection["fov_name"] == fov[1:]]
track_classes_infection = track_classes_infection[
@@ -155,7 +155,7 @@
# %%
# load division annotations
division = pd.read_csv(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv",
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv",
)
track_classes_division = division[division["fov_name"] == fov[1:]]
track_classes_division = track_classes_division[
@@ -258,7 +258,8 @@ def clim_percentile(heatmap, low=1, high=99):
# %%
f.savefig(
Path.home()
- / "gdrive/publications/learning_impacts_of_infection/fig_manuscript/fig_explanation/fig_explanation_patch12_stride4.pdf",
+ / "mydata"
+ / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/fig_explanation_patch12_stride4.pdf",
dpi=300,
)
diff --git a/applications/contrastive_phenotyping/figures/save_patches.py b/applications/contrastive_phenotyping/figures/save_patches.py
deleted file mode 100644
index ebba6c320..000000000
--- a/applications/contrastive_phenotyping/figures/save_patches.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# %% script to save 128 by 128 image patches from napari viewer
-
-import os
-import sys
-from pathlib import Path
-
-import numpy as np
-
-sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy")
-# from viscy.data.triplet import TripletDataModule
-from viscy.representation.evaluation import dataset_of_tracks
-
-# %% input parameters
-
-data_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr"
-)
-tracks_path = Path(
- "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr"
-)
-
-fov_name = "/B/4/6"
-track_id = 52
-source_channel = ["Phase3D", "RFP"]
-
-# %% load dataset
-
-prediction_dataset = dataset_of_tracks(
- data_path,
- tracks_path,
- [fov_name],
- [track_id],
- source_channel=source_channel,
-)
-whole = np.stack([p["anchor"] for p in prediction_dataset])
-phase = whole[:, 0]
-fluor = whole[:, 1]
-
-# use the following if you want to visualize a specific phase slice with max projected fluor
-# phase = whole[:, 0, 3] # 3 is the slice number
-# fluor = np.max(whole[:, 1], axis=1)
-
-# load image
-# v = napari.Viewer()
-# v.add_image(phase)
-# v.add_image(fluor)
-
-# %% save patches as png images
-
-# use sliders on napari to get the deisred contrast and make other adjustments
-# then use save screenshot if saving the image patch manually
-# you can add code to automate the process if desired
-
-# %% save as numpy files
-
-out_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/data/"
-fov_name_out = fov_name.replace("/", "_")
-np.save(
- (os.path.join(out_dir, "phase" + fov_name_out + "_" + str(track_id) + ".npy")),
- phase,
-)
-np.save(
- (os.path.join(out_dir, "fluor" + fov_name_out + "_" + str(track_id) + ".npy")),
- fluor,
-)
-
-# %%
diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py
deleted file mode 100644
index a4e712f5b..000000000
--- a/applications/infection_classification/Infection_classification_25DModel.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# %%
-import lightning.pytorch as pl
-import torch
-import torch.nn as nn
-from applications.infection_classification.classify_infection_25D import (
- SemanticSegUNet25D,
-)
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.loggers import TensorBoardLogger
-
-from viscy.data.hcs import HCSDataModule
-from viscy.preprocessing.pixel_ratio import sematic_class_weights
-from viscy.transforms import NormalizeSampled, RandWeightedCropd
-
-# %% Create a dataloader and visualize the batches.
-
-# Set the path to the dataset
-dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr"
-
-# %% create data module
-
-# Create an instance of HCSDataModule
-data_module = HCSDataModule(
- dataset_path,
- source_channel=["Phase", "HSP90"],
- target_channel=["Inf_mask"],
- yx_patch_size=[512, 512],
- split_ratio=0.8,
- z_window_size=5,
- architecture="2.5D",
- num_workers=3,
- batch_size=32,
- normalizations=[
- NormalizeSampled(
- keys=["Phase", "HSP90"],
- level="fov_statistics",
- subtrahend="median",
- divisor="iqr",
- )
- ],
- augmentations=[
- RandWeightedCropd(
- num_samples=4,
- spatial_size=[-1, 512, 512],
- keys=["Phase", "HSP90"],
- w_key="Inf_mask",
- )
- ],
-)
-
-pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")
-
-# Prepare the data
-data_module.prepare_data()
-
-# Setup the data
-data_module.setup(stage="fit")
-
-# Create a dataloader
-train_dm = data_module.train_dataloader()
-
-val_dm = data_module.val_dataloader()
-
-
-# %% Define the logger
-logger = TensorBoardLogger(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
- name="logs",
-)
-
-# Pass the logger to the Trainer
-trainer = pl.Trainer(
- logger=logger,
- max_epochs=200,
- default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
- log_every_n_steps=1,
- devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
-)
-
-# Define the checkpoint callback
-checkpoint_callback = ModelCheckpoint(
- dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
- filename="checkpoint_{epoch:02d}",
- save_top_k=-1,
- verbose=True,
- monitor="loss/validate",
- mode="min",
-)
-
-# Add the checkpoint callback to the trainer
-trainer.callbacks.append(checkpoint_callback)
-
-# Fit the model
-model = SemanticSegUNet25D(
- in_channels=2,
- out_channels=3,
- loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
-)
-
-print(model)
-
-# %% Run training.
-
-trainer.fit(model, data_module)
-
-# %%
diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py
deleted file mode 100644
index bfe203625..000000000
--- a/applications/infection_classification/Infection_classification_covnextModel.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# %%
-# import sys
-# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/")
-import lightning.pytorch as pl
-import torch
-import torch.nn as nn
-from applications.infection_classification.classify_infection_covnext import (
- SemanticSegUNet22D,
-)
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.loggers import TensorBoardLogger
-
-from viscy.data.hcs import HCSDataModule
-from viscy.preprocessing.pixel_ratio import sematic_class_weights
-from viscy.transforms import NormalizeSampled, RandWeightedCropd
-
-# %% Create a dataloader and visualize the batches.
-
-# Set the path to the dataset
-dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr"
-
-# %% craete data module
-
-# Create an instance of HCSDataModule
-data_module = HCSDataModule(
- dataset_path,
- source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
- target_channel=["Inf_mask"],
- yx_patch_size=[256, 256],
- split_ratio=0.8,
- z_window_size=5,
- architecture="2.2D",
- num_workers=3,
- batch_size=16,
- normalizations=[
- NormalizeSampled(
- keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
- level="fov_statistics",
- subtrahend="median",
- divisor="iqr",
- )
- ],
- augmentations=[
- RandWeightedCropd(
- num_samples=4,
- spatial_size=[-1, 256, 256],
- keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"],
- w_key="Inf_mask",
- )
- ],
-)
-pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask")
-
-# Prepare the data
-data_module.prepare_data()
-
-# Setup the data
-data_module.setup(stage="fit")
-
-# Create a dataloader
-train_dm = data_module.train_dataloader()
-
-val_dm = data_module.val_dataloader()
-
-
-# %% Define the logger
-logger = TensorBoardLogger(
- "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
- name="logs",
-)
-
-# Pass the logger to the Trainer
-trainer = pl.Trainer(
- logger=logger,
- max_epochs=200,
- default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
- log_every_n_steps=1,
- devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
-)
-
-# Define the checkpoint callback
-checkpoint_callback = ModelCheckpoint(
- dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
- filename="checkpoint_{epoch:02d}",
- save_top_k=-1,
- verbose=True,
- monitor="loss/validate",
- mode="min",
-)
-
-# Add the checkpoint callback to the trainer``
-trainer.callbacks.append(checkpoint_callback)
-
-# Fit the model
-model = SemanticSegUNet22D(
- in_channels=4,
- out_channels=3,
- loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
-)
-
-print(model)
-
-# %% Run training.
-
-trainer.fit(model, data_module)
-
-# %%
diff --git a/applications/infection_classification/readme.md b/applications/infection_classification/README.md
similarity index 100%
rename from applications/infection_classification/readme.md
rename to applications/infection_classification/README.md
diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py
deleted file mode 100644
index e16f56f42..000000000
--- a/applications/infection_classification/classify_infection_25D.py
+++ /dev/null
@@ -1,356 +0,0 @@
-# import torchview
-from typing import Literal, Sequence
-
-import cv2
-import lightning.pytorch as pl
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from matplotlib.cm import get_cmap
-from monai.transforms import DivisiblePad
-from skimage.exposure import rescale_intensity
-from skimage.measure import label, regionprops
-from torch import Tensor
-
-from viscy.data.hcs import Sample
-from viscy.unet.networks.Unet25D import Unet25d
-
-# %% Methods to compute confusion matrix per cell using torchmetrics
-
-
-# The confusion matrix at the single-cell resolution.
-def confusion_matrix_per_cell(
- y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int
-):
- """Compute confusion matrix per cell.
-
- Args:
- y_true (torch.Tensor): Ground truth label image (BXHXW).
- y_pred (torch.Tensor): Predicted label image (BXHXW).
- num_classes (int): Number of classes.
-
- Returns:
- torch.Tensor: Confusion matrix per cell (BXCXC).
- """
- # Convert the image class to the nuclei class
- confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes)
- confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell)
- return confusion_matrix_per_cell
-
-
-# These images can be logged with prediction.
-def compute_confusion_matrix(
- y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int
-):
- """Convert the class of the image to the class of the nuclei.
-
- Args:
- label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation.
- num_classes (int): Number of classes.
-
- Returns:
- torch.Tensor: Label images with a consensus class at the centroid of nuclei.
- """
-
- batch_size = y_true.size(0)
- # find centroids of nuclei from y_true
- conf_mat = np.zeros((num_classes, num_classes))
- for i in range(batch_size):
- y_true_cpu = y_true[i].cpu().numpy()
- y_pred_cpu = y_pred[i].cpu().numpy()
- y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:])
- y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:])
- y_pred_resized = cv2.resize(
- y_pred_reshaped,
- dsize=y_true_reshaped.shape[::-1],
- interpolation=cv2.INTER_NEAREST,
- )
- y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0)
-
- # find objects in every image
- label_img = label(y_true_reshaped)
- regions = regionprops(label_img)
-
- # Find centroids, pixel coordinates from the ground truth.
- for region in regions:
- if region.area > 0:
- row, col = region.centroid
- pred_id = y_pred_resized[int(row), int(col)]
- test_id = y_true_reshaped[int(row), int(col)]
-
- if pred_id == 1 and test_id == 1:
- conf_mat[1, 1] += 1
- if pred_id == 1 and test_id == 2:
- conf_mat[0, 1] += 1
- if pred_id == 2 and test_id == 1:
- conf_mat[1, 0] += 1
- if pred_id == 2 and test_id == 2:
- conf_mat[0, 0] += 1
- # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
- # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
- return conf_mat
-
-
-def plot_confusion_matrix(confusion_matrix, index_to_label_dict):
- # Create a figure and axis to plot the confusion matrix
- fig, ax = plt.subplots()
-
- # Create a color heatmap for the confusion matrix
- cax = ax.matshow(confusion_matrix, cmap="viridis")
-
- # Create a colorbar and set the label
- index_to_label_dict = dict(
- enumerate(index_to_label_dict)
- ) # Convert list to dictionary
- fig.colorbar(cax, label="Frequency")
-
- # Set labels for the classes
- ax.set_xticks(np.arange(len(index_to_label_dict)))
- ax.set_yticks(np.arange(len(index_to_label_dict)))
- ax.set_xticklabels(index_to_label_dict.values(), rotation=45)
- ax.set_yticklabels(index_to_label_dict.values())
-
- # Set labels for the axes
- ax.set_xlabel("Predicted")
- ax.set_ylabel("True")
-
- # Add text annotations to the confusion matrix
- for i in range(len(index_to_label_dict)):
- for j in range(len(index_to_label_dict)):
- ax.text(
- j,
- i,
- str(int(confusion_matrix[i, j])),
- ha="center",
- va="center",
- color="white",
- )
-
- # plt.show(fig) # Show the figure
- return fig
-
-
-# Define a 25d unet model for infection classification as a lightning module.
-
-
-class SemanticSegUNet25D(pl.LightningModule):
- # Model for semantic segmentation.
- def __init__(
- self,
- in_channels: int, # Number of input channels
- out_channels: int, # Number of output channels
- lr: float = 1e-3, # Learning rate
- loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function
- schedule: Literal[
- "WarmupCosine", "Constant"
- ] = "Constant", # Learning rate schedule
- log_batches_per_epoch: int = 2, # Number of batches to log per epoch
- log_samples_per_batch: int = 2, # Number of samples to log per batch
- ckpt_path: str = None, # Path to the checkpoint
- ):
- super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer
- # Initialize the UNet model
- self.unet_model = Unet25d(
- in_channels=in_channels,
- out_channels=out_channels,
- num_blocks=4,
- num_block_layers=4,
- )
- if ckpt_path is not None:
- state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
- "state_dict"
- ]
- state_dict.pop("loss_function.weight", None) # Remove the unexpected key
- self.load_state_dict(state_dict) # loading only weights
- self.lr = lr # Set the learning rate
- # Set the loss function to CrossEntropyLoss if none is provided
- self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss()
- self.schedule = schedule # Set the learning rate schedule
- self.log_batches_per_epoch = (
- log_batches_per_epoch # Set the number of batches to log per epoch
- )
- self.log_samples_per_batch = (
- log_samples_per_batch # Set the number of samples to log per batch
- )
- self.training_step_outputs = [] # Initialize the list of training step outputs
- self.validation_step_outputs = (
- []
- ) # Initialize the list of validation step outputs
-
- self.pred_cm = None # Initialize the confusion matrix
- self.index_to_label_dict = ["Infected", "Uninfected"]
-
- # Define the forward pass
- def forward(self, x):
- return self.unet_model(x) # Pass the input through the UNet model
-
- # Define the optimizer
- def configure_optimizers(self):
- optimizer = torch.optim.Adam(
- self.parameters(), lr=self.lr
- ) # Use the Adam optimizer
- return optimizer
-
- # Define the training step
- def training_step(self, batch: Sample, batch_idx: int):
- source = batch["source"] # Extract the source from the batch
- target = batch["target"] # Extract the target from the batch
- pred = self.forward(source) # Make a prediction using the source
- # Convert the target to one-hot encoding
- target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
- 0, 4, 1, 2, 3
- )
- target_one_hot = target_one_hot.float() # Convert the target to float type
- train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss
- # Log the training step outputs if the batch index is less than the number of batches to log per epoch
- if batch_idx < self.log_batches_per_epoch:
- self.training_step_outputs.extend(
- self._detach_sample((source, target_one_hot, pred))
- )
- # Log the training loss
- self.log(
- "loss/train",
- train_loss,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- sync_dist=True,
- )
- return train_loss # Return the training loss
-
- def validation_step(self, batch: Sample, batch_idx: int):
- source = batch["source"] # Extract the source from the batch
- target = batch["target"] # Extract the target from the batch
- pred = self.forward(source) # Make a prediction using the source
- # Convert the target to one-hot encoding
- target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
- 0, 4, 1, 2, 3
- )
- target_one_hot = target_one_hot.float() # Convert the target to float type
- loss = self.loss_function(pred, target_one_hot) # Calculate the loss
- # Log the validation step outputs if the batch index is less than the number of batches to log per epoch
- if batch_idx < self.log_batches_per_epoch:
- self.validation_step_outputs.extend(
- self._detach_sample((source, target_one_hot, pred))
- )
- # Log the validation loss
- self.log(
- "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True
- )
- return loss # Return the validation loss
-
- def on_predict_start(self):
- """Pad the input shape to be divisible by the downsampling factor.
- The inverse of this transform crops the prediction to original shape.
- """
- down_factor = 2**self.unet_model.num_blocks
- self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
-
- # Define the prediction step
- def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
- source = self._predict_pad(batch["source"]) # Pad the source
- logits = self._predict_pad.inverse(
- self.forward(source)
- ) # Predict and remove padding.
- prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities
- # Go from probabilities/one-hot encoded data to class labels.
- labels_pred = torch.argmax(
- prob_pred, dim=1, keepdim=True
- ) # Calculate the predicted labels
- # prob_chan = prob_pred[:, 2, :, :]
- # prob_chan = prob_chan.unsqueeze(1)
- return labels_pred # log the class predicted image
- # return prob_chan # log the probability predicted image
-
- def on_test_start(self):
- self.pred_cm = torch.zeros((2, 2))
- down_factor = 2**self.unet_model.num_blocks
- self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
-
- def test_step(self, batch: Sample):
- source = self._predict_pad(batch["source"]) # Pad the source
- logits = self._predict_pad.inverse(self.forward(source))
- prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities
- labels_pred = torch.argmax(
- prob_pred, dim=1, keepdim=True
- ) # Calculate the predicted labels
-
- target = self._predict_pad(batch["target"]) # Extract the target from the batch
- pred_cm = confusion_matrix_per_cell(
- target, labels_pred, num_classes=2
- ) # Calculate the confusion matrix per cell
- self.pred_cm += pred_cm # Append the confusion matrix to pred_cm
-
- self.logger.experiment.add_figure(
- "Confusion Matrix per Cell",
- plot_confusion_matrix(pred_cm, self.index_to_label_dict),
- self.current_epoch,
- )
-
- # Accumulate the confusion matrix at the end of test epoch and log.
- def on_test_end(self):
- confusion_matrix_sum = self.pred_cm
- self.logger.experiment.add_figure(
- "Confusion Matrix",
- plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict),
- self.current_epoch,
- )
-
- # Define what happens at the end of a training epoch
- def on_train_epoch_end(self):
- self._log_samples(
- "train_samples", self.training_step_outputs
- ) # Log the training samples
- self.training_step_outputs = [] # Reset the list of training step outputs
-
- # Define what happens at the end of a validation epoch
- def on_validation_epoch_end(self):
- self._log_samples(
- "val_samples", self.validation_step_outputs
- ) # Log the validation samples
- self.validation_step_outputs = [] # Reset the list of validation step outputs
-
- # Define a method to detach a sample
- def _detach_sample(self, imgs: Sequence[Tensor]):
- # Detach the images and convert them to numpy arrays
- num_samples = 3
- return [
- [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs]
- for i in range(num_samples)
- ]
-
- # Define a method to log samples
- def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
- images_grid = [] # Initialize the list of image grids
- for sample_images in imgs: # For each sample image
- images_row = [] # Initialize the list of image rows
- for i, image in enumerate(
- sample_images
- ): # For each image in the sample images
- cm_name = "gray" if i == 0 else "inferno" # Set the colormap name
- if image.ndim == 2: # If the image is 2D
- image = image[np.newaxis] # Add a new axis
- for channel in image: # For each channel in the image
- channel = rescale_intensity(
- channel, out_range=(0, 1)
- ) # Rescale the intensity of the channel
- render = get_cmap(cm_name)(channel, bytes=True)[
- ..., :3
- ] # Render the channel
- images_row.append(
- render
- ) # Append the render to the list of image rows
- images_grid.append(
- np.concatenate(images_row, axis=1)
- ) # Append the concatenated image rows to the list of image grids
- grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids
- # Log the image grid
- self.logger.experiment.add_image(
- key, grid, self.current_epoch, dataformats="HWC"
- )
-
-
-# %%
diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py
deleted file mode 100644
index 397e822db..000000000
--- a/applications/infection_classification/classify_infection_covnext.py
+++ /dev/null
@@ -1,363 +0,0 @@
-# import torchview
-from typing import Literal, Sequence
-
-import cv2
-import lightning.pytorch as pl
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from matplotlib.cm import get_cmap
-from monai.transforms import DivisiblePad
-from skimage.exposure import rescale_intensity
-from skimage.measure import label, regionprops
-from torch import Tensor
-
-from viscy.data.hcs import Sample
-from viscy.translation.engine import VSUNet
-
-#
-# %% Methods to compute confusion matrix per cell using torchmetrics
-
-
-# The confusion matrix at the single-cell resolution.
-def confusion_matrix_per_cell(
- y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int
-):
- """Compute confusion matrix per cell.
-
- Args:
- y_true (torch.Tensor): Ground truth label image (BXHXW).
- y_pred (torch.Tensor): Predicted label image (BXHXW).
- num_classes (int): Number of classes.
-
- Returns:
- torch.Tensor: Confusion matrix per cell (BXCXC).
- """
- # Convert the image class to the nuclei class
- confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes)
- confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell)
- return confusion_matrix_per_cell
-
-
-# These images can be logged with prediction.
-def compute_confusion_matrix(
- y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int
-):
- """Convert the class of the image to the class of the nuclei.
-
- Args:
- label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation.
- num_classes (int): Number of classes.
-
- Returns:
- torch.Tensor: Label images with a consensus class at the centroid of nuclei.
- """
-
- batch_size = y_true.size(0)
- # find centroids of nuclei from y_true
- conf_mat = np.zeros((num_classes, num_classes))
- for i in range(batch_size):
- y_true_cpu = y_true[i].cpu().numpy()
- y_pred_cpu = y_pred[i].cpu().numpy()
- y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:])
- y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:])
- y_pred_resized = cv2.resize(
- y_pred_reshaped,
- dsize=y_true_reshaped.shape[::-1],
- interpolation=cv2.INTER_NEAREST,
- )
- y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0)
-
- # find objects in every image
- label_img = label(y_true_reshaped)
- regions = regionprops(label_img)
-
- # Find centroids, pixel coordinates from the ground truth.
- for region in regions:
- if region.area > 0:
- row, col = region.centroid
- pred_id = y_pred_resized[int(row), int(col)]
- test_id = y_true_reshaped[int(row), int(col)]
-
- if pred_id == 1 and test_id == 1:
- conf_mat[1, 1] += 1
- if pred_id == 1 and test_id == 2:
- conf_mat[0, 1] += 1
- if pred_id == 2 and test_id == 1:
- conf_mat[1, 0] += 1
- if pred_id == 2 and test_id == 2:
- conf_mat[0, 0] += 1
- # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
- # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
- return conf_mat
-
-
-def plot_confusion_matrix(confusion_matrix, index_to_label_dict):
- # Create a figure and axis to plot the confusion matrix
- fig, ax = plt.subplots()
-
- # Create a color heatmap for the confusion matrix
- cax = ax.matshow(confusion_matrix, cmap="viridis")
-
- # Create a colorbar and set the label
- index_to_label_dict = dict(
- enumerate(index_to_label_dict)
- ) # Convert list to dictionary
- fig.colorbar(cax, label="Frequency")
-
- # Set labels for the classes
- ax.set_xticks(np.arange(len(index_to_label_dict)))
- ax.set_yticks(np.arange(len(index_to_label_dict)))
- ax.set_xticklabels(index_to_label_dict.values(), rotation=45)
- ax.set_yticklabels(index_to_label_dict.values())
-
- # Set labels for the axes
- ax.set_xlabel("Predicted")
- ax.set_ylabel("True")
-
- # Add text annotations to the confusion matrix
- for i in range(len(index_to_label_dict)):
- for j in range(len(index_to_label_dict)):
- ax.text(
- j,
- i,
- str(int(confusion_matrix[i, j])),
- ha="center",
- va="center",
- color="white",
- )
-
- # plt.show(fig) # Show the figure
- return fig
-
-
-# Define a 25d unet model for infection classification as a lightning module.
-
-
-class SemanticSegUNet22D(pl.LightningModule):
- # Model for semantic segmentation.
- def __init__(
- self,
- in_channels: int, # Number of input channels
- out_channels: int, # Number of output channels
- lr: float = 1e-3, # Learning rate
- loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function
- schedule: Literal[
- "WarmupCosine", "Constant"
- ] = "Constant", # Learning rate schedule
- log_batches_per_epoch: int = 2, # Number of batches to log per epoch
- log_samples_per_batch: int = 2, # Number of samples to log per batch
- ckpt_path: str = None, # Path to the checkpoint
- ):
- super(SemanticSegUNet22D, self).__init__() # Call the superclass initializer
- # Initialize the UNet model
- self.unet_model = VSUNet(
- architecture="2.2D",
- model_config={
- "in_channels": in_channels,
- "out_channels": out_channels,
- "in_stack_depth": 5,
- "backbone": "convnextv2_tiny",
- "stem_kernel_size": (5, 4, 4),
- "decoder_mode": "pixelshuffle",
- "head_expansion_ratio": 4,
- },
- )
- if ckpt_path is not None:
- state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
- "state_dict"
- ]
- state_dict.pop("loss_function.weight", None) # Remove the unexpected key
- self.load_state_dict(state_dict) # loading only weights
- self.lr = lr # Set the learning rate
- # Set the loss function to CrossEntropyLoss if none is provided
- self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss()
- self.schedule = schedule # Set the learning rate schedule
- self.log_batches_per_epoch = (
- log_batches_per_epoch # Set the number of batches to log per epoch
- )
- self.log_samples_per_batch = (
- log_samples_per_batch # Set the number of samples to log per batch
- )
- self.training_step_outputs = [] # Initialize the list of training step outputs
- self.validation_step_outputs = (
- []
- ) # Initialize the list of validation step outputs
-
- self.pred_cm = None # Initialize the confusion matrix
- self.index_to_label_dict = ["Infected", "Uninfected"]
-
- # Define the forward pass
- def forward(self, x):
- return self.unet_model(x) # Pass the input through the UNet model
-
- # Define the optimizer
- def configure_optimizers(self):
- optimizer = torch.optim.Adam(
- self.parameters(), lr=self.lr
- ) # Use the Adam optimizer
- return optimizer
-
- # Define the training step
- def training_step(self, batch: Sample, batch_idx: int):
- source = batch["source"] # Extract the source from the batch
- target = batch["target"] # Extract the target from the batch
- pred = self.forward(source) # Make a prediction using the source
- # Convert the target to one-hot encoding
- target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
- 0, 4, 1, 2, 3
- )
- target_one_hot = target_one_hot.float() # Convert the target to float type
- train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss
- # Log the training step outputs if the batch index is less than the number of batches to log per epoch
- if batch_idx < self.log_batches_per_epoch:
- self.training_step_outputs.extend(
- self._detach_sample((source, target_one_hot, pred))
- )
- # Log the training loss
- self.log(
- "loss/train",
- train_loss,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- sync_dist=True,
- )
- return train_loss # Return the training loss
-
- def validation_step(self, batch: Sample, batch_idx: int):
- source = batch["source"] # Extract the source from the batch
- target = batch["target"] # Extract the target from the batch
- pred = self.forward(source) # Make a prediction using the source
- # Convert the target to one-hot encoding
- target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
- 0, 4, 1, 2, 3
- )
- target_one_hot = target_one_hot.float() # Convert the target to float type
- loss = self.loss_function(pred, target_one_hot) # Calculate the loss
- # Log the validation step outputs if the batch index is less than the number of batches to log per epoch
- if batch_idx < self.log_batches_per_epoch:
- self.validation_step_outputs.extend(
- self._detach_sample((source, target_one_hot, pred))
- )
- # Log the validation loss
- self.log(
- "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True
- )
- return loss # Return the validation loss
-
- def on_predict_start(self):
- """Pad the input shape to be divisible by the downsampling factor.
- The inverse of this transform crops the prediction to original shape.
- """
- down_factor = 2**self.unet_model.num_blocks
- self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
-
- # Define the prediction step
- def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
- source = self._predict_pad(batch["source"]) # Pad the source
- logits = self._predict_pad.inverse(
- self.forward(source)
- ) # Predict and remove padding.
- prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities
- # Go from probabilities/one-hot encoded data to class labels.
- labels_pred = torch.argmax(
- prob_pred, dim=1, keepdim=True
- ) # Calculate the predicted labels
- # prob_chan = prob_pred[:, 2, :, :]
- # prob_chan = prob_chan.unsqueeze(1)
- return labels_pred # log the class predicted image
- # return prob_chan # log the probability predicted image
-
- def on_test_start(self):
- self.pred_cm = torch.zeros((2, 2))
- down_factor = 2**self.unet_model.num_blocks
- self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
-
- def test_step(self, batch: Sample):
- source = self._predict_pad(batch["source"]) # Pad the source
- logits = self._predict_pad.inverse(self.forward(source))
- prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities
- labels_pred = torch.argmax(
- prob_pred, dim=1, keepdim=True
- ) # Calculate the predicted labels
-
- target = self._predict_pad(batch["target"]) # Extract the target from the batch
- pred_cm = confusion_matrix_per_cell(
- target, labels_pred, num_classes=2
- ) # Calculate the confusion matrix per cell
- self.pred_cm += pred_cm # Append the confusion matrix to pred_cm
-
- self.logger.experiment.add_figure(
- "Confusion Matrix per Cell",
- plot_confusion_matrix(pred_cm, self.index_to_label_dict),
- self.current_epoch,
- )
-
- # Accumulate the confusion matrix at the end of test epoch and log.
- def on_test_end(self):
- confusion_matrix_sum = self.pred_cm
- self.logger.experiment.add_figure(
- "Confusion Matrix",
- plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict),
- self.current_epoch,
- )
-
- # Define what happens at the end of a training epoch
- def on_train_epoch_end(self):
- self._log_samples(
- "train_samples", self.training_step_outputs
- ) # Log the training samples
- self.training_step_outputs = [] # Reset the list of training step outputs
-
- # Define what happens at the end of a validation epoch
- def on_validation_epoch_end(self):
- self._log_samples(
- "val_samples", self.validation_step_outputs
- ) # Log the validation samples
- self.validation_step_outputs = [] # Reset the list of validation step outputs
-
- # Define a method to detach a sample
- def _detach_sample(self, imgs: Sequence[Tensor]):
- # Detach the images and convert them to numpy arrays
- num_samples = 3
- return [
- [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs]
- for i in range(num_samples)
- ]
-
- # Define a method to log samples
- def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
- images_grid = [] # Initialize the list of image grids
- for sample_images in imgs: # For each sample image
- images_row = [] # Initialize the list of image rows
- for i, image in enumerate(
- sample_images
- ): # For each image in the sample images
- cm_name = "gray" if i == 0 else "inferno" # Set the colormap name
- if image.ndim == 2: # If the image is 2D
- image = image[np.newaxis] # Add a new axis
- for channel in image: # For each channel in the image
- channel = rescale_intensity(
- channel, out_range=(0, 1)
- ) # Rescale the intensity of the channel
- render = get_cmap(cm_name)(channel, bytes=True)[
- ..., :3
- ] # Render the channel
- images_row.append(
- render
- ) # Append the render to the list of image rows
- images_grid.append(
- np.concatenate(images_row, axis=1)
- ) # Append the concatenated image rows to the list of image grids
- grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids
- # Log the image grid
- self.logger.experiment.add_image(
- key, grid, self.current_epoch, dataformats="HWC"
- )
-
-
-# %%
diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py
new file mode 100644
index 000000000..78a5a719b
--- /dev/null
+++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py
@@ -0,0 +1,766 @@
+# %%
+import ast
+import logging
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from plotting_utils import (
+ find_pattern_matches,
+ identify_lineages,
+ plot_pc_trajectories,
+ plot_reference_vs_full_lineages,
+)
+from sklearn.decomposition import PCA
+from sklearn.preprocessing import StandardScaler
+from tqdm import tqdm
+
+from viscy.data.triplet import TripletDataModule
+from viscy.representation.embedding_writer import read_embedding_dataset
+from viscy.representation.evaluation.dimensionality_reduction import compute_pca
+
+logger = logging.getLogger("viscy")
+logger.setLevel(logging.INFO)
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.INFO)
+formatter = logging.Formatter("%(message)s") # Simplified format
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
+
+
+NAPARI = True
+if NAPARI:
+ import os
+
+ import napari
+s
+ os.environ["DISPLAY"] = ":1"
+ viewer = napari.Viewer()
+# %%
+# Organelle and Phate aligned to infection
+
+input_data_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr"
+)
+tracks_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr"
+)
+infection_annotations_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/combined_annotations_n_tracks_infection.csv"
+)
+
+pretrain_features_root = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/prediction_pretrained_models"
+)
+# Phase n organelle
+# dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr"
+
+# pahe n sensor
+dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions_infection/2chan_192patch_100ckpt_timeAware_ntxent_GT.zarr"
+
+output_root = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/model_comparison"
+)
+
+
+# Load embeddings
+imagenet_features_path = (
+ pretrain_features_root / "ImageNet/20241107_sensor_n_phase_imagenet.zarr"
+)
+openphenom_features_path = (
+ pretrain_features_root / "OpenPhenom/20241107_sensor_n_phase_openphenom.zarr"
+)
+
+dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path)
+imagenet_embeddings = read_embedding_dataset(imagenet_features_path)
+openphenom_embeddings = read_embedding_dataset(openphenom_features_path)
+
+# Load infection annotations
+infection_annotations_df = pd.read_csv(infection_annotations_path)
+infection_annotations_df["fov_name"] = "/C/2/000001"
+
+process_embeddings = [
+ (dynaclr_embeddings, "dynaclr"),
+ (imagenet_embeddings, "imagenet"),
+ (openphenom_embeddings, "openphenom"),
+]
+
+
+output_root.mkdir(parents=True, exist_ok=True)
+# %%
+feature_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True)
+
+# Logic to find lineages
+lineages = identify_lineages(feature_df)
+logger.info(f"Found {len(lineages)} distinct lineages")
+filtered_lineages = []
+min_timepoints = 20
+for fov_id, track_ids in lineages:
+ # Get all rows for this lineage
+ lineage_rows = feature_df[
+ (feature_df["fov_name"] == fov_id) & (feature_df["track_id"].isin(track_ids))
+ ]
+
+ # Count the total number of timepoints
+ total_timepoints = len(lineage_rows)
+
+ # Only keep lineages with at least min_timepoints
+ if total_timepoints >= min_timepoints:
+ filtered_lineages.append((fov_id, track_ids))
+logger.info(
+ f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints"
+)
+
+# %%
+# Aligning condition embeddings to infection
+# OPTION 1: Use the infection annotations to find the reference lineage
+reference_lineage_fov = "/C/2/001000"
+reference_lineage_track_id = [129]
+reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling
+
+# Option 2: from the filtered lineages find one from FOV C/2/000001
+reference_lineage_fov = "/C/2/000001"
+for fov_id, track_ids in filtered_lineages:
+ if reference_lineage_fov == fov_id:
+ break
+reference_lineage_track_id = track_ids
+reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling
+
+# %%
+# Dictionary to store alignment results for comparison
+alignment_results = {}
+
+for embeddings, name in process_embeddings:
+ # Get the reference pattern from the current embedding space
+ reference_pattern = None
+ reference_lineage = []
+ for fov_id, track_ids in filtered_lineages:
+ if fov_id == reference_lineage_fov and all(
+ track_id in track_ids for track_id in reference_lineage_track_id
+ ):
+ logger.info(
+ f"Found reference pattern for {fov_id} {reference_lineage_track_id} using {name} embeddings"
+ )
+ reference_pattern = embeddings.sel(
+ sample=(fov_id, reference_lineage_track_id)
+ ).features.values
+ reference_lineage.append(reference_pattern)
+ break
+ if reference_pattern is None:
+ logger.info(f"Reference pattern not found for {name} embeddings. Skipping.")
+ continue
+ reference_pattern = np.concatenate(reference_lineage)
+ reference_pattern = reference_pattern[
+ reference_timepoints[0] : reference_timepoints[1]
+ ]
+
+ # Find all matches to the reference pattern
+ metric = "cosine"
+ all_match_positions = find_pattern_matches(
+ reference_pattern,
+ filtered_lineages,
+ embeddings,
+ window_step_fraction=0.1,
+ num_candidates=4,
+ method="bernd_clifford",
+ save_path=output_root / f"{name}_matching_lineages_{metric}.csv",
+ metric=metric,
+ )
+
+ # Store results for later comparison
+ alignment_results[name] = all_match_positions
+
+# Visualize warping paths in PC space instead of raw embedding dimensions
+for name, match_positions in alignment_results.items():
+ if match_positions is not None and not match_positions.empty:
+ # Call the new function from plotting_utils
+ plot_pc_trajectories(
+ reference_lineage_fov=reference_lineage_fov,
+ reference_lineage_track_id=reference_lineage_track_id,
+ reference_timepoints=reference_timepoints,
+ match_positions=match_positions,
+ embeddings_dataset=next(
+ emb for emb, emb_name in process_embeddings if emb_name == name
+ ),
+ filtered_lineages=filtered_lineages,
+ name=name,
+ save_path=output_root / f"{name}_pc_lineage_alignment.png",
+ )
+
+
+# %%
+# Compare DTW performance between embedding methods
+
+# Create a DataFrame to collect the alignment statistics for comparison
+match_data = []
+for name, match_positions in alignment_results.items():
+ if match_positions is not None and not match_positions.empty:
+ for i, row in match_positions.head(10).iterrows(): # Take top 10 matches
+ warping_path = (
+ ast.literal_eval(row["warp_path"])
+ if isinstance(row["warp_path"], str)
+ else row["warp_path"]
+ )
+ match_data.append(
+ {
+ "model": name,
+ "match_position": row["start_timepoint"],
+ "dtw_distance": row["distance"],
+ "path_skewness": row["skewness"],
+ "path_length": len(warping_path),
+ }
+ )
+
+comparison_df = pd.DataFrame(match_data)
+
+# Create visualizations to compare alignment quality
+plt.figure(figsize=(12, 10))
+
+# 1. Compare DTW distances
+plt.subplot(2, 2, 1)
+sns.boxplot(x="model", y="dtw_distance", data=comparison_df)
+plt.title("DTW Distance by Model")
+plt.ylabel("DTW Distance (lower is better)")
+
+# 2. Compare path skewness
+plt.subplot(2, 2, 2)
+sns.boxplot(x="model", y="path_skewness", data=comparison_df)
+plt.title("Path Skewness by Model")
+plt.ylabel("Skewness (lower is better)")
+
+# 3. Compare path lengths
+plt.subplot(2, 2, 3)
+sns.boxplot(x="model", y="path_length", data=comparison_df)
+plt.title("Warping Path Length by Model")
+plt.ylabel("Path Length")
+
+# 4. Scatterplot of distance vs skewness
+plt.subplot(2, 2, 4)
+scatter = sns.scatterplot(
+ x="dtw_distance", y="path_skewness", hue="model", data=comparison_df
+)
+plt.title("DTW Distance vs Path Skewness")
+plt.xlabel("DTW Distance")
+plt.ylabel("Path Skewness")
+plt.legend(title="Model")
+
+plt.tight_layout()
+plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300)
+plt.close()
+
+# %%
+# Analyze warping path step patterns for better understanding of alignment quality
+
+# Step pattern analysis
+step_pattern_counts = {
+ name: {"diagonal": 0, "horizontal": 0, "vertical": 0, "total": 0}
+ for name in alignment_results.keys()
+}
+
+for name, match_positions in alignment_results.items():
+ if match_positions is not None and not match_positions.empty:
+ # Get the top match
+ top_match = match_positions.iloc[0]
+ path = (
+ ast.literal_eval(top_match["warp_path"])
+ if isinstance(top_match["warp_path"], str)
+ else top_match["warp_path"]
+ )
+
+ # Count step types
+ for i in range(1, len(path)):
+ prev_i, prev_j = path[i - 1]
+ curr_i, curr_j = path[i]
+
+ step_i = curr_i - prev_i
+ step_j = curr_j - prev_j
+
+ if step_i == 1 and step_j == 1:
+ step_pattern_counts[name]["diagonal"] += 1
+ elif step_i == 1 and step_j == 0:
+ step_pattern_counts[name]["vertical"] += 1
+ elif step_i == 0 and step_j == 1:
+ step_pattern_counts[name]["horizontal"] += 1
+
+ step_pattern_counts[name]["total"] += 1
+
+# Convert to percentages
+for name in step_pattern_counts:
+ total = step_pattern_counts[name]["total"]
+ if total > 0:
+ for key in ["diagonal", "horizontal", "vertical"]:
+ step_pattern_counts[name][key] = (
+ step_pattern_counts[name][key] / total
+ ) * 100
+
+# Visualize step pattern distributions
+step_df = pd.DataFrame(
+ {
+ "model": [name for name in step_pattern_counts.keys() for _ in range(3)],
+ "step_type": ["diagonal", "horizontal", "vertical"] * len(step_pattern_counts),
+ "percentage": [
+ step_pattern_counts[name]["diagonal"] for name in step_pattern_counts.keys()
+ ]
+ + [
+ step_pattern_counts[name]["horizontal"]
+ for name in step_pattern_counts.keys()
+ ]
+ + [
+ step_pattern_counts[name]["vertical"] for name in step_pattern_counts.keys()
+ ],
+ }
+)
+
+plt.figure(figsize=(10, 6))
+sns.barplot(x="model", y="percentage", hue="step_type", data=step_df)
+plt.title("Step Pattern Distribution in Warping Paths")
+plt.ylabel("Percentage (%)")
+plt.savefig(output_root / "step_pattern_distribution.png", dpi=300)
+plt.close()
+
+# %%
+# Find all matches to the reference pattern
+MODEL = "openphenom"
+alignment_df_path = output_root / f"{MODEL}_matching_lineages_cosine.csv"
+alignment_df = pd.read_csv(alignment_df_path)
+
+# Get the top N aligned cells
+
+source_channels = [
+ "Phase3D",
+ "raw GFP EX488 EM525-45",
+ "raw mCherry EX561 EM600-37",
+]
+yx_patch_size = (192, 192)
+z_range = (10, 30)
+view_ref_sector_only = (True,)
+
+all_lineage_images = []
+all_aligned_stacks = []
+all_unaligned_stacks = []
+
+# Get aligned and unaligned stacks
+top_aligned_cells = alignment_df.head(5)
+napari_viewer = viewer if NAPARI else None
+# Plot the aligned and unaligned stacks
+for idx, row in tqdm(
+ top_aligned_cells.iterrows(),
+ total=len(top_aligned_cells),
+ desc="Aligning images",
+):
+ fov_name = row["fov_name"]
+ track_ids = ast.literal_eval(row["track_ids"])
+ warp_path = ast.literal_eval(row["warp_path"])
+ start_time = int(row["start_timepoint"])
+
+ print(f"Aligning images for {fov_name} with track ids: {track_ids}")
+ data_module = TripletDataModule(
+ data_path=input_data_path,
+ tracks_path=tracks_path,
+ source_channel=source_channels,
+ z_range=z_range,
+ initial_yx_patch_size=yx_patch_size,
+ final_yx_patch_size=yx_patch_size,
+ batch_size=1,
+ num_workers=12,
+ predict_cells=True,
+ include_fov_names=[fov_name] * len(track_ids),
+ include_track_ids=track_ids,
+ )
+ data_module.setup("predict")
+
+ # Get the images for the lineage
+ lineage_images = []
+ for batch in data_module.predict_dataloader():
+ image = batch["anchor"].numpy()[0]
+ lineage_images.append(image)
+
+ lineage_images = np.array(lineage_images)
+ all_lineage_images.append(lineage_images)
+ print(f"Lineage images shape: {np.array(lineage_images).shape}")
+
+ # Create an aligned stack based on the warping path
+ if view_ref_sector_only:
+ aligned_stack = np.zeros(
+ (len(reference_pattern),) + lineage_images.shape[-4:],
+ dtype=lineage_images.dtype,
+ )
+ unaligned_stack = np.zeros(
+ (len(reference_pattern),) + lineage_images.shape[-4:],
+ dtype=lineage_images.dtype,
+ )
+
+ # Map each reference timepoint to the corresponding lineage timepoint
+ for ref_idx in range(len(reference_pattern)):
+ # Find matches in warping path for this reference index
+ matches = [(i, q) for i, q in warp_path if i == ref_idx]
+ unaligned_stack[ref_idx] = lineage_images[ref_idx]
+ if matches:
+ # Get the corresponding lineage timepoint (first match if multiple)
+ print(f"Found match for ref idx: {ref_idx}")
+ match = matches[0]
+ query_idx = match[1]
+ lineage_idx = int(start_time + query_idx)
+ print(
+ f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}"
+ )
+ # Copy the image if it's within bounds
+ if 0 <= lineage_idx < len(lineage_images):
+ aligned_stack[ref_idx] = lineage_images[lineage_idx]
+ else:
+ # Find nearest valid timepoint if out of bounds
+ nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1)
+ aligned_stack[ref_idx] = lineage_images[nearest_idx]
+ else:
+ # If no direct match, find closest reference timepoint in warping path
+ print(f"No match found for ref idx: {ref_idx}")
+ all_ref_indices = [i for i, _ in warp_path]
+ if all_ref_indices:
+ closest_ref_idx = min(
+ all_ref_indices, key=lambda x: abs(x - ref_idx)
+ )
+ closest_matches = [
+ (i, q) for i, q in warp_path if i == closest_ref_idx
+ ]
+
+ if closest_matches:
+ closest_query_idx = closest_matches[0][1]
+ lineage_idx = int(start_time + closest_query_idx)
+
+ if 0 <= lineage_idx < len(lineage_images):
+ aligned_stack[ref_idx] = lineage_images[lineage_idx]
+ else:
+ # Bound to valid range
+ nearest_idx = min(
+ max(0, lineage_idx), len(lineage_images) - 1
+ )
+ aligned_stack[ref_idx] = lineage_images[nearest_idx]
+
+ all_aligned_stacks.append(aligned_stack)
+ all_unaligned_stacks.append(unaligned_stack)
+
+all_aligned_stacks = np.array(all_aligned_stacks)
+all_unaligned_stacks = np.array(all_unaligned_stacks)
+# %%
+if NAPARI:
+ for idx, row in tqdm(
+ top_aligned_cells.reset_index().iterrows(),
+ total=len(top_aligned_cells),
+ desc="Plotting aligned and unaligned stacks",
+ ):
+ fov_name = row["fov_name"]
+ # track_ids = ast.literal_eval(row["track_ids"])
+ track_ids = row["track_ids"]
+
+ aligned_stack = all_aligned_stacks[idx]
+ unaligned_stack = all_unaligned_stacks[idx]
+
+ unaligned_gfp_mip = np.max(unaligned_stack[:, 1, :, :], axis=1)
+ aligned_gfp_mip = np.max(aligned_stack[:, 1, :, :], axis=1)
+ unaligned_mcherry_mip = np.max(unaligned_stack[:, 2, :, :], axis=1)
+ aligned_mcherry_mip = np.max(aligned_stack[:, 2, :, :], axis=1)
+
+ z_slice = 15
+ unaligned_phase = unaligned_stack[:, 0, z_slice, :]
+ aligned_phase = aligned_stack[:, 0, z_slice, :]
+
+ # unaligned
+ viewer.add_image(
+ unaligned_gfp_mip,
+ name=f"unaligned_gfp_{fov_name}_{track_ids[0]}",
+ colormap="green",
+ contrast_limits=(106, 215),
+ )
+ viewer.add_image(
+ unaligned_mcherry_mip,
+ name=f"unaligned_mcherry_{fov_name}_{track_ids[0]}",
+ colormap="magenta",
+ contrast_limits=(106, 190),
+ )
+ viewer.add_image(
+ unaligned_phase,
+ name=f"unaligned_phase_{fov_name}_{track_ids[0]}",
+ colormap="gray",
+ contrast_limits=(-0.74, 0.4),
+ )
+ # aligned
+ viewer.add_image(
+ aligned_gfp_mip,
+ name=f"aligned_gfp_{fov_name}_{track_ids[0]}",
+ colormap="green",
+ contrast_limits=(106, 215),
+ )
+ viewer.add_image(
+ aligned_mcherry_mip,
+ name=f"aligned_mcherry_{fov_name}_{track_ids[0]}",
+ colormap="magenta",
+ contrast_limits=(106, 190),
+ )
+ viewer.add_image(
+ aligned_phase,
+ name=f"aligned_phase_{fov_name}_{track_ids[0]}",
+ colormap="gray",
+ contrast_limits=(-0.74, 0.4),
+ )
+ viewer.grid.enabled = True
+ viewer.grid.shape = (-1, 6)
+# %%
+# Evaluate model performance based on infection state warping accuracy
+# Check unique infection status values
+unique_infection_statuses = infection_annotations_df["infection_status"].unique()
+logger.info(f"Unique infection status values: {unique_infection_statuses}")
+
+# If "infected" is not in the unique values, this could explain zero precision/recall
+if "infected" not in unique_infection_statuses:
+ logger.warning('The label "infected" is not found in the infection_status column!')
+ logger.info(f"Using these values instead: {unique_infection_statuses}")
+
+ # If we need to map values, we could do it here
+ if len(unique_infection_statuses) >= 2:
+ logger.info(
+ f'Will treat "{unique_infection_statuses[1]}" as "infected" for metrics calculation'
+ )
+ infection_target_value = unique_infection_statuses[1]
+ else:
+ infection_target_value = unique_infection_statuses[0]
+else:
+ infection_target_value = "infected"
+
+logger.info(f'Using "{infection_target_value}" as positive class for F1 calculation')
+
+# Check if the reference track is in the annotations
+logger.info(
+ f"Looking for infection annotations for reference lineage: {reference_lineage_fov}, tracks: {reference_lineage_track_id}"
+)
+print(f"Sample of infection_annotations_df: {infection_annotations_df.head()}")
+
+reference_infection_states = {}
+for track_id in reference_lineage_track_id:
+ reference_annotations = infection_annotations_df[
+ (infection_annotations_df["fov_name"] == reference_lineage_fov)
+ & (infection_annotations_df["track_id"] == track_id)
+ ]
+
+ # Add annotations for this reference track
+ annotation_count = len(reference_annotations)
+ logger.info(f"Found {annotation_count} annotations for track {track_id}")
+ if annotation_count > 0:
+ print(
+ f"Sample annotations for track {track_id}: {reference_annotations.head()}"
+ )
+
+ for _, row in reference_annotations.iterrows():
+ reference_infection_states[row["t"]] = row["infection_status"]
+
+if reference_infection_states:
+ logger.info(
+ f"Total reference timepoints with infection status: {len(reference_infection_states)}"
+ )
+ reference_t_range = range(reference_timepoints[0], reference_timepoints[1])
+ reference_gt_states = [
+ reference_infection_states.get(t, "unknown") for t in reference_t_range
+ ]
+ logger.info(f"Reference track infection states: {reference_gt_states[:5]}...")
+
+ # Evaluate warping accuracy for each model
+ model_performance = []
+
+ for name, match_positions in alignment_results.items():
+ if match_positions is not None and not match_positions.empty:
+ total_correct = 0
+ total_predictions = 0
+ true_positives = 0
+ false_positives = 0
+ false_negatives = 0
+
+ # Analyze top alignments for this model
+ alignment_details = []
+ for i, row in match_positions.head(10).iterrows():
+ fov_name = row["fov_name"]
+ track_ids = row[
+ "track_ids"
+ ] # This is already a list of track IDs for the lineage
+ warp_path = (
+ ast.literal_eval(row["warp_path"])
+ if isinstance(row["warp_path"], str)
+ else row["warp_path"]
+ )
+ start_time = int(row["start_timepoint"])
+
+ # Get annotations for all tracks in this lineage
+ track_infection_states = {}
+ for track_id in track_ids:
+ track_annotations = infection_annotations_df[
+ (infection_annotations_df["fov_name"] == fov_name)
+ & (infection_annotations_df["track_id"] == track_id)
+ ]
+
+ # Add annotations for this track to the combined dictionary
+ for _, annotation_row in track_annotations.iterrows():
+ # Use t + track-specific offset if needed to handle timepoint overlaps between tracks
+ track_infection_states[annotation_row["t"]] = annotation_row[
+ "infection_status"
+ ]
+
+ # Only proceed if we found annotations for at least one track
+ if track_infection_states:
+ # For each reference timepoint, check if the warped timepoint maintains the infection state
+ track_correct = 0
+ track_predictions = 0
+ track_tp = 0
+ track_fp = 0
+ track_fn = 0
+
+ for ref_idx, query_idx in warp_path:
+ # Map to actual timepoints
+ ref_t = reference_timepoints[0] + ref_idx
+ query_t = start_time + query_idx
+
+ # Get ground truth infection states
+ ref_state = reference_infection_states.get(ref_t, "unknown")
+ query_state = track_infection_states.get(query_t, "unknown")
+
+ # Skip unknown states
+ if ref_state != "unknown" and query_state != "unknown":
+ track_predictions += 1
+
+ # Count correct alignments
+ if ref_state == query_state:
+ track_correct += 1
+
+ # Calculate F1 score components for "infected" state
+ if (
+ ref_state == infection_target_value
+ and query_state == infection_target_value
+ ):
+ track_tp += 1
+ elif (
+ ref_state != infection_target_value
+ and query_state == infection_target_value
+ ):
+ track_fp += 1
+ elif (
+ ref_state == infection_target_value
+ and query_state != infection_target_value
+ ):
+ track_fn += 1
+
+ # Calculate track-specific metrics
+ if track_predictions > 0:
+ track_accuracy = track_correct / track_predictions
+ track_precision = (
+ track_tp / (track_tp + track_fp)
+ if (track_tp + track_fp) > 0
+ else 0
+ )
+ track_recall = (
+ track_tp / (track_tp + track_fn)
+ if (track_tp + track_fn) > 0
+ else 0
+ )
+ track_f1 = (
+ 2
+ * (track_precision * track_recall)
+ / (track_precision + track_recall)
+ if (track_precision + track_recall) > 0
+ else 0
+ )
+
+ alignment_details.append(
+ {
+ "fov_name": fov_name,
+ "track_ids": track_ids,
+ "accuracy": track_accuracy,
+ "precision": track_precision,
+ "recall": track_recall,
+ "f1_score": track_f1,
+ "correct": track_correct,
+ "total": track_predictions,
+ }
+ )
+
+ # Add to model totals
+ total_correct += track_correct
+ total_predictions += track_predictions
+ true_positives += track_tp
+ false_positives += track_fp
+ false_negatives += track_fn
+
+ # Calculate metrics
+ accuracy = total_correct / total_predictions if total_predictions > 0 else 0
+ precision = (
+ true_positives / (true_positives + false_positives)
+ if (true_positives + false_positives) > 0
+ else 0
+ )
+ recall = (
+ true_positives / (true_positives + false_negatives)
+ if (true_positives + false_negatives) > 0
+ else 0
+ )
+ f1 = (
+ 2 * (precision * recall) / (precision + recall)
+ if (precision + recall) > 0
+ else 0
+ )
+
+ # Store alignment details for this model
+ if alignment_details:
+ alignment_details_df = pd.DataFrame(alignment_details)
+ print(f"\nDetailed alignment results for {name}:")
+ print(alignment_details_df)
+ alignment_details_df.to_csv(
+ output_root / f"{name}_alignment_details.csv", index=False
+ )
+
+ model_performance.append(
+ {
+ "model": name,
+ "accuracy": accuracy,
+ "precision": precision,
+ "recall": recall,
+ "f1_score": f1,
+ "total_predictions": total_predictions,
+ }
+ )
+
+ # Create performance DataFrame and visualize
+ performance_df = pd.DataFrame(model_performance)
+ print(performance_df)
+
+ # Plot performance metrics
+ plt.figure(figsize=(12, 8))
+
+ # Accuracy plot
+ plt.subplot(2, 2, 1)
+ sns.barplot(x="model", y="accuracy", data=performance_df)
+ plt.title("Infection State Warping Accuracy")
+ plt.ylabel("Accuracy")
+
+ # Precision plot
+ plt.subplot(2, 2, 2)
+ sns.barplot(x="model", y="precision", data=performance_df)
+ plt.title("Precision for Infected State")
+ plt.ylabel("Precision")
+
+ # Recall plot
+ plt.subplot(2, 2, 3)
+ sns.barplot(x="model", y="recall", data=performance_df)
+ plt.title("Recall for Infected State")
+ plt.ylabel("Recall")
+
+ # F1 score plot
+ plt.subplot(2, 2, 4)
+ sns.barplot(x="model", y="f1_score", data=performance_df)
+ plt.title("F1 Score for Infected State")
+ plt.ylabel("F1 Score")
+
+ plt.tight_layout()
+ # plt.savefig(output_root / "infection_state_warping_performance.png", dpi=300)
+ # plt.close()
+else:
+ logger.warning("Reference track annotations not found in infection_annotations_df")
+
+# %%
diff --git a/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py b/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py
new file mode 100644
index 000000000..595474801
--- /dev/null
+++ b/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py
@@ -0,0 +1,162 @@
+# %%
+import sys
+from pathlib import Path
+from typing import Literal
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import phate
+import seaborn as sns
+import torch
+import xarray as xr
+from sklearn.decomposition import PCA
+from sklearn.linear_model import LogisticRegression
+from sklearn.preprocessing import StandardScaler
+from tqdm import tqdm
+
+# Load model directly
+from transformers import AutoModel
+
+from viscy.data.triplet import TripletDataModule
+from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd
+
+
+# %% function to compute phate from embedding values
+def compute_phate(embeddings, n_components=2, knn=15, decay=0.5, **phate_kwargs):
+ # Compute PHATE embeddings
+ phate_model = phate.PHATE(
+ n_components=n_components, knn=knn, decay=decay, **phate_kwargs
+ )
+ phate_embedding = phate_model.fit_transform(embeddings)
+ return phate_embedding
+
+
+# %% OpenPhenom Wrapper
+class OpenPhenomWrapper:
+ def __init__(self):
+ try:
+ self.model = AutoModel.from_pretrained(
+ "recursionpharma/OpenPhenom", trust_remote_code=True
+ )
+ self.model.eval()
+ self.model.to("cuda")
+ except ImportError:
+ raise ImportError(
+ "Please install the OpenPhenom dependencies: "
+ "pip install git+https://github.com/recursionpharma/maes_microscopy.git"
+ )
+
+ def extract_features(self, x):
+ """Extract features from the input images.
+
+ Args:
+ x: Input tensor of shape [B, C, D, H, W] or [B, C, H, W]
+
+ Returns:
+ Features of shape [B, 384]
+ """
+ # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W]
+ # If 5D input, take middle slice or average across D
+ if x.dim() == 5:
+ # Take middle slice or average across D dimension
+ d = x.shape[2]
+ x = x[:, :, d // 2, :, :]
+
+ # Convert to uint8 as OpenPhenom expects uint8 inputs
+ if x.dtype != torch.uint8:
+ # Normalize to 0-1 range if not already
+ x = (x - x.min()) / (x.max() - x.min() + 1e-10)
+ x = (x * 255).clamp(0, 255).to(torch.uint8)
+
+ # Get embeddings
+ self.model.return_channelwise_embeddings = False
+ with torch.no_grad():
+ embeddings = self.model.predict(x)
+
+ return embeddings
+
+
+# %% Initialize OpenPhenom model
+print("Loading OpenPhenom model...")
+openphenom = OpenPhenomWrapper()
+# For infection dataset with phase and RFP
+print("Setting up data module...")
+dm = TripletDataModule(
+ data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr",
+ tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr",
+ source_channel=["Phase3D", "raw GFP EX488 EM525-45"],
+ batch_size=32, # Lower batch size for OpenPhenom which is larger
+ num_workers=10,
+ z_range=(25, 40),
+ initial_yx_patch_size=(192, 192),
+ final_yx_patch_size=(192, 192),
+ normalizations=[
+ ScaleIntensityRangePercentilesd(
+ keys=["raw GFP EX488 EM525-45"], lower=50, upper=99, b_min=0.0, b_max=1.0
+ ),
+ ScaleIntensityRangePercentilesd(
+ keys=["Phase3D"], lower=50, upper=99, b_min=0.0, b_max=1.0
+ ),
+ ],
+)
+dm.prepare_data()
+dm.setup("predict")
+# %%
+print("Extracting features...")
+features = []
+indices = []
+
+with torch.inference_mode():
+ for batch in tqdm(dm.predict_dataloader()):
+ # Get both channels and handle dimensions properly
+ phase = batch["anchor"][:, 0] # Phase channel
+ rfp = batch["anchor"][:, 1] # RFP channel
+ rfp = torch.max(rfp, dim=1).values
+ Z = phase.shape[-3]
+ phase = phase[:, Z // 2]
+ img = torch.stack([phase, rfp], dim=1).to("cuda")
+
+ # Extract features using OpenPhenom
+ batch_features = openphenom.extract_features(img)
+ features.append(batch_features.cpu())
+ indices.append(batch["index"])
+
+# %%
+print("Processing features...")
+pooled = torch.cat(features).numpy()
+tracks = pd.concat([pd.DataFrame(idx) for idx in indices])
+print("Computing PCA and PHATE...")
+scaled_features = StandardScaler().fit_transform(pooled)
+pca = PCA(n_components=8)
+pca_features = pca.fit_transform(scaled_features)
+
+phate_embedding = compute_phate(
+ embeddings=pooled,
+ n_components=2,
+ knn=5,
+ decay=40,
+ n_jobs=15,
+)
+# %% Add features to dataframe
+for i, feature in enumerate(pooled.T):
+ tracks[f"feature_{i}"] = feature
+# Add PCA features to dataframe
+for i, feature in enumerate(pca_features.T):
+ tracks[f"pca_{i}"] = feature
+for i, feature in enumerate(phate_embedding.T):
+ tracks[f"phate_{i}"] = feature
+
+# %% Save the extracted features
+
+output_path = Path(
+ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/openphenom_pretrained_analysis"
+)
+output_path.mkdir(parents=True, exist_ok=True)
+output_embeddings_file = (
+ output_path / "openphenom_pretrained_features_SEC61B_n_Phase.csv"
+)
+print(f"Saving features to {output_embeddings_file}")
+tracks.to_csv(output_embeddings_file, index=False)
+
+# %%
diff --git a/applications/pseudotime_analysis/plotting_utils.py b/applications/pseudotime_analysis/plotting_utils.py
new file mode 100644
index 000000000..c5e0ba8c0
--- /dev/null
+++ b/applications/pseudotime_analysis/plotting_utils.py
@@ -0,0 +1,1225 @@
+# Plotting utils
+import logging
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+from viscy.data.triplet import TripletDataModule
+
+logger = logging.getLogger(__name__)
+
+
+def plot_reference_aligned_average(
+ reference_pattern: np.ndarray,
+ top_aligned_cells: pd.DataFrame,
+ embeddings_dataset: xr.Dataset,
+ save_path: str | None = None,
+) -> np.ndarray:
+ """
+ Plot the reference embedding, aligned embeddings, and average aligned embedding.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ top_aligned_cells: DataFrame with alignment information
+ embeddings_dataset: Dataset containing embeddings
+ save_path: Path to save the figure (optional)
+ """
+ plt.figure(figsize=(15, 10))
+
+ # Get the reference pattern embeddings
+ reference_embeddings = reference_pattern
+
+ # Calculate average aligned embeddings
+ all_aligned_embeddings = []
+ for idx, row in top_aligned_cells.iterrows():
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ warp_path = row["warp_path"]
+ start_time = int(row["start_timepoint"])
+
+ # Reconstruct the concatenated lineage
+ lineages = []
+ track_offsets = (
+ {}
+ ) # To keep track of where each track starts in the concatenated array
+ current_offset = 0
+
+ for track_id in track_ids:
+ track_embeddings = embeddings_dataset.sel(
+ sample=(fov_name, track_id)
+ ).features.values
+ track_offsets[track_id] = current_offset
+ current_offset += len(track_embeddings)
+ lineages.append(track_embeddings)
+
+ lineage_embeddings = np.concatenate(lineages, axis=0)
+
+ # Create aligned embeddings using the warping path
+ aligned_embeddings = np.zeros(
+ (len(reference_pattern), lineage_embeddings.shape[1]),
+ dtype=lineage_embeddings.dtype,
+ )
+
+ # Create mapping from reference to lineage
+ ref_to_lineage = {}
+ for ref_idx, query_idx in warp_path:
+ lineage_idx = int(start_time + query_idx)
+ if 0 <= lineage_idx < len(lineage_embeddings):
+ ref_to_lineage[ref_idx] = lineage_idx
+
+ # Fill aligned embeddings
+ for ref_idx in range(len(reference_pattern)):
+ if ref_idx in ref_to_lineage:
+ aligned_embeddings[ref_idx] = lineage_embeddings[
+ ref_to_lineage[ref_idx]
+ ]
+ elif ref_to_lineage:
+ closest_ref_idx = min(
+ ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx)
+ )
+ aligned_embeddings[ref_idx] = lineage_embeddings[
+ ref_to_lineage[closest_ref_idx]
+ ]
+
+ all_aligned_embeddings.append(aligned_embeddings)
+
+ # Calculate average aligned embeddings
+ average_aligned_embeddings = np.mean(all_aligned_embeddings, axis=0)
+
+ # Plot dimension 0
+ plt.subplot(2, 1, 1)
+ # Plot reference pattern
+ plt.plot(
+ range(len(reference_embeddings)),
+ reference_embeddings[:, 0],
+ label="Reference",
+ color="black",
+ linewidth=3,
+ )
+
+ # Plot each aligned embedding
+ for i, aligned_embeddings in enumerate(all_aligned_embeddings):
+ plt.plot(
+ range(len(aligned_embeddings)),
+ aligned_embeddings[:, 0],
+ label=f"Aligned {i}",
+ alpha=0.4,
+ linestyle="--",
+ )
+
+ # Plot average aligned embedding
+ plt.plot(
+ range(len(average_aligned_embeddings)),
+ average_aligned_embeddings[:, 0],
+ label="Average Aligned",
+ color="red",
+ linewidth=2,
+ )
+
+ plt.title("Dimension 0: Reference, Aligned, and Average Embeddings")
+ plt.xlabel("Reference Time Index")
+ plt.ylabel("Embedding Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ # Plot dimension 1
+ plt.subplot(2, 1, 2)
+ # Plot reference pattern
+ plt.plot(
+ range(len(reference_embeddings)),
+ reference_embeddings[:, 1],
+ label="Reference",
+ color="black",
+ linewidth=3,
+ )
+
+ # Plot each aligned embedding
+ for i, aligned_embeddings in enumerate(all_aligned_embeddings):
+ plt.plot(
+ range(len(aligned_embeddings)),
+ aligned_embeddings[:, 1],
+ label=f"Aligned {i}",
+ alpha=0.4,
+ linestyle="--",
+ )
+
+ # Plot average aligned embedding
+ plt.plot(
+ range(len(average_aligned_embeddings)),
+ average_aligned_embeddings[:, 1],
+ label="Average Aligned",
+ color="red",
+ linewidth=2,
+ )
+
+ plt.title("Dimension 1: Reference, Aligned, and Average Embeddings")
+ plt.xlabel("Reference Time Index")
+ plt.ylabel("Embedding Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ if save_path:
+ plt.savefig(save_path, dpi=300)
+ plt.show()
+
+ return average_aligned_embeddings
+
+
+def plot_reference_vs_full_lineages(
+ reference_pattern: np.ndarray,
+ top_aligned_cells: pd.DataFrame,
+ embeddings_dataset: xr.Dataset,
+ save_path: str | None = None,
+) -> np.ndarray:
+ """
+ Visualize where the reference pattern matches in each full lineage.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ top_aligned_cells: DataFrame with alignment information
+ embeddings_dataset: Dataset containing embeddings
+ save_path: Path to save the figure (optional)
+ """
+ plt.figure(figsize=(15, 15))
+
+ # First, plot the reference pattern for comparison
+ plt.subplot(len(top_aligned_cells) + 1, 2, 1)
+ plt.plot(
+ range(len(reference_pattern)),
+ reference_pattern[:, 0],
+ label="Reference Dim 0",
+ color="black",
+ linewidth=2,
+ )
+ plt.title("Reference Pattern - Dimension 0")
+ plt.xlabel("Time Index")
+ plt.ylabel("Embedding Value")
+ plt.grid(True, alpha=0.3)
+ plt.legend()
+
+ plt.subplot(len(top_aligned_cells) + 1, 2, 2)
+ plt.plot(
+ range(len(reference_pattern)),
+ reference_pattern[:, 1],
+ label="Reference Dim 1",
+ color="black",
+ linewidth=2,
+ )
+ plt.title("Reference Pattern - Dimension 1")
+ plt.xlabel("Time Index")
+ plt.ylabel("Embedding Value")
+ plt.grid(True, alpha=0.3)
+ plt.legend()
+
+ # Then plot each lineage with the matched section highlighted
+ for i, (_, row) in enumerate(top_aligned_cells.iterrows()):
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ warp_path = row["warp_path"]
+ start_time = row["start_timepoint"]
+ distance = row["distance"]
+
+ # Get the full lineage embeddings
+ lineage_embeddings = embeddings_dataset.sel(
+ sample=(fov_name, track_ids)
+ ).features.values
+
+ # Create a subplot for dimension 0
+ plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 3)
+
+ # Plot the full lineage
+ plt.plot(
+ range(len(lineage_embeddings)),
+ lineage_embeddings[:, 0],
+ label="Full Lineage",
+ color="blue",
+ alpha=0.7,
+ )
+
+ # Highlight the matched section
+ matched_indices = set()
+ for _, query_idx in warp_path:
+ lineage_idx = int(start_time + query_idx)
+ if 0 <= lineage_idx < len(lineage_embeddings):
+ matched_indices.add(lineage_idx)
+
+ matched_indices = sorted(list(matched_indices))
+ if matched_indices:
+ plt.plot(
+ matched_indices,
+ [lineage_embeddings[idx, 0] for idx in matched_indices],
+ "ro-",
+ label=f"Matched Section (DTW dist={distance:.2f})",
+ linewidth=2,
+ )
+
+ # Add vertical lines to mark the start and end of the matched section
+ plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5)
+ plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5)
+
+ # Add text labels
+ plt.text(
+ min(matched_indices),
+ min(lineage_embeddings[:, 0]),
+ f"Start: {min(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+ plt.text(
+ max(matched_indices),
+ min(lineage_embeddings[:, 0]),
+ f"End: {max(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+
+ plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - Dimension 0")
+ plt.xlabel("Lineage Time")
+ plt.ylabel("Embedding Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ # Create a subplot for dimension 1
+ plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 4)
+
+ # Plot the full lineage
+ plt.plot(
+ range(len(lineage_embeddings)),
+ lineage_embeddings[:, 1],
+ label="Full Lineage",
+ color="green",
+ alpha=0.7,
+ )
+
+ # Highlight the matched section
+ if matched_indices:
+ plt.plot(
+ matched_indices,
+ [lineage_embeddings[idx, 1] for idx in matched_indices],
+ "ro-",
+ label=f"Matched Section (DTW dist={distance:.2f})",
+ linewidth=2,
+ )
+
+ # Add vertical lines to mark the start and end of the matched section
+ plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5)
+ plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5)
+
+ # Add text labels
+ plt.text(
+ min(matched_indices),
+ min(lineage_embeddings[:, 1]),
+ f"Start: {min(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+ plt.text(
+ max(matched_indices),
+ min(lineage_embeddings[:, 1]),
+ f"End: {max(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+
+ plt.title(f"Lineage {i} ({fov_name}) - Dimension 1")
+ plt.xlabel("Lineage Time")
+ plt.ylabel("Embedding Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ if save_path:
+ plt.savefig(save_path, dpi=300)
+ plt.show()
+
+
+def align_and_average_embeddings(
+ reference_pattern: np.ndarray,
+ top_aligned_cells: pd.DataFrame,
+ embeddings_dataset: xr.Dataset,
+ use_median: bool = False,
+) -> np.ndarray:
+ """
+ Align embeddings from multiple lineages to a reference pattern and compute their average.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ top_aligned_cells: DataFrame with alignment information
+ embeddings_dataset: Dataset containing embeddings
+ use_median: If True, use median instead of mean for averaging
+
+ Returns:
+ The average (or median) aligned embeddings
+ """
+ all_aligned_embeddings = []
+
+ for idx, row in top_aligned_cells.iterrows():
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ warp_path = row["warp_path"]
+ start_time = int(row["start_timepoint"])
+
+ # Reconstruct the concatenated lineage
+ lineages = []
+ track_offsets = (
+ {}
+ ) # To keep track of where each track starts in the concatenated array
+ current_offset = 0
+
+ for track_id in track_ids:
+ track_embeddings = embeddings_dataset.sel(
+ sample=(fov_name, track_id)
+ ).features.values
+ track_offsets[track_id] = current_offset
+ current_offset += len(track_embeddings)
+ lineages.append(track_embeddings)
+
+ lineage_embeddings = np.concatenate(lineages, axis=0)
+
+ # Create aligned embeddings using the warping path
+ aligned_segment = np.zeros_like(reference_pattern)
+
+ # Map each reference timepoint to the corresponding lineage timepoint
+ ref_to_lineage = {}
+ for ref_idx, query_idx in warp_path:
+ lineage_idx = int(start_time + query_idx)
+ if 0 <= lineage_idx < len(lineage_embeddings):
+ ref_to_lineage[ref_idx] = lineage_idx
+ aligned_segment[ref_idx] = lineage_embeddings[lineage_idx]
+
+ # Fill in missing values by using the closest available reference index
+ for ref_idx in range(len(reference_pattern)):
+ if ref_idx not in ref_to_lineage and ref_to_lineage:
+ closest_ref_idx = min(
+ ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx)
+ )
+ aligned_segment[ref_idx] = lineage_embeddings[
+ ref_to_lineage[closest_ref_idx]
+ ]
+
+ all_aligned_embeddings.append(aligned_segment)
+
+ all_aligned_embeddings = np.array(all_aligned_embeddings)
+
+ # Compute average or median
+ if use_median:
+ return np.median(all_aligned_embeddings, axis=0)
+ else:
+ return np.mean(all_aligned_embeddings, axis=0)
+
+
+def align_image_stacks(
+ reference_pattern: np.ndarray,
+ top_aligned_cells: pd.DataFrame,
+ input_data_path: Path,
+ tracks_path: Path,
+ source_channels: list[str],
+ yx_patch_size: tuple[int, int] = (192, 192),
+ z_range: tuple[int, int] = (0, 1),
+ view_ref_sector_only: bool = True,
+ napari_viewer=None,
+) -> tuple[list, list]:
+ """
+ Align image stacks from multiple lineages to a reference pattern.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ top_aligned_cells: DataFrame with alignment information
+ input_data_path: Path to the input data
+ tracks_path: Path to the tracks data
+ source_channels: List of channels to include
+ yx_patch_size: Patch size for images
+ z_range: Z-range to include
+ view_ref_sector_only: If True, only show the section that matches the reference pattern
+ napari_viewer: Optional napari viewer for visualization
+
+ Returns:
+ Tuple of (all_lineage_images, all_aligned_stacks)
+ """
+ from tqdm import tqdm
+
+ all_lineage_images = []
+ all_aligned_stacks = []
+
+ for idx, row in tqdm(
+ top_aligned_cells.iterrows(),
+ total=len(top_aligned_cells),
+ desc="Aligning images",
+ ):
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ warp_path = row["warp_path"]
+ start_time = int(row["start_timepoint"])
+
+ print(f"Aligning images for {fov_name} with track ids: {track_ids}")
+ data_module = TripletDataModule(
+ data_path=input_data_path,
+ tracks_path=tracks_path,
+ source_channel=source_channels,
+ z_range=z_range,
+ initial_yx_patch_size=yx_patch_size,
+ final_yx_patch_size=yx_patch_size,
+ batch_size=1,
+ num_workers=12,
+ predict_cells=True,
+ include_fov_names=[fov_name] * len(track_ids),
+ include_track_ids=track_ids,
+ )
+ data_module.setup("predict")
+
+ # Get the images for the lineage
+ lineage_images = []
+ for batch in data_module.predict_dataloader():
+ image = batch["anchor"].numpy()[0]
+ lineage_images.append(image)
+
+ lineage_images = np.array(lineage_images)
+ all_lineage_images.append(lineage_images)
+ print(f"Lineage images shape: {np.array(lineage_images).shape}")
+
+ # Create an aligned stack based on the warping path
+ if view_ref_sector_only:
+ aligned_stack = np.zeros(
+ (len(reference_pattern),) + lineage_images.shape[-4:],
+ dtype=lineage_images.dtype,
+ )
+
+ # Map each reference timepoint to the corresponding lineage timepoint
+ for ref_idx in range(len(reference_pattern)):
+ # Find matches in warping path for this reference index
+ matches = [(i, q) for i, q in warp_path if i == ref_idx]
+
+ if matches:
+ # Get the corresponding lineage timepoint (first match if multiple)
+ print(f"Found match for ref idx: {ref_idx}")
+ match = matches[0]
+ query_idx = match[1]
+ lineage_idx = int(start_time + query_idx)
+ print(
+ f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}"
+ )
+ # Copy the image if it's within bounds
+ if 0 <= lineage_idx < len(lineage_images):
+ aligned_stack[ref_idx] = lineage_images[lineage_idx]
+ else:
+ # Find nearest valid timepoint if out of bounds
+ nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1)
+ aligned_stack[ref_idx] = lineage_images[nearest_idx]
+ else:
+ # If no direct match, find closest reference timepoint in warping path
+ print(f"No match found for ref idx: {ref_idx}")
+ all_ref_indices = [i for i, _ in warp_path]
+ if all_ref_indices:
+ closest_ref_idx = min(
+ all_ref_indices, key=lambda x: abs(x - ref_idx)
+ )
+ closest_matches = [
+ (i, q) for i, q in warp_path if i == closest_ref_idx
+ ]
+
+ if closest_matches:
+ closest_query_idx = closest_matches[0][1]
+ lineage_idx = int(start_time + closest_query_idx)
+
+ if 0 <= lineage_idx < len(lineage_images):
+ aligned_stack[ref_idx] = lineage_images[lineage_idx]
+ else:
+ # Bound to valid range
+ nearest_idx = min(
+ max(0, lineage_idx), len(lineage_images) - 1
+ )
+ aligned_stack[ref_idx] = lineage_images[nearest_idx]
+
+ all_aligned_stacks.append(aligned_stack)
+ if napari_viewer:
+ napari_viewer.add_image(
+ aligned_stack,
+ name=f"Aligned_{fov_name}_track_{track_ids[0]}",
+ channel_axis=1,
+ )
+ else:
+ # View the whole lineage shifted by the start time
+ start_idx = int(start_time)
+ aligned_stack = lineage_images[start_idx:]
+ all_aligned_stacks.append(aligned_stack)
+ if napari_viewer:
+ napari_viewer.add_image(
+ aligned_stack,
+ name=f"Aligned_{fov_name}_track_{track_ids[0]}",
+ channel_axis=1,
+ )
+
+ return all_lineage_images, all_aligned_stacks
+
+
+def find_pattern_matches(
+ reference_pattern: np.ndarray,
+ filtered_lineages: list[tuple[str, list[int]]],
+ embeddings_dataset: xr.Dataset,
+ window_step_fraction: float = 0.25,
+ num_candidates: int = 3,
+ max_distance: float = float("inf"),
+ max_skew: float = 0.8, # Add skewness parameter
+ save_path: str | None = None,
+ method: str = "bernd_clifford",
+ normalize: bool = True,
+ metric: str = "euclidean",
+) -> pd.DataFrame:
+ """
+ Find the best matches of a reference pattern in multiple lineages using DTW.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ filtered_lineages: List of lineages to search in (fov_name, track_ids)
+ embeddings_dataset: Dataset containing embeddings
+ window_step_fraction: Fraction of reference pattern length to use as window step
+ num_candidates: Number of best candidates to consider per lineage
+ max_distance: Maximum distance threshold to consider a match
+ max_skew: Maximum allowed path skewness (0-1, where 0=perfect diagonal)
+ save_path: Optional path to save the results CSV
+ method: DTW method to use - 'bernd_clifford' (from utils.py) or 'dtai' (dtaidistance library)
+
+ Returns:
+ DataFrame with match positions and distances
+ """
+ from scipy.spatial.distance import cdist
+ from tqdm import tqdm
+
+ # Calculate window step based on reference pattern length
+ window_step = max(1, int(len(reference_pattern) * window_step_fraction))
+
+ all_match_positions = {
+ "fov_name": [],
+ "track_ids": [],
+ "distance": [],
+ "skewness": [], # Add skewness to results
+ "warp_path": [],
+ "start_timepoint": [],
+ "end_timepoint": [],
+ }
+
+ for i, (fov_name, track_ids) in tqdm(
+ enumerate(filtered_lineages),
+ total=len(filtered_lineages),
+ desc="Finding pattern matches",
+ ):
+ print(f"Finding pattern matches for {fov_name} with track ids: {track_ids}")
+ # Reconstruct the concatenated lineage
+ lineages = []
+ track_offsets = (
+ {}
+ ) # To keep track of where each track starts in the concatenated array
+ current_offset = 0
+
+ for track_id in track_ids:
+ track_embeddings = embeddings_dataset.sel(
+ sample=(fov_name, track_id)
+ ).features.values
+ track_offsets[track_id] = current_offset
+ current_offset += len(track_embeddings)
+ lineages.append(track_embeddings)
+
+ lineage_embeddings = np.concatenate(lineages, axis=0)
+
+ # Find best matches using the selected DTW method
+ if method == "bernd_clifford":
+ matches_df = find_best_match_dtw_bernd_clifford(
+ lineage_embeddings,
+ reference_pattern=reference_pattern,
+ num_candidates=num_candidates,
+ window_step=window_step,
+ max_distance=max_distance,
+ max_skew=max_skew,
+ normalize=normalize,
+ metric=metric,
+ )
+ else:
+ matches_df = find_best_match_dtw(
+ lineage_embeddings,
+ reference_pattern=reference_pattern,
+ num_candidates=num_candidates,
+ window_step=window_step,
+ max_distance=max_distance,
+ max_skew=max_skew,
+ normalize=normalize,
+ )
+
+ if not matches_df.empty:
+ # Get the best match (first row of the sorted DataFrame)
+ best_match = matches_df.iloc[0]
+ best_pos = best_match["position"]
+ best_path = best_match["path"]
+ best_dist = best_match["distance"]
+ best_skew = best_match["skewness"]
+
+ all_match_positions["fov_name"].append(fov_name)
+ all_match_positions["track_ids"].append(track_ids)
+ all_match_positions["distance"].append(best_dist)
+ all_match_positions["skewness"].append(best_skew)
+ all_match_positions["warp_path"].append(best_path)
+ all_match_positions["start_timepoint"].append(best_pos)
+ all_match_positions["end_timepoint"].append(
+ best_pos + len(reference_pattern)
+ )
+ else:
+ all_match_positions["fov_name"].append(fov_name)
+ all_match_positions["track_ids"].append(track_ids)
+ all_match_positions["distance"].append(None)
+ all_match_positions["skewness"].append(None)
+ all_match_positions["warp_path"].append(None)
+ all_match_positions["start_timepoint"].append(None)
+ all_match_positions["end_timepoint"].append(None)
+
+ # Convert to DataFrame and drop rows with no matches
+ all_match_positions = pd.DataFrame(all_match_positions)
+ all_match_positions = all_match_positions.dropna()
+
+ # Sort by distance (primary) and skewness (secondary)
+ all_match_positions = all_match_positions.sort_values(
+ by=["distance", "skewness"], ascending=[True, True]
+ )
+
+ # Save to CSV if path is provided
+ if save_path:
+ all_match_positions.to_csv(save_path, index=False)
+
+ return all_match_positions
+
+
+def find_best_match_dtw(
+ lineage: np.ndarray,
+ reference_pattern: np.ndarray,
+ num_candidates: int = 5,
+ window_step: int = 5,
+ max_distance: float = float("inf"),
+ max_skew: float = 0.8,
+ normalize: bool = True,
+) -> pd.DataFrame:
+ """
+ Find the best matches in a lineage using DTW with dtaidistance.
+
+ Args:
+ lineage: The lineage to search (t,embeddings).
+ reference_pattern: The pattern to search for (t,embeddings).
+ num_candidates: The number of candidates to return.
+ window_step: The step size for the window.
+ max_distance: Maximum distance threshold to consider a match.
+ max_skew: Maximum allowed path skewness (0-1).
+
+ Returns:
+ DataFrame with position, warping_path, distance, and skewness for the best matches
+ """
+ from dtaidistance.dtw_ndim import warping_path
+
+ from utils import path_skew
+
+ dtw_results = []
+ n_windows = len(lineage) - len(reference_pattern) + 1
+
+ if n_windows <= 0:
+ return pd.DataFrame(columns=["position", "path", "distance", "skewness"])
+
+ for i in range(0, n_windows, window_step):
+ window = lineage[i : i + len(reference_pattern)]
+ path, dist = warping_path(
+ reference_pattern,
+ window,
+ include_distance=True,
+ )
+ if normalize:
+ # Normalize by path length to match bernd_clifford method
+ dist = dist / len(path)
+ # Calculate skewness using the utils function
+ skewness = path_skew(path, len(reference_pattern), len(window))
+
+ if dist <= max_distance and skewness <= max_skew:
+ dtw_results.append(
+ {"position": i, "path": path, "distance": dist, "skewness": skewness}
+ )
+
+ # Convert to DataFrame
+ results_df = pd.DataFrame(dtw_results)
+
+ # Sort by distance first (primary) and then by skewness (secondary)
+ if not results_df.empty:
+ results_df = results_df.sort_values(by=["distance", "skewness"]).head(
+ num_candidates
+ )
+
+ return results_df
+
+
+def find_best_match_dtw_bernd_clifford(
+ lineage: np.ndarray,
+ reference_pattern: np.ndarray,
+ num_candidates: int = 5,
+ window_step: int = 5,
+ normalize: bool = True,
+ max_distance: float = float("inf"),
+ max_skew: float = 0.8,
+ metric: str = "euclidean",
+) -> pd.DataFrame:
+ """
+ Find the best matches in a lineage using DTW with the utils.py implementation.
+
+ Args:
+ lineage: The lineage to search (t,embeddings).
+ reference_pattern: The pattern to search for (t,embeddings).
+ num_candidates: The number of candidates to return.
+ window_step: The step size for the window.
+ max_distance: Maximum distance threshold to consider a match.
+ max_skew: Maximum allowed path skewness (0-1).
+
+ Returns:
+ DataFrame with position, warping_path, distance, and skewness for the best matches
+ """
+ from scipy.spatial.distance import cdist
+
+ from utils import dtw_with_matrix, path_skew
+
+ dtw_results = []
+ n_windows = len(lineage) - len(reference_pattern) + 1
+
+ if n_windows <= 0:
+ return pd.DataFrame(columns=["position", "path", "distance", "skewness"])
+
+ for i in range(0, n_windows, window_step):
+ window = lineage[i : i + len(reference_pattern)]
+
+ # Create distance matrix
+ distance_matrix = cdist(reference_pattern, window, metric=metric)
+
+ # Apply DTW using utils.py implementation
+ distance, _, path = dtw_with_matrix(distance_matrix, normalize=normalize)
+
+ # Calculate skewness
+ skewness = path_skew(path, len(reference_pattern), len(window))
+
+ # Only add if both distance and skewness pass thresholds
+ if distance <= max_distance and skewness <= max_skew:
+ logger.debug(
+ f"Found match at {i} with distance {distance} and skewness {skewness}"
+ )
+ dtw_results.append(
+ {
+ "position": i,
+ "path": path,
+ "distance": distance,
+ "skewness": skewness,
+ }
+ )
+
+ # Convert to DataFrame
+ results_df = pd.DataFrame(dtw_results)
+
+ # Sort by distance first (primary) and then by skewness (secondary)
+ if not results_df.empty:
+ results_df = results_df.sort_values(by=["distance", "skewness"]).head(
+ num_candidates
+ )
+
+ return results_df
+
+
+def create_consensus_embedding(
+ reference_pattern: np.ndarray,
+ top_aligned_cells: pd.DataFrame,
+ embeddings_dataset: xr.Dataset,
+) -> np.ndarray:
+ """
+ Create a consensus embedding from multiple aligned embeddings using
+ a weighted approach based on DTW distances.
+
+ Args:
+ reference_pattern: The reference pattern embeddings
+ top_aligned_cells: DataFrame with alignment information
+ embeddings_dataset: Dataset containing embeddings
+
+ Returns:
+ The consensus embedding
+ """
+ all_aligned_embeddings = []
+ distances = []
+
+ for idx, row in top_aligned_cells.iterrows():
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ warp_path = row["warp_path"]
+ start_time = int(row["start_timepoint"])
+ distance = row["distance"]
+
+ # Get lineage embeddings (similar to align_and_average_embeddings)
+ lineages = []
+ for track_id in track_ids:
+ track_embeddings = embeddings_dataset.sel(
+ sample=(fov_name, track_id)
+ ).features.values
+ lineages.append(track_embeddings)
+
+ lineage_embeddings = np.concatenate(lineages, axis=0)
+
+ # Create aligned embeddings using the warping path
+ aligned_segment = np.zeros_like(reference_pattern)
+
+ # Map each reference timepoint to the corresponding lineage timepoint
+ ref_to_lineage = {}
+ for ref_idx, query_idx in warp_path:
+ lineage_idx = int(start_time + query_idx)
+ if 0 <= lineage_idx < len(lineage_embeddings):
+ ref_to_lineage[ref_idx] = lineage_idx
+ aligned_segment[ref_idx] = lineage_embeddings[lineage_idx]
+
+ # Fill in missing values
+ for ref_idx in range(len(reference_pattern)):
+ if ref_idx not in ref_to_lineage and ref_to_lineage:
+ closest_ref_idx = min(
+ ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx)
+ )
+ aligned_segment[ref_idx] = lineage_embeddings[
+ ref_to_lineage[closest_ref_idx]
+ ]
+
+ all_aligned_embeddings.append(aligned_segment)
+ distances.append(distance)
+
+ all_aligned_embeddings = np.array(all_aligned_embeddings)
+
+ # Convert distances to weights (smaller distance = higher weight)
+ weights = 1.0 / (
+ np.array(distances) + 1e-10
+ ) # Add small epsilon to avoid division by zero
+ weights = weights / np.sum(weights) # Normalize weights
+
+ # Create weighted consensus
+ consensus_embedding = np.zeros_like(reference_pattern)
+ for i, aligned_embedding in enumerate(all_aligned_embeddings):
+ consensus_embedding += weights[i] * aligned_embedding
+
+ return consensus_embedding
+
+
+def identify_lineages(
+ tracking_df: pd.DataFrame, return_both_branches: bool = False
+) -> list[tuple[str, list[int]]]:
+ """
+ Identifies all distinct lineages in the cell tracking data, following only
+ one branch after each division event.
+
+ Args:
+ annotations_path: Path to the annotations CSV file
+
+ Returns:
+ A list of tuples, where each tuple contains (fov_id, [track_ids])
+ representing a single branch lineage within a single FOV
+ """
+ # Read the CSV file
+
+ # Process each FOV separately to handle repeated track_ids
+ all_lineages = []
+
+ # Group by FOV
+ for fov_id, fov_df in tracking_df.groupby("fov_name"):
+ # Create a dictionary to map tracks to their parents within this FOV
+ child_to_parent = {}
+
+ # Group by track_id and get the first row for each track to find its parent
+ for track_id, track_group in fov_df.groupby("track_id"):
+ first_row = track_group.iloc[0]
+ parent_track_id = first_row["parent_track_id"]
+
+ if parent_track_id != -1:
+ child_to_parent[track_id] = parent_track_id
+
+ # Find root tracks (those without parents or with parent_track_id = -1)
+ all_tracks = set(fov_df["track_id"].unique())
+ child_tracks = set(child_to_parent.keys())
+ root_tracks = all_tracks - child_tracks
+
+ # Additional validation for root tracks
+ root_tracks = set()
+ for track_id in all_tracks:
+ track_data = fov_df[fov_df["track_id"] == track_id]
+ # Check if it's truly a root track
+ if (
+ track_data.iloc[0]["parent_track_id"] == -1
+ or track_data.iloc[0]["parent_track_id"] not in all_tracks
+ ):
+ root_tracks.add(track_id)
+
+ # Build a parent-to-children mapping
+ parent_to_children = {}
+ for child, parent in child_to_parent.items():
+ if parent not in parent_to_children:
+ parent_to_children[parent] = []
+ parent_to_children[parent].append(child)
+
+ # Function to get all branches from each parent
+ def get_all_branches(track_id):
+ branches = []
+ current_branch = [track_id]
+
+ if track_id in parent_to_children:
+ # For each child, get all their branches
+ for child in parent_to_children[track_id]:
+ child_branches = get_all_branches(child)
+ # Add current track to start of each child branch
+ for branch in child_branches:
+ branches.append(current_branch + branch)
+ else:
+ # If no children, return just this track
+ branches.append(current_branch)
+
+ return branches
+
+ # Build lineages starting from root tracks within this FOV
+ for root_track in root_tracks:
+ # Get all branches from this root
+ lineage_tracks = get_all_branches(root_track)
+ if return_both_branches:
+ for branch in lineage_tracks:
+ all_lineages.append((fov_id, branch))
+ else:
+ all_lineages.append((fov_id, lineage_tracks[0]))
+
+ return all_lineages
+
+
+def plot_pc_trajectories(
+ reference_lineage_fov: str,
+ reference_lineage_track_id: list[int],
+ reference_timepoints: list[int],
+ match_positions: pd.DataFrame,
+ embeddings_dataset: xr.Dataset,
+ filtered_lineages: list[tuple[str, list[int]]],
+ name: str,
+ save_path: Path,
+):
+ """
+ Visualize warping paths in PC space, comparing reference pattern with aligned lineages.
+
+ Args:
+ reference_lineage_fov: FOV name for the reference lineage
+ reference_lineage_track_id: Track ID for the reference lineage
+ reference_timepoints: Time range [start, end] to use from reference
+ match_positions: DataFrame with alignment matches
+ embeddings_dataset: Dataset with embeddings
+ filtered_lineages: List of lineages to search in (fov_name, track_ids)
+ name: Name of the embedding model
+ save_path: Path to save the figure
+ """
+ import ast
+ from sklearn.decomposition import PCA
+ from sklearn.preprocessing import StandardScaler
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import pandas as pd
+
+ # Get reference pattern
+ ref_pattern = None
+ for fov_id, track_ids in filtered_lineages:
+ if fov_id == reference_lineage_fov and all(
+ track_id in track_ids for track_id in reference_lineage_track_id
+ ):
+ ref_pattern = embeddings_dataset.sel(
+ sample=(fov_id, reference_lineage_track_id)
+ ).features.values
+ break
+
+ if ref_pattern is None:
+ logger.info(
+ f"Reference pattern not found for {name}. Skipping PC trajectory plot."
+ )
+ return
+
+ ref_pattern = np.concatenate([ref_pattern])
+ ref_pattern = ref_pattern[reference_timepoints[0] : reference_timepoints[1]]
+
+ # Get top matches
+ top_n_aligned_cells = match_positions.head(5)
+
+ # Compute PCA directly with sklearn
+ # Scale the input data
+ scaler = StandardScaler()
+ ref_pattern_scaled = scaler.fit_transform(ref_pattern)
+
+ # Create and fit PCA model
+ pca_model = PCA(n_components=2, random_state=42)
+ pca_ref = pca_model.fit_transform(ref_pattern_scaled)
+
+ # Create a figure to display the results
+ plt.figure(figsize=(15, 15))
+
+ # Plot the reference pattern PCs
+ plt.subplot(len(top_n_aligned_cells) + 1, 2, 1)
+ plt.plot(
+ range(len(pca_ref)),
+ pca_ref[:, 0],
+ label="Reference PC1",
+ color="black",
+ linewidth=2,
+ )
+ plt.title(f"{name} - Reference Pattern - PC1")
+ plt.xlabel("Time Index")
+ plt.ylabel("PC1 Value")
+ plt.grid(True, alpha=0.3)
+ plt.legend()
+
+ plt.subplot(len(top_n_aligned_cells) + 1, 2, 2)
+ plt.plot(
+ range(len(pca_ref)),
+ pca_ref[:, 1],
+ label="Reference PC2",
+ color="black",
+ linewidth=2,
+ )
+ plt.title(f"{name} - Reference Pattern - PC2")
+ plt.xlabel("Time Index")
+ plt.ylabel("PC2 Value")
+ plt.grid(True, alpha=0.3)
+ plt.legend()
+
+ # Then plot each lineage with the matched section highlighted
+ for i, (_, row) in enumerate(top_n_aligned_cells.iterrows()):
+ fov_name = row["fov_name"]
+ track_ids = row["track_ids"]
+ if isinstance(track_ids, str):
+ track_ids = ast.literal_eval(track_ids)
+ warp_path = row["warp_path"]
+ if isinstance(warp_path, str):
+ warp_path = ast.literal_eval(warp_path)
+ start_time = row["start_timepoint"]
+ distance = row["distance"]
+
+ # Get the full lineage embeddings
+ lineage_embeddings = []
+ for track_id in track_ids:
+ try:
+ track_emb = embeddings_dataset.sel(
+ sample=(fov_name, track_id)
+ ).features.values
+ lineage_embeddings.append(track_emb)
+ except KeyError:
+ pass
+
+ if not lineage_embeddings:
+ continue
+
+ lineage_embeddings = np.concatenate(lineage_embeddings, axis=0)
+
+ # Transform lineage embeddings using the same PCA model
+ # Scale first using the same scaler
+ lineage_scaled = scaler.transform(lineage_embeddings)
+ pca_lineage = pca_model.transform(lineage_scaled)
+
+ # Create a subplot for PC1
+ plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 3)
+
+ # Plot the full lineage PC1
+ plt.plot(
+ range(len(pca_lineage)),
+ pca_lineage[:, 0],
+ label="Full Lineage",
+ color="blue",
+ alpha=0.7,
+ )
+
+ # Highlight the matched section
+ matched_indices = set()
+ for _, query_idx in warp_path:
+ lineage_idx = (
+ int(start_time) + query_idx if not pd.isna(start_time) else query_idx
+ )
+ if 0 <= lineage_idx < len(pca_lineage):
+ matched_indices.add(lineage_idx)
+
+ matched_indices = sorted(list(matched_indices))
+ if matched_indices:
+ plt.plot(
+ matched_indices,
+ [pca_lineage[idx, 0] for idx in matched_indices],
+ "ro-",
+ label=f"Matched Section (DTW dist={distance:.2f})",
+ linewidth=2,
+ )
+
+ # Add vertical lines to mark the start and end of the matched section
+ plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5)
+ plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5)
+
+ # Add text labels
+ plt.text(
+ min(matched_indices),
+ min(pca_lineage[:, 0]),
+ f"Start: {min(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+ plt.text(
+ max(matched_indices),
+ min(pca_lineage[:, 0]),
+ f"End: {max(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+
+ plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - PC1")
+ plt.xlabel("Lineage Time")
+ plt.ylabel("PC1 Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ # Create a subplot for PC2
+ plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 4)
+
+ # Plot the full lineage PC2
+ plt.plot(
+ range(len(pca_lineage)),
+ pca_lineage[:, 1],
+ label="Full Lineage",
+ color="green",
+ alpha=0.7,
+ )
+
+ # Highlight the matched section
+ if matched_indices:
+ plt.plot(
+ matched_indices,
+ [pca_lineage[idx, 1] for idx in matched_indices],
+ "ro-",
+ label=f"Matched Section (DTW dist={distance:.2f})",
+ linewidth=2,
+ )
+
+ # Add vertical lines to mark the start and end of the matched section
+ plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5)
+ plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5)
+
+ # Add text labels
+ plt.text(
+ min(matched_indices),
+ min(pca_lineage[:, 1]),
+ f"Start: {min(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+ plt.text(
+ max(matched_indices),
+ min(pca_lineage[:, 1]),
+ f"End: {max(matched_indices)}",
+ color="red",
+ fontsize=10,
+ )
+
+ plt.title(f"Lineage {i} ({fov_name}) - PC2")
+ plt.xlabel("Lineage Time")
+ plt.ylabel("PC2 Value")
+ plt.legend()
+ plt.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(save_path, dpi=300)
+ plt.close()
diff --git a/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py b/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py
new file mode 100644
index 000000000..084a87eaf
--- /dev/null
+++ b/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py
@@ -0,0 +1,224 @@
+# %%
+import matplotlib.pyplot as plt
+import numpy as np
+from dtaidistance import dtw
+from scipy.cluster.hierarchy import dendrogram, linkage
+
+# testing if we can use DTW to align cell trajectories using a short reference pattern
+
+np.random.seed(42)
+timepoints = 50
+cells = 8
+
+# Create synthetic cell trajectories (e.g., PCA1 shape evolution)
+cell_trajectories = [
+ np.sin(np.linspace(0, 10, timepoints) + np.random.rand() * 2) for _ in range(cells)
+]
+# add extra transforms to signal
+cell_trajectories[cells - 1] += np.sin(np.linspace(0, 5, timepoints)) * 2
+cell_trajectories[cells - 2] += np.sin(np.linspace(0, 5, timepoints)) * 3
+
+# %%
+# plot cell trajectories
+plt.figure(figsize=(8, 5))
+for i in range(cells):
+ plt.plot(cell_trajectories[i], label=f"Cell {i+1}")
+plt.legend()
+plt.title("Original Cell Trajectories")
+plt.show()
+# %%
+# Set reference cell for all subsequent analysis
+reference_cell = 0 # Use first cell as reference
+
+# Compute DTW distance matrix
+dtw_matrix = np.zeros((cells, cells))
+for i in range(cells):
+ for j in range(i + 1, cells):
+ dtw_matrix[i, j] = dtw.distance(cell_trajectories[i], cell_trajectories[j])
+ dtw_matrix[j, i] = dtw_matrix[i, j]
+
+# Print distance matrix for examination
+print("DTW Distance Matrix:")
+for i in range(cells):
+ print(f"Cell {i+1}: {dtw_matrix[reference_cell, i]:.2f}")
+
+# Plot distance heatmap
+plt.figure(figsize=(8, 6))
+plt.imshow(dtw_matrix, cmap="viridis", origin="lower")
+plt.colorbar(label="DTW Distance")
+plt.title("DTW Distance Matrix Between Cells")
+plt.xlabel("Cell Index")
+plt.ylabel("Cell Index")
+plt.tight_layout()
+plt.show()
+
+linkage_matrix = linkage(dtw_matrix, method="ward")
+
+# Plot the dendrogram
+plt.figure(figsize=(8, 5))
+dendrogram(linkage_matrix, labels=[f"Cell {i+1}" for i in range(cells)])
+plt.xlabel("Cells")
+plt.ylabel("DTW Distance")
+plt.title("Hierarchical Clustering of Cells Based on DTW")
+plt.show()
+
+# %%
+# Align cells using DTW with distance filtering
+# Set a threshold for maximum allowed DTW distance
+# Cells with distances above this threshold won't be aligned
+# This can be set based on the distribution of distances or domain knowledge
+distance_threshold = np.median(dtw_matrix[reference_cell, :]) * 1.5 # Example threshold
+
+print(f"Using distance threshold: {distance_threshold:.2f}")
+print("Distances from reference cell:")
+for i in range(cells):
+ distance = dtw_matrix[reference_cell, i]
+ status = (
+ "Included"
+ if distance <= distance_threshold or i == reference_cell
+ else "Excluded (too dissimilar)"
+ )
+ print(f"Cell {i+1}: {distance:.2f} - {status}")
+
+# Initialize aligned trajectories with the reference cell
+aligned_cell_trajectories = [cell_trajectories[reference_cell].copy()]
+alignment_status = [True] # Reference cell is always included
+
+for i in range(1, cells):
+ distance = dtw_matrix[reference_cell, i]
+
+ # Skip cells that are too dissimilar to the reference
+ if distance > distance_threshold:
+ aligned_cell_trajectories.append(
+ np.full_like(cell_trajectories[reference_cell], np.nan)
+ )
+ alignment_status.append(False)
+ continue
+
+ # Find optimal warping path
+ path = dtw.warping_path(cell_trajectories[reference_cell], cell_trajectories[i])
+
+ # Create aligned trajectory by mapping query points to reference timeline
+ aligned_trajectory = np.zeros_like(cell_trajectories[reference_cell])
+ path_dict = {}
+
+ # Group by reference indices
+ for ref_idx, query_idx in path:
+ if ref_idx not in path_dict:
+ path_dict[ref_idx] = []
+ path_dict[ref_idx].append(query_idx)
+
+ # For each reference index, average the corresponding query values
+ for ref_idx, query_indices in path_dict.items():
+ query_values = [cell_trajectories[i][idx] for idx in query_indices]
+ aligned_trajectory[ref_idx] = np.mean(query_values)
+
+ aligned_cell_trajectories.append(aligned_trajectory)
+ alignment_status.append(True)
+
+# %%
+# plot aligned cell trajectories (only included cells)
+plt.figure(figsize=(10, 6))
+
+# Plot reference cell first
+plt.plot(aligned_cell_trajectories[0], "k-", linewidth=2.5, label=f"Reference (Cell 1)")
+
+# Plot other cells that were successfully aligned
+for i in range(1, cells):
+ if alignment_status[i]:
+ plt.plot(aligned_cell_trajectories[i], label=f"Cell {i+1}")
+
+plt.legend()
+plt.title("Aligned Cell Trajectories (Filtered by DTW Distance)")
+plt.show()
+
+# %%
+# Visualize warping paths for examples
+plt.figure(figsize=(15, 10))
+
+# First find cells to include based on distance threshold
+included_cells = [i for i in range(1, cells) if alignment_status[i]]
+excluded_cells = [i for i in range(1, cells) if not alignment_status[i]]
+
+# Show included cells examples
+for idx, target_cell in enumerate(included_cells[: min(2, len(included_cells))]):
+ plt.subplot(2, 3, idx + 1)
+
+ # Get warping path
+ path = dtw.warping_path(
+ cell_trajectories[reference_cell], cell_trajectories[target_cell]
+ )
+
+ # Plot both signals
+ plt.plot(cell_trajectories[reference_cell], label=f"Reference", linewidth=2)
+ plt.plot(cell_trajectories[target_cell], label=f"Cell {target_cell+1}", linewidth=2)
+
+ # Plot warping connections
+ for ref_idx, query_idx in path:
+ plt.plot(
+ [ref_idx, query_idx],
+ [
+ cell_trajectories[reference_cell][ref_idx],
+ cell_trajectories[target_cell][query_idx],
+ ],
+ "k-",
+ alpha=0.1,
+ )
+
+ plt.title(
+ f"Included - Cell {target_cell+1} (Dist: {dtw_matrix[reference_cell, target_cell]:.2f})"
+ )
+ plt.legend()
+
+# Show excluded cells examples
+for idx, target_cell in enumerate(excluded_cells[: min(2, len(excluded_cells))]):
+ plt.subplot(2, 3, 3 + idx)
+
+ plt.plot(cell_trajectories[reference_cell], label=f"Reference", linewidth=2)
+ plt.plot(cell_trajectories[target_cell], label=f"Cell {target_cell+1}", linewidth=2)
+
+ # Show distance value
+ plt.title(
+ f"Excluded - Cell {target_cell+1} (Dist: {dtw_matrix[reference_cell, target_cell]:.2f})"
+ )
+ plt.legend()
+
+# Compare original and aligned for an included cell
+
+if included_cells:
+ plt.subplot(2, 3, 5)
+ target_cell = included_cells[0]
+ plt.plot(cell_trajectories[reference_cell], label="Reference", linewidth=2)
+ plt.plot(
+ cell_trajectories[target_cell],
+ label=f"Original Cell {target_cell+1}",
+ linewidth=2,
+ linestyle="--",
+ alpha=0.7,
+ )
+ plt.plot(
+ aligned_cell_trajectories[target_cell],
+ label=f"Aligned Cell {target_cell+1}",
+ linewidth=2,
+ )
+ plt.title(f"Alignment Example (Included)")
+ plt.legend()
+
+# Show distance distribution
+plt.subplot(2, 3, 6)
+distances = dtw_matrix[reference_cell, 1:] # Skip distance to self
+plt.hist(distances, bins=10)s
+plt.axvline(
+ distance_threshold,
+ color="r",
+ linestyle="--",
+ label=f"Threshold: {distance_threshold:.2f}",
+)
+plt.title("DTW Distance Distribution")
+plt.xlabel("DTW Distance from Reference")
+plt.ylabel("Count")
+plt.legend()
+
+plt.tight_layout()
+plt.show()
+# %%
diff --git a/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py
new file mode 100644
index 000000000..3bac7f227
--- /dev/null
+++ b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py
@@ -0,0 +1,634 @@
+# %%
+import matplotlib.pyplot as plt
+import numpy as np
+from dtaidistance.dtw_ndim import warping_path
+from scipy.spatial.distance import cdist
+
+# %%
+# Simulation of embeddings with temporal warping
+np.random.seed(42) # For reproducibility
+
+# Base parameters
+num_cells = 8
+num_timepoints = 30 # More timepoints for better visualization of warping
+embedding_dim = 100
+
+# Create a reference trajectory (Cell 1)
+t_ref = np.linspace(0, 4 * np.pi, num_timepoints) # 2 complete periods (0 to 4π)
+base_pattern = np.zeros((num_timepoints, embedding_dim))
+
+# Generate a structured pattern with clear sinusoidal shape
+for dim in range(embedding_dim):
+ # Use lower frequencies to ensure at least one full period
+ freq = 0.2 + 0.3 * np.random.rand() # Frequency between 0.2 and 0.5
+ phase = np.random.rand() * np.pi # Random phase
+ amplitude = 0.7 + 0.6 * np.random.rand() # Amplitude between 0.7 and 1.3
+
+ # Create basic sine wave for this dimension
+ base_pattern[:, dim] = amplitude * np.sin(freq * t_ref + phase)
+
+# Cell embeddings with different temporal dynamics
+cell_embeddings = np.zeros((num_cells, num_timepoints, embedding_dim))
+
+# Cell 1 (reference) - standard progression
+cell_embeddings[0] = base_pattern.copy()
+
+# Cell 2 - similar to reference with small variations
+cell_embeddings[1] = base_pattern + np.random.randn(num_timepoints, embedding_dim) * 0.2
+
+# Cell 3 - starts slow, then accelerates (time warping)
+# Map [0,1] -> [0,4π] with non-linear warping
+t_warped = np.power(np.linspace(0, 1, num_timepoints), 1.7) * 4 * np.pi
+for dim in range(embedding_dim):
+ # Get the same frequencies and phases as the reference
+ freq = 0.2 + 0.3 * np.random.rand()
+ phase = np.random.rand() * np.pi
+ amplitude = 0.7 + 0.6 * np.random.rand()
+
+ # Apply the warping to the timepoints
+ cell_embeddings[2, :, dim] = amplitude * np.sin(freq * t_warped + phase)
+cell_embeddings[2] += np.random.randn(num_timepoints, embedding_dim) * 0.15
+
+# Cell 4 - starts fast, then slows down (time warping)
+t_warped = np.power(np.linspace(0, 1, num_timepoints), 0.6) * 4 * np.pi
+for dim in range(embedding_dim):
+ freq = 0.2 + 0.3 * np.random.rand()
+ phase = np.random.rand() * np.pi
+ amplitude = 0.7 + 0.6 * np.random.rand()
+ cell_embeddings[3, :, dim] = amplitude * np.sin(freq * t_warped + phase)
+cell_embeddings[3] += np.random.randn(num_timepoints, embedding_dim) * 0.15
+
+# Cell 5 - missing middle section (temporal gap)
+t_warped = np.concatenate(
+ [
+ np.linspace(0, 1.5 * np.pi, num_timepoints // 2), # First 1.5 periods
+ np.linspace(
+ 2.5 * np.pi, 4 * np.pi, num_timepoints // 2
+ ), # Last 1.5 periods (sshkip middle)
+ ]
+)
+for dim in range(embedding_dim):
+ freq = 0.2 + 0.3 * np.random.rand()
+ phase = np.random.rand() * np.pi
+ amplitude = 0.7 + 0.6 * np.random.rand()
+ cell_embeddings[4, :, dim] = amplitude * np.sin(freq * t_warped + phase)
+cell_embeddings[4] += np.random.randn(num_timepoints, embedding_dim) * 0.15
+
+# Cell 6 - phase shifted (out of sync with reference)
+cell_embeddings[5] = np.roll(
+ base_pattern, shift=num_timepoints // 4, axis=0
+) # 1/4 cycle shift
+cell_embeddings[5] += np.random.randn(num_timepoints, embedding_dim) * 0.2
+
+# Cell 7 - Double frequency (faster oscillations)
+for dim in range(embedding_dim):
+ freq = (0.2 + 0.3 * np.random.rand()) * 2 # Double frequency
+ phase = np.random.rand() * np.pi
+ amplitude = 0.7 + 0.6 * np.random.rand()
+ cell_embeddings[6, :, dim] = amplitude * np.sin(freq * t_ref + phase)
+cell_embeddings[6] += np.random.randn(num_timepoints, embedding_dim) * 0.15
+
+# Cell 8 - Very different pattern with trend
+cell_embeddings[7] = np.random.randn(num_timepoints, embedding_dim) * 1.5
+trend = np.linspace(0, 3, num_timepoints).reshape(-1, 1)
+cell_embeddings[7] += trend * np.random.randn(1, embedding_dim)
+
+# %%
+# Visualize the first two dimensions of each cell's embeddings to see the temporal patterns
+plt.figure(figsize=(18, 10))
+
+# Create subplots for 4 dimensions
+for dim in range(4):
+ plt.subplot(2, 2, dim + 1)
+
+ for i in range(num_cells):
+ plt.plot(
+ range(num_timepoints),
+ cell_embeddings[i, :, dim],
+ label=f"Cell {i+1}",
+ linewidth=2,
+ )
+
+ plt.title(f"Dimension {dim+1} over time")
+ plt.xlabel("Timepoint")
+ plt.ylabel(f"Value (Dim {dim+1})")
+ plt.grid(alpha=0.3)
+
+ if dim == 0:
+ plt.legend(loc="upper right")
+
+plt.tight_layout()
+plt.show()
+
+
+# %%
+# Helper function to compute DTW warping matrix
+def compute_dtw_matrix(s1, s2):
+ """
+ Compute the DTW warping matrix and best path manually.
+
+ Args:
+ s1: First sequence (reference)
+ s2: Second sequence (query)
+
+ Returns:
+ warping_matrix: The accumulated cost matrix
+ best_path: The optimal warping path
+ """
+ # Compute pairwise distances between all timepoints
+ distance_matrix = cdist(s1, s2)
+
+ n, m = distance_matrix.shape
+
+ # Initialize the accumulated cost matrix
+ warping_matrix = np.full((n, m), np.inf)
+ warping_matrix[0, 0] = distance_matrix[0, 0]
+
+ # Fill the first column and row
+ for i in range(1, n):
+ warping_matrix[i, 0] = warping_matrix[i - 1, 0] + distance_matrix[i, 0]
+ for j in range(1, m):
+ warping_matrix[0, j] = warping_matrix[0, j - 1] + distance_matrix[0, j]
+
+ # Fill the rest of the matrix
+ for i in range(1, n):
+ for j in range(1, m):
+ warping_matrix[i, j] = distance_matrix[i, j] + min(
+ warping_matrix[i - 1, j], # insertion
+ warping_matrix[i, j - 1], # deletion
+ warping_matrix[i - 1, j - 1], # match
+ )
+
+ # Backtrack to find the optimal path
+ i, j = n - 1, m - 1
+ path = [(i, j)]
+
+ while i > 0 or j > 0:
+ if i == 0:
+ j -= 1
+ elif j == 0:
+ i -= 1
+ else:
+ min_cost = min(
+ warping_matrix[i - 1, j],
+ warping_matrix[i, j - 1],
+ warping_matrix[i - 1, j - 1],
+ )
+
+ if min_cost == warping_matrix[i - 1, j - 1]:
+ i, j = i - 1, j - 1
+ elif min_cost == warping_matrix[i - 1, j]:
+ i -= 1
+ else:
+ j -= 1
+
+ path.append((i, j))
+
+ path.reverse()
+
+ return warping_matrix, path
+
+
+# %%
+# Compute DTW distances and warping paths between cell 1 and all other cells
+reference_cell = 0 # Cell 1 (0-indexed)
+dtw_results = []
+
+for i in range(num_cells):
+ if i != reference_cell:
+ # Get distance and path from dtaidistance
+ path, dist = warping_path(
+ cell_embeddings[reference_cell],
+ cell_embeddings[i],
+ include_distance=True,
+ )
+
+ # Compute our own warping matrix for visualization
+ warping_matrix, _ = compute_dtw_matrix(
+ cell_embeddings[reference_cell], cell_embeddings[i]
+ )
+
+ dtw_results.append(
+ (i + 1, dist, path, warping_matrix)
+ ) # Store cell number, distance, path, matrix
+ print(f"DTW distance between Cell 1 and Cell {i+1} dist: {dist:.4f}")
+
+# %%
+# Visualize the DTW distances
+cell_ids = [result[0] for result in dtw_results]
+distances = [result[1] for result in dtw_results]
+
+plt.figure(figsize=(10, 6))
+plt.bar(cell_ids, distances)
+plt.xlabel("Cell ID")
+plt.ylabel("DTW Distance from Cell 1")
+plt.title("DTW Distances from Cell 1 to Other Cells")
+plt.xticks(cell_ids)
+plt.tight_layout()
+plt.show()
+
+# %%
+# Create a grid of all warping matrices in a 4x2 layout
+fig, axes = plt.subplots(4, 2, figsize=(16, 24))
+axes = axes.flatten()
+
+# Common colorbar limits for better comparison
+all_matrices = [result[3] for result in dtw_results]
+vmin = min(matrix.min() for matrix in all_matrices)
+vmax = max(matrix.max() for matrix in all_matrices)
+
+# Add diagonal reference line for comparison
+diagonal = np.linspace(0, num_timepoints - 1, 100)
+
+for i, result in enumerate(dtw_results):
+ cell_id, dist, path, warping_matrix = result
+ ax = axes[i]
+
+ # Plot the warping matrix
+ im = ax.imshow(
+ warping_matrix,
+ origin="lower",
+ aspect="auto",
+ cmap="viridis",
+ vmin=vmin,
+ vmax=vmax,
+ )
+
+ # Plot diagonal reference line
+ ax.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal")
+
+ # Extract and plot the best path
+ path_x = [p[0] for p in path]
+ path_y = [p[1] for p in path]
+ ax.plot(path_y, path_x, "r-", linewidth=2, label="Best path")
+
+ # Add some arrows to show direction
+ step = max(1, len(path) // 5) # Show 5 arrows along the path
+ for j in range(0, len(path) - 1, step):
+ ax.annotate(
+ "",
+ xy=(path_y[j + 1], path_x[j + 1]),
+ xytext=(path_y[j], path_x[j]),
+ arrowprops=dict(arrowstyle="->", color="orange", lw=1.5),
+ )
+
+ # Add title and axes labels
+ cell_desc = ""
+ if cell_id == 2:
+ cell_desc = " (Small variations)"
+ elif cell_id == 3:
+ cell_desc = " (Slow→Fast)"
+ elif cell_id == 4:
+ cell_desc = " (Fast→Slow)"
+ elif cell_id == 5:
+ cell_desc = " (Missing middle)"
+ elif cell_id == 6:
+ cell_desc = " (Phase shift)"
+ elif cell_id == 7:
+ cell_desc = " (Double frequency)"
+ elif cell_id == 8:
+ cell_desc = " (Different pattern)"
+
+ ax.set_title(f"Cell 1 vs Cell {cell_id}{cell_desc} (Dist: {dist:.2f})")
+ ax.set_xlabel("Cell {} Timepoints".format(cell_id))
+ ax.set_ylabel("Cell 1 Timepoints")
+
+ # Add legend
+ ax.legend(loc="lower right", fontsize=8)
+
+# Add a colorbar at the bottom spanning all subplots
+cbar_ax = fig.add_axes([0.15, 0.05, 0.7, 0.02])
+cbar = fig.colorbar(im, cax=cbar_ax, orientation="horizontal")
+cbar.set_label("Accumulated Cost")
+
+plt.suptitle("DTW Warping Matrices: Cell 1 vs All Other Cells", fontsize=16)
+plt.tight_layout(rect=[0, 0.07, 1, 0.98])
+plt.show()
+
+
+# %%
+# Compute average embedding by aligning all cells to the reference cell
+def align_to_reference(
+ reference: np.ndarray, query: np.ndarray, path: list[tuple[int, int]]
+) -> np.ndarray:
+ """
+ Align a query embedding to the reference timepoints based on the DTW path.
+
+ Args:
+ reference: Reference embedding (n_timepoints x n_dims)
+ query: Query embedding to align (m_timepoints x n_dims)
+ path: DTW path as list of (ref_idx, query_idx) tuples
+
+ Returns:
+ aligned_query: Query embeddings aligned to reference timepoints
+ """
+ n_ref, n_dims = reference.shape
+ aligned_query = np.zeros_like(reference)
+
+ # Count how many query timepoints map to each reference timepoint
+ counts = np.zeros(n_ref)
+
+ # Sum query embeddings for each reference timepoint based on the path
+ for ref_idx, query_idx in path:
+ aligned_query[ref_idx] += query[query_idx]
+ counts[ref_idx] += 1
+
+ # Average when multiple query timepoints map to the same reference timepoint
+ for i in range(n_ref):
+ if counts[i] > 0:
+ aligned_query[i] /= counts[i]
+ else:
+ # If no query timepoints map to this reference, use nearest neighbors
+ nearest_idx = min(range(len(path)), key=lambda j: abs(path[j][0] - i))
+ aligned_query[i] = query[path[nearest_idx][1]]
+
+ return aligned_query
+
+
+# Identify the top 5 most similar cells
+# Sort cells by distance
+sorted_results = sorted(dtw_results, key=lambda x: x[1]) # Sort by distance (x[1])
+
+# Select top 5 closest cells
+top_5_cells = sorted_results[:5]
+print("Top 5 cells (by DTW distance to reference):")
+for cell_id, dist, _, _ in top_5_cells:
+ print(f"Cell {cell_id}: Distance = {dist:.4f}")
+
+# First, align all the top 5 cells to the reference timepoints
+aligned_cells = []
+all_cell_ids = [reference_cell + 1] + [cell_id for cell_id, _, _, _ in top_5_cells]
+
+# Include reference cell as-is (it's already aligned)
+aligned_cells.append(cell_embeddings[reference_cell])
+
+# Align each of the top 5 cells
+for cell_id, _, path, _ in top_5_cells:
+ cell_idx = cell_id - 1 # Convert to 0-indexed
+
+ # Get aligned version of this cell
+ aligned_cell = align_to_reference(
+ cell_embeddings[reference_cell], cell_embeddings[cell_idx], path
+ )
+
+ # Store the aligned cell
+ aligned_cells.append(aligned_cell)
+
+# %%
+# Visualize the aligned cells before averaging
+plt.figure(figsize=(18, 10))
+
+# Create subplots for the first 4 dimensions
+for dim in range(4):
+ plt.subplot(2, 2, dim + 1)
+
+ # Plot each aligned cell
+ for i, (cell_id, aligned_cell) in enumerate(zip(all_cell_ids, aligned_cells)):
+ if i == 0:
+ # Reference cell
+ plt.plot(
+ range(num_timepoints),
+ aligned_cell[:, dim],
+ "b-",
+ linewidth=2,
+ label=f"Cell 1 (Reference)",
+ )
+ else:
+ # Other aligned cells
+ plt.plot(
+ range(num_timepoints),
+ aligned_cell[:, dim],
+ "g-",
+ alpha=0.5,
+ linewidth=1,
+ label=f"Cell {cell_id} (Aligned)" if i == 1 else None,
+ )
+
+ plt.title(f"Dimension {dim+1}: Aligned Cells")
+ plt.xlabel("Reference Timepoint")
+ plt.ylabel(f"Value (Dim {dim+1})")
+ plt.grid(alpha=0.3)
+ plt.legend()
+
+plt.tight_layout()
+plt.suptitle("Cells Aligned to Reference before Averaging", fontsize=16, y=1.02)
+plt.show()
+
+# %%
+# Now compute the average of the aligned cells
+average_embedding = np.zeros_like(cell_embeddings[reference_cell])
+
+# Add all aligned cells
+for aligned_cell in aligned_cells:
+ average_embedding += aligned_cell
+
+# Divide by number of cells
+average_embedding /= len(aligned_cells)
+
+# %%
+# Visualize the original embeddings and the average embedding
+plt.figure(figsize=(18, 8))
+
+# Create subplots for the first 4 dimensions
+for dim in range(4):
+ plt.subplot(2, 2, dim + 1)
+
+ # Plot all original cells (transparent)
+ for i in range(num_cells):
+ # Determine if this cell is in the top 5
+ is_top5 = False
+ for cell_id, _, _, _ in top_5_cells:
+ if i + 1 == cell_id: # Convert 1-indexed to 0-indexed
+ is_top5 = True
+ break
+
+ # Style based on cell type
+ alpha = 0.3
+ color = "gray"
+ label = None
+
+ if i == reference_cell:
+ alpha = 0.7
+ color = "blue"
+ label = "Cell 1 (Reference)"
+ elif is_top5:
+ alpha = 0.5
+ color = "green"
+ if dim == 0 and i == top_5_cells[0][0] - 1: # Only label once
+ label = "Top 5 Cells"
+
+ plt.plot(
+ range(num_timepoints),
+ cell_embeddings[i, :, dim],
+ alpha=alpha,
+ color=color,
+ linewidth=1,
+ label=label,
+ )
+
+ # Plot the average embedding
+ plt.plot(
+ range(num_timepoints),
+ average_embedding[:, dim],
+ "r-",
+ linewidth=2,
+ label="Average Embedding",
+ )
+
+ plt.title(f"Dimension {dim+1}: Original vs Average")
+ plt.xlabel("Timepoint")
+ plt.ylabel(f"Value (Dim {dim+1})")
+ plt.grid(alpha=0.3)
+ plt.legend()
+
+plt.tight_layout()
+plt.suptitle(
+ "Average Embedding from Top 5 Similar Cells (via DTW)", fontsize=16, y=1.02
+)
+plt.show()
+
+# %%
+# Evaluate the average embedding as a reference
+# Compute DTW distances from average to all cells
+average_dtw_results = []
+
+for i in range(num_cells):
+ # Get distance and path from the average to each cell
+ path, dist = warping_path(
+ average_embedding,
+ cell_embeddings[i],
+ include_distance=True,
+ )
+
+ # Compute warping matrix for visualization
+ warping_matrix, _ = compute_dtw_matrix(average_embedding, cell_embeddings[i])
+
+ average_dtw_results.append((i + 1, dist, path, warping_matrix))
+ print(f"DTW distance between Average and Cell {i+1} dist: {dist:.4f}")
+
+# %%
+# Compare distances: Cell 1 as reference vs Average as reference
+# Combine the DTW distances for comparison
+comparison_data = []
+
+# Add Cell 1 reference distances
+for i in range(num_cells):
+ if i == reference_cell:
+ # Distance to self is 0
+ comparison_data.append(
+ {
+ "Cell ID": i + 1,
+ "To Cell 1": 0.0,
+ "To Average": average_dtw_results[i][1],
+ }
+ )
+ else:
+ # Find the matching result from dtw_results
+ for cell_id, dist, _, _ in dtw_results:
+ if cell_id == i + 1:
+ comparison_data.append(
+ {
+ "Cell ID": i + 1,
+ "To Cell 1": dist,
+ "To Average": average_dtw_results[i][1],
+ }
+ )
+ break
+
+# Prepare bar chart data
+cell_ids = [d["Cell ID"] for d in comparison_data]
+to_cell1 = [d["To Cell 1"] for d in comparison_data]
+to_average = [d["To Average"] for d in comparison_data]
+
+# Compute some statistics
+total_to_cell1 = sum(to_cell1)
+total_to_average = sum(to_average)
+avg_to_cell1 = total_to_cell1 / len(cell_ids)
+avg_to_average = total_to_average / len(cell_ids)
+
+# Create a comparison bar chart
+plt.figure(figsize=(12, 6))
+x = np.arange(len(cell_ids))
+width = 0.35
+
+plt.bar(x - width / 2, to_cell1, width, label="Distance to Cell 1")
+plt.bar(x + width / 2, to_average, width, label="Distance to Average")
+
+plt.xlabel("Cell ID")
+plt.ylabel("DTW Distance")
+plt.title("Comparison: Cell 1 vs Average as Reference")
+plt.xticks(x, cell_ids)
+plt.legend()
+
+# Add summary as text annotation
+plt.figtext(
+ 0.5,
+ 0.01,
+ f"Total distance - Cell 1: {total_to_cell1:.2f}, Average: {total_to_average:.2f}\n"
+ f"Mean distance - Cell 1: {avg_to_cell1:.2f}, Average: {avg_to_average:.2f}",
+ ha="center",
+ fontsize=10,
+ bbox=dict(facecolor="white", alpha=0.8),
+)
+
+plt.tight_layout(rect=[0, 0.05, 1, 0.95])
+plt.show()
+
+# %%
+# Visualize selected warping matrices using the average as reference
+# Select a few representative cells (one similar, one different)
+similar_cell_idx = 1 # Cell 2
+different_cell_idx = 7 # Cell 8
+
+plt.figure(figsize=(16, 7))
+
+# Plot warping matrix for similar cell
+plt.subplot(1, 2, 1)
+warping_matrix = average_dtw_results[similar_cell_idx - 1][3]
+path = average_dtw_results[similar_cell_idx - 1][2]
+dist = average_dtw_results[similar_cell_idx - 1][1]
+
+plt.imshow(warping_matrix, origin="lower", aspect="auto", cmap="viridis")
+plt.colorbar(label="Accumulated Cost")
+
+# Plot diagonal and path
+diagonal = np.linspace(0, num_timepoints - 1, 100)
+plt.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal")
+
+# Extract and plot the best path
+path_x = [p[0] for p in path]
+path_y = [p[1] for p in path]
+plt.plot(path_y, path_x, "r-", linewidth=2, label="Best path")
+
+plt.title(f"Average vs Cell {similar_cell_idx} (Similar, Dist: {dist:.2f})")
+plt.xlabel(f"Cell {similar_cell_idx} Timepoints")
+plt.ylabel("Average Timepoints")
+plt.legend(loc="lower right", fontsize=8)
+
+# Plot warping matrix for different cell
+plt.subplot(1, 2, 2)
+warping_matrix = average_dtw_results[different_cell_idx - 1][3]
+path = average_dtw_results[different_cell_idx - 1][2]
+dist = average_dtw_results[different_cell_idx - 1][1]
+
+plt.imshow(warping_matrix, origin="lower", aspect="auto", cmap="viridis")
+plt.colorbar(label="Accumulated Cost")
+
+# Plot diagonal and path
+plt.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal")
+
+# Extract and plot the best path
+path_x = [p[0] for p in path]
+path_y = [p[1] for p in path]
+plt.plot(path_y, path_x, "r-", linewidth=2, label="Best path")
+
+plt.title(f"Average vs Cell {different_cell_idx} (Different, Dist: {dist:.2f})")
+plt.xlabel(f"Cell {different_cell_idx} Timepoints")
+plt.ylabel("Average Timepoints")
+plt.legend(loc="lower right", fontsize=8)
+
+plt.suptitle("DTW Warping Matrices: Average as Reference", fontsize=16)
+plt.tight_layout(rect=[0, 0, 1, 0.95])
+plt.show()
+
+# %%
diff --git a/applications/pseudotime_analysis/utils.py b/applications/pseudotime_analysis/utils.py
new file mode 100644
index 000000000..4523f2825
--- /dev/null
+++ b/applications/pseudotime_analysis/utils.py
@@ -0,0 +1,273 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+
+def load_annotation(
+ embedding_dataset: xr.Dataset,
+ track_csv_path: str,
+ name: str,
+ categories: dict | None = None,
+) -> pd.Series:
+ """
+ Load annotations from a CSV file and map them to the dataset.
+ """
+ annotation = pd.read_csv(track_csv_path)
+ annotation["fov_name"] = "/" + annotation["fov ID"]
+
+ embedding_index = pd.MultiIndex.from_arrays(
+ [
+ embedding_dataset["fov_name"].values,
+ embedding_dataset["id"].values,
+ embedding_dataset["t"].values,
+ embedding_dataset["track_id"].values,
+ ],
+ names=["fov_name", "id", "t", "track_id"],
+ )
+
+ annotation = annotation.set_index(["fov_name", "id", "t", "track_id"])
+ selected = annotation.reindex(embedding_index)[name]
+
+ if categories:
+ if -1 in selected.values and -1 not in categories:
+ categories = categories.copy()
+ categories[-1] = np.nan
+ selected = selected.map(categories)
+
+ return selected
+
+
+def identify_lineages(annotations_path: Path) -> list[tuple[str, list[int]]]:
+ """
+ Identifies all distinct lineages in the cell tracking data, following only
+ one branch after each division event.
+
+ Args:
+ annotations_path: Path to the annotations CSV file
+
+ Returns:
+ A list of tuples, where each tuple contains (fov_id, [track_ids])
+ representing a single branch lineage within a single FOV
+ """
+ # Read the CSV file
+ df = pd.read_csv(annotations_path)
+
+ # Ensure column names are consistent
+ if "fov ID" in df.columns and "fov_id" not in df.columns:
+ df["fov_id"] = df["fov ID"]
+
+ # Process each FOV separately to handle repeated track_ids
+ all_lineages = []
+
+ # Group by FOV
+ for fov_id, fov_df in df.groupby("fov_id"):
+ # Create a dictionary to map tracks to their parents within this FOV
+ child_to_parent = {}
+
+ # Group by track_id and get the first row for each track to find its parent
+ for track_id, track_group in fov_df.groupby("track_id"):
+ first_row = track_group.iloc[0]
+ parent_track_id = first_row["parent_track_id"]
+
+ if parent_track_id != -1:
+ child_to_parent[track_id] = parent_track_id
+
+ # Find root tracks (those without parents or with parent_track_id = -1)
+ all_tracks = set(fov_df["track_id"].unique())
+ child_tracks = set(child_to_parent.keys())
+ root_tracks = all_tracks - child_tracks
+
+ # Build a parent-to-children mapping
+ parent_to_children = {}
+ for child, parent in child_to_parent.items():
+ if parent not in parent_to_children:
+ parent_to_children[parent] = []
+ parent_to_children[parent].append(child)
+
+ # Function to get a single branch from each parent
+ # We'll choose the first child in the list (arbitrary choice)
+ def get_single_branch(track_id):
+ branch = [track_id]
+ if track_id in parent_to_children:
+ # Choose only the first child (you could implement other selection criteria)
+ first_child = parent_to_children[track_id][0]
+ branch.extend(get_single_branch(first_child))
+ return branch
+
+ # Build lineages starting from root tracks within this FOV
+ for root_track in root_tracks:
+ # Get a single branch from this root
+ lineage_tracks = get_single_branch(root_track)
+ all_lineages.append((fov_id, lineage_tracks))
+
+ return all_lineages
+
+
+def path_skew(warping_path, ref_len, query_len):
+ """
+ Calculate the skewness of a DTW warping path.
+
+ Args:
+ warping_path: List of tuples (ref_idx, query_idx) representing the warping path
+ ref_len: Length of the reference sequence
+ query_len: Length of the query sequence
+
+ Returns:
+ A skewness metric between 0 and 1, where:
+ - 0 means perfectly diagonal path (ideal alignment)
+ - 1 means completely skewed (worst alignment)
+ """
+ # Convert path to numpy array for easier manipulation
+ path_array = np.array(warping_path)
+
+ # Calculate "ideal" diagonal indices
+ diagonal_x = np.linspace(0, ref_len - 1, len(warping_path))
+ diagonal_y = np.linspace(0, query_len - 1, len(warping_path))
+ diagonal_path = np.column_stack((diagonal_x, diagonal_y))
+
+ # Calculate distances from points on the warping path to the diagonal
+ # Normalize based on max possible distance (corner to diagonal)
+ max_distance = max(ref_len, query_len)
+
+ # Calculate distance from each point to the diagonal
+ distances = []
+ for i, (x, y) in enumerate(path_array):
+ # Find the closest point on the diagonal
+ dx, dy = diagonal_path[i]
+ # Simple Euclidean distance
+ dist = np.sqrt((x - dx) ** 2 + (y - dy) ** 2)
+ distances.append(dist)
+
+ # Average normalized distance as skewness metric
+ skew = np.mean(distances) / max_distance
+
+ return skew
+
+
+def dtw_with_matrix(distance_matrix, normalize=True):
+ """
+ Compute DTW using a pre-computed distance matrix.
+
+ Args:
+ distance_matrix: Pre-computed distance matrix between two sequences
+ normalize: Whether to normalize the distance by path length (default: True)
+
+ Returns:
+ dtw_distance: The DTW distance
+ warping_matrix: The accumulated cost matrix
+ best_path: The optimal warping path
+ """
+ n, m = distance_matrix.shape
+
+ # Initialize the accumulated cost matrix
+ warping_matrix = np.full((n, m), np.inf)
+ warping_matrix[0, 0] = distance_matrix[0, 0]
+
+ # Fill the first column and row
+ for i in range(1, n):
+ warping_matrix[i, 0] = warping_matrix[i - 1, 0] + distance_matrix[i, 0]
+ for j in range(1, m):
+ warping_matrix[0, j] = warping_matrix[0, j - 1] + distance_matrix[0, j]
+
+ # Fill the rest of the matrix
+ for i in range(1, n):
+ for j in range(1, m):
+ warping_matrix[i, j] = distance_matrix[i, j] + min(
+ warping_matrix[i - 1, j], # insertion
+ warping_matrix[i, j - 1], # deletion
+ warping_matrix[i - 1, j - 1], # match
+ )
+
+ # Backtrack to find the optimal path
+ i, j = n - 1, m - 1
+ path = [(i, j)]
+
+ while i > 0 or j > 0:
+ if i == 0:
+ j -= 1
+ elif j == 0:
+ i -= 1
+ else:
+ min_cost = min(
+ warping_matrix[i - 1, j],
+ warping_matrix[i, j - 1],
+ warping_matrix[i - 1, j - 1],
+ )
+
+ if min_cost == warping_matrix[i - 1, j - 1]:
+ i, j = i - 1, j - 1
+ elif min_cost == warping_matrix[i - 1, j]:
+ i -= 1
+ else:
+ j -= 1
+
+ path.append((i, j))
+
+ path.reverse()
+
+ # Get the DTW distance (bottom-right cell)
+ dtw_distance = warping_matrix[n - 1, m - 1]
+
+ # Normalize by path length if requested
+ if normalize:
+ dtw_distance = dtw_distance / len(path)
+
+ return dtw_distance, warping_matrix, path
+
+
+# %%
+def filter_lineages_by_timepoints(lineages, annotation_path, min_timepoints=10):
+ """
+ Filter lineages that have fewer than min_timepoints total timepoints.
+
+ Args:
+ lineages: List of tuples (fov_id, [track_ids])
+ annotation_path: Path to the annotations CSV file
+ min_timepoints: Minimum number of timepoints required
+
+ Returns:
+ List of filtered lineages
+ """
+ # Read the annotations file
+ df = pd.read_csv(annotation_path)
+
+ # Ensure column names are consistent
+ if "fov ID" in df.columns and "fov_id" not in df.columns:
+ df["fov_id"] = df["fov ID"]
+
+ filtered_lineages = []
+
+ for fov_id, track_ids in lineages:
+ # Get all rows for this lineage
+ lineage_rows = df[(df["fov_id"] == fov_id) & (df["track_id"].isin(track_ids))]
+
+ # Count the total number of timepoints
+ total_timepoints = len(lineage_rows)
+
+ # Only keep lineages with at least min_timepoints
+ if total_timepoints >= min_timepoints:
+ filtered_lineages.append((fov_id, track_ids))
+
+ return filtered_lineages
+
+
+def find_top_matching_tracks(cell_division_df, infection_df, n_top=10) -> pd.DataFrame:
+ # Find common tracks between datasets
+ intersection_df = pd.merge(
+ cell_division_df,
+ infection_df,
+ on=["fov_name", "track_ids"],
+ how="inner",
+ suffixes=("_df1", "_df2"),
+ )
+
+ # Add column with sum of the values
+ intersection_df["distance_sum"] = (
+ intersection_df["distance_df1"] + intersection_df["distance_df2"]
+ )
+
+ # Find rows with the smallest sum
+ intersection_df.sort_values(by="distance_sum", ascending=True, inplace=True)
+ return intersection_df.head(n_top)
diff --git a/docs/figures/DynaCLR_schematic_v2.png b/docs/figures/DynaCLR_schematic_v2.png
new file mode 100644
index 000000000..274f89b62
Binary files /dev/null and b/docs/figures/DynaCLR_schematic_v2.png differ
diff --git a/docs/figures/dynaCLR_schematic.png b/docs/figures/dynaCLR_schematic.png
deleted file mode 100644
index eb591cb7d..000000000
Binary files a/docs/figures/dynaCLR_schematic.png and /dev/null differ
diff --git a/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md
new file mode 100644
index 000000000..7534c61c5
--- /dev/null
+++ b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md
@@ -0,0 +1,94 @@
+# Cell Infection Analysis Demo: ImageNet vs DynaCLR-DENV-VS+Ph model
+
+This demo compares different feature extraction methods for analyzing infected vs uninfected cells using microscopy images.
+
+As the cells get infected, the red fluorescence protein (RFP) translocates from the cytoplasm into the nucleus.
+
+## Overview
+
+The `demo_infection.py` script demonstrates:
+
+ - PHATE plots from the embeddings generated from DynaCLR and ImageNet
+ - Show the infection progression in cells via Phase and RFP (viral sensor) channels
+ - Highlighted trajectories for sample infected and uninfected cells over time
+
+## Key Features
+
+- **Feature Extraction**: Compare ImageNet pre-trained and specialized DynaCLR features
+- **Interactive Visualization**: Create plotly-based visualizations with time sliders
+- **Side-by-Side Comparison**: Directly compare cell images and PHATE embeddings
+- **Trajectory Analysis**: Visualize and track cell trajectories over time
+- **Infection State Analysis**: See how different models capture infection dynamics
+
+
+## Usage
+
+After [setting up the environment and downloading the data](/examples/DynaCLR/README.md#setup), activate it and run the demo script:
+
+```bash
+conda activate dynaclr
+python demo_infection.py
+```
+
+For both of these you will need to ensure to point to the path to the downloaded data:
+```python
+# Update these paths to your data
+input_data_path = "/path/to/registered_test.zarr"
+tracks_path = "/path/to/track_test.zarr"
+ann_path = "/path/to/extracted_inf_state.csv"
+
+# Update paths to features
+dynaclr_features_path = "/path/to/dynaclr_features.zarr"
+imagenet_features_path = "/path/to/imagenet_features.zarr"
+```
+
+Check out the demo's output visualization :
+
+