diff --git a/README.md b/README.md index a512c7585..24c152c49 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Please choose a stable release from PyPI for production use. ## Cytoland (robust virtual staining) -### Online demo [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm-dark.svg)](https://huggingface.co/spaces/chanzuckerberg/Cytoland) +### Demo [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm-dark.svg)](https://huggingface.co/spaces/chanzuckerberg/Cytoland) Try the 2D virtual staining demo of cell nuclei and membrane from label-free images on [Hugging Face](https://huggingface.co/spaces/chanzuckerberg/Cytoland). @@ -70,7 +70,7 @@ See the full gallery [here](https://github.com/mehta-lab/VisCy/wiki/Gallery). |:---:|:---:|:---:| | [![HEK293T](https://github.com/mehta-lab/VisCy/blob/dde3e27482e58a30f7c202e56d89378031180c75/docs/figures/svideo_1.png?raw=true)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755) | [![Neuromast](https://github.com/mehta-lab/VisCy/blob/dde3e27482e58a30f7c202e56d89378031180c75/docs/figures/svideo_3.png?raw=true)](https://github.com/mehta-lab/VisCy/assets/67518483/4cef8333-895c-486c-b260-167debb7fd64) | [![A549](https://github.com/mehta-lab/VisCy/blob/dde3e27482e58a30f7c202e56d89378031180c75/docs/figures/svideo_5.png?raw=true)](https://github.com/mehta-lab/VisCy/assets/67518483/287737dd-6b74-4ce3-8ee5-25fbf8be0018) | -### Reference +### References The virtual staining models and training protocols are reported in our recent [preprint on robust virtual staining](https://www.biorxiv.org/content/10.1101/2024.05.31.596901). @@ -123,35 +123,19 @@ This package evolved from the [TensorFlow version of virtual staining pipeline]( The robust virtual staining models (i.e *VSCyto2D*, *VSCyto3D*, *VSNeuromast*), and fine-tuned models can be found [here](https://github.com/mehta-lab/VisCy/wiki/Library-of-virtual-staining-(VS)-Models) -### Pipeline -A full illustration of the virtual staining pipeline can be found [here](https://github.com/mehta-lab/VisCy/blob/dde3e27482e58a30f7c202e56d89378031180c75/docs/virtual_staining.md). +## DynaCLR (Embedding Cell Dynamics via Contrastive Learning of Representations) -## DynaCLR (Contrastive learning of representations of cell dynamics) +DynaCLR is a self-supervised method for learning robust and temporally-regularized representations of cell and organelle dynamics from time-lapse microscopy using contrastive learning. It supports diverse downstream biological tasks -- including cell state classification with efficient human annotations, knowledge distillation across fluorescence and label-free imaging channels, and alignment of cell state dynamics. -We are currently developing self-supervised representation learning to map cell state dynamics in response to perturbations, -with focus on cell and organelle remodeling due to viral infection. +### Preprint +[DynaCLR on arXiv](https://arxiv.org/abs/2410.11281): -See our recent work on temporally regularized contrastive sampling method -for representation learning on [arXiv](https://arxiv.org/abs/2410.11281). +![DynaCLR schematic](https://github.com/mehta-lab/VisCy/blob/e5318d88e2bb5d404d3bae8d633b8cc07b1fbd61/docs/figures/DynaCLR_schematic_v2.png?raw=true) -
- Pradeep, Imran, Liu et al., 2024 - -

-@misc{pradeep_contrastive_2024,
-      title={Contrastive learning of cell state dynamics in response to perturbations},
-      author={Soorya Pradeep and Alishba Imran and Ziwen Liu and Taylla Milena Theodoro and Eduardo Hirata-Miyasaki and Ivan Ivanov and Madhura Bhave and Sudip Khadka and Hunter Woosley and Carolina Arias and Shalin B. Mehta},
-      year={2024},
-      eprint={2410.11281},
-      archivePrefix={arXiv},
-      primaryClass={cs.CV},
-      url={https://arxiv.org/abs/2410.11281},
-}
-    
-
-### Workflow demo +### Demo +- [DynaCLR demos](examples/DynaCLR/README.md) - Example test dataset, model checkpoint, and predictions can be found [here](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_demo/). @@ -159,7 +143,6 @@ for representation learning on [arXiv](https://arxiv.org/abs/2410.11281). - See tutorial on exploration of learned embeddings with napari-iohub [here](https://github.com/czbiohub-sf/napari-iohub/wiki/View-tracked-cells-and-their-associated-predictions/). -![DynaCLR schematic](https://github.com/mehta-lab/VisCy/blob/9eaab7eca50d684d8a473ad9da089aeab0e8f6a0/docs/figures/dynaCLR_schematic.png?raw=true) ## Installation diff --git a/applications/benchmarking/DynaCLR/ImageNet/config.yml b/applications/benchmarking/DynaCLR/ImageNet/config.yml new file mode 100644 index 000000000..630ec8f99 --- /dev/null +++ b/applications/benchmarking/DynaCLR/ImageNet/config.yml @@ -0,0 +1,45 @@ +datamodule: + batch_size: 32 + final_yx_patch_size: + - 160 + - 160 + include_fov_names: null + include_track_ids: null + initial_yx_patch_size: + - 160 + - 160 + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + b_max: 1.0 + b_min: 0.0 + keys: + - RFP + lower: 50 + upper: 99 + num_workers: 60 + source_channel: + - RFP + z_range: + - 15 + - 45 +embedding: + pca_kwargs: + n_components: 8 + phate_kwargs: + decay: 40 + knn: 5 + n_components: 2 + n_jobs: -1 + random_state: 42 +execution: + overwrite: true + save_config: true + show_config: true +model: + channel_reduction_methods: + RFP: max +paths: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr + output_path: /home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/ImageNet/20240204_A549_DENV_ZIKV_sensor_only_imagenet.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr diff --git a/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py b/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py new file mode 100644 index 000000000..168e97cc7 --- /dev/null +++ b/applications/benchmarking/DynaCLR/ImageNet/imagenet_embeddings.py @@ -0,0 +1,388 @@ +""" +Generate embeddings using a pre-trained ImageNet model and save them to a zarr store +using VisCy Trainer and EmbeddingWriter callback. +""" + +import importlib +import logging +import os +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import click +import timm +import torch +import yaml +from lightning.pytorch import LightningModule + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import EmbeddingWriter +from viscy.trainer import VisCyTrainer + +logger = logging.getLogger(__name__) + + +class ImageNetModule(LightningModule): + def __init__( + self, + model_name: str = "convnext_tiny", + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + ): + """Initialize the ImageNet module. + + Args: + model_name: Name of the pre-trained ImageNet model to use + channel_reduction_methods: Dict mapping channel names to reduction methods: + - "middle_slice": Take the middle slice along the depth dimension + - "mean": Average across the depth dimension + - "max": Take the maximum value across the depth dimension + channel_names: List of channel names corresponding to the input channels + """ + super().__init__() + self.channel_reduction_methods = channel_reduction_methods or {} + self.channel_names = channel_names or [] + + try: + torch.set_float32_matmul_precision("high") + self.model = timm.create_model(model_name, pretrained=True) + self.model.eval() + except ImportError: + raise ImportError("Please install the timm library: " "pip install timm") + + def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: + """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + + Args: + x: 5D input tensor + + Returns: + 4D tensor after applying reduction methods + """ + if x.dim() != 5: + return x + + B, C, D, H, W = x.shape + result = torch.zeros((B, C, H, W), device=x.device) + + # Process all channels at once for each reduction method to minimize loops + middle_slice_indices = [] + mean_indices = [] + max_indices = [] + + # Group channels by reduction method + for c in range(C): + channel_name = ( + self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" + ) + method = self.channel_reduction_methods.get(channel_name, "middle_slice") + + if method == "mean": + mean_indices.append(c) + elif method == "max": + max_indices.append(c) + else: # Default to middle_slice for any unknown method + middle_slice_indices.append(c) + + # Apply middle_slice reduction to all relevant channels at once + if middle_slice_indices: + indices = torch.tensor(middle_slice_indices, device=x.device) + result[:, indices] = x[:, indices, D // 2] + + # Apply mean reduction to all relevant channels at once + if mean_indices: + indices = torch.tensor(mean_indices, device=x.device) + result[:, indices] = x[:, indices].mean(dim=2) + + # Apply max reduction to all relevant channels at once + if max_indices: + indices = torch.tensor(max_indices, device=x.device) + result[:, indices] = x[:, indices].max(dim=2)[0] + + return result + + def _convert_to_rgb(self, x: torch.Tensor) -> torch.Tensor: + """Convert input tensor to 3-channel RGB format as needed. + + Args: + x: Input tensor with 1, 2, or 3+ channels + + Returns: + 3-channel tensor suitable for ImageNet models + """ + if x.shape[1] == 3: + return x + elif x.shape[1] == 1: + # Convert to RGB by repeating the channel 3 times + return x.repeat(1, 3, 1, 1) + elif x.shape[1] == 2: + # Normalize each channel independently to handle different scales + B, _, H, W = x.shape + x_3ch = torch.zeros((B, 3, H, W), device=x.device, dtype=x.dtype) + + # Normalize each channel to 0-1 range + ch0 = x[:, 0:1] + ch1 = x[:, 1:2] + + ch0_min = ch0.reshape(B, -1).min(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1) + ch0_max = ch0.reshape(B, -1).max(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1) + ch0_range = ch0_max - ch0_min + 1e-7 # Add epsilon for numerical stability + ch0_norm = (ch0 - ch0_min) / ch0_range + + ch1_min = ch1.reshape(B, -1).min(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1) + ch1_max = ch1.reshape(B, -1).max(dim=1, keepdim=True)[0].reshape(B, 1, 1, 1) + ch1_range = ch1_max - ch1_min + 1e-7 # Add epsilon for numerical stability + ch1_norm = (ch1 - ch1_min) / ch1_range + + # Create blended RGB channels - map each normalized channel to different colors + x_3ch[:, 0] = ch0_norm.squeeze(1) # R channel from first input + x_3ch[:, 1] = ch1_norm.squeeze(1) # G channel from second input + x_3ch[:, 2] = 0.5 * ( + ch0_norm.squeeze(1) + ch1_norm.squeeze(1) + ) # B channel as blend + + return x_3ch + else: + # For more than 3 channels, use the first 3 + return x[:, :3] + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Extract features from the input images. + + Returns: + Dictionary with features, properly shaped empty projections tensor, and index information + """ + x = batch["anchor"] + + # Handle 5D input (B, C, D, H, W) using configured reduction methods + if x.dim() == 5: + x = self._reduce_5d_input(x) + + # Convert input to RGB format + x = self._convert_to_rgb(x) + + # Get embeddings + with torch.no_grad(): + features = self.model.forward_features(x) + + # Average pooling to get feature vector + if features.dim() > 2: + features = features.mean(dim=[2, 3]) + + # Return features and empty projections with correct batch dimension + return { + "features": features, + "projections": torch.zeros((features.shape[0], 0), device=features.device), + "index": batch["index"], + } + + +def load_config(config_file): + """Load configuration from a YAML file.""" + with open(config_file, "r") as f: + config = yaml.safe_load(f) + return config + + +def load_normalization_from_config(norm_config): + """Load a normalization transform from a configuration dictionary.""" + class_path = norm_config["class_path"] + init_args = norm_config.get("init_args", {}) + + # Split module and class name + module_path, class_name = class_path.rsplit(".", 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the class + transform_class = getattr(module, class_name) + + # Instantiate the transform + return transform_class(**init_args) + + +@click.command() +@click.option( + "--config", + "-c", + type=click.Path(exists=True), + required=True, + help="Path to YAML configuration file", +) +@click.option( + "--model", + "-m", + type=str, + default="convnext_tiny", + help="Name of the pre-trained ImageNet model to use", +) +def main(config, model): + """Extract ImageNet embeddings and save to zarr format using VisCy Trainer.""" + # Configure logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load config file + cfg = load_config(config) + logger.info(f"Loaded configuration from {config}") + + # Prepare datamodule parameters + dm_params = {} + + # Add data and tracks paths from the paths section + if "paths" not in cfg: + raise ValueError("Configuration must contain a 'paths' section") + + if "data_path" not in cfg["paths"]: + raise ValueError( + "Data path is required in the configuration file (paths.data_path)" + ) + dm_params["data_path"] = cfg["paths"]["data_path"] + + if "tracks_path" not in cfg["paths"]: + raise ValueError( + "Tracks path is required in the configuration file (paths.tracks_path)" + ) + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Add datamodule parameters + if "datamodule" not in cfg: + raise ValueError("Configuration must contain a 'datamodule' section") + + # Prepare normalizations + if ( + "normalizations" not in cfg["datamodule"] + or not cfg["datamodule"]["normalizations"] + ): + raise ValueError( + "Normalizations are required in the configuration file (datamodule.normalizations)" + ) + + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy all other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + # Handle patch sizes + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + # Set up the data module + logger.info("Setting up data module") + dm = TripletDataModule(**dm_params) + + # Get model parameters for handling 5D inputs + channel_reduction_methods = {} + + if "model" in cfg and "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + + # Initialize ImageNet model with reduction settings + logger.info(f"Loading ImageNet model: {model}") + model_module = ImageNetModule( + model_name=model, + channel_reduction_methods=channel_reduction_methods, + channel_names=dm_params.get("source_channel", []), + ) + + # Get dimensionality reduction parameters from config + phate_kwargs = None + pca_kwargs = None + + if "embedding" in cfg: + # Check for both capitalization variants and normalize + if "phate_kwargs" in cfg["embedding"]: + phate_kwargs = cfg["embedding"]["phate_kwargs"] + + if "umap_kwargs" in cfg["embedding"]: + umap_kwargs = cfg["embedding"]["umap_kwargs"] + + if "pca_kwargs" in cfg["embedding"]: + pca_kwargs = cfg["embedding"]["pca_kwargs"] + + # Check if output path exists and should be overwritten + if "output_path" not in cfg["paths"]: + raise ValueError( + "Output path is required in the configuration file (paths.output_path)" + ) + + output_path = Path(cfg["paths"]["output_path"]) + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + + overwrite = False + if "execution" in cfg and "overwrite" in cfg["execution"]: + overwrite = cfg["execution"]["overwrite"] + elif output_path.exists(): + logger.warning(f"Output path {output_path} already exists, will overwrite") + overwrite = True + + # Set up EmbeddingWriter callback + embedding_writer = EmbeddingWriter( + output_path=output_path, + phate_kwargs=phate_kwargs, + pca_kwargs=pca_kwargs, + overwrite=overwrite, + ) + + # Set up and run VisCy trainer + logger.info("Setting up VisCy trainer") + trainer = VisCyTrainer( + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + callbacks=[embedding_writer], + inference_mode=True, + ) + + logger.info(f"Running prediction and saving to {output_path}") + trainer.predict(model_module, datamodule=dm) + + # Save configuration if requested + save_config_flag = True + show_config_flag = True + + if "execution" in cfg: + if "save_config" in cfg["execution"]: + save_config_flag = cfg["execution"]["save_config"] + if "show_config" in cfg["execution"]: + show_config_flag = cfg["execution"]["show_config"] + + # Save configuration if requested + if save_config_flag: + config_path = os.path.join(output_dir, "config.yml") + with open(config_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False) + logger.info(f"Configuration saved to {config_path}") + + # Display configuration if requested + if show_config_flag: + click.echo("\nConfiguration used:") + click.echo("-" * 40) + for key, value in cfg.items(): + click.echo(f"{key}:") + if isinstance(value, dict): + for subkey, subvalue in value.items(): + if isinstance(subvalue, list) and subkey == "normalizations": + click.echo(f" {subkey}:") + for norm in subvalue: + click.echo(f" - class_path: {norm['class_path']}") + click.echo(f" init_args: {norm['init_args']}") + else: + click.echo(f" {subkey}: {subvalue}") + else: + click.echo(f" {value}") + click.echo("-" * 40) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml new file mode 100644 index 000000000..4139e0d08 --- /dev/null +++ b/applications/benchmarking/DynaCLR/OpenPhenom/config_template.yml @@ -0,0 +1,66 @@ +# OpenPhenom Embeddings Configuration + +# Paths section +paths: + data_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr + tracks_path: /hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr + output_path: "/home/eduardo.hirata/repos/viscy/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_sec61b_n_phase_3.zarr" + +# Model configuration +model: + # Channel-specific 5D input handling methods + # Options: "middle_slice", "mean", "max" + # Default is "middle_slice" if not specified + channel_reduction_methods: + "Phase3D": "middle_slice" # For phase contrast, middle slice often works well + "raw GFP EX488 EM525-45": "max" + +# Data module configuration +datamodule: + source_channel: + - Phase3D + - "raw GFP EX488 EM525-45" + z_range: [25, 40] + batch_size: 32 + num_workers: 10 + initial_yx_patch_size: [192, 192] + final_yx_patch_size: [192, 192] + predict_cells: true + include_fov_names: + - "/C/2/000000" + - "/C/2/000000" + - "/C/2/000000" + - "/C/2/000000" + include_track_ids: [33,60,57,65] + normalizations: + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: ["Phase3D"] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: ["raw GFP EX488 EM525-45"] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 + +# Embedding parameters +embedding: + phate_kwargs: + n_components: 2 + knn: 5 + decay: 40 + n_jobs: -1 + random_state: 42 + pca_kwargs: + n_components: 2 + +# Execution configuration +execution: + overwrite: false + save_config: true + show_config: true \ No newline at end of file diff --git a/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py new file mode 100644 index 000000000..510d8e132 --- /dev/null +++ b/applications/benchmarking/DynaCLR/OpenPhenom/openphenom_embeddings.py @@ -0,0 +1,330 @@ +""" +Generate embeddings using the OpenPhenom model and save them to a zarr store +using VisCy Trainer and EmbeddingWriter callback. +""" + +import importlib +import logging +import os +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import click +import torch +import yaml +from lightning.pytorch import LightningModule +from transformers import AutoModel + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import EmbeddingWriter +from viscy.trainer import VisCyTrainer + + +class OpenPhenomModule(LightningModule): + def __init__( + self, + channel_reduction_methods: Optional[ + Dict[str, Literal["middle_slice", "mean", "max"]] + ] = None, + channel_names: Optional[List[str]] = None, + ): + """Initialize the OpenPhenom module. + + Parameters + ---------- + channel_reduction_methods : dict, optional + Dictionary mapping channel names to reduction methods: + - "middle_slice": Take the middle slice along the depth dimension + - "mean": Average across the depth dimension + - "max": Take the maximum value across the depth dimension + channel_names : list of str, optional + List of channel names corresponding to the input channels + + Notes + ----- + The module uses the OpenPhenom model from HuggingFace for generating embeddings. + """ + super().__init__() + + self.channel_reduction_methods = channel_reduction_methods or {} + self.channel_names = channel_names or [] + + try: + torch.set_float32_matmul_precision("high") + self.model = AutoModel.from_pretrained( + "recursionpharma/OpenPhenom", trust_remote_code=True + ) + self.model.eval() + except ImportError: + raise ImportError( + "Please install the OpenPhenom dependencies: " + "pip install transformers" + ) + + def on_predict_start(self): + # Move model to GPU when prediction starts + self.model.to(self.device) + + def _reduce_5d_input(self, x: torch.Tensor) -> torch.Tensor: + """Reduce 5D input (B, C, D, H, W) to 4D (B, C, H, W) using specified methods. + + Args: + x: 5D input tensor + + Returns: + 4D tensor after applying reduction methods + """ + if x.dim() != 5: + return x + + B, C, D, H, W = x.shape + result = torch.zeros((B, C, H, W), device=x.device) + + # Apply reduction method for each channel + for c in range(C): + channel_name = ( + self.channel_names[c] if c < len(self.channel_names) else f"channel_{c}" + ) + # Default to middle slice if not specified + method = self.channel_reduction_methods.get(channel_name, "middle_slice") + + if method == "middle_slice": + result[:, c] = x[:, c, D // 2] + elif method == "mean": + result[:, c] = x[:, c].mean(dim=1) + elif method == "max": + result[:, c] = x[:, c].max(dim=1)[0] + else: + # Fallback to middle slice for unknown methods + result[:, c] = x[:, c, D // 2] + + return result + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Extract features from the input images. + + Returns: + Dictionary with features, projections (None), and index information + """ + x = batch["anchor"] + + # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W] + # If 5D input, handle according to specified reduction methods + if x.dim() == 5: + x = self._reduce_5d_input(x) + + # Convert to uint8 as OpenPhenom expects uint8 inputs + if x.dtype != torch.uint8: + x = ( + ((x - x.min()) / (x.max() - x.min()) * 255) + .clamp(0, 255) + .to(torch.uint8) + ) + + # Get embeddings + self.model.return_channelwise_embeddings = False + features = self.model.predict(x) + # Create empty projections tensor with same batch size as features + # This ensures the EmbeddingWriter can process it + projections = torch.zeros((features.shape[0], 0), device=features.device) + + return { + "features": features, + "projections": projections, + "index": batch["index"], + } + + +def load_config(config_file): + """Load configuration from a YAML file.""" + with open(config_file, "r") as f: + config = yaml.safe_load(f) + return config + + +def load_normalization_from_config(norm_config): + """Load a normalization transform from a configuration dictionary.""" + class_path = norm_config["class_path"] + init_args = norm_config.get("init_args", {}) + + # Split module and class name + module_path, class_name = class_path.rsplit(".", 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the class + transform_class = getattr(module, class_name) + + # Instantiate the transform + return transform_class(**init_args) + + +@click.command() +@click.option( + "--config", + "-c", + type=click.Path(exists=True), + required=True, + help="Path to YAML configuration file", +) +def main(config): + """Extract OpenPhenom embeddings and save to zarr format using VisCy Trainer.""" + # Configure logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load config file + cfg = load_config(config) + logger.info(f"Loaded configuration from {config}") + + # Prepare datamodule parameters + dm_params = {} + + # Add data and tracks paths from the paths section + if "paths" not in cfg: + raise ValueError("Configuration must contain a 'paths' section") + + if "data_path" not in cfg["paths"]: + raise ValueError( + "Data path is required in the configuration file (paths.data_path)" + ) + dm_params["data_path"] = cfg["paths"]["data_path"] + + if "tracks_path" not in cfg["paths"]: + raise ValueError( + "Tracks path is required in the configuration file (paths.tracks_path)" + ) + dm_params["tracks_path"] = cfg["paths"]["tracks_path"] + + # Add datamodule parameters + if "datamodule" not in cfg: + raise ValueError("Configuration must contain a 'datamodule' section") + + # Prepare normalizations + if ( + "normalizations" not in cfg["datamodule"] + or not cfg["datamodule"]["normalizations"] + ): + raise ValueError( + "Normalizations are required in the configuration file (datamodule.normalizations)" + ) + + norm_configs = cfg["datamodule"]["normalizations"] + normalizations = [load_normalization_from_config(norm) for norm in norm_configs] + dm_params["normalizations"] = normalizations + + # Copy all other datamodule parameters + for param, value in cfg["datamodule"].items(): + if param != "normalizations": + # Handle patch sizes + if param == "patch_size": + dm_params["initial_yx_patch_size"] = value + dm_params["final_yx_patch_size"] = value + else: + dm_params[param] = value + + # Set up the data module + logger.info("Setting up data module") + dm = TripletDataModule(**dm_params) + + # Get model parameters for handling 5D inputs + channel_reduction_methods = {} + + if "model" in cfg and "channel_reduction_methods" in cfg["model"]: + channel_reduction_methods = cfg["model"]["channel_reduction_methods"] + + # Initialize OpenPhenom model with reduction settings + logger.info("Loading OpenPhenom model") + model = OpenPhenomModule( + channel_reduction_methods=channel_reduction_methods, + channel_names=dm_params.get("source_channel", []), + ) + + # Get dimensionality reduction parameters from config + phate_kwargs = None + pca_kwargs = None + + if "embedding" in cfg: + if "phate_kwargs" in cfg["embedding"]: + phate_kwargs = cfg["embedding"]["phate_kwargs"] + if "pca_kwargs" in cfg["embedding"]: + pca_kwargs = cfg["embedding"]["pca_kwargs"] + # Check if output path exists and should be overwritten + if "output_path" not in cfg["paths"]: + raise ValueError( + "Output path is required in the configuration file (paths.output_path)" + ) + + output_path = Path(cfg["paths"]["output_path"]) + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + + overwrite = False + if "execution" in cfg and "overwrite" in cfg["execution"]: + overwrite = cfg["execution"]["overwrite"] + elif output_path.exists(): + logger.warning(f"Output path {output_path} already exists, will overwrite") + overwrite = True + + # Set up EmbeddingWriter callback + embedding_writer = EmbeddingWriter( + output_path=output_path, + phate_kwargs=phate_kwargs, + pca_kwargs=pca_kwargs, + overwrite=overwrite, + ) + + # Set up and run VisCy trainer + logger.info("Setting up VisCy trainer") + trainer = VisCyTrainer( + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + callbacks=[embedding_writer], + inference_mode=True, + ) + + logger.info(f"Running prediction and saving to {output_path}") + trainer.predict(model, datamodule=dm) + + # Save configuration if requested + save_config_flag = True + show_config_flag = True + + if "execution" in cfg: + if "save_config" in cfg["execution"]: + save_config_flag = cfg["execution"]["save_config"] + if "show_config" in cfg["execution"]: + show_config_flag = cfg["execution"]["show_config"] + + # Save configuration if requested + if save_config_flag: + config_path = os.path.join(output_dir, "config.yml") + with open(config_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False) + logger.info(f"Configuration saved to {config_path}") + + # Display configuration if requested + if show_config_flag: + click.echo("\nConfiguration used:") + click.echo("-" * 40) + for key, value in cfg.items(): + click.echo(f"{key}:") + if isinstance(value, dict): + for subkey, subvalue in value.items(): + if isinstance(subvalue, list) and subkey == "normalizations": + click.echo(f" {subkey}:") + for norm in subvalue: + click.echo(f" - class_path: {norm['class_path']}") + click.echo(f" init_args: {norm['init_args']}") + else: + click.echo(f" {subkey}: {subvalue}") + else: + click.echo(f" {value}") + click.echo("-" * 40) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml b/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml deleted file mode 100644 index 47def3ae9..000000000 --- a/applications/contrastive_phenotyping/contrastive_cli/fit_ctc_mps.yml +++ /dev/null @@ -1,117 +0,0 @@ -# See help here on how to configure hyper-parameters with config files: -# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html -seed_everything: 42 -trainer: - accelerator: gpu - strategy: auto - devices: 1 - num_nodes: 1 - precision: 32-true - logger: - class_path: lightning.pytorch.loggers.TensorBoardLogger - # Nesting the logger config like this is equivalent to - # supplying the following argument to `lightning.pytorch.Trainer`: - # logger=TensorBoardLogger( - # "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations", - # log_graph=True, - # version="vanilla", - # ) - init_args: - save_dir: /Users/ziwen.liu/Projects/test-time - # this is the name of the experiment. - # The logs will be saved in `save_dir/lightning_logs/version` - version: time_interval_1 - log_graph: True - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: loss/val - every_n_epochs: 1 - save_top_k: 4 - save_last: true - fast_dev_run: false - max_epochs: 100 - log_every_n_steps: 10 - enable_checkpointing: true - inference_mode: true - use_distributed_sampler: true - # synchronize batchnorm parameters across multiple GPUs. - # important for contrastive learning to normalize the tensors across the whole batch. - sync_batchnorm: true -model: - class_path: viscy.representation.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy.representation.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 32 - drop_path_rate: 0.0 - loss_function: - class_path: torch.nn.TripletMarginLoss - init_args: - margin: 0.5 - lr: 0.0002 - log_batches_per_epoch: 3 - log_samples_per_batch: 2 - example_input_array_shape: [1, 1, 1, 128, 128] -data: - class_path: viscy.data.triplet.TripletDataModule - init_args: - data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr - tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr - source_channel: - - DIC - z_range: [0, 1] - batch_size: 16 - num_workers: 4 - initial_yx_patch_size: [256, 256] - final_yx_patch_size: [128, 128] - time_interval: 1 - normalizations: - - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [DIC] - level: fov_statistics - subtrahend: mean - divisor: std - augmentations: - - class_path: viscy.transforms.RandAffined - init_args: - keys: [DIC] - prob: 0.8 - scale_range: [0, 0.2, 0.2] - rotate_range: [3.14, 0.0, 0.0] - shear_range: [0.0, 0.01, 0.01] - padding_mode: zeros - - class_path: viscy.transforms.RandAdjustContrastd - init_args: - keys: [DIC] - prob: 0.5 - gamma: [0.8, 1.2] - - class_path: viscy.transforms.RandScaleIntensityd - init_args: - keys: [DIC] - prob: 0.5 - factors: 0.5 - - class_path: viscy.transforms.RandGaussianSmoothd - init_args: - keys: [DIC] - prob: 0.5 - sigma_x: [0.25, 0.75] - sigma_y: [0.25, 0.75] - sigma_z: [0.0, 0.0] - - class_path: viscy.transforms.RandGaussianNoised - init_args: - keys: [DIC] - prob: 0.5 - mean: 0.0 - std: 0.2 diff --git a/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml b/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml deleted file mode 100644 index d305e32aa..000000000 --- a/applications/contrastive_phenotyping/contrastive_cli/predict_ctc_mps.yml +++ /dev/null @@ -1,48 +0,0 @@ -seed_everything: 42 -trainer: - accelerator: gpu - strategy: auto - devices: auto - num_nodes: 1 - precision: 32-true - callbacks: - - class_path: viscy.representation.embedding_writer.EmbeddingWriter - init_args: - output_path: /Users/ziwen.liu/Projects/test-time/predict/time_interval_1.zarr - inference_mode: true -model: - class_path: viscy.representation.engine.ContrastiveModule - init_args: - encoder: - class_path: viscy.representation.contrastive.ContrastiveEncoder - init_args: - backbone: convnext_tiny - in_channels: 1 - in_stack_depth: 1 - stem_kernel_size: [1, 4, 4] - stem_stride: [1, 4, 4] - embedding_dim: 768 - projection_dim: 32 - drop_path_rate: 0.0 - example_input_array_shape: [1, 1, 1, 128, 128] -data: - class_path: viscy.data.triplet.TripletDataModule - init_args: - data_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr - tracks_path: /Users/ziwen.liu/Downloads/Hela_CTC.zarr - source_channel: DIC - z_range: [0, 1] - batch_size: 16 - num_workers: 4 - initial_yx_patch_size: [128, 128] - final_yx_patch_size: [128, 128] - time_interval: 1 - normalizations: - - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [DIC] - level: fov_statistics - subtrahend: mean - divisor: std -return_predictions: false -ckpt_path: /Users/ziwen.liu/Projects/test-time/lightning_logs/time_interval_1/checkpoints/last.ckpt diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py new file mode 100644 index 000000000..b79c7cbe7 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/ALFI_MSD_v2.py @@ -0,0 +1,227 @@ +# %% +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.distance import ( + compute_displacement, + compute_displacement_statistics, +) + +# Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", + "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", +} + +# Colors for different time intervals +interval_colors = { + "7 min interval": "blue", + "21 min interval": "red", +} + +# %% Compute MSD for each dataset +results = {} +raw_displacements = {} + +for label, path in feature_paths.items(): + print(f"\nProcessing {label}...") + embedding_dataset = read_embedding_dataset(Path(path)) + + # Compute displacements + displacements = compute_displacement( + embedding_dataset=embedding_dataset, + distance_metric="euclidean_squared", + ) + means, stds = compute_displacement_statistics(displacements) + results[label] = (means, stds) + raw_displacements[label] = displacements + + # Print some statistics + taus = sorted(means.keys()) + print(f" Number of different τ values: {len(taus)}") + print(f" τ range: {min(taus)} to {max(taus)}") + print(f" MSD at τ=1: {means[1]:.4f} ± {stds[1]:.4f}") + +# %% Plot MSD vs time (linear scale) +plt.figure(figsize=(10, 6)) + +# Plot each time interval +for interval_label, path in feature_paths.items(): + means, stds = results[interval_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + plt.plot( + taus, + mean_values, + "-", + color=interval_colors[interval_label], + alpha=0.5, + zorder=1, + ) + plt.scatter( + taus, + mean_values, + color=interval_colors[interval_label], + s=20, + label=interval_label, + zorder=2, + ) + +plt.xlabel("Time Shift (τ)") +plt.ylabel("Mean Square Displacement") +plt.title("MSD vs Time Shift") +plt.grid(True, alpha=0.3) +plt.legend() +plt.tight_layout() +plt.show() + +# %% Plot MSD vs time (log-log scale with slopes) +plt.figure(figsize=(10, 6)) + +# Plot each time interval +for interval_label, path in feature_paths.items(): + means, stds = results[interval_label] + + # Sort by tau for plotting + taus = sorted(means.keys()) + mean_values = [means[tau] for tau in taus] + std_values = [stds[tau] for tau in taus] + + # Filter out non-positive values for log scale + valid_mask = np.array(mean_values) > 0 + valid_taus = np.array(taus)[valid_mask] + valid_means = np.array(mean_values)[valid_mask] + + # Calculate slopes for different regions + log_taus = np.log(valid_taus) + log_means = np.log(valid_means) + + # Early slope (first third of points) + n_points = len(log_taus) + early_end = n_points // 3 + early_slope, early_intercept = np.polyfit( + log_taus[:early_end], log_means[:early_end], 1 + ) + + # Late slope (last third of points) + late_start = 2 * (n_points // 3) + late_slope, late_intercept = np.polyfit( + log_taus[late_start:], log_means[late_start:], 1 + ) + + plt.plot( + valid_taus, + valid_means, + "-", + color=interval_colors[interval_label], + alpha=0.5, + zorder=1, + ) + plt.scatter( + valid_taus, + valid_means, + color=interval_colors[interval_label], + s=20, + label=f"{interval_label} (α_early={early_slope:.2f}, α_late={late_slope:.2f})", + zorder=2, + ) + + # Plot fitted lines for early and late regions + early_fit = np.exp(early_intercept + early_slope * log_taus[:early_end]) + late_fit = np.exp(late_intercept + late_slope * log_taus[late_start:]) + + plt.plot( + valid_taus[:early_end], + early_fit, + "--", + color=interval_colors[interval_label], + alpha=0.3, + zorder=1, + ) + plt.plot( + valid_taus[late_start:], + late_fit, + "--", + color=interval_colors[interval_label], + alpha=0.3, + zorder=1, + ) + +plt.xscale("log") +plt.yscale("log") +plt.xlabel("Time Shift (τ)") +plt.ylabel("Mean Square Displacement") +plt.title("MSD vs Time Shift (log-log)") +plt.grid(True, alpha=0.3, which="both") +plt.legend( + title="α = slope in log-log space", bbox_to_anchor=(1.05, 1), loc="upper left" +) +plt.tight_layout() +plt.show() + +# %% Plot slopes analysis +early_slopes = [] +late_slopes = [] +intervals = [] + +for interval_label in feature_paths.keys(): + means, _ = results[interval_label] + + # Calculate slopes + taus = np.array(sorted(means.keys())) + mean_values = np.array([means[tau] for tau in taus]) + valid_mask = mean_values > 0 + + if np.sum(valid_mask) > 3: # Need at least 4 points to calculate both slopes + log_taus = np.log(taus[valid_mask]) + log_means = np.log(mean_values[valid_mask]) + + # Calculate early and late slopes + n_points = len(log_taus) + early_end = n_points // 3 + late_start = 2 * (n_points // 3) + + early_slope, _ = np.polyfit(log_taus[:early_end], log_means[:early_end], 1) + late_slope, _ = np.polyfit(log_taus[late_start:], log_means[late_start:], 1) + + early_slopes.append(early_slope) + late_slopes.append(late_slope) + intervals.append(interval_label) + +# Create bar plot +plt.figure(figsize=(12, 6)) + +x = np.arange(len(intervals)) +width = 0.35 + +plt.bar(x - width / 2, early_slopes, width, label="Early slope", alpha=0.7) +plt.bar(x + width / 2, late_slopes, width, label="Late slope", alpha=0.7) + +# Add reference lines +plt.axhline(y=1, color="k", linestyle="--", alpha=0.3, label="Normal diffusion (α=1)") +plt.axhline(y=0, color="k", linestyle="-", alpha=0.2) + +plt.xlabel("Time Interval") +plt.ylabel("Slope (α)") +plt.title("MSD Slopes by Time Interval") +plt.xticks(x, intervals, rotation=45) +plt.legend() + +# Add annotations for diffusion regimes +plt.text( + plt.xlim()[1] * 1.2, 1.5, "Super-diffusion", rotation=90, verticalalignment="center" +) +plt.text( + plt.xlim()[1] * 1.2, 0.5, "Sub-diffusion", rotation=90, verticalalignment="center" +) + +plt.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py b/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py deleted file mode 100644 index 595f283f7..000000000 --- a/applications/contrastive_phenotyping/evaluation/ALFI_displacement.py +++ /dev/null @@ -1,312 +0,0 @@ -# %% -from pathlib import Path -import matplotlib.pyplot as plt -import numpy as np -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.distance import ( - calculate_normalized_euclidean_distance_cell, - compute_displacement, - compute_dynamic_range, - compute_rms_per_track, -) -from collections import defaultdict -from tabulate import tabulate - -import numpy as np -from sklearn.metrics.pairwise import cosine_similarity -from collections import OrderedDict - -# %% function - -# Removed redundant compute_displacement_mean_std_full function -# Removed redundant compute_dynamic_range and compute_rms_per_track functions - - -def plot_rms_histogram(rms_values, label, bins=30): - """ - Plot histogram of RMS values across tracks. - - Parameters: - rms_values : list - List of RMS values, one for each track. - label : str - Label for the dataset (used in the title). - bins : int, optional - Number of bins for the histogram. Default is 30. - - Returns: - None: Displays the histogram. - """ - plt.figure(figsize=(10, 6)) - plt.hist(rms_values, bins=bins, alpha=0.7, color="blue", edgecolor="black") - plt.title(f"Histogram of RMS Values Across Tracks ({label})", fontsize=16) - plt.xlabel("RMS of Time Derivative", fontsize=14) - plt.ylabel("Frequency", fontsize=14) - plt.grid(True) - plt.show() - - -def plot_displacement( - mean_displacement, std_displacement, label, metrics_no_track=None -): - """ - Plot embedding displacement over time with mean and standard deviation. - - Parameters: - mean_displacement : dict - Mean displacement for each tau. - std_displacement : dict - Standard deviation of displacement for each tau. - label : str - Label for the dataset. - metrics_no_track : dict, optional - Metrics for the "Classical Contrastive (No Tracking)" dataset to compare against. - - Returns: - None: Displays the plot. - """ - plt.figure(figsize=(10, 6)) - taus = list(mean_displacement.keys()) - mean_values = list(mean_displacement.values()) - std_values = list(std_displacement.values()) - - plt.plot(taus, mean_values, marker="o", label=f"{label}", color="green") - plt.fill_between( - taus, - np.array(mean_values) - np.array(std_values), - np.array(mean_values) + np.array(std_values), - color="green", - alpha=0.3, - label=f"Std Dev ({label})", - ) - - if metrics_no_track: - mean_values_no_track = list(metrics_no_track["mean_displacement"].values()) - std_values_no_track = list(metrics_no_track["std_displacement"].values()) - - plt.plot( - taus, - mean_values_no_track, - marker="o", - label="Classical Contrastive (No Tracking)", - color="blue", - ) - plt.fill_between( - taus, - np.array(mean_values_no_track) - np.array(std_values_no_track), - np.array(mean_values_no_track) + np.array(std_values_no_track), - color="blue", - alpha=0.3, - label="Std Dev (No Tracking)", - ) - - plt.xlabel("Time Shift (τ)", fontsize=14) - plt.ylabel("Euclidean Distance", fontsize=14) - plt.title(f"Embedding Displacement Over Time ({label})", fontsize=16) - plt.grid(True) - plt.legend(fontsize=12) - plt.show() - - -def plot_overlay_displacement(overlay_displacement_data): - """ - Plot embedding displacement over time for all datasets in one plot. - - Parameters: - overlay_displacement_data : dict - A dictionary containing mean displacement per tau for all datasets. - - Returns: - None: Displays the plot. - """ - plt.figure(figsize=(12, 8)) - for label, mean_displacement in overlay_displacement_data.items(): - taus = list(mean_displacement.keys()) - mean_values = list(mean_displacement.values()) - plt.plot(taus, mean_values, marker="o", label=label) - - plt.xlabel("Time Shift (τ)", fontsize=14) - plt.ylabel("Euclidean Distance", fontsize=14) - plt.title("Overlayed Embedding Displacement Over Time", fontsize=16) - plt.grid(True) - plt.legend(fontsize=12) - plt.show() - - -# %% hist stats -def plot_boxplot_rms_across_models(datasets_rms): - """ - Plot a boxplot for the distribution of RMS values across models. - - Parameters: - datasets_rms : dict - A dictionary where keys are dataset names and values are lists of RMS values. - - Returns: - None: Displays the boxplot. - """ - plt.figure(figsize=(12, 6)) - labels = list(datasets_rms.keys()) - data = list(datasets_rms.values()) - print(labels) - print(data) - # Plot the boxplot - plt.boxplot(data, tick_labels=labels, patch_artist=True, showmeans=True) - - plt.title( - "Distribution of RMS of Rate of Change of Embedding Across Models", fontsize=16 - ) - plt.ylabel("RMS of Time Derivative", fontsize=14) - plt.xticks(rotation=45, fontsize=12) - plt.grid(axis="y", linestyle="--", alpha=0.7) - plt.tight_layout() - plt.show() - - -def plot_histogram_absolute_differences(datasets_abs_diff): - """ - Plot histograms of absolute differences across embeddings for all models. - - Parameters: - datasets_abs_diff : dict - A dictionary where keys are dataset names and values are lists of absolute differences. - - Returns: - None: Displays the histograms. - """ - plt.figure(figsize=(12, 6)) - for label, abs_diff in datasets_abs_diff.items(): - plt.hist(abs_diff, bins=50, alpha=0.5, label=label, density=True) - - plt.title("Histograms of Absolute Differences Across Models", fontsize=16) - plt.xlabel("Absolute Difference", fontsize=14) - plt.ylabel("Density", fontsize=14) - plt.legend(fontsize=12) - plt.grid(alpha=0.7) - plt.tight_layout() - plt.show() - - -# %% Paths to datasets -feature_paths = { - "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_7mins.zarr", - "21 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_21mins.zarr", - "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_updated_28mins.zarr", - "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_56mins.zarr", - "Cell Aware": "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_cellaware.zarr", -} - -no_track_path = "/hpc/projects/organelle_phenotyping/ALFI_benchmarking/predictions_final/ALFI_opp_classical.zarr" - -# %% Process Datasets -max_tau = 69 -metrics = {} - -overlay_displacement_data = {} -datasets_rms = {} -datasets_abs_diff = {} - -# Process "No Tracking" dataset -features_path_no_track = Path(no_track_path) -embedding_dataset_no_track = read_embedding_dataset(features_path_no_track) - -mean_displacement_no_track, std_displacement_no_track = compute_displacement( - embedding_dataset_no_track, max_tau=max_tau, return_mean_std=True -) -dynamic_range_no_track = compute_dynamic_range(mean_displacement_no_track) -metrics["No Tracking"] = { - "dynamic_range": dynamic_range_no_track, - "mean_displacement": mean_displacement_no_track, - "std_displacement": std_displacement_no_track, -} - -overlay_displacement_data["No Tracking"] = mean_displacement_no_track - -print("\nProcessing No Tracking dataset...") -print(f"Dynamic Range for No Tracking: {dynamic_range_no_track}") - -plot_displacement(mean_displacement_no_track, std_displacement_no_track, "No Tracking") - -rms_values_no_track = compute_rms_per_track(embedding_dataset_no_track) -datasets_rms["No Tracking"] = rms_values_no_track - -print(f"Plotting histogram of RMS values for No Tracking dataset...") -plot_rms_histogram(rms_values_no_track, "No Tracking", bins=30) - -# Compute absolute differences for "No Tracking" -abs_diff_no_track = np.concatenate( - [ - np.linalg.norm( - np.diff(embedding_dataset_no_track["features"].values[indices], axis=0), - axis=-1, - ) - for indices in np.split( - np.arange(len(embedding_dataset_no_track["track_id"])), - np.where(np.diff(embedding_dataset_no_track["track_id"]) != 0)[0] + 1, - ) - ] -) -datasets_abs_diff["No Tracking"] = abs_diff_no_track - -# Process other datasets -for label, path in feature_paths.items(): - print(f"\nProcessing {label} dataset...") - - features_path = Path(path) - embedding_dataset = read_embedding_dataset(features_path) - - mean_displacement, std_displacement = compute_displacement( - embedding_dataset, max_tau=max_tau, return_mean_std=True - ) - dynamic_range = compute_dynamic_range(mean_displacement) - metrics[label] = { - "dynamic_range": dynamic_range, - "mean_displacement": mean_displacement, - "std_displacement": std_displacement, - } - - overlay_displacement_data[label] = mean_displacement - - print(f"Dynamic Range for {label}: {dynamic_range}") - - plot_displacement( - mean_displacement, - std_displacement, - label, - metrics_no_track=metrics.get("No Tracking", None), - ) - - rms_values = compute_rms_per_track(embedding_dataset) - datasets_rms[label] = rms_values - - print(f"Plotting histogram of RMS values for {label}...") - plot_rms_histogram(rms_values, label, bins=30) - - abs_diff = np.concatenate( - [ - np.linalg.norm( - np.diff(embedding_dataset["features"].values[indices], axis=0), axis=-1 - ) - for indices in np.split( - np.arange(len(embedding_dataset["track_id"])), - np.where(np.diff(embedding_dataset["track_id"]) != 0)[0] + 1, - ) - ] - ) - datasets_abs_diff[label] = abs_diff - -print("\nPlotting overlayed displacement for all datasets...") -plot_overlay_displacement(overlay_displacement_data) - -print("\nSummary of Dynamic Ranges:") -for label, metric in metrics.items(): - print(f"{label}: Dynamic Range = {metric['dynamic_range']}") - -print("\nPlotting RMS boxplot across models...") -plot_boxplot_rms_across_models(datasets_rms) - -print("\nPlotting histograms of absolute differences across models...") -plot_histogram_absolute_differences(datasets_abs_diff) - - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py deleted file mode 100644 index 66a459ddc..000000000 --- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF.py +++ /dev/null @@ -1,449 +0,0 @@ -""" Script to compute the correlation between PCA and UMAP features and computed features -* finds the computed features best representing the PCA and UMAP components -* outputs a heatmap of the correlation between PCA and UMAP features and computed features -""" - -# %% -import os -import sys -from pathlib import Path - -sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") - -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns -from sklearn.decomposition import PCA - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks -from viscy.representation.evaluation.feature import ( - FeatureExtractor as FE, -) - -# %% -features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - -# %% - -source_channel = ["Phase3D", "RFP"] -z_range = (28, 43) -normalizations = None -# fov_name = "/B/4/5" -# track_id = 11 - -embedding_dataset = read_embedding_dataset(features_path) -embedding_dataset - -# load all unprojected features: -features = embedding_dataset["features"] - -# %% PCA analysis of the features - -pca = PCA(n_components=5) -pca_features = pca.fit_transform(features.values) -features = ( - features.assign_coords(PCA1=("sample", pca_features[:, 0])) - .assign_coords(PCA2=("sample", pca_features[:, 1])) - .assign_coords(PCA3=("sample", pca_features[:, 2])) - .assign_coords(PCA4=("sample", pca_features[:, 3])) - .assign_coords(PCA5=("sample", pca_features[:, 4])) - .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5"], append=True) -) - -# %% convert the xarray to dataframe structure and add columns for computed features -features_df = features.to_dataframe() -features_df = features_df.drop(columns=["features"]) -df = features_df.drop_duplicates() -features = df.reset_index(drop=True) - -features = features[features["fov_name"].str.startswith("/B/")] - -features["Phase Symmetry Score"] = np.nan -features["Fluor Symmetry Score"] = np.nan -features["Sensor Area"] = np.nan -features["Masked Sensor Intensity"] = np.nan -features["Entropy Phase"] = np.nan -features["Entropy Fluor"] = np.nan -features["Contrast Phase"] = np.nan -features["Dissimilarity Phase"] = np.nan -features["Homogeneity Phase"] = np.nan -features["Contrast Fluor"] = np.nan -features["Dissimilarity Fluor"] = np.nan -features["Homogeneity Fluor"] = np.nan -features["Phase IQR"] = np.nan -features["Fluor Mean Intensity"] = np.nan -features["Phase Standard Deviation"] = np.nan -features["Fluor Standard Deviation"] = np.nan -features["Phase radial profile"] = np.nan -features["Fluor radial profile"] = np.nan - -# %% compute the computed features and add them to the dataset - -fov_names_list = features["fov_name"].unique() -unique_fov_names = sorted(list(set(fov_names_list))) - - -for fov_name in unique_fov_names: - - unique_track_ids = features[features["fov_name"] == fov_name]["track_id"].unique() - unique_track_ids = list(set(unique_track_ids)) - - for track_id in unique_track_ids: - - prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, - ) - - whole = np.stack([p["anchor"] for p in prediction_dataset]) - phase = whole[:, 0, 3] - fluor = np.max(whole[:, 1], axis=1) - - for t in range(phase.shape[0]): - # Compute Fourier descriptors for phase image - phase_descriptors = FE.compute_fourier_descriptors(phase[t]) - # Analyze symmetry of phase image - phase_symmetry_score = FE.analyze_symmetry(phase_descriptors) - - # Compute Fourier descriptors for fluor image - fluor_descriptors = FE.compute_fourier_descriptors(fluor[t]) - # Analyze symmetry of fluor image - fluor_symmetry_score = FE.analyze_symmetry(fluor_descriptors) - - # Compute area of sensor - masked_intensity, area = FE.compute_area(fluor[t]) - - # Compute higher frequency features using spectral entropy - entropy_phase = FE.compute_spectral_entropy(phase[t]) - entropy_fluor = FE.compute_spectral_entropy(fluor[t]) - - # Compute texture analysis using GLCM - contrast_phase, dissimilarity_phase, homogeneity_phase = ( - FE.compute_glcm_features(phase[t]) - ) - contrast_fluor, dissimilarity_fluor, homogeneity_fluor = ( - FE.compute_glcm_features(fluor[t]) - ) - - # Compute interqualtile range of pixel intensities - iqr = FE.compute_iqr(phase[t]) - - # Compute mean pixel intensity - fluor_mean_intensity = FE.compute_mean_intensity(fluor[t]) - - # Compute standard deviation of pixel intensities - phase_std_dev = FE.compute_std_dev(phase[t]) - fluor_std_dev = FE.compute_std_dev(fluor[t]) - - # Compute radial intensity gradient - phase_radial_profile = FE.compute_radial_intensity_gradient(phase[t]) - fluor_radial_profile = FE.compute_radial_intensity_gradient(fluor[t]) - - # update the features dataframe with the computed features - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Fluor Symmetry Score", - ] = fluor_symmetry_score - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Symmetry Score", - ] = phase_symmetry_score - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Sensor Area", - ] = area - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Masked Sensor Intensity", - ] = masked_intensity - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Entropy Phase", - ] = entropy_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Entropy Fluor", - ] = entropy_fluor - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Contrast Phase", - ] = contrast_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Dissimilarity Phase", - ] = dissimilarity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Homogeneity Phase", - ] = homogeneity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Contrast Fluor", - ] = contrast_fluor - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Dissimilarity Fluor", - ] = dissimilarity_fluor - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Homogeneity Fluor", - ] = homogeneity_fluor - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase IQR", - ] = iqr - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Fluor Mean Intensity", - ] = fluor_mean_intensity - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Standard Deviation", - ] = phase_std_dev - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Fluor Standard Deviation", - ] = fluor_std_dev - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase radial profile", - ] = phase_radial_profile - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Fluor radial profile", - ] = fluor_radial_profile - -# %% - -# Save the features dataframe to a CSV file -features.to_csv( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan.csv", - index=False, -) - -# # read the features dataframe from the CSV file -# features = pd.read_csv( -# "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan.csv" -# ) - -# remove the rows with missing values -features = features.dropna() - -# sub_features = features[features["Time"] == 20] -feature_df_removed = features.drop( - columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] -) - -# Compute correlation between PCA features and computed features -correlation = feature_df_removed.corr(method="spearman") - -# %% display PCA correlation as a heatmap - -plt.figure(figsize=(20, 5)) -sns.heatmap( - correlation.drop(columns=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5"]).loc[ - "PCA1":"PCA5", : - ], - annot=True, - cmap="coolwarm", - fmt=".2f", -) -plt.title("Correlation between PCA features and computed features") -plt.xlabel("Computed Features") -plt.ylabel("PCA Features") -plt.savefig( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca.svg" -) - - -# %% plot PCA vs set of computed features - -set_features = [ - "Fluor radial profile", - "Homogeneity Phase", - "Phase IQR", - "Phase Standard Deviation", - "Sensor Area", - "Homogeneity Fluor", - "Contrast Fluor", - "Phase radial profile", -] - -plt.figure(figsize=(8, 10)) -sns.heatmap( - correlation.loc[set_features, "PCA1":"PCA5"], - annot=True, - cmap="coolwarm", - fmt=".2f", - vmin=-1, - vmax=1, -) - -plt.savefig( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_setfeatures.svg" -) - -# %% find the cell patches with the highest and lowest value in each feature - -def save_patches(fov_name, track_id): - data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" - ) - tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" - ) - source_channel = ["Phase3D", "RFP"] - prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, - ) - whole = np.stack([p["anchor"] for p in prediction_dataset]) - phase = whole[:, 0] - fluor = whole[:, 1] - out_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/data/computed_features/" - fov_name_out = fov_name.replace("/", "_") - np.save( - (os.path.join(out_dir, "phase" + fov_name_out + "_" + str(track_id) + ".npy")), - phase, - ) - np.save( - (os.path.join(out_dir, "fluor" + fov_name_out + "_" + str(track_id) + ".npy")), - fluor, - ) - - -# PCA1: Fluor radial profile -highest_fluor_radial_profile = features.loc[features["Fluor radial profile"].idxmax()] -print("Row with highest 'Fluor radial profile':") -# print(highest_fluor_radial_profile) -print( - f"fov_name: {highest_fluor_radial_profile['fov_name']}, time: {highest_fluor_radial_profile['t']}" -) -save_patches( - highest_fluor_radial_profile["fov_name"], highest_fluor_radial_profile["track_id"] -) - -lowest_fluor_radial_profile = features.loc[features["Fluor radial profile"].idxmin()] -print("Row with lowest 'Fluor radial profile':") -# print(lowest_fluor_radial_profile) -print( - f"fov_name: {lowest_fluor_radial_profile['fov_name']}, time: {lowest_fluor_radial_profile['t']}" -) -save_patches( - lowest_fluor_radial_profile["fov_name"], lowest_fluor_radial_profile["track_id"] -) - -# PCA2: Entropy phase -highest_entropy_phase = features.loc[features["Entropy Phase"].idxmax()] -print("Row with highest 'Entropy Phase':") -# print(highest_entropy_phase) -print( - f"fov_name: {highest_entropy_phase['fov_name']}, time: {highest_entropy_phase['t']}" -) -save_patches(highest_entropy_phase["fov_name"], highest_entropy_phase["track_id"]) - -lowest_entropy_phase = features.loc[features["Entropy Phase"].idxmin()] -print("Row with lowest 'Entropy Phase':") -# print(lowest_entropy_phase) -print( - f"fov_name: {lowest_entropy_phase['fov_name']}, time: {lowest_entropy_phase['t']}" -) -save_patches(lowest_entropy_phase["fov_name"], lowest_entropy_phase["track_id"]) - -# PCA3: Phase IQR -highest_phase_iqr = features.loc[features["Phase IQR"].idxmax()] -print("Row with highest 'Phase IQR':") -# print(highest_phase_iqr) -print(f"fov_name: {highest_phase_iqr['fov_name']}, time: {highest_phase_iqr['t']}") -save_patches(highest_phase_iqr["fov_name"], highest_phase_iqr["track_id"]) - -tenth_lowest_phase_iqr = features.nsmallest(10, "Phase IQR").iloc[9] -print("Row with tenth lowest 'Phase IQR':") -# print(tenth_lowest_phase_iqr) -print( - f"fov_name: {tenth_lowest_phase_iqr['fov_name']}, time: {tenth_lowest_phase_iqr['t']}" -) -save_patches(tenth_lowest_phase_iqr["fov_name"], tenth_lowest_phase_iqr["track_id"]) - -# PCA4: Phase Standard Deviation -highest_phase_std_dev = features.loc[features["Phase Standard Deviation"].idxmax()] -print("Row with highest 'Phase Standard Deviation':") -# print(highest_phase_std_dev) -print( - f"fov_name: {highest_phase_std_dev['fov_name']}, time: {highest_phase_std_dev['t']}" -) -save_patches(highest_phase_std_dev["fov_name"], highest_phase_std_dev["track_id"]) - -lowest_phase_std_dev = features.loc[features["Phase Standard Deviation"].idxmin()] -print("Row with lowest 'Phase Standard Deviation':") -# print(lowest_phase_std_dev) -print( - f"fov_name: {lowest_phase_std_dev['fov_name']}, time: {lowest_phase_std_dev['t']}" -) -save_patches(lowest_phase_std_dev["fov_name"], lowest_phase_std_dev["track_id"]) - -# PCA5: Sensor area -highest_sensor_area = features.loc[features["Sensor Area"].idxmax()] -print("Row with highest 'Sensor Area':") -# print(highest_sensor_area) -print(f"fov_name: {highest_sensor_area['fov_name']}, time: {highest_sensor_area['t']}") -save_patches(highest_sensor_area["fov_name"], highest_sensor_area["track_id"]) - -tenth_lowest_sensor_area = features.nsmallest(10, "Sensor Area").iloc[9] -print("Row with tenth lowest 'Sensor Area':") -# print(tenth_lowest_sensor_area) -print( - f"fov_name: {tenth_lowest_sensor_area['fov_name']}, time: {tenth_lowest_sensor_area['t']}" -) -save_patches(tenth_lowest_sensor_area["fov_name"], tenth_lowest_sensor_area["track_id"]) diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py b/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py deleted file mode 100644 index 3d5049166..000000000 --- a/applications/contrastive_phenotyping/evaluation/PC_vs_CF_singleChannel.py +++ /dev/null @@ -1,245 +0,0 @@ -""" Script to compute the correlation between PCA and UMAP features and computed features -* finds the computed features best representing the PCA and UMAP components -* outputs a heatmap of the correlation between PCA and UMAP features and computed features -""" - -# %% -import sys -from pathlib import Path - -sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") - -import numpy as np -import pandas as pd -import plotly.express as px -from scipy.stats import spearmanr -from sklearn.decomposition import PCA - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import dataset_of_tracks -from viscy.representation.evaluation.feature import ( - FeatureExtractor as FE, -) - -# %% -features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval_phase/predictions/epoch_186/1chan_128patch_186ckpt_Febtest.zarr" -) -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/track.zarr" -) - -# %% - -source_channel = ["Phase3D"] -z_range = (28, 43) -normalizations = None -# fov_name = "/B/4/5" -# track_id = 11 - -embedding_dataset = read_embedding_dataset(features_path) -embedding_dataset - -# load all unprojected features: -features = embedding_dataset["features"] - -# %% PCA analysis of the features - -pca = PCA(n_components=3) -embedding = pca.fit_transform(features.values) -features = ( - features.assign_coords(PCA1=("sample", embedding[:, 0])) - .assign_coords(PCA2=("sample", embedding[:, 1])) - .assign_coords(PCA3=("sample", embedding[:, 2])) - .set_index(sample=["PCA1", "PCA2", "PCA3"], append=True) -) - -# %% convert the xarray to dataframe structure and add columns for computed features -features_df = features.to_dataframe() -features_df = features_df.drop(columns=["features"]) -df = features_df.drop_duplicates() -features = df.reset_index(drop=True) - -features = features[features["fov_name"].str.startswith("/B/")] - -features["Phase Symmetry Score"] = np.nan -features["Entropy Phase"] = np.nan -features["Contrast Phase"] = np.nan -features["Dissimilarity Phase"] = np.nan -features["Homogeneity Phase"] = np.nan -features["Phase IQR"] = np.nan -features["Phase Standard Deviation"] = np.nan -features["Phase radial profile"] = np.nan - -# %% compute the computed features and add them to the dataset - -fov_names_list = features["fov_name"].unique() -unique_fov_names = sorted(list(set(fov_names_list))) - -for fov_name in unique_fov_names: - - unique_track_ids = features[features["fov_name"] == fov_name]["track_id"].unique() - unique_track_ids = list(set(unique_track_ids)) - - for track_id in unique_track_ids: - - # load the image patches - - prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, - ) - - whole = np.stack([p["anchor"] for p in prediction_dataset]) - phase = whole[:, 0, 3] - - for t in range(phase.shape[0]): - # Compute Fourier descriptors for phase image - phase_descriptors = FE.compute_fourier_descriptors(phase[t]) - # Analyze symmetry of phase image - phase_symmetry_score = FE.analyze_symmetry(phase_descriptors) - - # Compute higher frequency features using spectral entropy - entropy_phase = FE.compute_spectral_entropy(phase[t]) - - # Compute texture analysis using GLCM - contrast_phase, dissimilarity_phase, homogeneity_phase = ( - FE.compute_glcm_features(phase[t]) - ) - - # Compute interqualtile range of pixel intensities - iqr = FE.compute_iqr(phase[t]) - - # Compute standard deviation of pixel intensities - phase_std_dev = FE.compute_std_dev(phase[t]) - - # Compute radial intensity gradient - phase_radial_profile = FE.compute_radial_intensity_gradient(phase[t]) - - # update the features dataframe with the computed features - - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Symmetry Score", - ] = phase_symmetry_score - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Entropy Phase", - ] = entropy_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Contrast Phase", - ] = contrast_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Dissimilarity Phase", - ] = dissimilarity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Homogeneity Phase", - ] = homogeneity_phase - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase IQR", - ] = iqr - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase Standard Deviation", - ] = phase_std_dev - features.loc[ - (features["fov_name"] == fov_name) - & (features["track_id"] == track_id) - & (features["t"] == t), - "Phase radial profile", - ] = phase_radial_profile - -# %% -# Save the features dataframe to a CSV file -features.to_csv( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv", - index=False, -) - -# read the csv file -# features = pd.read_csv( -# "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_oneChan.csv" -# ) - -# remove the rows with missing values -features = features.dropna() - -# sub_features = features[features["Time"] == 20] -feature_df_removed = features.drop( - columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] -) - -# Compute correlation between PCA features and computed features -correlation = feature_df_removed.corr(method="spearman") - -# %% calculate the p-value and draw volcano plot to show the significance of the correlation - -p_values = pd.DataFrame(index=correlation.index, columns=correlation.columns) - -for i in correlation.index: - for j in correlation.columns: - if i != j: - p_values.loc[i, j] = spearmanr( - feature_df_removed[i], feature_df_removed[j] - )[1] - -p_values = p_values.astype(float) - -# %% draw an interactive volcano plot showing -log10(p-value) vs fold change - -# Flatten the correlation and p-values matrices and create a DataFrame -correlation_flat = correlation.values.flatten() -p_values_flat = p_values.values.flatten() -# Create a list of feature names for the flattened correlation and p-values -feature_names = [f"{i}_{j}" for i in correlation.index for j in correlation.columns] - -data = pd.DataFrame( - { - "Correlation": correlation_flat, - "-log10(p-value)": -np.log10(p_values_flat), - "feature_names": feature_names, - } -) - -# Create an interactive scatter plot using Plotly -fig = px.scatter( - data, - x="Correlation", - y="-log10(p-value)", - title="Volcano plot showing significance of correlation", - labels={"Correlation": "Correlation", "-log10(p-value)": "-log10(p-value)"}, - opacity=0.5, - hover_data=["feature_names"], -) - -fig.show() -# Save the interactive volcano plot as an HTML file -fig.write_html( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/volcano_plot_1chan.html" -) - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py b/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py new file mode 100644 index 000000000..3d300ea7b --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/PC_vs_computed_features.py @@ -0,0 +1,160 @@ +""" Script to compute the correlation between PCA and UMAP features and computed features +* finds the computed features best representing the PCA and UMAP components +* outputs a heatmap of the correlation between PCA and UMAP features and computed features +""" + +# %% +from pathlib import Path +import matplotlib.pyplot as plt +import seaborn as sns +from compute_pca_features import compute_features, compute_correlation_and_save_png + +# %% for sensor features + +features_path = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_94ckpt_rev6_2.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" +) + +source_channel = ["Phase3D", "RFP"] +seg_channel = ["Nuclei_prediction_labels"] +z_range = (28, 43) +fov_list = ["/A/3", "/B/3", "/B/4"] + +features_sensor = compute_features( + features_path, + data_path, + tracks_path, + source_channel, + seg_channel, + z_range, + fov_list, +) + +features_sensor.to_csv( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_allset_sensor.csv", + index=False, +) + +# features_sensor = pd.read_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_allset_sensor.csv") + +# take a subset without the 768 features +feature_columns = [f"feature_{i+1}" for i in range(768)] +features_subset_sensor = features_sensor.drop(columns=feature_columns) +correlation_sensor = compute_correlation_and_save_png( + features_subset_sensor, + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_sensor_allset.svg", +) + +# %% plot PCA vs set of computed features for sensor features + +set_features = [ + "Fluor Radial Intensity Gradient", + "Phase Interquartile Range", + "Perimeter area ratio", + "Fluor Interquartile Range", + "Phase Entropy", + "Fluor Zernike Moment Mean", +] + +plt.figure(figsize=(10, 8)) +sns.heatmap( + correlation_sensor.loc[set_features, "PCA1":"PCA6"], + annot=True, + cmap="coolwarm", + fmt=".2f", + annot_kws={"size": 24}, + vmin=-1, + vmax=1, +) +plt.xlabel("Computed Features", fontsize=24) +plt.ylabel("PCA Features", fontsize=24) +plt.xticks(fontsize=24) # Increase x-axis tick labels +plt.yticks(fontsize=24) # Increase y-axis tick labels + +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_allset_sensor_6features.svg" +) + +# plot the PCA1 vs PCA2 map for sensor features + +plt.figure(figsize=(10, 10)) +sns.scatterplot( + x="PCA1", + y="PCA2", + data=features_sensor, +) + + +# .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. .-.-. +# / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ / / \ \ +# '-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-'-' '-' + + +# %% for organelle features + +features_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/Soorya/timeAware_2chan_ntxent_192patch_91ckpt_rev7_GT.zarr" +) +data_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_ZIKV_DENV.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +) + +source_channel = ["Phase3D", "raw GFP EX488 EM525-45"] +seg_channel = ["nuclei_prediction_labels_labels"] +z_range = (16, 21) +normalizations = None +fov_list = ["/B/2/000000", "/B/3/000000", "/C/2/000000"] + +features_organelle = compute_features( + features_path, + data_path, + tracks_path, + source_channel, + seg_channel, + z_range, + fov_list, +) + +# Save the features dataframe to a CSV file +features_organelle.to_csv( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan_organelle_multiwell.csv", + index=False, +) + +correlation_organelle = compute_correlation_and_save_png( + features_organelle, + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_organelle_multiwell.svg", +) + +# features_organelle = pd.read_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/features_twoChan_organelle_multiwell_refinedPCA.csv") + +# %% plot PCA vs set of computed features for organelle features + +plt.figure(figsize=(10, 8)) +sns.heatmap( + correlation_organelle.loc[set_features, "PCA1":"PCA6"], + annot=True, + cmap="coolwarm", + fmt=".2f", + annot_kws={"size": 24}, + vmin=-1, + vmax=1, +) +plt.xlabel("Computed Features", fontsize=24) +plt.ylabel("PCA Features", fontsize=24) +plt.xticks(fontsize=24) # Increase x-axis tick labels +plt.yticks(fontsize=24) # Increase y-axis tick labels +plt.savefig( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/cell_division/PC_vs_CF_2chan_pca_setfeatures_organelle_6features.svg" +) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/compute_pca_features.py b/applications/contrastive_phenotyping/evaluation/compute_pca_features.py new file mode 100644 index 000000000..a7f09b59f --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/compute_pca_features.py @@ -0,0 +1,382 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks +from viscy.representation.evaluation.feature import CellFeatures + + +## function to read the embedding dataset and return the features +def compute_PCA(features_path: Path): + """Compute PCA components from embedding features and combine with original features. + + This function reads an embedding dataset, standardizes the features, and computes + 8 principal components. The PCA components are then combined with the original + features in an xarray dataset structure. + + Parameters + ---------- + features_path : Path + Path to the embedding dataset containing the feature vectors. + + Returns + ------- + features: xarray dataset with PCA components as new coordinates + + """ + embedding_dataset = read_embedding_dataset(features_path) + embedding_dataset + + # load all unprojected features: + features = embedding_dataset["features"] + scaled_features = StandardScaler().fit_transform(features.values) + # PCA analysis of the features + + pca = PCA(n_components=8) + pca_features = pca.fit_transform(scaled_features) + features = ( + features.assign_coords(PCA1=("sample", pca_features[:, 0])) + .assign_coords(PCA2=("sample", pca_features[:, 1])) + .assign_coords(PCA3=("sample", pca_features[:, 2])) + .assign_coords(PCA4=("sample", pca_features[:, 3])) + .assign_coords(PCA5=("sample", pca_features[:, 4])) + .assign_coords(PCA6=("sample", pca_features[:, 5])) + .assign_coords(PCA7=("sample", pca_features[:, 6])) + .assign_coords(PCA8=("sample", pca_features[:, 7])) + .set_index( + sample=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7", "PCA8"], + append=True, + ) + ) + + return features + + +def compute_features( + features_path: Path, + data_path: Path, + tracks_path: Path, + source_channel: list, + seg_channel: list, + z_range: tuple, + fov_list: list, +): + """Compute various cell features and combine them with PCA features. + + This function processes cell tracking data to compute various morphological and + intensity-based features for both phase and fluorescence channels, and combines + them with PCA features from an embedding dataset. + + Parameters + ---------- + features_path : Path + Path to the embedding dataset containing PCA features. + data_path : Path + Path to the raw data directory containing image data. + tracks_path : Path + Path to the directory containing tracking data in CSV format. + source_channel : list + List of source channels to process from the data. + seg_channel : list + List of segmentation channels to process from the data. + z_range : tuple + Tuple specifying the z-range to process (min_z, max_z). + fov_list : list + List of field of view names to process. + + Returns + ------- + pandas.DataFrame + DataFrame containing all computed features including: + - Basic features (mean intensity, std dev, kurtosis, etc.) for both Phase and Fluor channels + - Organelle features (area, masked intensity) + - Nuclear features (area, perimeter, eccentricity) + - PCA components (PCA1-PCA8) + - Original tracking information (fov_name, track_id, time points) + """ + + embedding_dataset = compute_PCA(features_path) + features_npy = embedding_dataset["features"].values + + # convert the xarray to dataframe structure and add columns for computed features + embedding_df = embedding_dataset["sample"].to_dataframe().reset_index(drop=True) + feature_columns = pd.DataFrame( + features_npy, columns=[f"feature_{i+1}" for i in range(768)] + ) + + embedding_df = pd.concat([embedding_df, feature_columns], axis=1) + embedding_df = embedding_df.drop(columns=["sample", "UMAP1", "UMAP2"]) + + # Filter features based on FOV names that start with any of the items in fov_list + embedding_df = embedding_df[ + embedding_df["fov_name"].apply( + lambda x: any(x.startswith(fov) for fov in fov_list) + ) + ] + + # Define feature categories and their corresponding column names + feature_columns = { + "basic_features": [ + ("Mean Intensity", ["Phase", "Fluor"]), + ("Std Dev", ["Phase", "Fluor"]), + ("Kurtosis", ["Phase", "Fluor"]), + ("Skewness", ["Phase", "Fluor"]), + ("Entropy", ["Phase", "Fluor"]), + ("Interquartile Range", ["Phase", "Fluor"]), + ("Dissimilarity", ["Phase", "Fluor"]), + ("Contrast", ["Phase", "Fluor"]), + ("Texture", ["Phase", "Fluor"]), + ("Weighted Intensity Gradient", ["Phase", "Fluor"]), + ("Radial Intensity Gradient", ["Phase", "Fluor"]), + ("Zernike Moment Std", ["Phase", "Fluor"]), + ("Zernike Moment Mean", ["Phase", "Fluor"]), + ("Intensity Localization", ["Phase", "Fluor"]), + ], + "organelle_features": [ + "Fluor Area", + "Fluor Masked Intensity", + ], + "nuclear_features": [ + "Nuclear area", + "Perimeter", + "Perimeter area ratio", + "Nucleus eccentricity", + ], + } + + # Initialize all feature columns + for category, feature_list in feature_columns.items(): + if isinstance(feature_list[0], tuple): # Handle features with multiple channels + for feature, channels in feature_list: + for channel in channels: + col_name = f"{channel} {feature}" + embedding_df[col_name] = np.nan + else: # Handle single features + for feature in feature_list: + embedding_df[feature] = np.nan + + # compute the computed features and add them to the dataset + + fov_names_list = embedding_df["fov_name"].unique() + unique_fov_names = sorted(list(set(fov_names_list))) + + for fov_name in unique_fov_names: + + unique_track_ids = embedding_df[embedding_df["fov_name"] == fov_name][ + "track_id" + ].unique() + unique_track_ids = list(set(unique_track_ids)) + + # iteration_count = 0 + + for track_id in unique_track_ids: + if not embedding_df[ + (embedding_df["fov_name"] == fov_name) + & (embedding_df["track_id"] == track_id) + ].empty: + + prediction_dataset = dataset_of_tracks( + data_path, + tracks_path, + [fov_name], + [track_id], + z_range=z_range, + source_channel=source_channel, + ) + track_channel = dataset_of_tracks( + tracks_path, + tracks_path, + [fov_name], + [track_id], + z_range=(0, 1), + source_channel=seg_channel, + ) + + whole = np.stack([p["anchor"] for p in prediction_dataset]) + seg_mask = np.stack([p["anchor"] for p in track_channel]) + phase = whole[:, 0, 2] + # Normalize phase image to 0-255 range + # phase = ((phase - phase.min()) / (phase.max() - phase.min()) * 255).astype(np.uint8) + # Normalize fluorescence image to 0-255 range + fluor = np.max(whole[:, 1], axis=1) + # fluor = ((fluor - fluor.min()) / (fluor.max() - fluor.min()) * 255).astype(np.uint8) + nucl_mask = seg_mask[:, 0, 0] + + for i, t in enumerate( + embedding_df[ + (embedding_df["fov_name"] == fov_name) + & (embedding_df["track_id"] == track_id) + ]["t"] + ): + + # Basic statistical features for both channels + phase_features = CellFeatures(phase[i], nucl_mask[i]) + PF = phase_features.compute_all_features() + + # Get all basic statistical measures at once + phase_stats = { + "Mean Intensity": PF["mean_intensity"], + "Std Dev": PF["std_dev"], + "Kurtosis": PF["kurtosis"], + "Skewness": PF["skewness"], + "Interquartile Range": PF["iqr"], + "Entropy": PF["spectral_entropy"], + "Dissimilarity": PF["dissimilarity"], + "Contrast": PF["contrast"], + "Texture": PF["texture"], + "Zernike Moment Std": PF["zernike_std"], + "Zernike Moment Mean": PF["zernike_mean"], + "Radial Intensity Gradient": PF["radial_intensity_gradient"], + "Weighted Intensity Gradient": PF[ + "weighted_intensity_gradient" + ], + "Intensity Localization": PF["intensity_localization"], + } + + fluor_cell_features = CellFeatures(fluor[i], nucl_mask[i]) + + FF = fluor_cell_features.compute_all_features() + + fluor_stats = { + "Mean Intensity": FF["mean_intensity"], + "Std Dev": FF["std_dev"], + "Kurtosis": FF["kurtosis"], + "Skewness": FF["skewness"], + "Interquartile Range": FF["iqr"], + "Entropy": FF["spectral_entropy"], + "Contrast": FF["contrast"], + "Dissimilarity": FF["dissimilarity"], + "Texture": FF["texture"], + "Masked Area": FF["masked_area"], + "Masked Intensity": FF["masked_intensity"], + "Weighted Intensity Gradient": FF[ + "weighted_intensity_gradient" + ], + "Radial Intensity Gradient": FF["radial_intensity_gradient"], + "Zernike Moment Std": FF["zernike_std"], + "Zernike Moment Mean": FF["zernike_mean"], + "Intensity Localization": FF["intensity_localization"], + "Area": FF["area"], + } + + mask_features = CellFeatures(nucl_mask[i], nucl_mask[i]) + MF = mask_features.compute_all_features() + + mask_stats = { + "perimeter": MF["perimeter"], + "area": MF["area"], + "eccentricity": MF["eccentricity"], + "perimeter_area_ratio": MF["perimeter_area_ratio"], + } + + # Create dictionaries for each feature category + phase_feature_mapping = { + f"Phase {k.replace('_', ' ').title()}": v + for k, v in phase_stats.items() + } + + fluor_feature_mapping = { + f"Fluor {k.replace('_', ' ').title()}": v + for k, v in fluor_stats.items() + } + + mask_feature_mapping = { + "Nuclear area": mask_stats["area"], + "Perimeter": mask_stats["perimeter"], + "Perimeter area ratio": mask_stats["perimeter_area_ratio"], + "Nucleus eccentricity": mask_stats["eccentricity"], + } + + # Combine all feature dictionaries + feature_values = { + **phase_feature_mapping, + **fluor_feature_mapping, + **mask_feature_mapping, + } + + # update the features dataframe + for feature_name, value in feature_values.items(): + embedding_df.loc[ + (embedding_df["fov_name"] == fov_name) + & (embedding_df["track_id"] == track_id) + & (embedding_df["t"] == t), + feature_name, + ] = value[0] + + # iteration_count += 1 + print(f"Processed {fov_name}+{track_id}") + + return embedding_df + + +## save all feature dataframe to png file +def compute_correlation_and_save_png(features: pd.DataFrame, filename: str): + """Compute correlation between PCA features and computed features, and save as heatmap. + + This function calculates the Spearman correlation between PCA components and all + computed features, then visualizes the results as a heatmap. The heatmap focuses + on the correlation between PCA components (PCA1-PCA8) and all other computed features. + + Parameters + ---------- + features : pandas.DataFrame + DataFrame containing all features including: + - PCA components (PCA1-PCA8) + - Computed features (morphological, intensity-based, etc.) + - Tracking metadata (fov_name, track_id, t, etc.) + filename : str + Path where the correlation heatmap will be saved as a PNG or SVG file. + + Returns + ------- + pandas.DataFrame + The correlation matrix between all features. + """ + # remove the rows with missing values + features = features.dropna() + + # sub_features = features[features["Time"] == 20] + feature_df_removed = features.drop( + columns=["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] + ) + + # Compute correlation between PCA features and computed features + correlation = feature_df_removed.corr(method="spearman") + + # display PCA correlation as a heatmap + + plt.figure(figsize=(30, 10)) + sns.heatmap( + correlation.drop( + columns=["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7", "PCA8"] + ).loc["PCA1":"PCA8", :], + annot=True, + cmap="coolwarm", + fmt=".2f", + annot_kws={"size": 18}, + cbar=False, + ) + plt.title("Correlation between PCA features and computed features", fontsize=12) + plt.xlabel("Computed Features", fontsize=18) + plt.ylabel("PCA Features", fontsize=18) + plt.xticks(fontsize=18, rotation=45, ha="right") # Rotate labels and align them + plt.yticks(fontsize=18) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + plt.savefig( + filename, + dpi=300, + bbox_inches="tight", + pad_inches=0.5, # Add padding around the figure + ) + plt.close() + + return correlation diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py new file mode 100644 index 000000000..b3a1a75b8 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py @@ -0,0 +1,122 @@ +# metrics for the knowledge distillation figure + +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from sklearn.metrics import accuracy_score, classification_report, f1_score + +# %% +# Mantis +test_virus = ["C/2/000000", "C/2/001001"] +test_mock = ["B/3/000000", "B/3/000001"] + +# Mantis +TRAIN_FOVS = ["C/2/000001", "C/2/001000", "B/3/001000", "B/3/001001"] + +VAL_FOVS = test_virus + test_mock + +# %% +prediction_from_scratch = pd.read_csv( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/from-scratch-last-1126.csv" +) +prediction_from_scratch["pretraining"] = "ImageNet" + +prediction_finetuned = pd.read_csv( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-last-1126.csv" +) +pretrained_name = "DynaCLR" +prediction_finetuned["pretraining"] = pretrained_name + +prediction = pd.concat([prediction_from_scratch, prediction_finetuned], axis=0) + +# %% +prediction = prediction[prediction["fov_name"].isin(VAL_FOVS)] +prediction["prediction_binary"] = prediction["prediction"] > 0.5 +prediction + +# %% +print( + classification_report( + prediction["label"], prediction["prediction_binary"], digits=3 + ) +) + +# %% +prediction["HPI"] = prediction["t"] / 6 + 3 + +bins = [3, 6, 9, 12, 15, 18, 21, 24] +labels = [f"{start}-{end}" for start, end in zip(bins[:-1], bins[1:])] +prediction["stage"] = pd.cut(prediction["HPI"], bins=bins, labels=labels, right=True) +prediction["well"] = prediction["fov_name"].apply( + lambda x: "ZIKV" if x in test_virus else "Mock" +) +comparison = prediction.melt( + id_vars=["fov_name", "id", "HPI", "well", "stage", "pretraining"], + value_vars=["label", "prediction_binary"], + var_name="source", + value_name="value", +) +with sns.axes_style("whitegrid"): + ax = sns.lineplot( + data=comparison[comparison["pretraining"] == pretrained_name], + x="HPI", + y="value", + hue="well", + hue_order=["Mock", "ZIKV"], + style="source", + errorbar=None, + color="gray", + ) + ax.set_ylabel("Infection ratio") + +# %% +id_vars = ["stage", "pretraining"] + +accuracy_by_t = prediction.groupby(id_vars).apply( + lambda x: float(accuracy_score(x["label"], x["prediction_binary"])) +) +f1_by_t = prediction.groupby(id_vars).apply( + lambda x: float(f1_score(x["label"], x["prediction_binary"])) +) + +metrics_df = pd.DataFrame( + data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values}, + index=f1_by_t.index, +).reset_index() + +metrics_long = metrics_df.melt( + id_vars=id_vars, + value_vars=["accuracy"], + var_name="metric", + value_name="score", +) + +with sns.axes_style("ticks"): + plt.style.use("../figures/figure.mplstyle") + g = sns.catplot( + data=metrics_long, + x="stage", + y="score", + hue="pretraining", + kind="point", + linewidth=1.5, + linestyles="--", + ) + g.set_axis_labels("HPI", "accuracy") + sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1)) + g.figure.set_size_inches(3.5, 1.5) + g.set(xlim=(-1, 7), ylim=(0.6, 1.0)) + plt.show() + + +# %% +g.figure.savefig( + Path.home() + / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/accuracy_students.pdf", + dpi=300, +) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py new file mode 100644 index 000000000..6afe391a7 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py @@ -0,0 +1,149 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from iohub.ngff import open_ome_zarr +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, classification_report, f1_score + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% +train_annotations = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_11_26_A549_ZIKA-sensor_ZIKV/3-phenotype/annotate-infection/combined_annotations.csv" +) +train_embeddings = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/generate-labels/sensor-2024-11-26.zarr" +) +val_annotations = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_08_14_ZIKV_pal17_48h/6-phenotype/combined_annotations.csv" + # "/hpc/projects/intracellular_dashboard/viral-sensor/2024_11_05_A549_pAL10_24h/4-phenotype/annotate-infection/combined_annotations.csv" +) +val_embeddings = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/generate-labels/sensor-2024-08-14-annotation.zarr" + # "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/generate-labels/sensor-2024-11-05.zarr" +) + + +# %% +def filter_train_fovs(fov_name: pd.Series) -> pd.Series: + return fov_name.str[1:4].isin(["C/2", "B/3"]) + + +def filter_val_fovs(fov_name: pd.Series) -> pd.Series: + return fov_name.str[1:4].isin(["0/3"]) | (fov_name == "/0/4/000001") + # return fov_name.isin( + # ["/0/15/000001", "/0/11/002000", "/0/11/002001", "/0/11/002002"] + # ) + + +def all_fovs(fov_name: pd.Series) -> pd.Series: + return None + + +def load_features_and_annotations(embedding_path, annotation_path, filter_fn): + dataset = read_embedding_dataset(embedding_path) + features = dataset["features"][filter_fn(dataset["fov_name"])] + annotation = pd.read_csv(annotation_path) + annotation["fov_name"] = "/" + annotation["fov_name"] + annotation = annotation.set_index(["fov_name", "id"]) + index = features["sample"].to_dataframe().reset_index(drop=True)[["fov_name", "id"]] + selected = pd.merge( + left=index, right=annotation, on=["fov_name", "id"], how="inner" + ) + selected["infection_state"] = selected["infection_state"].astype("category") + return features, selected["infection_state"], selected + + +# %% +train_features, train_annotation, train_selected = load_features_and_annotations( + train_embeddings, train_annotations, filter_fn=filter_train_fovs +) +val_features, val_annotation, val_selected = load_features_and_annotations( + val_embeddings, val_annotations, filter_fn=filter_val_fovs +) + +model = LogisticRegression(class_weight="balanced", random_state=42, solver="liblinear") +model = model.fit(train_features, train_annotation) +train_prediction = model.predict(train_features) +val_prediction = model.predict(val_features) + +print("Training\n", classification_report(train_annotation, train_prediction)) +print("Validation\n", classification_report(val_annotation, val_prediction)) + +val_selected["label"] = val_selected["infection_state"].cat.codes +val_selected["prediction_binary"] = val_prediction + +# %% +prediction = val_selected +prediction["HPI"] = prediction["t"] / 2 + 3 +bins = [3, 6, 9, 12, 15, 18, 21, 24] +labels = [f"{start}-{end}" for start, end in zip(bins[:-1], bins[1:])] +prediction["stage"] = pd.cut(prediction["HPI"], bins=bins, labels=labels, right=True) +comparison = prediction.melt( + id_vars=["fov_name", "id", "HPI"], + value_vars=["label", "prediction_binary"], + var_name="source", + value_name="value", +) +with sns.axes_style("whitegrid"): + ax = sns.lineplot( + data=comparison, + x="HPI", + y="value", + style="source", + errorbar=None, + color="gray", + ) + ax.set_ylabel("Infection ratio") + +# %% +accuracy_by_t = prediction.groupby(["stage"]).apply( + lambda x: float(accuracy_score(x["label"], x["prediction_binary"])) +) +f1_by_t = prediction.groupby(["stage"]).apply( + lambda x: float(f1_score(x["label"], x["prediction_binary"])) +) + +metrics_df = pd.DataFrame( + data={ + "accuracy": accuracy_by_t.values, + "F1": f1_by_t.values, + }, + index=f1_by_t.index, +).reset_index() + +metrics_long = metrics_df.melt( + id_vars=["stage"], + value_vars=["accuracy"], + var_name="metric", + value_name="score", +) +with sns.axes_style("ticks"): + plt.style.use("../figures/figure.mplstyle") + g = sns.catplot( + data=metrics_long, + x="stage", + y="score", + kind="point", + linewidth=1.5, + linestyles="--", + legend=False, + color="gray", + ) + g.set_axis_labels("HPI", "accuracy") + g.figure.set_size_inches(3.5, 0.75) + g.set(xlim=(-1, 7), ylim=(0.9, 1.0)) + plt.show() + +# %% +g.savefig( + Path.home() + / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/teacher_accuracy.pdf", + dpi=300, + bbox_inches="tight", +) + +# %% diff --git a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py b/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py deleted file mode 100644 index 5f59da3e0..000000000 --- a/applications/contrastive_phenotyping/evaluation/pca_umap_embeddings_time.py +++ /dev/null @@ -1,220 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler -from umap import UMAP - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation - -# %% Paths and parameters. - - -features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" -) -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - - -# %% -embedding_dataset = read_embedding_dataset(features_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% -# Add UMAP coordinates to the dataset and plot w/ time - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - - -# Add the title to the plot -plt.title("Cell & Time Aware Sampling (30 min interval)") -plt.xlim(-10, 20) -plt.ylim(-10, 20) -# plt.savefig('umap_cell_time_aware_time.svg', format='svg') -plt.savefig("updated_cell_time_aware_time.png", format="png") -# Show the plot -plt.show() - - -# %% - - -any_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_difcell_randomtime_sampling/Ver2_updateTracking_refineModel/predictions/Feb_2chan_128patch_32projDim/2chan_128patch_56ckpt_FebTest.zarr" -) -embedding_dataset = read_embedding_dataset(any_features_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% Any time sampling plot - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - - -# Add the title to the plot -plt.title("Cell Aware Sampling") - -plt.xlim(-10, 20) -plt.ylim(-10, 20) - -plt.savefig("1_updated_cell_aware_time.png", format="png") -# plt.savefig('umap_cell_aware_time.pdf', format='pdf') -# Show the plot -plt.show() - - -# %% - - -contrastive_learning_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/feb_fixed_test_predict.zarr" -) -embedding_dataset = read_embedding_dataset(contrastive_learning_path) -embedding_dataset - - -# %% -# Compute UMAP over all features -features = embedding_dataset["features"] -# or select a well: -# features = features[features["fov_name"].str.contains("B/4")] - - -scaled_features = StandardScaler().fit_transform(features.values) -umap = UMAP() -# Fit UMAP on all features -embedding = umap.fit_transform(scaled_features) - - -# %% Any time sampling plot - - -features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) -) -features - -sns.scatterplot( - x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 -) - -# Add the title to the plot -plt.title("Classical Contrastive Learning Sampling") -plt.xlim(-10, 20) -plt.ylim(-10, 20) -plt.savefig("updated_classical_time.png", format="png") -# plt.savefig('classical_time.pdf', format='pdf') - -# Show the plot -plt.show() - - -# %% PCA - - -pca = PCA(n_components=4) -# scaled_features = StandardScaler().fit_transform(features.values) -# pca_features = pca.fit_transform(scaled_features) -pca_features = pca.fit_transform(features.values) - - -features = ( - features.assign_coords(PCA1=("sample", pca_features[:, 0])) - .assign_coords(PCA2=("sample", pca_features[:, 1])) - .assign_coords(PCA3=("sample", pca_features[:, 2])) - .assign_coords(PCA4=("sample", pca_features[:, 3])) - .set_index(sample=["PCA1", "PCA2", "PCA3", "PCA4"], append=True) -) - - -# %% plot PCA components w/ time - - -plt.figure(figsize=(10, 10)) -sns.scatterplot( - x=features["PCA1"], y=features["PCA2"], hue=features["t"], s=7, alpha=0.8 -) - - -# %% OVERLAY INFECTION ANNOTATION -ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" -) - - -infection = load_annotation( - features, - ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - - -# %% -sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) - - -# %% plot PCA components with infection hue -sns.scatterplot(x=features["PCA1"], y=features["PCA2"], hue=infection, s=7, alpha=0.8) - - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py b/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py deleted file mode 100644 index fb64b9f07..000000000 --- a/applications/contrastive_phenotyping/evaluation/predict_infection_score_supervised.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import warnings -from argparse import ArgumentParser - -import numpy as np -import pandas as pd -from torch.utils.data import DataLoader -from tqdm import tqdm - -from viscy.data.triplet import TripletDataModule - -warnings.filterwarnings( - "ignore", - category=UserWarning, - message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", -) - -# %% Paths and constants -save_dir = ( - "/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4" -) - -# rechunked data -data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr" - -# updated tracking data -tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -source_channel = ["background_mask", "uninfected_mask", "infected_mask"] -z_range = (0, 1) -batch_size = 1 # match the number of fovs being processed such that no data is left -# set to 15 for full, 12 for infected, and 8 for uninfected - -# non-rechunked data -data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -# updated tracking data -tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" - -source_channel_1 = ["Nuclei_prediction_labels"] - - -# %% Define the main function for training -def main(hparams): - # Initialize the data module for prediction, re-do embeddings but with size 224 by 224 - data_module = TripletDataModule( - data_path=data_path, - tracks_path=tracks_path, - source_channel=source_channel, - z_range=z_range, - initial_yx_patch_size=(224, 224), - final_yx_patch_size=(224, 224), - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - data_module.setup(stage="predict") - - print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") - - dataloader = DataLoader( - data_module.predict_dataset, - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - # Initialize the second data module for segmentation masks - seg_data_module = TripletDataModule( - data_path=data_path_1, - tracks_path=tracks_path_1, - source_channel=source_channel_1, - z_range=z_range, - initial_yx_patch_size=(224, 224), - final_yx_patch_size=(224, 224), - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - seg_data_module.setup(stage="predict") - - seg_dataloader = DataLoader( - seg_data_module.predict_dataset, - batch_size=batch_size, - num_workers=hparams.num_workers, - ) - - # Initialize lists to store average values - background_avg = [] - uninfected_avg = [] - infected_avg = [] - - for batch, seg_batch in tqdm( - zip(dataloader, seg_dataloader), - desc="Processing batches", - total=len(data_module.predict_dataset), - ): - anchor = batch["anchor"] - seg_anchor = seg_batch["anchor"].int() - - # Extract the fov_name and id from the batch - fov_name = batch["index"]["fov_name"][0] - cell_id = batch["index"]["id"].item() - - fov_dirs = fov_name.split("/") - # Construct the path to the CSV file - csv_path = os.path.join( - tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv" - ) - - # Read the CSV file - df = pd.read_csv(csv_path) - - # Find the row with the specified id and extract the track_id - track_id = df.loc[df["id"] == cell_id, "track_id"].values[0] - - # Create a boolean mask where segmentation values are equal to the track_id - mask = seg_anchor == track_id - # mask = (seg_anchor > 0) - - # Find the most frequent non-zero value in seg_anchor - # unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True) - # most_frequent_value = unique[np.argmax(counts)] - - # # Create a boolean mask where segmentation values are equal to the most frequent value - # mask = (seg_anchor == most_frequent_value) - - # Expand the mask to match the anchor tensor shape - mask = mask.expand(1, 3, 1, 224, 224) - - # Calculate average values for each channel (background, uninfected, infected) using the mask - background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item()) - uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item()) - infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item()) - - # Convert lists to numpy arrays - background_avg = np.array(background_avg) - uninfected_avg = np.array(uninfected_avg) - infected_avg = np.array(infected_avg) - - print("Average values per cell for each mask calculated.") - print("Background average shape:", background_avg.shape) - print("Uninfected average shape:", uninfected_avg.shape) - print("Infected average shape:", infected_avg.shape) - - # Save the averages as .npy files - np.save(os.path.join(save_dir, "background_avg.npy"), background_avg) - np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg) - np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--backbone", type=str, default="resnet50") - parser.add_argument("--margin", type=float, default=0.5) - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--schedule", type=str, default="Constant") - parser.add_argument("--log_steps_per_epoch", type=int, default=10) - parser.add_argument("--embedding_len", type=int, default=256) - parser.add_argument("--max_epochs", type=int, default=100) - parser.add_argument("--accelerator", type=str, default="gpu") - parser.add_argument("--devices", type=int, default=1) - parser.add_argument("--num_nodes", type=int, default=1) - parser.add_argument("--log_every_n_steps", type=int, default=1) - parser.add_argument("--num_workers", type=int, default=8) - args = parser.parse_args() - main(args) diff --git a/applications/contrastive_phenotyping/examples_cli/predict.yml b/applications/contrastive_phenotyping/examples_cli/predict.yml deleted file mode 100644 index 622f273fb..000000000 --- a/applications/contrastive_phenotyping/examples_cli/predict.yml +++ /dev/null @@ -1,59 +0,0 @@ -seed_everything: 42 -trainer: - accelerator: gpu - strategy: auto - devices: auto - num_nodes: 1 - precision: 32-true - callbacks: - - class_path: viscy.representation.embedding_writer.EmbeddingWriter - init_args: - output_path: "/path/to/output.zarr" - phate_kwargs: - n_components: 2 - knn: 10 - decay: 50 - gamma: 1 - # edit the following lines to specify logging path - # - class_path: lightning.pytorch.loggers.TensorBoardLogger - # init_args: - # save_dir: /path/to/save_dir - # version: name-of-experiment - # log_graph: True - inference_mode: true -model: - class_path: viscy.representation.engine.ContrastiveModule - init_args: - backbone: convnext_tiny - in_channels: 2 - in_stack_depth: 15 - stem_kernel_size: [5, 4, 4] -data: - class_path: viscy.data.triplet.TripletDataModule - init_args: - data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr - source_channel: - - Phase3D - - RFP - z_range: [28, 43] - batch_size: 32 - num_workers: 16 - initial_yx_patch_size: [192, 192] - final_yx_patch_size: [192, 192] - normalizations: - - class_path: viscy.transforms.NormalizeSampled - init_args: - keys: [Phase3D] - level: fov_statistics - subtrahend: mean - divisor: std - - class_path: viscy.transforms.ScaleIntensityRangePercentilesd - init_args: - keys: [RFP] - lower: 50 - upper: 99 - b_min: 0.0 - b_max: 1.0 -return_predictions: false -ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt diff --git a/applications/contrastive_phenotyping/figures/classify_feb.py b/applications/contrastive_phenotyping/figures/classify_feb.py deleted file mode 100644 index b9dd81b8e..000000000 --- a/applications/contrastive_phenotyping/figures/classify_feb.py +++ /dev/null @@ -1,100 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from imblearn.over_sampling import SMOTE -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix -from tqdm import tqdm - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation -from viscy.representation.evaluation.dimensionality_reduction import compute_pca - -# %% Defining Paths for February Dataset -feb_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr" -) - - -# %% Load and Process February Dataset -feb_embedding_dataset = read_embedding_dataset(feb_features_path) -print(feb_embedding_dataset) -pca_df = compute_pca(feb_embedding_dataset, n_components=6) - -# Print shape before merge -print("Shape of pca_df before merge:", pca_df.shape) - -# Load the ground truth infection labels -feb_ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track" -) -feb_infection = load_annotation( - feb_embedding_dataset, - feb_ann_root / "tracking_v1_infection.csv", - "infection class", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - -# Print shape of feb_infection -print("Shape of feb_infection:", feb_infection.shape) - -# Merge PCA results with ground truth labels on both 'fov_name' and 'id' -pca_df = pd.merge(pca_df, feb_infection.reset_index(), on=["fov_name", "id"]) - -# Print shape after merge -print("Shape of pca_df after merge:", pca_df.shape) - -# Prepare the full dataset -X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]] -y = pca_df["infection class"] - -# Apply SMOTE to balance the classes in the full dataset -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print shape after SMOTE -print( - f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}" -) - -# %% Train Logistic Regression Classifier with Progress Bar -model = LogisticRegression(max_iter=1000, random_state=42) - -# Wrap the training with tqdm to show a progress bar -for _ in tqdm(range(1)): - model.fit(X_resampled, y_resampled) - -# %% Predict Labels for the Entire Dataset -pca_df["Predicted_Label"] = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"])) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"])) - -# %% Plotting the Results -plt.figure(figsize=(10, 8)) -sns.scatterplot( - x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8 -) -plt.title("PCA with Ground Truth Labels") -plt.savefig("up_pca_ground_truth_labels.png", format="png", dpi=300) -plt.show() - -plt.figure(figsize=(10, 8)) -sns.scatterplot( - x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8 -) -plt.title("PCA with Logistic Regression Predicted Labels") -plt.savefig("up_pca_predicted_labels.png", format="png", dpi=300) -plt.show() - -# %% Save Predicted Labels to CSV -save_path_csv = "up_logistic_regression_predicted_labels_feb_pca.csv" -pca_df[["id", "fov_name", "Predicted_Label"]].to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") diff --git a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py b/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py deleted file mode 100644 index da63c52a8..000000000 --- a/applications/contrastive_phenotyping/figures/classify_feb_embeddings.py +++ /dev/null @@ -1,94 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import pandas as pd -from imblearn.over_sampling import SMOTE -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation import load_annotation - -# %% Defining Paths for February Dataset -feb_features_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/negpair_random_sampling2/" -) - - -# %% Load and Process February Dataset (Embedding Features) -feb_embedding_dataset = read_embedding_dataset( - feb_features_path / "febtest_predict.zarr" -) -print(feb_embedding_dataset) - -# Extract the embedding feature values as the input matrix (X) -X = feb_embedding_dataset["features"].values - -# Prepare a DataFrame for the embeddings with id and fov_name -embedding_df = pd.DataFrame(X, columns=[f"feature_{i+1}" for i in range(X.shape[1])]) -embedding_df["id"] = feb_embedding_dataset["id"].values -embedding_df["fov_name"] = feb_embedding_dataset["fov_name"].values -print(embedding_df.head()) - -# %% Load the ground truth infection labels -feb_ann_root = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred" -) -feb_infection = load_annotation( - feb_embedding_dataset, - feb_ann_root / "extracted_inf_state.csv", - "infection_state", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, -) - -# %% Merge embedding features with infection labels on 'fov_name' and 'id' -merged_df = pd.merge(embedding_df, feb_infection.reset_index(), on=["fov_name", "id"]) -print(merged_df.head()) -# %% Prepare the full dataset for training -X = merged_df.drop( - columns=["id", "fov_name", "infection_state"] -).values # Use embeddings as features -y = merged_df["infection_state"] # Use infection state as labels -print(X.shape) -print(y.shape) -# %% Print class distribution before applying SMOTE -print("Class distribution before SMOTE:") -print(y.value_counts()) - -# Apply SMOTE to balance the classes -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print class distribution after applying SMOTE -print("Class distribution after SMOTE:") -print(pd.Series(y_resampled).value_counts()) - -# Train Logistic Regression Classifier -model = LogisticRegression(max_iter=1000, random_state=42) -model.fit(X_resampled, y_resampled) - -# Predict Labels for the Entire Dataset -y_pred = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(y, y_pred)) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(y, y_pred)) - -# %% -# Save the predicted labels to a CSV -save_path_csv = feb_features_path / "feb_test_regression_predicted_labels_embedding.csv" -predicted_labels_df = pd.DataFrame( - { - "id": merged_df["id"].values, - "fov_name": merged_df["fov_name"].values, - "Predicted_Label": y_pred, - } -) - -predicted_labels_df.to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") - -# %% diff --git a/applications/contrastive_phenotyping/figures/classify_june.py b/applications/contrastive_phenotyping/figures/classify_june.py deleted file mode 100644 index ca51f2b17..000000000 --- a/applications/contrastive_phenotyping/figures/classify_june.py +++ /dev/null @@ -1,121 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from imblearn.over_sampling import SMOTE -from sklearn.decomposition import PCA -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report, confusion_matrix -from sklearn.preprocessing import StandardScaler -from tqdm import tqdm - -from viscy.representation.embedding_writer import read_embedding_dataset - -# %% Defining Paths for June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") - -# %% Function to Load Annotations -def load_annotation(da, path, name, categories: dict | None = None): - annotation = pd.read_csv(path) - annotation["fov_name"] = "/" + annotation["fov ID"] - annotation = annotation.set_index(["fov_name", "id"]) - mi = pd.MultiIndex.from_arrays( - [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] - ) - selected = annotation.loc[mi][name] - if categories: - selected = selected.astype("category").cat.rename_categories(categories) - return selected - -# %% Function to Compute PCA -def compute_pca(embedding_dataset, n_components=6): - features = embedding_dataset["features"] - scaled_features = StandardScaler().fit_transform(features.values) - - # Compute PCA with specified number of components - pca = PCA(n_components=n_components, random_state=42) - pca_embedding = pca.fit_transform(scaled_features) - - # Prepare DataFrame with id and PCA coordinates - pca_df = pd.DataFrame({ - "id": embedding_dataset["id"].values, - "fov_name": embedding_dataset["fov_name"].values, - "PCA1": pca_embedding[:, 0], - "PCA2": pca_embedding[:, 1], - "PCA3": pca_embedding[:, 2], - "PCA4": pca_embedding[:, 3], - "PCA5": pca_embedding[:, 4], - "PCA6": pca_embedding[:, 5] - }) - - return pca_df - -# %% Load and Process June Dataset -june_embedding_dataset = read_embedding_dataset(june_features_path) -print(june_embedding_dataset) -pca_df = compute_pca(june_embedding_dataset, n_components=6) - -# Print shape before merge -print("Shape of pca_df before merge:", pca_df.shape) - -# Load the ground truth infection labels -june_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.1-tracking") -june_infection = load_annotation(june_embedding_dataset, june_ann_root / "tracking_v1_infection.csv", "infection class", - {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) - -# Print shape of june_infection -print("Shape of june_infection:", june_infection.shape) - -# Merge PCA results with ground truth labels on both 'fov_name' and 'id' -pca_df = pd.merge(pca_df, june_infection.reset_index(), on=['fov_name', 'id']) - -# Print shape after merge -print("Shape of pca_df after merge:", pca_df.shape) - -# Prepare the full dataset -X = pca_df[["PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6"]] -y = pca_df["infection class"] - -# Apply SMOTE to balance the classes in the full dataset -smote = SMOTE(random_state=42) -X_resampled, y_resampled = smote.fit_resample(X, y) - -# Print shape after SMOTE -print(f"Shape after SMOTE - X_resampled: {X_resampled.shape}, y_resampled: {y_resampled.shape}") - -# %% Train Logistic Regression Classifier with Progress Bar -model = LogisticRegression(max_iter=1000, random_state=42) - -# Wrap the training with tqdm to show a progress bar -for _ in tqdm(range(1)): - model.fit(X_resampled, y_resampled) - -# %% Predict Labels for the Entire Dataset -pca_df["Predicted_Label"] = model.predict(X) - -# Compute metrics based on the entire original dataset -print("Classification Report for Entire Dataset:") -print(classification_report(pca_df["infection class"], pca_df["Predicted_Label"])) - -print("Confusion Matrix for Entire Dataset:") -print(confusion_matrix(pca_df["infection class"], pca_df["Predicted_Label"])) - -# %% Plotting the Results -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["infection class"], s=7, alpha=0.8) -plt.title("PCA with Ground Truth Labels") -plt.savefig("june_pca_ground_truth_labels.png", format='png', dpi=300) -plt.show() - -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=pca_df["PCA1"], y=pca_df["PCA2"], hue=pca_df["Predicted_Label"], s=7, alpha=0.8) -plt.title("PCA with Logistic Regression Predicted Labels") -plt.savefig("june_pca_predicted_labels.png", format='png', dpi=300) -plt.show() - -# %% Save Predicted Labels to CSV -save_path_csv = "june_logistic_regression_predicted_labels_feb_pca.csv" -pca_df[['id', 'fov_name', 'Predicted_Label']].to_csv(save_path_csv, index=False) -print(f"Predicted labels saved to {save_path_csv}") diff --git a/applications/contrastive_phenotyping/evaluation/figure.mplstyle b/applications/contrastive_phenotyping/figures/figure.mplstyle similarity index 100% rename from applications/contrastive_phenotyping/evaluation/figure.mplstyle rename to applications/contrastive_phenotyping/figures/figure.mplstyle diff --git a/applications/contrastive_phenotyping/figures/figure_4a_1.py b/applications/contrastive_phenotyping/figures/figure_4a_1.py deleted file mode 100644 index a670db0d0..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4a_1.py +++ /dev/null @@ -1,167 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from sklearn.preprocessing import StandardScaler -from umap import UMAP - -from viscy.representation.embedding_writer import read_embedding_dataset - -# %% Defining Paths for February and June Datasets -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_data_path = Path("/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr") -feb_tracks_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr") - -# %% Function to Load and Process the Embedding Dataset -def compute_umap(embedding_dataset): - features = embedding_dataset["features"] - scaled_features = StandardScaler().fit_transform(features.values) - umap = UMAP() - embedding = umap.fit_transform(scaled_features) - - features = ( - features.assign_coords(UMAP1=("sample", embedding[:, 0])) - .assign_coords(UMAP2=("sample", embedding[:, 1])) - .set_index(sample=["UMAP1", "UMAP2"], append=True) - ) - return features - -# %% Function to Load Annotations -def load_annotation(da, path, name, categories: dict | None = None): - annotation = pd.read_csv(path) - annotation["fov_name"] = "/" + annotation["fov ID"] - annotation = annotation.set_index(["fov_name", "id"]) - mi = pd.MultiIndex.from_arrays( - [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] - ) - selected = annotation.loc[mi][name] - if categories: - selected = selected.astype("category").cat.rename_categories(categories) - return selected - -# %% Function to Plot UMAP with Infection Annotations -def plot_umap_infection(features, infection, title): - plt.figure(figsize=(10, 8)) - sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) - plt.title(f"UMAP Plot - {title}") - plt.show() - -# %% Load and Process February Dataset -feb_embedding_dataset = read_embedding_dataset(feb_features_path) -feb_features = compute_umap(feb_embedding_dataset) - -feb_ann_root = Path("/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track") -feb_infection = load_annotation(feb_features, feb_ann_root / "tracking_v1_infection.csv", "infection class", {0.0: "background", 1.0: "uninfected", 2.0: "infected"}) - -# %% Plot UMAP with Infection Status for February Dataset -plot_umap_infection(feb_features, feb_infection, "February Dataset") - -# %% -print(feb_embedding_dataset) -print(feb_infection) -print(feb_features) -# %% - - -# %% Identify cells by infection type using fov_name -mock_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/3') | feb_features['fov_name'].str.contains('/B/3')) -zika_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/A/4')) -dengue_cells = feb_features.sel(sample=feb_features['fov_name'].str.contains('/B/4')) - -# %% Plot UMAP with Infection Status -plt.figure(figsize=(10, 8)) -sns.scatterplot(x=feb_features["UMAP1"], y=feb_features["UMAP2"], hue=feb_infection, s=7, alpha=0.8) - -# Overlay with circled cells -plt.scatter(mock_cells["UMAP1"], mock_cells["UMAP2"], facecolors='none', edgecolors='blue', s=20, label='Mock Cells') -plt.scatter(zika_cells["UMAP1"], zika_cells["UMAP2"], facecolors='none', edgecolors='green', s=20, label='Zika MOI 5') -plt.scatter(dengue_cells["UMAP1"], dengue_cells["UMAP2"], facecolors='none', edgecolors='red', s=20, label='Dengue MOI 5') - -# Add legend and show plot -plt.legend(loc='best') -plt.title("UMAP Plot - February Dataset with Mock, Zika, and Dengue Highlighted") -plt.show() - -# %% -# %% Create a 1x3 grid of heatmaps -fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0]) -axs[0].set_title('Mock Cells') -axs[0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1]) -axs[1].set_title('Zika MOI 5') -axs[1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[2]) -axs[2].set_title('Dengue MOI 5') -axs[2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Set labels and adjust layout -for ax in axs: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') - -plt.tight_layout() -plt.show() - -# %% -import matplotlib.pyplot as plt -import seaborn as sns - -# %% Create a 2x3 grid of heatmaps (1 row for each heatmap, splitting infected and uninfected in the second row) -fig, axs = plt.subplots(2, 3, figsize=(24, 12), sharex=True, sharey=True) - -# Mock Cells Heatmap -sns.histplot(x=mock_cells["UMAP1"], y=mock_cells["UMAP2"], bins=50, pmax=1, cmap="Blues", ax=axs[0, 0]) -axs[0, 0].set_title('Mock Cells') -axs[0, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Zika Cells Heatmap -sns.histplot(x=zika_cells["UMAP1"], y=zika_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[0, 1]) -axs[0, 1].set_title('Zika MOI 5') -axs[0, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Dengue Cells Heatmap -sns.histplot(x=dengue_cells["UMAP1"], y=dengue_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[0, 2]) -axs[0, 2].set_title('Dengue MOI 5') -axs[0, 2].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[0, 2].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Infected Cells Heatmap -sns.histplot(x=infected_cells["UMAP1"], y=infected_cells["UMAP2"], bins=50, pmax=1, cmap="Reds", ax=axs[1, 0]) -axs[1, 0].set_title('Infected Cells') -axs[1, 0].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 0].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Uninfected Cells Heatmap -sns.histplot(x=uninfected_cells["UMAP1"], y=uninfected_cells["UMAP2"], bins=50, pmax=1, cmap="Greens", ax=axs[1, 1]) -axs[1, 1].set_title('Uninfected Cells') -axs[1, 1].set_xlim(feb_features["UMAP1"].min(), feb_features["UMAP1"].max()) -axs[1, 1].set_ylim(feb_features["UMAP2"].min(), feb_features["UMAP2"].max()) - -# Remove the last subplot (bottom right corner) -fig.delaxes(axs[1, 2]) - -# Set labels and adjust layout -for ax in axs.flat: - ax.set_xlabel('UMAP1') - ax.set_ylabel('UMAP2') - -plt.tight_layout() -plt.show() - - - -# %% diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py b/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py deleted file mode 100644 index d3052018a..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_feb.py +++ /dev/null @@ -1,87 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - -from viscy.representation.embedding_writer import read_embedding_dataset - - -# %% Function to Load Annotations from GMM CSV -def load_gmm_annotation(gmm_csv_path): - gmm_df = pd.read_csv(gmm_csv_path) - return gmm_df - -# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on GMM Labels -def count_infected_cell_states_over_time(embedding_dataset, gmm_df): - # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - - # Merge with GMM data to add GMM labels - df = pd.merge(df, gmm_df[['id', 'fov_name', 'Predicted_Label']], on=['fov_name', 'id'], how='left') - - # Filter by time range (3 HPI to 30 HPI) - df = df[(df['t'] >= 3) & (df['t'] <= 27)] - - # Determine the well type (Mock, Zika, Dengue) based on fov_name - df['well_type'] = df['fov_name'].apply(lambda x: 'Mock' if '/A/3' in x or '/B/3' in x else - ('Zika' if '/A/4' in x else 'Dengue')) - - # Group by time, well type, and GMM label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Predicted_Label']).size().unstack(fill_value=0) - - # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - - # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - - return state_counts - -# %% Function to Plot Percentage of Infected Cells Over Time -def plot_infected_cell_states(state_counts): - plt.figure(figsize=(12, 8)) - - # Loop through each well type - for well_type in ['Mock', 'Zika', 'Dengue']: - # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - - # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') - - plt.title("Percentage of Infected Cells Over Time - February") - plt.xlabel("Hours Post Perturbation") - plt.ylabel("Percentage of Infected Cells") - plt.legend(title="Well Type") - plt.grid(True) - plt.show() - -# %% Load and process Feb Dataset -feb_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/June_140Patch_2chan/phaseRFP_140patch_99ckpt_Feb.zarr") -feb_embedding_dataset = read_embedding_dataset(feb_features_path) - -# Load the GMM annotation CSV -gmm_csv_path = "june_logistic_regression_predicted_labels_feb_pca.csv" # Path to CSV file -gmm_df = load_gmm_annotation(gmm_csv_path) - -# %% Count Infected Cell States Over Time as Percentage using GMM labels -state_counts = count_infected_cell_states_over_time(feb_embedding_dataset, gmm_df) -print(state_counts.head()) -state_counts.info() - -# %% Plot Infected Cell States Over Time as Percentage -plot_infected_cell_states(state_counts) - -# %% - - diff --git a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py b/applications/contrastive_phenotyping/figures/figure_4e_2_june.py deleted file mode 100644 index 1605ba278..000000000 --- a/applications/contrastive_phenotyping/figures/figure_4e_2_june.py +++ /dev/null @@ -1,85 +0,0 @@ -# %% Importing Necessary Libraries -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd - -from viscy.representation.embedding_writer import read_embedding_dataset - - -# %% Function to Load Annotations from CSV -def load_annotation(csv_path): - return pd.read_csv(csv_path) - -# %% Function to Count and Calculate Percentage of Infected Cells Over Time Based on Predicted Labels -def count_infected_cell_states_over_time(embedding_dataset, prediction_df): - # Convert the embedding dataset to a DataFrame - df = pd.DataFrame({ - "fov_name": embedding_dataset["fov_name"].values, - "track_id": embedding_dataset["track_id"].values, - "t": embedding_dataset["t"].values, - "id": embedding_dataset["id"].values - }) - - # Merge with the prediction data to add Predicted Labels - df = pd.merge(df, prediction_df[['id', 'fov_name', 'Infection_Class']], on=['fov_name', 'id'], how='left') - - # Filter by time range (2 HPI to 50 HPI) - df = df[(df['t'] >= 2) & (df['t'] <= 50)] - - # Determine the well type (Mock, Dengue, Zika) based on fov_name - df['well_type'] = df['fov_name'].apply( - lambda x: 'Mock' if '/0/1' in x or '/0/2' in x or '/0/3' in x or '/0/4' in x else - ('Dengue' if '/0/5' in x or '/0/6' in x else 'Zika')) - - # Group by time, well type, and Predicted_Label to count the number of infected cells - state_counts = df.groupby(['t', 'well_type', 'Infection_Class']).size().unstack(fill_value=0) - - # Ensure that 'infected' column exists - if 'infected' not in state_counts.columns: - state_counts['infected'] = 0 - - # Calculate the percentage of infected cells - state_counts['total'] = state_counts.sum(axis=1) - state_counts['infected'] = (state_counts['infected'] / state_counts['total']) * 100 - - return state_counts - -# %% Function to Plot Percentage of Infected Cells Over Time -def plot_infected_cell_states(state_counts): - plt.figure(figsize=(12, 8)) - - # Loop through each well type - for well_type in ['Mock', 'Dengue', 'Zika']: - # Select the data for the current well type - if well_type in state_counts.index.get_level_values('well_type'): - well_data = state_counts.xs(well_type, level='well_type') - - # Plot only the percentage of infected cells - if 'infected' in well_data.columns: - plt.plot(well_data.index, well_data['infected'], label=f'{well_type} - Infected') - - plt.title("Percentage of Infected Cells Over Time - June") - plt.xlabel("Hours Post Perturbation") - plt.ylabel("Percentage of Infected Cells") - plt.legend(title="Well Type") - plt.grid(True) - plt.show() - -# %% Load and process June Dataset -june_features_path = Path("/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/code_testing_soorya/output/Phase_RFP_smallPatch_June/phaseRFP_36patch_June.zarr") -june_embedding_dataset = read_embedding_dataset(june_features_path) - -# Load the predicted labels from CSV -prediction_csv_path = "3up_gmm_clustering_results_june_pca_6components.csv" # Path to predicted labels CSV file -prediction_df = load_annotation(prediction_csv_path) - -# %% Count Infected Cell States Over Time as Percentage using Predicted labels -state_counts = count_infected_cell_states_over_time(june_embedding_dataset, prediction_df) -print(state_counts.head()) -state_counts.info() - -# %% Plot Infected Cell States Over Time as Percentage -plot_infected_cell_states(state_counts) - -# %% diff --git a/applications/contrastive_phenotyping/evaluation/grad_attr.py b/applications/contrastive_phenotyping/figures/grad_attr.py similarity index 86% rename from applications/contrastive_phenotyping/evaluation/grad_attr.py rename to applications/contrastive_phenotyping/figures/grad_attr.py index 2def1bfce..f0873c288 100644 --- a/applications/contrastive_phenotyping/evaluation/grad_attr.py +++ b/applications/contrastive_phenotyping/figures/grad_attr.py @@ -29,8 +29,8 @@ # %% dm = TripletDataModule( - data_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr", + data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr", source_channel=["Phase3D", "RFP"], z_range=[25, 40], batch_size=48, @@ -76,10 +76,10 @@ "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178_gt_tracks.zarr" ) path_annotations_infection = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv" ) path_annotations_division = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv" ) infection_dataset = read_embedding_dataset(path_infection_embedding) @@ -145,7 +145,7 @@ # %% # load infection annotations infection = pd.read_csv( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv", + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/supervised_inf_pred/extracted_inf_state.csv", ) track_classes_infection = infection[infection["fov_name"] == fov[1:]] track_classes_infection = track_classes_infection[ @@ -155,7 +155,7 @@ # %% # load division annotations division = pd.read_csv( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv", + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_02_04_A549_DENV_ZIKV_timelapse/9-lineage-cell-division/lineages_gt/cell_division_state_test_set.csv", ) track_classes_division = division[division["fov_name"] == fov[1:]] track_classes_division = track_classes_division[ @@ -258,7 +258,8 @@ def clim_percentile(heatmap, low=1, high=99): # %% f.savefig( Path.home() - / "gdrive/publications/learning_impacts_of_infection/fig_manuscript/fig_explanation/fig_explanation_patch12_stride4.pdf", + / "mydata" + / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_occlusion_analysis/figure_parts/fig_explanation_patch12_stride4.pdf", dpi=300, ) diff --git a/applications/contrastive_phenotyping/figures/save_patches.py b/applications/contrastive_phenotyping/figures/save_patches.py deleted file mode 100644 index ebba6c320..000000000 --- a/applications/contrastive_phenotyping/figures/save_patches.py +++ /dev/null @@ -1,67 +0,0 @@ -# %% script to save 128 by 128 image patches from napari viewer - -import os -import sys -from pathlib import Path - -import numpy as np - -sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") -# from viscy.data.triplet import TripletDataModule -from viscy.representation.evaluation import dataset_of_tracks - -# %% input parameters - -data_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/registered_test.zarr" -) -tracks_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/8-train-test-split/track_test.zarr" -) - -fov_name = "/B/4/6" -track_id = 52 -source_channel = ["Phase3D", "RFP"] - -# %% load dataset - -prediction_dataset = dataset_of_tracks( - data_path, - tracks_path, - [fov_name], - [track_id], - source_channel=source_channel, -) -whole = np.stack([p["anchor"] for p in prediction_dataset]) -phase = whole[:, 0] -fluor = whole[:, 1] - -# use the following if you want to visualize a specific phase slice with max projected fluor -# phase = whole[:, 0, 3] # 3 is the slice number -# fluor = np.max(whole[:, 1], axis=1) - -# load image -# v = napari.Viewer() -# v.add_image(phase) -# v.add_image(fluor) - -# %% save patches as png images - -# use sliders on napari to get the deisred contrast and make other adjustments -# then use save screenshot if saving the image patch manually -# you can add code to automate the process if desired - -# %% save as numpy files - -out_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/Figure_panels/data/" -fov_name_out = fov_name.replace("/", "_") -np.save( - (os.path.join(out_dir, "phase" + fov_name_out + "_" + str(track_id) + ".npy")), - phase, -) -np.save( - (os.path.join(out_dir, "fluor" + fov_name_out + "_" + str(track_id) + ".npy")), - fluor, -) - -# %% diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py deleted file mode 100644 index a4e712f5b..000000000 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ /dev/null @@ -1,106 +0,0 @@ -# %% -import lightning.pytorch as pl -import torch -import torch.nn as nn -from applications.infection_classification.classify_infection_25D import ( - SemanticSegUNet25D, -) -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from viscy.data.hcs import HCSDataModule -from viscy.preprocessing.pixel_ratio import sematic_class_weights -from viscy.transforms import NormalizeSampled, RandWeightedCropd - -# %% Create a dataloader and visualize the batches. - -# Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" - -# %% create data module - -# Create an instance of HCSDataModule -data_module = HCSDataModule( - dataset_path, - source_channel=["Phase", "HSP90"], - target_channel=["Inf_mask"], - yx_patch_size=[512, 512], - split_ratio=0.8, - z_window_size=5, - architecture="2.5D", - num_workers=3, - batch_size=32, - normalizations=[ - NormalizeSampled( - keys=["Phase", "HSP90"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], - augmentations=[ - RandWeightedCropd( - num_samples=4, - spatial_size=[-1, 512, 512], - keys=["Phase", "HSP90"], - w_key="Inf_mask", - ) - ], -) - -pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") - -# Prepare the data -data_module.prepare_data() - -# Setup the data -data_module.setup(stage="fit") - -# Create a dataloader -train_dm = data_module.train_dataloader() - -val_dm = data_module.val_dataloader() - - -# %% Define the logger -logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", - name="logs", -) - -# Pass the logger to the Trainer -trainer = pl.Trainer( - logger=logger, - max_epochs=200, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - log_every_n_steps=1, - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) - -# Define the checkpoint callback -checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - filename="checkpoint_{epoch:02d}", - save_top_k=-1, - verbose=True, - monitor="loss/validate", - mode="min", -) - -# Add the checkpoint callback to the trainer -trainer.callbacks.append(checkpoint_callback) - -# Fit the model -model = SemanticSegUNet25D( - in_channels=2, - out_channels=3, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), -) - -print(model) - -# %% Run training. - -trainer.fit(model, data_module) - -# %% diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py deleted file mode 100644 index bfe203625..000000000 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ /dev/null @@ -1,107 +0,0 @@ -# %% -# import sys -# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") -import lightning.pytorch as pl -import torch -import torch.nn as nn -from applications.infection_classification.classify_infection_covnext import ( - SemanticSegUNet22D, -) -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from viscy.data.hcs import HCSDataModule -from viscy.preprocessing.pixel_ratio import sematic_class_weights -from viscy.transforms import NormalizeSampled, RandWeightedCropd - -# %% Create a dataloader and visualize the batches. - -# Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" - -# %% craete data module - -# Create an instance of HCSDataModule -data_module = HCSDataModule( - dataset_path, - source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - target_channel=["Inf_mask"], - yx_patch_size=[256, 256], - split_ratio=0.8, - z_window_size=5, - architecture="2.2D", - num_workers=3, - batch_size=16, - normalizations=[ - NormalizeSampled( - keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], - augmentations=[ - RandWeightedCropd( - num_samples=4, - spatial_size=[-1, 256, 256], - keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], - w_key="Inf_mask", - ) - ], -) -pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") - -# Prepare the data -data_module.prepare_data() - -# Setup the data -data_module.setup(stage="fit") - -# Create a dataloader -train_dm = data_module.train_dataloader() - -val_dm = data_module.val_dataloader() - - -# %% Define the logger -logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", - name="logs", -) - -# Pass the logger to the Trainer -trainer = pl.Trainer( - logger=logger, - max_epochs=200, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - log_every_n_steps=1, - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) - -# Define the checkpoint callback -checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", - filename="checkpoint_{epoch:02d}", - save_top_k=-1, - verbose=True, - monitor="loss/validate", - mode="min", -) - -# Add the checkpoint callback to the trainer`` -trainer.callbacks.append(checkpoint_callback) - -# Fit the model -model = SemanticSegUNet22D( - in_channels=4, - out_channels=3, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), -) - -print(model) - -# %% Run training. - -trainer.fit(model, data_module) - -# %% diff --git a/applications/infection_classification/readme.md b/applications/infection_classification/README.md similarity index 100% rename from applications/infection_classification/readme.md rename to applications/infection_classification/README.md diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py deleted file mode 100644 index e16f56f42..000000000 --- a/applications/infection_classification/classify_infection_25D.py +++ /dev/null @@ -1,356 +0,0 @@ -# import torchview -from typing import Literal, Sequence - -import cv2 -import lightning.pytorch as pl -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from matplotlib.cm import get_cmap -from monai.transforms import DivisiblePad -from skimage.exposure import rescale_intensity -from skimage.measure import label, regionprops -from torch import Tensor - -from viscy.data.hcs import Sample -from viscy.unet.networks.Unet25D import Unet25d - -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) - confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def compute_confusion_matrix( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - conf_mat = np.zeros((num_classes, num_classes)) - for i in range(batch_size): - y_true_cpu = y_true[i].cpu().numpy() - y_pred_cpu = y_pred[i].cpu().numpy() - y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize( - y_pred_reshaped, - dsize=y_true_reshaped.shape[::-1], - interpolation=cv2.INTER_NEAREST, - ) - y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - - # find objects in every image - label_img = label(y_true_reshaped) - regions = regionprops(label_img) - - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - if region.area > 0: - row, col = region.centroid - pred_id = y_pred_resized[int(row), int(col)] - test_id = y_true_reshaped[int(row), int(col)] - - if pred_id == 1 and test_id == 1: - conf_mat[1, 1] += 1 - if pred_id == 1 and test_id == 2: - conf_mat[0, 1] += 1 - if pred_id == 2 and test_id == 1: - conf_mat[1, 0] += 1 - if pred_id == 2 and test_id == 2: - conf_mat[0, 0] += 1 - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - return conf_mat - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - index_to_label_dict = dict( - enumerate(index_to_label_dict) - ) # Convert list to dictionary - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - # plt.show(fig) # Show the figure - return fig - - -# Define a 25d unet model for infection classification as a lightning module. - - -class SemanticSegUNet25D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - ckpt_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = Unet25d( - in_channels=in_channels, - out_channels=out_channels, - num_blocks=4, - num_block_layers=4, - ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - # prob_chan = prob_pred[:, 2, :, :] - # prob_chan = prob_chan.unsqueeze(1) - return labels_pred # log the class predicted image - # return prob_chan # log the probability predicted image - - def on_test_start(self): - self.pred_cm = torch.zeros((2, 2)) - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - def test_step(self, batch: Sample): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - - target = self._predict_pad(batch["target"]) # Extract the target from the batch - pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=2 - ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(pred_cm, self.index_to_label_dict), - self.current_epoch, - ) - - # Accumulate the confusion matrix at the end of test epoch and log. - def on_test_end(self): - confusion_matrix_sum = self.pred_cm - self.logger.experiment.add_figure( - "Confusion Matrix", - plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - - -# %% diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py deleted file mode 100644 index 397e822db..000000000 --- a/applications/infection_classification/classify_infection_covnext.py +++ /dev/null @@ -1,363 +0,0 @@ -# import torchview -from typing import Literal, Sequence - -import cv2 -import lightning.pytorch as pl -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from matplotlib.cm import get_cmap -from monai.transforms import DivisiblePad -from skimage.exposure import rescale_intensity -from skimage.measure import label, regionprops -from torch import Tensor - -from viscy.data.hcs import Sample -from viscy.translation.engine import VSUNet - -# -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) - confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def compute_confusion_matrix( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - conf_mat = np.zeros((num_classes, num_classes)) - for i in range(batch_size): - y_true_cpu = y_true[i].cpu().numpy() - y_pred_cpu = y_pred[i].cpu().numpy() - y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize( - y_pred_reshaped, - dsize=y_true_reshaped.shape[::-1], - interpolation=cv2.INTER_NEAREST, - ) - y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - - # find objects in every image - label_img = label(y_true_reshaped) - regions = regionprops(label_img) - - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - if region.area > 0: - row, col = region.centroid - pred_id = y_pred_resized[int(row), int(col)] - test_id = y_true_reshaped[int(row), int(col)] - - if pred_id == 1 and test_id == 1: - conf_mat[1, 1] += 1 - if pred_id == 1 and test_id == 2: - conf_mat[0, 1] += 1 - if pred_id == 2 and test_id == 1: - conf_mat[1, 0] += 1 - if pred_id == 2 and test_id == 2: - conf_mat[0, 0] += 1 - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - return conf_mat - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - index_to_label_dict = dict( - enumerate(index_to_label_dict) - ) # Convert list to dictionary - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - # plt.show(fig) # Show the figure - return fig - - -# Define a 25d unet model for infection classification as a lightning module. - - -class SemanticSegUNet22D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - ckpt_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet22D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = VSUNet( - architecture="2.2D", - model_config={ - "in_channels": in_channels, - "out_channels": out_channels, - "in_stack_depth": 5, - "backbone": "convnextv2_tiny", - "stem_kernel_size": (5, 4, 4), - "decoder_mode": "pixelshuffle", - "head_expansion_ratio": 4, - }, - ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - # prob_chan = prob_pred[:, 2, :, :] - # prob_chan = prob_chan.unsqueeze(1) - return labels_pred # log the class predicted image - # return prob_chan # log the probability predicted image - - def on_test_start(self): - self.pred_cm = torch.zeros((2, 2)) - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - def test_step(self, batch: Sample): - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax( - prob_pred, dim=1, keepdim=True - ) # Calculate the predicted labels - - target = self._predict_pad(batch["target"]) # Extract the target from the batch - pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=2 - ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(pred_cm, self.index_to_label_dict), - self.current_epoch, - ) - - # Accumulate the confusion matrix at the end of test epoch and log. - def on_test_end(self): - confusion_matrix_sum = self.pred_cm - self.logger.experiment.add_figure( - "Confusion Matrix", - plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - - -# %% diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py new file mode 100644 index 000000000..78a5a719b --- /dev/null +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py @@ -0,0 +1,766 @@ +# %% +import ast +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from plotting_utils import ( + find_pattern_matches, + identify_lineages, + plot_pc_trajectories, + plot_reference_vs_full_lineages, +) +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.dimensionality_reduction import compute_pca + +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") # Simplified format +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +NAPARI = True +if NAPARI: + import os + + import napari +s + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() +# %% +# Organelle and Phate aligned to infection + +input_data_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +) +infection_annotations_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/combined_annotations_n_tracks_infection.csv" +) + +pretrain_features_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/prediction_pretrained_models" +) +# Phase n organelle +# dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + +# pahe n sensor +dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions_infection/2chan_192patch_100ckpt_timeAware_ntxent_GT.zarr" + +output_root = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/model_comparison" +) + + +# Load embeddings +imagenet_features_path = ( + pretrain_features_root / "ImageNet/20241107_sensor_n_phase_imagenet.zarr" +) +openphenom_features_path = ( + pretrain_features_root / "OpenPhenom/20241107_sensor_n_phase_openphenom.zarr" +) + +dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path) +imagenet_embeddings = read_embedding_dataset(imagenet_features_path) +openphenom_embeddings = read_embedding_dataset(openphenom_features_path) + +# Load infection annotations +infection_annotations_df = pd.read_csv(infection_annotations_path) +infection_annotations_df["fov_name"] = "/C/2/000001" + +process_embeddings = [ + (dynaclr_embeddings, "dynaclr"), + (imagenet_embeddings, "imagenet"), + (openphenom_embeddings, "openphenom"), +] + + +output_root.mkdir(parents=True, exist_ok=True) +# %% +feature_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True) + +# Logic to find lineages +lineages = identify_lineages(feature_df) +logger.info(f"Found {len(lineages)} distinct lineages") +filtered_lineages = [] +min_timepoints = 20 +for fov_id, track_ids in lineages: + # Get all rows for this lineage + lineage_rows = feature_df[ + (feature_df["fov_name"] == fov_id) & (feature_df["track_id"].isin(track_ids)) + ] + + # Count the total number of timepoints + total_timepoints = len(lineage_rows) + + # Only keep lineages with at least min_timepoints + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) +logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +) + +# %% +# Aligning condition embeddings to infection +# OPTION 1: Use the infection annotations to find the reference lineage +reference_lineage_fov = "/C/2/001000" +reference_lineage_track_id = [129] +reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling + +# Option 2: from the filtered lineages find one from FOV C/2/000001 +reference_lineage_fov = "/C/2/000001" +for fov_id, track_ids in filtered_lineages: + if reference_lineage_fov == fov_id: + break +reference_lineage_track_id = track_ids +reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling + +# %% +# Dictionary to store alignment results for comparison +alignment_results = {} + +for embeddings, name in process_embeddings: + # Get the reference pattern from the current embedding space + reference_pattern = None + reference_lineage = [] + for fov_id, track_ids in filtered_lineages: + if fov_id == reference_lineage_fov and all( + track_id in track_ids for track_id in reference_lineage_track_id + ): + logger.info( + f"Found reference pattern for {fov_id} {reference_lineage_track_id} using {name} embeddings" + ) + reference_pattern = embeddings.sel( + sample=(fov_id, reference_lineage_track_id) + ).features.values + reference_lineage.append(reference_pattern) + break + if reference_pattern is None: + logger.info(f"Reference pattern not found for {name} embeddings. Skipping.") + continue + reference_pattern = np.concatenate(reference_lineage) + reference_pattern = reference_pattern[ + reference_timepoints[0] : reference_timepoints[1] + ] + + # Find all matches to the reference pattern + metric = "cosine" + all_match_positions = find_pattern_matches( + reference_pattern, + filtered_lineages, + embeddings, + window_step_fraction=0.1, + num_candidates=4, + method="bernd_clifford", + save_path=output_root / f"{name}_matching_lineages_{metric}.csv", + metric=metric, + ) + + # Store results for later comparison + alignment_results[name] = all_match_positions + +# Visualize warping paths in PC space instead of raw embedding dimensions +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Call the new function from plotting_utils + plot_pc_trajectories( + reference_lineage_fov=reference_lineage_fov, + reference_lineage_track_id=reference_lineage_track_id, + reference_timepoints=reference_timepoints, + match_positions=match_positions, + embeddings_dataset=next( + emb for emb, emb_name in process_embeddings if emb_name == name + ), + filtered_lineages=filtered_lineages, + name=name, + save_path=output_root / f"{name}_pc_lineage_alignment.png", + ) + + +# %% +# Compare DTW performance between embedding methods + +# Create a DataFrame to collect the alignment statistics for comparison +match_data = [] +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + for i, row in match_positions.head(10).iterrows(): # Take top 10 matches + warping_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + match_data.append( + { + "model": name, + "match_position": row["start_timepoint"], + "dtw_distance": row["distance"], + "path_skewness": row["skewness"], + "path_length": len(warping_path), + } + ) + +comparison_df = pd.DataFrame(match_data) + +# Create visualizations to compare alignment quality +plt.figure(figsize=(12, 10)) + +# 1. Compare DTW distances +plt.subplot(2, 2, 1) +sns.boxplot(x="model", y="dtw_distance", data=comparison_df) +plt.title("DTW Distance by Model") +plt.ylabel("DTW Distance (lower is better)") + +# 2. Compare path skewness +plt.subplot(2, 2, 2) +sns.boxplot(x="model", y="path_skewness", data=comparison_df) +plt.title("Path Skewness by Model") +plt.ylabel("Skewness (lower is better)") + +# 3. Compare path lengths +plt.subplot(2, 2, 3) +sns.boxplot(x="model", y="path_length", data=comparison_df) +plt.title("Warping Path Length by Model") +plt.ylabel("Path Length") + +# 4. Scatterplot of distance vs skewness +plt.subplot(2, 2, 4) +scatter = sns.scatterplot( + x="dtw_distance", y="path_skewness", hue="model", data=comparison_df +) +plt.title("DTW Distance vs Path Skewness") +plt.xlabel("DTW Distance") +plt.ylabel("Path Skewness") +plt.legend(title="Model") + +plt.tight_layout() +plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300) +plt.close() + +# %% +# Analyze warping path step patterns for better understanding of alignment quality + +# Step pattern analysis +step_pattern_counts = { + name: {"diagonal": 0, "horizontal": 0, "vertical": 0, "total": 0} + for name in alignment_results.keys() +} + +for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + # Get the top match + top_match = match_positions.iloc[0] + path = ( + ast.literal_eval(top_match["warp_path"]) + if isinstance(top_match["warp_path"], str) + else top_match["warp_path"] + ) + + # Count step types + for i in range(1, len(path)): + prev_i, prev_j = path[i - 1] + curr_i, curr_j = path[i] + + step_i = curr_i - prev_i + step_j = curr_j - prev_j + + if step_i == 1 and step_j == 1: + step_pattern_counts[name]["diagonal"] += 1 + elif step_i == 1 and step_j == 0: + step_pattern_counts[name]["vertical"] += 1 + elif step_i == 0 and step_j == 1: + step_pattern_counts[name]["horizontal"] += 1 + + step_pattern_counts[name]["total"] += 1 + +# Convert to percentages +for name in step_pattern_counts: + total = step_pattern_counts[name]["total"] + if total > 0: + for key in ["diagonal", "horizontal", "vertical"]: + step_pattern_counts[name][key] = ( + step_pattern_counts[name][key] / total + ) * 100 + +# Visualize step pattern distributions +step_df = pd.DataFrame( + { + "model": [name for name in step_pattern_counts.keys() for _ in range(3)], + "step_type": ["diagonal", "horizontal", "vertical"] * len(step_pattern_counts), + "percentage": [ + step_pattern_counts[name]["diagonal"] for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["horizontal"] + for name in step_pattern_counts.keys() + ] + + [ + step_pattern_counts[name]["vertical"] for name in step_pattern_counts.keys() + ], + } +) + +plt.figure(figsize=(10, 6)) +sns.barplot(x="model", y="percentage", hue="step_type", data=step_df) +plt.title("Step Pattern Distribution in Warping Paths") +plt.ylabel("Percentage (%)") +plt.savefig(output_root / "step_pattern_distribution.png", dpi=300) +plt.close() + +# %% +# Find all matches to the reference pattern +MODEL = "openphenom" +alignment_df_path = output_root / f"{MODEL}_matching_lineages_cosine.csv" +alignment_df = pd.read_csv(alignment_df_path) + +# Get the top N aligned cells + +source_channels = [ + "Phase3D", + "raw GFP EX488 EM525-45", + "raw mCherry EX561 EM600-37", +] +yx_patch_size = (192, 192) +z_range = (10, 30) +view_ref_sector_only = (True,) + +all_lineage_images = [] +all_aligned_stacks = [] +all_unaligned_stacks = [] + +# Get aligned and unaligned stacks +top_aligned_cells = alignment_df.head(5) +napari_viewer = viewer if NAPARI else None +# Plot the aligned and unaligned stacks +for idx, row in tqdm( + top_aligned_cells.iterrows(), + total=len(top_aligned_cells), + desc="Aligning images", +): + fov_name = row["fov_name"] + track_ids = ast.literal_eval(row["track_ids"]) + warp_path = ast.literal_eval(row["warp_path"]) + start_time = int(row["start_timepoint"]) + + print(f"Aligning images for {fov_name} with track ids: {track_ids}") + data_module = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=source_channels, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + batch_size=1, + num_workers=12, + predict_cells=True, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + ) + data_module.setup("predict") + + # Get the images for the lineage + lineage_images = [] + for batch in data_module.predict_dataloader(): + image = batch["anchor"].numpy()[0] + lineage_images.append(image) + + lineage_images = np.array(lineage_images) + all_lineage_images.append(lineage_images) + print(f"Lineage images shape: {np.array(lineage_images).shape}") + + # Create an aligned stack based on the warping path + if view_ref_sector_only: + aligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + unaligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + + # Map each reference timepoint to the corresponding lineage timepoint + for ref_idx in range(len(reference_pattern)): + # Find matches in warping path for this reference index + matches = [(i, q) for i, q in warp_path if i == ref_idx] + unaligned_stack[ref_idx] = lineage_images[ref_idx] + if matches: + # Get the corresponding lineage timepoint (first match if multiple) + print(f"Found match for ref idx: {ref_idx}") + match = matches[0] + query_idx = match[1] + lineage_idx = int(start_time + query_idx) + print( + f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + ) + # Copy the image if it's within bounds + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Find nearest valid timepoint if out of bounds + nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + else: + # If no direct match, find closest reference timepoint in warping path + print(f"No match found for ref idx: {ref_idx}") + all_ref_indices = [i for i, _ in warp_path] + if all_ref_indices: + closest_ref_idx = min( + all_ref_indices, key=lambda x: abs(x - ref_idx) + ) + closest_matches = [ + (i, q) for i, q in warp_path if i == closest_ref_idx + ] + + if closest_matches: + closest_query_idx = closest_matches[0][1] + lineage_idx = int(start_time + closest_query_idx) + + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Bound to valid range + nearest_idx = min( + max(0, lineage_idx), len(lineage_images) - 1 + ) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + + all_aligned_stacks.append(aligned_stack) + all_unaligned_stacks.append(unaligned_stack) + +all_aligned_stacks = np.array(all_aligned_stacks) +all_unaligned_stacks = np.array(all_unaligned_stacks) +# %% +if NAPARI: + for idx, row in tqdm( + top_aligned_cells.reset_index().iterrows(), + total=len(top_aligned_cells), + desc="Plotting aligned and unaligned stacks", + ): + fov_name = row["fov_name"] + # track_ids = ast.literal_eval(row["track_ids"]) + track_ids = row["track_ids"] + + aligned_stack = all_aligned_stacks[idx] + unaligned_stack = all_unaligned_stacks[idx] + + unaligned_gfp_mip = np.max(unaligned_stack[:, 1, :, :], axis=1) + aligned_gfp_mip = np.max(aligned_stack[:, 1, :, :], axis=1) + unaligned_mcherry_mip = np.max(unaligned_stack[:, 2, :, :], axis=1) + aligned_mcherry_mip = np.max(aligned_stack[:, 2, :, :], axis=1) + + z_slice = 15 + unaligned_phase = unaligned_stack[:, 0, z_slice, :] + aligned_phase = aligned_stack[:, 0, z_slice, :] + + # unaligned + viewer.add_image( + unaligned_gfp_mip, + name=f"unaligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + unaligned_mcherry_mip, + name=f"unaligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + unaligned_phase, + name=f"unaligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + # aligned + viewer.add_image( + aligned_gfp_mip, + name=f"aligned_gfp_{fov_name}_{track_ids[0]}", + colormap="green", + contrast_limits=(106, 215), + ) + viewer.add_image( + aligned_mcherry_mip, + name=f"aligned_mcherry_{fov_name}_{track_ids[0]}", + colormap="magenta", + contrast_limits=(106, 190), + ) + viewer.add_image( + aligned_phase, + name=f"aligned_phase_{fov_name}_{track_ids[0]}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + viewer.grid.enabled = True + viewer.grid.shape = (-1, 6) +# %% +# Evaluate model performance based on infection state warping accuracy +# Check unique infection status values +unique_infection_statuses = infection_annotations_df["infection_status"].unique() +logger.info(f"Unique infection status values: {unique_infection_statuses}") + +# If "infected" is not in the unique values, this could explain zero precision/recall +if "infected" not in unique_infection_statuses: + logger.warning('The label "infected" is not found in the infection_status column!') + logger.info(f"Using these values instead: {unique_infection_statuses}") + + # If we need to map values, we could do it here + if len(unique_infection_statuses) >= 2: + logger.info( + f'Will treat "{unique_infection_statuses[1]}" as "infected" for metrics calculation' + ) + infection_target_value = unique_infection_statuses[1] + else: + infection_target_value = unique_infection_statuses[0] +else: + infection_target_value = "infected" + +logger.info(f'Using "{infection_target_value}" as positive class for F1 calculation') + +# Check if the reference track is in the annotations +logger.info( + f"Looking for infection annotations for reference lineage: {reference_lineage_fov}, tracks: {reference_lineage_track_id}" +) +print(f"Sample of infection_annotations_df: {infection_annotations_df.head()}") + +reference_infection_states = {} +for track_id in reference_lineage_track_id: + reference_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == reference_lineage_fov) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this reference track + annotation_count = len(reference_annotations) + logger.info(f"Found {annotation_count} annotations for track {track_id}") + if annotation_count > 0: + print( + f"Sample annotations for track {track_id}: {reference_annotations.head()}" + ) + + for _, row in reference_annotations.iterrows(): + reference_infection_states[row["t"]] = row["infection_status"] + +if reference_infection_states: + logger.info( + f"Total reference timepoints with infection status: {len(reference_infection_states)}" + ) + reference_t_range = range(reference_timepoints[0], reference_timepoints[1]) + reference_gt_states = [ + reference_infection_states.get(t, "unknown") for t in reference_t_range + ] + logger.info(f"Reference track infection states: {reference_gt_states[:5]}...") + + # Evaluate warping accuracy for each model + model_performance = [] + + for name, match_positions in alignment_results.items(): + if match_positions is not None and not match_positions.empty: + total_correct = 0 + total_predictions = 0 + true_positives = 0 + false_positives = 0 + false_negatives = 0 + + # Analyze top alignments for this model + alignment_details = [] + for i, row in match_positions.head(10).iterrows(): + fov_name = row["fov_name"] + track_ids = row[ + "track_ids" + ] # This is already a list of track IDs for the lineage + warp_path = ( + ast.literal_eval(row["warp_path"]) + if isinstance(row["warp_path"], str) + else row["warp_path"] + ) + start_time = int(row["start_timepoint"]) + + # Get annotations for all tracks in this lineage + track_infection_states = {} + for track_id in track_ids: + track_annotations = infection_annotations_df[ + (infection_annotations_df["fov_name"] == fov_name) + & (infection_annotations_df["track_id"] == track_id) + ] + + # Add annotations for this track to the combined dictionary + for _, annotation_row in track_annotations.iterrows(): + # Use t + track-specific offset if needed to handle timepoint overlaps between tracks + track_infection_states[annotation_row["t"]] = annotation_row[ + "infection_status" + ] + + # Only proceed if we found annotations for at least one track + if track_infection_states: + # For each reference timepoint, check if the warped timepoint maintains the infection state + track_correct = 0 + track_predictions = 0 + track_tp = 0 + track_fp = 0 + track_fn = 0 + + for ref_idx, query_idx in warp_path: + # Map to actual timepoints + ref_t = reference_timepoints[0] + ref_idx + query_t = start_time + query_idx + + # Get ground truth infection states + ref_state = reference_infection_states.get(ref_t, "unknown") + query_state = track_infection_states.get(query_t, "unknown") + + # Skip unknown states + if ref_state != "unknown" and query_state != "unknown": + track_predictions += 1 + + # Count correct alignments + if ref_state == query_state: + track_correct += 1 + + # Calculate F1 score components for "infected" state + if ( + ref_state == infection_target_value + and query_state == infection_target_value + ): + track_tp += 1 + elif ( + ref_state != infection_target_value + and query_state == infection_target_value + ): + track_fp += 1 + elif ( + ref_state == infection_target_value + and query_state != infection_target_value + ): + track_fn += 1 + + # Calculate track-specific metrics + if track_predictions > 0: + track_accuracy = track_correct / track_predictions + track_precision = ( + track_tp / (track_tp + track_fp) + if (track_tp + track_fp) > 0 + else 0 + ) + track_recall = ( + track_tp / (track_tp + track_fn) + if (track_tp + track_fn) > 0 + else 0 + ) + track_f1 = ( + 2 + * (track_precision * track_recall) + / (track_precision + track_recall) + if (track_precision + track_recall) > 0 + else 0 + ) + + alignment_details.append( + { + "fov_name": fov_name, + "track_ids": track_ids, + "accuracy": track_accuracy, + "precision": track_precision, + "recall": track_recall, + "f1_score": track_f1, + "correct": track_correct, + "total": track_predictions, + } + ) + + # Add to model totals + total_correct += track_correct + total_predictions += track_predictions + true_positives += track_tp + false_positives += track_fp + false_negatives += track_fn + + # Calculate metrics + accuracy = total_correct / total_predictions if total_predictions > 0 else 0 + precision = ( + true_positives / (true_positives + false_positives) + if (true_positives + false_positives) > 0 + else 0 + ) + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + + # Store alignment details for this model + if alignment_details: + alignment_details_df = pd.DataFrame(alignment_details) + print(f"\nDetailed alignment results for {name}:") + print(alignment_details_df) + alignment_details_df.to_csv( + output_root / f"{name}_alignment_details.csv", index=False + ) + + model_performance.append( + { + "model": name, + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1_score": f1, + "total_predictions": total_predictions, + } + ) + + # Create performance DataFrame and visualize + performance_df = pd.DataFrame(model_performance) + print(performance_df) + + # Plot performance metrics + plt.figure(figsize=(12, 8)) + + # Accuracy plot + plt.subplot(2, 2, 1) + sns.barplot(x="model", y="accuracy", data=performance_df) + plt.title("Infection State Warping Accuracy") + plt.ylabel("Accuracy") + + # Precision plot + plt.subplot(2, 2, 2) + sns.barplot(x="model", y="precision", data=performance_df) + plt.title("Precision for Infected State") + plt.ylabel("Precision") + + # Recall plot + plt.subplot(2, 2, 3) + sns.barplot(x="model", y="recall", data=performance_df) + plt.title("Recall for Infected State") + plt.ylabel("Recall") + + # F1 score plot + plt.subplot(2, 2, 4) + sns.barplot(x="model", y="f1_score", data=performance_df) + plt.title("F1 Score for Infected State") + plt.ylabel("F1 Score") + + plt.tight_layout() + # plt.savefig(output_root / "infection_state_warping_performance.png", dpi=300) + # plt.close() +else: + logger.warning("Reference track annotations not found in infection_annotations_df") + +# %% diff --git a/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py b/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py new file mode 100644 index 000000000..6e8fb6042 --- /dev/null +++ b/applications/pseudotime_analysis/evaluation/dtw_compare_openphenom.py @@ -0,0 +1,145 @@ +# %% +from pathlib import Path + +import pandas as pd +import torch +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +# Load model directly +from transformers import AutoModel + +from viscy.data.triplet import TripletDataModule +from viscy.representation.evaluation.dimensionality_reduction import compute_phate +from viscy.transforms import ScaleIntensityRangePercentilesd + + +# %% OpenPhenom Wrapper +class OpenPhenomWrapper: + def __init__(self): + try: + self.model = AutoModel.from_pretrained( + "recursionpharma/OpenPhenom", trust_remote_code=True + ) + self.model.eval() + self.model.to("cuda") + except ImportError: + raise ImportError( + "Please install the OpenPhenom dependencies: " + "pip install git+https://github.com/recursionpharma/maes_microscopy.git" + ) + + def extract_features(self, x): + """Extract features from the input images. + + Args: + x: Input tensor of shape [B, C, D, H, W] or [B, C, H, W] + + Returns: + Features of shape [B, 384] + """ + # OpenPhenom expects [B, C, H, W] but our data might be [B, C, D, H, W] + # If 5D input, take middle slice or average across D + if x.dim() == 5: + # Take middle slice or average across D dimension + d = x.shape[2] + x = x[:, :, d // 2, :, :] + + # Convert to uint8 as OpenPhenom expects uint8 inputs + if x.dtype != torch.uint8: + # Normalize to 0-1 range if not already + x = (x - x.min()) / (x.max() - x.min() + 1e-10) + x = (x * 255).clamp(0, 255).to(torch.uint8) + + # Get embeddings + self.model.return_channelwise_embeddings = False + with torch.no_grad(): + embeddings = self.model.predict(x) + + return embeddings + + +# %% Initialize OpenPhenom model +print("Loading OpenPhenom model...") +openphenom = OpenPhenomWrapper() +# For infection dataset with phase and RFP +print("Setting up data module...") +dm = TripletDataModule( + data_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr", + source_channel=["Phase3D", "raw GFP EX488 EM525-45"], + batch_size=32, # Lower batch size for OpenPhenom which is larger + num_workers=10, + z_range=(25, 40), + initial_yx_patch_size=(192, 192), + final_yx_patch_size=(192, 192), + normalizations=[ + ScaleIntensityRangePercentilesd( + keys=["raw GFP EX488 EM525-45"], lower=50, upper=99, b_min=0.0, b_max=1.0 + ), + ScaleIntensityRangePercentilesd( + keys=["Phase3D"], lower=50, upper=99, b_min=0.0, b_max=1.0 + ), + ], +) +dm.prepare_data() +dm.setup("predict") +# %% +print("Extracting features...") +features = [] +indices = [] + +with torch.inference_mode(): + for batch in tqdm(dm.predict_dataloader()): + # Get both channels and handle dimensions properly + phase = batch["anchor"][:, 0] # Phase channel + rfp = batch["anchor"][:, 1] # RFP channel + rfp = torch.max(rfp, dim=1).values + Z = phase.shape[-3] + phase = phase[:, Z // 2] + img = torch.stack([phase, rfp], dim=1).to("cuda") + + # Extract features using OpenPhenom + batch_features = openphenom.extract_features(img) + features.append(batch_features.cpu()) + indices.append(batch["index"]) + +# %% +print("Processing features...") +pooled = torch.cat(features).numpy() +tracks = pd.concat([pd.DataFrame(idx) for idx in indices]) +print("Computing PCA and PHATE...") +scaled_features = StandardScaler().fit_transform(pooled) +pca = PCA(n_components=8) +pca_features = pca.fit_transform(scaled_features) + +phate_embedding = compute_phate( + embeddings=pooled, + n_components=2, + knn=5, + decay=40, + n_jobs=15, +) +# %% Add features to dataframe +for i, feature in enumerate(pooled.T): + tracks[f"feature_{i}"] = feature +# Add PCA features to dataframe +for i, feature in enumerate(pca_features.T): + tracks[f"pca_{i}"] = feature +for i, feature in enumerate(phate_embedding.T): + tracks[f"phate_{i}"] = feature + +# %% Save the extracted features + +output_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/openphenom_pretrained_analysis" +) +output_path.mkdir(parents=True, exist_ok=True) +output_embeddings_file = ( + output_path / "openphenom_pretrained_features_SEC61B_n_Phase.csv" +) +print(f"Saving features to {output_embeddings_file}") +tracks.to_csv(output_embeddings_file, index=False) + +# %% diff --git a/applications/pseudotime_analysis/plotting_utils.py b/applications/pseudotime_analysis/plotting_utils.py new file mode 100644 index 000000000..c5e0ba8c0 --- /dev/null +++ b/applications/pseudotime_analysis/plotting_utils.py @@ -0,0 +1,1225 @@ +# Plotting utils +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr + +from viscy.data.triplet import TripletDataModule + +logger = logging.getLogger(__name__) + + +def plot_reference_aligned_average( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + save_path: str | None = None, +) -> np.ndarray: + """ + Plot the reference embedding, aligned embeddings, and average aligned embedding. + + Args: + reference_pattern: The reference pattern embeddings + top_aligned_cells: DataFrame with alignment information + embeddings_dataset: Dataset containing embeddings + save_path: Path to save the figure (optional) + """ + plt.figure(figsize=(15, 10)) + + # Get the reference pattern embeddings + reference_embeddings = reference_pattern + + # Calculate average aligned embeddings + all_aligned_embeddings = [] + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + # Reconstruct the concatenated lineage + lineages = [] + track_offsets = ( + {} + ) # To keep track of where each track starts in the concatenated array + current_offset = 0 + + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + track_offsets[track_id] = current_offset + current_offset += len(track_embeddings) + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_embeddings = np.zeros( + (len(reference_pattern), lineage_embeddings.shape[1]), + dtype=lineage_embeddings.dtype, + ) + + # Create mapping from reference to lineage + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + + # Fill aligned embeddings + for ref_idx in range(len(reference_pattern)): + if ref_idx in ref_to_lineage: + aligned_embeddings[ref_idx] = lineage_embeddings[ + ref_to_lineage[ref_idx] + ] + elif ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_embeddings[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_embeddings) + + # Calculate average aligned embeddings + average_aligned_embeddings = np.mean(all_aligned_embeddings, axis=0) + + # Plot dimension 0 + plt.subplot(2, 1, 1) + # Plot reference pattern + plt.plot( + range(len(reference_embeddings)), + reference_embeddings[:, 0], + label="Reference", + color="black", + linewidth=3, + ) + + # Plot each aligned embedding + for i, aligned_embeddings in enumerate(all_aligned_embeddings): + plt.plot( + range(len(aligned_embeddings)), + aligned_embeddings[:, 0], + label=f"Aligned {i}", + alpha=0.4, + linestyle="--", + ) + + # Plot average aligned embedding + plt.plot( + range(len(average_aligned_embeddings)), + average_aligned_embeddings[:, 0], + label="Average Aligned", + color="red", + linewidth=2, + ) + + plt.title("Dimension 0: Reference, Aligned, and Average Embeddings") + plt.xlabel("Reference Time Index") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot dimension 1 + plt.subplot(2, 1, 2) + # Plot reference pattern + plt.plot( + range(len(reference_embeddings)), + reference_embeddings[:, 1], + label="Reference", + color="black", + linewidth=3, + ) + + # Plot each aligned embedding + for i, aligned_embeddings in enumerate(all_aligned_embeddings): + plt.plot( + range(len(aligned_embeddings)), + aligned_embeddings[:, 1], + label=f"Aligned {i}", + alpha=0.4, + linestyle="--", + ) + + # Plot average aligned embedding + plt.plot( + range(len(average_aligned_embeddings)), + average_aligned_embeddings[:, 1], + label="Average Aligned", + color="red", + linewidth=2, + ) + + plt.title("Dimension 1: Reference, Aligned, and Average Embeddings") + plt.xlabel("Reference Time Index") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + plt.show() + + return average_aligned_embeddings + + +def plot_reference_vs_full_lineages( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + save_path: str | None = None, +) -> np.ndarray: + """ + Visualize where the reference pattern matches in each full lineage. + + Args: + reference_pattern: The reference pattern embeddings + top_aligned_cells: DataFrame with alignment information + embeddings_dataset: Dataset containing embeddings + save_path: Path to save the figure (optional) + """ + plt.figure(figsize=(15, 15)) + + # First, plot the reference pattern for comparison + plt.subplot(len(top_aligned_cells) + 1, 2, 1) + plt.plot( + range(len(reference_pattern)), + reference_pattern[:, 0], + label="Reference Dim 0", + color="black", + linewidth=2, + ) + plt.title("Reference Pattern - Dimension 0") + plt.xlabel("Time Index") + plt.ylabel("Embedding Value") + plt.grid(True, alpha=0.3) + plt.legend() + + plt.subplot(len(top_aligned_cells) + 1, 2, 2) + plt.plot( + range(len(reference_pattern)), + reference_pattern[:, 1], + label="Reference Dim 1", + color="black", + linewidth=2, + ) + plt.title("Reference Pattern - Dimension 1") + plt.xlabel("Time Index") + plt.ylabel("Embedding Value") + plt.grid(True, alpha=0.3) + plt.legend() + + # Then plot each lineage with the matched section highlighted + for i, (_, row) in enumerate(top_aligned_cells.iterrows()): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = row["start_timepoint"] + distance = row["distance"] + + # Get the full lineage embeddings + lineage_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_ids) + ).features.values + + # Create a subplot for dimension 0 + plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 3) + + # Plot the full lineage + plt.plot( + range(len(lineage_embeddings)), + lineage_embeddings[:, 0], + label="Full Lineage", + color="blue", + alpha=0.7, + ) + + # Highlight the matched section + matched_indices = set() + for _, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + matched_indices.add(lineage_idx) + + matched_indices = sorted(list(matched_indices)) + if matched_indices: + plt.plot( + matched_indices, + [lineage_embeddings[idx, 0] for idx in matched_indices], + "ro-", + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5) + plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5) + + # Add text labels + plt.text( + min(matched_indices), + min(lineage_embeddings[:, 0]), + f"Start: {min(matched_indices)}", + color="red", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(lineage_embeddings[:, 0]), + f"End: {max(matched_indices)}", + color="red", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - Dimension 0") + plt.xlabel("Lineage Time") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Create a subplot for dimension 1 + plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 4) + + # Plot the full lineage + plt.plot( + range(len(lineage_embeddings)), + lineage_embeddings[:, 1], + label="Full Lineage", + color="green", + alpha=0.7, + ) + + # Highlight the matched section + if matched_indices: + plt.plot( + matched_indices, + [lineage_embeddings[idx, 1] for idx in matched_indices], + "ro-", + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5) + plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5) + + # Add text labels + plt.text( + min(matched_indices), + min(lineage_embeddings[:, 1]), + f"Start: {min(matched_indices)}", + color="red", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(lineage_embeddings[:, 1]), + f"End: {max(matched_indices)}", + color="red", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) - Dimension 1") + plt.xlabel("Lineage Time") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + plt.show() + + +def align_and_average_embeddings( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + use_median: bool = False, +) -> np.ndarray: + """ + Align embeddings from multiple lineages to a reference pattern and compute their average. + + Args: + reference_pattern: The reference pattern embeddings + top_aligned_cells: DataFrame with alignment information + embeddings_dataset: Dataset containing embeddings + use_median: If True, use median instead of mean for averaging + + Returns: + The average (or median) aligned embeddings + """ + all_aligned_embeddings = [] + + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + # Reconstruct the concatenated lineage + lineages = [] + track_offsets = ( + {} + ) # To keep track of where each track starts in the concatenated array + current_offset = 0 + + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + track_offsets[track_id] = current_offset + current_offset += len(track_embeddings) + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_segment = np.zeros_like(reference_pattern) + + # Map each reference timepoint to the corresponding lineage timepoint + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + aligned_segment[ref_idx] = lineage_embeddings[lineage_idx] + + # Fill in missing values by using the closest available reference index + for ref_idx in range(len(reference_pattern)): + if ref_idx not in ref_to_lineage and ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_segment[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_segment) + + all_aligned_embeddings = np.array(all_aligned_embeddings) + + # Compute average or median + if use_median: + return np.median(all_aligned_embeddings, axis=0) + else: + return np.mean(all_aligned_embeddings, axis=0) + + +def align_image_stacks( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + input_data_path: Path, + tracks_path: Path, + source_channels: list[str], + yx_patch_size: tuple[int, int] = (192, 192), + z_range: tuple[int, int] = (0, 1), + view_ref_sector_only: bool = True, + napari_viewer=None, +) -> tuple[list, list]: + """ + Align image stacks from multiple lineages to a reference pattern. + + Args: + reference_pattern: The reference pattern embeddings + top_aligned_cells: DataFrame with alignment information + input_data_path: Path to the input data + tracks_path: Path to the tracks data + source_channels: List of channels to include + yx_patch_size: Patch size for images + z_range: Z-range to include + view_ref_sector_only: If True, only show the section that matches the reference pattern + napari_viewer: Optional napari viewer for visualization + + Returns: + Tuple of (all_lineage_images, all_aligned_stacks) + """ + from tqdm import tqdm + + all_lineage_images = [] + all_aligned_stacks = [] + + for idx, row in tqdm( + top_aligned_cells.iterrows(), + total=len(top_aligned_cells), + desc="Aligning images", + ): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + print(f"Aligning images for {fov_name} with track ids: {track_ids}") + data_module = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=source_channels, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + batch_size=1, + num_workers=12, + predict_cells=True, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + ) + data_module.setup("predict") + + # Get the images for the lineage + lineage_images = [] + for batch in data_module.predict_dataloader(): + image = batch["anchor"].numpy()[0] + lineage_images.append(image) + + lineage_images = np.array(lineage_images) + all_lineage_images.append(lineage_images) + print(f"Lineage images shape: {np.array(lineage_images).shape}") + + # Create an aligned stack based on the warping path + if view_ref_sector_only: + aligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + + # Map each reference timepoint to the corresponding lineage timepoint + for ref_idx in range(len(reference_pattern)): + # Find matches in warping path for this reference index + matches = [(i, q) for i, q in warp_path if i == ref_idx] + + if matches: + # Get the corresponding lineage timepoint (first match if multiple) + print(f"Found match for ref idx: {ref_idx}") + match = matches[0] + query_idx = match[1] + lineage_idx = int(start_time + query_idx) + print( + f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + ) + # Copy the image if it's within bounds + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Find nearest valid timepoint if out of bounds + nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + else: + # If no direct match, find closest reference timepoint in warping path + print(f"No match found for ref idx: {ref_idx}") + all_ref_indices = [i for i, _ in warp_path] + if all_ref_indices: + closest_ref_idx = min( + all_ref_indices, key=lambda x: abs(x - ref_idx) + ) + closest_matches = [ + (i, q) for i, q in warp_path if i == closest_ref_idx + ] + + if closest_matches: + closest_query_idx = closest_matches[0][1] + lineage_idx = int(start_time + closest_query_idx) + + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Bound to valid range + nearest_idx = min( + max(0, lineage_idx), len(lineage_images) - 1 + ) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + + all_aligned_stacks.append(aligned_stack) + if napari_viewer: + napari_viewer.add_image( + aligned_stack, + name=f"Aligned_{fov_name}_track_{track_ids[0]}", + channel_axis=1, + ) + else: + # View the whole lineage shifted by the start time + start_idx = int(start_time) + aligned_stack = lineage_images[start_idx:] + all_aligned_stacks.append(aligned_stack) + if napari_viewer: + napari_viewer.add_image( + aligned_stack, + name=f"Aligned_{fov_name}_track_{track_ids[0]}", + channel_axis=1, + ) + + return all_lineage_images, all_aligned_stacks + + +def find_pattern_matches( + reference_pattern: np.ndarray, + filtered_lineages: list[tuple[str, list[int]]], + embeddings_dataset: xr.Dataset, + window_step_fraction: float = 0.25, + num_candidates: int = 3, + max_distance: float = float("inf"), + max_skew: float = 0.8, # Add skewness parameter + save_path: str | None = None, + method: str = "bernd_clifford", + normalize: bool = True, + metric: str = "euclidean", +) -> pd.DataFrame: + """ + Find the best matches of a reference pattern in multiple lineages using DTW. + + Args: + reference_pattern: The reference pattern embeddings + filtered_lineages: List of lineages to search in (fov_name, track_ids) + embeddings_dataset: Dataset containing embeddings + window_step_fraction: Fraction of reference pattern length to use as window step + num_candidates: Number of best candidates to consider per lineage + max_distance: Maximum distance threshold to consider a match + max_skew: Maximum allowed path skewness (0-1, where 0=perfect diagonal) + save_path: Optional path to save the results CSV + method: DTW method to use - 'bernd_clifford' (from utils.py) or 'dtai' (dtaidistance library) + + Returns: + DataFrame with match positions and distances + """ + from scipy.spatial.distance import cdist + from tqdm import tqdm + + # Calculate window step based on reference pattern length + window_step = max(1, int(len(reference_pattern) * window_step_fraction)) + + all_match_positions = { + "fov_name": [], + "track_ids": [], + "distance": [], + "skewness": [], # Add skewness to results + "warp_path": [], + "start_timepoint": [], + "end_timepoint": [], + } + + for i, (fov_name, track_ids) in tqdm( + enumerate(filtered_lineages), + total=len(filtered_lineages), + desc="Finding pattern matches", + ): + print(f"Finding pattern matches for {fov_name} with track ids: {track_ids}") + # Reconstruct the concatenated lineage + lineages = [] + track_offsets = ( + {} + ) # To keep track of where each track starts in the concatenated array + current_offset = 0 + + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + track_offsets[track_id] = current_offset + current_offset += len(track_embeddings) + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Find best matches using the selected DTW method + if method == "bernd_clifford": + matches_df = find_best_match_dtw_bernd_clifford( + lineage_embeddings, + reference_pattern=reference_pattern, + num_candidates=num_candidates, + window_step=window_step, + max_distance=max_distance, + max_skew=max_skew, + normalize=normalize, + metric=metric, + ) + else: + matches_df = find_best_match_dtw( + lineage_embeddings, + reference_pattern=reference_pattern, + num_candidates=num_candidates, + window_step=window_step, + max_distance=max_distance, + max_skew=max_skew, + normalize=normalize, + ) + + if not matches_df.empty: + # Get the best match (first row of the sorted DataFrame) + best_match = matches_df.iloc[0] + best_pos = best_match["position"] + best_path = best_match["path"] + best_dist = best_match["distance"] + best_skew = best_match["skewness"] + + all_match_positions["fov_name"].append(fov_name) + all_match_positions["track_ids"].append(track_ids) + all_match_positions["distance"].append(best_dist) + all_match_positions["skewness"].append(best_skew) + all_match_positions["warp_path"].append(best_path) + all_match_positions["start_timepoint"].append(best_pos) + all_match_positions["end_timepoint"].append( + best_pos + len(reference_pattern) + ) + else: + all_match_positions["fov_name"].append(fov_name) + all_match_positions["track_ids"].append(track_ids) + all_match_positions["distance"].append(None) + all_match_positions["skewness"].append(None) + all_match_positions["warp_path"].append(None) + all_match_positions["start_timepoint"].append(None) + all_match_positions["end_timepoint"].append(None) + + # Convert to DataFrame and drop rows with no matches + all_match_positions = pd.DataFrame(all_match_positions) + all_match_positions = all_match_positions.dropna() + + # Sort by distance (primary) and skewness (secondary) + all_match_positions = all_match_positions.sort_values( + by=["distance", "skewness"], ascending=[True, True] + ) + + # Save to CSV if path is provided + if save_path: + all_match_positions.to_csv(save_path, index=False) + + return all_match_positions + + +def find_best_match_dtw( + lineage: np.ndarray, + reference_pattern: np.ndarray, + num_candidates: int = 5, + window_step: int = 5, + max_distance: float = float("inf"), + max_skew: float = 0.8, + normalize: bool = True, +) -> pd.DataFrame: + """ + Find the best matches in a lineage using DTW with dtaidistance. + + Args: + lineage: The lineage to search (t,embeddings). + reference_pattern: The pattern to search for (t,embeddings). + num_candidates: The number of candidates to return. + window_step: The step size for the window. + max_distance: Maximum distance threshold to consider a match. + max_skew: Maximum allowed path skewness (0-1). + + Returns: + DataFrame with position, warping_path, distance, and skewness for the best matches + """ + from dtaidistance.dtw_ndim import warping_path + + from utils import path_skew + + dtw_results = [] + n_windows = len(lineage) - len(reference_pattern) + 1 + + if n_windows <= 0: + return pd.DataFrame(columns=["position", "path", "distance", "skewness"]) + + for i in range(0, n_windows, window_step): + window = lineage[i : i + len(reference_pattern)] + path, dist = warping_path( + reference_pattern, + window, + include_distance=True, + ) + if normalize: + # Normalize by path length to match bernd_clifford method + dist = dist / len(path) + # Calculate skewness using the utils function + skewness = path_skew(path, len(reference_pattern), len(window)) + + if dist <= max_distance and skewness <= max_skew: + dtw_results.append( + {"position": i, "path": path, "distance": dist, "skewness": skewness} + ) + + # Convert to DataFrame + results_df = pd.DataFrame(dtw_results) + + # Sort by distance first (primary) and then by skewness (secondary) + if not results_df.empty: + results_df = results_df.sort_values(by=["distance", "skewness"]).head( + num_candidates + ) + + return results_df + + +def find_best_match_dtw_bernd_clifford( + lineage: np.ndarray, + reference_pattern: np.ndarray, + num_candidates: int = 5, + window_step: int = 5, + normalize: bool = True, + max_distance: float = float("inf"), + max_skew: float = 0.8, + metric: str = "euclidean", +) -> pd.DataFrame: + """ + Find the best matches in a lineage using DTW with the utils.py implementation. + + Args: + lineage: The lineage to search (t,embeddings). + reference_pattern: The pattern to search for (t,embeddings). + num_candidates: The number of candidates to return. + window_step: The step size for the window. + max_distance: Maximum distance threshold to consider a match. + max_skew: Maximum allowed path skewness (0-1). + + Returns: + DataFrame with position, warping_path, distance, and skewness for the best matches + """ + from scipy.spatial.distance import cdist + + from utils import dtw_with_matrix, path_skew + + dtw_results = [] + n_windows = len(lineage) - len(reference_pattern) + 1 + + if n_windows <= 0: + return pd.DataFrame(columns=["position", "path", "distance", "skewness"]) + + for i in range(0, n_windows, window_step): + window = lineage[i : i + len(reference_pattern)] + + # Create distance matrix + distance_matrix = cdist(reference_pattern, window, metric=metric) + + # Apply DTW using utils.py implementation + distance, _, path = dtw_with_matrix(distance_matrix, normalize=normalize) + + # Calculate skewness + skewness = path_skew(path, len(reference_pattern), len(window)) + + # Only add if both distance and skewness pass thresholds + if distance <= max_distance and skewness <= max_skew: + logger.debug( + f"Found match at {i} with distance {distance} and skewness {skewness}" + ) + dtw_results.append( + { + "position": i, + "path": path, + "distance": distance, + "skewness": skewness, + } + ) + + # Convert to DataFrame + results_df = pd.DataFrame(dtw_results) + + # Sort by distance first (primary) and then by skewness (secondary) + if not results_df.empty: + results_df = results_df.sort_values(by=["distance", "skewness"]).head( + num_candidates + ) + + return results_df + + +def create_consensus_embedding( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, +) -> np.ndarray: + """ + Create a consensus embedding from multiple aligned embeddings using + a weighted approach based on DTW distances. + + Args: + reference_pattern: The reference pattern embeddings + top_aligned_cells: DataFrame with alignment information + embeddings_dataset: Dataset containing embeddings + + Returns: + The consensus embedding + """ + all_aligned_embeddings = [] + distances = [] + + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + distance = row["distance"] + + # Get lineage embeddings (similar to align_and_average_embeddings) + lineages = [] + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_segment = np.zeros_like(reference_pattern) + + # Map each reference timepoint to the corresponding lineage timepoint + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + aligned_segment[ref_idx] = lineage_embeddings[lineage_idx] + + # Fill in missing values + for ref_idx in range(len(reference_pattern)): + if ref_idx not in ref_to_lineage and ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_segment[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_segment) + distances.append(distance) + + all_aligned_embeddings = np.array(all_aligned_embeddings) + + # Convert distances to weights (smaller distance = higher weight) + weights = 1.0 / ( + np.array(distances) + 1e-10 + ) # Add small epsilon to avoid division by zero + weights = weights / np.sum(weights) # Normalize weights + + # Create weighted consensus + consensus_embedding = np.zeros_like(reference_pattern) + for i, aligned_embedding in enumerate(all_aligned_embeddings): + consensus_embedding += weights[i] * aligned_embedding + + return consensus_embedding + + +def identify_lineages( + tracking_df: pd.DataFrame, return_both_branches: bool = False +) -> list[tuple[str, list[int]]]: + """ + Identifies all distinct lineages in the cell tracking data, following only + one branch after each division event. + + Args: + annotations_path: Path to the annotations CSV file + + Returns: + A list of tuples, where each tuple contains (fov_id, [track_ids]) + representing a single branch lineage within a single FOV + """ + # Read the CSV file + + # Process each FOV separately to handle repeated track_ids + all_lineages = [] + + # Group by FOV + for fov_id, fov_df in tracking_df.groupby("fov_name"): + # Create a dictionary to map tracks to their parents within this FOV + child_to_parent = {} + + # Group by track_id and get the first row for each track to find its parent + for track_id, track_group in fov_df.groupby("track_id"): + first_row = track_group.iloc[0] + parent_track_id = first_row["parent_track_id"] + + if parent_track_id != -1: + child_to_parent[track_id] = parent_track_id + + # Find root tracks (those without parents or with parent_track_id = -1) + all_tracks = set(fov_df["track_id"].unique()) + child_tracks = set(child_to_parent.keys()) + root_tracks = all_tracks - child_tracks + + # Additional validation for root tracks + root_tracks = set() + for track_id in all_tracks: + track_data = fov_df[fov_df["track_id"] == track_id] + # Check if it's truly a root track + if ( + track_data.iloc[0]["parent_track_id"] == -1 + or track_data.iloc[0]["parent_track_id"] not in all_tracks + ): + root_tracks.add(track_id) + + # Build a parent-to-children mapping + parent_to_children = {} + for child, parent in child_to_parent.items(): + if parent not in parent_to_children: + parent_to_children[parent] = [] + parent_to_children[parent].append(child) + + # Function to get all branches from each parent + def get_all_branches(track_id): + branches = [] + current_branch = [track_id] + + if track_id in parent_to_children: + # For each child, get all their branches + for child in parent_to_children[track_id]: + child_branches = get_all_branches(child) + # Add current track to start of each child branch + for branch in child_branches: + branches.append(current_branch + branch) + else: + # If no children, return just this track + branches.append(current_branch) + + return branches + + # Build lineages starting from root tracks within this FOV + for root_track in root_tracks: + # Get all branches from this root + lineage_tracks = get_all_branches(root_track) + if return_both_branches: + for branch in lineage_tracks: + all_lineages.append((fov_id, branch)) + else: + all_lineages.append((fov_id, lineage_tracks[0])) + + return all_lineages + + +def plot_pc_trajectories( + reference_lineage_fov: str, + reference_lineage_track_id: list[int], + reference_timepoints: list[int], + match_positions: pd.DataFrame, + embeddings_dataset: xr.Dataset, + filtered_lineages: list[tuple[str, list[int]]], + name: str, + save_path: Path, +): + """ + Visualize warping paths in PC space, comparing reference pattern with aligned lineages. + + Args: + reference_lineage_fov: FOV name for the reference lineage + reference_lineage_track_id: Track ID for the reference lineage + reference_timepoints: Time range [start, end] to use from reference + match_positions: DataFrame with alignment matches + embeddings_dataset: Dataset with embeddings + filtered_lineages: List of lineages to search in (fov_name, track_ids) + name: Name of the embedding model + save_path: Path to save the figure + """ + import ast + from sklearn.decomposition import PCA + from sklearn.preprocessing import StandardScaler + import matplotlib.pyplot as plt + import numpy as np + import pandas as pd + + # Get reference pattern + ref_pattern = None + for fov_id, track_ids in filtered_lineages: + if fov_id == reference_lineage_fov and all( + track_id in track_ids for track_id in reference_lineage_track_id + ): + ref_pattern = embeddings_dataset.sel( + sample=(fov_id, reference_lineage_track_id) + ).features.values + break + + if ref_pattern is None: + logger.info( + f"Reference pattern not found for {name}. Skipping PC trajectory plot." + ) + return + + ref_pattern = np.concatenate([ref_pattern]) + ref_pattern = ref_pattern[reference_timepoints[0] : reference_timepoints[1]] + + # Get top matches + top_n_aligned_cells = match_positions.head(5) + + # Compute PCA directly with sklearn + # Scale the input data + scaler = StandardScaler() + ref_pattern_scaled = scaler.fit_transform(ref_pattern) + + # Create and fit PCA model + pca_model = PCA(n_components=2, random_state=42) + pca_ref = pca_model.fit_transform(ref_pattern_scaled) + + # Create a figure to display the results + plt.figure(figsize=(15, 15)) + + # Plot the reference pattern PCs + plt.subplot(len(top_n_aligned_cells) + 1, 2, 1) + plt.plot( + range(len(pca_ref)), + pca_ref[:, 0], + label="Reference PC1", + color="black", + linewidth=2, + ) + plt.title(f"{name} - Reference Pattern - PC1") + plt.xlabel("Time Index") + plt.ylabel("PC1 Value") + plt.grid(True, alpha=0.3) + plt.legend() + + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2) + plt.plot( + range(len(pca_ref)), + pca_ref[:, 1], + label="Reference PC2", + color="black", + linewidth=2, + ) + plt.title(f"{name} - Reference Pattern - PC2") + plt.xlabel("Time Index") + plt.ylabel("PC2 Value") + plt.grid(True, alpha=0.3) + plt.legend() + + # Then plot each lineage with the matched section highlighted + for i, (_, row) in enumerate(top_n_aligned_cells.iterrows()): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + if isinstance(track_ids, str): + track_ids = ast.literal_eval(track_ids) + warp_path = row["warp_path"] + if isinstance(warp_path, str): + warp_path = ast.literal_eval(warp_path) + start_time = row["start_timepoint"] + distance = row["distance"] + + # Get the full lineage embeddings + lineage_embeddings = [] + for track_id in track_ids: + try: + track_emb = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineage_embeddings.append(track_emb) + except KeyError: + pass + + if not lineage_embeddings: + continue + + lineage_embeddings = np.concatenate(lineage_embeddings, axis=0) + + # Transform lineage embeddings using the same PCA model + # Scale first using the same scaler + lineage_scaled = scaler.transform(lineage_embeddings) + pca_lineage = pca_model.transform(lineage_scaled) + + # Create a subplot for PC1 + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 3) + + # Plot the full lineage PC1 + plt.plot( + range(len(pca_lineage)), + pca_lineage[:, 0], + label="Full Lineage", + color="blue", + alpha=0.7, + ) + + # Highlight the matched section + matched_indices = set() + for _, query_idx in warp_path: + lineage_idx = ( + int(start_time) + query_idx if not pd.isna(start_time) else query_idx + ) + if 0 <= lineage_idx < len(pca_lineage): + matched_indices.add(lineage_idx) + + matched_indices = sorted(list(matched_indices)) + if matched_indices: + plt.plot( + matched_indices, + [pca_lineage[idx, 0] for idx in matched_indices], + "ro-", + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5) + plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5) + + # Add text labels + plt.text( + min(matched_indices), + min(pca_lineage[:, 0]), + f"Start: {min(matched_indices)}", + color="red", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(pca_lineage[:, 0]), + f"End: {max(matched_indices)}", + color="red", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - PC1") + plt.xlabel("Lineage Time") + plt.ylabel("PC1 Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Create a subplot for PC2 + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 4) + + # Plot the full lineage PC2 + plt.plot( + range(len(pca_lineage)), + pca_lineage[:, 1], + label="Full Lineage", + color="green", + alpha=0.7, + ) + + # Highlight the matched section + if matched_indices: + plt.plot( + matched_indices, + [pca_lineage[idx, 1] for idx in matched_indices], + "ro-", + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline(x=min(matched_indices), color="red", linestyle="--", alpha=0.5) + plt.axvline(x=max(matched_indices), color="red", linestyle="--", alpha=0.5) + + # Add text labels + plt.text( + min(matched_indices), + min(pca_lineage[:, 1]), + f"Start: {min(matched_indices)}", + color="red", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(pca_lineage[:, 1]), + f"End: {max(matched_indices)}", + color="red", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) - PC2") + plt.xlabel("Lineage Time") + plt.ylabel("PC2 Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.close() diff --git a/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py b/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py new file mode 100644 index 000000000..084a87eaf --- /dev/null +++ b/applications/pseudotime_analysis/simulation/demo_dtw_cellfeatures.py @@ -0,0 +1,224 @@ +# %% +import matplotlib.pyplot as plt +import numpy as np +from dtaidistance import dtw +from scipy.cluster.hierarchy import dendrogram, linkage + +# testing if we can use DTW to align cell trajectories using a short reference pattern + +np.random.seed(42) +timepoints = 50 +cells = 8 + +# Create synthetic cell trajectories (e.g., PCA1 shape evolution) +cell_trajectories = [ + np.sin(np.linspace(0, 10, timepoints) + np.random.rand() * 2) for _ in range(cells) +] +# add extra transforms to signal +cell_trajectories[cells - 1] += np.sin(np.linspace(0, 5, timepoints)) * 2 +cell_trajectories[cells - 2] += np.sin(np.linspace(0, 5, timepoints)) * 3 + +# %% +# plot cell trajectories +plt.figure(figsize=(8, 5)) +for i in range(cells): + plt.plot(cell_trajectories[i], label=f"Cell {i+1}") +plt.legend() +plt.title("Original Cell Trajectories") +plt.show() +# %% +# Set reference cell for all subsequent analysis +reference_cell = 0 # Use first cell as reference + +# Compute DTW distance matrix +dtw_matrix = np.zeros((cells, cells)) +for i in range(cells): + for j in range(i + 1, cells): + dtw_matrix[i, j] = dtw.distance(cell_trajectories[i], cell_trajectories[j]) + dtw_matrix[j, i] = dtw_matrix[i, j] + +# Print distance matrix for examination +print("DTW Distance Matrix:") +for i in range(cells): + print(f"Cell {i+1}: {dtw_matrix[reference_cell, i]:.2f}") + +# Plot distance heatmap +plt.figure(figsize=(8, 6)) +plt.imshow(dtw_matrix, cmap="viridis", origin="lower") +plt.colorbar(label="DTW Distance") +plt.title("DTW Distance Matrix Between Cells") +plt.xlabel("Cell Index") +plt.ylabel("Cell Index") +plt.tight_layout() +plt.show() + +linkage_matrix = linkage(dtw_matrix, method="ward") + +# Plot the dendrogram +plt.figure(figsize=(8, 5)) +dendrogram(linkage_matrix, labels=[f"Cell {i+1}" for i in range(cells)]) +plt.xlabel("Cells") +plt.ylabel("DTW Distance") +plt.title("Hierarchical Clustering of Cells Based on DTW") +plt.show() + +# %% +# Align cells using DTW with distance filtering +# Set a threshold for maximum allowed DTW distance +# Cells with distances above this threshold won't be aligned +# This can be set based on the distribution of distances or domain knowledge +distance_threshold = np.median(dtw_matrix[reference_cell, :]) * 1.5 # Example threshold + +print(f"Using distance threshold: {distance_threshold:.2f}") +print("Distances from reference cell:") +for i in range(cells): + distance = dtw_matrix[reference_cell, i] + status = ( + "Included" + if distance <= distance_threshold or i == reference_cell + else "Excluded (too dissimilar)" + ) + print(f"Cell {i+1}: {distance:.2f} - {status}") + +# Initialize aligned trajectories with the reference cell +aligned_cell_trajectories = [cell_trajectories[reference_cell].copy()] +alignment_status = [True] # Reference cell is always included + +for i in range(1, cells): + distance = dtw_matrix[reference_cell, i] + + # Skip cells that are too dissimilar to the reference + if distance > distance_threshold: + aligned_cell_trajectories.append( + np.full_like(cell_trajectories[reference_cell], np.nan) + ) + alignment_status.append(False) + continue + + # Find optimal warping path + path = dtw.warping_path(cell_trajectories[reference_cell], cell_trajectories[i]) + + # Create aligned trajectory by mapping query points to reference timeline + aligned_trajectory = np.zeros_like(cell_trajectories[reference_cell]) + path_dict = {} + + # Group by reference indices + for ref_idx, query_idx in path: + if ref_idx not in path_dict: + path_dict[ref_idx] = [] + path_dict[ref_idx].append(query_idx) + + # For each reference index, average the corresponding query values + for ref_idx, query_indices in path_dict.items(): + query_values = [cell_trajectories[i][idx] for idx in query_indices] + aligned_trajectory[ref_idx] = np.mean(query_values) + + aligned_cell_trajectories.append(aligned_trajectory) + alignment_status.append(True) + +# %% +# plot aligned cell trajectories (only included cells) +plt.figure(figsize=(10, 6)) + +# Plot reference cell first +plt.plot(aligned_cell_trajectories[0], "k-", linewidth=2.5, label=f"Reference (Cell 1)") + +# Plot other cells that were successfully aligned +for i in range(1, cells): + if alignment_status[i]: + plt.plot(aligned_cell_trajectories[i], label=f"Cell {i+1}") + +plt.legend() +plt.title("Aligned Cell Trajectories (Filtered by DTW Distance)") +plt.show() + +# %% +# Visualize warping paths for examples +plt.figure(figsize=(15, 10)) + +# First find cells to include based on distance threshold +included_cells = [i for i in range(1, cells) if alignment_status[i]] +excluded_cells = [i for i in range(1, cells) if not alignment_status[i]] + +# Show included cells examples +for idx, target_cell in enumerate(included_cells[: min(2, len(included_cells))]): + plt.subplot(2, 3, idx + 1) + + # Get warping path + path = dtw.warping_path( + cell_trajectories[reference_cell], cell_trajectories[target_cell] + ) + + # Plot both signals + plt.plot(cell_trajectories[reference_cell], label=f"Reference", linewidth=2) + plt.plot(cell_trajectories[target_cell], label=f"Cell {target_cell+1}", linewidth=2) + + # Plot warping connections + for ref_idx, query_idx in path: + plt.plot( + [ref_idx, query_idx], + [ + cell_trajectories[reference_cell][ref_idx], + cell_trajectories[target_cell][query_idx], + ], + "k-", + alpha=0.1, + ) + + plt.title( + f"Included - Cell {target_cell+1} (Dist: {dtw_matrix[reference_cell, target_cell]:.2f})" + ) + plt.legend() + +# Show excluded cells examples +for idx, target_cell in enumerate(excluded_cells[: min(2, len(excluded_cells))]): + plt.subplot(2, 3, 3 + idx) + + plt.plot(cell_trajectories[reference_cell], label=f"Reference", linewidth=2) + plt.plot(cell_trajectories[target_cell], label=f"Cell {target_cell+1}", linewidth=2) + + # Show distance value + plt.title( + f"Excluded - Cell {target_cell+1} (Dist: {dtw_matrix[reference_cell, target_cell]:.2f})" + ) + plt.legend() + +# Compare original and aligned for an included cell + +if included_cells: + plt.subplot(2, 3, 5) + target_cell = included_cells[0] + plt.plot(cell_trajectories[reference_cell], label="Reference", linewidth=2) + plt.plot( + cell_trajectories[target_cell], + label=f"Original Cell {target_cell+1}", + linewidth=2, + linestyle="--", + alpha=0.7, + ) + plt.plot( + aligned_cell_trajectories[target_cell], + label=f"Aligned Cell {target_cell+1}", + linewidth=2, + ) + plt.title(f"Alignment Example (Included)") + plt.legend() + +# Show distance distribution +plt.subplot(2, 3, 6) +distances = dtw_matrix[reference_cell, 1:] # Skip distance to self +plt.hist(distances, bins=10)s +plt.axvline( + distance_threshold, + color="r", + linestyle="--", + label=f"Threshold: {distance_threshold:.2f}", +) +plt.title("DTW Distance Distribution") +plt.xlabel("DTW Distance from Reference") +plt.ylabel("Count") +plt.legend() + +plt.tight_layout() +plt.show() +# %% diff --git a/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py new file mode 100644 index 000000000..3bac7f227 --- /dev/null +++ b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py @@ -0,0 +1,634 @@ +# %% +import matplotlib.pyplot as plt +import numpy as np +from dtaidistance.dtw_ndim import warping_path +from scipy.spatial.distance import cdist + +# %% +# Simulation of embeddings with temporal warping +np.random.seed(42) # For reproducibility + +# Base parameters +num_cells = 8 +num_timepoints = 30 # More timepoints for better visualization of warping +embedding_dim = 100 + +# Create a reference trajectory (Cell 1) +t_ref = np.linspace(0, 4 * np.pi, num_timepoints) # 2 complete periods (0 to 4π) +base_pattern = np.zeros((num_timepoints, embedding_dim)) + +# Generate a structured pattern with clear sinusoidal shape +for dim in range(embedding_dim): + # Use lower frequencies to ensure at least one full period + freq = 0.2 + 0.3 * np.random.rand() # Frequency between 0.2 and 0.5 + phase = np.random.rand() * np.pi # Random phase + amplitude = 0.7 + 0.6 * np.random.rand() # Amplitude between 0.7 and 1.3 + + # Create basic sine wave for this dimension + base_pattern[:, dim] = amplitude * np.sin(freq * t_ref + phase) + +# Cell embeddings with different temporal dynamics +cell_embeddings = np.zeros((num_cells, num_timepoints, embedding_dim)) + +# Cell 1 (reference) - standard progression +cell_embeddings[0] = base_pattern.copy() + +# Cell 2 - similar to reference with small variations +cell_embeddings[1] = base_pattern + np.random.randn(num_timepoints, embedding_dim) * 0.2 + +# Cell 3 - starts slow, then accelerates (time warping) +# Map [0,1] -> [0,4π] with non-linear warping +t_warped = np.power(np.linspace(0, 1, num_timepoints), 1.7) * 4 * np.pi +for dim in range(embedding_dim): + # Get the same frequencies and phases as the reference + freq = 0.2 + 0.3 * np.random.rand() + phase = np.random.rand() * np.pi + amplitude = 0.7 + 0.6 * np.random.rand() + + # Apply the warping to the timepoints + cell_embeddings[2, :, dim] = amplitude * np.sin(freq * t_warped + phase) +cell_embeddings[2] += np.random.randn(num_timepoints, embedding_dim) * 0.15 + +# Cell 4 - starts fast, then slows down (time warping) +t_warped = np.power(np.linspace(0, 1, num_timepoints), 0.6) * 4 * np.pi +for dim in range(embedding_dim): + freq = 0.2 + 0.3 * np.random.rand() + phase = np.random.rand() * np.pi + amplitude = 0.7 + 0.6 * np.random.rand() + cell_embeddings[3, :, dim] = amplitude * np.sin(freq * t_warped + phase) +cell_embeddings[3] += np.random.randn(num_timepoints, embedding_dim) * 0.15 + +# Cell 5 - missing middle section (temporal gap) +t_warped = np.concatenate( + [ + np.linspace(0, 1.5 * np.pi, num_timepoints // 2), # First 1.5 periods + np.linspace( + 2.5 * np.pi, 4 * np.pi, num_timepoints // 2 + ), # Last 1.5 periods (sshkip middle) + ] +) +for dim in range(embedding_dim): + freq = 0.2 + 0.3 * np.random.rand() + phase = np.random.rand() * np.pi + amplitude = 0.7 + 0.6 * np.random.rand() + cell_embeddings[4, :, dim] = amplitude * np.sin(freq * t_warped + phase) +cell_embeddings[4] += np.random.randn(num_timepoints, embedding_dim) * 0.15 + +# Cell 6 - phase shifted (out of sync with reference) +cell_embeddings[5] = np.roll( + base_pattern, shift=num_timepoints // 4, axis=0 +) # 1/4 cycle shift +cell_embeddings[5] += np.random.randn(num_timepoints, embedding_dim) * 0.2 + +# Cell 7 - Double frequency (faster oscillations) +for dim in range(embedding_dim): + freq = (0.2 + 0.3 * np.random.rand()) * 2 # Double frequency + phase = np.random.rand() * np.pi + amplitude = 0.7 + 0.6 * np.random.rand() + cell_embeddings[6, :, dim] = amplitude * np.sin(freq * t_ref + phase) +cell_embeddings[6] += np.random.randn(num_timepoints, embedding_dim) * 0.15 + +# Cell 8 - Very different pattern with trend +cell_embeddings[7] = np.random.randn(num_timepoints, embedding_dim) * 1.5 +trend = np.linspace(0, 3, num_timepoints).reshape(-1, 1) +cell_embeddings[7] += trend * np.random.randn(1, embedding_dim) + +# %% +# Visualize the first two dimensions of each cell's embeddings to see the temporal patterns +plt.figure(figsize=(18, 10)) + +# Create subplots for 4 dimensions +for dim in range(4): + plt.subplot(2, 2, dim + 1) + + for i in range(num_cells): + plt.plot( + range(num_timepoints), + cell_embeddings[i, :, dim], + label=f"Cell {i+1}", + linewidth=2, + ) + + plt.title(f"Dimension {dim+1} over time") + plt.xlabel("Timepoint") + plt.ylabel(f"Value (Dim {dim+1})") + plt.grid(alpha=0.3) + + if dim == 0: + plt.legend(loc="upper right") + +plt.tight_layout() +plt.show() + + +# %% +# Helper function to compute DTW warping matrix +def compute_dtw_matrix(s1, s2): + """ + Compute the DTW warping matrix and best path manually. + + Args: + s1: First sequence (reference) + s2: Second sequence (query) + + Returns: + warping_matrix: The accumulated cost matrix + best_path: The optimal warping path + """ + # Compute pairwise distances between all timepoints + distance_matrix = cdist(s1, s2) + + n, m = distance_matrix.shape + + # Initialize the accumulated cost matrix + warping_matrix = np.full((n, m), np.inf) + warping_matrix[0, 0] = distance_matrix[0, 0] + + # Fill the first column and row + for i in range(1, n): + warping_matrix[i, 0] = warping_matrix[i - 1, 0] + distance_matrix[i, 0] + for j in range(1, m): + warping_matrix[0, j] = warping_matrix[0, j - 1] + distance_matrix[0, j] + + # Fill the rest of the matrix + for i in range(1, n): + for j in range(1, m): + warping_matrix[i, j] = distance_matrix[i, j] + min( + warping_matrix[i - 1, j], # insertion + warping_matrix[i, j - 1], # deletion + warping_matrix[i - 1, j - 1], # match + ) + + # Backtrack to find the optimal path + i, j = n - 1, m - 1 + path = [(i, j)] + + while i > 0 or j > 0: + if i == 0: + j -= 1 + elif j == 0: + i -= 1 + else: + min_cost = min( + warping_matrix[i - 1, j], + warping_matrix[i, j - 1], + warping_matrix[i - 1, j - 1], + ) + + if min_cost == warping_matrix[i - 1, j - 1]: + i, j = i - 1, j - 1 + elif min_cost == warping_matrix[i - 1, j]: + i -= 1 + else: + j -= 1 + + path.append((i, j)) + + path.reverse() + + return warping_matrix, path + + +# %% +# Compute DTW distances and warping paths between cell 1 and all other cells +reference_cell = 0 # Cell 1 (0-indexed) +dtw_results = [] + +for i in range(num_cells): + if i != reference_cell: + # Get distance and path from dtaidistance + path, dist = warping_path( + cell_embeddings[reference_cell], + cell_embeddings[i], + include_distance=True, + ) + + # Compute our own warping matrix for visualization + warping_matrix, _ = compute_dtw_matrix( + cell_embeddings[reference_cell], cell_embeddings[i] + ) + + dtw_results.append( + (i + 1, dist, path, warping_matrix) + ) # Store cell number, distance, path, matrix + print(f"DTW distance between Cell 1 and Cell {i+1} dist: {dist:.4f}") + +# %% +# Visualize the DTW distances +cell_ids = [result[0] for result in dtw_results] +distances = [result[1] for result in dtw_results] + +plt.figure(figsize=(10, 6)) +plt.bar(cell_ids, distances) +plt.xlabel("Cell ID") +plt.ylabel("DTW Distance from Cell 1") +plt.title("DTW Distances from Cell 1 to Other Cells") +plt.xticks(cell_ids) +plt.tight_layout() +plt.show() + +# %% +# Create a grid of all warping matrices in a 4x2 layout +fig, axes = plt.subplots(4, 2, figsize=(16, 24)) +axes = axes.flatten() + +# Common colorbar limits for better comparison +all_matrices = [result[3] for result in dtw_results] +vmin = min(matrix.min() for matrix in all_matrices) +vmax = max(matrix.max() for matrix in all_matrices) + +# Add diagonal reference line for comparison +diagonal = np.linspace(0, num_timepoints - 1, 100) + +for i, result in enumerate(dtw_results): + cell_id, dist, path, warping_matrix = result + ax = axes[i] + + # Plot the warping matrix + im = ax.imshow( + warping_matrix, + origin="lower", + aspect="auto", + cmap="viridis", + vmin=vmin, + vmax=vmax, + ) + + # Plot diagonal reference line + ax.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal") + + # Extract and plot the best path + path_x = [p[0] for p in path] + path_y = [p[1] for p in path] + ax.plot(path_y, path_x, "r-", linewidth=2, label="Best path") + + # Add some arrows to show direction + step = max(1, len(path) // 5) # Show 5 arrows along the path + for j in range(0, len(path) - 1, step): + ax.annotate( + "", + xy=(path_y[j + 1], path_x[j + 1]), + xytext=(path_y[j], path_x[j]), + arrowprops=dict(arrowstyle="->", color="orange", lw=1.5), + ) + + # Add title and axes labels + cell_desc = "" + if cell_id == 2: + cell_desc = " (Small variations)" + elif cell_id == 3: + cell_desc = " (Slow→Fast)" + elif cell_id == 4: + cell_desc = " (Fast→Slow)" + elif cell_id == 5: + cell_desc = " (Missing middle)" + elif cell_id == 6: + cell_desc = " (Phase shift)" + elif cell_id == 7: + cell_desc = " (Double frequency)" + elif cell_id == 8: + cell_desc = " (Different pattern)" + + ax.set_title(f"Cell 1 vs Cell {cell_id}{cell_desc} (Dist: {dist:.2f})") + ax.set_xlabel("Cell {} Timepoints".format(cell_id)) + ax.set_ylabel("Cell 1 Timepoints") + + # Add legend + ax.legend(loc="lower right", fontsize=8) + +# Add a colorbar at the bottom spanning all subplots +cbar_ax = fig.add_axes([0.15, 0.05, 0.7, 0.02]) +cbar = fig.colorbar(im, cax=cbar_ax, orientation="horizontal") +cbar.set_label("Accumulated Cost") + +plt.suptitle("DTW Warping Matrices: Cell 1 vs All Other Cells", fontsize=16) +plt.tight_layout(rect=[0, 0.07, 1, 0.98]) +plt.show() + + +# %% +# Compute average embedding by aligning all cells to the reference cell +def align_to_reference( + reference: np.ndarray, query: np.ndarray, path: list[tuple[int, int]] +) -> np.ndarray: + """ + Align a query embedding to the reference timepoints based on the DTW path. + + Args: + reference: Reference embedding (n_timepoints x n_dims) + query: Query embedding to align (m_timepoints x n_dims) + path: DTW path as list of (ref_idx, query_idx) tuples + + Returns: + aligned_query: Query embeddings aligned to reference timepoints + """ + n_ref, n_dims = reference.shape + aligned_query = np.zeros_like(reference) + + # Count how many query timepoints map to each reference timepoint + counts = np.zeros(n_ref) + + # Sum query embeddings for each reference timepoint based on the path + for ref_idx, query_idx in path: + aligned_query[ref_idx] += query[query_idx] + counts[ref_idx] += 1 + + # Average when multiple query timepoints map to the same reference timepoint + for i in range(n_ref): + if counts[i] > 0: + aligned_query[i] /= counts[i] + else: + # If no query timepoints map to this reference, use nearest neighbors + nearest_idx = min(range(len(path)), key=lambda j: abs(path[j][0] - i)) + aligned_query[i] = query[path[nearest_idx][1]] + + return aligned_query + + +# Identify the top 5 most similar cells +# Sort cells by distance +sorted_results = sorted(dtw_results, key=lambda x: x[1]) # Sort by distance (x[1]) + +# Select top 5 closest cells +top_5_cells = sorted_results[:5] +print("Top 5 cells (by DTW distance to reference):") +for cell_id, dist, _, _ in top_5_cells: + print(f"Cell {cell_id}: Distance = {dist:.4f}") + +# First, align all the top 5 cells to the reference timepoints +aligned_cells = [] +all_cell_ids = [reference_cell + 1] + [cell_id for cell_id, _, _, _ in top_5_cells] + +# Include reference cell as-is (it's already aligned) +aligned_cells.append(cell_embeddings[reference_cell]) + +# Align each of the top 5 cells +for cell_id, _, path, _ in top_5_cells: + cell_idx = cell_id - 1 # Convert to 0-indexed + + # Get aligned version of this cell + aligned_cell = align_to_reference( + cell_embeddings[reference_cell], cell_embeddings[cell_idx], path + ) + + # Store the aligned cell + aligned_cells.append(aligned_cell) + +# %% +# Visualize the aligned cells before averaging +plt.figure(figsize=(18, 10)) + +# Create subplots for the first 4 dimensions +for dim in range(4): + plt.subplot(2, 2, dim + 1) + + # Plot each aligned cell + for i, (cell_id, aligned_cell) in enumerate(zip(all_cell_ids, aligned_cells)): + if i == 0: + # Reference cell + plt.plot( + range(num_timepoints), + aligned_cell[:, dim], + "b-", + linewidth=2, + label=f"Cell 1 (Reference)", + ) + else: + # Other aligned cells + plt.plot( + range(num_timepoints), + aligned_cell[:, dim], + "g-", + alpha=0.5, + linewidth=1, + label=f"Cell {cell_id} (Aligned)" if i == 1 else None, + ) + + plt.title(f"Dimension {dim+1}: Aligned Cells") + plt.xlabel("Reference Timepoint") + plt.ylabel(f"Value (Dim {dim+1})") + plt.grid(alpha=0.3) + plt.legend() + +plt.tight_layout() +plt.suptitle("Cells Aligned to Reference before Averaging", fontsize=16, y=1.02) +plt.show() + +# %% +# Now compute the average of the aligned cells +average_embedding = np.zeros_like(cell_embeddings[reference_cell]) + +# Add all aligned cells +for aligned_cell in aligned_cells: + average_embedding += aligned_cell + +# Divide by number of cells +average_embedding /= len(aligned_cells) + +# %% +# Visualize the original embeddings and the average embedding +plt.figure(figsize=(18, 8)) + +# Create subplots for the first 4 dimensions +for dim in range(4): + plt.subplot(2, 2, dim + 1) + + # Plot all original cells (transparent) + for i in range(num_cells): + # Determine if this cell is in the top 5 + is_top5 = False + for cell_id, _, _, _ in top_5_cells: + if i + 1 == cell_id: # Convert 1-indexed to 0-indexed + is_top5 = True + break + + # Style based on cell type + alpha = 0.3 + color = "gray" + label = None + + if i == reference_cell: + alpha = 0.7 + color = "blue" + label = "Cell 1 (Reference)" + elif is_top5: + alpha = 0.5 + color = "green" + if dim == 0 and i == top_5_cells[0][0] - 1: # Only label once + label = "Top 5 Cells" + + plt.plot( + range(num_timepoints), + cell_embeddings[i, :, dim], + alpha=alpha, + color=color, + linewidth=1, + label=label, + ) + + # Plot the average embedding + plt.plot( + range(num_timepoints), + average_embedding[:, dim], + "r-", + linewidth=2, + label="Average Embedding", + ) + + plt.title(f"Dimension {dim+1}: Original vs Average") + plt.xlabel("Timepoint") + plt.ylabel(f"Value (Dim {dim+1})") + plt.grid(alpha=0.3) + plt.legend() + +plt.tight_layout() +plt.suptitle( + "Average Embedding from Top 5 Similar Cells (via DTW)", fontsize=16, y=1.02 +) +plt.show() + +# %% +# Evaluate the average embedding as a reference +# Compute DTW distances from average to all cells +average_dtw_results = [] + +for i in range(num_cells): + # Get distance and path from the average to each cell + path, dist = warping_path( + average_embedding, + cell_embeddings[i], + include_distance=True, + ) + + # Compute warping matrix for visualization + warping_matrix, _ = compute_dtw_matrix(average_embedding, cell_embeddings[i]) + + average_dtw_results.append((i + 1, dist, path, warping_matrix)) + print(f"DTW distance between Average and Cell {i+1} dist: {dist:.4f}") + +# %% +# Compare distances: Cell 1 as reference vs Average as reference +# Combine the DTW distances for comparison +comparison_data = [] + +# Add Cell 1 reference distances +for i in range(num_cells): + if i == reference_cell: + # Distance to self is 0 + comparison_data.append( + { + "Cell ID": i + 1, + "To Cell 1": 0.0, + "To Average": average_dtw_results[i][1], + } + ) + else: + # Find the matching result from dtw_results + for cell_id, dist, _, _ in dtw_results: + if cell_id == i + 1: + comparison_data.append( + { + "Cell ID": i + 1, + "To Cell 1": dist, + "To Average": average_dtw_results[i][1], + } + ) + break + +# Prepare bar chart data +cell_ids = [d["Cell ID"] for d in comparison_data] +to_cell1 = [d["To Cell 1"] for d in comparison_data] +to_average = [d["To Average"] for d in comparison_data] + +# Compute some statistics +total_to_cell1 = sum(to_cell1) +total_to_average = sum(to_average) +avg_to_cell1 = total_to_cell1 / len(cell_ids) +avg_to_average = total_to_average / len(cell_ids) + +# Create a comparison bar chart +plt.figure(figsize=(12, 6)) +x = np.arange(len(cell_ids)) +width = 0.35 + +plt.bar(x - width / 2, to_cell1, width, label="Distance to Cell 1") +plt.bar(x + width / 2, to_average, width, label="Distance to Average") + +plt.xlabel("Cell ID") +plt.ylabel("DTW Distance") +plt.title("Comparison: Cell 1 vs Average as Reference") +plt.xticks(x, cell_ids) +plt.legend() + +# Add summary as text annotation +plt.figtext( + 0.5, + 0.01, + f"Total distance - Cell 1: {total_to_cell1:.2f}, Average: {total_to_average:.2f}\n" + f"Mean distance - Cell 1: {avg_to_cell1:.2f}, Average: {avg_to_average:.2f}", + ha="center", + fontsize=10, + bbox=dict(facecolor="white", alpha=0.8), +) + +plt.tight_layout(rect=[0, 0.05, 1, 0.95]) +plt.show() + +# %% +# Visualize selected warping matrices using the average as reference +# Select a few representative cells (one similar, one different) +similar_cell_idx = 1 # Cell 2 +different_cell_idx = 7 # Cell 8 + +plt.figure(figsize=(16, 7)) + +# Plot warping matrix for similar cell +plt.subplot(1, 2, 1) +warping_matrix = average_dtw_results[similar_cell_idx - 1][3] +path = average_dtw_results[similar_cell_idx - 1][2] +dist = average_dtw_results[similar_cell_idx - 1][1] + +plt.imshow(warping_matrix, origin="lower", aspect="auto", cmap="viridis") +plt.colorbar(label="Accumulated Cost") + +# Plot diagonal and path +diagonal = np.linspace(0, num_timepoints - 1, 100) +plt.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal") + +# Extract and plot the best path +path_x = [p[0] for p in path] +path_y = [p[1] for p in path] +plt.plot(path_y, path_x, "r-", linewidth=2, label="Best path") + +plt.title(f"Average vs Cell {similar_cell_idx} (Similar, Dist: {dist:.2f})") +plt.xlabel(f"Cell {similar_cell_idx} Timepoints") +plt.ylabel("Average Timepoints") +plt.legend(loc="lower right", fontsize=8) + +# Plot warping matrix for different cell +plt.subplot(1, 2, 2) +warping_matrix = average_dtw_results[different_cell_idx - 1][3] +path = average_dtw_results[different_cell_idx - 1][2] +dist = average_dtw_results[different_cell_idx - 1][1] + +plt.imshow(warping_matrix, origin="lower", aspect="auto", cmap="viridis") +plt.colorbar(label="Accumulated Cost") + +# Plot diagonal and path +plt.plot(diagonal, diagonal, "w--", alpha=0.7, linewidth=1, label="Diagonal") + +# Extract and plot the best path +path_x = [p[0] for p in path] +path_y = [p[1] for p in path] +plt.plot(path_y, path_x, "r-", linewidth=2, label="Best path") + +plt.title(f"Average vs Cell {different_cell_idx} (Different, Dist: {dist:.2f})") +plt.xlabel(f"Cell {different_cell_idx} Timepoints") +plt.ylabel("Average Timepoints") +plt.legend(loc="lower right", fontsize=8) + +plt.suptitle("DTW Warping Matrices: Average as Reference", fontsize=16) +plt.tight_layout(rect=[0, 0, 1, 0.95]) +plt.show() + +# %% diff --git a/applications/pseudotime_analysis/utils.py b/applications/pseudotime_analysis/utils.py new file mode 100644 index 000000000..4523f2825 --- /dev/null +++ b/applications/pseudotime_analysis/utils.py @@ -0,0 +1,273 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import xarray as xr + + +def load_annotation( + embedding_dataset: xr.Dataset, + track_csv_path: str, + name: str, + categories: dict | None = None, +) -> pd.Series: + """ + Load annotations from a CSV file and map them to the dataset. + """ + annotation = pd.read_csv(track_csv_path) + annotation["fov_name"] = "/" + annotation["fov ID"] + + embedding_index = pd.MultiIndex.from_arrays( + [ + embedding_dataset["fov_name"].values, + embedding_dataset["id"].values, + embedding_dataset["t"].values, + embedding_dataset["track_id"].values, + ], + names=["fov_name", "id", "t", "track_id"], + ) + + annotation = annotation.set_index(["fov_name", "id", "t", "track_id"]) + selected = annotation.reindex(embedding_index)[name] + + if categories: + if -1 in selected.values and -1 not in categories: + categories = categories.copy() + categories[-1] = np.nan + selected = selected.map(categories) + + return selected + + +def identify_lineages(annotations_path: Path) -> list[tuple[str, list[int]]]: + """ + Identifies all distinct lineages in the cell tracking data, following only + one branch after each division event. + + Args: + annotations_path: Path to the annotations CSV file + + Returns: + A list of tuples, where each tuple contains (fov_id, [track_ids]) + representing a single branch lineage within a single FOV + """ + # Read the CSV file + df = pd.read_csv(annotations_path) + + # Ensure column names are consistent + if "fov ID" in df.columns and "fov_id" not in df.columns: + df["fov_id"] = df["fov ID"] + + # Process each FOV separately to handle repeated track_ids + all_lineages = [] + + # Group by FOV + for fov_id, fov_df in df.groupby("fov_id"): + # Create a dictionary to map tracks to their parents within this FOV + child_to_parent = {} + + # Group by track_id and get the first row for each track to find its parent + for track_id, track_group in fov_df.groupby("track_id"): + first_row = track_group.iloc[0] + parent_track_id = first_row["parent_track_id"] + + if parent_track_id != -1: + child_to_parent[track_id] = parent_track_id + + # Find root tracks (those without parents or with parent_track_id = -1) + all_tracks = set(fov_df["track_id"].unique()) + child_tracks = set(child_to_parent.keys()) + root_tracks = all_tracks - child_tracks + + # Build a parent-to-children mapping + parent_to_children = {} + for child, parent in child_to_parent.items(): + if parent not in parent_to_children: + parent_to_children[parent] = [] + parent_to_children[parent].append(child) + + # Function to get a single branch from each parent + # We'll choose the first child in the list (arbitrary choice) + def get_single_branch(track_id): + branch = [track_id] + if track_id in parent_to_children: + # Choose only the first child (you could implement other selection criteria) + first_child = parent_to_children[track_id][0] + branch.extend(get_single_branch(first_child)) + return branch + + # Build lineages starting from root tracks within this FOV + for root_track in root_tracks: + # Get a single branch from this root + lineage_tracks = get_single_branch(root_track) + all_lineages.append((fov_id, lineage_tracks)) + + return all_lineages + + +def path_skew(warping_path, ref_len, query_len): + """ + Calculate the skewness of a DTW warping path. + + Args: + warping_path: List of tuples (ref_idx, query_idx) representing the warping path + ref_len: Length of the reference sequence + query_len: Length of the query sequence + + Returns: + A skewness metric between 0 and 1, where: + - 0 means perfectly diagonal path (ideal alignment) + - 1 means completely skewed (worst alignment) + """ + # Convert path to numpy array for easier manipulation + path_array = np.array(warping_path) + + # Calculate "ideal" diagonal indices + diagonal_x = np.linspace(0, ref_len - 1, len(warping_path)) + diagonal_y = np.linspace(0, query_len - 1, len(warping_path)) + diagonal_path = np.column_stack((diagonal_x, diagonal_y)) + + # Calculate distances from points on the warping path to the diagonal + # Normalize based on max possible distance (corner to diagonal) + max_distance = max(ref_len, query_len) + + # Calculate distance from each point to the diagonal + distances = [] + for i, (x, y) in enumerate(path_array): + # Find the closest point on the diagonal + dx, dy = diagonal_path[i] + # Simple Euclidean distance + dist = np.sqrt((x - dx) ** 2 + (y - dy) ** 2) + distances.append(dist) + + # Average normalized distance as skewness metric + skew = np.mean(distances) / max_distance + + return skew + + +def dtw_with_matrix(distance_matrix, normalize=True): + """ + Compute DTW using a pre-computed distance matrix. + + Args: + distance_matrix: Pre-computed distance matrix between two sequences + normalize: Whether to normalize the distance by path length (default: True) + + Returns: + dtw_distance: The DTW distance + warping_matrix: The accumulated cost matrix + best_path: The optimal warping path + """ + n, m = distance_matrix.shape + + # Initialize the accumulated cost matrix + warping_matrix = np.full((n, m), np.inf) + warping_matrix[0, 0] = distance_matrix[0, 0] + + # Fill the first column and row + for i in range(1, n): + warping_matrix[i, 0] = warping_matrix[i - 1, 0] + distance_matrix[i, 0] + for j in range(1, m): + warping_matrix[0, j] = warping_matrix[0, j - 1] + distance_matrix[0, j] + + # Fill the rest of the matrix + for i in range(1, n): + for j in range(1, m): + warping_matrix[i, j] = distance_matrix[i, j] + min( + warping_matrix[i - 1, j], # insertion + warping_matrix[i, j - 1], # deletion + warping_matrix[i - 1, j - 1], # match + ) + + # Backtrack to find the optimal path + i, j = n - 1, m - 1 + path = [(i, j)] + + while i > 0 or j > 0: + if i == 0: + j -= 1 + elif j == 0: + i -= 1 + else: + min_cost = min( + warping_matrix[i - 1, j], + warping_matrix[i, j - 1], + warping_matrix[i - 1, j - 1], + ) + + if min_cost == warping_matrix[i - 1, j - 1]: + i, j = i - 1, j - 1 + elif min_cost == warping_matrix[i - 1, j]: + i -= 1 + else: + j -= 1 + + path.append((i, j)) + + path.reverse() + + # Get the DTW distance (bottom-right cell) + dtw_distance = warping_matrix[n - 1, m - 1] + + # Normalize by path length if requested + if normalize: + dtw_distance = dtw_distance / len(path) + + return dtw_distance, warping_matrix, path + + +# %% +def filter_lineages_by_timepoints(lineages, annotation_path, min_timepoints=10): + """ + Filter lineages that have fewer than min_timepoints total timepoints. + + Args: + lineages: List of tuples (fov_id, [track_ids]) + annotation_path: Path to the annotations CSV file + min_timepoints: Minimum number of timepoints required + + Returns: + List of filtered lineages + """ + # Read the annotations file + df = pd.read_csv(annotation_path) + + # Ensure column names are consistent + if "fov ID" in df.columns and "fov_id" not in df.columns: + df["fov_id"] = df["fov ID"] + + filtered_lineages = [] + + for fov_id, track_ids in lineages: + # Get all rows for this lineage + lineage_rows = df[(df["fov_id"] == fov_id) & (df["track_id"].isin(track_ids))] + + # Count the total number of timepoints + total_timepoints = len(lineage_rows) + + # Only keep lineages with at least min_timepoints + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) + + return filtered_lineages + + +def find_top_matching_tracks(cell_division_df, infection_df, n_top=10) -> pd.DataFrame: + # Find common tracks between datasets + intersection_df = pd.merge( + cell_division_df, + infection_df, + on=["fov_name", "track_ids"], + how="inner", + suffixes=("_df1", "_df2"), + ) + + # Add column with sum of the values + intersection_df["distance_sum"] = ( + intersection_df["distance_df1"] + intersection_df["distance_df2"] + ) + + # Find rows with the smallest sum + intersection_df.sort_values(by="distance_sum", ascending=True, inplace=True) + return intersection_df.head(n_top) diff --git a/docs/figures/DynaCLR_schematic_v2.png b/docs/figures/DynaCLR_schematic_v2.png new file mode 100644 index 000000000..274f89b62 Binary files /dev/null and b/docs/figures/DynaCLR_schematic_v2.png differ diff --git a/docs/figures/dynaCLR_schematic.png b/docs/figures/dynaCLR_schematic.png deleted file mode 100644 index eb591cb7d..000000000 Binary files a/docs/figures/dynaCLR_schematic.png and /dev/null differ diff --git a/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md new file mode 100644 index 000000000..252b70bb8 --- /dev/null +++ b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md @@ -0,0 +1,94 @@ +# Cell Infection Analysis Demo: ImageNet vs DynaCLR-DENV-VS+Ph model + +This demo compares different feature extraction methods for analyzing infected vs uninfected cells using microscopy images. + +As the cells get infected, the red fluorescence protein (RFP) translocates from the cytoplasm into the nucleus. + +## Overview + +The `demo_infection.py` script demonstrates: + + - PHATE plots from the embeddings generated from DynaCLR and ImageNet + - Show the infection progression in cells via Phase and RFP (viral sensor) channels + - Highlighted trajectories for sample infected and uninfected cells over time + +## Key Features + +- **Feature Extraction**: Compare ImageNet pre-trained and specialized DynaCLR features +- **Interactive Visualization**: Create plotly-based visualizations with time sliders +- **Side-by-Side Comparison**: Directly compare cell images and PHATE embeddings +- **Trajectory Analysis**: Visualize and track cell trajectories over time +- **Infection State Analysis**: See how different models capture infection dynamics + + +## Usage + +After [setting up the environment and downloading the data](/examples/DynaCLR/README.md#setup), activate it and run the demo script: + +```bash +conda activate dynaclr +python demo_infection.py +``` + +For both of these you will need to ensure to point to the path to the downloaded data: +```python +# Update these paths to your data +input_data_path = "/path/to/registered_test.zarr" +tracks_path = "/path/to/track_test.zarr" +ann_path = "/path/to/extracted_inf_state.csv" + +# Update paths to features +dynaclr_features_path = "/path/to/dynaclr_features.zarr" +imagenet_features_path = "/path/to/imagenet_features.zarr" +``` + +Check out the demo's output visualization: + +- [Open Visualization](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/cell_infection_visualization.html) + +Note: You may need to press pause/play for the image to show + +## (OPTIONAL) Generating DynaCLR-DENV-VS+PH Features + +1. Open the `dynaclr_denv-vs-ph_test_data.yml` and modify the following to point to your download: + +- Replace with the output path (`.zarr`) for the embeddings. +```yaml + callbacks: + - class_path: viscy.representation.embedding_writer.EmbeddingWriter + init_args: + output_path: '/TODO_REPLACE_TO_OUTPUT_PATH.zarr' #Select the path to save +``` + +- Point to the downloaded checkpoint for DynaCLR-DENV-VS+Ph +```yaml + ckpt_path: '/downloaded.ckpt' # Point to ckpt file + ``` + +2. Run the following CLI to run inference +```bash +viscy predict -c dynaclr_denv-vs-ph_test_data.yml +``` + +## (OPTIONAL) Generating ImageNet Features + +To generate ImageNet features for your own data, you can use the `imagenet_embeddings.py` script [here](../../../applications/benchmarking/DynaCLR/ImageNet/config.yml): + +1. Modify the `infection_example_config.yml` lines: + +```yaml +paths: + data_path: /path/to/downloaded/registered_test.zarr + tracks_path: /path/to/downloaded/track_test.zarr + output_path: /path/to/output.zarr +``` + +2. You can run the python script: + +```bash +# Navigate to the ImageNet scripts directory +cd ../../applications/benchmarking/DynaCLR/ImageNet + +# Run the script with the example config +python imagenet_embeddings.py -c infection_example_config.yml +``` \ No newline at end of file diff --git a/examples/DynaCLR/DynaCLR-DENV-VS-Ph/demo_infection.py b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/demo_infection.py new file mode 100644 index 000000000..986953672 --- /dev/null +++ b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/demo_infection.py @@ -0,0 +1,244 @@ +# %% [markdown] +# # Demo: Comparing DynaCLR vs ImageNet Embeddings for Cell Infection Analysis +# +# This tutorial demonstrates how to: +# 1. Use ImageNet pre-trained features for analyzing cell infection +# 2. Compare with DynaCLR learned features +# 3. Visualize the differences between approaches + +# %% [markdown] +# ## Setup and Imports + +# %% +from pathlib import Path + +import numpy as np +import pandas as pd +from skimage.exposure import rescale_intensity + +from utils import ( + create_combined_visualization, +) +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% [markdown] +# ## Set Data Paths +# +# The data, tracks, annotations and precomputed embeddings can be downloaded from [here]() +# +# ## Note: +# +# Alternatively, you can run the CLI to compute the features yourself by following the instructions in the [README.md](./README.md) + +# %% +# TODO: Update the paths to the downloaded data +# Point to the *.zarr files +download_root = Path.home() / "data/dynaclr/demo" +input_data_path = ( + download_root / "registered_test.zarr" +) # Replace with path to registered_test.zarr +tracks_path = download_root / "track_test.zarr" # Replace with path to track_test.zarr +ann_path = ( + download_root / "extracted_inf_state.csv" +) # Replace with path to extracted_inf_state.csv + +# TODO: Update the path to the DynaCLR and ImageNet features +# Point to the precomputed embeddings +dynaclr_features_path = ( + download_root / "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr" +) +imagenet_features_path = ( + download_root + / "precomputed_embeddings/20240204_A549_DENV_ZIKV_sensor_only_imagenet.zarr" +) + +# %% [markdown] +# ## Load the embeddings and annotations +# Load the embeddings you downloaded and append the human annotations to the dataframe + +# %% +# Load the embeddings +dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path) +imagenet_embeddings = read_embedding_dataset(imagenet_features_path) + +dynaclr_features_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True) +imagenet_features_df = ( + imagenet_embeddings["sample"].to_dataframe().reset_index(drop=True) +) + +# Load the annotations and create a dataframe with the infection state +annotation = pd.read_csv(ann_path) +annotation["fov_name"] = "/" + annotation["fov_name"] + +imagenet_features_df["infection"] = float("nan") + +for index, row in annotation.iterrows(): + mask = ( + (imagenet_features_df["fov_name"] == row["fov_name"]) + & (imagenet_features_df["track_id"] == row["track_id"]) + & (imagenet_features_df["t"] == row["t"]) + ) + imagenet_features_df.loc[mask, "infection"] = row["infection_state"] + mask = ( + (dynaclr_features_df["fov_name"] == row["fov_name"]) + & (dynaclr_features_df["track_id"] == row["track_id"]) + & (dynaclr_features_df["t"] == row["t"]) + ) + dynaclr_features_df.loc[mask, "infection"] = row["infection_state"] + +# Filter out rows with infection state 0 +imagenet_features_df = imagenet_features_df[imagenet_features_df["infection"] != 0] +dynaclr_features_df = dynaclr_features_df[dynaclr_features_df["infection"] != 0] + +# %% [markdown] +# ## Choose a representative track for visualization + +# %% +# NOTE: We have chosen these tracks to be representative of the data. Feel free to open the dataset and select other tracks +fov_name_mock = "/A/3/9" +track_id_mock = [19] +fov_name_inf = "/B/4/9" +track_id_inf = [42] + +# Default parameters for the test dataset +z_range = (24, 29) +yx_patch_size = (160, 160) + +channels_to_display = ["Phase3D", "RFP"] +fov_name_mock_list = [fov_name_mock] * len(track_id_mock) +fov_name_inf_list = [fov_name_inf] * len(track_id_inf) + +conditions_to_compare = { + "uninfected": { + "fov_name_list": fov_name_mock_list, + "track_id_list": track_id_mock, + }, + "infected": { + "fov_name_list": fov_name_inf_list, + "track_id_list": track_id_inf, + }, +} + +print("Caching sample images...") +image_cache = {} +for condition, condition_data in conditions_to_compare.items(): + dm = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=channels_to_display, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + include_fov_names=condition_data["fov_name_list"] + * len(condition_data["track_id_list"]), + include_track_ids=condition_data["track_id_list"], + predict_cells=True, + batch_size=1, + ) + dm.setup("predict") + + condition_key = f"{condition}_cache" + image_cache[condition_key] = { + "fov_name": None, + "track_id": None, + "images_by_timepoint": {}, + } + for i, patch in enumerate(dm.predict_dataloader()): + fov_name = patch["index"]["fov_name"][0] + track_id = patch["index"]["track_id"][0] + images = patch["anchor"].numpy()[0] + t = int(patch["index"]["t"][0]) + + if image_cache[condition_key]["fov_name"] is None: + image_cache[condition_key]["fov_name"] = fov_name + image_cache[condition_key]["track_id"] = track_id + + z_idx = images.shape[1] // 2 + C, Z, Y, X = images.shape + image_out = np.zeros((C, 1, Y, X), dtype=np.float32) + # NOTE: here we are using the default percentile range for the RFP channel, change if using different channels or this threshold does not work + for c_idx, channel in enumerate(channels_to_display): + if channel in ["Phase3D", "DIC", "BF"]: + image_out[c_idx] = images[c_idx, z_idx] + image_out[c_idx] = ( + image_out[c_idx] - image_out[c_idx].mean() + ) / image_out[c_idx].std() + image_out[c_idx] = rescale_intensity(image_out[c_idx], out_range=(0, 1)) + else: + image_out[c_idx] = np.max(images[c_idx], axis=0) + lower, upper = np.percentile(image_out[c_idx], (50, 99)) + image_out[c_idx] = (image_out[c_idx] - lower) / (upper - lower) + image_out[c_idx] = rescale_intensity(image_out[c_idx], out_range=(0, 1)) + + image_cache[condition_key]["images_by_timepoint"][t] = image_out + + print( + f"Cached {condition_key} with {len(image_cache[condition_key]['images_by_timepoint'])} timepoints" + ) + +# %% +print("Creating Cell Images and PHATE Embeddings Visualization...") +create_combined_visualization( + image_cache, + imagenet_features_df, + dynaclr_features_df, + highlight_tracks={ + 1: [(fov_name_mock, track_id_mock[0])], # Uninfected tracks + 2: [(fov_name_inf, track_id_inf[0])], # Infected tracks + }, + subplot_titles=[ + "Uninfected Phase", + "Uninfected Viral Sensor", + "Infected Phase", + "Infected Viral Sensor", + ], + condition_keys=["uninfected_cache", "infected_cache"], + channel_colormaps=["gray", "magma"], + category_colors={1: "cornflowerblue", 2: "salmon"}, + highlight_colors={1: "blue", 2: "red"}, + category_labels={1: "Uninfected", 2: "Infected"}, + plot_size_xy=(1200, 600), + title_location="top", +) + +# Save the visualization as an interactive HTML file +fig = create_combined_visualization( + image_cache, + imagenet_features_df, + dynaclr_features_df, + highlight_tracks={ + 1: [(fov_name_mock, track_id_mock[0])], # Uninfected tracks + 2: [(fov_name_inf, track_id_inf[0])], # Infected tracks + }, + subplot_titles=[ + "Uninfected Phase", + "Uninfected Viral Sensor", + "Infected Phase", + "Infected Viral Sensor", + ], + condition_keys=["uninfected_cache", "infected_cache"], + channel_colormaps=["gray", "magma"], + category_colors={1: "cornflowerblue", 2: "salmon"}, + highlight_colors={1: "blue", 2: "red"}, + category_labels={1: "Uninfected", 2: "Infected"}, + plot_size_xy=(1200, 600), + title_location="top", +) + +# Create output directory if it doesn't exist +output_dir = Path("output") +output_dir.mkdir(exist_ok=True) + +# Save the interactive visualization +output_path = output_dir / "cell_infection_visualization.html" +fig.write_html(str(output_path)) +print(f"Saved interactive visualization to: {output_path}") + +# %% [markdown] +# ## Conclusion +# +# Time-aware sampling improved temporal continutiy and dynamic range of embeddings. +# These improvements can be seen in the PHATE projections of DynaCLR. +# The embeddings show smoother and higher dynamic range. +# diff --git a/examples/DynaCLR/DynaCLR-DENV-VS-Ph/utils.py b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/utils.py new file mode 100644 index 000000000..64f7906d9 --- /dev/null +++ b/examples/DynaCLR/DynaCLR-DENV-VS-Ph/utils.py @@ -0,0 +1,1240 @@ +"""Utility functions for visualization and analysis.""" + +import warnings + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib import cm +from skimage.exposure import rescale_intensity + + +def add_arrows(df, color, df_coordinates=["PHATE1", "PHATE2"]): + """ + Add arrows to a plot to show direction of trajectory. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame containing custom coordinates (like PHATE coordinates (PHATE1, PHATE2)) + color : str + Color for the arrows + """ + from matplotlib.patches import FancyArrowPatch + + for i in range(df.shape[0] - 1): + start = df.iloc[i] + end = df.iloc[i + 1] + arrow = FancyArrowPatch( + (start[df_coordinates[0]], start[df_coordinates[1]]), + (end[df_coordinates[0]], end[df_coordinates[1]]), + color=color, + arrowstyle="-", + mutation_scale=10, + lw=1, + shrinkA=0, + shrinkB=0, + ) + plt.gca().add_patch(arrow) + + +def plot_phate_time_trajectories( + df, + output_dir="./phate_timeseries", + highlight_tracks=None, +): + """ + Generate a series of PHATE embedding plots for each timepoint, showing trajectories. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame containing the PHATE embeddings + output_dir : str, optional + Directory to save the PNG files, by default "./phate_timeseries" + highlight_tracks : dict, optional + Dictionary specifying tracks to highlight, by default None + """ + import os + + import matplotlib.pyplot as plt + from matplotlib.lines import Line2D + + if highlight_tracks is None: + # Default tracks to highlight + highlight_tracks = { + "infected": [("/B/4/9", 42)], + "uninfected": [("/A/3/9", 19)], + } + + os.makedirs(output_dir, exist_ok=True) + + # Get unique time points + all_times = sorted(df["t"].unique()) + + # Calculate global axis limits to keep them consistent + padding = 0.1 # Add padding to the limits for better visualization + x_min = df["PHATE1"].min() - padding * (df["PHATE1"].max() - df["PHATE1"].min()) + x_max = df["PHATE1"].max() + padding * (df["PHATE1"].max() - df["PHATE1"].min()) + y_min = df["PHATE2"].min() - padding * (df["PHATE2"].max() - df["PHATE2"].min()) + y_max = df["PHATE2"].max() + padding * (df["PHATE2"].max() - df["PHATE2"].min()) + + # Make sure the aspect ratio is 1:1 by using the same range for both axes + x_range = x_max - x_min + y_range = y_max - y_min + if x_range > y_range: + # Expand y-limits to match x-range + center = (y_max + y_min) / 2 + y_min = center - x_range / 2 + y_max = center + x_range / 2 + else: + # Expand x-limits to match y-range + center = (x_max + x_min) / 2 + x_min = center - y_range / 2 + x_max = center + y_range / 2 + + # Generate plots for each time step + for t_idx, t in enumerate(all_times): + plt.close("all") + fig, ax = plt.figure(figsize=(10, 10)), plt.subplot(111) + + # Plot historical points in gray (all points from previous time steps) + if t_idx > 0: + historical_df = df[df["t"] < t] + ax.scatter( + historical_df["PHATE1"], + historical_df["PHATE2"], + c="lightgray", + s=10, + alpha=0.15, + ) + + # Plot current time points + current_df = df[df["t"] == t] + + # Plot infected vs uninfected points for current time + for infection_state, color in [(1, "cornflowerblue"), (2, "salmon")]: + points = current_df[current_df["infection"] == infection_state] + ax.scatter(points["PHATE1"], points["PHATE2"], c=color, s=30, alpha=0.7) + + # Add track trajectories for highlighted cells + for label, track_list in highlight_tracks.items(): + for fov_name, track_id in track_list: + # Get all timepoints up to current time for this track + track_data = df[ + (df["fov_name"] == fov_name) + & (df["track_id"] == track_id) + & (df["t"] <= t) + ].sort_values("t") + + if len(track_data) > 0: + # Draw trajectory using arrows + color = "red" if label == "infected" else "blue" + + if len(track_data) > 1: + # Use the arrow function that works with PHATE1/PHATE2 columns + add_arrows( + track_data, color, df_coordinates=["PHATE1", "PHATE2"] + ) + + # Mark current position with a larger point + current_pos = track_data[track_data["t"] == t] + if len(current_pos) > 0: + ax.scatter( + current_pos["PHATE1"], + current_pos["PHATE2"], + s=150, + color=color, + edgecolor="black", + linewidth=1.5, + zorder=10, + ) + + # Set the same axis limits for all frames + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + + # Add legend + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="blue", + markersize=8, + label="Uninfected", + ), + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="red", + markersize=8, + label="Infected", + ), + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="blue", + markersize=12, + markeredgecolor="black", + label="Highlighted Uninfected Track", + ), + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="red", + markersize=12, + markeredgecolor="black", + label="Highlighted Infected Track", + ), + ] + ax.legend(handles=legend_elements, loc="upper right") + + # Add labels and title with time info + ax.set_title(f"ImageNet PHATE Embedding - Time: {t}") + ax.set_xlabel("PHATE1") + ax.set_ylabel("PHATE2") + + # Set equal aspect ratio for better visualization + ax.set_aspect("equal") + + # Save figure + plt.tight_layout() + plt.savefig( + f"{output_dir}/phate_embedding_t{t:03d}.png", dpi=300, bbox_inches="tight" + ) + + # Only show the first frame in the notebook + if t == all_times[0]: + plt.show() + + +def create_plotly_visualization( + df, + highlight_tracks=None, + df_coordinates=["PHATE1", "PHATE2"], + time_column="t", + category_column="infection", + category_labels={1: "Uninfected", 2: "Infected"}, + category_colors={1: "cornflowerblue", 2: "salmon"}, + highlight_colors={1: "blue", 2: "red"}, + title_prefix="PHATE Embedding", + plot_size_xy=(1000, 800), +): + """ + Create an interactive visualization using Plotly with a time slider. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame containing the embedding coordinates + highlight_tracks : dict, optional + Dictionary specifying tracks to highlight, by default None + Format: {category_name: [(fov_name, track_id), ...]} + e.g., {"infected": [("/B/4/9", 42)], "uninfected": [("/A/3/9", 19)]} + or {1: [("/A/3/9", 19)], 2: [("/B/4/9", 42)]} where 1=uninfected, 2=infected + df_coordinates : list, optional + Column names for the x and y coordinates, by default ["PHATE1", "PHATE2"] + time_column : str, optional + Column name for the time points, by default "t" + category_column : str, optional + Column name for the category to color by, by default "infection" + category_labels : dict, optional + Mapping from category values to display labels, by default {1: "Uninfected", 2: "Infected"} + category_colors : dict, optional + Mapping from category values to colors for markers, by default {1: "cornflowerblue", 2: "salmon"} + highlight_colors : dict, optional + Mapping from category values to colors for highlighted tracks, by default {1: "blue", 2: "red"} + title_prefix : str, optional + Prefix for the plot title, by default "PHATE Embedding" + plot_size_xy : tuple, optional + Width and height of the plot in pixels, by default (1000, 800) + + Returns + ------- + plotly.graph_objects.Figure + The interactive Plotly figure + """ + # Check if plotly is available + try: + import plotly.graph_objects as go + except ImportError: + print("Plotly is not installed. Please install it using: pip install plotly") + return None + + highlight_track_map = {} + category_value_map = {"uninfected": 1, "infected": 2} + for key, tracks in highlight_tracks.items(): + # If the key is a string like "infected" or "uninfected", convert to category value + if isinstance(key, str) and key.lower() in category_value_map: + category = category_value_map[key.lower()] + else: + # Otherwise use the key directly (assumed to be a category value) + category = key + highlight_track_map[category] = tracks + + # Get unique time points and categories + all_times = sorted(df[time_column].unique()) + categories = sorted(df[category_column].unique()) + + # Calculate global axis limits + padding = 0.1 + x_min = df[df_coordinates[0]].min() - padding * ( + df[df_coordinates[0]].max() - df[df_coordinates[0]].min() + ) + x_max = df[df_coordinates[0]].max() + padding * ( + df[df_coordinates[0]].max() - df[df_coordinates[0]].min() + ) + y_min = df[df_coordinates[1]].min() - padding * ( + df[df_coordinates[1]].max() - df[df_coordinates[1]].min() + ) + y_max = df[df_coordinates[1]].max() + padding * ( + df[df_coordinates[1]].max() - df[df_coordinates[1]].min() + ) + + # Make sure the aspect ratio is 1:1 + x_range = x_max - x_min + y_range = y_max - y_min + if x_range > y_range: + center = (y_max + y_min) / 2 + y_min = center - x_range / 2 + y_max = center + x_range / 2 + else: + center = (x_max + x_min) / 2 + x_min = center - y_range / 2 + x_max = center + y_range / 2 + + # Pre-compute all track data to ensure consistency across frames + track_data_cache = {} + for category, track_list in highlight_track_map.items(): + for idx, (fov_name, track_id) in enumerate(track_list): + track_key = f"{category}_{fov_name}_{track_id}" + print(f"Processing track: {track_key}") + # Get all data for this track + full_track_data = df[ + (df["fov_name"] == fov_name) & (df["track_id"] == track_id) + ].sort_values(time_column) + + print(f"Found {len(full_track_data)} points for track {track_key}") + if len(full_track_data) > 0: + track_data_cache[track_key] = full_track_data + print( + f"Time points for {track_key}: {sorted(full_track_data[time_column].unique())}" + ) + else: + print(f"WARNING: No data found for track {track_key}") + + print(f"Track data cache keys: {list(track_data_cache.keys())}") + + # Prepare data for all frames of the animation + frames = [] + + # Create traces for each time point + for t_idx, t in enumerate(all_times): + frame_data = [] + + # Historical data trace (all points from previous timepoints) + if t_idx > 0: + historical_df = df[df[time_column] < t] + frame_data.append( + go.Scatter( + x=historical_df[df_coordinates[0]], + y=historical_df[df_coordinates[1]], + mode="markers", + marker=dict(color="lightgray", size=5, opacity=0.2), + name="Historical", + hoverinfo="none", + showlegend=False, + ) + ) + else: + # Empty trace as placeholder + frame_data.append( + go.Scatter( + x=[], y=[], mode="markers", name="Historical", showlegend=False + ) + ) + + # Current time data + current_df = df[df[time_column] == t] + + # Plot each category + for category in categories: + category_points = current_df[current_df[category_column] == category] + if len(category_points) > 0: + frame_data.append( + go.Scatter( + x=category_points[df_coordinates[0]], + y=category_points[df_coordinates[1]], + mode="markers", + marker=dict( + color=category_colors.get(category, "gray"), + size=8, + opacity=0.7, + ), + name=category_labels.get(category, f"Category {category}"), + hovertext=[ + f"FOV: {row['fov_name']}, Track: {row['track_id']}, {category_labels.get(category, f'Category {category}')}" + for _, row in category_points.iterrows() + ], + hoverinfo="text", + showlegend=False, # Never show legend + ) + ) + else: + frame_data.append( + go.Scatter( + x=[], + y=[], + mode="markers", + name=category_labels.get(category, f"Category {category}"), + showlegend=False, # Never show legend + ) + ) + + # Add highlighted tracks + for category, track_list in highlight_track_map.items(): + for idx, (fov_name, track_id) in enumerate(track_list): + track_key = f"{category}_{fov_name}_{track_id}" + + if track_key in track_data_cache: + # Get the full track data from cache + full_track_data = track_data_cache[track_key] + + # Filter for data up to current time for trajectory + track_data = full_track_data[full_track_data[time_column] <= t] + + if len(track_data) > 0: + color = highlight_colors.get(category, "gray") + label = category_labels.get(category, f"Category {category}") + + # Create single line trace for the entire trajectory + frame_data.append( + go.Scatter( + x=track_data[df_coordinates[0]], + y=track_data[df_coordinates[1]], + mode="lines", + line=dict(color=color, width=2), + name=f"Track {track_id} ({label})", + showlegend=False, # Never show legend + ) + ) + + # Add current position marker + current_pos = track_data[track_data[time_column] == t] + + # If no data at current time but we have track data, show the last known position + if len(current_pos) == 0: + # Get the most recent position before current timepoint + latest_pos = track_data.iloc[-1:] + + if t_idx == 0: + print( + f"No current position for {track_key} at time {t}, using last known at {latest_pos[time_column].iloc[0]}" + ) + + # Add a semi-transparent marker at the last known position + frame_data.append( + go.Scatter( + x=latest_pos[df_coordinates[0]], + y=latest_pos[df_coordinates[1]], + mode="markers", + marker=dict( + color=color, + size=15, + line=dict(color="black", width=1), + opacity=0.5, # Lower opacity for non-current positions + ), + name=f"Last Known Position - {label}", + hovertext=[ + f"FOV: {row['fov_name']}, Track: {row['track_id']}, Last Seen at t={row[time_column]}, {label}" + for _, row in latest_pos.iterrows() + ], + hoverinfo="text", + showlegend=False, + ) + ) + else: + # Normal case - we have data at current timepoint + if t_idx == 0: + print( + f"Found current position for {track_key} at time {t}" + ) + + frame_data.append( + go.Scatter( + x=current_pos[df_coordinates[0]], + y=current_pos[df_coordinates[1]], + mode="markers", + marker=dict( + color=color, + size=15, + line=dict(color="black", width=1), + ), + name=f"Highlighted {label}", + hovertext=[ + f"FOV: {row['fov_name']}, Track: {row['track_id']}, Highlighted {label}" + for _, row in current_pos.iterrows() + ], + hoverinfo="text", + showlegend=False, # Never show legend + ) + ) + + # Create a frame for this time point + frames.append(go.Frame(data=frame_data, name=str(t))) + + # Create the base figure with the first frame data + fig = go.Figure( + data=frames[0].data, + frames=frames, + layout=go.Layout( + title=f"{title_prefix} - Time: {all_times[0]}", + xaxis=dict(title=df_coordinates[0], range=[x_min, x_max]), + yaxis=dict( + title=df_coordinates[1], + range=[y_min, y_max], + scaleanchor="x", # Make it 1:1 aspect ratio + scaleratio=1, + ), + updatemenus=[ + { + "type": "buttons", + "direction": "right", + "x": 0.15, + "y": 0, + "buttons": [ + { + "label": "Play", + "method": "animate", + "args": [ + None, + { + "frame": {"duration": 500, "redraw": True}, + "fromcurrent": True, + "transition": {"duration": 0}, + }, + ], + }, + { + "label": "Pause", + "method": "animate", + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + }, + ], + } + ], + sliders=[ + { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 16}, + "prefix": "Time: ", + "visible": True, + "xanchor": "right", + }, + "transition": {"duration": 0}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [ + [str(t)], + { + "frame": {"duration": 0, "redraw": True}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": str(t), + "method": "animate", + } + for t in all_times + ], + } + ], + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ), + ), + ) + + # Update figure layout + fig.update_layout( + width=plot_size_xy[0], + height=plot_size_xy[1], + margin=dict(l=50, r=50, t=100, b=100), + template="plotly_white", + ) + return fig + + +def create_image_visualization( + image_cache, + subplot_titles=["Mock Phase", "Mock RFP", "Infected Phase", "Infected RFP"], + condition_keys=["uinfected_cache", "infected_cache"], + channel_colormaps=["gray", "magma"], + plot_size_xy=(1000, 800), + horizontal_spacing=0.05, + vertical_spacing=0.1, +): + """ + Create an interactive visualization of images from image cache using Plotly with a time slider. + + Parameters + ---------- + image_cache : dict + Dictionary containing cached images by condition and timepoint + Format: {"condition_key": {"images_by_timepoint": {t: image_array}}} + subplot_titles : list, optional + Titles for the subplots, by default ["Mock Phase", "Mock RFP", "Infected Phase", "Infected RFP"] + condition_keys : list, optional + Keys for the conditions in the image_cache, by default ["uinfected_cache", "infected_cache"] + channel_colormaps : list, optional + Colormaps for each channel, by default ["gray", "magma"] + plot_size_xy : tuple, optional + Width and height of the plot in pixels, by default (1000, 800) + horizontal_spacing : float, optional + Horizontal spacing between subplots, by default 0.05 + vertical_spacing : float, optional + Vertical spacing between subplots, by default 0.1 + + Returns + ------- + plotly.graph_objects.Figure + The interactive Plotly figure + """ + # Check if plotly is available + try: + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + print("Plotly is not installed. Please install it using: pip install plotly") + return None + + # Get all available timepoints from all conditions + all_timepoints = [] + for condition_key in condition_keys: + if ( + condition_key in image_cache + and "images_by_timepoint" in image_cache[condition_key] + ): + all_timepoints.extend( + list(image_cache[condition_key]["images_by_timepoint"].keys()) + ) + + all_timepoints = sorted(list(set(all_timepoints))) + print(f"All timepoints: {all_timepoints}") + + if not all_timepoints: + print("No timepoints found in the image cache") + return None + + # Create the figure with subplots + fig = make_subplots( + rows=len(condition_keys), + cols=len(channel_colormaps), + subplot_titles=subplot_titles, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + ) + + # Create initial frame + t_initial = all_timepoints[0] + + # Add each condition as a row + for row_idx, condition_key in enumerate(condition_keys, 1): + if ( + condition_key in image_cache + and t_initial in image_cache[condition_key]["images_by_timepoint"] + ): + img = image_cache[condition_key]["images_by_timepoint"][t_initial] + + # Add each channel as a column + for col_idx, colormap in enumerate(channel_colormaps, 1): + cmap = cm.get_cmap(colormap) + img = img[col_idx, 0] + colored_img = cmap(img) + + # Convert to RGB format (remove alpha channel) + colored_img = (colored_img[:, :, :3] * 255).astype(np.uint8) + + if col_idx <= img.shape[0]: # Make sure we have this channel + fig.add_trace( + go.Image( + z=colored_img, + x0=0, + y0=0, + dx=1, + dy=1, + colormodel="rgb", + ), + row=row_idx, + col=col_idx, + ) + else: + # Empty placeholder if channel doesn't exist + fig.add_trace( + go.Image( + z=np.zeros((10, 10, 3)), + colormodel="rgb", + x0=0, + y0=0, + dx=1, + dy=1, + ), + row=row_idx, + col=col_idx, + ) + else: + # Empty placeholders if condition or timepoint not found + for col_idx, colormap in enumerate(channel_colormaps, 1): + fig.add_trace( + go.Image( + z=np.zeros((10, 10, 3)), + colormodel="rgb", + x0=0, + y0=0, + dx=1, + dy=1, + ), + row=row_idx, + col=col_idx, + ) + + # Function to create a frame for a specific timepoint + def create_frame_for_timepoint(t): + frame_data = [] + + for condition_key in condition_keys: + if ( + condition_key in image_cache + and t in image_cache[condition_key]["images_by_timepoint"] + ): + img = image_cache[condition_key]["images_by_timepoint"][t] + + for colormap in channel_colormaps: + col_idx = channel_colormaps.index(colormap) + cmap = cm.get_cmap(colormap) + img = img[col_idx, 0] + print(f"img shape: {img.shape}") + colored_img = cmap(img) + + # Convert to RGB format (remove alpha channel) + colored_img = (colored_img[:, :, :3] * 255).astype(np.uint8) + + if col_idx < img.shape[0]: # Make sure we have this channel + frame_data.append( + go.Image( + z=colored_img, + colormodel="rgb", + x0=0, + y0=0, + dx=1, + dy=1, + ) + ) + else: + # Empty placeholder + frame_data.append( + go.Image( + z=np.zeros((10, 10, 3)), + colormodel="rgb", + x0=0, + y0=0, + dx=1, + dy=1, + ) + ) + else: + # Empty placeholders if condition not found + for _ in channel_colormaps: + frame_data.append( + go.Image( + z=np.zeros((10, 10, 3)), + colormodel="rgb", + x0=0, + y0=0, + dx=1, + dy=1, + ) + ) + + # Create trace indices for updating the correct traces in each frame + trace_indices = list(range(len(condition_keys) * len(channel_colormaps))) + return go.Frame(data=frame_data, name=str(t), traces=trace_indices) + + # Create frames for the slider + frames = [create_frame_for_timepoint(t) for t in all_timepoints] + fig.frames = frames + + # Update layout + fig.update_layout( + title=f"Cell Images - Time: {t_initial}", + height=plot_size_xy[1], + width=plot_size_xy[0], + sliders=[ + { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 16}, + "prefix": "Time: ", + "visible": True, + "xanchor": "right", + }, + "transition": {"duration": 0}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [ + [str(t)], + { + "frame": {"duration": 0, "redraw": True}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": str(t), + "method": "animate", + } + for t in all_timepoints + ], + } + ], + ) + + # Update axes to hide ticks and labels + for row in range(1, len(condition_keys) + 1): + for col in range(1, len(channel_colormaps) + 1): + fig.update_xaxes( + showticklabels=False, showgrid=False, zeroline=False, row=row, col=col + ) + fig.update_yaxes( + showticklabels=False, showgrid=False, zeroline=False, row=row, col=col + ) + + return fig + + +def create_combined_visualization( + image_cache, + imagenet_df: pd.DataFrame, + dynaclr_df: pd.DataFrame, + highlight_tracks: dict, + subplot_titles=[ + "Uninfected Phase", + "Uninfected Viral Sensor", + "Infected Phase", + "Infected Viral Sensor", + ], + condition_keys=["uninfected_cache", "infected_cache"], + channel_colormaps=["gray", "magma"], + category_colors={1: "cornflowerblue", 2: "salmon"}, + highlight_colors={1: "blue", 2: "red"}, + category_labels={1: "Uninfected", 2: "Infected"}, + plot_size_xy=(1800, 600), + title_location="inside", +): + """ + Creates a combined visualization with cell images and PHATE embeddings with a shared time slider. + All plots are arranged side by side in one row. + + Parameters + ---------- + image_cache : dict + Image cache dictionary with cell images + imagenet_df : pandas.DataFrame + DataFrame with ImageNet PHATE embeddings + dynaclr_df : pandas.DataFrame + DataFrame with DynaCLR PHATE embeddings + highlight_tracks : dict + Dictionary of tracks to highlight in PHATE plots + subplot_titles : list + Titles for the image subplots + condition_keys : list + Keys for conditions in image cache + channel_colormaps : list + Colormaps for image channels + category_colors, highlight_colors, category_labels : dict + Visual configuration for PHATE plots + plot_size_xy : tuple + Width and height of the plot + title_location : str + Location of subplot titles. Either "inside" (default) or "top" + + Returns + ------- + plotly.graph_objects.Figure + Combined interactive figure + """ + import plotly.graph_objects as go + from plotly.subplots import make_subplots + + all_timepoints_images = set() + for condition_key in condition_keys: + if ( + condition_key in image_cache + and "images_by_timepoint" in image_cache[condition_key] + ): + all_timepoints_images.update( + image_cache[condition_key]["images_by_timepoint"].keys() + ) + + all_timepoints_imagenet = set(imagenet_df["t"].unique()) + all_timepoints_dynaclr = set(dynaclr_df["t"].unique()) + + all_timepoints = sorted( + list( + all_timepoints_images.intersection( + all_timepoints_imagenet, all_timepoints_dynaclr + ) + ) + ) + + if not all_timepoints: + print("No common timepoints found across all datasets") + all_timepoints = sorted( + list( + all_timepoints_images.union( + all_timepoints_imagenet, all_timepoints_dynaclr + ) + ) + ) + + def create_phate_traces( + df: pd.DataFrame, t: int, df_coordinates: list[str] = ["PHATE1", "PHATE2"] + ): + """Creates PHATE plot traces for a specific timepoint""" + traces = [] + + historical_df = df[df["t"] < t] + if len(historical_df) > 0: + traces.append( + go.Scatter( + x=historical_df[df_coordinates[0]], + y=historical_df[df_coordinates[1]], + mode="markers", + marker=dict(color="lightgray", size=5, opacity=0.2), + name="Historical", + hoverinfo="none", + showlegend=False, + ) + ) + else: + traces.append(go.Scatter(x=[], y=[], mode="markers", showlegend=False)) + + current_df = df[df["t"] == t] + categories = sorted(df["infection"].unique()) + + for category in categories: + category_points = current_df[current_df["infection"] == category] + if len(category_points) > 0: + traces.append( + go.Scatter( + x=category_points[df_coordinates[0]], + y=category_points[df_coordinates[1]], + mode="markers", + marker=dict( + color=category_colors.get(category, "gray"), + size=8, + opacity=0.7, + ), + name=category_labels.get(category, f"Category {category}"), + hovertext=[ + f"FOV: {row['fov_name']}, Track: {row['track_id']}, {category_labels.get(category, f'Category {category}')}" + for _, row in category_points.iterrows() + ], + hoverinfo="text", + showlegend=False, + ) + ) + else: + traces.append(go.Scatter(x=[], y=[], mode="markers", showlegend=False)) + + for category, track_list in highlight_tracks.items(): + for fov_name, track_id in track_list: + track_data = df[ + (df["fov_name"] == fov_name) + & (df["track_id"] == track_id) + & (df["t"] <= t) + ].sort_values("t") + + if len(track_data) > 0: + color = highlight_colors.get(category, "gray") + + traces.append( + go.Scatter( + x=track_data[df_coordinates[0]], + y=track_data[df_coordinates[1]], + mode="lines", + line=dict(color=color, width=2), + showlegend=False, + ) + ) + + current_pos = track_data[track_data["t"] == t] + if len(current_pos) == 0: + latest_pos = track_data.iloc[-1:] + opacity = 0.5 + else: + latest_pos = current_pos + opacity = 1.0 + + traces.append( + go.Scatter( + x=latest_pos[df_coordinates[0]], + y=latest_pos[df_coordinates[1]], + mode="markers", + marker=dict( + color=color, + size=15, + line=dict(color="black", width=1), + opacity=opacity, + ), + hovertext=[ + f"FOV: {row['fov_name']}, Track: {row['track_id']}, t={row['t']}" + for _, row in latest_pos.iterrows() + ], + hoverinfo="text", + showlegend=False, + ) + ) + + return traces + + def get_phate_limits(df, df_coordinates=["PHATE1", "PHATE2"]): + padding = 0.1 + x_min = df[df_coordinates[0]].min() - padding * ( + df[df_coordinates[0]].max() - df[df_coordinates[0]].min() + ) + x_max = df[df_coordinates[0]].max() + padding * ( + df[df_coordinates[0]].max() - df[df_coordinates[0]].min() + ) + y_min = df[df_coordinates[1]].min() - padding * ( + df[df_coordinates[1]].max() - df[df_coordinates[1]].min() + ) + y_max = df[df_coordinates[1]].max() + padding * ( + df[df_coordinates[1]].max() - df[df_coordinates[1]].min() + ) + + x_range = x_max - x_min + y_range = y_max - y_min + if x_range > y_range: + center = (y_max + y_min) / 2 + y_min = center - x_range / 2 + y_max = center + x_range / 2 + else: + center = (x_max + x_min) / 2 + x_min = center - y_range / 2 + x_max = center + y_range / 2 + + return x_min, x_max, y_min, y_max + + imagenet_limits = get_phate_limits(imagenet_df) + dynaclr_limits = get_phate_limits(dynaclr_df) + + t_initial = all_timepoints[0] + + main_fig = make_subplots( + rows=1, + cols=3, + column_widths=[0.33, 0.33, 0.33], + subplot_titles=["", "ImageNet PHATE", "DynaCLR PHATE"], + specs=[[{"type": "xy"}, {"type": "xy"}, {"type": "xy"}]], + ) + + def create_cell_image_traces(t): + traces = [] + from matplotlib import cm + + for row_idx, condition_key in enumerate(condition_keys): + if ( + condition_key in image_cache + and t in image_cache[condition_key]["images_by_timepoint"] + ): + img = image_cache[condition_key]["images_by_timepoint"][t] + + for col_idx, colormap in enumerate(channel_colormaps): + if col_idx < img.shape[0]: # Check if channel exists + img_data = img[col_idx, 0] + img_data = rescale_intensity(img_data, out_range=(0, 1)) + + if colormap == "gray": + rgb_img = np.stack([img_data] * 3, axis=-1) + rgb_img = (rgb_img * 255).astype(np.uint8) + else: + cmap = cm.get_cmap(colormap) + colored_img = cmap(img_data) + rgb_img = (colored_img[:, :, :3] * 255).astype(np.uint8) + + x_pos = col_idx * 0.5 + y_pos = 1.0 - row_idx * 0.5 + + x_coords = np.linspace(x_pos, x_pos + 0.45, rgb_img.shape[1]) + y_coords = np.linspace(y_pos - 0.45, y_pos, rgb_img.shape[0]) + + traces.append( + go.Image( + z=rgb_img, + x0=x_coords[0], + y0=y_coords[0], + dx=(x_coords[-1] - x_coords[0]) / rgb_img.shape[1], + dy=(y_coords[-1] - y_coords[0]) / rgb_img.shape[0], + colormodel="rgb", + name=subplot_titles[ + row_idx * len(channel_colormaps) + col_idx + ], + ) + ) + else: + warnings.warn( + f"Channel {col_idx} does not exist in image cache for timepoint {t}" + ) + + return traces + + for trace in create_cell_image_traces(t_initial): + main_fig.add_trace(trace, row=1, col=1) + + for trace in create_phate_traces(imagenet_df, t_initial, ["PHATE1", "PHATE2"]): + main_fig.add_trace(trace, row=1, col=2) + + for trace in create_phate_traces(dynaclr_df, t_initial, ["PHATE1", "PHATE2"]): + main_fig.add_trace(trace, row=1, col=3) + + for i, title in enumerate(subplot_titles): + row = i // 2 + col = i % 2 + + if title_location == "top": + x_pos = col * 0.5 + 0.22 + y_pos = 1 - row * 0.5 + yanchor = "bottom" + font_color = "black" + else: + x_pos = col * 0.5 + 0.22 + y_pos = 1 - row * 0.5 - 0.05 + yanchor = "top" + font_color = "white" + + main_fig.add_annotation( + x=x_pos, + y=y_pos, + text=title, + showarrow=False, + xref="x", + yref="y", + xanchor="center", + yanchor=yanchor, + font=dict(size=10, color=font_color), + row=1, + col=1, + ) + + main_fig.update_xaxes( + range=[0, 1], showticklabels=False, showgrid=False, zeroline=False, row=1, col=1 + ) + main_fig.update_yaxes( + range=[0, 1], showticklabels=False, showgrid=False, zeroline=False, row=1, col=1 + ) + + main_fig.update_xaxes(title="PHATE1", range=imagenet_limits[:2], row=1, col=2) + main_fig.update_yaxes( + title="PHATE2", + range=imagenet_limits[2:], + scaleanchor="x2", + scaleratio=1, + row=1, + col=2, + ) + main_fig.update_xaxes(title="PHATE1", range=dynaclr_limits[:2], row=1, col=3) + main_fig.update_yaxes( + title="PHATE2", + range=dynaclr_limits[2:], + scaleanchor="x3", + scaleratio=1, + row=1, + col=3, + ) + + main_fig.update_layout( + title=f"Cell Images and PHATE Embeddings", + width=plot_size_xy[0], + height=plot_size_xy[1], + sliders=[ + { + "active": 1, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "font": {"size": 16}, + "prefix": "Time: ", + "visible": True, + "xanchor": "right", + }, + "transition": {"duration": 0}, + "pad": {"b": 10, "t": 50}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [ + [str(t)], + { + "frame": {"duration": 0, "redraw": True}, + "mode": "immediate", + "transition": {"duration": 0}, + "fromcurrent": False, + }, + ], + "label": str(t), + "method": "animate", + } + for t in all_timepoints + ], + } + ], + ) + + frames = [] + for t in all_timepoints: + frame_data = [] + + frame_data.extend(create_cell_image_traces(t)) + + frame_data.extend(create_phate_traces(imagenet_df, t, ["PHATE1", "PHATE2"])) + frame_data.extend(create_phate_traces(dynaclr_df, t, ["PHATE1", "PHATE2"])) + + frames.append(go.Frame(data=frame_data, name=str(t))) + + main_fig.frames = frames + + main_fig.update_layout( + transition={"duration": 0}, updatemenus=[] # Remove any animation buttons + ) + + return main_fig diff --git a/examples/DynaCLR/DynaCLR-classical-sampling/README.md b/examples/DynaCLR/DynaCLR-classical-sampling/README.md new file mode 100644 index 000000000..db0628117 --- /dev/null +++ b/examples/DynaCLR/DynaCLR-classical-sampling/README.md @@ -0,0 +1,44 @@ +# DynaCLR Classical Sampling + +This module implements classical triplet sampling for training DynaCLR models by generating pseudo-tracking data from 2D segmentation masks. It processes segmentation data from an HCS OME-Zarr store and creates corresponding tracking CSV files with the following information: +- Track IDs from segmentation masks +- Centroid coordinates (t, y, x) for each segmented object per time point +- Unique IDs for each object + +## Prerequisites +- Input HCS OME-Zarr store containing segmentation masks + +## Usage + +### 1. Configure Input/Output Paths +Open `create_pseudo_tracks.py` and modify: +```python +# Input path to your segmentation data +input_data_path = "/path/to/your/input.zarr" +# Output path for tracking data +track_data_path = "/path/to/your/output.zarr" +# Channel name for the segmentations +segmentation_channel_name = "Nucl_mask" +# Z-slice to use for 2D tracking +Z_SLICE = 30 +``` + +### 2. Run the Script +```bash +python create_pseudo_tracks.py +``` + +## Processing Steps +1. Loads segmentation data from input zarr store +2. For each well and position: + - Processes each timepoint + - Extracts 2D segmentation at specified z-slice + - Calculates centroid coordinates for segmented objects (i.e. (y,x)) + - Generates and save the pseudo-tracking data to CSV files +1. Creates a new zarr store with the processed data + +## Notes +- Currently only supports 2D segmentation tracking at a single z-slice +- The z-slice index can be modified in the script +- Output CSV files are organized by well and position +- Make sure your zarr stores are properly configured before running the script \ No newline at end of file diff --git a/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py b/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py new file mode 100644 index 000000000..b225d17c6 --- /dev/null +++ b/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py @@ -0,0 +1,119 @@ +# %% +import os + +import numpy as np +import pandas as pd +from iohub.ngff import open_ome_zarr +from iohub.ngff.utils import create_empty_plate +from tqdm import tqdm + +# %% create training and validation dataset +# TODO: Modify path to the input data +input_data_path = "/training_data.zarr" +# TODO: Modify path to the output data +track_data_path = "/training_data_tracks.zarr" + +# TODO: Modify the channel name to the one you are using for the segmentation mask +segmentation_channel_name = "Nucl_mask" +# TODO: Modify the z-slice to the one you are using for the segmentation mask +Z_SLICE = 30 +# %% +""" +Add csvs with fake tracking to tracking data. + +The tracking data is a csv with the following columns: +- track_id: from segmentation mask, list of labels +- t: all 0 since there is just one timepoint +- x, y: the coordinates of the centroid of the segmentation mask +- id: must be all unqiue 6 digit numbers starting from 100000 +- parent_track_id: all -1 +- parent_id: all -1 +""" + + +def create_track_df(seg_mask, time): + track_id = np.unique(seg_mask) + track_id = track_id[track_id != 0] + track_rows = [] + # Get coordinates for each track_id separately + for tid in track_id: + y, x = np.where(seg_mask == tid) # Note: y comes first from np.where + # Use mean coordinates as centroid + mean_y = np.mean(y) + mean_x = np.mean(x) + track_rows.append( + { + "track_id": tid, + "t": time, + "y": mean_y, # Using mean y coordinate + "x": mean_x, # Using mean x coordinate + "id": 100000 + tid, + "parent_track_id": -1, + "parent_id": -1, + } + ) + track_df = pd.DataFrame(track_rows) + return track_df + + +def save_track_df(track_df, well_id, pos_name, out_path): + folder, subfolder = well_id.split("/") + out_name = f"{folder}_{subfolder}_{pos_name}_tracks.csv" + out_path = os.path.join(out_path, folder, subfolder, pos_name, out_name) + track_df.to_csv(out_path, index=False) + + +# %% +def main(): + # Load the input segmentation data + zarr_input = open_ome_zarr( + input_data_path, + layout="hcs", + mode="r+", + ) + chan_names = zarr_input.channel_names + assert ( + segmentation_channel_name in chan_names + ), "Channel name not found in the input data" + + # Create the empty store for the tracking data + position_names = [] + for ds, position in zarr_input.positions(): + position_names.append(tuple(ds.split("/"))) + + create_empty_plate( + store_path=track_data_path, + position_keys=position_names, + channel_names=segmentation_channel_name, + shape=(1, 1, 1, *position.data.shape[3:]), + chunks=position.data.chunks, + scale=position.scale, + ) + + # Populate the tracking data + with open_ome_zarr(track_data_path, layout="hcs", mode="r+") as track_store: + # Create progress bar for wells and positions + for well_id, well_data in tqdm(zarr_input.wells(), desc="Processing wells"): + for pos_name, pos_data in tqdm( + well_data.positions(), + desc=f"Processing positions in {well_id}", + leave=False, + ): + data = pos_data.data + T, C, Z, Y, X = data.shape + track_df_all = pd.DataFrame() + for time in range(T): + seg_mask = data[ + time, chan_names.index(segmentation_channel_name), Z_SLICE, :, : + ] + track_pos = track_store[well_id + "/" + pos_name] + track_pos["0"][0, 0, 0] = seg_mask + track_df = create_track_df(seg_mask, time) + track_df_all = pd.concat([track_df_all, track_df]) + save_track_df(track_df_all, well_id, pos_name, track_data_path) + zarr_input.close() + + +# %% +if __name__ == "__main__": + main() diff --git a/examples/DynaCLR/README.md b/examples/DynaCLR/README.md new file mode 100644 index 000000000..143eed8d5 --- /dev/null +++ b/examples/DynaCLR/README.md @@ -0,0 +1,65 @@ +# DynaCLR Demos + +This repository contains examples and demos to embed cellular dynamics using DynaCLR. + +## Available Demos + +- [ImageNet vs DynaCLR embeddings (cell infection)](/examples/DynaCLR/DynaCLR-DENV-VS-Ph/README.md) +- [Embedding visualization](/examples/DynaCLR/embedding-web-visualization/README.md) + +## Setup + +To run the demos, you need to download the data and activate the environment. + +> **Note**: The `download_data.sh` script downloads data to `{$HOME}/data/dynaclr/demo` by default. Modify the script to download the data to a different directory if needed. + +```bash +# To setup the environment +bash setup.sh + +# To download the data +bash download_data.sh +``` + +## Generate DynaCLR Embeddings + +For this demo, we will use the `DynaCLR-DENV-VS-Ph` model as an example. + +The datasets and config files for the models can be found: +- [Test datasets](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/) +- [Models](https://public.czbiohub.org/comp.micro/viscy/DynaCLR_models/) + + +### Modify the Config File + +Open the `dynaclr_denv-vs-ph_test_data.yml` and modify the following to point to your download: + +Replace the output path where you want to save the xarray `.zarr` file with the embeddings: + +```yaml +callbacks: +- class_path: viscy.representation.embedding_writer.EmbeddingWriter + init_args: + output_path: '/TODO_REPLACE_TO_OUTPUT_PATH.zarr' # Select the path to save +``` + +Point to the downloaded checkpoint for the desired model (e.g., `DynaCLR-DENV-VS+Ph`): + +```yaml +ckpt_path: '/downloaded.ckpt' # Point to ckpt file +``` + +--- +### DynaCLR with classical triplet sampling + +To train DynaCLR models using the classical triplet sampling, you need to generate pseudo-tracking data from 2D segmentation masks. + +These pseudo-tracks are used to run the same. For more information: [README.md](./DynaCLR-classical-sampling/README.md) + +### Exporting DynaCLR models + +To export DynaCLR models to ONNX run: + +`viscy export -c config.yml` + +The `config.yml` is similar to the `fit.yml` which describes the model. An example can be found [here](./examples_cli/dynaclr_microglia_onnx.yml). \ No newline at end of file diff --git a/examples/DynaCLR/download_data.sh b/examples/DynaCLR/download_data.sh new file mode 100644 index 000000000..d71b9a1bb --- /dev/null +++ b/examples/DynaCLR/download_data.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +START_DIR=$(pwd) + +# Create the directory structure +output_dir=~/ +mkdir -p "$output_dir"/data/dynaclr/demo + +# Change to the target directory if you want to download the data to a specific directory +cd ~/data/dynaclr/demo + +# Download the data +wget -m -np -nH --cut-dirs=6 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/DynaCLR_data/DENV/test/20240204_A549_DENV_ZIKV_timelapse/" + +echo "Data downloaded successfully." + +# Change back to the starting directory +cd $START_DIR diff --git a/examples/DynaCLR/embedding-web-visualization/README.md b/examples/DynaCLR/embedding-web-visualization/README.md new file mode 100644 index 000000000..bd2694362 --- /dev/null +++ b/examples/DynaCLR/embedding-web-visualization/README.md @@ -0,0 +1,55 @@ +# Web-based embedding exploration + +## Overview + +The `interactive_visualize.py` script allows for embedding visualization and exploration. + +## Key Features + +- **Interactive Visualization**: Plotly-dash visualization of the embeddings +- Lasso selection to display image clusters +- Display Principal Components and PHATE plots +- Single cell selection + +## Setup + +The demo uses cellular imaging data with the following components: +- Embeds the dynamic cellular response and plots Principal Components or PHATE + +You can download the data from the provided Google Drive links in the script or use your own data by updating the paths: + +```python +# Update these paths to the downloaded data +download_root = Path.home() / "data/dynaclr/demo" +viz_config = { + "data_path": download_root / "registered_test.zarr", # TODO add path to data + "tracks_path": download_root / "track_test.zarr", # TODO add path to tracks + "features_path": download_root + / "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr", # TODO add path to features + "channels_to_display": ["Phase3D", "RFP"], + # TODO: Modify for specific FOVs [A/3/*]- Uinfected and [B/4/*]-Infected for 0-9 FOVs. They will be cached in memory. + "fov_tracks": { + "/A/3/9": list(range(50)), + "/B/4/9": list(range(50)), + }, + "yx_patch_size": (160, 160), + "num_PC_components": 8, +} +``` + +## Usage + +After setting up the environment, activate it and run the demo script: + +```bash +conda activate dynaclr +python interactive_visualizer.py +``` + +## Demo + +### Embeddings per track (click on the track to see the embeddings) +![embeddings_per_track](/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embeddings_visualization_track.png) + +### Clustering (use the lasso to select the embeddings) +![embeddings_per_cluster](/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embedding_visualization_cluster.png) diff --git a/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embedding_visualization_cluster.png b/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embedding_visualization_cluster.png new file mode 100644 index 000000000..a7cae6180 Binary files /dev/null and b/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embedding_visualization_cluster.png differ diff --git a/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embeddings_visualization_track.png b/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embeddings_visualization_track.png new file mode 100644 index 000000000..d528dc055 Binary files /dev/null and b/examples/DynaCLR/embedding-web-visualization/demo_imgs/demo2_embeddings_visualization_track.png differ diff --git a/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py new file mode 100644 index 000000000..cce99c9e4 --- /dev/null +++ b/examples/DynaCLR/embedding-web-visualization/interactive_visualizer.py @@ -0,0 +1,55 @@ +"""Interactive visualization of phenotype data.""" + +import logging +from pathlib import Path + +from numpy.random import seed + +from viscy.representation.evaluation.visualization import EmbeddingVisualizationApp + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +seed(42) + + +def main(): + """Main function to run the visualization app.""" + + # Config for the visualization app + # TODO: Update the paths to the downloaded data. By default the data is downloaded to ~/data/dynaclr/demo + download_root = Path.home() / "data/dynaclr/demo" + output_path = Path.home() / "data/dynaclr/demo/embedding-web-visualization" + viz_config = { + "data_path": download_root / "registered_test.zarr", # TODO add path to data + "tracks_path": download_root / "track_test.zarr", # TODO add path to tracks + "features_path": download_root + / "precomputed_embeddings/infection_160patch_94ckpt_rev6_dynaclr.zarr", # TODO add path to features + "channels_to_display": ["Phase3D", "RFP"], + "fov_tracks": { + "/A/3/9": list(range(50)), + "/B/4/9": list(range(50)), + }, + "yx_patch_size": (160, 160), + "z_range": (24, 29), + "num_PC_components": 8, + "output_dir": output_path, + } + + # Create and run the visualization app + try: + # Create and run the visualization app + app = EmbeddingVisualizationApp(**viz_config) + app.preload_images() + app.run(debug=True) + + except KeyboardInterrupt: + logger.info("Application shutdown requested by user") + except Exception as e: + logger.error(f"Application error: {e}") + finally: + logger.info("Application shutdown complete") + + +if __name__ == "__main__": + main() diff --git a/examples/DynaCLR/examples_cli/dynaclr_microglia_onnx.yml b/examples/DynaCLR/examples_cli/dynaclr_microglia_onnx.yml new file mode 100644 index 000000000..2dbcfd1e4 --- /dev/null +++ b/examples/DynaCLR/examples_cli/dynaclr_microglia_onnx.yml @@ -0,0 +1,58 @@ +# lightning.pytorch==2.4.0 + +# TODO: Check the TODO's and change the paths to the correct ones + +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: [] +model: + class_path: viscy.representation.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy.representation.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 1 + in_stack_depth: 1 + stem_kernel_size: + - 1 + - 4 + - 4 + stem_stride: + - 1 + - 4 + - 4 + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.0 + loss_function: + class_path: pytorch_metric_learning.losses.NTXentLoss + init_args: + temperature: 0.2 + embedding_regularizer: null + embedding_reg_weight: 1 + reducer: null + distance: null + collect_stats: null + lr: 2.0e-05 + schedule: Constant + log_batches_per_epoch: 3 + log_samples_per_batch: 3 + log_embeddings: false + example_input_array_shape: + - 1 + - 1 + - 1 + - 256 + - 256 +ckpt_path: /epoch=19-step=12960.ckpt #TODO: change to the checkpoint path +format: onnx +export_path: dynaclr_microglia.onnx #TODO: change to the export path + + + diff --git a/applications/contrastive_phenotyping/examples_cli/fit.yml b/examples/DynaCLR/examples_cli/fit.yml similarity index 67% rename from applications/contrastive_phenotyping/examples_cli/fit.yml rename to examples/DynaCLR/examples_cli/fit.yml index 90911824e..281bb1ebe 100644 --- a/applications/contrastive_phenotyping/examples_cli/fit.yml +++ b/examples/DynaCLR/examples_cli/fit.yml @@ -1,5 +1,10 @@ # See help here on how to configure hyper-parameters with config files: # https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html + +# TODO: point to the path to save the embeddings +# TODO: point to the path to the data +# TODO: point to the path to the tracks + seed_everything: 42 trainer: accelerator: gpu @@ -9,18 +14,9 @@ trainer: precision: 32-true logger: class_path: lightning.pytorch.loggers.TensorBoardLogger - # Nesting the logger config like this is equivalent to - # supplying the following argument to `lightning.pytorch.Trainer`: - # logger=TensorBoardLogger( - # "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations", - # log_graph=True, - # version="vanilla", - # ) init_args: - save_dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations - # this is the name of the experiment. - # The logs will be saved in `save_dir/lightning_logs/version` - version: l2_projection_batchnorm + save_dir: #TODO point to the path to save the logs + version: #TODO point to the version name log_graph: True callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -38,44 +34,42 @@ trainer: enable_checkpointing: true inference_mode: true use_distributed_sampler: true - # synchronize batchnorm parameters across multiple GPUs. - # important for contrastive learning to normalize the tensors across the whole batch. - sync_batchnorm: true model: - class_path: + class_path: viscy.representation.engine.ContrastiveModule init_args: encoder: class_path: viscy.representation.contrastive.ContrastiveEncoder init_args: backbone: convnext_tiny in_channels: 2 - in_stack_depth: 15 + in_stack_depth: 30 stem_kernel_size: [5, 4, 4] stem_stride: [5, 4, 4] embedding_dim: 768 - projection_dim: 128 + projection_dim: 32 drop_path_rate: 0.0 loss_function: class_path: torch.nn.TripletMarginLoss init_args: margin: 0.5 - lr: 0.0002 + lr: 0.00002 log_batches_per_epoch: 3 log_samples_per_batch: 3 - example_input_array_shape: [1, 2, 15, 256, 256] + example_input_array_shape: [1, 2, 30, 256, 256] data: class_path: viscy.data.triplet.TripletDataModule init_args: - data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr - tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr + data_path: #TODO point to the path to the data (e.g. /2024_10_16_A549_SEC61_sensor_train.zarr) + tracks_path: #TODO point to the path to the corresponding tracks (e.g. /track_trainVal.zarr) source_channel: - Phase3D - - RFP - z_range: [25, 40] - batch_size: 32 + - raw mCherry EX561 EM600-37 + z_range: [15, 45] + batch_size: 64 num_workers: 12 initial_yx_patch_size: [384, 384] - final_yx_patch_size: [192, 192] + final_yx_patch_size: [160, 160] + time_interval: 1 normalizations: - class_path: viscy.transforms.NormalizeSampled init_args: @@ -85,7 +79,7 @@ data: divisor: std - class_path: viscy.transforms.ScaleIntensityRangePercentilesd init_args: - keys: [RFP] + keys: [raw mCherry EX561 EM600-37] lower: 50 upper: 99 b_min: 0.0 @@ -93,7 +87,7 @@ data: augmentations: - class_path: viscy.transforms.RandAffined init_args: - keys: [Phase3D, RFP] + keys: [Phase3D, raw mCherry EX561 EM600-37] prob: 0.8 scale_range: [0, 0.2, 0.2] rotate_range: [3.14, 0.0, 0.0] @@ -101,9 +95,9 @@ data: padding_mode: zeros - class_path: viscy.transforms.RandAdjustContrastd init_args: - keys: [RFP] + keys: [raw mCherry EX561 EM600-37] prob: 0.5 - gamma: [0.7, 1.3] + gamma: [0.8, 1.2] - class_path: viscy.transforms.RandAdjustContrastd init_args: keys: [Phase3D] @@ -111,8 +105,8 @@ data: gamma: [0.8, 1.2] - class_path: viscy.transforms.RandScaleIntensityd init_args: - keys: [RFP] - prob: 0.7 + keys: [raw mCherry EX561 EM600-37] + prob: 0.5 factors: 0.5 - class_path: viscy.transforms.RandScaleIntensityd init_args: @@ -121,20 +115,20 @@ data: factors: 0.5 - class_path: viscy.transforms.RandGaussianSmoothd init_args: - keys: [Phase3D, RFP] + keys: [Phase3D, raw mCherry EX561 EM600-37] prob: 0.5 sigma_x: [0.25, 0.75] sigma_y: [0.25, 0.75] sigma_z: [0.0, 0.0] - class_path: viscy.transforms.RandGaussianNoised init_args: - keys: [RFP] + keys: [raw mCherry EX561 EM600-37] prob: 0.5 mean: 0.0 - std: 0.5 + std: 0.2 - class_path: viscy.transforms.RandGaussianNoised init_args: keys: [Phase3D] prob: 0.5 mean: 0.0 - std: 0.2 + std: 0.2 \ No newline at end of file diff --git a/applications/contrastive_phenotyping/examples_cli/fit_slurm.sh b/examples/DynaCLR/examples_cli/fit_slurm.sh similarity index 50% rename from applications/contrastive_phenotyping/examples_cli/fit_slurm.sh rename to examples/DynaCLR/examples_cli/fit_slurm.sh index 220e98373..0ff0ea7c8 100644 --- a/applications/contrastive_phenotyping/examples_cli/fit_slurm.sh +++ b/examples/DynaCLR/examples_cli/fit_slurm.sh @@ -9,36 +9,33 @@ #SBATCH --mem-per-cpu=15G #SBATCH --time=0-20:00:00 -# debugging flags (optional) +# NOTE: debugging flags (optional) # https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html export NCCL_DEBUG=INFO export PYTHONFAULTHANDLER=1 - - -# Cleanup function to remove the temporary files function cleanup() { rm -rf /tmp/$SLURM_JOB_ID/*.zarr echo "Cleanup Completed." } - trap cleanup EXIT -# trap the EXIT signal sent to the process and invoke the cleanup. -# Activate the conda environment - specfic to your installation! -module load anaconda/2022.05 -# You'll need to replace this path with path to your own conda environment. -conda activate /hpc/mydata/$USER/envs/viscy -config=./demo_cli_fit.yml +# TODO: Activate the conda environment - specfic to your installation! +# TODO: You'll need to replace this path with path to your own conda environment +module load anaconda/latest +conda activate dynaclr + +# TODO: point to the path to the config file +config=./fit.yml # Printing this to the stdout lets us connect the job id to config. scontrol show job $SLURM_JOB_ID cat $config # Run the training CLI -srun python -m viscy.cli.contrastive_triplet fit -c $config +viscy fit -c $config # Tips: -# 1. run this script with `sbatch demo_cli_fit_slurm.sh` -# 2. check the status of the job with `squeue -u $USER` -# 3. use turm to monitor the job with `turm -u first.last`. Use module load turm to load the turm module. +# 1. Run this script with `sbatch fit_slurm.sh` +# 2. Check the status of the job with `squeue -u $USER` +# 3. Use turm to monitor the job with `turm -u first.last`. Use module load turm to load the turm module. diff --git a/examples/DynaCLR/examples_cli/predict.yml b/examples/DynaCLR/examples_cli/predict.yml new file mode 100644 index 000000000..50b490d00 --- /dev/null +++ b/examples/DynaCLR/examples_cli/predict.yml @@ -0,0 +1,68 @@ +# TODO: point to the path to save the embeddings +# TODO: point to the path to the data +# TODO: point to the path to the tracks +# TODO: point to the path to the checkpoint + +seed_everything: 42 +trainer: + accelerator: gpu + strategy: auto + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: + - class_path: viscy.representation.embedding_writer.EmbeddingWriter + init_args: + output_path: #TODO point to the path to save the embeddings + phate_kwargs: #TODO modify default parameters. Set to null to skip PHATE computation. + knn: 5 + decay: 40 + n_jobs: -1 + random_state: 42 + pca_kwargs: #TODO modify default parameters. Set to null to skip PCA computation. + n_components: 8 + inference_mode: true +model: + class_path: viscy.representation.engine.ContrastiveModule + init_args: + encoder: + class_path: viscy.representation.contrastive.ContrastiveEncoder + init_args: + backbone: convnext_tiny + in_channels: 2 + in_stack_depth: 30 + stem_kernel_size: [5, 4, 4] + stem_stride: [5, 4, 4] + embedding_dim: 768 + projection_dim: 32 + drop_path_rate: 0.0 + example_input_array_shape: [1, 2, 30, 256, 256] +data: + class_path: viscy.data.triplet.TripletDataModule + init_args: + data_path: #TODO point to the path to the data (e.g. /registered_test.zarr) + tracks_path: #TODO point to the path to the tracks (e.g. /track_test.zarr) + source_channel: + - Phase3D + - RFP + z_range: [15, 45] + batch_size: 32 + num_workers: 30 + initial_yx_patch_size: [160, 160] + final_yx_patch_size: [160, 160] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [Phase3D] + level: fov_statistics + subtrahend: mean + divisor: std + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: [RFP] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 +return_predictions: false +ckpt_path: #TODO point to the path to the checkpoint (e.g. /checkpoints/epoch=94-step=2375.ckpt) diff --git a/applications/contrastive_phenotyping/examples_cli/predict_slurm.sh b/examples/DynaCLR/examples_cli/predict_slurm.sh similarity index 73% rename from applications/contrastive_phenotyping/examples_cli/predict_slurm.sh rename to examples/DynaCLR/examples_cli/predict_slurm.sh index 3f91fc9bc..cdf887971 100644 --- a/applications/contrastive_phenotyping/examples_cli/predict_slurm.sh +++ b/examples/DynaCLR/examples_cli/predict_slurm.sh @@ -9,13 +9,14 @@ #SBATCH --mem-per-cpu=7G #SBATCH --time=0-01:00:00 -module load anaconda/2022.05 +module load anaconda/latest # Update to use the actual prefix -conda activate $MYDATA/envs/viscy +conda activate dynaclr scontrol show job $SLURM_JOB_ID # use absolute path in production config=./predict.yml cat $config -srun python -m viscy.cli.contrastive_triplet predict -c $config + +viscy predict -c $config diff --git a/examples/DynaCLR/setup.sh b/examples/DynaCLR/setup.sh new file mode 100644 index 000000000..fcc72d5ba --- /dev/null +++ b/examples/DynaCLR/setup.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +START_DIR=$(pwd) + +# Initialize conda for the shell +eval "$(conda shell.bash hook)" +conda deactivate + +# Check if environment exists +if ! conda env list | grep -q "dynaclr"; then + echo "Creating new dynaclr environment..." + conda config --add channels defaults + conda create -y --name dynaclr python=3.11 +else + echo "Environment already exists. Updating packages..." +fi + +# Activate the environment +conda activate dynaclr + +# Install/update conda packages +conda install -y ipykernel nbformat nbconvert black jupytext ipywidgets +python -m ipykernel install --user --name dynaclr --display-name "Python (dynaclr)" + +# Install viscy and its dependencies using pip +pip install "viscy[visual,metrics,phate,examples]>=0.4.0a1" + +# Change back to the starting directory +cd $START_DIR + +echo "DynaCLR environment setup complete." diff --git a/pyproject.toml b/pyproject.toml index d2fb83788..6d0fb1e17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,7 @@ dependencies = [ "matplotlib>=3.9.0", "numpy", "xarray", - "pytorch-metric-learning>2.0.0", -] + "pytorch-metric-learning>2.0.0"] dynamic = ["version"] [project.optional-dependencies] @@ -40,9 +39,12 @@ metrics = [ "ptflops>=0.7", "umap-learn", "captum>=0.7.0", + "mahotas", +] +phate = [ "phate", ] -examples = ["napari", "jupyter", "jupytext"] +examples = ["napari", "jupyter", "jupytext", "transformers>=4.51.3"] visual = [ "ipykernel", "graphviz", @@ -51,9 +53,10 @@ visual = [ "plotly", "nbformat", "cmap", + "dash", ] dev = [ - "viscy[metrics,examples,visual]", + "viscy[metrics,phate,examples,visual]", "pytest", "pytest-cov", "hypothesis", diff --git a/tests/conftest.py b/tests/conftest.py index 38ccaab26..8fc6c8216 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import numpy as np +import pandas as pd from iohub import open_ome_zarr from pytest import TempPathFactory, fixture @@ -28,7 +29,7 @@ def _build_hcs( ) for row in ("A", "B"): for col in ("1", "2"): - for fov in ("0", "1"): + for fov in ("0", "1", "2", "3"): pos = dataset.create_position(row, col, fov) pos.create_image( "0", @@ -61,9 +62,40 @@ def small_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: return dataset_path +@fixture(scope="function") +def small_hcs_labels(tmp_path_factory: TempPathFactory) -> Path: + """Provides a small, not preprocessed HCS OME-Zarr dataset with labels.""" + dataset_path = tmp_path_factory.mktemp("small_with_labels.zarr") + _build_hcs( + dataset_path, ["nuclei_labels", "membrane_labels"], (12, 64, 64), np.uint16, 50 + ) + return dataset_path + + @fixture(scope="function") def labels_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: """Provides a small, not preprocessed HCS OME-Zarr dataset.""" dataset_path = tmp_path_factory.mktemp("labels.zarr") _build_hcs(dataset_path, ["DAPI", "GFP"], (2, 16, 16), np.uint16, 3) return dataset_path + + +@fixture(scope="function") +def tracks_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: + """Provides a HCS OME-Zarr dataset with tracking CSV results.""" + dataset_path = tmp_path_factory.mktemp("tracks.zarr") + _build_hcs(dataset_path, ["nuclei_labels"], (1, 256, 256), np.uint16, 3) + for fov_name, _ in open_ome_zarr(dataset_path).positions(): + fake_tracks = pd.DataFrame( + { + "track_id": [0, 1], + "t": [0, 1], + "y": [100, 200], + "x": [96, 160], + "id": [0, 1], + "parent_track_id": [-1, -1], + "parent_id": [-1, -1], + } + ) + fake_tracks.to_csv(dataset_path / fov_name / "tracks.csv", index=False) + return dataset_path diff --git a/tests/data/test_data.py b/tests/data/test_hcs.py similarity index 100% rename from tests/data/test_data.py rename to tests/data/test_hcs.py diff --git a/tests/data/test_select.py b/tests/data/test_select.py new file mode 100644 index 000000000..6ffc55b83 --- /dev/null +++ b/tests/data/test_select.py @@ -0,0 +1,30 @@ +import pytest +from iohub.ngff import open_ome_zarr + +from viscy.data.select import SelectWell + + +@pytest.mark.parametrize("include_wells", [None, ["A/1", "A/2", "B/2"]]) +@pytest.mark.parametrize("exclude_fovs", [None, ["A/1/0", "A/1/1", "A/2/2"]]) +def test_select_well(include_wells, exclude_fovs, preprocessed_hcs_dataset): + dummy = SelectWell() + dummy._include_wells = include_wells + dummy._exclude_fovs = exclude_fovs + plate = open_ome_zarr(preprocessed_hcs_dataset) + filtered_positions = dummy._filter_fit_fovs(plate) + fovs_per_well = len(plate["A/1"]) + if include_wells is None: + total_wells = len(list(plate.wells())) + else: + total_wells = len(include_wells) + total_fovs = total_wells * fovs_per_well + if exclude_fovs is not None: + total_fovs -= len(exclude_fovs) + assert len(filtered_positions) == total_fovs + for position in filtered_positions: + fov_name = position.zgroup.name.strip("/") + well_name, _ = fov_name.rsplit("/", 1) + if include_wells is not None: + assert well_name in include_wells + if exclude_fovs is not None: + assert fov_name not in exclude_fovs diff --git a/tests/data/test_triplet.py b/tests/data/test_triplet.py new file mode 100644 index 000000000..ef0210958 --- /dev/null +++ b/tests/data/test_triplet.py @@ -0,0 +1,68 @@ +import pandas as pd +from iohub import open_ome_zarr +from pytest import mark + +from viscy.data.triplet import TripletDataModule + + +@mark.parametrize("include_wells", [None, ["A/1", "A/2", "B/1"]]) +@mark.parametrize("exclude_fovs", [None, ["A/1/0", "A/1/1", "A/2/2", "B/1/3"]]) +def test_datamodule_setup_fit( + preprocessed_hcs_dataset, tracks_hcs_dataset, include_wells, exclude_fovs +): + data_path = preprocessed_hcs_dataset + z_window_size = 5 + split_ratio = 0.75 + yx_patch_size = [32, 32] + batch_size = 4 + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + total_wells = len(list(dataset.wells())) + fovs_per_well = len(dataset["A/1"]) + if include_wells is not None: + total_wells = len(include_wells) + total_fovs = total_wells * fovs_per_well + if exclude_fovs is not None: + total_fovs -= len(exclude_fovs) + len_total = total_fovs * 2 + len_train = int(len_total * split_ratio) + len_val = len_total - len_train + dm = TripletDataModule( + data_path=data_path, + tracks_path=tracks_hcs_dataset, + source_channel=channel_names, + z_range=(4, 9), + initial_yx_patch_size=(64, 64), + final_yx_patch_size=(32, 32), + num_workers=0, + split_ratio=split_ratio, + batch_size=batch_size, + fit_include_wells=include_wells, + fit_exclude_fovs=exclude_fovs, + return_negative=True, + ) + dm.setup(stage="fit") + assert len(dm.train_dataset) == len_train + assert len(dm.val_dataset) == len_val + all_tracks = pd.concat([dm.train_dataset.tracks, dm.val_dataset.tracks]) + filtered_fov_names = all_tracks["fov_name"].str[1:].unique() + for fov_name in filtered_fov_names: + well_name, _ = fov_name.rsplit("/", 1) + if include_wells is not None: + assert well_name in include_wells + if exclude_fovs is not None: + assert fov_name not in exclude_fovs + assert len(all_tracks) == len_total + for batch in dm.train_dataloader(): + assert batch["anchor"].shape == ( + batch_size, + len(channel_names), + z_window_size, + *yx_patch_size, + ) + assert batch["negative"].shape == ( + batch_size, + len(channel_names), + z_window_size, + *yx_patch_size, + ) diff --git a/tests/evaluation/test_cell_feature_metrics.py b/tests/evaluation/test_cell_feature_metrics.py new file mode 100644 index 000000000..d118fd381 --- /dev/null +++ b/tests/evaluation/test_cell_feature_metrics.py @@ -0,0 +1,257 @@ +import numpy as np +import pandas as pd +import pytest +from skimage import measure + +from viscy.representation.evaluation.feature import CellFeatures, DynamicFeatures + + +@pytest.fixture +def simple_image(): + """Create a simple test image with a known pattern""" + axis = [0.0, 0.2, 0.4, 0.2, 0.0] + x, y = np.meshgrid(axis, axis) + + image = x + y + + return image + + +@pytest.fixture +def simple_mask(): + """Create a simple binary mask.""" + from skimage.morphology import disk + + mask = disk(2) + return mask + + +def test_intensity_features(simple_image): + """Test computation of intensity-based features with known input.""" + cell_features = CellFeatures(simple_image) + cell_features.compute_intensity_features() + + features = cell_features.intensity_features + + assert np.isclose(features["mean_intensity"], 0.32, atol=1e-6) + assert np.isclose( + features["std_dev"], 0.21166010488516723, atol=1e-6 + ) # Actual std dev + assert np.isclose(features["min_intensity"], 0.0) + assert np.isclose(features["max_intensity"], 0.8) + + assert not np.isnan(features["kurtosis"]) + assert not np.isnan(features["skewness"]) + assert features["spectral_entropy"] > 0 + assert features["iqr"] > 0 + + +def test_texture_features(simple_image): + """Test computation of texture features with known input.""" + cell_features = CellFeatures(simple_image) + cell_features.compute_texture_features() + + features = cell_features.texture_features + + assert features["contrast"] >= 0 + assert features["dissimilarity"] >= 0 + assert 0 <= features["homogeneity"] <= 1 + + assert features["spectral_entropy"] > 0 + assert features["entropy"] > 0 + assert features["texture"] >= 0 + + +def test_morphology_features(simple_image, simple_mask): + """Test computation of morphological features with known input.""" + # Convert mask to labeled image + labeled_mask = measure.label(simple_mask.astype(int)) + + cell_features = CellFeatures(simple_image, labeled_mask) + cell_features.compute_morphology_features() + + features = cell_features.morphology_features + + assert features["area"] == 13 # Number of True pixels in mask + assert features["perimeter"] > 0 + assert features["perimeter_area_ratio"] > 0 + + assert 0 <= features["eccentricity"] <= 1 + + assert features["intensity_localization"] > 0 + assert features["masked_intensity"] > 0 + assert features["masked_area"] == 13 + + +def test_symmetry_descriptor(simple_image): + """Test computation of symmetry features with known input.""" + cell_features = CellFeatures(simple_image) + cell_features.compute_symmetry_descriptor() + + features = cell_features.symmetry_descriptor + + assert features["zernike_std"] >= 0 + assert not np.isnan(features["zernike_mean"]) + + assert features["radial_intensity_gradient"] < 0 + + +def test_all_features(simple_image, simple_mask): + """Test computation of all features together.""" + labeled_mask = measure.label(simple_mask.astype(int)) + + cell_features = CellFeatures(simple_image, labeled_mask) + features_df = cell_features.compute_all_features() + + # Test that all feature types are present + assert "mean_intensity" in features_df.columns + assert "contrast" in features_df.columns + assert "area" in features_df.columns + assert "zernike_std" in features_df.columns + + # Test that no features are NaN + assert not features_df.isna().any().any() + + +def test_edge_cases(): + """Test behavior with edge cases.""" + constant_image = np.ones((5, 5)) + cell_features = CellFeatures(constant_image) + cell_features.compute_intensity_features() + + features = cell_features.intensity_features + assert features["std_dev"] == 0 + assert np.isnan(features["kurtosis"]) # Kurtosis is undefined for constant values + assert np.isnan(features["skewness"]) # Skewness is undefined for constant values + + empty_mask = np.zeros((5, 5), dtype=int) + cell_features = CellFeatures(constant_image, empty_mask) + with pytest.raises(AssertionError): + cell_features.compute_morphology_features() + + +def test_normalization(simple_image, simple_mask): + """Test that features are invariant to intensity scaling.""" + # Convert mask to labeled image + labeled_mask = measure.label(simple_mask.astype(int)) + + # Compute features for original image + cell_features1 = CellFeatures(simple_image, labeled_mask) + features1 = cell_features1.compute_all_features() + + # Compute features for scaled image + scaled_image = simple_image * 2.0 + cell_features2 = CellFeatures(scaled_image, labeled_mask) + features2 = cell_features2.compute_all_features() + + # Compare features that should be invariant to scaling + assert np.allclose(features1["eccentricity"], features2["eccentricity"]) + assert np.allclose( + features1["perimeter_area_ratio"], features2["perimeter_area_ratio"] + ) + assert np.allclose(features1["zernike_std"], features2["zernike_std"]) + + +@pytest.fixture +def simple_track(): + """Create a simple track with known properties.""" + # Create a track moving in a straight line + t = np.array([0, 1, 2, 3, 4]) + x = np.array([0, 1, 2, 3, 4]) + y = np.array([0, 0, 0, 0, 0]) + track_id = np.array(["1"] * 5) + + return pd.DataFrame({"track_id": track_id, "t": t, "x": x, "y": y}) + + +def test_velocity_features(simple_track): + """Test computation of velocity features with known input.""" + dynamic_features = DynamicFeatures(simple_track) + features_df = dynamic_features.compute_all_features("1") + + # Test velocity features + assert "mean_velocity" in features_df.columns + assert "max_velocity" in features_df.columns + assert "min_velocity" in features_df.columns + assert "std_velocity" in features_df.columns + + # For straight line motion at constant speed: + assert np.isclose(features_df["mean_velocity"].iloc[0], 0.8, atol=1e-6) + assert np.isclose( + features_df["std_velocity"].iloc[0], 0.4, atol=1e-6 + ) # Actual std dev + + +def test_displacement_features(simple_track): + """Test computation of displacement features with known input.""" + dynamic_features = DynamicFeatures(simple_track) + features_df = dynamic_features.compute_all_features("1") + + assert "total_distance" in features_df.columns + assert "net_displacement" in features_df.columns + assert "directional_persistence" in features_df.columns + + # For straight line motion: + assert np.isclose( + features_df["total_distance"].iloc[0], 4.0, atol=1e-6 + ) # Total distance + assert np.isclose( + features_df["net_displacement"].iloc[0], 4.0, atol=1e-6 + ) # Net displacement + assert np.isclose( + features_df["directional_persistence"].iloc[0], 1.0, atol=1e-6 + ) # Perfect persistence + + +def test_angular_features(simple_track): + """Test computation of angular features with known input.""" + dynamic_features = DynamicFeatures(simple_track) + features_df = dynamic_features.compute_all_features("1") + + assert "mean_angular_velocity" in features_df.columns + assert "max_angular_velocity" in features_df.columns + assert "std_angular_velocity" in features_df.columns + + # For straight line motion: + assert np.isclose( + features_df["mean_angular_velocity"].iloc[0], 1.4142136207497351e-05, atol=1e-6 + ) # Actual small value + assert np.isclose(features_df["std_angular_velocity"].iloc[0], 0.0, atol=1e-6) + + +def test_tracking_edge_cases(): + """Test behavior with edge cases in tracking data.""" + # Test with single point + single_point = pd.DataFrame({"track_id": ["1"], "t": [0], "x": [0], "y": [0]}) + dynamic_features = DynamicFeatures(single_point) + features_df = dynamic_features.compute_all_features("1") + + assert np.isclose(features_df["mean_velocity"].iloc[0], 0.0, atol=1e-6) + assert np.isclose(features_df["total_distance"].iloc[0], 0.0, atol=1e-6) + assert np.isclose(features_df["mean_angular_velocity"].iloc[0], 0.0, atol=1e-6) + + two_points = pd.DataFrame( + {"track_id": ["1", "1"], "t": [0, 1], "x": [0, 1], "y": [0, 1]} + ) + dynamic_features = DynamicFeatures(two_points) + features_df = dynamic_features.compute_all_features("1") + + assert np.isclose( + features_df["mean_velocity"].iloc[0], 0.7071067811865476, atol=1e-6 + ) + assert np.isclose(features_df["mean_angular_velocity"].iloc[0], 0.0, atol=1e-6) + + +def test_tracking_invalid_data(): + """Test behavior with invalid tracking data.""" + # Test with missing columns + invalid_data = pd.DataFrame({"track_id": ["1"], "t": [0]}) + with pytest.raises(ValueError): + DynamicFeatures(invalid_data) + + # Test with non-numeric coordinates + invalid_data = pd.DataFrame( + {"track_id": ["1"], "t": [0], "x": ["invalid"], "y": [0]} + ) + with pytest.raises(ValueError): + DynamicFeatures(invalid_data) diff --git a/tests/representation/test_feature.py b/tests/representation/test_feature.py new file mode 100644 index 000000000..5bd83ebdc --- /dev/null +++ b/tests/representation/test_feature.py @@ -0,0 +1,114 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from iohub import open_ome_zarr + +from viscy.representation.evaluation.feature import ( + CellFeatures, + DynamicFeatures, +) + + +@pytest.mark.parametrize("channel_idx", [0, 1]) +def test_cell_features_with_labels_hcs( + small_hcs_dataset, small_hcs_labels, channel_idx +): + """Test CellFeatures with labels HCS dataset.""" + data_path = small_hcs_dataset + with open_ome_zarr(data_path) as dataset: + _, position = next(dataset.positions()) + image_array = position["0"] + + with open_ome_zarr(small_hcs_labels) as labels_dataset: + _, position = next(labels_dataset.positions()) + labels_array = position["0"] + + # Extract patch from center + patch_size = 16 + t, z = 0, 0 + y_center, x_center = image_array.shape[-2] // 2, image_array.shape[-1] // 2 + half_patch = patch_size // 2 + + y_slice = slice(y_center - half_patch, y_center + half_patch) + x_slice = slice(x_center - half_patch, x_center + half_patch) + + image_patch = image_array[t, channel_idx, z, y_slice, x_slice] + labels_patch = labels_array[t, channel_idx, z, y_slice, x_slice] + + cf = CellFeatures( + image=image_patch.astype(np.float32), segmentation_mask=labels_patch + ) + features_df = cf.compute_all_features() + + assert isinstance(features_df, pd.DataFrame) + assert len(features_df) == 1 + + assert "mean_intensity" in features_df.columns + assert "contrast" in features_df.columns + assert "zernike_std" in features_df.columns + assert "area" in features_df.columns + + for col in features_df.columns: + value = features_df[col].iloc[0] + if col in ["kurtosis", "skewness"]: + patch_std = np.std(image_patch) + if patch_std < 1e-10: + # For constant images, kurtosis and skewness should be NaN + assert np.isnan(value), ( + f"Feature {col} should be NaN for constant image (std={patch_std})" + ) + else: + # For non-constant images, values should be finite and reasonable + assert np.isfinite(value), ( + f"Feature {col} is not finite for non-constant image (std={patch_std})" + ) + assert -10 < value < 10, ( + f"Feature {col} = {value} seems unreasonable for random data" + ) + else: + assert np.isfinite(value), f"Feature {col} is not finite: {value}" + + +@pytest.mark.parametrize("fov_path", ["A/1/0", "A/1/1", "A/2/0"]) +def test_dynamic_features_with_tracks_hcs(tracks_hcs_dataset, fov_path): + """Test DynamicFeatures with tracks HCS dataset.""" + + tracks_path = Path(tracks_hcs_dataset) / fov_path / "tracks.csv" + if not tracks_path.exists(): + pytest.skip(f"Tracks file not found at {tracks_path}") + + tracks_df = pd.read_csv(tracks_path) + + if len(tracks_df) == 0: + pytest.skip("No tracks found in dataset") + + # Test with first track + first_track_id = tracks_df["track_id"].iloc[0] + + df = DynamicFeatures(tracks_df) + features_df = df.compute_all_features(first_track_id) + + assert isinstance(features_df, pd.DataFrame) + assert len(features_df) == 1 + + # Check expected columns + expected_cols = { + "mean_velocity", + "max_velocity", + "min_velocity", + "std_velocity", + "total_distance", + "net_displacement", + "directional_persistence", + "mean_angular_velocity", + "max_angular_velocity", + "std_angular_velocity", + "instantaneous_velocity", + } + assert set(features_df.columns) == expected_cols + + for col in features_df.columns: + if col != "instantaneous_velocity": + assert np.isfinite(features_df[col].iloc[0]), f"Feature {col} is not finite" diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py new file mode 100644 index 000000000..ac72a6601 --- /dev/null +++ b/viscy/data/cell_classification.py @@ -0,0 +1,185 @@ +from pathlib import Path +from typing import Callable + +import pandas as pd +import torch +from iohub.ngff import Plate, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from viscy.data.hcs import _read_norm_meta +from viscy.data.triplet import INDEX_COLUMNS + + +class ClassificationDataset(Dataset): + def __init__( + self, + plate: Plate, + annotation: pd.DataFrame, + channel_name: str, + z_range: tuple[int, int], + transform: Callable | None, + initial_yx_patch_size: tuple[int, int], + return_indices: bool = False, + ): + self.plate = plate + self.z_range = z_range + self.initial_yx_patch_size = initial_yx_patch_size + self.transform = transform + self.channel_name = channel_name + self.channel_index = plate.get_channel_index(channel_name) + self.return_indices = return_indices + y_exclude, x_exclude = ( + self.initial_yx_patch_size[0] // 2, + self.initial_yx_patch_size[1] // 2, + ) + example_image_shape = next(plate.positions())[1]["0"].shape + y_range = (y_exclude, example_image_shape[-2] - y_exclude) + x_range = (x_exclude, example_image_shape[-1] - x_exclude) + self.annotation = annotation[ + annotation["y"].between(*y_range, inclusive="neither") + & annotation["x"].between(*x_range, inclusive="neither") + ] + + def __len__(self): + return len(self.annotation) + + def __getitem__( + self, idx + ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, dict[str, int | str]]: + row = self.annotation.iloc[idx] + fov_name, t, y, x = row["fov_name"], row["t"], row["y"], row["x"] + fov = self.plate[fov_name] + y_half, x_half = (s // 2 for s in self.initial_yx_patch_size) + image = torch.from_numpy( + fov["0"][ + t, + self.channel_index, + slice(*self.z_range), + slice(y - y_half, y + y_half), + slice(x - x_half, x + x_half), + ] + ).float()[None] + norm_meta = _read_norm_meta(fov)[self.channel_name]["fov_statistics"] + img = (image - norm_meta["mean"]) / norm_meta["std"] + if self.transform is not None: + img = self.transform(img) + label = torch.tensor(row["infection_state"]).float()[None] + if self.return_indices: + return img, label, row[INDEX_COLUMNS].to_dict() + else: + return img, label + + +class ClassificationDataModule(LightningDataModule): + def __init__( + self, + image_path: Path, + annotation_path: Path, + val_fovs: list[str] | None, + channel_name: str, + z_range: tuple[int, int], + train_exlude_timepoints: list[int], + train_transforms: list[Callable] | None, + val_transforms: list[Callable] | None, + initial_yx_patch_size: tuple[int, int], + batch_size: int, + num_workers: int, + ): + super().__init__() + self.image_path = image_path + self.annotation_path = annotation_path + self.val_fovs = val_fovs + self.channel_name = channel_name + self.z_range = z_range + self.train_exlude_timepoints = train_exlude_timepoints + self.train_transform = Compose(train_transforms) + self.val_transform = Compose(val_transforms) + self.initial_yx_patch_size = initial_yx_patch_size + self.batch_size = batch_size + self.num_workers = num_workers + + def _subset( + self, + plate: Plate, + annotation: pd.DataFrame, + fov_names: list[str], + transform: Callable | None, + exclude_timepoints: list[int] = [], + return_indices: bool = False, + ) -> ClassificationDataset: + if exclude_timepoints: + filter_timepoints = annotation["t"].isin(exclude_timepoints) + annotation = annotation[~filter_timepoints] + return ClassificationDataset( + plate=plate, + annotation=annotation[annotation["fov_name"].isin(fov_names)], + channel_name=self.channel_name, + z_range=self.z_range, + transform=transform, + initial_yx_patch_size=self.initial_yx_patch_size, + return_indices=return_indices, + ) + + def setup(self, stage=None): + plate = open_ome_zarr(self.image_path) + all_fovs = ["/" + name for (name, _) in plate.positions()] + annotation = pd.read_csv(self.annotation_path) + for column in ("t", "y", "x"): + annotation[column] = annotation[column].astype(int) + if stage in (None, "fit", "validate"): + train_fovs = list(set(all_fovs) - set(self.val_fovs)) + self.train_dataset = self._subset( + plate, + annotation, + train_fovs, + transform=self.train_transform, + exclude_timepoints=self.train_exlude_timepoints, + ) + self.val_dataset = self._subset( + plate, + annotation, + self.val_fovs, + transform=self.val_transform, + exclude_timepoints=[], + ) + elif stage == "predict": + self.predict_dataset = ClassificationDataset( + plate=plate, + annotation=annotation, + channel_name=self.channel_name, + z_range=self.z_range, + transform=None, + initial_yx_patch_size=self.initial_yx_patch_size, + return_indices=True, + ) + elif stage == "test": + raise NotImplementedError("Test stage not implemented.") + else: + raise (f"Unknown stage: {stage}") + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 3ec8ca9e2..5eb200f33 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -18,8 +18,8 @@ from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _ensure_channel_list, _read_norm_meta +from viscy.data.select import SelectWell from viscy.data.typing import DictTransform, NormMeta -from viscy.preprocessing.precompute import _filter_fovs, _filter_wells if TYPE_CHECKING: from multiprocessing.managers import DictProxy @@ -173,22 +173,6 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]: return sample -class SelectWell: - _include_wells: list[str] | None - _exclude_fovs: list[str] | None - - def _filter_fit_fovs(self, plate: Plate) -> list[Position]: - positions = [] - for well in _filter_wells(plate, include_wells=self._include_wells): - for fov in _filter_fovs(well, exclude_fovs=self._exclude_fovs): - positions.append(fov) - if len(positions) < 2: - raise ValueError( - "At least 2 FOVs are required for training and validation." - ) - return positions - - class CachedOmeZarrDataModule(GPUTransformDataModule, SelectWell): """Data module for cached OME-Zarr arrays. diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 52d0e5103..4a7e9dfb5 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -345,6 +345,7 @@ def __init__( ground_truth_masks: Path | None = None, persistent_workers=False, prefetch_factor=None, + pin_memory=False, ): super().__init__() self.data_path = Path(data_path) @@ -363,6 +364,7 @@ def __init__( self.prepare_data_per_node = True self.persistent_workers = persistent_workers self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory @property def cache_path(self): @@ -554,6 +556,7 @@ def train_dataloader(self): persistent_workers=self.persistent_workers, collate_fn=_collate_samples, drop_last=True, + pin_memory=self.pin_memory, ) def val_dataloader(self): @@ -564,6 +567,7 @@ def val_dataloader(self): shuffle=False, prefetch_factor=self.prefetch_factor if self.num_workers else None, persistent_workers=self.persistent_workers, + pin_memory=self.pin_memory, ) def test_dataloader(self): diff --git a/viscy/data/mmap_cache.py b/viscy/data/mmap_cache.py index f621c87ca..735159903 100644 --- a/viscy/data/mmap_cache.py +++ b/viscy/data/mmap_cache.py @@ -16,8 +16,9 @@ from torch.multiprocessing import Manager from torch.utils.data import Dataset -from viscy.data.gpu_aug import GPUTransformDataModule, SelectWell +from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.hcs import _ensure_channel_list, _read_norm_meta +from viscy.data.select import SelectWell from viscy.data.typing import DictTransform, NormMeta if TYPE_CHECKING: diff --git a/viscy/data/select.py b/viscy/data/select.py new file mode 100644 index 000000000..6e00c10e8 --- /dev/null +++ b/viscy/data/select.py @@ -0,0 +1,36 @@ +from typing import Generator + +from iohub.ngff.nodes import Plate, Position, Well + + +def _filter_wells( + plate: Plate, include_wells: list[str] | None +) -> Generator[Well, None, None]: + for well_name, well in plate.wells(): + if include_wells is None or well_name in include_wells: + yield well + + +def _filter_fovs( + well: Well, exclude_fovs: list[str] | None +) -> Generator[Position, None, None]: + for _, fov in well.positions(): + fov_name = fov.zgroup.name.strip("/") + if exclude_fovs is None or fov_name not in exclude_fovs: + yield fov + + +class SelectWell: + _include_wells: list[str] | None + _exclude_fovs: list[str] | None + + def _filter_fit_fovs(self, plate: Plate) -> list[Position]: + positions = [] + for well in _filter_wells(plate, include_wells=self._include_wells): + for fov in _filter_fovs(well, exclude_fovs=self._exclude_fovs): + positions.append(fov) + if len(positions) < 2: + raise ValueError( + "At least 2 FOVs are required for training and validation." + ) + return positions diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index a04a20f47..c25a0fc74 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -10,11 +10,22 @@ from torch.utils.data import Dataset from viscy.data.hcs import HCSDataModule, _read_norm_meta +from viscy.data.select import _filter_fovs, _filter_wells from viscy.data.typing import DictTransform, NormMeta, TripletSample _logger = logging.getLogger("lightning.pytorch") -INDEX_COLUMNS = ["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] +INDEX_COLUMNS = [ + "fov_name", + "track_id", + "t", + "id", + "parent_track_id", + "parent_id", + "z", + "y", + "x", +] def _scatter_channels( @@ -272,7 +283,16 @@ def __getitem__(self, index: int) -> TripletSample: else: sample.update({"positive": positive_patch}) else: - sample.update({"index": anchor_row[INDEX_COLUMNS].to_dict()}) + # For new predictions, ensure all INDEX_COLUMNS are included + index_dict = {} + for col in INDEX_COLUMNS: + if col in anchor_row.index: + index_dict[col] = anchor_row[col] + else: + # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS + if col not in ["y", "x", "z"]: + raise KeyError(f"Required column '{col}' not found in data") + sample.update({"index": index_dict}) return sample @@ -291,11 +311,16 @@ def __init__( normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], caching: bool = False, + fit_include_wells: list[str] | None = None, + fit_exclude_fovs: list[str] | None = None, predict_cells: bool = False, include_fov_names: list[str] | None = None, include_track_ids: list[int] | None = None, time_interval: Literal["any"] | int = "any", return_negative: bool = True, + persistent_workers: bool = False, + prefetch_factor: int | None = None, + pin_memory: bool = False, ): """Lightning data module for triplet sampling of patches. @@ -325,6 +350,10 @@ def __init__( Augmentation transforms, by default [] caching : bool, optional Whether to cache the dataset, by default False + fit_include_wells : list[str], optional + Only include these wells for fitting, by default None + fit_exclude_fovs : list[str], optional + Exclude these FOVs for fitting, by default None predict_cells : bool, optional Only predict for selected cells, by default False include_fov_names : list[str] | None, optional @@ -339,6 +368,12 @@ def __init__( Whether to return the negative sample during the fit stage (can be set to False when using a loss function like NT-Xent), by default True + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False """ super().__init__( data_path=data_path, @@ -353,10 +388,15 @@ def __init__( normalizations=normalizations, augmentations=augmentations, caching=caching, + persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, ) self.z_range = slice(*z_range) self.tracks_path = Path(tracks_path) self.initial_yx_patch_size = initial_yx_patch_size + self._include_wells = fit_include_wells + self._exclude_fovs = fit_exclude_fovs self.predict_cells = predict_cells self.include_fov_names = include_fov_names self.include_track_ids = include_track_ids @@ -377,12 +417,13 @@ def _align_tracks_tables_with_positions( positions = [] tracks_tables = [] images_plate = open_ome_zarr(self.data_path) - for fov_name, _ in open_ome_zarr(self.tracks_path).positions(): - positions.append(images_plate[fov_name]) - tracks_df = pd.read_csv( - next((self.tracks_path / fov_name).glob("*.csv")) - ).astype(int) - tracks_tables.append(tracks_df) + for well in _filter_wells(images_plate, include_wells=self._include_wells): + for fov in _filter_fovs(well, exclude_fovs=self._exclude_fovs): + positions.append(fov) + tracks_df = pd.read_csv( + next((self.tracks_path / fov.zgroup.name.strip("/")).glob("*.csv")) + ).astype(int) + tracks_tables.append(tracks_df) return positions, tracks_tables diff --git a/viscy/preprocessing/precompute.py b/viscy/preprocessing/precompute.py index 815354ded..1c68ad300 100644 --- a/viscy/preprocessing/precompute.py +++ b/viscy/preprocessing/precompute.py @@ -3,14 +3,13 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Generator, Literal +from typing import Literal import dask.array as da from dask.diagnostics import ProgressBar from iohub.ngff import open_ome_zarr -if TYPE_CHECKING: - from iohub.ngff.nodes import Plate, Position, Well +from viscy.data.select import _filter_fovs, _filter_wells def _normalize_image( @@ -31,22 +30,6 @@ def _normalize_image( return (image - subtrahend_value) / divisor_value -def _filter_wells( - plate: Plate, include_wells: list[str] | None -) -> Generator[Well, None, None]: - for well_name, well in plate.wells(): - if include_wells is None or well_name in include_wells: - yield well - - -def _filter_fovs( - well: Well, exclude_fovs: list[str] | None -) -> Generator[Position, None, None]: - for fov_name, fov in well.positions(): - if exclude_fovs is None or fov_name not in exclude_fovs: - yield fov - - def precompute_array( data_path: Path, output_path: Path, diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py new file mode 100644 index 000000000..0b4ed58a8 --- /dev/null +++ b/viscy/representation/classification.py @@ -0,0 +1,105 @@ +from pathlib import Path + +import pandas as pd +import torch +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import BasePredictionWriter +from torch import nn +from torchmetrics.functional.classification import binary_accuracy, binary_f1_score + +from viscy.representation.contrastive import ContrastiveEncoder +from viscy.utils.log_images import render_images + + +class ClassificationPredictionWriter(BasePredictionWriter): + def __init__(self, output_path: Path): + super().__init__("epoch") + if Path(output_path).exists(): + raise FileExistsError(f"Output path {output_path} already exists.") + self.output_path = output_path + + def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + all_predictions = [] + for prediction in predictions: + for key, value in prediction.items(): + if isinstance(value, torch.Tensor): + prediction[key] = value.detach().cpu().numpy().flatten() + all_predictions.append(pd.DataFrame(prediction)) + pd.concat(all_predictions).to_csv(self.output_path, index=False) + + +class ClassificationModule(LightningModule): + def __init__( + self, + encoder: ContrastiveEncoder, + lr: float | None, + loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), + ): + super().__init__() + self.stem = encoder.stem + self.backbone = encoder.encoder + self.backbone.head.fc = nn.Linear(768, 1) + self.loss = loss + self.lr = lr + self.example_input_array = torch.rand(2, 1, 15, 160, 160) + + def forward(self, x): + x = self.stem(x) + return self.backbone(x) + + def on_fit_start(self): + self.train_examples = [] + self.val_examples = [] + + def _fit_step(self, batch, stage: str, loss_on_step: bool): + x, y = batch + y_hat = self(x) + loss = self.loss(y_hat, y) + acc = binary_accuracy(y_hat, y) + f1 = binary_f1_score(y_hat, y) + self.log(f"loss/{stage}", loss, on_step=loss_on_step, on_epoch=True) + self.log_dict( + {f"metric/accuracy/{stage}": acc, f"metric/f1_score/{stage}": f1}, + on_step=False, + on_epoch=True, + ) + return loss, x[0, 0, x.shape[2] // 2].detach().cpu().numpy() + + def training_step(self, batch, batch_idx: int): + loss, example = self._fit_step(batch, "train", loss_on_step=True) + if batch_idx < 4: + self.train_examples.append([example]) + return loss + + def validation_step(self, batch, batch_idx: int): + loss, example = self._fit_step(batch, "val", loss_on_step=False) + if batch_idx < 4: + self.val_examples.append([example]) + return loss + + def predict_step(self, batch, batch_idx: int, dataloader_idx: int | None = None): + x, y, indices = batch + y_hat = nn.functional.sigmoid(self(x)) + indices["label"] = y + indices["prediction"] = y_hat + return indices + + def _log_images(self, examples, stage): + image = render_images(examples) + self.logger.experiment.add_image( + f"{stage}/examples", + image, + global_step=self.current_epoch, + dataformats="HWC", + ) + + def on_train_epoch_end(self): + self._log_images(self.train_examples, "train") + self.train_examples.clear() + + def on_validation_epoch_end(self): + self._log_images(self.val_examples, "val") + self.val_examples.clear() + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index c188c7396..7b6296356 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -1,7 +1,8 @@ import logging from pathlib import Path -from typing import Literal, Sequence +from typing import Any, Dict, Literal, Optional, Sequence +import numpy as np import pandas as pd import torch from lightning.pytorch import LightningModule, Trainer @@ -13,16 +14,18 @@ from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( _fit_transform_umap, + compute_pca, compute_phate, ) -__all__ = ["read_embedding_dataset", "EmbeddingWriter"] +__all__ = ["read_embedding_dataset", "EmbeddingWriter", "write_embedding_dataset"] _logger = logging.getLogger("lightning.pytorch") def read_embedding_dataset(path: Path) -> Dataset: """ Read the embedding dataset written by the EmbeddingWriter callback. + Supports both legacy datasets (without x/y coordinates) and new datasets. Parameters ---------- @@ -34,7 +37,19 @@ def read_embedding_dataset(path: Path) -> Dataset: Dataset Xarray dataset with features and projections. """ - return open_zarr(path).set_index(sample=INDEX_COLUMNS) + dataset = open_zarr(path) + # Check which index columns are present in the dataset + available_cols = [col for col in INDEX_COLUMNS if col in dataset.coords] + + # Warn if any INDEX_COLUMNS are missing + missing_cols = set(INDEX_COLUMNS) - set(available_cols) + if missing_cols: + _logger.warning( + f"Dataset at {path} is missing index columns: {sorted(missing_cols)}. " + "This appears to be a legacy dataset format." + ) + + return dataset.set_index(sample=available_cols) def _move_and_stack_embeddings( @@ -44,47 +59,166 @@ def _move_and_stack_embeddings( return torch.cat([p[key].cpu() for p in predictions], dim=0).numpy() -class EmbeddingWriter(BasePredictionWriter): +def write_embedding_dataset( + output_path: Path, + features: np.ndarray, + index_df: pd.DataFrame, + projections: Optional[np.ndarray] = None, + umap_kwargs: Optional[Dict[str, Any]] = None, + phate_kwargs: Optional[Dict[str, Any]] = None, + pca_kwargs: Optional[Dict[str, Any]] = None, + overwrite: bool = False, +) -> None: """ - Callback to write embeddings to a zarr store in an Xarray-compatible format. + Write embeddings to a zarr store in an Xarray-compatible format. Parameters ---------- output_path : Path Path to the zarr store. - write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional - When to write the embeddings, by default 'epoch'. - phate_kwargs : dict, optional - Keyword arguments passed to PHATE, by default None. + features : np.ndarray + Array of shape (n_samples, n_features) containing the embeddings. + index_df : pd.DataFrame + DataFrame containing the index information for each embedding. + projections : np.ndarray, optional + Array of shape (n_samples, n_projections) containing projection values, by default None. + umap_kwargs : Dict[str, Any], optional + Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed) + Common parameters include: + - n_components: int, dimensions for projection (default: 2) + - n_neighbors: int, number of nearest neighbors (default: 15) + - min_dist: float, minimum distance between points (default: 0.1) + - metric: str, distance metric (default: 'euclidean') + phate_kwargs : Dict[str, Any], optional + Keyword arguments passed to PHATE, by default None (i.e. PHATE is not computed) Common parameters include: - knn: int, number of nearest neighbors (default: 5) - decay: int, decay rate for kernel (default: 40) - n_jobs: int, number of jobs for parallel processing - t: int, number of diffusion steps - potential_method: str, potential method to use - See phate.PHATE for all available parameters. + pca_kwargs : Dict[str, Any], optional + Keyword arguments passed to PCA, by default None (i.e. PCA is not computed) + Common parameters include: + - n_components: int, dimensions for projection (default: 2) + - whiten: bool, whether to whiten (default: False) + overwrite : bool, optional + Whether to overwrite existing zarr store, by default False. + + Raises + ------ + FileExistsError + If output_path exists and overwrite is False. + """ + output_path = Path(output_path) + + # Check if output_path exists + if output_path.exists() and not overwrite: + raise FileExistsError(f"Output path {output_path} already exists.") + + # Create a copy of the index DataFrame to avoid modifying the original + ultrack_indices = index_df.copy() + n_samples = len(features) + + # Set up default kwargs for each method + if umap_kwargs: + if umap_kwargs["n_neighbors"] >= n_samples: + _logger.warning( + f"Reducing n_neighbors from {umap_kwargs['n_neighbors']} to {min(15, n_samples // 2)} due to small dataset size" + ) + umap_kwargs["n_neighbors"] = min(15, n_samples // 2) + + _logger.debug(f"Using UMAP kwargs: {umap_kwargs}") + _, UMAP = _fit_transform_umap(features, **umap_kwargs) + for i in range(UMAP.shape[1]): + ultrack_indices[f"UMAP{i + 1}"] = UMAP[:, i] + + if phate_kwargs: + # Update with user-provided kwargs + _logger.debug(f"Using PHATE kwargs: {phate_kwargs}") + # Ensure knn is appropriate for dataset size for PHATE + if phate_kwargs["knn"] >= n_samples: + _logger.warning( + f"Reducing knn from {phate_kwargs['knn']} to {max(2, n_samples // 2)} due to small dataset size" + ) + phate_kwargs["knn"] = max(2, n_samples // 2) + + # Compute PHATE + try: + _logger.debug("Computing PHATE") + _, PHATE = compute_phate(features, **phate_kwargs) + for i in range(PHATE.shape[1]): + ultrack_indices[f"PHATE{i + 1}"] = PHATE[:, i] + except Exception as e: + _logger.warning(f"PHATE computation failed: {str(e)}") + + if pca_kwargs: + # Update with user-provided kwargs + _logger.debug(f"Using PCA kwargs: {pca_kwargs}") + try: + _logger.debug("Computing PCA") + PCA_features, _ = compute_pca(features, **pca_kwargs) + for i in range(PCA_features.shape[1]): + ultrack_indices[f"PCA{i + 1}"] = PCA_features[:, i] + except Exception as e: + _logger.warning(f"PCA computation failed: {str(e)}") + + # Create multi-index and dataset + index = pd.MultiIndex.from_frame(ultrack_indices) + + # Create dataset dictionary with features + dataset_dict = {"features": (("sample", "features"), features)} + + # Add projections if provided + if projections is not None: + dataset_dict["projections"] = (("sample", "projections"), projections) + + # Create the dataset + dataset = Dataset(dataset_dict, coords={"sample": index}).reset_index("sample") + + _logger.debug(f"Writing dataset to {output_path}") + with dataset.to_zarr(output_path, mode="w") as zarr_store: + zarr_store.close() + + +class EmbeddingWriter(BasePredictionWriter): + """ + Callback to write embeddings to a zarr store in an Xarray-compatible format. + + Parameters + ---------- + output_path : Path + Path to the zarr store. + write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional + When to write the embeddings, by default 'epoch'. + umap_kwargs : dict, optional + Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed). + phate_kwargs : dict, optional + Keyword arguments passed to PHATE, by default PHATE is computed with default parameters. + pca_kwargs : dict, optional + Keyword arguments passed to PCA, by default PCA is computed with default parameters. """ def __init__( self, output_path: Path, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch", - phate_kwargs: dict | None = None, - ): - super().__init__(write_interval) - self.output_path = Path(output_path) - - # Set default PHATE parameters - default_phate_kwargs = { - "n_components": 2, + umap_kwargs: dict | None = None, + phate_kwargs: dict | None = { "knn": 5, "decay": 40, "n_jobs": -1, "random_state": 42, - } - if phate_kwargs is not None: - default_phate_kwargs.update(phate_kwargs) - self.phate_kwargs = default_phate_kwargs + }, + pca_kwargs: dict | None = {"n_components": 8}, + overwrite: bool = False, + ): + super().__init__(write_interval) + self.output_path = Path(output_path) + self.umap_kwargs = umap_kwargs + self.phate_kwargs = phate_kwargs + self.pca_kwargs = pca_kwargs + self.overwrite = overwrite def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if self.output_path.exists(): @@ -115,29 +249,13 @@ def write_on_epoch_end( projections = _move_and_stack_embeddings(predictions, "projections") ultrack_indices = pd.concat([pd.DataFrame(p["index"]) for p in predictions]) - _logger.info( - f"Computing dimensionality reductions for {len(features)} samples." + write_embedding_dataset( + output_path=self.output_path, + features=features, + index_df=ultrack_indices, + projections=projections, + umap_kwargs=self.umap_kwargs, + phate_kwargs=self.phate_kwargs, + pca_kwargs=self.pca_kwargs, + overwrite=self.overwrite, ) - _, umap = _fit_transform_umap(features, n_components=2, normalize=True) - _, phate = compute_phate( - features, - **self.phate_kwargs, - ) - - # Add dimensionality reduction coordinates - ultrack_indices["UMAP1"], ultrack_indices["UMAP2"] = umap[:, 0], umap[:, 1] - ultrack_indices["PHATE1"], ultrack_indices["PHATE2"] = phate[:, 0], phate[:, 1] - - # Create multi-index and dataset - index = pd.MultiIndex.from_frame(ultrack_indices) - dataset = Dataset( - { - "features": (("sample", "features"), features), - "projections": (("sample", "projections"), projections), - }, - coords={"sample": index}, - ).reset_index("sample") - - _logger.debug(f"Writing predictions dataset:\n{dataset}") - with dataset.to_zarr(self.output_path, mode="w") as zarr_store: - zarr_store.close() diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index ca67263a2..7a35d93f1 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -50,9 +50,20 @@ def __init__( self.validation_step_outputs = [] self.log_embeddings = log_embeddings - def forward(self, x: Tensor) -> Tensor: - "Only return projected embeddings for training and validation." - return self.model(x)[1] + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Return both features and projections. + + Parameters + ---------- + x : Tensor + Input tensor + + Returns + ------- + tuple[Tensor, Tensor] + Tuple of (features, projections) + """ + return self.model(x) def log_feature_statistics(self, embeddings: Tensor, prefix: str): mean = torch.mean(embeddings, dim=0).detach().cpu().numpy() @@ -137,8 +148,8 @@ def log_embedding_umap(self, embeddings: Tensor, tag: str): def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: anchor_img = batch["anchor"] pos_img = batch["positive"] - anchor_projection = self(anchor_img) - positive_projection = self(pos_img) + _, anchor_projection = self(anchor_img) + _, positive_projection = self(pos_img) negative_projection = None if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( @@ -151,7 +162,7 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: self._log_step_samples(batch_idx, (anchor_img, pos_img), "train") else: neg_img = batch["negative"] - negative_projection = self(neg_img) + _, negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) @@ -180,8 +191,8 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: """Validation step of the model.""" anchor = batch["anchor"] pos_img = batch["positive"] - anchor_projection = self(anchor) - positive_projection = self(pos_img) + _, anchor_projection = self(anchor) + _, positive_projection = self(pos_img) negative_projection = None if isinstance(self.loss_function, NTXentLoss): indices = torch.arange( @@ -194,7 +205,7 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: self._log_step_samples(batch_idx, (anchor, pos_img), "val") else: neg_img = batch["negative"] - negative_projection = self(neg_img) + _, negative_projection = self(neg_img) loss = self.loss_function( anchor_projection, positive_projection, negative_projection ) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 5916f7124..eb5d43f91 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,7 +1,6 @@ """PCA and UMAP dimensionality reduction.""" import pandas as pd -import phate import umap from numpy.typing import NDArray from sklearn.decomposition import PCA @@ -16,7 +15,7 @@ def compute_phate( decay: int = 40, update_dataset: bool = False, **phate_kwargs, -) -> tuple[phate.PHATE, NDArray]: +) -> tuple[object, NDArray]: """ Compute PHATE embeddings for features and optionally update dataset. @@ -38,10 +37,20 @@ def compute_phate( Returns ------- - phate.PHATE, NDArray + tuple[object, NDArray] PHATE model and PHATE embeddings + + Raises + ------ + ImportError + If PHATE is not installed. Install with: pip install viscy[phate] """ - import phate + try: + import phate + except ImportError: + raise ImportError( + "PHATE is not available. Install with: pip install viscy[phate]" + ) # Get embeddings from dataset if needed embeddings = ( @@ -67,43 +76,58 @@ def compute_phate( def compute_pca(embedding_dataset, n_components=None, normalize_features=True): - features = embedding_dataset["features"] - projections = embedding_dataset["projections"] + """Compute PCA embeddings for features and optionally update dataset. + + Parameters + ---------- + embedding_dataset : xarray.Dataset or NDArray + The dataset containing embeddings, timepoints, fov_name, and track_id, + or a numpy array of embeddings. + n_components : int, optional + Number of components to keep in the PCA, by default None + normalize_features : bool, optional + Whether to normalize the features, by default True + + Returns + ------- + tuple[NDArray, pd.DataFrame] + PCA embeddings and PCA DataFrame + """ + + embeddings = ( + embedding_dataset["features"].values + if isinstance(embedding_dataset, Dataset) + else embedding_dataset + ) if normalize_features: - scaled_projections = StandardScaler().fit_transform(projections.values) - scaled_features = StandardScaler().fit_transform(features.values) + scaled_features = StandardScaler().fit_transform(embeddings) else: - scaled_projections = projections.values - scaled_features = features.values + scaled_features = embeddings # Compute PCA with specified number of components PCA_features = PCA(n_components=n_components, random_state=42) - PCA_projection = PCA(n_components=n_components, random_state=42) pc_features = PCA_features.fit_transform(scaled_features) - pc_projection = PCA_projection.fit_transform(scaled_projections) - # Prepare DataFrame with id and PCA coordinates - pca_df = pd.DataFrame( - { + # Create base dictionary with id and fov_name + if isinstance(embedding_dataset, Dataset): + pca_dict = { "id": embedding_dataset["id"].values, "fov_name": embedding_dataset["fov_name"].values, - "PCA1": pc_features[:, 0], - "PCA2": pc_features[:, 1], - "PCA3": pc_features[:, 2], - "PCA4": pc_features[:, 3], - "PCA5": pc_features[:, 4], - "PCA6": pc_features[:, 5], - "PCA1_proj": pc_projection[:, 0], - "PCA2_proj": pc_projection[:, 1], - "PCA3_proj": pc_projection[:, 2], - "PCA4_proj": pc_projection[:, 3], - "PCA5_proj": pc_projection[:, 4], - "PCA6_proj": pc_projection[:, 5], + "t": embedding_dataset["t"].values, + "track_id": embedding_dataset["track_id"].values, } - ) + else: + pca_dict = {} + + # Add PCA components for features + for i in range(pc_features.shape[1]): + pca_dict[f"PCA{i + 1}"] = pc_features[:, i] + + # Create DataFrame with all components + pca_df = pd.DataFrame(pca_dict) - return PCA_features, PCA_projection, pca_df + return pc_features, pca_df def _fit_transform_umap( diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 9a1c72ef3..a920eb072 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Literal import numpy as np from sklearn.metrics.pairwise import cosine_similarity @@ -22,65 +23,111 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_displacement( embedding_dataset, - max_tau=10, - use_cosine=False, - use_dissimilarity=False, - use_umap=False, - return_mean_std=False, -): - """Compute the norm of differences between embeddings at t and t + tau""" + distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", +) -> dict[int, list[float]]: + """Compute the displacement or mean square displacement (MSD) of embeddings. + + For each time difference τ, computes either: + - |r(t + τ) - r(t)|² for squared Euclidean (MSD) + - cos_sim(r(t + τ), r(t)) for cosine + for all particles and initial times t. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing embeddings and metadata + distance_metric : str + The metric to use for computing distances between embeddings. + Valid options are: + - "euclidean": Euclidean distance (L2 norm) + - "euclidean_squared": Squared Euclidean distance (for MSD, default) + - "cosine": Cosine similarity + - "cosine_dissimilarity": 1 - cosine similarity + + Returns + ------- + dict[int, list[float]] + Dictionary mapping τ to list of displacements for all particles and initial times + """ + # Get unique tracks efficiently using pandas operations + unique_tracks_df = ( + embedding_dataset[["fov_name", "track_id"]].to_dataframe().drop_duplicates() + ) + + # Get data from dataset fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values - if use_umap: - embeddings = np.vstack( - (embedding_dataset["UMAP1"].values, embedding_dataset["UMAP2"].values) - ).T - else: - embeddings = embedding_dataset["features"].values - + # Initialize results dictionary with empty lists displacement_per_tau = defaultdict(list) - for i in range(len(fov_names)): - fov_name = fov_names[i] - track_id = track_ids[i] - current_time = timepoints[i] - current_embedding = embeddings[i].reshape(1, -1) - - for tau in range(1, max_tau + 1): - future_time = current_time + tau - matching_indices = np.where( - (fov_names == fov_name) - & (track_ids == track_id) - & (timepoints == future_time) - )[0] - - if len(matching_indices) == 1: - future_embedding = embeddings[matching_indices[0]].reshape(1, -1) - - if use_cosine: - similarity = cosine_similarity(current_embedding, future_embedding)[ - 0 - ][0] - displacement = 1 - similarity if use_dissimilarity else similarity - else: - displacement = np.sum((current_embedding - future_embedding) ** 2) - - displacement_per_tau[tau].append(displacement) - - if return_mean_std: - mean_displacement_per_tau = { - tau: np.mean(displacements) - for tau, displacements in displacement_per_tau.items() - } - std_displacement_per_tau = { - tau: np.std(displacements) - for tau, displacements in displacement_per_tau.items() - } - return mean_displacement_per_tau, std_displacement_per_tau - - return displacement_per_tau + # Process each track + for fov_name, track_id in zip( + unique_tracks_df["fov_name"], unique_tracks_df["track_id"] + ): + # Get sorted track data + mask = (fov_names == fov_name) & (track_ids == track_id) + times = timepoints[mask] + track_embeddings = embeddings[mask] + + # Sort by time + time_order = np.argsort(times) + times = times[time_order] + track_embeddings = track_embeddings[time_order] + + # Process each time point + for t_idx, t in enumerate(times[:-1]): + current_embedding = track_embeddings[t_idx] + + # Check all possible future time points + for future_idx, future_time in enumerate( + times[t_idx + 1 :], start=t_idx + 1 + ): + tau = future_time - t + future_embedding = track_embeddings[future_idx] + + if distance_metric in ["cosine"]: + dot_product = np.dot(current_embedding, future_embedding) + norms = np.linalg.norm(current_embedding) * np.linalg.norm( + future_embedding + ) + similarity = dot_product / norms + displacement = similarity + else: # Euclidean metrics + diff_squared = np.sum((current_embedding - future_embedding) ** 2) + displacement = diff_squared + displacement_per_tau[int(tau)].append(displacement) + + return dict(displacement_per_tau) + + +def compute_displacement_statistics( + displacement_per_tau: dict[int, list[float]], +) -> tuple[dict[int, float], dict[int, float]]: + """Compute mean and standard deviation of displacements for each tau. + + Parameters + ---------- + displacement_per_tau : dict[int, list[float]] + Dictionary mapping τ to list of displacements + + Returns + ------- + tuple[dict[int, float], dict[int, float]] + Tuple of (mean_displacements, std_displacements) where each is a + dictionary mapping τ to the statistic + """ + mean_displacement_per_tau = { + tau: np.mean(displacements) + for tau, displacements in displacement_per_tau.items() + } + std_displacement_per_tau = { + tau: np.std(displacements) + for tau, displacements in displacement_per_tau.items() + } + return mean_displacement_per_tau, std_displacement_per_tau def compute_dynamic_range(mean_displacement_per_tau): diff --git a/viscy/representation/evaluation/feature.py b/viscy/representation/evaluation/feature.py index 3ecbd1b6e..c06138629 100644 --- a/viscy/representation/evaluation/feature.py +++ b/viscy/representation/evaluation/feature.py @@ -1,73 +1,260 @@ +from typing import TypedDict + +import mahotas as mh import numpy as np +import pandas as pd +import scipy.stats from numpy import fft +from numpy.typing import ArrayLike +from scipy.ndimage import distance_transform_edt +from scipy.stats import linregress +from skimage.exposure import rescale_intensity from skimage.feature import graycomatrix, graycoprops from skimage.filters import gaussian, threshold_otsu +from skimage.measure import regionprops -class FeatureExtractor: - # FIXME: refactor into a separate module with standalone functions +class IntensityFeatures(TypedDict): + """Intensity-based features extracted from a single cell.""" - def __init__(self): - pass + mean_intensity: float + std_dev: float + min_intensity: float + max_intensity: float + kurtosis: float + skewness: float + spectral_entropy: float + iqr: float + weighted_intensity_gradient: float - def compute_fourier_descriptors(image): - """ - Compute the Fourier descriptors of the image - The sensor or nuclear shape changes when infected, which can be captured by analyzing Fourier descriptors - :param np.array image: input image - :return: Fourier descriptors - """ - # Convert contour to complex numbers - contour_complex = image[:, 0] + 1j * image[:, 1] - # Compute Fourier descriptors - descriptors = np.fft.fft(contour_complex) +class TextureFeatures(TypedDict): + """Texture-based features extracted from a single cell.""" + + spectral_entropy: float + contrast: float + entropy: float + homogeneity: float + dissimilarity: float + texture: float + + +class MorphologyFeatures(TypedDict): + """Morphological features extracted from a single cell.""" + + area: float + perimeter: float + perimeter_area_ratio: float + eccentricity: float + intensity_localization: float + masked_intensity: float + masked_area: float + + +class SymmetryDescriptor(TypedDict): + """Symmetry-based features extracted from a single cell.""" + + zernike_std: float + zernike_mean: float + radial_intensity_gradient: float + + +class TrackFeatures(TypedDict): + """Velocity-based features extracted from a single track.""" + + instantaneous_velocity: list[float] + mean_velocity: float + max_velocity: float + min_velocity: float + std_velocity: float + + +class DisplacementFeatures(TypedDict): + """Displacement-based features extracted from a single track.""" + + total_distance: float + net_displacement: float + directional_persistence: float + - return descriptors +class AngularFeatures(TypedDict): + """Angular features extracted from a single track.""" - def analyze_symmetry(descriptors): + mean_angular_velocity: float + max_angular_velocity: float + std_angular_velocity: float + + +class CellFeatures: + """Class for computing various features from a single cell image patch. + + This class provides methods to compute intensity, texture, morphological, + and symmetry features from a cell image and its segmentation mask. + + Parameters + ---------- + image : ArrayLike + Input image array of the cell. + segmentation_mask : ArrayLike, optional + Binary mask of the cell segmentation, by default None. + + Attributes + ---------- + image : ArrayLike + Input image array. + segmentation_mask : ArrayLike + Binary segmentation mask. + intensity_features : IntensityFeatures + Computed intensity features. + texture_features : TextureFeatures + Computed texture features. + morphology_features : MorphologyFeatures + Computed morphological features. + symmetry_descriptor : SymmetryDescriptor + Computed symmetry features. + """ + + def __init__(self, image: ArrayLike, segmentation_mask: ArrayLike | None = None): + self.image = image + self.segmentation_mask = segmentation_mask + self.image_normalized = rescale_intensity(self.image, out_range=(0, 1)) + + # Initialize feature containers + self.intensity_features = None + self.texture_features = None + self.morphology_features = None + self.symmetry_descriptor = None + + self._eps = 1e-10 + + def _compute_kurtosis(self): + """Compute the kurtosis of the image. + + Returns + ------- + kurtosis: float + Kurtosis of the image intensity distribution (scale-invariant). + Returns nan for constant arrays. """ - Analyze the symmetry of the Fourier descriptors - Symmetry of the sensor or nuclear shape changes when infected - :param np.array descriptors: Fourier descriptors - :return: standard deviation of the descriptors + if np.std(self.image) == 0: + return np.nan + return scipy.stats.kurtosis(self.image, fisher=True, axis=None) + + def _compute_skewness(self): + """Compute the skewness of the image. + + Returns + ------- + skewness: float + Skewness of the image intensity distribution (scale-invariant). + Returns nan for constant arrays. """ - # Normalize descriptors - descriptors = np.abs(descriptors) / np.max(np.abs(descriptors)) + if np.std(self.image) == 0: + return np.nan + return scipy.stats.skew(self.image, axis=None) - return np.std(descriptors) # Lower standard deviation indicates higher symmetry + def _compute_glcm_features(self): + """Compute GLCM-based texture features from the image. - def compute_area(input_image, sigma=0.6): - """Create a binary mask using morphological operations - Sensor area will increase when infected due to expression in nucleus - :param np.array input_image: generate masks from this 3D image - :param float sigma: Gaussian blur standard deviation, increase in value increases blur - :return: area of the sensor mask & mean intensity inside the sensor area + Converts normalized image to uint8 for GLCM computation. """ + # Convert 0-1 normalized image to uint8 (0-255) + image_uint8 = (self.image_normalized * 255).astype(np.uint8) - input_image_blur = gaussian(input_image, sigma=sigma) + glcm = graycomatrix(image_uint8, [1], [45], symmetric=True, normed=True) - thresh = threshold_otsu(input_image_blur) - mask = input_image >= thresh + contrast = graycoprops(glcm, "contrast")[0, 0] + dissimilarity = graycoprops(glcm, "dissimilarity")[0, 0] + homogeneity = graycoprops(glcm, "homogeneity")[0, 0] - # Apply sensor mask to the image - masked_image = input_image * mask + return contrast, dissimilarity, homogeneity - # Compute the mean intensity inside the sensor area - masked_intensity = np.mean(masked_image) + def _compute_iqr(self): + """Compute the interquartile range of pixel intensities. - return masked_intensity, np.sum(mask) + The IQR is observed to increase when a cell is infected, + providing a measure of intensity distribution spread. - def compute_spectral_entropy(image): + Returns + ------- + iqr: float + Interquartile range of pixel intensities. """ - Compute the spectral entropy of the image - High frequency components are observed to increase in phase and reduce in sensor when cell is infected - :param np.array image: input image - :return: spectral entropy + iqr = np.percentile(self.image, 75) - np.percentile(self.image, 25) + + return iqr + + def _compute_weighted_intensity_gradient(self): + """Compute the weighted radial intensity gradient profile. + + Calculates the slope of the azimuthally averaged radial gradient + profile, weighted by intensity. This provides information about + how intensity changes with distance from the cell center. + + Returns + ------- + slope: float + Slope of the weighted radial intensity gradient profile. """ + # Get image dimensions + h, w = self.image.shape + center_y, center_x = h // 2, w // 2 + + # Create meshgrid of coordinates + y, x = np.ogrid[:h, :w] + + # Calculate radial distances from center + r = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) + + # Calculate gradients in x and y directions + gy, gx = np.gradient(self.image) + + # Calculate magnitude of gradient + gradient_magnitude = np.sqrt(gx**2 + gy**2) + + # Weight gradient by intensity + weighted_gradient = gradient_magnitude * self.image + + # Calculate maximum radius (to edge of image) + max_radius = int(min(h // 2, w // 2)) + + # Initialize arrays for radial profile + radial_profile = np.zeros(max_radius) + counts = np.zeros(max_radius) + + # Bin pixels by radius + for i in range(h): + for j in range(w): + radius = int(r[i, j]) + if radius < max_radius: + radial_profile[radius] += weighted_gradient[i, j] + counts[radius] += 1 + + # Average by counts (avoiding division by zero) + valid_mask = counts > 0 + radial_profile[valid_mask] /= counts[valid_mask] + + # Calculate slope using linear regression + x = np.arange(max_radius)[valid_mask] + y = radial_profile[valid_mask] + slope = np.polyfit(x, y, 1)[0] + return slope + + def _compute_spectral_entropy(self): + """Compute the spectral entropy of the image. + + Spectral entropy measures the complexity of the image's frequency + components. High frequency components are observed to increase in + phase and reduce in sensor when a cell is infected. + + Returns + ------- + entropy: float + Spectral entropy of the image. + """ # Compute the 2D Fourier Transform - f_transform = fft.fft2(image) + f_transform = fft.fft2(self.image) # Compute the power spectrum power_spectrum = np.abs(f_transform) ** 2 @@ -81,94 +268,602 @@ def compute_spectral_entropy(image): return entropy - def compute_glcm_features(image): + def _compute_texture_features(self): + """Compute Haralick texture features from the image. + + Converts normalized image to uint8 for Haralick computation. """ - Compute the contrast, dissimilarity and homogeneity of the image - Both sensor and phase texture changes when infected, smooth in sensor, and rough in phase - :param np.array image: input image - :return: contrast, dissimilarity, homogeneity + # Convert 0-1 normalized image to uint8 (0-255) + image_uint8 = (self.image_normalized * 255).astype(np.uint8) + texture_features = mh.features.haralick(image_uint8).ptp(0) + return np.mean(texture_features) + + def _compute_perimeter_area_ratio(self): + """Compute the perimeter of the nuclear segmentations found inside the patch. + + This function calculates the average perimeter, average area, and their ratio + for all nuclear segmentations in the patch. + + Returns + ------- + average_perimeter, average_area, ratio: tuple + Tuple containing: + - average_perimeter : float + Average perimeter of all regions in the patch + - average_area : float + Average area of all regions + - ratio : float + Ratio of total perimeter to total area """ + total_perimeter = 0 + total_area = 0 - # Normalize the input image from 0 to 255 - image = (image - np.min(image)) * (255 / (np.max(image) - np.min(image))) - image = image.astype(np.uint8) + # Use regionprops to analyze each labeled region + regions = regionprops(self.segmentation_mask) - # Compute the GLCM - distances = [1] # Distance between pixels - angles = [0] # Angle in radians + if not regions: # If no regions found + return 0, 0, 0 - glcm = graycomatrix(image, distances, angles, symmetric=True, normed=True) + # Sum up perimeter and area for all regions + for region in regions: + total_perimeter += region.perimeter + total_area += region.area - # Compute GLCM properties - contrast = graycoprops(glcm, "contrast")[0, 0] - dissimilarity = graycoprops(glcm, "dissimilarity")[0, 0] - homogeneity = graycoprops(glcm, "homogeneity")[0, 0] + average_area = total_area / len(regions) + average_perimeter = total_perimeter / len(regions) - return contrast, dissimilarity, homogeneity + return average_perimeter, average_area, total_perimeter / total_area - def compute_iqr(image): - """ - Compute the interquartile range of pixel intensities - Observed to increase when cell is infected - :param np.array image: input image - :return: interquartile range of pixel intensities + def _compute_nucleus_eccentricity(self): + """Compute the eccentricity of the nucleus. + + Eccentricity measures how much the nucleus deviates from + a perfect circle, with 0 being perfectly circular and 1 + being a line segment. + + Returns + ------- + eccentricity: float + Eccentricity of the nucleus (0 to 1). """ + # Use regionprops to analyze each labeled region + regions = regionprops(self.segmentation_mask) - # Compute the interquartile range of pixel intensities - iqr = np.percentile(image, 75) - np.percentile(image, 25) + if not regions: # If no regions found + return 0.0 - return iqr + # Calculate mean eccentricity across all regions + eccentricities = [region.eccentricity for region in regions] + return float(np.mean(eccentricities)) - def compute_mean_intensity(image): - """ - Compute the mean pixel intensity - Expected to vary when cell morphology changes due to infection, divison or death - :param np.array image: input image - :return: mean pixel intensity + def _compute_Eucledian_distance_transform(self): + """Compute the Euclidean distance transform of the segmentation mask. + + This transform computes the distance from each pixel to the + nearest background pixel, providing information about the + spatial distribution of the cell. + + Returns + ------- + dist_transform: ndarray + Distance transform of the segmentation mask. """ + # Ensure the image is binary + binary_mask = (self.segmentation_mask > 0).astype(np.uint8) - # Compute the mean pixel intensity - mean_intensity = np.mean(image) + # Compute the distance transform using scikit-image + dist_transform = distance_transform_edt(binary_mask) - return mean_intensity + return dist_transform - def compute_std_dev(image): + def _compute_intensity_localization(self): + """Compute localization of fluor using Eucledian distance transformation and fluor intensity. + + This function computes the intensity-weighted center of the fluor + using the Euclidean distance transform of the segmentation mask. + The intensity-weighted center is calculated as the sum of the + product of the image intensity and the distance transform, + divided by the sum of the distance transform. + + Returns + ------- + intensity_weighted_center: float + Intensity-weighted center of the fluor. """ - Compute the standard deviation of pixel intensities - Expected to vary when cell morphology changes due to infection, divison or death - :param np.array image: input image - :return: standard deviation of pixel intensities + # compute EDT of mask + edt = self._compute_Eucledian_distance_transform() + # compute the intensity weighted center of the fluor + intensity_weighted_center = np.sum(self.image * edt) / (np.sum(edt) + self._eps) + return intensity_weighted_center + + def _compute_area(self, sigma=0.6): + """Create a binary mask using morphological operations. + + This function creates a binary mask from the input image using Gaussian blur + and Otsu thresholding. The sensor area will increase when infected due to + expression in nucleus. + + Parameters + ---------- + sigma : float + Gaussian blur standard deviation. Increasing this value increases the blur, + by default 0.6 + + Returns + ------- + masked_intensity, masked_area: tuple + Tuple containing: + - masked_intensity : float + Mean intensity inside the sensor area + - masked_area : float + Area of the sensor mask in pixels """ - # Compute the standard deviation of pixel intensities - std_dev = np.std(image) + input_image_blur = gaussian(self.image, sigma=sigma) - return std_dev + thresh = threshold_otsu(input_image_blur) + mask = self.image >= thresh - def compute_radial_intensity_gradient(image): - """ - Compute the radial intensity gradient of the image - The sensor relocalizes inside the nucleus, which is center of the image when cells are infected - Expected negative gradient when infected and zero to positive gradient when not infected - :param np.array image: input image - :return: radial intensity gradient + # Apply sensor mask to the image + masked_image = self.image * mask + + # Compute the mean intensity inside the sensor area + masked_intensity = np.mean(masked_image) + + return masked_intensity, np.sum(mask) + + def _compute_zernike_moments(self): + """Compute the Zernike moments of the image. + + Zernike moments are a set of orthogonal moments that capture + the shape of the image. They are invariant to translation, rotation, + and scale. + + Returns + ------- + zernike_moments: np.ndarray + Zernike moments of the image. """ - # normalize the image - image = (image - np.min(image)) / (np.max(image) - np.min(image)) + zernike_moments = mh.features.zernike_moments(self.image, 32) + return zernike_moments - # compute the intensity gradient from center to periphery - y, x = np.indices(image.shape) - center = np.array(image.shape) / 2 + def _compute_radial_intensity_gradient(self): + """Compute the radial intensity gradient of the image. + + Uses 0-1 normalized image directly for gradient calculation. + """ + # Use 0-1 normalized image directly + y, x = np.indices(self.image_normalized.shape) + center = np.array(self.image_normalized.shape) / 2 r = np.sqrt((x - center[1]) ** 2 + (y - center[0]) ** 2) r = r.astype(int) - tbin = np.bincount(r.ravel(), image.ravel()) + + tbin = np.bincount(r.ravel(), self.image_normalized.ravel()) nr = np.bincount(r.ravel()) radial_intensity_values = tbin / nr - # get the slope radial_intensity_values - from scipy.stats import linregress - radial_intensity_gradient = linregress( range(len(radial_intensity_values)), radial_intensity_values ) return radial_intensity_gradient[0] + + def compute_intensity_features(self): + """Compute intensity features. + + This function computes various intensity-based features from the input image. + It calculates the mean, standard deviation, minimum, maximum, kurtosis, + skewness, spectral entropy, interquartile range, and weighted intensity gradient. + + Returns + ------- + IntensityFeatures + Dictionary containing all computed intensity features. + """ + self.intensity_features = IntensityFeatures( + mean_intensity=float(np.mean(self.image)), + std_dev=float(np.std(self.image)), + min_intensity=float(np.min(self.image)), + max_intensity=float(np.max(self.image)), + kurtosis=self._compute_kurtosis(), + skewness=self._compute_skewness(), + spectral_entropy=self._compute_spectral_entropy(), + iqr=self._compute_iqr(), + weighted_intensity_gradient=self._compute_weighted_intensity_gradient(), + ) + + def compute_texture_features(self): + """Compute texture features. + + This function computes texture features from the input image. + It calculates the spectral entropy, contrast, entropy, homogeneity, + dissimilarity, and texture features. + + Returns + ------- + TextureFeatures + Dictionary containing all computed texture features. + """ + contrast, dissimilarity, homogeneity = self._compute_glcm_features() + self.texture_features = TextureFeatures( + spectral_entropy=self._compute_spectral_entropy(), + contrast=contrast, + entropy=self._compute_spectral_entropy(), # Note: This could be redundant + homogeneity=homogeneity, + dissimilarity=dissimilarity, + texture=self._compute_texture_features(), + ) + + def compute_morphology_features(self): + """Compute morphology features. + + This function computes morphology features from the input image. + It calculates the area, perimeter, perimeter-to-area ratio, + eccentricity, intensity localization, masked intensity, and masked area. + + Returns + ------- + MorphologyFeatures + Dictionary containing all computed morphology features. + + Raises + ------ + AssertionError + If segmentation mask is None or empty + """ + if self.segmentation_mask is None: + raise AssertionError("Segmentation mask is required") + + if np.sum(self.segmentation_mask) == 0: + raise AssertionError("Segmentation mask is empty") + + masked_intensity, masked_area = self._compute_area() + perimeter, area, ratio = self._compute_perimeter_area_ratio() + self.morphology_features = MorphologyFeatures( + area=area, + perimeter=perimeter, + perimeter_area_ratio=ratio, + eccentricity=self._compute_nucleus_eccentricity(), + intensity_localization=self._compute_intensity_localization(), + masked_intensity=masked_intensity, + masked_area=masked_area, + ) + + def compute_symmetry_descriptor(self): + """Compute the symmetry descriptor of the image. + + This function computes the symmetry descriptor of the image. + It calculates the Zernike moments, Zernike mean, and radial intensity gradient. + + Returns + ------- + SymmetryDescriptor + Dictionary containing all computed symmetry descriptor features. + """ + self.symmetry_descriptor = SymmetryDescriptor( + zernike_std=np.std(self._compute_zernike_moments()), + zernike_mean=np.mean(self._compute_zernike_moments()), + radial_intensity_gradient=self._compute_radial_intensity_gradient(), + ) + + def compute_all_features(self) -> pd.DataFrame: + """Compute all features. + + This function computes all features from the input image. + It calculates the intensity, texture, symmetry descriptor, + and morphology features. + + Returns + ------- + pd.DataFrame + DataFrame containing all computed features. + """ + # Compute intensity features + self.compute_intensity_features() + + # Compute texture features + self.compute_texture_features() + + # Compute symmetry descriptor + self.compute_symmetry_descriptor() + + if self.segmentation_mask is not None: + self.compute_morphology_features() + + return self.to_df() + + def to_df(self) -> pd.DataFrame: + """Convert all features to a pandas DataFrame. + + This function combines all computed features (intensity, texture, + morphology, and symmetry features) into a single pandas DataFrame. + The features are organized in a flat structure where each column + represents a different feature. + + Returns + ------- + pd.DataFrame + DataFrame containing all computed features with the following columns: + - Intensity features (if computed) + - Texture features (if computed) + - Morphology features (if computed) + - Symmetry descriptor (if computed) + + Notes + ----- + Only features that have been computed (non-None) will be included + in the output DataFrame. The DataFrame will have a single row + containing all the features. + """ + features_dict = {} + if self.intensity_features: + features_dict.update(self.intensity_features) + if self.texture_features: + features_dict.update(self.texture_features) + if self.morphology_features: + features_dict.update(self.morphology_features) + if self.symmetry_descriptor: + features_dict.update(self.symmetry_descriptor) + return pd.DataFrame([features_dict]) + + +class DynamicFeatures: + """Compute dynamic features from cell tracking data. + + This class provides methods to compute various dynamic features from cell + tracking data, including velocity, displacement, and angular features. + These features are useful for analyzing cell movement patterns and behavior. + + Parameters + ---------- + tracking_df : pandas.DataFrame + DataFrame containing cell tracking data with track_id, t, x, y columns + + Attributes + ---------- + tracking_df : pandas.DataFrame + The input tracking dataframe containing cell position data over time + track_features : TrackFeatures or None + Computed velocity-based features including mean, max, min velocities + and their standard deviation + displacement_features : DisplacementFeatures or None + Computed displacement features including total distance traveled, + net displacement, and directional persistence + angular_features : AngularFeatures or None + Computed angular features including mean, max, and standard deviation + of angular velocities + + Raises + ------ + ValueError + If the tracking dataframe is missing any of the required columns + (track_id, t, x, y) + """ + + def __init__(self, tracking_df: pd.DataFrame): + self.tracking_df = tracking_df + self.track_features = None + self.displacement_features = None + self.angular_features = None + + self._eps = 1e-10 + # Verify required columns exist + required_cols = ["track_id", "t", "x", "y"] + missing_cols = [col for col in required_cols if col not in tracking_df.columns] + if missing_cols: + raise ValueError(f"Missing required columns: {missing_cols}") + + # Verify numeric types for coordinates + for col in ["t", "x", "y"]: + if not np.issubdtype(tracking_df[col].dtype, np.number): + raise ValueError(f"Column {col} must be numeric") + + def _compute_instantaneous_velocity(self, track_id: str) -> np.ndarray: + """Compute the instantaneous velocity for all timepoints in a track. + + Parameters + ---------- + track_id : str + ID of the track to compute velocities for + + Returns + ------- + velocities : np.ndarray + Array of instantaneous velocities for each timepoint + """ + # Get track data sorted by time + track_data = self.tracking_df[ + self.tracking_df["track_id"] == track_id + ].sort_values("t") + + # TODO: decide if we want to return nans or zeros + if len(track_data) < 2: + return np.array([0.0]) # Return zero velocity for single-point tracks + + # Calculate displacements between consecutive points + dx = np.diff(track_data["x"].values) + dy = np.diff(track_data["y"].values) + dt = np.diff(track_data["t"].values) + + # Compute distances + distances = np.sqrt(dx**2 + dy**2) + + # Compute velocities (avoid division by zero) + velocities = np.zeros(len(track_data)) + velocities[1:] = distances / np.maximum(dt, self._eps) + + return velocities + + def _compute_displacement(self, track_id: str) -> tuple[float, float, float]: + """Compute displacement-based features for a track. + + This function calculates various displacement metrics for a given track, + including total distance traveled, net displacement, and directional + persistence. These metrics help characterize the movement pattern of + the tracked cell. + + Parameters + ---------- + track_id : str + ID of the track to compute displacement features for + + Returns + ------- + total_distance, net_displacement, directional_persistence: tuple + Tuple containing: + - total_distance : float + Total distance traveled by the cell along its path + - net_displacement : float + Straight-line distance between start and end positions + - directional_persistence : float + Ratio of net displacement to total distance (0 to 1), + where 1 indicates perfectly straight movement + """ + track_data = self.tracking_df[ + self.tracking_df["track_id"] == track_id + ].sort_values("t") + + if len(track_data) < 2: + return 0.0, 0.0, 0.0 + + # Compute total distance + dx = np.diff(track_data["x"].values) + dy = np.diff(track_data["y"].values) + distances = np.sqrt(dx**2 + dy**2) + total_distance = np.sum(distances) + + # Compute net displacement + start_point = track_data.iloc[0][["x", "y"]].values + end_point = track_data.iloc[-1][["x", "y"]].values + net_displacement = np.sqrt(np.sum((end_point - start_point) ** 2)) + + # Compute directional persistence + directional_persistence = ( + net_displacement / total_distance if total_distance > 0 else 0.0 + ) + + return total_distance, net_displacement, directional_persistence + + def _compute_angular_velocity(self, track_id: str) -> tuple[float, float, float]: + """Compute angular velocity features for a track. + + This function calculates the angular velocity statistics for a given track, + including mean, maximum, and standard deviation of angular velocities. + Angular velocity is computed as the change in angle between consecutive + movement vectors over time. + + Parameters + ---------- + track_id : str + ID of the track to compute angular velocity for + + Returns + ------- + mean_angular_velocity, max_angular_velocity, std_angular_velocity: tuple + Tuple containing: + - mean_angular_velocity + - max_angular_velocity + - std_angular_velocity + """ + track_data = self.tracking_df[ + self.tracking_df["track_id"] == track_id + ].sort_values("t") + + if len(track_data) < 3: # Need at least 3 points to compute angle changes + return 0.0, 0.0, 0.0 + + # Compute vectors between consecutive points + dx = np.diff(track_data["x"].values) + dy = np.diff(track_data["y"].values) + dt = np.diff(track_data["t"].values) + + # Compute angles between consecutive vectors + vectors = np.column_stack([dx, dy]) + angles = np.zeros(len(vectors) - 1) + for i in range(len(vectors) - 1): + v1, v2 = vectors[i], vectors[i + 1] + cos_angle = np.dot(v1, v2) / ( + np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-10 + ) + angles[i] = np.arccos(np.clip(cos_angle, -1.0, 1.0)) + + # Compute angular velocities (change in angle over time) + angular_velocities = angles / (dt[1:] + self._eps) + + return ( + float(np.mean(angular_velocities)), + float(np.max(angular_velocities)), + float(np.std(angular_velocities)), + ) + + def compute_all_features(self, track_id: str) -> pd.DataFrame: + """Compute all dynamic features for a given track. + + This function computes a comprehensive set of dynamic features for a track, + including velocity, displacement, and angular features. These features + characterize the movement patterns and behavior of the tracked cell. + + Parameters + ---------- + track_id : str + ID of the track to compute features for + + Returns + ------- + pd.DataFrame + DataFrame containing all computed features: + - Velocity features: instantaneous, mean, max, min velocities and std + - Displacement features: total distance, net displacement, persistence + - Angular features: mean, max, and std of angular velocities + """ + # Compute velocity features + velocities = self._compute_instantaneous_velocity(track_id) + self.velocity_features = TrackFeatures( + instantaneous_velocity=velocities.tolist(), + mean_velocity=float(np.mean(velocities)), + max_velocity=float(np.max(velocities)), + min_velocity=float(np.min(velocities)), + std_velocity=float(np.std(velocities)), + ) + + # Compute displacement features + total_dist, net_disp, dir_persist = self._compute_displacement(track_id) + self.displacement_features = DisplacementFeatures( + total_distance=total_dist, + net_displacement=net_disp, + directional_persistence=dir_persist, + ) + + # Compute angular features + mean_ang, max_ang, std_ang = self._compute_angular_velocity(track_id) + self.angular_features = AngularFeatures( + mean_angular_velocity=mean_ang, + max_angular_velocity=max_ang, + std_angular_velocity=std_ang, + ) + + return self.to_df() + + def to_df(self) -> pd.DataFrame: + """Convert all features to a pandas DataFrame. + + This function combines all computed features (velocity, displacement, + and angular features) into a single pandas DataFrame. The features + are organized in a flat structure where each column represents a + different feature. + + Returns + ------- + pd.DataFrame + DataFrame containing all computed features with the following columns: + - Velocity features + - Displacement features + - Angular features + """ + features_dict = {} + if self.velocity_features: + features_dict.update(self.velocity_features) + if self.displacement_features: + features_dict.update(self.displacement_features) + if self.angular_features: + features_dict.update(self.angular_features) + return pd.DataFrame([features_dict]) diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py new file mode 100644 index 000000000..9d787fe05 --- /dev/null +++ b/viscy/representation/evaluation/visualization.py @@ -0,0 +1,2244 @@ +import atexit +import base64 +import json +import logging +from io import BytesIO +from pathlib import Path + +import dash +import dash.dependencies as dd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.graph_objects as go +from dash import dcc, html +from PIL import Image +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.data.triplet import TripletDataModule +from viscy.representation.embedding_writer import read_embedding_dataset + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class EmbeddingVisualizationApp: + def __init__( + self, + data_path: str, + tracks_path: str, + features_path: str, + channels_to_display: list[str] | str, + fov_tracks: dict[str, list[int] | str], + z_range: tuple[int, int] = (0, 1), + yx_patch_size: tuple[int, int] = (128, 128), + num_PC_components: int = 3, + cache_path: str | None = None, + num_loading_workers: int = 16, + output_dir: str | None = None, + ) -> None: + """ + Initialize a Dash application for visualizing the DynaCLR embeddings. + + This class provides a visualization tool for visualizing the DynaCLR embeddings into a 2D space (e.g. PCA, UMAP, PHATE). + It allows users to interactively explore and analyze trajectories, visualize clusters, and explore the embedding space. + + Parameters + ---------- + data_path: str + Path to the data directory. + tracks_path: str + Path to the tracks directory. + features_path: str + Path to the features directory. + channels_to_display: list[str] | str + List of channels to display. + fov_tracks: dict[str, list[int] | str] + Dictionary of FOV names and track IDs. + z_range: tuple[int, int] | list[int,int] + Range of z-slices to display. + yx_patch_size: tuple[int, int] | list[int,int] + Size of the yx-patch to display. + num_PC_components: int + Number of PCA components to use. + cache_path: str | None + Path to the cache directory. + num_loading_workers: int + Number of workers to use for loading data. + output_dir: str | None, optional + Directory to save CSV files and other outputs. If None, uses current working directory. + Returns + ------- + None + Initializes the visualization app. + """ + self.data_path = Path(data_path) + self.tracks_path = Path(tracks_path) + self.features_path = Path(features_path) + self.fov_tracks = fov_tracks + self.image_cache = {} + self.cache_path = Path(cache_path) if cache_path else None + self.output_dir = Path(output_dir) if output_dir else Path.cwd() + self.app = None + self.features_df = None + self.fig = None + self.channels_to_display = channels_to_display + self.z_range = z_range + self.yx_patch_size = yx_patch_size + self.filtered_tracks_by_fov = {} + self._z_idx = (self.z_range[1] - self.z_range[0]) // 2 + self.num_PC_components = num_PC_components + self.num_loading_workers = num_loading_workers + # Initialize cluster storage before preparing data and creating figure + self.clusters = [] # List to store all clusters + self.cluster_points = set() # Set to track all points in clusters + self.cluster_names = {} # Dictionary to store cluster names + self.next_cluster_id = 1 # Counter for cluster IDs + # Initialize data + self._prepare_data() + self._create_figure() + self._init_app() + atexit.register(self._cleanup_cache) + + def _prepare_data(self): + """Prepare the feature data and PCA transformation""" + embedding_dataset = read_embedding_dataset(self.features_path) + features = embedding_dataset["features"] + self.features_df = features["sample"].to_dataframe().reset_index(drop=True) + + # Check if UMAP or PHATE columns already exist + existing_dims = [] + dim_options = [] + + # Check for PCA and compute if needed + if not any(col.startswith("PCA") for col in self.features_df.columns): + # PCA transformation + scaled_features = StandardScaler().fit_transform(features.values) + pca = PCA(n_components=self.num_PC_components) + pca_coords = pca.fit_transform(scaled_features) + + # Add PCA coordinates to the features dataframe + for i in range(self.num_PC_components): + self.features_df[f"PCA{i + 1}"] = pca_coords[:, i] + + # Store explained variance for PCA + self.pca_explained_variance = [ + f"PC{i + 1} ({var:.1f}%)" + for i, var in enumerate(pca.explained_variance_ratio_ * 100) + ] + + # Add PCA options + for i, pc_label in enumerate(self.pca_explained_variance): + dim_options.append({"label": pc_label, "value": f"PCA{i + 1}"}) + existing_dims.append(f"PCA{i + 1}") + + # Check for UMAP coordinates + umap_dims = [col for col in self.features_df.columns if col.startswith("UMAP")] + if umap_dims: + for dim in umap_dims: + dim_options.append({"label": dim, "value": dim}) + existing_dims.append(dim) + + # Check for PHATE coordinates + phate_dims = [ + col for col in self.features_df.columns if col.startswith("PHATE") + ] + if phate_dims: + for dim in phate_dims: + dim_options.append({"label": dim, "value": dim}) + existing_dims.append(dim) + + # Store dimension options for dropdowns + self.dim_options = dim_options + + # Set default x and y axes based on available dimensions + self.default_x = existing_dims[0] if existing_dims else "PCA1" + self.default_y = existing_dims[1] if len(existing_dims) > 1 else "PCA2" + + # Process each FOV and its track IDs + all_filtered_features = [] + for fov_name, track_ids in self.fov_tracks.items(): + if track_ids == "all": + fov_tracks = ( + self.features_df[self.features_df["fov_name"] == fov_name][ + "track_id" + ] + .unique() + .tolist() + ) + else: + fov_tracks = track_ids + + self.filtered_tracks_by_fov[fov_name] = fov_tracks + + # Filter features for this FOV and its track IDs + fov_features = self.features_df[ + (self.features_df["fov_name"] == fov_name) + & (self.features_df["track_id"].isin(fov_tracks)) + ] + all_filtered_features.append(fov_features) + + # Combine all filtered features + self.filtered_features_df = pd.concat(all_filtered_features, axis=0) + + def _create_figure(self): + """Create the initial scatter plot figure""" + self.fig = self._create_track_colored_figure() + + def _init_app(self): + """Initialize the Dash application""" + self.app = dash.Dash(__name__) + + # Add cluster assignment button next to clear selection + cluster_controls = html.Div( + [ + html.Button( + "Assign to New Cluster", + id="assign-cluster", + style={ + "backgroundColor": "#28a745", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Clear All Clusters", + id="clear-clusters", + style={ + "backgroundColor": "#dc3545", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Save Clusters to CSV", + id="save-clusters-csv", + style={ + "backgroundColor": "#17a2b8", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Clear Selection", + id="clear-selection", + style={ + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "padding": "5px 10px", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + ], + style={"marginLeft": "10px", "display": "inline-block"}, + ) + # Create tabs for different views + tabs = dcc.Tabs( + id="view-tabs", + value="timeline-tab", + children=[ + dcc.Tab( + label="Track Timeline", + value="timeline-tab", + children=[ + html.Div( + id="track-timeline", + style={ + "height": "auto", + "overflowY": "auto", + "maxHeight": "80vh", + "padding": "10px", + "marginTop": "10px", + }, + ), + ], + ), + dcc.Tab( + label="Clusters", + value="clusters-tab", + id="clusters-tab", + children=[ + html.Div( + id="cluster-container", + style={ + "padding": "10px", + "marginTop": "10px", + }, + ), + ], + style={"display": "none"}, # Initially hidden + ), + ], + style={"marginTop": "20px"}, + ) + + # Add modal for cluster naming + cluster_name_modal = html.Div( + id="cluster-name-modal", + children=[ + html.Div( + [ + html.H3("Name Your Cluster", style={"marginBottom": "20px"}), + html.Label("Cluster Name:"), + dcc.Input( + id="cluster-name-input", + type="text", + placeholder="Enter cluster name...", + style={"width": "100%", "marginBottom": "20px"}, + ), + html.Div( + [ + html.Button( + "Save", + id="save-cluster-name", + style={ + "backgroundColor": "#28a745", + "color": "white", + "border": "none", + "padding": "8px 16px", + "borderRadius": "4px", + "cursor": "pointer", + "marginRight": "10px", + }, + ), + html.Button( + "Cancel", + id="cancel-cluster-name", + style={ + "backgroundColor": "#6c757d", + "color": "white", + "border": "none", + "padding": "8px 16px", + "borderRadius": "4px", + "cursor": "pointer", + }, + ), + ], + style={"textAlign": "right"}, + ), + ], + style={ + "backgroundColor": "white", + "padding": "30px", + "borderRadius": "8px", + "maxWidth": "400px", + "margin": "auto", + "boxShadow": "0 4px 6px rgba(0, 0, 0, 0.1)", + "border": "1px solid #ddd", + }, + ) + ], + style={ + "display": "none", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + }, + ) + + # Update layout to use tabs + self.app.layout = html.Div( + style={ + "maxWidth": "95vw", + "margin": "auto", + "padding": "20px", + }, + children=[ + html.H1( + "Track Visualization", + style={"textAlign": "center", "marginBottom": "20px"}, + ), + html.Div( + [ + html.Div( + style={ + "width": "100%", + "display": "inline-block", + "verticalAlign": "top", + }, + children=[ + html.Div( + style={ + "marginBottom": "20px", + "display": "flex", + "alignItems": "center", + "gap": "20px", + "flexWrap": "wrap", + }, + children=[ + html.Div( + [ + html.Label( + "Color by:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="color-mode", + options=[ + { + "label": "Track ID", + "value": "track", + }, + { + "label": "Time", + "value": "time", + }, + ], + value="track", + style={"width": "200px"}, + ), + ] + ), + html.Div( + [ + dcc.Checklist( + id="show-arrows", + options=[ + { + "label": "Show arrows", + "value": "show", + } + ], + value=[], + style={"marginLeft": "20px"}, + ), + ] + ), + html.Div( + [ + html.Label( + "X-axis:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="x-axis", + options=self.dim_options, + value=self.default_x, + style={"width": "200px"}, + ), + ] + ), + html.Div( + [ + html.Label( + "Y-axis:", + style={"marginRight": "10px"}, + ), + dcc.Dropdown( + id="y-axis", + options=self.dim_options, + value=self.default_y, + style={"width": "200px"}, + ), + ] + ), + cluster_controls, + ], + ), + ], + ), + ] + ), + dcc.Loading( + id="loading", + children=[ + dcc.Graph( + id="scatter-plot", + figure=self.fig, + config={ + "displayModeBar": True, + "editable": False, + "showEditInChartStudio": False, + "modeBarButtonsToRemove": [ + "select2d", + "resetScale2d", + ], + "edits": { + "annotationPosition": False, + "annotationTail": False, + "annotationText": False, + "shapePosition": True, + }, + "scrollZoom": True, + }, + style={"height": "50vh"}, + ), + ], + type="default", + ), + tabs, + cluster_name_modal, + ], + ) + + @self.app.callback( + [ + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + ], + [ + dd.Input("color-mode", "value"), + dd.Input("show-arrows", "value"), + dd.Input("x-axis", "value"), + dd.Input("y-axis", "value"), + dd.Input("scatter-plot", "relayoutData"), + dd.Input("scatter-plot", "selectedData"), + ], + [dd.State("scatter-plot", "figure")], + prevent_initial_call=True, + ) + def update_figure( + color_mode, + show_arrows, + x_axis, + y_axis, + relayout_data, + selected_data, + current_figure, + ): + show_arrows = len(show_arrows or []) > 0 + + ctx = dash.callback_context + if not ctx.triggered: + triggered_id = "No clicks yet" + else: + triggered_id = ctx.triggered[0]["prop_id"].split(".")[0] + + # Create new figure when necessary + if triggered_id in [ + "color-mode", + "show-arrows", + "x-axis", + "y-axis", + ]: + if color_mode == "track": + fig = self._create_track_colored_figure(show_arrows, x_axis, y_axis) + else: + fig = self._create_time_colored_figure(show_arrows, x_axis, y_axis) + + # Update dragmode and selection settings + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision="true", + selectdirection="any", + ) + else: + fig = dash.no_update + + return fig, selected_data + + @self.app.callback( + dd.Output("track-timeline", "children"), + [dd.Input("scatter-plot", "clickData")], + prevent_initial_call=True, + ) + def update_track_timeline(clickData): + """Update the track timeline based on the clicked point""" + if clickData is None: + return html.Div("Click on a point to see the track timeline") + + # Parse the hover text to get track_id, time and fov_name + hover_text = clickData["points"][0]["text"] + track_id = int(hover_text.split("
")[0].split(": ")[1]) + clicked_time = int(hover_text.split("
")[1].split(": ")[1]) + fov_name = hover_text.split("
")[2].split(": ")[1] + + # Get all timepoints for this track + track_data = self.features_df[ + (self.features_df["fov_name"] == fov_name) + & (self.features_df["track_id"] == track_id) + ].sort_values("t") + + if track_data.empty: + return html.Div(f"No data found for track {track_id}") + + # Get unique timepoints + timepoints = track_data["t"].unique() + + # Create a list to store all timepoint columns + timepoint_columns = [] + + # First create the time labels row + time_labels = [] + for t in timepoints: + is_clicked = t == clicked_time + time_style = { + "width": "150px", + "textAlign": "center", + "padding": "5px", + "fontWeight": "bold" if is_clicked else "normal", + "color": "#007bff" if is_clicked else "black", + } + time_labels.append(html.Div(f"t={t}", style=time_style)) + + timepoint_columns.append( + html.Div( + time_labels, + style={ + "display": "flex", + "flexDirection": "row", + "minWidth": "fit-content", + "borderBottom": "2px solid #ddd", + "marginBottom": "10px", + "paddingBottom": "5px", + }, + ) + ) + + # Then create image rows for each channel + for channel in self.channels_to_display: + channel_images = [] + for t in timepoints: + cache_key = (fov_name, track_id, t) + if ( + cache_key in self.image_cache + and channel in self.image_cache[cache_key] + ): + is_clicked = t == clicked_time + image_style = { + "width": "150px", + "height": "150px", + "border": ( + "3px solid #007bff" if is_clicked else "1px solid #ddd" + ), + "borderRadius": "4px", + } + channel_images.append( + html.Div( + html.Img( + src=self.image_cache[cache_key][channel], + style=image_style, + ), + style={ + "width": "150px", + "padding": "5px", + }, + ) + ) + + if channel_images: + # Add channel label + timepoint_columns.append( + html.Div( + [ + html.Div( + channel, + style={ + "width": "100px", + "fontWeight": "bold", + "fontSize": "14px", + "padding": "5px", + "backgroundColor": "#f8f9fa", + "borderRadius": "4px", + "marginBottom": "5px", + "textAlign": "center", + }, + ), + html.Div( + channel_images, + style={ + "display": "flex", + "flexDirection": "row", + "minWidth": "fit-content", + "marginBottom": "15px", + }, + ), + ] + ) + ) + + # Create the main container with synchronized scrolling + return html.Div( + [ + html.H4( + f"Track {track_id} (FOV: {fov_name})", + style={ + "marginBottom": "20px", + "fontSize": "20px", + "fontWeight": "bold", + "color": "#2c3e50", + }, + ), + html.Div( + timepoint_columns, + style={ + "overflowX": "auto", + "overflowY": "hidden", + "whiteSpace": "nowrap", + "backgroundColor": "white", + "padding": "20px", + "borderRadius": "8px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + "marginBottom": "20px", + }, + ), + ] + ) + + # Add callback to show/hide clusters tab and handle modal + @self.app.callback( + [ + dd.Output("clusters-tab", "style"), + dd.Output("cluster-container", "children"), + dd.Output("view-tabs", "value"), + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("cluster-name-modal", "style"), + dd.Output("cluster-name-input", "value"), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + ], + [ + dd.Input("assign-cluster", "n_clicks"), + dd.Input("clear-clusters", "n_clicks"), + dd.Input("save-cluster-name", "n_clicks"), + dd.Input("cancel-cluster-name", "n_clicks"), + dd.Input({"type": "edit-cluster-name", "index": dash.ALL}, "n_clicks"), + ], + [ + dd.State("scatter-plot", "selectedData"), + dd.State("scatter-plot", "figure"), + dd.State("color-mode", "value"), + dd.State("show-arrows", "value"), + dd.State("x-axis", "value"), + dd.State("y-axis", "value"), + dd.State("cluster-name-input", "value"), + ], + prevent_initial_call=True, + ) + def update_clusters_tab( + assign_clicks, + clear_clicks, + save_name_clicks, + cancel_name_clicks, + edit_name_clicks, + selected_data, + current_figure, + color_mode, + show_arrows, + x_axis, + y_axis, + cluster_name, + ): + ctx = dash.callback_context + if not ctx.triggered: + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + button_id = ctx.triggered[0]["prop_id"].split(".")[0] + + # Handle edit cluster name button clicks + if button_id.startswith('{"type":"edit-cluster-name"'): + try: + id_dict = json.loads(button_id) + cluster_idx = id_dict["index"] + + # Get current cluster name + current_name = self.cluster_names.get( + cluster_idx, f"Cluster {cluster_idx + 1}" + ) + + # Show modal + modal_style = { + "display": "flex", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + } + + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + dash.no_update, + modal_style, + current_name, + dash.no_update, # Don't change selection + ) + except Exception: + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + if ( + button_id == "assign-cluster" + and selected_data + and selected_data.get("points") + ): + # Create new cluster from selected points + new_cluster = [] + for point in selected_data["points"]: + text = point["text"] + lines = text.split("
") + track_id = int(lines[0].split(": ")[1]) + t = int(lines[1].split(": ")[1]) + fov = lines[2].split(": ")[1] + + cache_key = (fov, track_id, t) + if cache_key in self.image_cache: + new_cluster.append( + { + "track_id": track_id, + "t": t, + "fov_name": fov, + } + ) + self.cluster_points.add(cache_key) + + if new_cluster: + # Add cluster to list but don't assign name yet + self.clusters.append(new_cluster) + # Open modal for naming + modal_style = { + "display": "flex", + "position": "fixed", + "top": "0", + "left": "0", + "width": "100%", + "height": "100%", + "backgroundColor": "rgba(0, 0, 0, 0.5)", + "zIndex": "1000", + "justifyContent": "center", + "alignItems": "center", + } + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + dash.no_update, # Don't update figure yet + modal_style, # Show modal + "", # Clear input + None, # Clear selection + ) + + elif button_id == "save-cluster-name" and cluster_name: + # Assign name to the most recently created cluster + if self.clusters: + cluster_id = len(self.clusters) - 1 + self.cluster_names[cluster_id] = cluster_name.strip() + + # Create new figure with updated colors + fig = self._create_track_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + # Ensure the dragmode is set based on selection_mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision="true", # Keep the UI state + selectdirection="any", + ) + modal_style = {"display": "none"} + return ( + {"display": "block"}, + self._get_cluster_images(), + "clusters-tab", + fig, + modal_style, # Hide modal + "", # Clear input + None, # Clear selection + ) + + elif button_id == "cancel-cluster-name": + # Remove the cluster that was just created + if self.clusters: + # Remove points from cluster_points set + for point in self.clusters[-1]: + cache_key = (point["fov_name"], point["track_id"], point["t"]) + self.cluster_points.discard(cache_key) + # Remove the cluster + self.clusters.pop() + + # Create new figure with updated colors + fig = self._create_track_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + # Ensure the dragmode is set based on selection_mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision="true", # Keep the UI state + selectdirection="any", + ) + modal_style = {"display": "none"} + return ( + ( + {"display": "none"} + if not self.clusters + else {"display": "block"} + ), + self._get_cluster_images() if self.clusters else None, + "timeline-tab" if not self.clusters else "clusters-tab", + fig, + modal_style, # Hide modal + "", # Clear input + None, # Clear selection + ) + + elif button_id == "clear-clusters": + self.clusters = [] + self.cluster_points.clear() + self.cluster_names.clear() + # Restore original coloring + fig = self._create_track_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + # Reset UI state completely to ensure clean slate + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision=None, # Reset UI state completely + selectdirection="any", + ) + modal_style = {"display": "none"} + return ( + {"display": "none"}, + None, + "timeline-tab", + fig, + modal_style, + "", + None, + ) # Clear selection + + return ( + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + dash.no_update, + ) + + # Add callback for saving clusters to CSV + @self.app.callback( + dd.Output("cluster-container", "children", allow_duplicate=True), + [dd.Input("save-clusters-csv", "n_clicks")], + prevent_initial_call=True, + ) + def save_clusters_csv(n_clicks): + """Callback to save clusters to CSV file""" + if n_clicks and self.clusters: + try: + output_path = self.save_clusters_to_csv() + return html.Div( + [ + html.H3("Clusters", style={"marginBottom": "20px"}), + html.Div( + f"✅ Successfully saved {len(self.clusters)} clusters to: {output_path}", + style={ + "backgroundColor": "#d4edda", + "color": "#155724", + "padding": "10px", + "borderRadius": "4px", + "marginBottom": "20px", + "border": "1px solid #c3e6cb", + }, + ), + self._get_cluster_images(), + ] + ) + except Exception as e: + return html.Div( + [ + html.H3("Clusters", style={"marginBottom": "20px"}), + html.Div( + f"❌ Error saving clusters: {str(e)}", + style={ + "backgroundColor": "#f8d7da", + "color": "#721c24", + "padding": "10px", + "borderRadius": "4px", + "marginBottom": "20px", + "border": "1px solid #f5c6cb", + }, + ), + self._get_cluster_images(), + ] + ) + elif n_clicks and not self.clusters: + return html.Div( + [ + html.H3("Clusters", style={"marginBottom": "20px"}), + html.Div( + "⚠️ No clusters to save. Create clusters first by selecting points and clicking 'Assign to New Cluster'.", + style={ + "backgroundColor": "#fff3cd", + "color": "#856404", + "padding": "10px", + "borderRadius": "4px", + "marginBottom": "20px", + "border": "1px solid #ffeaa7", + }, + ), + ] + ) + return dash.no_update + + @self.app.callback( + [ + dd.Output("scatter-plot", "figure", allow_duplicate=True), + dd.Output("scatter-plot", "selectedData", allow_duplicate=True), + ], + [dd.Input("clear-selection", "n_clicks")], + [ + dd.State("color-mode", "value"), + dd.State("show-arrows", "value"), + dd.State("x-axis", "value"), + dd.State("y-axis", "value"), + ], + prevent_initial_call=True, + ) + def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): + """Callback to clear the selection and restore original opacity""" + if n_clicks: + # Create a new figure with no selections + if color_mode == "track": + fig = self._create_track_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + else: + fig = self._create_time_colored_figure( + len(show_arrows or []) > 0, + x_axis, + y_axis, + ) + + # Update layout to maintain lasso mode but clear selections + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + uirevision=None, # Reset UI state + selectdirection="any", + ) + + return fig, None # Return new figure and clear selectedData + return dash.no_update, dash.no_update + + def _calculate_equal_aspect_ranges(self, x_data, y_data): + """Calculate ranges for x and y axes to ensure equal aspect ratio. + + Parameters + ---------- + x_data : array-like + Data for x-axis + y_data : array-like + Data for y-axis + + Returns + ------- + tuple + (x_range, y_range) as tuples of (min, max) with equal scaling + """ + # Get data ranges + x_min, x_max = np.min(x_data), np.max(x_data) + y_min, y_max = np.min(y_data), np.max(y_data) + + # Add padding (5% on each side) + x_padding = 0.05 * (x_max - x_min) + y_padding = 0.05 * (y_max - y_min) + + x_min -= x_padding + x_max += x_padding + y_min -= y_padding + y_max += y_padding + + # Ensure equal scaling by using the larger range + x_range = x_max - x_min + y_range = y_max - y_min + + if x_range > y_range: + # Expand y-range to match x-range aspect ratio + y_center = (y_max + y_min) / 2 + y_min = y_center - x_range / 2 + y_max = y_center + x_range / 2 + else: + # Expand x-range to match y-range aspect ratio + x_center = (x_max + x_min) / 2 + x_min = x_center - y_range / 2 + x_max = x_center + y_range / 2 + + return (x_min, x_max), (y_min, y_max) + + def _create_track_colored_figure( + self, + show_arrows=False, + x_axis=None, + y_axis=None, + ): + """Create scatter plot with track-based coloring""" + x_axis = x_axis or self.default_x + y_axis = y_axis or self.default_y + + unique_tracks = self.filtered_features_df["track_id"].unique() + cmap = plt.cm.tab20 + track_colors = { + track_id: f"rgb{tuple(int(x * 255) for x in cmap(i % 20)[:3])}" + for i, track_id in enumerate(unique_tracks) + } + + fig = go.Figure() + + # Set initial layout with lasso mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + selectdirection="any", + plot_bgcolor="white", + title="PCA visualization of Selected Tracks", + xaxis_title=x_axis, + yaxis_title=y_axis, + uirevision=True, + hovermode="closest", + showlegend=True, + legend=dict( + yanchor="top", + y=1, + xanchor="left", + x=1.02, + title="Tracks", + bordercolor="Black", + borderwidth=1, + ), + margin=dict(l=50, r=150, t=50, b=50), + autosize=True, + ) + fig.update_xaxes(showgrid=False) + fig.update_yaxes(showgrid=False) + + # Add background points with hover info (excluding the colored tracks) + background_df = self.features_df[ + (self.features_df["fov_name"].isin(self.fov_tracks.keys())) + & (~self.features_df["track_id"].isin(unique_tracks)) + ] + + if not background_df.empty: + # Subsample background points if there are too many + if len(background_df) > 5000: # Adjust this threshold as needed + background_df = background_df.sample(n=5000, random_state=42) + + fig.add_trace( + go.Scattergl( + x=background_df[x_axis], + y=background_df[y_axis], + mode="markers", + marker=dict(size=12, color="lightgray", opacity=0.3), + name=f"Other tracks (showing {len(background_df)} of {len(self.features_df)} points)", + text=[ + f"Track: {track_id}
Time: {t}
FOV: {fov}" + for track_id, t, fov in zip( + background_df["track_id"], + background_df["t"], + background_df["fov_name"], + ) + ], + hoverinfo="text", + showlegend=True, + hoverlabel=dict(namelength=-1), + ) + ) + + # Add points for each selected track + for track_id in unique_tracks: + track_data = self.filtered_features_df[ + self.filtered_features_df["track_id"] == track_id + ].sort_values("t") + + # Get points for this track that are in clusters + track_points = list( + zip( + [fov for fov in track_data["fov_name"]], + [track_id] * len(track_data), + [t for t in track_data["t"]], + ) + ) + + # Determine colors based on cluster membership + colors = [] + opacities = [] + if self.clusters: + cluster_colors = [ + f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}" + for i in range(len(self.clusters)) + ] + point_to_cluster = {} + for cluster_idx, cluster in enumerate(self.clusters): + for point in cluster: + point_key = (point["fov_name"], point["track_id"], point["t"]) + point_to_cluster[point_key] = cluster_idx + + for point in track_points: + if point in point_to_cluster: + colors.append(cluster_colors[point_to_cluster[point]]) + opacities.append(1.0) + else: + colors.append("lightgray") + opacities.append(0.3) + else: + colors = [track_colors[track_id]] * len(track_data) + opacities = [1.0] * len(track_data) + + # Add points using Scattergl for better performance + scatter_kwargs = { + "x": track_data[x_axis], + "y": track_data[y_axis], + "mode": "markers", + "marker": dict( + size=10, # Reduced size + color=colors, + line=dict(width=1, color="black"), + opacity=opacities, + ), + "name": f"Track {track_id}", + "text": [ + f"Track: {track_id}
Time: {t}
FOV: {fov}" + for t, fov in zip(track_data["t"], track_data["fov_name"]) + ], + "hoverinfo": "text", + "hoverlabel": dict(namelength=-1), # Show full text in hover + } + + # Only apply selection properties if there are clusters + # This prevents opacity conflicts when no clusters exist + if self.clusters: + scatter_kwargs.update( + { + "unselected": dict(marker=dict(opacity=0.3, size=10)), + "selected": dict(marker=dict(size=12, opacity=1.0)), + } + ) + + fig.add_trace(go.Scattergl(**scatter_kwargs)) + + # Add trajectory lines and arrows if requested + if show_arrows and len(track_data) > 1: + x_coords = track_data[x_axis].values + y_coords = track_data[y_axis].values + + # Add dashed lines for the trajectory using Scattergl + fig.add_trace( + go.Scattergl( + x=x_coords, + y=y_coords, + mode="lines", + line=dict( + color=track_colors[track_id], + width=1, + dash="dot", + ), + showlegend=False, + hoverinfo="skip", + ) + ) + + # Add arrows at regular intervals (reduced frequency) + arrow_interval = max( + 1, len(track_data) // 3 + ) # Reduced number of arrows + for i in range(0, len(track_data) - 1, arrow_interval): + # Calculate arrow angle + dx = x_coords[i + 1] - x_coords[i] + dy = y_coords[i + 1] - y_coords[i] + + # Only add arrow if there's significant movement + if dx * dx + dy * dy > 1e-6: # Minimum distance threshold + # Add arrow annotation + fig.add_annotation( + x=x_coords[i + 1], + y=y_coords[i + 1], + ax=x_coords[i], + ay=y_coords[i], + xref="x", + yref="y", + axref="x", + ayref="y", + showarrow=True, + arrowhead=2, + arrowsize=1, # Reduced size + arrowwidth=1, # Reduced width + arrowcolor=track_colors[track_id], + opacity=0.8, + ) + + # Compute axis ranges to ensure equal aspect ratio + all_x_data = self.filtered_features_df[x_axis] + all_y_data = self.filtered_features_df[y_axis] + + if not all_x_data.empty and not all_y_data.empty: + x_range, y_range = self._calculate_equal_aspect_ranges( + all_x_data, all_y_data + ) + + # Set equal aspect ratio and range + fig.update_layout( + xaxis=dict( + range=x_range, scaleanchor="y", scaleratio=1, constrain="domain" + ), + yaxis=dict(range=y_range, constrain="domain"), + ) + + return fig + + def _create_time_colored_figure( + self, + show_arrows=False, + x_axis=None, + y_axis=None, + ): + """Create scatter plot with time-based coloring""" + x_axis = x_axis or self.default_x + y_axis = y_axis or self.default_y + + fig = go.Figure() + + # Set initial layout with lasso mode + fig.update_layout( + dragmode="lasso", + clickmode="event+select", + selectdirection="any", + plot_bgcolor="white", + title="PCA visualization of Selected Tracks", + xaxis_title=x_axis, + yaxis_title=y_axis, + uirevision=True, + hovermode="closest", + showlegend=True, + legend=dict( + yanchor="top", + y=1, + xanchor="left", + x=1.02, + title="Tracks", + bordercolor="Black", + borderwidth=1, + ), + margin=dict(l=50, r=150, t=50, b=50), + autosize=True, + ) + fig.update_xaxes(showgrid=False) + fig.update_yaxes(showgrid=False) + + # Add background points with hover info + all_tracks_df = self.features_df[ + self.features_df["fov_name"].isin(self.fov_tracks.keys()) + ] + + # Subsample background points if there are too many + if len(all_tracks_df) > 5000: # Adjust this threshold as needed + all_tracks_df = all_tracks_df.sample(n=5000, random_state=42) + + fig.add_trace( + go.Scattergl( + x=all_tracks_df[x_axis], + y=all_tracks_df[y_axis], + mode="markers", + marker=dict(size=12, color="lightgray", opacity=0.3), + name=f"Other points (showing {len(all_tracks_df)} of {len(self.features_df)} points)", + text=[ + f"Track: {track_id}
Time: {t}
FOV: {fov}" + for track_id, t, fov in zip( + all_tracks_df["track_id"], + all_tracks_df["t"], + all_tracks_df["fov_name"], + ) + ], + hoverinfo="text", + hoverlabel=dict(namelength=-1), + ) + ) + + # Add time-colored points using Scattergl + fig.add_trace( + go.Scattergl( + x=self.filtered_features_df[x_axis], + y=self.filtered_features_df[y_axis], + mode="markers", + marker=dict( + size=10, # Reduced size + color=self.filtered_features_df["t"], + colorscale="Viridis", + colorbar=dict(title="Time"), + ), + text=[ + f"Track: {track_id}
Time: {t}
FOV: {fov}" + for track_id, t, fov in zip( + self.filtered_features_df["track_id"], + self.filtered_features_df["t"], + self.filtered_features_df["fov_name"], + ) + ], + hoverinfo="text", + showlegend=False, + hoverlabel=dict(namelength=-1), # Show full text in hover + ) + ) + + # Add arrows if requested, but more efficiently + if show_arrows: + for track_id in self.filtered_features_df["track_id"].unique(): + track_data = self.filtered_features_df[ + self.filtered_features_df["track_id"] == track_id + ].sort_values("t") + + if len(track_data) > 1: + # Calculate distances between consecutive points + x_coords = track_data[x_axis].values + y_coords = track_data[y_axis].values + distances = np.sqrt(np.diff(x_coords) ** 2 + np.diff(y_coords) ** 2) + + # Only show arrows for movements larger than the median distance + threshold = np.median(distances) * 0.5 + + # Add arrows as a single trace + arrow_x = [] + arrow_y = [] + + for i in range(len(track_data) - 1): + if distances[i] > threshold: + arrow_x.extend([x_coords[i], x_coords[i + 1], None]) + arrow_y.extend([y_coords[i], y_coords[i + 1], None]) + + if arrow_x: # Only add if there are arrows to show + fig.add_trace( + go.Scatter( + x=arrow_x, + y=arrow_y, + mode="lines", + line=dict( + color="rgba(128, 128, 128, 0.5)", + width=1, + dash="dot", + ), + showlegend=False, + hoverinfo="skip", + ) + ) + + # Compute axis ranges to ensure equal aspect ratio + all_x_data = self.filtered_features_df[x_axis] + all_y_data = self.filtered_features_df[y_axis] + if not all_x_data.empty and not all_y_data.empty: + x_range, y_range = self._calculate_equal_aspect_ranges( + all_x_data, all_y_data + ) + + # Set equal aspect ratio and range + fig.update_layout( + xaxis=dict( + range=x_range, scaleanchor="y", scaleratio=1, constrain="domain" + ), + yaxis=dict(range=y_range, constrain="domain"), + ) + + return fig + + @staticmethod + def _normalize_image(img_array): + """Normalize a single image array to [0, 255] more efficiently""" + min_val = img_array.min() + max_val = img_array.max() + if min_val == max_val: + return np.zeros_like(img_array, dtype=np.uint8) + # Normalize in one step + return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8) + + @staticmethod + def _numpy_to_base64(img_array): + """Convert numpy array to base64 string with compression""" + if not isinstance(img_array, np.uint8): + img_array = img_array.astype(np.uint8) + img = Image.fromarray(img_array) + buffered = BytesIO() + # Use JPEG format with quality=85 for better compression + img.save(buffered, format="JPEG", quality=85, optimize=True) + return "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode( + "utf-8" + ) + + def save_cache(self, cache_path: str | None = None): + """Save the image cache to disk using pickle. + + Parameters + ---------- + cache_path : str | None, optional + Path to save the cache. If None, uses self.cache_path, by default None + """ + import pickle + + if cache_path is None: + if self.cache_path is None: + logger.warning("No cache path specified, skipping cache save") + return + cache_path = self.cache_path + else: + cache_path = Path(cache_path) + + # Create parent directory if it doesn't exist + cache_path.parent.mkdir(parents=True, exist_ok=True) + + # Save cache metadata for validation + cache_metadata = { + "data_path": str(self.data_path), + "tracks_path": str(self.tracks_path), + "features_path": str(self.features_path), + "channels": self.channels_to_display, + "z_range": self.z_range, + "yx_patch_size": self.yx_patch_size, + "cache_size": len(self.image_cache), + } + + try: + logger.info(f"Saving image cache to {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump((cache_metadata, self.image_cache), f) + logger.info(f"Successfully saved cache with {len(self.image_cache)} images") + except Exception as e: + logger.error(f"Error saving cache: {e}") + + def load_cache(self, cache_path: str | None = None) -> bool: + """Load the image cache from disk using pickle. + + Parameters + ---------- + cache_path : str | None, optional + Path to load the cache from. If None, uses self.cache_path, by default None + + Returns + ------- + bool + True if cache was successfully loaded, False otherwise + """ + import pickle + + if cache_path is None: + if self.cache_path is None: + logger.warning("No cache path specified, skipping cache load") + return False + cache_path = self.cache_path + else: + cache_path = Path(cache_path) + + if not cache_path.exists(): + logger.warning(f"Cache file {cache_path} does not exist") + return False + + try: + logger.info(f"Loading image cache from {cache_path}") + with open(cache_path, "rb") as f: + cache_metadata, loaded_cache = pickle.load(f) + + # Validate cache metadata + if ( + cache_metadata["data_path"] != str(self.data_path) + or cache_metadata["tracks_path"] != str(self.tracks_path) + or cache_metadata["features_path"] != str(self.features_path) + or cache_metadata["channels"] != self.channels_to_display + or cache_metadata["z_range"] != self.z_range + or cache_metadata["yx_patch_size"] != self.yx_patch_size + ): + logger.warning("Cache metadata mismatch, skipping cache load") + return False + + self.image_cache = loaded_cache + logger.info( + f"Successfully loaded cache with {len(self.image_cache)} images" + ) + return True + except Exception as e: + logger.error(f"Error loading cache: {e}") + return False + + def preload_images(self): + """Preload all images into memory""" + # Try to load from cache first + if self.cache_path and self.load_cache(): + return + + logger.info("Preloading images into cache...") + logger.info(f"FOVs to process: {list(self.filtered_tracks_by_fov.keys())}") + + # Process each FOV and its tracks + for fov_name, track_ids in self.filtered_tracks_by_fov.items(): + if not track_ids: # Skip FOVs with no tracks + logger.info(f"Skipping FOV {fov_name} as it has no tracks") + continue + + logger.info(f"Processing FOV {fov_name} with tracks {track_ids}") + + try: + data_module = TripletDataModule( + data_path=self.data_path, + tracks_path=self.tracks_path, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + source_channel=self.channels_to_display, + z_range=self.z_range, + initial_yx_patch_size=self.yx_patch_size, + final_yx_patch_size=self.yx_patch_size, + batch_size=1, + num_workers=self.num_loading_workers, + normalizations=None, + predict_cells=True, + ) + data_module.setup("predict") + + for batch in data_module.predict_dataloader(): + try: + images = batch["anchor"].numpy() + indices = batch["index"] + track_id = indices["track_id"].tolist() + t = indices["t"].tolist() + + img = np.stack(images) + cache_key = (fov_name, track_id[0], t[0]) + + logger.debug(f"Processing cache key: {cache_key}") + + # Process each channel based on its type + processed_channels = {} + for idx, channel in enumerate(self.channels_to_display): + try: + if channel in ["Phase3D", "DIC", "BF"]: + # For phase contrast, use the middle z-slice + z_idx = (self.z_range[1] - self.z_range[0]) // 2 + processed = self._normalize_image( + img[0, idx, z_idx] + ) + else: + # For fluorescence, use max projection + processed = self._normalize_image( + np.max(img[0, idx], axis=0) + ) + + processed_channels[channel] = self._numpy_to_base64( + processed + ) + logger.debug( + f"Successfully processed channel {channel} for {cache_key}" + ) + except Exception as e: + logger.error( + f"Error processing channel {channel} for {cache_key}: {e}" + ) + continue + + if ( + processed_channels + ): # Only store if at least one channel was processed + self.image_cache[cache_key] = processed_channels + + except Exception as e: + logger.error( + f"Error processing batch for {fov_name}, track {track_id}: {e}" + ) + continue + + except Exception as e: + logger.error(f"Error setting up data module for FOV {fov_name}: {e}") + continue + + logger.info(f"Successfully cached {len(self.image_cache)} images") + # Log some statistics about the cache + cached_fovs = set(key[0] for key in self.image_cache.keys()) + cached_tracks = set((key[0], key[1]) for key in self.image_cache.keys()) + logger.info(f"Cached FOVs: {cached_fovs}") + logger.info(f"Number of unique track-FOV combinations: {len(cached_tracks)}") + + # Save cache if path is specified + if self.cache_path: + self.save_cache() + + def _cleanup_cache(self): + """Clear the image cache when the program exits""" + logging.info("Cleaning up image cache...") + self.image_cache.clear() + + def _get_trajectory_images_lasso(self, x_axis, y_axis, selected_data): + """Get images of points selected by lasso""" + if not selected_data or not selected_data.get("points"): + return html.Div("Use the lasso tool to select points") + + # Dictionary to store points for each lasso selection + lasso_clusters = {} + + # Track which points we've seen to avoid duplicates within clusters + seen_points = set() + + # Process each selected point + for point in selected_data["points"]: + text = point["text"] + lines = text.split("
") + track_id = int(lines[0].split(": ")[1]) + t = int(lines[1].split(": ")[1]) + fov = lines[2].split(": ")[1] + + point_id = (track_id, t, fov) + cache_key = (fov, track_id, t) + + # Skip if we don't have the image in cache + if cache_key not in self.image_cache: + logger.debug(f"Skipping point {point_id} as it's not in the cache") + continue + + # Determine which curve (lasso selection) this point belongs to + curve_number = point.get("curveNumber", 0) + if curve_number not in lasso_clusters: + lasso_clusters[curve_number] = [] + + # Only add if we haven't seen this point in this cluster + cluster_point_id = (curve_number, point_id) + if cluster_point_id not in seen_points: + seen_points.add(cluster_point_id) + lasso_clusters[curve_number].append( + { + "track_id": track_id, + "t": t, + "fov_name": fov, + x_axis: point["x"], + y_axis: point["y"], + } + ) + + if not lasso_clusters: + return html.Div("No cached images found for the selected points") + + # Create sections for each lasso selection + cluster_sections = [] + for cluster_idx, points in lasso_clusters.items(): + cluster_df = pd.DataFrame(points) + + # Create channel rows for this cluster + channel_rows = [] + for channel in self.channels_to_display: + images = [] + for _, row in cluster_df.iterrows(): + cache_key = (row["fov_name"], row["track_id"], row["t"]) + images.append( + html.Div( + [ + html.Img( + src=self.image_cache[cache_key][channel], + style={ + "width": "150px", + "height": "150px", + "margin": "5px", + "border": "1px solid #ddd", + }, + ), + html.Div( + f"Track {row['track_id']}, t={row['t']}", + style={ + "textAlign": "center", + "fontSize": "12px", + }, + ), + ], + style={ + "display": "inline-block", + "margin": "5px", + "verticalAlign": "top", + }, + ) + ) + + if images: # Only add row if there are images + channel_rows.extend( + [ + html.H5( + f"{channel}", + style={ + "margin": "10px 5px", + "fontSize": "16px", + "fontWeight": "bold", + }, + ), + html.Div( + images, + style={ + "overflowX": "auto", + "whiteSpace": "nowrap", + "padding": "10px", + "border": "1px solid #ddd", + "borderRadius": "5px", + "marginBottom": "20px", + "backgroundColor": "#f8f9fa", + }, + ), + ] + ) + + if channel_rows: # Only add cluster section if it has images + cluster_sections.append( + html.Div( + [ + html.H3( + f"Lasso Selection {cluster_idx + 1}", + style={ + "marginTop": "30px", + "marginBottom": "15px", + "fontSize": "24px", + "fontWeight": "bold", + "borderBottom": "2px solid #007bff", + "paddingBottom": "5px", + }, + ), + html.Div( + channel_rows, + style={ + "backgroundColor": "#ffffff", + "padding": "15px", + "borderRadius": "8px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ), + ] + ) + ) + + return html.Div( + [ + html.H2( + f"Selected Points ({len(cluster_sections)} selections)", + style={ + "marginBottom": "20px", + "fontSize": "28px", + "fontWeight": "bold", + "color": "#2c3e50", + }, + ), + html.Div(cluster_sections), + ] + ) + + def _get_output_info_display(self) -> html.Div: + """ + Create a display component showing the output directory information. + + Returns + ------- + html.Div + HTML component displaying output directory info + """ + return html.Div( + [ + html.H4( + "Output Directory", + style={"marginBottom": "10px", "fontSize": "16px"}, + ), + html.Div( + [ + html.Span("📁 ", style={"fontSize": "14px"}), + html.Span( + str(self.output_dir), + style={ + "fontFamily": "monospace", + "backgroundColor": "#f8f9fa", + "padding": "4px 8px", + "borderRadius": "4px", + "border": "1px solid #dee2e6", + "fontSize": "12px", + }, + ), + ], + style={"marginBottom": "10px"}, + ), + html.Div( + "CSV files will be saved to this directory with timestamped names.", + style={ + "fontSize": "12px", + "color": "#6c757d", + "fontStyle": "italic", + }, + ), + ], + style={ + "backgroundColor": "#e9ecef", + "padding": "10px", + "borderRadius": "6px", + "marginBottom": "15px", + "border": "1px solid #ced4da", + }, + ) + + def _get_cluster_images(self): + """Display images for all clusters in a grid layout""" + if not self.clusters: + return html.Div( + [self._get_output_info_display(), html.Div("No clusters created yet")] + ) + + # Create cluster colors once + cluster_colors = [ + f"rgb{tuple(int(x * 255) for x in plt.cm.Set2(i % 8)[:3])}" + for i in range(len(self.clusters)) + ] + + # Create individual cluster panels + cluster_panels = [] + for cluster_idx, cluster_points in enumerate(self.clusters): + # Get cluster name or use default + cluster_name = self.cluster_names.get( + cluster_idx, f"Cluster {cluster_idx + 1}" + ) + + # Create a single scrollable container for all channels + all_channel_images = [] + for channel in self.channels_to_display: + images = [] + for point in cluster_points: + cache_key = (point["fov_name"], point["track_id"], point["t"]) + + images.append( + html.Div( + [ + html.Img( + src=self.image_cache[cache_key][channel], + style={ + "width": "100px", + "height": "100px", + "margin": "2px", + "border": f"2px solid {cluster_colors[cluster_idx]}", + "borderRadius": "4px", + }, + ), + html.Div( + f"Track {point['track_id']}, t={point['t']}", + style={ + "textAlign": "center", + "fontSize": "10px", + }, + ), + ], + style={ + "display": "inline-block", + "margin": "2px", + "verticalAlign": "top", + }, + ) + ) + + if images: + all_channel_images.extend( + [ + html.H6( + f"{channel}", + style={ + "margin": "5px", + "fontSize": "12px", + "fontWeight": "bold", + "position": "sticky", + "left": "0", + "backgroundColor": "#f8f9fa", + "zIndex": "1", + "paddingLeft": "5px", + }, + ), + html.Div( + images, + style={ + "whiteSpace": "nowrap", + "marginBottom": "10px", + }, + ), + ] + ) + + if all_channel_images: + # Create a panel for this cluster with synchronized scrolling + cluster_panels.append( + html.Div( + [ + html.Div( + [ + html.Span( + cluster_name, + style={ + "color": cluster_colors[cluster_idx], + "fontWeight": "bold", + "fontSize": "16px", + }, + ), + html.Span( + f" ({len(cluster_points)} points)", + style={ + "color": "#2c3e50", + "fontSize": "14px", + }, + ), + html.Button( + "✏️", + id={ + "type": "edit-cluster-name", + "index": cluster_idx, + }, + style={ + "backgroundColor": "transparent", + "border": "none", + "cursor": "pointer", + "fontSize": "12px", + "marginLeft": "5px", + "color": "#6c757d", + }, + title="Edit cluster name", + ), + ], + style={ + "marginBottom": "10px", + "borderBottom": f"2px solid {cluster_colors[cluster_idx]}", + "paddingBottom": "5px", + "position": "sticky", + "top": "0", + "backgroundColor": "white", + "zIndex": "1", + }, + ), + html.Div( + all_channel_images, + style={ + "overflowX": "auto", + "overflowY": "auto", + "height": "400px", + "backgroundColor": "#ffffff", + "padding": "10px", + "borderRadius": "8px", + "boxShadow": "0 2px 4px rgba(0,0,0,0.1)", + }, + ), + ], + style={ + "width": "24%", + "display": "inline-block", + "verticalAlign": "top", + "padding": "5px", + "boxSizing": "border-box", + }, + ) + ) + + # Create rows of 4 panels each + rows = [] + for i in range(0, len(cluster_panels), 4): + row = html.Div( + cluster_panels[i : i + 4], + style={ + "display": "flex", + "justifyContent": "flex-start", + "gap": "10px", + "marginBottom": "10px", + }, + ) + rows.append(row) + + return html.Div( + [ + html.H2( + [ + "Clusters ", + html.Span( + f"({len(self.clusters)} total)", + style={"color": "#666"}, + ), + ], + style={ + "marginBottom": "20px", + "fontSize": "28px", + "fontWeight": "bold", + "color": "#2c3e50", + }, + ), + self._get_output_info_display(), + html.Div( + rows, + style={ + "maxHeight": "calc(100vh - 200px)", + "overflowY": "auto", + "padding": "10px", + }, + ), + ] + ) + + def get_output_dir(self) -> Path: + """ + Get the output directory for saving files. + + Returns + ------- + Path + The output directory path + """ + return self.output_dir + + def save_clusters_to_csv(self, output_path: str | None = None) -> str: + """ + Save cluster information to CSV file. + + This method exports all cluster data including track_id, time, FOV, + cluster assignment, and cluster names to a CSV file for further analysis. + + Parameters + ---------- + output_path : str | None, optional + Path to save the CSV file. If None, generates a timestamped filename + in the output directory, by default None + + Returns + ------- + str + Path to the saved CSV file + + Notes + ----- + The CSV will contain columns: + - cluster_id: The cluster number (1-indexed) + - cluster_name: The custom name assigned to the cluster + - track_id: The track identifier + - time: The timepoint + - fov_name: The field of view name + - cluster_size: Number of points in the cluster + """ + if not self.clusters: + logger.warning("No clusters to save") + return "" + + # Prepare data for CSV export + csv_data = [] + for cluster_idx, cluster in enumerate(self.clusters): + cluster_id = cluster_idx + 1 # 1-indexed for user-friendly output + cluster_size = len(cluster) + cluster_name = self.cluster_names.get(cluster_idx, f"Cluster {cluster_id}") + + for point in cluster: + csv_data.append( + { + "cluster_id": cluster_id, + "cluster_name": cluster_name, + "track_id": point["track_id"], + "time": point["t"], + "fov_name": point["fov_name"], + "cluster_size": cluster_size, + } + ) + + # Create DataFrame and save to CSV + df = pd.DataFrame(csv_data) + + if output_path is None: + # Generate timestamped filename in output directory + from datetime import datetime + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = self.output_dir / f"clusters_{timestamp}.csv" + else: + output_path = Path(output_path) + # If only filename is provided, use output directory + if not output_path.parent.name: + output_path = self.output_dir / output_path.name + + try: + # Create parent directory if it doesn't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + df.to_csv(output_path, index=False) + logger.info(f"Successfully saved {len(df)} cluster points to {output_path}") + return str(output_path) + + except Exception as e: + logger.error(f"Error saving clusters to CSV: {e}") + raise + + def run(self, debug=False, port=None): + """Run the Dash server + + Parameters + ---------- + debug : bool, optional + Whether to run in debug mode, by default False + port : int, optional + Port to run on. If None, will try ports from 8050-8070, by default None + """ + import socket + + def is_port_in_use(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("127.0.0.1", port)) + return False + except socket.error: + return True + + if port is None: + # Try ports from 8050 to 8070 + # FIXME: set a range for the ports + port_range = list(range(8050, 8071)) + for p in port_range: + if not is_port_in_use(p): + port = p + break + if port is None: + raise RuntimeError( + f"Could not find an available port in range {port_range[0]}-{port_range[-1]}" + ) + + try: + logger.info(f"Starting server on port {port}") + self.app.run( + debug=debug, + port=port, + use_reloader=False, # Disable reloader to prevent multiple instances + ) + except KeyboardInterrupt: + logger.info("Server shutdown requested...") + except Exception as e: + logger.error(f"Error running server: {e}") + finally: + self._cleanup_cache() + logger.info("Server shutdown complete") diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py new file mode 100644 index 000000000..55481d434 --- /dev/null +++ b/viscy/representation/multi_modal.py @@ -0,0 +1,139 @@ +from logging import getLogger +from typing import Literal, Sequence + +import torch +from pytorch_metric_learning.losses import NTXentLoss +from torch import Tensor, nn + +from viscy.data.typing import TripletSample +from viscy.representation.contrastive import ContrastiveEncoder +from viscy.representation.engine import ContrastiveModule + +_logger = getLogger("lightning.pytorch") + + +class JointEncoders(nn.Module): + def __init__( + self, + source_encoder: nn.Module | ContrastiveEncoder, + target_encoder: nn.Module | ContrastiveEncoder, + ) -> None: + super().__init__() + self.source_encoder = source_encoder + self.target_encoder = target_encoder + + def forward( + self, source: Tensor, target: Tensor + ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: + return self.source_encoder(source), self.target_encoder(target) + + def forward_features(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + return self.source_encoder(source)[0], self.target_encoder(target)[0] + + def forward_projections( + self, source: Tensor, target: Tensor + ) -> tuple[Tensor, Tensor]: + return self.source_encoder(source)[1], self.target_encoder(target)[1] + + +class JointContrastiveModule(ContrastiveModule): + """CLIP-style model pair for self-supervised cross-modality representation learning.""" + + def __init__( + self, + encoder: nn.Module | JointEncoders, + loss_function: ( + nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + ) = nn.TripletMarginLoss(margin=0.5), + lr: float = 1e-3, + schedule: Literal["WarmupCosine", "Constant"] = "Constant", + log_batches_per_epoch: int = 8, + log_samples_per_batch: int = 1, + log_embeddings: bool = False, + example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256), + prediction_arm: Literal["source", "target"] = "source", + ) -> None: + super().__init__( + encoder=encoder, + loss_function=loss_function, + lr=lr, + schedule=schedule, + log_batches_per_epoch=log_batches_per_epoch, + log_samples_per_batch=log_samples_per_batch, + log_embeddings=log_embeddings, + example_input_array_shape=example_input_array_shape, + ) + self.example_input_array = (self.example_input_array, self.example_input_array) + self._prediction_arm = prediction_arm + + def forward(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + return self.model.forward_projections(source, target) + + def _info_nce_style_loss(self, z1: Tensor, z2: Tensor) -> Tensor: + indices = torch.arange(0, z1.size(0), device=z2.device) + labels = torch.cat((indices, indices)) + embeddings = torch.cat((z1, z2)) + return self.loss_function(embeddings, labels) + + def _fit_forward_step( + self, batch: TripletSample, batch_idx: int, stage: Literal["train", "val"] + ) -> Tensor: + anchor_img = batch["anchor"] + pos_img = batch["positive"] + anchor_source_projection, anchor_target_projection = ( + self.model.forward_projections(anchor_img[:, 0:1], anchor_img[:, 1:2]) + ) + positive_source_projection, positive_target_projection = ( + self.model.forward_projections(pos_img[:, 0:1], pos_img[:, 1:2]) + ) + # loss_source = self._info_nce_style_loss( + # anchor_source_projection, positive_source_projection + # ) + # loss_target = self._info_nce_style_loss( + # anchor_target_projection, positive_target_projection + # ) + loss_joint = self._info_nce_style_loss( + anchor_source_projection, anchor_target_projection + ) + self._info_nce_style_loss( + positive_target_projection, positive_source_projection + ) + # loss = loss_source + loss_target + loss_joint + loss = loss_joint + self._log_step_samples(batch_idx, (anchor_img, pos_img), stage) + self._log_metrics( + loss=loss, + anchor=anchor_source_projection, + positive=anchor_target_projection, + negative=None, + stage=stage, + ) + return loss + + def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="train") + + def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="val") + + def on_predict_start(self) -> None: + _logger.info(f"Using {self._prediction_arm} encoder for predictions.") + if self._prediction_arm == "source": + self._prediction_encoder = self.model.source_encoder + self._prediction_channel_slice = slice(0, 1) + elif self._prediction_arm == "target": + self._prediction_encoder = self.model.target_encoder + self._prediction_channel_slice = slice(1, 2) + else: + raise ValueError("Invalid prediction arm.") + + def predict_step( + self, batch: TripletSample, batch_idx: int, dataloader_idx: int = 0 + ): + features, projections = self._prediction_encoder( + batch["anchor"][:, self._prediction_channel_slice] + ) + return { + "features": features, + "projections": projections, + "index": batch["index"], + } diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 3fff21fa1..12177b64b 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,9 +1,12 @@ from viscy.transforms._redef import ( + CenterSpatialCropd, RandAdjustContrastd, RandAffined, + RandFlipd, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, + RandSpatialCropd, RandWeightedCropd, ScaleIntensityRangePercentilesd, ) @@ -17,13 +20,16 @@ __all__ = [ "BatchedZoom", + "CenterSpatialCropd", "NormalizeSampled", "RandAdjustContrastd", "RandAffined", + "RandFlipd", "RandGaussianNoised", "RandGaussianSmoothd", "RandInvertIntensityd", "RandScaleIntensityd", + "RandSpatialCropd", "RandWeightedCropd", "ScaleIntensityRangePercentilesd", "StackChannelsd", diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index e41c27446..fe168603a 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -3,11 +3,14 @@ from typing import Sequence from monai.transforms import ( + CenterSpatialCropd, RandAdjustContrastd, RandAffined, + RandFlipd, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, + RandSpatialCropd, RandWeightedCropd, ScaleIntensityRangePercentilesd, ) @@ -37,9 +40,9 @@ def __init__( self, keys: Sequence[str] | str, prob: float, - rotate_range: Sequence[float] | float, - shear_range: Sequence[float] | float, - scale_range: Sequence[float] | float, + rotate_range: Sequence[float | Sequence[float]] | float, + shear_range: Sequence[float | Sequence[float]] | float, + scale_range: Sequence[float | Sequence[float]] | float, **kwargs, ): super().__init__( @@ -132,3 +135,40 @@ def __init__( dtype=dtype, allow_missing_keys=allow_missing_keys, ) + + +class RandSpatialCropd(RandSpatialCropd): + def __init__( + self, + keys: Sequence[str] | str, + roi_size: Sequence[int] | int, + random_center: bool = True, + **kwargs, + ): + super().__init__( + keys=keys, + roi_size=roi_size, + random_center=random_center, + **kwargs, + ) + + +class CenterSpatialCropd(CenterSpatialCropd): + def __init__( + self, + keys: Sequence[str] | str, + roi_size: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, roi_size=roi_size, **kwargs) + + +class RandFlipd(RandFlipd): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + spatial_axis: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) diff --git a/viscy/utils/slurm_utils.py b/viscy/utils/slurm_utils.py new file mode 100644 index 000000000..9cfafb84b --- /dev/null +++ b/viscy/utils/slurm_utils.py @@ -0,0 +1,94 @@ +import psutil +import torch + + +def calculate_dataloader_settings( + batch_size: int, + sample_memory_mb: float, + available_ram_gb: float | None = None, + available_cpu_cores: int | None = None, + target_ram_usage: float = 0.25, + target_vram_usage: float = 0.1, + available_vram_gb: float | None = None, + use_gpu: bool = True, +) -> dict: + """ + Calculate optimal DataLoader settings based on system resources including GPU VRAM. + + Parameters + ---------- + batch_size: int + Size of each batch + sample_memory_mb: float + Approximate memory per sample in MB + available_ram_gb: float, optional + Available RAM in GB. If None, will use system RAM + available_cpu_cores: int, optional + Number of CPU cores. If None, will use system cores + target_ram_usage: float, optional + Target fraction of RAM to use for prefetching. If None, will use 0.25 of RAM. + target_vram_usage: float, optional + Target fraction of VRAM to use for prefetching. If None, will use 0.1 of VRAM. + available_vram_gb: float, optional + Available VRAM in GB. If None, will use system VRAM. + use_gpu: bool, optional + Whether to consider GPU memory constraints. If False, VRAM constraints are ignored + + Returns + ------- + dict: Recommended settings for DataLoader + """ + # Get system resources if not provided + if available_ram_gb is None: + available_ram_gb = psutil.virtual_memory().total / (1024**3) + if available_cpu_cores is None: + available_cpu_cores = psutil.cpu_count(logical=False) + + # Calculate memory per batch + batch_memory_mb = batch_size * sample_memory_mb + + # Calculate maximum prefetch factor based on RAM + max_prefetch_memory_mb = (available_ram_gb * 1024) * target_ram_usage + max_prefetch_factor_ram = int( + max_prefetch_memory_mb / (batch_memory_mb * available_cpu_cores) + ) + + # Calculate maximum prefetch factor based on VRAM if GPU is being used + max_prefetch_factor_vram = float("inf") + if use_gpu: + if available_vram_gb is None: + if torch.cuda.is_available(): + available_vram_gb = torch.cuda.get_device_properties(0).total_memory / ( + 1024**3 + ) + else: + raise ValueError( + "use_gpu is True but no VRAM specified and CUDA is not available" + ) + + max_prefetch_memory_mb_vram = (available_vram_gb * 1024) * target_vram_usage + max_prefetch_factor_vram = int( + max_prefetch_memory_mb_vram / (batch_memory_mb * available_cpu_cores) + ) + + # Take the minimum of RAM and VRAM based prefetch factors + max_prefetch_factor = min(max_prefetch_factor_ram, max_prefetch_factor_vram) + max_prefetch_factor = max(1, min(max_prefetch_factor, 4)) # Cap between 1 and 4 + + # Calculate optimal number of workers + # Leaving 2 cores for main process and other tasks + optimal_workers = max(1, min(available_cpu_cores - 2, available_cpu_cores)) + + return { + "num_workers": optimal_workers, + "prefetch_factor": max_prefetch_factor, + "persistent_workers": True, + "pin_memory": True, + "estimated_memory_usage_mb": batch_memory_mb + * optimal_workers + * max_prefetch_factor, + "estimated_vram_usage_mb": ( + batch_memory_mb * optimal_workers * max_prefetch_factor if use_gpu else 0 + ), + "available_vram_gb": available_vram_gb if use_gpu else None, + }