From f53f48b307e7ab3ac9170adc0842ef507589a58d Mon Sep 17 00:00:00 2001 From: "dihan.zheng" Date: Mon, 7 Jul 2025 15:48:23 -0700 Subject: [PATCH 01/10] updated the first version of the FID code --- applications/dynacell/fid.py | 220 ++++++++++ applications/dynacell/test_fid.sh | 8 + applications/dynacell/vae_3d/__init__.py | 0 .../dynacell/vae_3d/modules/__init__.py | 0 .../dynacell/vae_3d/modules/autoencoders.py | 160 ++++++++ .../dynacell/vae_3d/modules/blocks.py | 385 ++++++++++++++++++ .../dynacell/vae_3d/modules/decoder.py | 142 +++++++ .../dynacell/vae_3d/modules/encoder.py | 157 +++++++ applications/dynacell/vae_3d/modules/utils.py | 9 + applications/dynacell/vae_3d/vae_3d_config.py | 16 + applications/dynacell/vae_3d/vae_3d_model.py | 138 +++++++ 11 files changed, 1235 insertions(+) create mode 100644 applications/dynacell/fid.py create mode 100644 applications/dynacell/test_fid.sh create mode 100644 applications/dynacell/vae_3d/__init__.py create mode 100644 applications/dynacell/vae_3d/modules/__init__.py create mode 100644 applications/dynacell/vae_3d/modules/autoencoders.py create mode 100644 applications/dynacell/vae_3d/modules/blocks.py create mode 100644 applications/dynacell/vae_3d/modules/decoder.py create mode 100644 applications/dynacell/vae_3d/modules/encoder.py create mode 100644 applications/dynacell/vae_3d/modules/utils.py create mode 100644 applications/dynacell/vae_3d/vae_3d_config.py create mode 100644 applications/dynacell/vae_3d/vae_3d_model.py diff --git a/applications/dynacell/fid.py b/applications/dynacell/fid.py new file mode 100644 index 000000000..044264023 --- /dev/null +++ b/applications/dynacell/fid.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +from vae_3d.vae_3d_config import VAE3DConfig +from vae_3d.vae_3d_model import VAE3DModel + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.no_grad() +def encode_fovs( + fov_pairs, + vae, + channel_name1: str, + channel_name2: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb1, emb2 : (N, latent_dim) tensors + """ + emb1, emb2 = [], [] + + for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v1 = torch.as_tensor( + pos1.data[:, pos1.get_channel_index(channel_name1)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + v2 = torch.as_tensor( + pos2.data[:, pos2.get_channel_index(channel_name2)], + dtype=torch.float32, device=device, + ) + + v1 = normalise(v1) # still (T, D, H, W) + v2 = normalise(v2) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v1.shape[0], batch_size): + slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) + + # resize to input spatial size + slice1 = torch.nn.functional.interpolate( + slice1, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + slice2 = torch.nn.functional.interpolate( + slice2, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat1 = vae.encode(slice1).mean + feat2 = vae.encode(slice2).mean + + feat1 = feat1.mean(dim=(1, 2)) + feat2 = feat2.mean(dim=(1, 2)) + + feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) + feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) + + emb1.append(feat1) + emb2.append(feat2) + + return torch.cat(emb1, 0), torch.cat(emb2, 0) + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) # let Hydra handle VAE3DConfig + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + model_cfg = VAE3DConfig() + model_cfg.loadcheck_path = args.loadcheck_path + vae = VAE3DModel(config=model_cfg).to(device).eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + n = min(len(fovs1), len(fovs2)) + if args.max_fov: + n = min(n, args.max_fov) + pair_list = list(zip(fovs1[:n], fovs2[:n])) + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1, emb2 = encode_fovs( + pair_list, vae, + args.channel_name1, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID (VAE latent, N={emb1.size(0)}): {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/dynacell/test_fid.sh b/applications/dynacell/test_fid.sh new file mode 100644 index 000000000..990e93f5e --- /dev/null +++ b/applications/dynacell/test_fid.sh @@ -0,0 +1,8 @@ +python applications/dynacell/fid.py \ + --data_path1 /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/prediction/a549/output.zarr \ + --data_path2 /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/prediction/a549/output.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Nuclei-prediction \ + --loadcheck_path /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/pretrain_cyto3d/PT_VAE-3D_nucleus_poisson_KL1e-3_LC2_ch-32-64-128-256/checkpoint-50000/pytorch_model.bin \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/dynacell/vae_3d/__init__.py b/applications/dynacell/vae_3d/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/dynacell/vae_3d/modules/__init__.py b/applications/dynacell/vae_3d/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/dynacell/vae_3d/modules/autoencoders.py b/applications/dynacell/vae_3d/modules/autoencoders.py new file mode 100644 index 000000000..9c3fb927a --- /dev/null +++ b/applications/dynacell/vae_3d/modules/autoencoders.py @@ -0,0 +1,160 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin + + +class Autoencoder3DKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/applications/dynacell/vae_3d/modules/blocks.py b/applications/dynacell/vae_3d/modules/blocks.py new file mode 100644 index 000000000..b0d4e4380 --- /dev/null +++ b/applications/dynacell/vae_3d/modules/blocks.py @@ -0,0 +1,385 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.normalization import RMSNorm +from diffusers.utils import is_torch_version +from diffusers.models.activations import get_activation + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + conv_shortcut_bias: bool = True, + ): + super().__init__() + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_3d_out_channels = out_channels + self.conv2 = nn.Conv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + conv_3d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + norm_type: Optional[str] = None, + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + bias: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + self.kernel_size = kernel_size + stride = 2 # Downsampling stride is fixed to 2 + + # Initialize normalization + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # Choose between convolutional or pooling downsampling + if use_conv: + self.conv = nn.Conv3d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels, "out_channels must match channels when using pooling" + self.conv = nn.AvgPool3d(kernel_size=stride, stride=stride) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the downsampling block. + + Args: + hidden_states (torch.Tensor): Input feature map of shape (B, C, D, H, W). + + Returns: + torch.Tensor: Downsampled feature map. + """ + assert hidden_states.shape[1] == self.channels, \ + f"Expected input channels {self.channels}, but got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + # Apply padding if using conv downsampling and no padding was specified + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + # Apply downsampling + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample3D(nn.Module): + """A 3D upsampling layer with a convolution. + """ + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + dtype = hidden_states.dtype + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): + hidden_states = hidden_states.to(torch.float32) + + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + if output_size is None: + B, C, D, H, W = hidden_states.shape + if B*C*D*H*W*8 > 2_100_000_000: + x1 = F.interpolate(hidden_states[:, :, :D//2], scale_factor=2, mode="nearest") + x2 = F.interpolate(hidden_states[:, :, D//2:], scale_factor=2, mode="nearest") + hidden_states = torch.cat([x1, x2], dim=2) + else: + hidden_states = F.interpolate(hidden_states, scale_factor=2, mode="nearest") + # hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # Cast back to original dtype + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + """ + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if attn_groups is None: + attn_groups = resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ), + ] + + for _ in range(num_layers): + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/decoder.py b/applications/dynacell/vae_3d/modules/decoder.py new file mode 100644 index 000000000..a1887f0b1 --- /dev/null +++ b/applications/dynacell/vae_3d/modules/decoder.py @@ -0,0 +1,142 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .blocks import UNetMidBlock3D, UpDecoderBlock3D + +from diffusers.utils import is_torch_version +from diffusers.models.attention_processor import SpatialNorm + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_up_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(num_up_blocks): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=not is_final_block, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv3d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + padding_mode='reflect', + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + else: + # middle + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/encoder.py b/applications/dynacell/vae_3d/modules/encoder.py new file mode 100644 index 000000000..31b8a70cf --- /dev/null +++ b/applications/dynacell/vae_3d/modules/encoder.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from .blocks import DownEncoderBlock3D, UNetMidBlock3D +from diffusers.utils import BaseOutput, is_torch_version + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + commit_loss: Optional[torch.FloatTensor] = None + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + padding_mode='reflect' + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i in range(num_down_blocks): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + in_channels=input_channel, + out_channels=output_channel, + dropout=0.0, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv3d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + diff --git a/applications/dynacell/vae_3d/modules/utils.py b/applications/dynacell/vae_3d/modules/utils.py new file mode 100644 index 000000000..0c2a8185e --- /dev/null +++ b/applications/dynacell/vae_3d/modules/utils.py @@ -0,0 +1,9 @@ +import torch +from dataclasses import dataclass +from transformers.utils import ModelOutput + +@dataclass +class VAEOutput(ModelOutput): + loss: torch.FloatTensor = None + recon_loss: torch.FloatTensor = None + kl_loss: torch.FloatTensor = None \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_config.py b/applications/dynacell/vae_3d/vae_3d_config.py new file mode 100644 index 000000000..73f353219 --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_config.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from dataclasses import dataclass +from transformers import PretrainedConfig +from dataclasses import field + +@dataclass +class VAE3DConfig(PretrainedConfig): + model_type: str = 'vae' + + # Model parameters + in_channels: int = 1 + out_channels: int = 1 + num_down_blocks: int = 4 + latent_channels: int = 2 + vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256]) + loadcheck_path: str = "" \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_model.py b/applications/dynacell/vae_3d/vae_3d_model.py new file mode 100644 index 000000000..d01138b1e --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_model.py @@ -0,0 +1,138 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig +from transformers import PreTrainedModel +from .modules.utils import VAEOutput + + +class VAE3DModel(PreTrainedModel): + config_class = VAE3DConfig + + def __init__(self, config: VAE3DConfig): + super().__init__(config) + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x).latent_dist + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, batched_data): + x = batched_data['data'] + + """Forward pass through the VAE.""" + latent_dist = self.encode(x) + latents = latent_dist.sample() + recon_x = self.decode(latents).sample + + total_loss, recon_loss, kl_loss = self.compute_loss(x, recon_x, latent_dist) + + return VAEOutput(total_loss, recon_loss, kl_loss) + + def compute_loss(self, x, recon_x, latent_dist): + """Compute reconstruction and KL divergence loss.""" + if self.config.vae_recon_loss_type == 'mse': + recon_loss = nn.MSELoss()(recon_x, x) + elif self.config.vae_recon_loss_type == 'poisson': + x = x.clip(-1, 1) + recon_x = recon_x.clip(-1, 1) + peak = self.config.poisson_peak if hasattr(self.config, 'poisson_peak') else 1.0 + target = (x + 1) / 2.0 * peak + lam = (recon_x + 1) / 2.0 * peak + recon_loss = torch.mean(lam - target * torch.log(lam + 1e-8)) + + kl_loss = -0.5 * torch.mean(1 + latent_dist.logvar - latent_dist.mean.pow(2) - latent_dist.logvar.exp()) + total_loss = self.config.recon_loss_coeff * recon_loss + self.config.kl_loss_coeff * kl_loss + return total_loss, recon_loss, kl_loss + + def sample(self, num_samples=1, latent_size=32, device="cpu"): + """ + Generate samples from the latent space. + + Args: + num_samples (int): Number of samples to generate. + device (str): Device to perform sampling on. + + Returns: + torch.Tensor: Generated images. + """ + # Sample from a standard normal distribution in latent space + latents = torch.randn((num_samples, self.config.latent_channels, latent_size, latent_size, latent_size), device=device) # Shape matches latent dimensions + + # Decode latents to generate images + with torch.no_grad(): + generated_images = self.decode(latents).sample + + return generated_images + + def reconstruct(self, x): + latent_dist = self.encode(x) + latents = latent_dist.sample() # Reparameterization trick + recon_x = self.decode(latents).sample + + return recon_x From d33f8097bb52a4d1bb91f7fd52cbfb168aa4ebc9 Mon Sep 17 00:00:00 2001 From: "dihan.zheng" Date: Fri, 18 Jul 2025 14:10:51 -0700 Subject: [PATCH 02/10] update --- applications/dynacell/fid.py | 125 +++++++++--------- applications/dynacell/test_fid.sh | 17 ++- applications/dynacell/vae_3d/vae_3d_config.py | 4 +- 3 files changed, 76 insertions(+), 70 deletions(-) diff --git a/applications/dynacell/fid.py b/applications/dynacell/fid.py index 044264023..577c65346 100644 --- a/applications/dynacell/fid.py +++ b/applications/dynacell/fid.py @@ -25,68 +25,6 @@ def normalise(volume: torch.Tensor) -> torch.Tensor: volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] return volume * 2.0 - 1.0 # → [-1,1] -@torch.no_grad() -def encode_fovs( - fov_pairs, - vae, - channel_name1: str, - channel_name2: str, - device: str = "cuda", - batch_size: int = 4, - input_spatial_size: tuple = (32, 512, 512), -): - """ - For each FOV pair: - • take all T time-frames (shape: T, D, H, W) - • normalise to [-1, 1] - • feed through VAE in chunks of ≤ batch_size frames - • average the resulting T latent vectors → one embedding / FOV - Returns - emb1, emb2 : (N, latent_dim) tensors - """ - emb1, emb2 = [], [] - - for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): - # ---------------- load & normalise ---------------- # - v1 = torch.as_tensor( - pos1.data[:, pos1.get_channel_index(channel_name1)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) - v2 = torch.as_tensor( - pos2.data[:, pos2.get_channel_index(channel_name2)], - dtype=torch.float32, device=device, - ) - - v1 = normalise(v1) # still (T, D, H, W) - v2 = normalise(v2) - - # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v1.shape[0], batch_size): - slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) - - # resize to input spatial size - slice1 = torch.nn.functional.interpolate( - slice1, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - slice2 = torch.nn.functional.interpolate( - slice2, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - - feat1 = vae.encode(slice1).mean - feat2 = vae.encode(slice2).mean - - feat1 = feat1.mean(dim=(1, 2)) - feat2 = feat2.mean(dim=(1, 2)) - - feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) - feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) - - emb1.append(feat1) - emb2.append(feat2) - - return torch.cat(emb1, 0), torch.cat(emb2, 0) - @torch.jit.script_if_tracing def sqrtm(sigma: Tensor) -> Tensor: r"""Returns the square root of a positive semi-definite matrix. @@ -160,12 +98,71 @@ def fid_from_features(f1, f2, eps=1e-6): return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() +@torch.no_grad() +def encode_fovs( + fov_pairs, + vae, + channel_name1: str, + channel_name2: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb1, emb2 : (N, latent_dim) tensors + """ + emb1, emb2 = [], [] + + for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v1 = torch.as_tensor( + pos1.data[:, pos1.get_channel_index(channel_name1)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + v2 = torch.as_tensor( + pos2.data[:, pos2.get_channel_index(channel_name2)], + dtype=torch.float32, device=device, + ) + + v1 = normalise(v1) # still (T, D, H, W) + v2 = normalise(v2) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v1.shape[0], batch_size): + slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) + + # resize to input spatial size + slice1 = torch.nn.functional.interpolate( + slice1, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + slice2 = torch.nn.functional.interpolate( + slice2, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat1 = vae.encode(slice1).mean + feat2 = vae.encode(slice2).mean + + feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) + feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) + + emb1.append(feat1) + emb2.append(feat2) + + return torch.cat(emb1, 0), torch.cat(emb2, 0) + # ----------------------------------------------------------------------------- # # Main # # ----------------------------------------------------------------------------- # def build_argparser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(add_help=False) # let Hydra handle VAE3DConfig + p = argparse.ArgumentParser(add_help=False) p.add_argument("--data_path1", type=Path, required=True) p.add_argument("--data_path2", type=Path, required=True) p.add_argument("--channel_name", type=str, default=None) @@ -212,7 +209,7 @@ def main(args) -> None: # ----------------- FID ------------------ # fid_val = fid_from_features(emb1, emb2) - print(f"\nFID (VAE latent, N={emb1.size(0)}): {fid_val:.6f}") + print(f"\nFID: {fid_val:.6f}") if __name__ == "__main__": parser = build_argparser() diff --git a/applications/dynacell/test_fid.sh b/applications/dynacell/test_fid.sh index 990e93f5e..9e8adea15 100644 --- a/applications/dynacell/test_fid.sh +++ b/applications/dynacell/test_fid.sh @@ -1,8 +1,17 @@ -python applications/dynacell/fid.py \ - --data_path1 /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/prediction/a549/output.zarr \ - --data_path2 /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/prediction/a549/output.zarr \ +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ --channel_name1 Nuclei-prediction \ --channel_name2 Nuclei-prediction \ - --loadcheck_path /hpc/projects/group.huang/dihan.zheng/CELL-Diff-VS/pretrain_cyto3d/PT_VAE-3D_nucleus_poisson_KL1e-3_LC2_ch-32-64-128-256/checkpoint-50000/pytorch_model.bin \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ + --batch_size 4 \ + --device cuda + +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane-prediction \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ --batch_size 4 \ --device cuda \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_config.py b/applications/dynacell/vae_3d/vae_3d_config.py index 73f353219..d30883ed2 100644 --- a/applications/dynacell/vae_3d/vae_3d_config.py +++ b/applications/dynacell/vae_3d/vae_3d_config.py @@ -10,7 +10,7 @@ class VAE3DConfig(PretrainedConfig): # Model parameters in_channels: int = 1 out_channels: int = 1 - num_down_blocks: int = 4 + num_down_blocks: int = 5 latent_channels: int = 2 - vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256]) + vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256, 256]) loadcheck_path: str = "" \ No newline at end of file From 14df418d553f83018f03bcf0132f7fd56086e600 Mon Sep 17 00:00:00 2001 From: "dihan.zheng" Date: Thu, 7 Aug 2025 11:20:18 -0700 Subject: [PATCH 03/10] update torchscript model --- .gitignore | 2 + applications/dynacell/fid_ts.py | 213 ++++++++++++++++++ applications/dynacell/test_fid_ts.sh | 17 ++ .../vae_3d/modules/autoencoders_ts.py | 82 +++++++ .../dynacell/vae_3d/modules/blocks.py | 50 +--- .../dynacell/vae_3d/modules/decoder.py | 54 +---- .../dynacell/vae_3d/modules/encoder.py | 86 +------ .../dynacell/vae_3d/vae_3d_model_ts.py | 90 ++++++++ 8 files changed, 426 insertions(+), 168 deletions(-) create mode 100644 applications/dynacell/fid_ts.py create mode 100644 applications/dynacell/test_fid_ts.sh create mode 100644 applications/dynacell/vae_3d/modules/autoencoders_ts.py create mode 100644 applications/dynacell/vae_3d/vae_3d_model_ts.py diff --git a/.gitignore b/.gitignore index c390d5982..4850e8961 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ slurm*.out #lightning_logs directory lightning_logs/ + +applications/dynacell/test \ No newline at end of file diff --git a/applications/dynacell/fid_ts.py b/applications/dynacell/fid_ts.py new file mode 100644 index 000000000..0bbddf823 --- /dev/null +++ b/applications/dynacell/fid_ts.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +@torch.no_grad() +def encode_fovs( + fov_pairs, + vae, + channel_name1: str, + channel_name2: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb1, emb2 : (N, latent_dim) tensors + """ + emb1, emb2 = [], [] + + for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v1 = torch.as_tensor( + pos1.data[:, pos1.get_channel_index(channel_name1)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + v2 = torch.as_tensor( + pos2.data[:, pos2.get_channel_index(channel_name2)], + dtype=torch.float32, device=device, + ) + + v1 = normalise(v1) # still (T, D, H, W) + v2 = normalise(v2) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v1.shape[0], batch_size): + slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) + + # resize to input spatial size + slice1 = torch.nn.functional.interpolate( + slice1, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + slice2 = torch.nn.functional.interpolate( + slice2, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat1 = vae.encode(slice1)[0] # mean + feat2 = vae.encode(slice2)[0] # mean + + feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) + feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) + + emb1.append(feat1) + emb2.append(feat2) + + return torch.cat(emb1, 0), torch.cat(emb2, 0) + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + vae = torch.jit.load(args.loadcheck_path).to(device) + vae.eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + n = min(len(fovs1), len(fovs2)) + if args.max_fov: + n = min(n, args.max_fov) + pair_list = list(zip(fovs1[:n], fovs2[:n])) + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1, emb2 = encode_fovs( + pair_list, vae, + args.channel_name1, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID: {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/dynacell/test_fid_ts.sh b/applications/dynacell/test_fid_ts.sh new file mode 100644 index 000000000..d98e5047f --- /dev/null +++ b/applications/dynacell/test_fid_ts.sh @@ -0,0 +1,17 @@ +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Nuclei-prediction \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ + --batch_size 4 \ + --device cuda + +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane-prediction \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/autoencoders_ts.py b/applications/dynacell/vae_3d/modules/autoencoders_ts.py new file mode 100644 index 000000000..9438ddd6d --- /dev/null +++ b/applications/dynacell/vae_3d/modules/autoencoders_ts.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + + +class Autoencoder3DKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode(self, x: torch.Tensor): + h = self._encode(x) + mean, logvar = torch.chunk(h, 2, dim=1) + + return mean, logvar + + def _decode(self, z: torch.Tensor): + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + return dec + + def decode(self, z: torch.FloatTensor): + decoded = self._decode(z) + + return decoded + + def forward(self, x): + # placeholder forward + return x \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/blocks.py b/applications/dynacell/vae_3d/modules/blocks.py index b0d4e4380..569d66e9a 100644 --- a/applications/dynacell/vae_3d/modules/blocks.py +++ b/applications/dynacell/vae_3d/modules/blocks.py @@ -4,7 +4,6 @@ import torch.nn.functional as F from torch import nn from diffusers.models.normalization import RMSNorm -from diffusers.utils import is_torch_version from diffusers.models.activations import get_activation class UpDecoderBlock3D(nn.Module): @@ -93,14 +92,14 @@ def __init__( self.downsamplers = nn.ModuleList( [ Downsample3D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, ) ] ) else: self.downsamplers = None - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states) @@ -248,7 +247,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Apply padding if using conv downsampling and no padding was specified if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0.0) # Apply downsampling hidden_states = self.conv(hidden_states) @@ -291,7 +290,7 @@ def __init__( # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed self.conv = conv - def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" # Apply normalization if specified @@ -300,29 +299,14 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) - dtype = hidden_states.dtype - if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): - hidden_states = hidden_states.to(torch.float32) - if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() if output_size is None: - B, C, D, H, W = hidden_states.shape - if B*C*D*H*W*8 > 2_100_000_000: - x1 = F.interpolate(hidden_states[:, :, :D//2], scale_factor=2, mode="nearest") - x2 = F.interpolate(hidden_states[:, :, D//2:], scale_factor=2, mode="nearest") - hidden_states = torch.cat([x1, x2], dim=2) - else: - hidden_states = F.interpolate(hidden_states, scale_factor=2, mode="nearest") - # hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - # Cast back to original dtype - if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): - hidden_states = hidden_states.to(dtype) - hidden_states = self.conv(hidden_states) return hidden_states @@ -349,8 +333,7 @@ def __init__( if attn_groups is None: attn_groups = resnet_groups - # there is always at least one resnet - resnets = [ + self.resnets = nn.ModuleList([ ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, @@ -359,27 +342,12 @@ def __init__( dropout=dropout, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, - ), - ] - - for _ in range(num_layers): - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) ) - - self.resnets = nn.ModuleList(resnets) + for _ in range(num_layers + 1) + ]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.resnets[0](hidden_states) - for resnet in self.resnets[1:]: + for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/decoder.py b/applications/dynacell/vae_3d/modules/decoder.py index a1887f0b1..19ff8725b 100644 --- a/applications/dynacell/vae_3d/modules/decoder.py +++ b/applications/dynacell/vae_3d/modules/decoder.py @@ -4,14 +4,9 @@ import torch.nn as nn from .blocks import UNetMidBlock3D, UpDecoderBlock3D - -from diffusers.utils import is_torch_version from diffusers.models.attention_processor import SpatialNorm class Decoder(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. - """ def __init__( self, @@ -87,53 +82,14 @@ def __init__( self.gradient_checkpointing = False def forward(self, sample: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `Decoder` class.""" - sample = self.conv_in(sample) - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), - sample, - use_reentrant=False, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - use_reentrant=False, - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, - ) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) - else: - # middle - sample = self.mid_block(sample) - sample = sample.to(upscale_dtype) + # middle + sample = self.mid_block(sample) - # up - for up_block in self.up_blocks: - sample = up_block(sample) + # up + for up_block in self.up_blocks: + sample = up_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) diff --git a/applications/dynacell/vae_3d/modules/encoder.py b/applications/dynacell/vae_3d/modules/encoder.py index 31b8a70cf..bc93f857b 100644 --- a/applications/dynacell/vae_3d/modules/encoder.py +++ b/applications/dynacell/vae_3d/modules/encoder.py @@ -1,51 +1,10 @@ -from dataclasses import dataclass -from typing import Optional, Tuple - import torch import torch.nn as nn +from typing import Tuple from .blocks import DownEncoderBlock3D, UNetMidBlock3D -from diffusers.utils import BaseOutput, is_torch_version - - -@dataclass -class DecoderOutput(BaseOutput): - r""" - Output of decoding method. - - Args: - sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): - The decoded output sample from the last layer of the model. - """ - - sample: torch.Tensor - commit_loss: Optional[torch.FloatTensor] = None - class Encoder(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available - options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. See `~diffusers.models.activations.get_activation` for available options. - double_z (`bool`, *optional*, defaults to `True`): - Whether to double the number of output channels for the last block. - """ - def __init__( self, in_channels: int = 3, @@ -111,47 +70,18 @@ def __init__( self.gradient_checkpointing = False def forward(self, sample: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `Encoder` class.""" - sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - if is_torch_version(">=", "1.11.0"): - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, use_reentrant=False - ) - else: - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) - - else: - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) + # middle + sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) - return sample - - + return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_model_ts.py b/applications/dynacell/vae_3d/vae_3d_model_ts.py new file mode 100644 index 000000000..92a53ac24 --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_model_ts.py @@ -0,0 +1,90 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders_ts import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig + + +class VAE3DModel(nn.Module): + def __init__(self, config: VAE3DConfig): + super().__init__() + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x) + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, x): + # placeholder forward + return x + + def reconstruct(self, x): + mean, logvar = self.encode(x) + latents = mean + torch.exp(0.5 * logvar) * torch.randn_like(logvar) # Reparameterization trick + recon_x = self.decode(latents) + + return recon_x From 0f013543367e3a2cf619d179235a6f20e9237406 Mon Sep 17 00:00:00 2001 From: "dihan.zheng" Date: Wed, 17 Sep 2025 12:09:09 -0700 Subject: [PATCH 04/10] update fid script --- applications/dynacell/fid.py | 66 ++++++++++++---------------- applications/dynacell/fid_ts.py | 66 ++++++++++++---------------- applications/dynacell/test_fid.sh | 12 ++--- applications/dynacell/test_fid_ts.sh | 12 ++--- 4 files changed, 68 insertions(+), 88 deletions(-) diff --git a/applications/dynacell/fid.py b/applications/dynacell/fid.py index 577c65346..d53392dc4 100644 --- a/applications/dynacell/fid.py +++ b/applications/dynacell/fid.py @@ -100,10 +100,9 @@ def fid_from_features(f1, f2, eps=1e-6): @torch.no_grad() def encode_fovs( - fov_pairs, + fovs, vae, - channel_name1: str, - channel_name2: str, + channel_name: str, device: str = "cuda", batch_size: int = 4, input_spatial_size: tuple = (32, 512, 512), @@ -115,47 +114,33 @@ def encode_fovs( • feed through VAE in chunks of ≤ batch_size frames • average the resulting T latent vectors → one embedding / FOV Returns - emb1, emb2 : (N, latent_dim) tensors + emb: (N, latent_dim) tensors """ - emb1, emb2 = [], [] + emb = [] - for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): + for pos in tqdm(fovs, desc="Encoding FOVs"): # ---------------- load & normalise ---------------- # - v1 = torch.as_tensor( - pos1.data[:, pos1.get_channel_index(channel_name1)], + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], dtype=torch.float32, device=device, ) # (T, D, H, W) - v2 = torch.as_tensor( - pos2.data[:, pos2.get_channel_index(channel_name2)], - dtype=torch.float32, device=device, - ) - v1 = normalise(v1) # still (T, D, H, W) - v2 = normalise(v2) + v = normalise(v) # still (T, D, H, W) # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v1.shape[0], batch_size): - slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) # resize to input spatial size - slice1 = torch.nn.functional.interpolate( - slice1, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - slice2 = torch.nn.functional.interpolate( - slice2, size=input_spatial_size, mode="trilinear", align_corners=False, + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, ) # (b, 1, D, H, W) - feat1 = vae.encode(slice1).mean - feat2 = vae.encode(slice2).mean - - feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) - feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) - - emb1.append(feat1) - emb2.append(feat2) + feat = vae.encode(slice).mean # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) - return torch.cat(emb1, 0), torch.cat(emb2, 0) + return torch.cat(emb, 0) # ----------------------------------------------------------------------------- # # Main # @@ -188,21 +173,26 @@ def main(args) -> None: # ----------------- FOV list ------------ # fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - n = min(len(fovs1), len(fovs2)) if args.max_fov: - n = min(n, args.max_fov) - pair_list = list(zip(fovs1[:n], fovs2[:n])) + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] # ----------------- Embeddings ----------- # input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] if args.channel_name is not None: args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) - emb1, emb2 = encode_fovs( - pair_list, vae, - args.channel_name1, - args.channel_name2, + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, device, args.batch_size, input_spatial_size, ) diff --git a/applications/dynacell/fid_ts.py b/applications/dynacell/fid_ts.py index 0bbddf823..8c8f272cf 100644 --- a/applications/dynacell/fid_ts.py +++ b/applications/dynacell/fid_ts.py @@ -97,10 +97,9 @@ def fid_from_features(f1, f2, eps=1e-6): @torch.no_grad() def encode_fovs( - fov_pairs, + fovs, vae, - channel_name1: str, - channel_name2: str, + channel_name: str, device: str = "cuda", batch_size: int = 4, input_spatial_size: tuple = (32, 512, 512), @@ -112,47 +111,33 @@ def encode_fovs( • feed through VAE in chunks of ≤ batch_size frames • average the resulting T latent vectors → one embedding / FOV Returns - emb1, emb2 : (N, latent_dim) tensors + emb: (N, latent_dim) tensors """ - emb1, emb2 = [], [] + emb = [] - for pos1, pos2 in tqdm(fov_pairs, desc="Encoding FOVs"): + for pos in tqdm(fovs, desc="Encoding FOVs"): # ---------------- load & normalise ---------------- # - v1 = torch.as_tensor( - pos1.data[:, pos1.get_channel_index(channel_name1)], + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], dtype=torch.float32, device=device, ) # (T, D, H, W) - v2 = torch.as_tensor( - pos2.data[:, pos2.get_channel_index(channel_name2)], - dtype=torch.float32, device=device, - ) - v1 = normalise(v1) # still (T, D, H, W) - v2 = normalise(v2) + v = normalise(v) # still (T, D, H, W) # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v1.shape[0], batch_size): - slice1 = v1[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - slice2 = v2[t0 : t0 + batch_size].unsqueeze(1) + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) # resize to input spatial size - slice1 = torch.nn.functional.interpolate( - slice1, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - slice2 = torch.nn.functional.interpolate( - slice2, size=input_spatial_size, mode="trilinear", align_corners=False, + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, ) # (b, 1, D, H, W) - feat1 = vae.encode(slice1)[0] # mean - feat2 = vae.encode(slice2)[0] # mean - - feat1 = feat1.flatten(start_dim=1) # (b, latent_dim) - feat2 = feat2.flatten(start_dim=1) # (b, latent_dim) - - emb1.append(feat1) - emb2.append(feat2) + feat = vae.encode(slice)[0] # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) - return torch.cat(emb1, 0), torch.cat(emb2, 0) + return torch.cat(emb, 0) # ----------------------------------------------------------------------------- # # Main # @@ -184,21 +169,26 @@ def main(args) -> None: # ----------------- FOV list ------------ # fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - n = min(len(fovs1), len(fovs2)) if args.max_fov: - n = min(n, args.max_fov) - pair_list = list(zip(fovs1[:n], fovs2[:n])) + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] # ----------------- Embeddings ----------- # input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] if args.channel_name is not None: args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) - emb1, emb2 = encode_fovs( - pair_list, vae, - args.channel_name1, - args.channel_name2, + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, device, args.batch_size, input_spatial_size, ) diff --git a/applications/dynacell/test_fid.sh b/applications/dynacell/test_fid.sh index 9e8adea15..84c330e5d 100644 --- a/applications/dynacell/test_fid.sh +++ b/applications/dynacell/test_fid.sh @@ -1,17 +1,17 @@ python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ --channel_name1 Nuclei-prediction \ - --channel_name2 Nuclei-prediction \ + --channel_name2 Organelle \ --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ --batch_size 4 \ --device cuda python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ --channel_name1 Membrane-prediction \ - --channel_name2 Membrane-prediction \ + --channel_name2 Membrane \ --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ --batch_size 4 \ --device cuda \ No newline at end of file diff --git a/applications/dynacell/test_fid_ts.sh b/applications/dynacell/test_fid_ts.sh index d98e5047f..012cb23b1 100644 --- a/applications/dynacell/test_fid_ts.sh +++ b/applications/dynacell/test_fid_ts.sh @@ -1,17 +1,17 @@ python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ --channel_name1 Nuclei-prediction \ - --channel_name2 Nuclei-prediction \ + --channel_name2 Organelle \ --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ --batch_size 4 \ --device cuda python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/prediction/a549/output.zarr \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ --channel_name1 Membrane-prediction \ - --channel_name2 Membrane-prediction \ + --channel_name2 Membrane \ --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ --batch_size 4 \ --device cuda \ No newline at end of file From 593c6dffd786ef1196c6b6f89d85bfc965731912 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 16:59:19 -0700 Subject: [PATCH 05/10] moving files to benchmar/Dynacell --- applications/benchmarking/DynaCell/fid.py | 207 ++++++++++ applications/benchmarking/DynaCell/fid_ts.py | 203 ++++++++++ .../benchmarking/DynaCell/test_fid.sh | 17 + .../benchmarking/DynaCell/test_fid_ts.sh | 17 + .../benchmarking/DynaCell/vae_3d/__init__.py | 0 .../DynaCell/vae_3d/modules/__init__.py | 0 .../DynaCell/vae_3d/modules/autoencoders.py | 160 ++++++++ .../vae_3d/modules/autoencoders_ts.py | 82 ++++ .../DynaCell/vae_3d/modules/blocks.py | 353 ++++++++++++++++++ .../DynaCell/vae_3d/modules/decoder.py | 98 +++++ .../DynaCell/vae_3d/modules/encoder.py | 87 +++++ .../DynaCell/vae_3d/modules/utils.py | 9 + .../DynaCell/vae_3d/vae_3d_config.py | 16 + .../DynaCell/vae_3d/vae_3d_model.py | 138 +++++++ .../DynaCell/vae_3d/vae_3d_model_ts.py | 90 +++++ 15 files changed, 1477 insertions(+) create mode 100644 applications/benchmarking/DynaCell/fid.py create mode 100644 applications/benchmarking/DynaCell/fid_ts.py create mode 100644 applications/benchmarking/DynaCell/test_fid.sh create mode 100644 applications/benchmarking/DynaCell/test_fid_ts.sh create mode 100644 applications/benchmarking/DynaCell/vae_3d/__init__.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/__init__.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/blocks.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/decoder.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/encoder.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/utils.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py create mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py diff --git a/applications/benchmarking/DynaCell/fid.py b/applications/benchmarking/DynaCell/fid.py new file mode 100644 index 000000000..d53392dc4 --- /dev/null +++ b/applications/benchmarking/DynaCell/fid.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +from vae_3d.vae_3d_config import VAE3DConfig +from vae_3d.vae_3d_model import VAE3DModel + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +@torch.no_grad() +def encode_fovs( + fovs, + vae, + channel_name: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb: (N, latent_dim) tensors + """ + emb = [] + + for pos in tqdm(fovs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + + v = normalise(v) # still (T, D, H, W) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat = vae.encode(slice).mean # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) + + return torch.cat(emb, 0) + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + model_cfg = VAE3DConfig() + model_cfg.loadcheck_path = args.loadcheck_path + vae = VAE3DModel(config=model_cfg).to(device).eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + if args.max_fov: + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) + + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID: {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/fid_ts.py b/applications/benchmarking/DynaCell/fid_ts.py new file mode 100644 index 000000000..8c8f272cf --- /dev/null +++ b/applications/benchmarking/DynaCell/fid_ts.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +@torch.no_grad() +def encode_fovs( + fovs, + vae, + channel_name: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb: (N, latent_dim) tensors + """ + emb = [] + + for pos in tqdm(fovs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + + v = normalise(v) # still (T, D, H, W) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat = vae.encode(slice)[0] # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) + + return torch.cat(emb, 0) + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + vae = torch.jit.load(args.loadcheck_path).to(device) + vae.eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + if args.max_fov: + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) + + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID: {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/test_fid.sh b/applications/benchmarking/DynaCell/test_fid.sh new file mode 100644 index 000000000..84c330e5d --- /dev/null +++ b/applications/benchmarking/DynaCell/test_fid.sh @@ -0,0 +1,17 @@ +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Organelle \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ + --batch_size 4 \ + --device cuda + +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/test_fid_ts.sh b/applications/benchmarking/DynaCell/test_fid_ts.sh new file mode 100644 index 000000000..012cb23b1 --- /dev/null +++ b/applications/benchmarking/DynaCell/test_fid_ts.sh @@ -0,0 +1,17 @@ +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Organelle \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ + --batch_size 4 \ + --device cuda + +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/__init__.py b/applications/benchmarking/DynaCell/vae_3d/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/__init__.py b/applications/benchmarking/DynaCell/vae_3d/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py new file mode 100644 index 000000000..9c3fb927a --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py @@ -0,0 +1,160 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin + + +class Autoencoder3DKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py new file mode 100644 index 000000000..9438ddd6d --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + + +class Autoencoder3DKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode(self, x: torch.Tensor): + h = self._encode(x) + mean, logvar = torch.chunk(h, 2, dim=1) + + return mean, logvar + + def _decode(self, z: torch.Tensor): + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + return dec + + def decode(self, z: torch.FloatTensor): + decoded = self._decode(z) + + return decoded + + def forward(self, x): + # placeholder forward + return x \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py b/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py new file mode 100644 index 000000000..569d66e9a --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py @@ -0,0 +1,353 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.normalization import RMSNorm +from diffusers.models.activations import get_activation + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + conv_shortcut_bias: bool = True, + ): + super().__init__() + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_3d_out_channels = out_channels + self.conv2 = nn.Conv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + conv_3d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + norm_type: Optional[str] = None, + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + bias: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + self.kernel_size = kernel_size + stride = 2 # Downsampling stride is fixed to 2 + + # Initialize normalization + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # Choose between convolutional or pooling downsampling + if use_conv: + self.conv = nn.Conv3d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels, "out_channels must match channels when using pooling" + self.conv = nn.AvgPool3d(kernel_size=stride, stride=stride) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the downsampling block. + + Args: + hidden_states (torch.Tensor): Input feature map of shape (B, C, D, H, W). + + Returns: + torch.Tensor: Downsampled feature map. + """ + assert hidden_states.shape[1] == self.channels, \ + f"Expected input channels {self.channels}, but got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + # Apply padding if using conv downsampling and no padding was specified + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0.0) + + # Apply downsampling + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample3D(nn.Module): + """A 3D upsampling layer with a convolution. + """ + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: + assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + """ + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if attn_groups is None: + attn_groups = resnet_groups + + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + for _ in range(num_layers + 1) + ]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py b/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py new file mode 100644 index 000000000..19ff8725b --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py @@ -0,0 +1,98 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .blocks import UNetMidBlock3D, UpDecoderBlock3D +from diffusers.models.attention_processor import SpatialNorm + +class Decoder(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_up_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(num_up_blocks): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=not is_final_block, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv3d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + padding_mode='reflect', + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py b/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py new file mode 100644 index 000000000..bc93f857b --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from typing import Tuple +from .blocks import DownEncoderBlock3D, UNetMidBlock3D + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + padding_mode='reflect' + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i in range(num_down_blocks): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + in_channels=input_channel, + out_channels=output_channel, + dropout=0.0, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv3d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/utils.py b/applications/benchmarking/DynaCell/vae_3d/modules/utils.py new file mode 100644 index 000000000..0c2a8185e --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/modules/utils.py @@ -0,0 +1,9 @@ +import torch +from dataclasses import dataclass +from transformers.utils import ModelOutput + +@dataclass +class VAEOutput(ModelOutput): + loss: torch.FloatTensor = None + recon_loss: torch.FloatTensor = None + kl_loss: torch.FloatTensor = None \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py new file mode 100644 index 000000000..d30883ed2 --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from dataclasses import dataclass +from transformers import PretrainedConfig +from dataclasses import field + +@dataclass +class VAE3DConfig(PretrainedConfig): + model_type: str = 'vae' + + # Model parameters + in_channels: int = 1 + out_channels: int = 1 + num_down_blocks: int = 5 + latent_channels: int = 2 + vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256, 256]) + loadcheck_path: str = "" \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py new file mode 100644 index 000000000..d01138b1e --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py @@ -0,0 +1,138 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig +from transformers import PreTrainedModel +from .modules.utils import VAEOutput + + +class VAE3DModel(PreTrainedModel): + config_class = VAE3DConfig + + def __init__(self, config: VAE3DConfig): + super().__init__(config) + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x).latent_dist + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, batched_data): + x = batched_data['data'] + + """Forward pass through the VAE.""" + latent_dist = self.encode(x) + latents = latent_dist.sample() + recon_x = self.decode(latents).sample + + total_loss, recon_loss, kl_loss = self.compute_loss(x, recon_x, latent_dist) + + return VAEOutput(total_loss, recon_loss, kl_loss) + + def compute_loss(self, x, recon_x, latent_dist): + """Compute reconstruction and KL divergence loss.""" + if self.config.vae_recon_loss_type == 'mse': + recon_loss = nn.MSELoss()(recon_x, x) + elif self.config.vae_recon_loss_type == 'poisson': + x = x.clip(-1, 1) + recon_x = recon_x.clip(-1, 1) + peak = self.config.poisson_peak if hasattr(self.config, 'poisson_peak') else 1.0 + target = (x + 1) / 2.0 * peak + lam = (recon_x + 1) / 2.0 * peak + recon_loss = torch.mean(lam - target * torch.log(lam + 1e-8)) + + kl_loss = -0.5 * torch.mean(1 + latent_dist.logvar - latent_dist.mean.pow(2) - latent_dist.logvar.exp()) + total_loss = self.config.recon_loss_coeff * recon_loss + self.config.kl_loss_coeff * kl_loss + return total_loss, recon_loss, kl_loss + + def sample(self, num_samples=1, latent_size=32, device="cpu"): + """ + Generate samples from the latent space. + + Args: + num_samples (int): Number of samples to generate. + device (str): Device to perform sampling on. + + Returns: + torch.Tensor: Generated images. + """ + # Sample from a standard normal distribution in latent space + latents = torch.randn((num_samples, self.config.latent_channels, latent_size, latent_size, latent_size), device=device) # Shape matches latent dimensions + + # Decode latents to generate images + with torch.no_grad(): + generated_images = self.decode(latents).sample + + return generated_images + + def reconstruct(self, x): + latent_dist = self.encode(x) + latents = latent_dist.sample() # Reparameterization trick + recon_x = self.decode(latents).sample + + return recon_x diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py new file mode 100644 index 000000000..92a53ac24 --- /dev/null +++ b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py @@ -0,0 +1,90 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders_ts import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig + + +class VAE3DModel(nn.Module): + def __init__(self, config: VAE3DConfig): + super().__init__() + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x) + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, x): + # placeholder forward + return x + + def reconstruct(self, x): + mean, logvar = self.encode(x) + latents = mean + torch.exp(0.5 * logvar) * torch.randn_like(logvar) # Reparameterization trick + recon_x = self.decode(latents) + + return recon_x From 90c77b2d5c863cd5ad4cc1a66a4832f85a4c443c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 17:01:16 -0700 Subject: [PATCH 06/10] rename shell script to run ts --- applications/benchmarking/DynaCell/fid_ts.py | 3 +-- .../benchmarking/DynaCell/test_fid_ts.sh | 17 ----------------- 2 files changed, 1 insertion(+), 19 deletions(-) delete mode 100644 applications/benchmarking/DynaCell/test_fid_ts.sh diff --git a/applications/benchmarking/DynaCell/fid_ts.py b/applications/benchmarking/DynaCell/fid_ts.py index 8c8f272cf..a5751eab9 100644 --- a/applications/benchmarking/DynaCell/fid_ts.py +++ b/applications/benchmarking/DynaCell/fid_ts.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- import argparse from pathlib import Path import torch -from tqdm import tqdm from iohub.ngff import open_ome_zarr from torch import Tensor +from tqdm import tqdm # ----------------------------------------------------------------------------- # # Helper functions # diff --git a/applications/benchmarking/DynaCell/test_fid_ts.sh b/applications/benchmarking/DynaCell/test_fid_ts.sh deleted file mode 100644 index 012cb23b1..000000000 --- a/applications/benchmarking/DynaCell/test_fid_ts.sh +++ /dev/null @@ -1,17 +0,0 @@ -python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Nuclei-prediction \ - --channel_name2 Organelle \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ - --batch_size 4 \ - --device cuda - -python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Membrane-prediction \ - --channel_name2 Membrane \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ - --batch_size 4 \ - --device cuda \ No newline at end of file From 65ea06fe2545f11eeba1e75214d3bf25adcf0b23 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 17:08:56 -0700 Subject: [PATCH 07/10] convert numpy docs, using click, torch_inference mode and simplifiying logic for getting the channels --- applications/benchmarking/DynaCell/fid_ts.py | 213 ++++++++++++------- 1 file changed, 139 insertions(+), 74 deletions(-) diff --git a/applications/benchmarking/DynaCell/fid_ts.py b/applications/benchmarking/DynaCell/fid_ts.py index a5751eab9..024f3448c 100644 --- a/applications/benchmarking/DynaCell/fid_ts.py +++ b/applications/benchmarking/DynaCell/fid_ts.py @@ -1,6 +1,6 @@ -import argparse from pathlib import Path +import click import torch from iohub.ngff import open_ome_zarr from torch import Tensor @@ -15,7 +15,18 @@ def read_zarr(zarr_path: str): return [pos for _, pos in plate.positions()] def normalise(volume: torch.Tensor) -> torch.Tensor: - """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + """Normalize volume to [-1, 1] range using min-max normalization. + + Parameters + ---------- + volume : torch.Tensor + Input volume with shape (D, H, W) or (B, D, H, W) + + Returns + ------- + torch.Tensor + Normalized volume in [-1, 1] range with same shape as input + """ v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] @@ -23,21 +34,28 @@ def normalise(volume: torch.Tensor) -> torch.Tensor: @torch.jit.script_if_tracing def sqrtm(sigma: Tensor) -> Tensor: - r"""Returns the square root of a positive semi-definite matrix. - - .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + r"""Compute the square root of a positive semi-definite matrix. + Uses eigendecomposition: :math:`\sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T` where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. - Args: - sigma: A positive semi-definite matrix, :math:`(*, D, D)`. - - Example: - >>> V = torch.randn(4, 4, dtype=torch.double) - >>> A = V @ V.T - >>> B = sqrtm(A @ A) - >>> torch.allclose(A, B) - True + Parameters + ---------- + sigma : Tensor + A positive semi-definite matrix with shape (*, D, D) + + Returns + ------- + Tensor + Square root of the input matrix with same shape + + Examples + -------- + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True """ L, Q = torch.linalg.eigh(sigma) @@ -52,27 +70,40 @@ def frechet_distance( mu_y: Tensor, sigma_y: Tensor, ) -> Tensor: - r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + r"""Compute the Fréchet distance between two multivariate Gaussian distributions. + The Fréchet distance is given by: .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) - Wikipedia: - https://wikipedia.org/wiki/Frechet_distance - - Args: - mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. - sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. - mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. - sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. - - Example: - >>> mu_x = torch.arange(3).float() - >>> sigma_x = torch.eye(3) - >>> mu_y = 2 * mu_x + 1 - >>> sigma_y = 2 * sigma_x + 1 - >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) - tensor(15.8710) + Parameters + ---------- + mu_x : Tensor + Mean of the first distribution with shape (*, D) + sigma_x : Tensor + Covariance of the first distribution with shape (*, D, D) + mu_y : Tensor + Mean of the second distribution with shape (*, D) + sigma_y : Tensor + Covariance of the second distribution with shape (*, D, D) + + Returns + ------- + Tensor + Fréchet distance between the two distributions + + References + ---------- + .. [1] https://wikipedia.org/wiki/Frechet_distance + + Examples + -------- + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) """ sigma_y_12 = sqrtm(sigma_y) @@ -85,6 +116,22 @@ def frechet_distance( @torch.no_grad() def fid_from_features(f1, f2, eps=1e-6): + """Compute Fréchet Inception Distance (FID) from feature embeddings. + + Parameters + ---------- + f1 : torch.Tensor + Features from first dataset with shape (N1, D) + f2 : torch.Tensor + Features from second dataset with shape (N2, D) + eps : float, default=1e-6 + Small value added to diagonal for numerical stability + + Returns + ------- + float + FID score between the two feature sets + """ mu1, sigma1 = f1.mean(0), torch.cov(f1.T) mu2, sigma2 = f2.mean(0), torch.cov(f2.T) @@ -94,7 +141,7 @@ def fid_from_features(f1, f2, eps=1e-6): return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() -@torch.no_grad() +@torch.inference_mode() def encode_fovs( fovs, vae, @@ -103,14 +150,33 @@ def encode_fovs( batch_size: int = 4, input_spatial_size: tuple = (32, 512, 512), ): - """ - For each FOV pair: - • take all T time-frames (shape: T, D, H, W) - • normalise to [-1, 1] - • feed through VAE in chunks of ≤ batch_size frames - • average the resulting T latent vectors → one embedding / FOV + """Encode field-of-view (FOV) data using a variational autoencoder. + + For each FOV: + - Extract all time-frames with shape (T, D, H, W) + - Normalize to [-1, 1] range + - Process through VAE in batches of ≤ batch_size frames + - Collect all latent vectors from all time points + + Parameters + ---------- + fovs : list + List of FOV position objects + vae : torch.nn.Module + Pre-trained VAE model for encoding + channel_name : str + Name of the channel to extract from each FOV + device : str, default="cuda" + Device to run computations on + batch_size : int, default=4 + Number of frames to process simultaneously + input_spatial_size : tuple, default=(32, 512, 512) + Target spatial dimensions for VAE input (D, H, W) + Returns - emb: (N, latent_dim) tensors + ------- + torch.Tensor + Concatenated embeddings from all FOVs and timepoints with shape (N_total_timepoints, latent_dim) """ emb = [] @@ -142,53 +208,54 @@ def encode_fovs( # Main # # ----------------------------------------------------------------------------- # -def build_argparser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(add_help=False) - p.add_argument("--data_path1", type=Path, required=True) - p.add_argument("--data_path2", type=Path, required=True) - p.add_argument("--channel_name", type=str, default=None) - p.add_argument("--channel_name1", type=str, default=None) - p.add_argument("--channel_name2", type=str, default=None) - p.add_argument("--input_spatial_size", type=str, default="32,512,512", - help="Input spatial size for the VAE, e.g. '32,512,512'.") - p.add_argument("--loadcheck_path", type=Path, default=None, - help="Path to the VAE model checkpoint for loading.") - p.add_argument("--batch_size", type=int, default=4) - p.add_argument("--device", type=str, default="cuda") - p.add_argument("--max_fov", type=int, default=None, - help="Limit number of FOV pairs (for quick tests).") - return p - -def main(args) -> None: - device = args.device +@click.command() +@click.option("--source_path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option("--target_path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option("--channel_names", type=str, multiple=True, required=True, + help="Channel names for source and target (1 or 2 values). If 1 value, same channel used for both.") +@click.option("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") +@click.option("--loadcheck_path", type=click.Path(exists=True, path_type=Path), default=None, + help="Path to the VAE model checkpoint for loading.") +@click.option("--batch_size", type=int, default=4) +@click.option("--device", type=str, default="cuda") +@click.option("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") +def main(source_path, target_path, channel_names, + input_spatial_size, loadcheck_path, batch_size, device, max_fov) -> None: # ----------------- VAE ----------------- # - vae = torch.jit.load(args.loadcheck_path).to(device) + vae = torch.jit.load(loadcheck_path).to(device) vae.eval() # ----------------- FOV list ------------ # - fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - if args.max_fov: - fovs1 = fovs1[:args.max_fov] - fovs2 = fovs2[:args.max_fov] + fovs1, fovs2 = read_zarr(source_path), read_zarr(target_path) + if max_fov: + fovs1 = fovs1[:max_fov] + fovs2 = fovs2[:max_fov] # ----------------- Embeddings ----------- # - input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] - - if args.channel_name is not None: - args.channel_name1 = args.channel_name2 = args.channel_name + input_spatial_size = [int(dim) for dim in input_spatial_size.split(",")] + + # Handle channel names: use same for both if only one provided + if len(channel_names) == 1: + channel_name1 = channel_name2 = channel_names[0] + elif len(channel_names) == 2: + channel_name1, channel_name2 = channel_names + else: + raise ValueError("Must provide 1 or 2 channel names") emb1 = encode_fovs( fovs1, vae, - args.channel_name1, - device, args.batch_size, + channel_name1, + device, batch_size, input_spatial_size, ) emb2 = encode_fovs( fovs2, vae, - args.channel_name2, - device, args.batch_size, + channel_name2, + device, batch_size, input_spatial_size, ) @@ -197,6 +264,4 @@ def main(args) -> None: print(f"\nFID: {fid_val:.6f}") if __name__ == "__main__": - parser = build_argparser() - args = parser.parse_args() - main(args) \ No newline at end of file + main() \ No newline at end of file From 31b429ccbbc94e7fdda473c2586543234449d8e6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 18:34:01 -0700 Subject: [PATCH 08/10] store embeddings --- applications/benchmarking/DynaCell/fid_ts.py | 353 ++++++++++++++---- .../benchmarking/DynaCell/run_fid_ts.sh | 32 ++ 2 files changed, 308 insertions(+), 77 deletions(-) create mode 100644 applications/benchmarking/DynaCell/run_fid_ts.sh diff --git a/applications/benchmarking/DynaCell/fid_ts.py b/applications/benchmarking/DynaCell/fid_ts.py index 024f3448c..4c1f8ccb0 100644 --- a/applications/benchmarking/DynaCell/fid_ts.py +++ b/applications/benchmarking/DynaCell/fid_ts.py @@ -1,8 +1,11 @@ +import warnings from pathlib import Path import click +import numpy as np import torch -from iohub.ngff import open_ome_zarr +import xarray as xr +from iohub.ngff import Position, open_ome_zarr from torch import Tensor from tqdm import tqdm @@ -10,10 +13,6 @@ # Helper functions # # ----------------------------------------------------------------------------- # -def read_zarr(zarr_path: str): - plate = open_ome_zarr(zarr_path, mode="r") - return [pos for _, pos in plate.positions()] - def normalise(volume: torch.Tensor) -> torch.Tensor: """Normalize volume to [-1, 1] range using min-max normalization. @@ -142,17 +141,16 @@ def fid_from_features(f1, f2, eps=1e-6): return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() @torch.inference_mode() -def encode_fovs( - fovs, - vae, +def encode_position( + position: Position, + vae: torch.nn.Module, channel_name: str, device: str = "cuda", batch_size: int = 4, input_spatial_size: tuple = (32, 512, 512), ): - """Encode field-of-view (FOV) data using a variational autoencoder. + """Encode position data using a variational autoencoder. - For each FOV: - Extract all time-frames with shape (T, D, H, W) - Normalize to [-1, 1] range - Process through VAE in batches of ≤ batch_size frames @@ -160,12 +158,12 @@ def encode_fovs( Parameters ---------- - fovs : list - List of FOV position objects + position : Position + Single position object from zarr plate vae : torch.nn.Module Pre-trained VAE model for encoding channel_name : str - Name of the channel to extract from each FOV + Name of the channel to extract from the position device : str, default="cuda" Device to run computations on batch_size : int, default=4 @@ -176,92 +174,293 @@ def encode_fovs( Returns ------- torch.Tensor - Concatenated embeddings from all FOVs and timepoints with shape (N_total_timepoints, latent_dim) + Embeddings from all timepoints with shape (N_timepoints, latent_dim) """ emb = [] - for pos in tqdm(fovs, desc="Encoding FOVs"): - # ---------------- load & normalise ---------------- # - v = torch.as_tensor( - pos.data[:, pos.get_channel_index(channel_name)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + position.data[:, position.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) - v = normalise(v) # still (T, D, H, W) + v = normalise(v) # still (T, D, H, W) - # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v.shape[0], batch_size): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + # ---------------- chunked VAE inference ----------- # + for t0 in tqdm(range(0, v.shape[0], batch_size), desc="Encoding timepoints"): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) - feat = vae.encode(slice)[0] # mean, - feat = feat.flatten(start_dim=1) # (b, latent_dim) - emb.append(feat) + feat = vae.encode(slice)[0] # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) return torch.cat(emb, 0) +@torch.inference_mode() +def encode_position_with_metadata( + position: Position, + vae: torch.nn.Module, + channel_name: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """Encode position data using a variational autoencoder with metadata. + + Parameters + ---------- + position : Position + Single position object from zarr plate + vae : torch.nn.Module + Pre-trained VAE model for encoding + channel_name : str + Name of the channel to extract from the position + position_name : str + Name/identifier for this position (e.g., 'A/1/0') + device : str, default="cuda" + Device to run computations on + batch_size : int, default=4 + Number of frames to process simultaneously + input_spatial_size : tuple, default=(32, 512, 512) + Target spatial dimensions for VAE input (D, H, W) + + Returns + ------- + xr.Dataset + Dataset with embeddings and metadata + """ + position_name = position.zgroup.name + embeddings_list = [] + timepoints_list = [] + + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + position.data[:, position.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + + v = normalise(v) # still (T, D, H, W) + + # ---------------- chunked VAE inference ----------- # + timepoint = 0 + for t0 in tqdm(range(0, v.shape[0], batch_size), desc=f"Encoding {position_name}/{channel_name}"): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat = vae.encode(slice)[0] # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + + # Convert to numpy and collect embeddings + feat_np = feat.cpu().numpy() + for i, embedding in enumerate(feat_np): + embeddings_list.append(embedding) + timepoints_list.append(timepoint + i) + + timepoint += feat.shape[0] + + # Create xarray Dataset + embeddings_array = np.stack(embeddings_list) + n_samples, n_features = embeddings_array.shape + + ds = xr.Dataset({ + 'embeddings': (['sample', 'feature'], embeddings_array) + }, coords={ + 'sample': range(n_samples), + 'feature': range(n_features), + 'timepoint': ('sample', timepoints_list) + }) + + # Add metadata as attributes + ds.attrs['position_name'] = position_name + ds.attrs['channel_name'] = channel_name + + return ds + # ----------------------------------------------------------------------------- # # Main # # ----------------------------------------------------------------------------- # @click.command() -@click.option("--source_path", type=click.Path(exists=True, path_type=Path), required=True) -@click.option("--target_path", type=click.Path(exists=True, path_type=Path), required=True) -@click.option("--channel_names", type=str, multiple=True, required=True, - help="Channel names for source and target (1 or 2 values). If 1 value, same channel used for both.") -@click.option("--input_spatial_size", type=str, default="32,512,512", - help="Input spatial size for the VAE, e.g. '32,512,512'.") -@click.option("--loadcheck_path", type=click.Path(exists=True, path_type=Path), default=None, +@click.option("--source_position", "-s", type=click.Path(exists=True, path_type=Path), required=True, help="Full path to source position (e.g., '/path/to/plate.zarr/A/1/0')") +@click.option("--target_position", "-t", type=click.Path(exists=True, path_type=Path), required=True, help="Full path to target position (e.g., '/path/to/plate.zarr/B/2/0')") +@click.option("--source_channel", "-sc", type=str, required=True, help="Channel name for source position") +@click.option("--target_channel", "-tc", type=str, required=True, help="Channel name for target position") +@click.option("-z", type=int, default=32, help="Depth dimension for VAE input") +@click.option("-y", type=int, default=512, help="Height dimension for VAE input") +@click.option("-x", type=int, default=512, help="Width dimension for VAE input") +@click.option("--ckpt_path", "-c", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the VAE model checkpoint for loading.") -@click.option("--batch_size", type=int, default=4) -@click.option("--device", type=str, default="cuda") -@click.option("--max_fov", type=int, default=None, - help="Limit number of FOV pairs (for quick tests).") -def main(source_path, target_path, channel_names, - input_spatial_size, loadcheck_path, batch_size, device, max_fov) -> None: +@click.option("--batch_size", "-b", type=int, default=4) +@click.option("--device", "-d", type=str, default="cuda") +@click.option("--source_output", "-so", type=click.Path(path_type=Path), help="Path to save source embeddings") +@click.option("--target_output", "-to", type=click.Path(path_type=Path), help="Path to save target embeddings") +def embed_dataset(source_position, target_position, source_channel, target_channel, z, y, x, + ckpt_path, batch_size, device, source_output, target_output) -> None: + """Encode positions using a pre-trained VAE and optionally compute FID or save embeddings. + + This function loads two zarr positions, encodes them using a variational autoencoder, + and can either compute FID scores or save embeddings with metadata to a parquet file. + + Parameters + ---------- + source_position : Path + Full path to the source position (e.g., '/path/to/plate.zarr/A/1/0') + target_position : Path + Full path to the target position (e.g., '/path/to/plate.zarr/B/2/0') + source_channel : str + Channel name for source position + target_channel : str + Channel name for target position + z : int + Depth dimension for VAE input + y : int + Height dimension for VAE input + x : int + Width dimension for VAE input + ckpt_path : Path + Path to the pre-trained VAE model checkpoint (.pt file) + batch_size : int + Number of timepoints to process simultaneously through the VAE + device : str + Device to run computations on ("cuda" or "cpu") + + Examples + -------- + Compute FID score between two positions: + + $ python fid_ts.py \\ + --source_position /path/to/dataset1.zarr/A/1/0 \\ + --target_position /path/to/dataset2.zarr/A/1/0 \\ + --source_channel phase \\ + --target_channel phase \\ + --ckpt_path /path/to/vae_model.pt \\ + --compute_fid + + Save embeddings to parquet file: + + $ python fid_ts.py \\ + --source_position /path/to/dataset1.zarr/A/1/0 \\ + --target_position /path/to/dataset2.zarr/B/2/0 \\ + --source_channel phase \\ + --target_channel brightfield \\ + --ckpt_path /path/to/vae_model.pt \\ + --output_path embeddings.parquet + + Save embeddings and compute FID: + + $ python fid_ts.py \\ + --source_position /path/to/dataset1.zarr/A/1/0 \\ + --target_position /path/to/dataset2.zarr/B/2/0 \\ + --source_channel phase \\ + --target_channel brightfield \\ + --ckpt_path /path/to/vae_model.pt \\ + --output_path embeddings.parquet \\ + --compute_fid + """ # ----------------- VAE ----------------- # - vae = torch.jit.load(loadcheck_path).to(device) + if device == "cuda" and not torch.cuda.is_available(): + warnings.warn("CUDA is not available, using CPU instead") + device = "cpu" + + vae = torch.jit.load(ckpt_path).to(device) vae.eval() - # ----------------- FOV list ------------ # - fovs1, fovs2 = read_zarr(source_path), read_zarr(target_path) - if max_fov: - fovs1 = fovs1[:max_fov] - fovs2 = fovs2[:max_fov] + # ----------------- Load positions ------------ # + source_position = open_ome_zarr(source_position) + source_channel_names = source_position.channel_names + assert source_channel in source_channel_names, f"Channel {source_channel} not found in source position" + + target_position = open_ome_zarr(target_position) + target_channel_names = target_position.channel_names + assert target_channel in target_channel_names, f"Channel {target_channel} not found in target position" # ----------------- Embeddings ----------- # - input_spatial_size = [int(dim) for dim in input_spatial_size.split(",")] - - # Handle channel names: use same for both if only one provided - if len(channel_names) == 1: - channel_name1 = channel_name2 = channel_names[0] - elif len(channel_names) == 2: - channel_name1, channel_name2 = channel_names - else: - raise ValueError("Must provide 1 or 2 channel names") + input_spatial_size = (z, y, x) + + + if source_output or target_output: + # Generate source embeddings + if source_output: + source_ds = encode_position_with_metadata( + position=source_position, vae=vae, + channel_name=source_channel, + device=device, batch_size=batch_size, + input_spatial_size=input_spatial_size, + ) + source_output.parent.mkdir(parents=True, exist_ok=True) + source_ds.to_zarr(source_output, mode='w') + print(f"Source embeddings saved to: {source_output}") + + # Generate target embeddings + if target_output: + target_ds = encode_position_with_metadata( + position=target_position, vae=vae, + channel_name=target_channel, + device=device, batch_size=batch_size, + input_spatial_size=input_spatial_size, + ) + target_output.parent.mkdir(parents=True, exist_ok=True) + target_ds.to_zarr(target_output, mode='w') + print(f"Target embeddings saved to: {target_output}") + +@click.command() +@click.option("--source_path", "-sp", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the source embeddings zarr file") +@click.option("--target_path", "-tp", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the target embeddings zarr file") +def compute_fid_cli(source_path: Path, target_path: Path) -> None: + """Compute FID score between two embedding datasets. - emb1 = encode_fovs( - fovs1, vae, - channel_name1, - device, batch_size, - input_spatial_size, - ) - - emb2 = encode_fovs( - fovs2, vae, - channel_name2, - device, batch_size, - input_spatial_size, - ) - - # ----------------- FID ------------------ # - fid_val = fid_from_features(emb1, emb2) - print(f"\nFID: {fid_val:.6f}") + Parameters + ---------- + source_path : Path + Path to the source embeddings zarr file + target_path : Path + Path to the target embeddings zarr file + + Examples + -------- + $ python fid_ts.py compute-fid \\ + -sp source_embeddings.zarr \\ + -tp target_embeddings.zarr + """ + # Load the datasets + source_ds = xr.open_zarr(source_path) + target_ds = xr.open_zarr(target_path) + + # Get embeddings arrays + source_embeddings = torch.tensor(source_ds.embeddings.values, dtype=torch.float32) + target_embeddings = torch.tensor(target_ds.embeddings.values, dtype=torch.float32) + + fid_score = fid_from_features(source_embeddings, target_embeddings) + + # Get metadata from attributes + source_channel = source_ds.attrs.get('channel_name', 'unknown') + target_channel = target_ds.attrs.get('channel_name', 'unknown') + source_position = source_ds.attrs.get('position_name', 'unknown') + target_position = target_ds.attrs.get('position_name', 'unknown') + + print(f"Source: {source_position}/{source_channel} ({len(source_embeddings)} samples)") + print(f"Target: {target_position}/{target_channel} ({len(target_embeddings)} samples)") + print(f"FID score: {fid_score:.6f}") + + return fid_score + +@click.group() +def cli(): + """VAE embedding and FID computation tools.""" + pass + +cli.add_command(embed_dataset, name="embed") +cli.add_command(compute_fid_cli, name="compute-fid") if __name__ == "__main__": - main() \ No newline at end of file + cli() \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/run_fid_ts.sh b/applications/benchmarking/DynaCell/run_fid_ts.sh new file mode 100644 index 000000000..eb70ccd15 --- /dev/null +++ b/applications/benchmarking/DynaCell/run_fid_ts.sh @@ -0,0 +1,32 @@ +# Generate nucleus embeddings separately +python fid_ts.py embed \ + -s /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr/0/HIST2H2BE/0000010 \ + -t /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr/0/HIST2H2BE/0000010 \ + -sc Nuclei-prediction \ + -tc Organelle \ + -c /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ + -so nuclei_prediction_embeddings.zarr \ + -to organelle_embeddings.zarr \ + -b 4 \ + -d cuda + +# Generate membrane embeddings separately +python fid_ts.py embed \ + -s /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr/0/HIST2H2BE/0000010 \ + -t /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr/0/HIST2H2BE/0000010 \ + -sc Membrane-prediction \ + -tc Membrane \ + -c /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ + -so membrane_prediction_embeddings.zarr \ + -to membrane_embeddings.zarr \ + -b 4 \ + -d cuda + +# Compute FID from separate embedding files +python fid_ts.py compute-fid \ + -sp nuclei_prediction_embeddings.zarr \ + -tp organelle_embeddings.zarr + +python fid_ts.py compute-fid \ + -sp membrane_prediction_embeddings.zarr \ + -tp membrane_embeddings.zarr \ No newline at end of file From efbb1883bcf3045461340352514fcfcb757bf202 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 18:36:42 -0700 Subject: [PATCH 09/10] deleting uncessary files. keeping only the torch script related scripts --- applications/benchmarking/DynaCell/fid.py | 207 ---------- .../benchmarking/DynaCell/test_fid.sh | 17 - .../benchmarking/DynaCell/vae_3d/__init__.py | 0 .../DynaCell/vae_3d/modules/__init__.py | 0 .../DynaCell/vae_3d/modules/autoencoders.py | 160 -------- .../vae_3d/modules/autoencoders_ts.py | 82 ---- .../DynaCell/vae_3d/modules/blocks.py | 353 ------------------ .../DynaCell/vae_3d/modules/decoder.py | 98 ----- .../DynaCell/vae_3d/modules/encoder.py | 87 ----- .../DynaCell/vae_3d/modules/utils.py | 9 - .../DynaCell/vae_3d/vae_3d_config.py | 16 - .../DynaCell/vae_3d/vae_3d_model.py | 138 ------- .../DynaCell/vae_3d/vae_3d_model_ts.py | 90 ----- applications/dynacell/fid.py | 207 ---------- applications/dynacell/fid_ts.py | 203 ---------- applications/dynacell/test_fid.sh | 17 - applications/dynacell/test_fid_ts.sh | 17 - applications/dynacell/vae_3d/__init__.py | 0 .../dynacell/vae_3d/modules/__init__.py | 0 .../dynacell/vae_3d/modules/autoencoders.py | 160 -------- .../vae_3d/modules/autoencoders_ts.py | 82 ---- .../dynacell/vae_3d/modules/blocks.py | 353 ------------------ .../dynacell/vae_3d/modules/decoder.py | 98 ----- .../dynacell/vae_3d/modules/encoder.py | 87 ----- applications/dynacell/vae_3d/modules/utils.py | 9 - applications/dynacell/vae_3d/vae_3d_config.py | 16 - applications/dynacell/vae_3d/vae_3d_model.py | 138 ------- .../dynacell/vae_3d/vae_3d_model_ts.py | 90 ----- 28 files changed, 2734 deletions(-) delete mode 100644 applications/benchmarking/DynaCell/fid.py delete mode 100644 applications/benchmarking/DynaCell/test_fid.sh delete mode 100644 applications/benchmarking/DynaCell/vae_3d/__init__.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/__init__.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/blocks.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/decoder.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/encoder.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/modules/utils.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py delete mode 100644 applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py delete mode 100644 applications/dynacell/fid.py delete mode 100644 applications/dynacell/fid_ts.py delete mode 100644 applications/dynacell/test_fid.sh delete mode 100644 applications/dynacell/test_fid_ts.sh delete mode 100644 applications/dynacell/vae_3d/__init__.py delete mode 100644 applications/dynacell/vae_3d/modules/__init__.py delete mode 100644 applications/dynacell/vae_3d/modules/autoencoders.py delete mode 100644 applications/dynacell/vae_3d/modules/autoencoders_ts.py delete mode 100644 applications/dynacell/vae_3d/modules/blocks.py delete mode 100644 applications/dynacell/vae_3d/modules/decoder.py delete mode 100644 applications/dynacell/vae_3d/modules/encoder.py delete mode 100644 applications/dynacell/vae_3d/modules/utils.py delete mode 100644 applications/dynacell/vae_3d/vae_3d_config.py delete mode 100644 applications/dynacell/vae_3d/vae_3d_model.py delete mode 100644 applications/dynacell/vae_3d/vae_3d_model_ts.py diff --git a/applications/benchmarking/DynaCell/fid.py b/applications/benchmarking/DynaCell/fid.py deleted file mode 100644 index d53392dc4..000000000 --- a/applications/benchmarking/DynaCell/fid.py +++ /dev/null @@ -1,207 +0,0 @@ -# -*- coding: utf-8 -*- -import argparse -from pathlib import Path - -import torch -from tqdm import tqdm -from iohub.ngff import open_ome_zarr -from torch import Tensor - -from vae_3d.vae_3d_config import VAE3DConfig -from vae_3d.vae_3d_model import VAE3DModel - -# ----------------------------------------------------------------------------- # -# Helper functions # -# ----------------------------------------------------------------------------- # - -def read_zarr(zarr_path: str): - plate = open_ome_zarr(zarr_path, mode="r") - return [pos for _, pos in plate.positions()] - -def normalise(volume: torch.Tensor) -> torch.Tensor: - """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" - v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) - v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) - volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] - return volume * 2.0 - 1.0 # → [-1,1] - -@torch.jit.script_if_tracing -def sqrtm(sigma: Tensor) -> Tensor: - r"""Returns the square root of a positive semi-definite matrix. - - .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T - - where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. - - Args: - sigma: A positive semi-definite matrix, :math:`(*, D, D)`. - - Example: - >>> V = torch.randn(4, 4, dtype=torch.double) - >>> A = V @ V.T - >>> B = sqrtm(A @ A) - >>> torch.allclose(A, B) - True - """ - - L, Q = torch.linalg.eigh(sigma) - L = L.relu().sqrt() - - return Q @ (L[..., None] * Q.mT) - -@torch.jit.script_if_tracing -def frechet_distance( - mu_x: Tensor, - sigma_x: Tensor, - mu_y: Tensor, - sigma_y: Tensor, -) -> Tensor: - r"""Returns the Fréchet distance between two multivariate Gaussian distributions. - - .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + - \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) - - Wikipedia: - https://wikipedia.org/wiki/Frechet_distance - - Args: - mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. - sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. - mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. - sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. - - Example: - >>> mu_x = torch.arange(3).float() - >>> sigma_x = torch.eye(3) - >>> mu_y = 2 * mu_x + 1 - >>> sigma_y = 2 * sigma_x + 1 - >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) - tensor(15.8710) - """ - - sigma_y_12 = sqrtm(sigma_y) - - a = (mu_x - mu_y).square().sum(dim=-1) - b = sigma_x.trace() + sigma_y.trace() - c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() - - return a + b - 2 * c - -@torch.no_grad() -def fid_from_features(f1, f2, eps=1e-6): - mu1, sigma1 = f1.mean(0), torch.cov(f1.T) - mu2, sigma2 = f2.mean(0), torch.cov(f2.T) - - eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) - sigma1 = sigma1 + eps * eye - sigma2 = sigma2 + eps * eye - - return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() - -@torch.no_grad() -def encode_fovs( - fovs, - vae, - channel_name: str, - device: str = "cuda", - batch_size: int = 4, - input_spatial_size: tuple = (32, 512, 512), -): - """ - For each FOV pair: - • take all T time-frames (shape: T, D, H, W) - • normalise to [-1, 1] - • feed through VAE in chunks of ≤ batch_size frames - • average the resulting T latent vectors → one embedding / FOV - Returns - emb: (N, latent_dim) tensors - """ - emb = [] - - for pos in tqdm(fovs, desc="Encoding FOVs"): - # ---------------- load & normalise ---------------- # - v = torch.as_tensor( - pos.data[:, pos.get_channel_index(channel_name)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) - - v = normalise(v) # still (T, D, H, W) - - # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v.shape[0], batch_size): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - - feat = vae.encode(slice).mean # mean, - feat = feat.flatten(start_dim=1) # (b, latent_dim) - emb.append(feat) - - return torch.cat(emb, 0) - -# ----------------------------------------------------------------------------- # -# Main # -# ----------------------------------------------------------------------------- # - -def build_argparser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(add_help=False) - p.add_argument("--data_path1", type=Path, required=True) - p.add_argument("--data_path2", type=Path, required=True) - p.add_argument("--channel_name", type=str, default=None) - p.add_argument("--channel_name1", type=str, default=None) - p.add_argument("--channel_name2", type=str, default=None) - p.add_argument("--input_spatial_size", type=str, default="32,512,512", - help="Input spatial size for the VAE, e.g. '32,512,512'.") - p.add_argument("--loadcheck_path", type=Path, default=None, - help="Path to the VAE model checkpoint for loading.") - p.add_argument("--batch_size", type=int, default=4) - p.add_argument("--device", type=str, default="cuda") - p.add_argument("--max_fov", type=int, default=None, - help="Limit number of FOV pairs (for quick tests).") - return p - -def main(args) -> None: - device = args.device - - # ----------------- VAE ----------------- # - model_cfg = VAE3DConfig() - model_cfg.loadcheck_path = args.loadcheck_path - vae = VAE3DModel(config=model_cfg).to(device).eval() - - # ----------------- FOV list ------------ # - fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - if args.max_fov: - fovs1 = fovs1[:args.max_fov] - fovs2 = fovs2[:args.max_fov] - - # ----------------- Embeddings ----------- # - input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] - - if args.channel_name is not None: - args.channel_name1 = args.channel_name2 = args.channel_name - - emb1 = encode_fovs( - fovs1, vae, - args.channel_name1, - device, args.batch_size, - input_spatial_size, - ) - - emb2 = encode_fovs( - fovs2, vae, - args.channel_name2, - device, args.batch_size, - input_spatial_size, - ) - - # ----------------- FID ------------------ # - fid_val = fid_from_features(emb1, emb2) - print(f"\nFID: {fid_val:.6f}") - -if __name__ == "__main__": - parser = build_argparser() - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/test_fid.sh b/applications/benchmarking/DynaCell/test_fid.sh deleted file mode 100644 index 84c330e5d..000000000 --- a/applications/benchmarking/DynaCell/test_fid.sh +++ /dev/null @@ -1,17 +0,0 @@ -python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Nuclei-prediction \ - --channel_name2 Organelle \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ - --batch_size 4 \ - --device cuda - -python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Membrane-prediction \ - --channel_name2 Membrane \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ - --batch_size 4 \ - --device cuda \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/__init__.py b/applications/benchmarking/DynaCell/vae_3d/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/__init__.py b/applications/benchmarking/DynaCell/vae_3d/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py deleted file mode 100644 index 9c3fb927a..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .decoder import Decoder -from .encoder import Encoder - -from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders.single_file_model import FromOriginalModelMixin -from diffusers.utils.accelerate_utils import apply_forward_hook -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from diffusers.models.modeling_utils import ModelMixin - - -class Autoencoder3DKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - num_up_blocks: int = 2, - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - num_down_blocks=num_down_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - num_up_blocks=num_up_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, Decoder)): - module.gradient_checkpointing = value - - def _encode(self, x: torch.Tensor) -> torch.Tensor: - - enc = self.encoder(x) - if self.quant_conv is not None: - enc = self.quant_conv(enc) - - return enc - - @apply_forward_hook - def encode( - self, x: torch.Tensor, return_dict: bool = True - ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: - """ - Encode a batch of images into latents. - - Args: - x (`torch.Tensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. - - Returns: - The latent representations of the encoded images. If `return_dict` is True, a - [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. - """ - h = self._encode(x) - posterior = DiagonalGaussianDistribution(h) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - - if self.post_quant_conv is not None: - z = self.post_quant_conv(z) - - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - @apply_forward_hook - def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None - ) -> Union[DecoderOutput, torch.FloatTensor]: - """ - Decode a batch of images. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - - """ - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.Tensor]: - r""" - Args: - sample (`torch.Tensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py b/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py deleted file mode 100644 index 9438ddd6d..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/autoencoders_ts.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from .decoder import Decoder -from .encoder import Encoder - - -class Autoencoder3DKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - num_up_blocks: int = 2, - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - num_down_blocks=num_down_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - num_up_blocks=num_up_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None - - def _encode(self, x: torch.Tensor) -> torch.Tensor: - - enc = self.encoder(x) - if self.quant_conv is not None: - enc = self.quant_conv(enc) - - return enc - - def encode(self, x: torch.Tensor): - h = self._encode(x) - mean, logvar = torch.chunk(h, 2, dim=1) - - return mean, logvar - - def _decode(self, z: torch.Tensor): - - if self.post_quant_conv is not None: - z = self.post_quant_conv(z) - dec = self.decoder(z) - - return dec - - def decode(self, z: torch.FloatTensor): - decoded = self._decode(z) - - return decoded - - def forward(self, x): - # placeholder forward - return x \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py b/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py deleted file mode 100644 index 569d66e9a..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/blocks.py +++ /dev/null @@ -1,353 +0,0 @@ -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn -from diffusers.models.normalization import RMSNorm -from diffusers.models.activations import get_activation - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - conv_shortcut_bias: bool = True, - ): - super().__init__() - - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - - self.dropout = torch.nn.Dropout(dropout) - conv_3d_out_channels = out_channels - self.conv2 = nn.Conv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1, padding=1) - - self.nonlinearity = get_activation(non_linearity) - self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = nn.Conv3d( - in_channels, - conv_3d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=conv_shortcut_bias, - ) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - hidden_states = self.norm2(hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - -class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - kernel_size: int = 3, - norm_type: Optional[str] = None, - eps: Optional[float] = 1e-5, - elementwise_affine: Optional[bool] = True, - bias: bool = True, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - self.kernel_size = kernel_size - stride = 2 # Downsampling stride is fixed to 2 - - # Initialize normalization - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"Unknown norm_type: {norm_type}") - - # Choose between convolutional or pooling downsampling - if use_conv: - self.conv = nn.Conv3d( - self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias - ) - else: - assert self.channels == self.out_channels, "out_channels must match channels when using pooling" - self.conv = nn.AvgPool3d(kernel_size=stride, stride=stride) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the downsampling block. - - Args: - hidden_states (torch.Tensor): Input feature map of shape (B, C, D, H, W). - - Returns: - torch.Tensor: Downsampled feature map. - """ - assert hidden_states.shape[1] == self.channels, \ - f"Expected input channels {self.channels}, but got {hidden_states.shape[1]}" - - # Apply normalization if specified - if self.norm is not None: - # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) - hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) - - # Apply padding if using conv downsampling and no padding was specified - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0.0) - - # Apply downsampling - hidden_states = self.conv(hidden_states) - - return hidden_states - -class Upsample3D(nn.Module): - """A 3D upsampling layer with a convolution. - """ - - def __init__( - self, - channels: int, - out_channels: Optional[int] = None, - kernel_size: Optional[int] = None, - padding=1, - norm_type=None, - eps=None, - elementwise_affine=None, - bias=True, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"unknown norm_type: {norm_type}") - - conv = None - if kernel_size is None: - kernel_size = 3 - conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - self.conv = conv - - def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: - assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" - - # Apply normalization if specified - if self.norm is not None: - # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) - hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) - - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - hidden_states = self.conv(hidden_states) - - return hidden_states - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - """ - - def __init__( - self, - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - attn_groups: Optional[int] = None, - output_scale_factor: float = 1.0, - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - if attn_groups is None: - attn_groups = resnet_groups - - self.resnets = nn.ModuleList([ - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - for _ in range(num_layers + 1) - ]) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py b/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py deleted file mode 100644 index 19ff8725b..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/decoder.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from .blocks import UNetMidBlock3D, UpDecoderBlock3D -from diffusers.models.attention_processor import SpatialNorm - -class Decoder(nn.Module): - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_up_blocks: int = 2, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_groups=norm_num_groups, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(num_up_blocks): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - in_channels=prev_output_channel, - out_channels=output_channel, - num_layers=self.layers_per_block + 1, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - add_upsample=not is_final_block, - ) - - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv3d( - block_out_channels[0], - out_channels, - kernel_size=3, - padding=1, - padding_mode='reflect', - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for up_block in self.up_blocks: - sample = up_block(sample) - - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py b/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py deleted file mode 100644 index bc93f857b..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/encoder.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.nn as nn - -from typing import Tuple -from .blocks import DownEncoderBlock3D, UNetMidBlock3D - -class Encoder(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - padding_mode='reflect' - ) - - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i in range(num_down_blocks): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - in_channels=input_channel, - out_channels=output_channel, - dropout=0.0, - num_layers=self.layers_per_block, - add_downsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - downsample_padding=0, - ) - - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_groups=norm_num_groups, - ) - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv3d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.conv_in(sample) - - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/modules/utils.py b/applications/benchmarking/DynaCell/vae_3d/modules/utils.py deleted file mode 100644 index 0c2a8185e..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/modules/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from dataclasses import dataclass -from transformers.utils import ModelOutput - -@dataclass -class VAEOutput(ModelOutput): - loss: torch.FloatTensor = None - recon_loss: torch.FloatTensor = None - kl_loss: torch.FloatTensor = None \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py deleted file mode 100644 index d30883ed2..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/vae_3d_config.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -from dataclasses import dataclass -from transformers import PretrainedConfig -from dataclasses import field - -@dataclass -class VAE3DConfig(PretrainedConfig): - model_type: str = 'vae' - - # Model parameters - in_channels: int = 1 - out_channels: int = 1 - num_down_blocks: int = 5 - latent_channels: int = 2 - vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256, 256]) - loadcheck_path: str = "" \ No newline at end of file diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py deleted file mode 100644 index d01138b1e..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import torch -import torch.nn as nn -from .modules.autoencoders import Autoencoder3DKL -from .vae_3d_config import VAE3DConfig -from transformers import PreTrainedModel -from .modules.utils import VAEOutput - - -class VAE3DModel(PreTrainedModel): - config_class = VAE3DConfig - - def __init__(self, config: VAE3DConfig): - super().__init__(config) - self.config = config - - self.num_down_blocks = config.num_down_blocks - self.num_up_blocks = self.num_down_blocks - - # Initialize Autoencoder3DKL - self.vae = Autoencoder3DKL( - in_channels=config.in_channels, - out_channels=config.out_channels, - num_down_blocks=self.num_down_blocks, - num_up_blocks=self.num_up_blocks, - block_out_channels=config.vae_block_out_channels, - latent_channels=config.latent_channels, - ) - - self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) - - def load_pretrained_weights(self, checkpoint_path): - """ - Load pretrained weights from a given state_dict. - """ - - if os.path.splitext(checkpoint_path)[1] == '.safetensors': - from safetensors.torch import load_file - checkpoints_state = load_file(checkpoint_path) - else: - checkpoints_state = torch.load(checkpoint_path, map_location="cpu") - - if "model" in checkpoints_state: - checkpoints_state = checkpoints_state["model"] - elif "module" in checkpoints_state: - checkpoints_state = checkpoints_state["module"] - - IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) - IncompatibleKeys = IncompatibleKeys._asdict() - - missing_keys = [] - for keys in IncompatibleKeys["missing_keys"]: - if keys.find("dummy") == -1: - missing_keys.append(keys) - - unexpected_keys = [] - for keys in IncompatibleKeys["unexpected_keys"]: - if keys.find("dummy") == -1: - unexpected_keys.append(keys) - - if len(missing_keys) > 0: - print( - "Missing keys in {}: {}".format( - checkpoint_path, - missing_keys, - ) - ) - - if len(unexpected_keys) > 0: - print( - "Unexpected keys {}: {}".format( - checkpoint_path, - unexpected_keys, - ) - ) - - def encode(self, x): - """Encodes input into latent space.""" - return self.vae.encode(x).latent_dist - - def decode(self, latents): - """Decodes latent space into reconstructed input.""" - return self.vae.decode(latents) - - def forward(self, batched_data): - x = batched_data['data'] - - """Forward pass through the VAE.""" - latent_dist = self.encode(x) - latents = latent_dist.sample() - recon_x = self.decode(latents).sample - - total_loss, recon_loss, kl_loss = self.compute_loss(x, recon_x, latent_dist) - - return VAEOutput(total_loss, recon_loss, kl_loss) - - def compute_loss(self, x, recon_x, latent_dist): - """Compute reconstruction and KL divergence loss.""" - if self.config.vae_recon_loss_type == 'mse': - recon_loss = nn.MSELoss()(recon_x, x) - elif self.config.vae_recon_loss_type == 'poisson': - x = x.clip(-1, 1) - recon_x = recon_x.clip(-1, 1) - peak = self.config.poisson_peak if hasattr(self.config, 'poisson_peak') else 1.0 - target = (x + 1) / 2.0 * peak - lam = (recon_x + 1) / 2.0 * peak - recon_loss = torch.mean(lam - target * torch.log(lam + 1e-8)) - - kl_loss = -0.5 * torch.mean(1 + latent_dist.logvar - latent_dist.mean.pow(2) - latent_dist.logvar.exp()) - total_loss = self.config.recon_loss_coeff * recon_loss + self.config.kl_loss_coeff * kl_loss - return total_loss, recon_loss, kl_loss - - def sample(self, num_samples=1, latent_size=32, device="cpu"): - """ - Generate samples from the latent space. - - Args: - num_samples (int): Number of samples to generate. - device (str): Device to perform sampling on. - - Returns: - torch.Tensor: Generated images. - """ - # Sample from a standard normal distribution in latent space - latents = torch.randn((num_samples, self.config.latent_channels, latent_size, latent_size, latent_size), device=device) # Shape matches latent dimensions - - # Decode latents to generate images - with torch.no_grad(): - generated_images = self.decode(latents).sample - - return generated_images - - def reconstruct(self, x): - latent_dist = self.encode(x) - latents = latent_dist.sample() # Reparameterization trick - recon_x = self.decode(latents).sample - - return recon_x diff --git a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py b/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py deleted file mode 100644 index 92a53ac24..000000000 --- a/applications/benchmarking/DynaCell/vae_3d/vae_3d_model_ts.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import torch -import torch.nn as nn -from .modules.autoencoders_ts import Autoencoder3DKL -from .vae_3d_config import VAE3DConfig - - -class VAE3DModel(nn.Module): - def __init__(self, config: VAE3DConfig): - super().__init__() - self.config = config - - self.num_down_blocks = config.num_down_blocks - self.num_up_blocks = self.num_down_blocks - - # Initialize Autoencoder3DKL - self.vae = Autoencoder3DKL( - in_channels=config.in_channels, - out_channels=config.out_channels, - num_down_blocks=self.num_down_blocks, - num_up_blocks=self.num_up_blocks, - block_out_channels=config.vae_block_out_channels, - latent_channels=config.latent_channels, - ) - - self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) - - def load_pretrained_weights(self, checkpoint_path): - """ - Load pretrained weights from a given state_dict. - """ - - if os.path.splitext(checkpoint_path)[1] == '.safetensors': - from safetensors.torch import load_file - checkpoints_state = load_file(checkpoint_path) - else: - checkpoints_state = torch.load(checkpoint_path, map_location="cpu") - - if "model" in checkpoints_state: - checkpoints_state = checkpoints_state["model"] - elif "module" in checkpoints_state: - checkpoints_state = checkpoints_state["module"] - - IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) - IncompatibleKeys = IncompatibleKeys._asdict() - - missing_keys = [] - for keys in IncompatibleKeys["missing_keys"]: - if keys.find("dummy") == -1: - missing_keys.append(keys) - - unexpected_keys = [] - for keys in IncompatibleKeys["unexpected_keys"]: - if keys.find("dummy") == -1: - unexpected_keys.append(keys) - - if len(missing_keys) > 0: - print( - "Missing keys in {}: {}".format( - checkpoint_path, - missing_keys, - ) - ) - - if len(unexpected_keys) > 0: - print( - "Unexpected keys {}: {}".format( - checkpoint_path, - unexpected_keys, - ) - ) - - def encode(self, x): - """Encodes input into latent space.""" - return self.vae.encode(x) - - def decode(self, latents): - """Decodes latent space into reconstructed input.""" - return self.vae.decode(latents) - - def forward(self, x): - # placeholder forward - return x - - def reconstruct(self, x): - mean, logvar = self.encode(x) - latents = mean + torch.exp(0.5 * logvar) * torch.randn_like(logvar) # Reparameterization trick - recon_x = self.decode(latents) - - return recon_x diff --git a/applications/dynacell/fid.py b/applications/dynacell/fid.py deleted file mode 100644 index d53392dc4..000000000 --- a/applications/dynacell/fid.py +++ /dev/null @@ -1,207 +0,0 @@ -# -*- coding: utf-8 -*- -import argparse -from pathlib import Path - -import torch -from tqdm import tqdm -from iohub.ngff import open_ome_zarr -from torch import Tensor - -from vae_3d.vae_3d_config import VAE3DConfig -from vae_3d.vae_3d_model import VAE3DModel - -# ----------------------------------------------------------------------------- # -# Helper functions # -# ----------------------------------------------------------------------------- # - -def read_zarr(zarr_path: str): - plate = open_ome_zarr(zarr_path, mode="r") - return [pos for _, pos in plate.positions()] - -def normalise(volume: torch.Tensor) -> torch.Tensor: - """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" - v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) - v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) - volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] - return volume * 2.0 - 1.0 # → [-1,1] - -@torch.jit.script_if_tracing -def sqrtm(sigma: Tensor) -> Tensor: - r"""Returns the square root of a positive semi-definite matrix. - - .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T - - where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. - - Args: - sigma: A positive semi-definite matrix, :math:`(*, D, D)`. - - Example: - >>> V = torch.randn(4, 4, dtype=torch.double) - >>> A = V @ V.T - >>> B = sqrtm(A @ A) - >>> torch.allclose(A, B) - True - """ - - L, Q = torch.linalg.eigh(sigma) - L = L.relu().sqrt() - - return Q @ (L[..., None] * Q.mT) - -@torch.jit.script_if_tracing -def frechet_distance( - mu_x: Tensor, - sigma_x: Tensor, - mu_y: Tensor, - sigma_y: Tensor, -) -> Tensor: - r"""Returns the Fréchet distance between two multivariate Gaussian distributions. - - .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + - \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) - - Wikipedia: - https://wikipedia.org/wiki/Frechet_distance - - Args: - mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. - sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. - mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. - sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. - - Example: - >>> mu_x = torch.arange(3).float() - >>> sigma_x = torch.eye(3) - >>> mu_y = 2 * mu_x + 1 - >>> sigma_y = 2 * sigma_x + 1 - >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) - tensor(15.8710) - """ - - sigma_y_12 = sqrtm(sigma_y) - - a = (mu_x - mu_y).square().sum(dim=-1) - b = sigma_x.trace() + sigma_y.trace() - c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() - - return a + b - 2 * c - -@torch.no_grad() -def fid_from_features(f1, f2, eps=1e-6): - mu1, sigma1 = f1.mean(0), torch.cov(f1.T) - mu2, sigma2 = f2.mean(0), torch.cov(f2.T) - - eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) - sigma1 = sigma1 + eps * eye - sigma2 = sigma2 + eps * eye - - return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() - -@torch.no_grad() -def encode_fovs( - fovs, - vae, - channel_name: str, - device: str = "cuda", - batch_size: int = 4, - input_spatial_size: tuple = (32, 512, 512), -): - """ - For each FOV pair: - • take all T time-frames (shape: T, D, H, W) - • normalise to [-1, 1] - • feed through VAE in chunks of ≤ batch_size frames - • average the resulting T latent vectors → one embedding / FOV - Returns - emb: (N, latent_dim) tensors - """ - emb = [] - - for pos in tqdm(fovs, desc="Encoding FOVs"): - # ---------------- load & normalise ---------------- # - v = torch.as_tensor( - pos.data[:, pos.get_channel_index(channel_name)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) - - v = normalise(v) # still (T, D, H, W) - - # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v.shape[0], batch_size): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - - feat = vae.encode(slice).mean # mean, - feat = feat.flatten(start_dim=1) # (b, latent_dim) - emb.append(feat) - - return torch.cat(emb, 0) - -# ----------------------------------------------------------------------------- # -# Main # -# ----------------------------------------------------------------------------- # - -def build_argparser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(add_help=False) - p.add_argument("--data_path1", type=Path, required=True) - p.add_argument("--data_path2", type=Path, required=True) - p.add_argument("--channel_name", type=str, default=None) - p.add_argument("--channel_name1", type=str, default=None) - p.add_argument("--channel_name2", type=str, default=None) - p.add_argument("--input_spatial_size", type=str, default="32,512,512", - help="Input spatial size for the VAE, e.g. '32,512,512'.") - p.add_argument("--loadcheck_path", type=Path, default=None, - help="Path to the VAE model checkpoint for loading.") - p.add_argument("--batch_size", type=int, default=4) - p.add_argument("--device", type=str, default="cuda") - p.add_argument("--max_fov", type=int, default=None, - help="Limit number of FOV pairs (for quick tests).") - return p - -def main(args) -> None: - device = args.device - - # ----------------- VAE ----------------- # - model_cfg = VAE3DConfig() - model_cfg.loadcheck_path = args.loadcheck_path - vae = VAE3DModel(config=model_cfg).to(device).eval() - - # ----------------- FOV list ------------ # - fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - if args.max_fov: - fovs1 = fovs1[:args.max_fov] - fovs2 = fovs2[:args.max_fov] - - # ----------------- Embeddings ----------- # - input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] - - if args.channel_name is not None: - args.channel_name1 = args.channel_name2 = args.channel_name - - emb1 = encode_fovs( - fovs1, vae, - args.channel_name1, - device, args.batch_size, - input_spatial_size, - ) - - emb2 = encode_fovs( - fovs2, vae, - args.channel_name2, - device, args.batch_size, - input_spatial_size, - ) - - # ----------------- FID ------------------ # - fid_val = fid_from_features(emb1, emb2) - print(f"\nFID: {fid_val:.6f}") - -if __name__ == "__main__": - parser = build_argparser() - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/applications/dynacell/fid_ts.py b/applications/dynacell/fid_ts.py deleted file mode 100644 index 8c8f272cf..000000000 --- a/applications/dynacell/fid_ts.py +++ /dev/null @@ -1,203 +0,0 @@ -# -*- coding: utf-8 -*- -import argparse -from pathlib import Path - -import torch -from tqdm import tqdm -from iohub.ngff import open_ome_zarr -from torch import Tensor - -# ----------------------------------------------------------------------------- # -# Helper functions # -# ----------------------------------------------------------------------------- # - -def read_zarr(zarr_path: str): - plate = open_ome_zarr(zarr_path, mode="r") - return [pos for _, pos in plate.positions()] - -def normalise(volume: torch.Tensor) -> torch.Tensor: - """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" - v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) - v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) - volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] - return volume * 2.0 - 1.0 # → [-1,1] - -@torch.jit.script_if_tracing -def sqrtm(sigma: Tensor) -> Tensor: - r"""Returns the square root of a positive semi-definite matrix. - - .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T - - where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. - - Args: - sigma: A positive semi-definite matrix, :math:`(*, D, D)`. - - Example: - >>> V = torch.randn(4, 4, dtype=torch.double) - >>> A = V @ V.T - >>> B = sqrtm(A @ A) - >>> torch.allclose(A, B) - True - """ - - L, Q = torch.linalg.eigh(sigma) - L = L.relu().sqrt() - - return Q @ (L[..., None] * Q.mT) - -@torch.jit.script_if_tracing -def frechet_distance( - mu_x: Tensor, - sigma_x: Tensor, - mu_y: Tensor, - sigma_y: Tensor, -) -> Tensor: - r"""Returns the Fréchet distance between two multivariate Gaussian distributions. - - .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + - \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) - - Wikipedia: - https://wikipedia.org/wiki/Frechet_distance - - Args: - mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. - sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. - mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. - sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. - - Example: - >>> mu_x = torch.arange(3).float() - >>> sigma_x = torch.eye(3) - >>> mu_y = 2 * mu_x + 1 - >>> sigma_y = 2 * sigma_x + 1 - >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) - tensor(15.8710) - """ - - sigma_y_12 = sqrtm(sigma_y) - - a = (mu_x - mu_y).square().sum(dim=-1) - b = sigma_x.trace() + sigma_y.trace() - c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() - - return a + b - 2 * c - -@torch.no_grad() -def fid_from_features(f1, f2, eps=1e-6): - mu1, sigma1 = f1.mean(0), torch.cov(f1.T) - mu2, sigma2 = f2.mean(0), torch.cov(f2.T) - - eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) - sigma1 = sigma1 + eps * eye - sigma2 = sigma2 + eps * eye - - return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() - -@torch.no_grad() -def encode_fovs( - fovs, - vae, - channel_name: str, - device: str = "cuda", - batch_size: int = 4, - input_spatial_size: tuple = (32, 512, 512), -): - """ - For each FOV pair: - • take all T time-frames (shape: T, D, H, W) - • normalise to [-1, 1] - • feed through VAE in chunks of ≤ batch_size frames - • average the resulting T latent vectors → one embedding / FOV - Returns - emb: (N, latent_dim) tensors - """ - emb = [] - - for pos in tqdm(fovs, desc="Encoding FOVs"): - # ---------------- load & normalise ---------------- # - v = torch.as_tensor( - pos.data[:, pos.get_channel_index(channel_name)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) - - v = normalise(v) # still (T, D, H, W) - - # ---------------- chunked VAE inference ----------- # - for t0 in range(0, v.shape[0], batch_size): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - - feat = vae.encode(slice)[0] # mean, - feat = feat.flatten(start_dim=1) # (b, latent_dim) - emb.append(feat) - - return torch.cat(emb, 0) - -# ----------------------------------------------------------------------------- # -# Main # -# ----------------------------------------------------------------------------- # - -def build_argparser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(add_help=False) - p.add_argument("--data_path1", type=Path, required=True) - p.add_argument("--data_path2", type=Path, required=True) - p.add_argument("--channel_name", type=str, default=None) - p.add_argument("--channel_name1", type=str, default=None) - p.add_argument("--channel_name2", type=str, default=None) - p.add_argument("--input_spatial_size", type=str, default="32,512,512", - help="Input spatial size for the VAE, e.g. '32,512,512'.") - p.add_argument("--loadcheck_path", type=Path, default=None, - help="Path to the VAE model checkpoint for loading.") - p.add_argument("--batch_size", type=int, default=4) - p.add_argument("--device", type=str, default="cuda") - p.add_argument("--max_fov", type=int, default=None, - help="Limit number of FOV pairs (for quick tests).") - return p - -def main(args) -> None: - device = args.device - - # ----------------- VAE ----------------- # - vae = torch.jit.load(args.loadcheck_path).to(device) - vae.eval() - - # ----------------- FOV list ------------ # - fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) - if args.max_fov: - fovs1 = fovs1[:args.max_fov] - fovs2 = fovs2[:args.max_fov] - - # ----------------- Embeddings ----------- # - input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] - - if args.channel_name is not None: - args.channel_name1 = args.channel_name2 = args.channel_name - - emb1 = encode_fovs( - fovs1, vae, - args.channel_name1, - device, args.batch_size, - input_spatial_size, - ) - - emb2 = encode_fovs( - fovs2, vae, - args.channel_name2, - device, args.batch_size, - input_spatial_size, - ) - - # ----------------- FID ------------------ # - fid_val = fid_from_features(emb1, emb2) - print(f"\nFID: {fid_val:.6f}") - -if __name__ == "__main__": - parser = build_argparser() - args = parser.parse_args() - main(args) \ No newline at end of file diff --git a/applications/dynacell/test_fid.sh b/applications/dynacell/test_fid.sh deleted file mode 100644 index 84c330e5d..000000000 --- a/applications/dynacell/test_fid.sh +++ /dev/null @@ -1,17 +0,0 @@ -python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Nuclei-prediction \ - --channel_name2 Organelle \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ - --batch_size 4 \ - --device cuda - -python fid.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Membrane-prediction \ - --channel_name2 Membrane \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ - --batch_size 4 \ - --device cuda \ No newline at end of file diff --git a/applications/dynacell/test_fid_ts.sh b/applications/dynacell/test_fid_ts.sh deleted file mode 100644 index 012cb23b1..000000000 --- a/applications/dynacell/test_fid_ts.sh +++ /dev/null @@ -1,17 +0,0 @@ -python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Nuclei-prediction \ - --channel_name2 Organelle \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ - --batch_size 4 \ - --device cuda - -python fid_ts.py \ - --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ - --channel_name1 Membrane-prediction \ - --channel_name2 Membrane \ - --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ - --batch_size 4 \ - --device cuda \ No newline at end of file diff --git a/applications/dynacell/vae_3d/__init__.py b/applications/dynacell/vae_3d/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/dynacell/vae_3d/modules/__init__.py b/applications/dynacell/vae_3d/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/applications/dynacell/vae_3d/modules/autoencoders.py b/applications/dynacell/vae_3d/modules/autoencoders.py deleted file mode 100644 index 9c3fb927a..000000000 --- a/applications/dynacell/vae_3d/modules/autoencoders.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .decoder import Decoder -from .encoder import Encoder - -from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders.single_file_model import FromOriginalModelMixin -from diffusers.utils.accelerate_utils import apply_forward_hook -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from diffusers.models.modeling_utils import ModelMixin - - -class Autoencoder3DKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - num_up_blocks: int = 2, - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - num_down_blocks=num_down_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - num_up_blocks=num_up_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Encoder, Decoder)): - module.gradient_checkpointing = value - - def _encode(self, x: torch.Tensor) -> torch.Tensor: - - enc = self.encoder(x) - if self.quant_conv is not None: - enc = self.quant_conv(enc) - - return enc - - @apply_forward_hook - def encode( - self, x: torch.Tensor, return_dict: bool = True - ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: - """ - Encode a batch of images into latents. - - Args: - x (`torch.Tensor`): Input batch of images. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. - - Returns: - The latent representations of the encoded images. If `return_dict` is True, a - [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. - """ - h = self._encode(x) - posterior = DiagonalGaussianDistribution(h) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - - if self.post_quant_conv is not None: - z = self.post_quant_conv(z) - - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - @apply_forward_hook - def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None - ) -> Union[DecoderOutput, torch.FloatTensor]: - """ - Decode a batch of images. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - - """ - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.Tensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.Tensor]: - r""" - Args: - sample (`torch.Tensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/applications/dynacell/vae_3d/modules/autoencoders_ts.py b/applications/dynacell/vae_3d/modules/autoencoders_ts.py deleted file mode 100644 index 9438ddd6d..000000000 --- a/applications/dynacell/vae_3d/modules/autoencoders_ts.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from .decoder import Decoder -from .encoder import Encoder - - -class Autoencoder3DKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - num_up_blocks: int = 2, - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - num_down_blocks=num_down_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - num_up_blocks=num_up_blocks, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None - self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None - - def _encode(self, x: torch.Tensor) -> torch.Tensor: - - enc = self.encoder(x) - if self.quant_conv is not None: - enc = self.quant_conv(enc) - - return enc - - def encode(self, x: torch.Tensor): - h = self._encode(x) - mean, logvar = torch.chunk(h, 2, dim=1) - - return mean, logvar - - def _decode(self, z: torch.Tensor): - - if self.post_quant_conv is not None: - z = self.post_quant_conv(z) - dec = self.decoder(z) - - return dec - - def decode(self, z: torch.FloatTensor): - decoded = self._decode(z) - - return decoded - - def forward(self, x): - # placeholder forward - return x \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/blocks.py b/applications/dynacell/vae_3d/modules/blocks.py deleted file mode 100644 index 569d66e9a..000000000 --- a/applications/dynacell/vae_3d/modules/blocks.py +++ /dev/null @@ -1,353 +0,0 @@ -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn -from diffusers.models.normalization import RMSNorm -from diffusers.models.activations import get_activation - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - *, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - output_scale_factor: float = 1.0, - use_in_shortcut: Optional[bool] = None, - conv_shortcut_bias: bool = True, - ): - super().__init__() - - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - - self.dropout = torch.nn.Dropout(dropout) - conv_3d_out_channels = out_channels - self.conv2 = nn.Conv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1, padding=1) - - self.nonlinearity = get_activation(non_linearity) - self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = nn.Conv3d( - in_channels, - conv_3d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=conv_shortcut_bias, - ) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - hidden_states = self.norm2(hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - -class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: Optional[int] = None, - padding: int = 1, - kernel_size: int = 3, - norm_type: Optional[str] = None, - eps: Optional[float] = 1e-5, - elementwise_affine: Optional[bool] = True, - bias: bool = True, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - self.kernel_size = kernel_size - stride = 2 # Downsampling stride is fixed to 2 - - # Initialize normalization - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"Unknown norm_type: {norm_type}") - - # Choose between convolutional or pooling downsampling - if use_conv: - self.conv = nn.Conv3d( - self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias - ) - else: - assert self.channels == self.out_channels, "out_channels must match channels when using pooling" - self.conv = nn.AvgPool3d(kernel_size=stride, stride=stride) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the downsampling block. - - Args: - hidden_states (torch.Tensor): Input feature map of shape (B, C, D, H, W). - - Returns: - torch.Tensor: Downsampled feature map. - """ - assert hidden_states.shape[1] == self.channels, \ - f"Expected input channels {self.channels}, but got {hidden_states.shape[1]}" - - # Apply normalization if specified - if self.norm is not None: - # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) - hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) - - # Apply padding if using conv downsampling and no padding was specified - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0.0) - - # Apply downsampling - hidden_states = self.conv(hidden_states) - - return hidden_states - -class Upsample3D(nn.Module): - """A 3D upsampling layer with a convolution. - """ - - def __init__( - self, - channels: int, - out_channels: Optional[int] = None, - kernel_size: Optional[int] = None, - padding=1, - norm_type=None, - eps=None, - elementwise_affine=None, - bias=True, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - if norm_type == "ln_norm": - self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) - elif norm_type == "rms_norm": - self.norm = RMSNorm(channels, eps, elementwise_affine) - elif norm_type is None: - self.norm = None - else: - raise ValueError(f"unknown norm_type: {norm_type}") - - conv = None - if kernel_size is None: - kernel_size = 3 - conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - self.conv = conv - - def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: - assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" - - # Apply normalization if specified - if self.norm is not None: - # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) - hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) - - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - hidden_states = self.conv(hidden_states) - - return hidden_states - -class UNetMidBlock3D(nn.Module): - """ - A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. - """ - - def __init__( - self, - in_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - attn_groups: Optional[int] = None, - output_scale_factor: float = 1.0, - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - if attn_groups is None: - attn_groups = resnet_groups - - self.resnets = nn.ModuleList([ - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - ) - for _ in range(num_layers + 1) - ]) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/decoder.py b/applications/dynacell/vae_3d/modules/decoder.py deleted file mode 100644 index 19ff8725b..000000000 --- a/applications/dynacell/vae_3d/modules/decoder.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from .blocks import UNetMidBlock3D, UpDecoderBlock3D -from diffusers.models.attention_processor import SpatialNorm - -class Decoder(nn.Module): - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_up_blocks: int = 2, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_groups=norm_num_groups, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(num_up_blocks): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpDecoderBlock3D( - in_channels=prev_output_channel, - out_channels=output_channel, - num_layers=self.layers_per_block + 1, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - add_upsample=not is_final_block, - ) - - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv3d( - block_out_channels[0], - out_channels, - kernel_size=3, - padding=1, - padding_mode='reflect', - ) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for up_block in self.up_blocks: - sample = up_block(sample) - - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/encoder.py b/applications/dynacell/vae_3d/modules/encoder.py deleted file mode 100644 index bc93f857b..000000000 --- a/applications/dynacell/vae_3d/modules/encoder.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import torch.nn as nn - -from typing import Tuple -from .blocks import DownEncoderBlock3D, UNetMidBlock3D - -class Encoder(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - num_down_blocks: int = 2, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - padding_mode='reflect' - ) - - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i in range(num_down_blocks): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = DownEncoderBlock3D( - in_channels=input_channel, - out_channels=output_channel, - dropout=0.0, - num_layers=self.layers_per_block, - add_downsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - downsample_padding=0, - ) - - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_groups=norm_num_groups, - ) - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv3d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - sample = self.conv_in(sample) - - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/utils.py b/applications/dynacell/vae_3d/modules/utils.py deleted file mode 100644 index 0c2a8185e..000000000 --- a/applications/dynacell/vae_3d/modules/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from dataclasses import dataclass -from transformers.utils import ModelOutput - -@dataclass -class VAEOutput(ModelOutput): - loss: torch.FloatTensor = None - recon_loss: torch.FloatTensor = None - kl_loss: torch.FloatTensor = None \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_config.py b/applications/dynacell/vae_3d/vae_3d_config.py deleted file mode 100644 index d30883ed2..000000000 --- a/applications/dynacell/vae_3d/vae_3d_config.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -from dataclasses import dataclass -from transformers import PretrainedConfig -from dataclasses import field - -@dataclass -class VAE3DConfig(PretrainedConfig): - model_type: str = 'vae' - - # Model parameters - in_channels: int = 1 - out_channels: int = 1 - num_down_blocks: int = 5 - latent_channels: int = 2 - vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256, 256]) - loadcheck_path: str = "" \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_model.py b/applications/dynacell/vae_3d/vae_3d_model.py deleted file mode 100644 index d01138b1e..000000000 --- a/applications/dynacell/vae_3d/vae_3d_model.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import torch -import torch.nn as nn -from .modules.autoencoders import Autoencoder3DKL -from .vae_3d_config import VAE3DConfig -from transformers import PreTrainedModel -from .modules.utils import VAEOutput - - -class VAE3DModel(PreTrainedModel): - config_class = VAE3DConfig - - def __init__(self, config: VAE3DConfig): - super().__init__(config) - self.config = config - - self.num_down_blocks = config.num_down_blocks - self.num_up_blocks = self.num_down_blocks - - # Initialize Autoencoder3DKL - self.vae = Autoencoder3DKL( - in_channels=config.in_channels, - out_channels=config.out_channels, - num_down_blocks=self.num_down_blocks, - num_up_blocks=self.num_up_blocks, - block_out_channels=config.vae_block_out_channels, - latent_channels=config.latent_channels, - ) - - self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) - - def load_pretrained_weights(self, checkpoint_path): - """ - Load pretrained weights from a given state_dict. - """ - - if os.path.splitext(checkpoint_path)[1] == '.safetensors': - from safetensors.torch import load_file - checkpoints_state = load_file(checkpoint_path) - else: - checkpoints_state = torch.load(checkpoint_path, map_location="cpu") - - if "model" in checkpoints_state: - checkpoints_state = checkpoints_state["model"] - elif "module" in checkpoints_state: - checkpoints_state = checkpoints_state["module"] - - IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) - IncompatibleKeys = IncompatibleKeys._asdict() - - missing_keys = [] - for keys in IncompatibleKeys["missing_keys"]: - if keys.find("dummy") == -1: - missing_keys.append(keys) - - unexpected_keys = [] - for keys in IncompatibleKeys["unexpected_keys"]: - if keys.find("dummy") == -1: - unexpected_keys.append(keys) - - if len(missing_keys) > 0: - print( - "Missing keys in {}: {}".format( - checkpoint_path, - missing_keys, - ) - ) - - if len(unexpected_keys) > 0: - print( - "Unexpected keys {}: {}".format( - checkpoint_path, - unexpected_keys, - ) - ) - - def encode(self, x): - """Encodes input into latent space.""" - return self.vae.encode(x).latent_dist - - def decode(self, latents): - """Decodes latent space into reconstructed input.""" - return self.vae.decode(latents) - - def forward(self, batched_data): - x = batched_data['data'] - - """Forward pass through the VAE.""" - latent_dist = self.encode(x) - latents = latent_dist.sample() - recon_x = self.decode(latents).sample - - total_loss, recon_loss, kl_loss = self.compute_loss(x, recon_x, latent_dist) - - return VAEOutput(total_loss, recon_loss, kl_loss) - - def compute_loss(self, x, recon_x, latent_dist): - """Compute reconstruction and KL divergence loss.""" - if self.config.vae_recon_loss_type == 'mse': - recon_loss = nn.MSELoss()(recon_x, x) - elif self.config.vae_recon_loss_type == 'poisson': - x = x.clip(-1, 1) - recon_x = recon_x.clip(-1, 1) - peak = self.config.poisson_peak if hasattr(self.config, 'poisson_peak') else 1.0 - target = (x + 1) / 2.0 * peak - lam = (recon_x + 1) / 2.0 * peak - recon_loss = torch.mean(lam - target * torch.log(lam + 1e-8)) - - kl_loss = -0.5 * torch.mean(1 + latent_dist.logvar - latent_dist.mean.pow(2) - latent_dist.logvar.exp()) - total_loss = self.config.recon_loss_coeff * recon_loss + self.config.kl_loss_coeff * kl_loss - return total_loss, recon_loss, kl_loss - - def sample(self, num_samples=1, latent_size=32, device="cpu"): - """ - Generate samples from the latent space. - - Args: - num_samples (int): Number of samples to generate. - device (str): Device to perform sampling on. - - Returns: - torch.Tensor: Generated images. - """ - # Sample from a standard normal distribution in latent space - latents = torch.randn((num_samples, self.config.latent_channels, latent_size, latent_size, latent_size), device=device) # Shape matches latent dimensions - - # Decode latents to generate images - with torch.no_grad(): - generated_images = self.decode(latents).sample - - return generated_images - - def reconstruct(self, x): - latent_dist = self.encode(x) - latents = latent_dist.sample() # Reparameterization trick - recon_x = self.decode(latents).sample - - return recon_x diff --git a/applications/dynacell/vae_3d/vae_3d_model_ts.py b/applications/dynacell/vae_3d/vae_3d_model_ts.py deleted file mode 100644 index 92a53ac24..000000000 --- a/applications/dynacell/vae_3d/vae_3d_model_ts.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import torch -import torch.nn as nn -from .modules.autoencoders_ts import Autoencoder3DKL -from .vae_3d_config import VAE3DConfig - - -class VAE3DModel(nn.Module): - def __init__(self, config: VAE3DConfig): - super().__init__() - self.config = config - - self.num_down_blocks = config.num_down_blocks - self.num_up_blocks = self.num_down_blocks - - # Initialize Autoencoder3DKL - self.vae = Autoencoder3DKL( - in_channels=config.in_channels, - out_channels=config.out_channels, - num_down_blocks=self.num_down_blocks, - num_up_blocks=self.num_up_blocks, - block_out_channels=config.vae_block_out_channels, - latent_channels=config.latent_channels, - ) - - self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) - - def load_pretrained_weights(self, checkpoint_path): - """ - Load pretrained weights from a given state_dict. - """ - - if os.path.splitext(checkpoint_path)[1] == '.safetensors': - from safetensors.torch import load_file - checkpoints_state = load_file(checkpoint_path) - else: - checkpoints_state = torch.load(checkpoint_path, map_location="cpu") - - if "model" in checkpoints_state: - checkpoints_state = checkpoints_state["model"] - elif "module" in checkpoints_state: - checkpoints_state = checkpoints_state["module"] - - IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) - IncompatibleKeys = IncompatibleKeys._asdict() - - missing_keys = [] - for keys in IncompatibleKeys["missing_keys"]: - if keys.find("dummy") == -1: - missing_keys.append(keys) - - unexpected_keys = [] - for keys in IncompatibleKeys["unexpected_keys"]: - if keys.find("dummy") == -1: - unexpected_keys.append(keys) - - if len(missing_keys) > 0: - print( - "Missing keys in {}: {}".format( - checkpoint_path, - missing_keys, - ) - ) - - if len(unexpected_keys) > 0: - print( - "Unexpected keys {}: {}".format( - checkpoint_path, - unexpected_keys, - ) - ) - - def encode(self, x): - """Encodes input into latent space.""" - return self.vae.encode(x) - - def decode(self, latents): - """Decodes latent space into reconstructed input.""" - return self.vae.decode(latents) - - def forward(self, x): - # placeholder forward - return x - - def reconstruct(self, x): - mean, logvar = self.encode(x) - latents = mean + torch.exp(0.5 * logvar) * torch.randn_like(logvar) # Reparameterization trick - recon_x = self.decode(latents) - - return recon_x From 0098771e06fc9c1b3b143dfba6b0243dc1df5ef8 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Sep 2025 20:17:06 -0700 Subject: [PATCH 10/10] fid on position . --- applications/benchmarking/DynaCell/fid_ts.py | 189 ++++-------------- .../benchmarking/DynaCell/run_fid_ts.sh | 14 +- 2 files changed, 44 insertions(+), 159 deletions(-) diff --git a/applications/benchmarking/DynaCell/fid_ts.py b/applications/benchmarking/DynaCell/fid_ts.py index 4c1f8ccb0..7d8cbbee7 100644 --- a/applications/benchmarking/DynaCell/fid_ts.py +++ b/applications/benchmarking/DynaCell/fid_ts.py @@ -9,10 +9,6 @@ from torch import Tensor from tqdm import tqdm -# ----------------------------------------------------------------------------- # -# Helper functions # -# ----------------------------------------------------------------------------- # - def normalise(volume: torch.Tensor) -> torch.Tensor: """Normalize volume to [-1, 1] range using min-max normalization. @@ -140,69 +136,9 @@ def fid_from_features(f1, f2, eps=1e-6): return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() -@torch.inference_mode() -def encode_position( - position: Position, - vae: torch.nn.Module, - channel_name: str, - device: str = "cuda", - batch_size: int = 4, - input_spatial_size: tuple = (32, 512, 512), -): - """Encode position data using a variational autoencoder. - - - Extract all time-frames with shape (T, D, H, W) - - Normalize to [-1, 1] range - - Process through VAE in batches of ≤ batch_size frames - - Collect all latent vectors from all time points - - Parameters - ---------- - position : Position - Single position object from zarr plate - vae : torch.nn.Module - Pre-trained VAE model for encoding - channel_name : str - Name of the channel to extract from the position - device : str, default="cuda" - Device to run computations on - batch_size : int, default=4 - Number of frames to process simultaneously - input_spatial_size : tuple, default=(32, 512, 512) - Target spatial dimensions for VAE input (D, H, W) - - Returns - ------- - torch.Tensor - Embeddings from all timepoints with shape (N_timepoints, latent_dim) - """ - emb = [] - - # ---------------- load & normalise ---------------- # - v = torch.as_tensor( - position.data[:, position.get_channel_index(channel_name)], - dtype=torch.float32, device=device, - ) # (T, D, H, W) - - v = normalise(v) # still (T, D, H, W) - - # ---------------- chunked VAE inference ----------- # - for t0 in tqdm(range(0, v.shape[0], batch_size), desc="Encoding timepoints"): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) - - feat = vae.encode(slice)[0] # mean, - feat = feat.flatten(start_dim=1) # (b, latent_dim) - emb.append(feat) - - return torch.cat(emb, 0) @torch.inference_mode() -def encode_position_with_metadata( +def embed_position( position: Position, vae: torch.nn.Module, channel_name: str, @@ -220,8 +156,6 @@ def encode_position_with_metadata( Pre-trained VAE model for encoding channel_name : str Name of the channel to extract from the position - position_name : str - Name/identifier for this position (e.g., 'A/1/0') device : str, default="cuda" Device to run computations on batch_size : int, default=4 @@ -238,7 +172,6 @@ def encode_position_with_metadata( embeddings_list = [] timepoints_list = [] - # ---------------- load & normalise ---------------- # v = torch.as_tensor( position.data[:, position.get_channel_index(channel_name)], dtype=torch.float32, device=device, @@ -246,28 +179,22 @@ def encode_position_with_metadata( v = normalise(v) # still (T, D, H, W) - # ---------------- chunked VAE inference ----------- # timepoint = 0 for t0 in tqdm(range(0, v.shape[0], batch_size), desc=f"Encoding {position_name}/{channel_name}"): - slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) - - # resize to input spatial size - slice = torch.nn.functional.interpolate( - slice, size=input_spatial_size, mode="trilinear", align_corners=False, - ) # (b, 1, D, H, W) + batch_slice = v[t0 : t0 + batch_size].unsqueeze(1) + batch_slice = torch.nn.functional.interpolate( + batch_slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) - feat = vae.encode(slice)[0] # mean, + feat = vae.encode(batch_slice)[0] # mean, feat = feat.flatten(start_dim=1) # (b, latent_dim) - # Convert to numpy and collect embeddings feat_np = feat.cpu().numpy() for i, embedding in enumerate(feat_np): embeddings_list.append(embedding) timepoints_list.append(timepoint + i) - timepoint += feat.shape[0] - # Create xarray Dataset embeddings_array = np.stack(embeddings_list) n_samples, n_features = embeddings_array.shape @@ -276,19 +203,14 @@ def encode_position_with_metadata( }, coords={ 'sample': range(n_samples), 'feature': range(n_features), - 'timepoint': ('sample', timepoints_list) + 't': ('sample', timepoints_list) }) - # Add metadata as attributes ds.attrs['position_name'] = position_name ds.attrs['channel_name'] = channel_name return ds -# ----------------------------------------------------------------------------- # -# Main # -# ----------------------------------------------------------------------------- # - @click.command() @click.option("--source_position", "-s", type=click.Path(exists=True, path_type=Path), required=True, help="Full path to source position (e.g., '/path/to/plate.zarr/A/1/0')") @click.option("--target_position", "-t", type=click.Path(exists=True, path_type=Path), required=True, help="Full path to target position (e.g., '/path/to/plate.zarr/B/2/0')") @@ -301,10 +223,9 @@ def encode_position_with_metadata( help="Path to the VAE model checkpoint for loading.") @click.option("--batch_size", "-b", type=int, default=4) @click.option("--device", "-d", type=str, default="cuda") -@click.option("--source_output", "-so", type=click.Path(path_type=Path), help="Path to save source embeddings") -@click.option("--target_output", "-to", type=click.Path(path_type=Path), help="Path to save target embeddings") +@click.option("--output_dir", "-o", type=click.Path(path_type=Path), help="Path to save source embeddings") def embed_dataset(source_position, target_position, source_channel, target_channel, z, y, x, - ckpt_path, batch_size, device, source_output, target_output) -> None: + ckpt_path, batch_size, device, output_dir) -> None: """Encode positions using a pre-trained VAE and optionally compute FID or save embeddings. This function loads two zarr positions, encodes them using a variational autoencoder, @@ -332,39 +253,8 @@ def embed_dataset(source_position, target_position, source_channel, target_chann Number of timepoints to process simultaneously through the VAE device : str Device to run computations on ("cuda" or "cpu") - - Examples - -------- - Compute FID score between two positions: - - $ python fid_ts.py \\ - --source_position /path/to/dataset1.zarr/A/1/0 \\ - --target_position /path/to/dataset2.zarr/A/1/0 \\ - --source_channel phase \\ - --target_channel phase \\ - --ckpt_path /path/to/vae_model.pt \\ - --compute_fid - - Save embeddings to parquet file: - - $ python fid_ts.py \\ - --source_position /path/to/dataset1.zarr/A/1/0 \\ - --target_position /path/to/dataset2.zarr/B/2/0 \\ - --source_channel phase \\ - --target_channel brightfield \\ - --ckpt_path /path/to/vae_model.pt \\ - --output_path embeddings.parquet - - Save embeddings and compute FID: - - $ python fid_ts.py \\ - --source_position /path/to/dataset1.zarr/A/1/0 \\ - --target_position /path/to/dataset2.zarr/B/2/0 \\ - --source_channel phase \\ - --target_channel brightfield \\ - --ckpt_path /path/to/vae_model.pt \\ - --output_path embeddings.parquet \\ - --compute_fid + output_dir : Path + Path to save embeddings """ # ----------------- VAE ----------------- # @@ -375,7 +265,6 @@ def embed_dataset(source_position, target_position, source_channel, target_chann vae = torch.jit.load(ckpt_path).to(device) vae.eval() - # ----------------- Load positions ------------ # source_position = open_ome_zarr(source_position) source_channel_names = source_position.channel_names assert source_channel in source_channel_names, f"Channel {source_channel} not found in source position" @@ -384,38 +273,36 @@ def embed_dataset(source_position, target_position, source_channel, target_chann target_channel_names = target_position.channel_names assert target_channel in target_channel_names, f"Channel {target_channel} not found in target position" - # ----------------- Embeddings ----------- # input_spatial_size = (z, y, x) - - if source_output or target_output: - # Generate source embeddings - if source_output: - source_ds = encode_position_with_metadata( - position=source_position, vae=vae, - channel_name=source_channel, - device=device, batch_size=batch_size, - input_spatial_size=input_spatial_size, - ) - source_output.parent.mkdir(parents=True, exist_ok=True) - source_ds.to_zarr(source_output, mode='w') - print(f"Source embeddings saved to: {source_output}") + if output_dir: + output_dir = Path(output_dir) + source_name = source_position.zgroup.name.split('/')[-1] if source_position.zgroup.name else "source" + target_name = target_position.zgroup.name.split('/')[-1] if target_position.zgroup.name else "target" + source_output = output_dir / f"{source_name}_{source_channel}.zarr" + target_output = output_dir / f"{target_name}_{target_channel}.zarr" + + source_ds = embed_position( + position=source_position, vae=vae, + channel_name=source_channel, + device=device, batch_size=batch_size, + input_spatial_size=input_spatial_size, + ) + source_ds.to_zarr(source_output, mode='w') + print(f"Source embeddings saved to: {source_output}") - # Generate target embeddings - if target_output: - target_ds = encode_position_with_metadata( - position=target_position, vae=vae, - channel_name=target_channel, - device=device, batch_size=batch_size, - input_spatial_size=input_spatial_size, - ) - target_output.parent.mkdir(parents=True, exist_ok=True) - target_ds.to_zarr(target_output, mode='w') - print(f"Target embeddings saved to: {target_output}") + target_ds = embed_position( + position=target_position, vae=vae, + channel_name=target_channel, + device=device, batch_size=batch_size, + input_spatial_size=input_spatial_size, + ) + target_ds.to_zarr(target_output, mode='w') + print(f"Target embeddings saved to: {target_output}") @click.command() -@click.option("--source_path", "-sp", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the source embeddings zarr file") -@click.option("--target_path", "-tp", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the target embeddings zarr file") +@click.option("--source_path", "-s", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the source embeddings zarr file") +@click.option("--target_path", "-t", type=click.Path(exists=True, path_type=Path), required=True, help="Path to the target embeddings zarr file") def compute_fid_cli(source_path: Path, target_path: Path) -> None: """Compute FID score between two embedding datasets. @@ -429,8 +316,8 @@ def compute_fid_cli(source_path: Path, target_path: Path) -> None: Examples -------- $ python fid_ts.py compute-fid \\ - -sp source_embeddings.zarr \\ - -tp target_embeddings.zarr + -s source_embeddings.zarr \\ + -t target_embeddings.zarr """ # Load the datasets source_ds = xr.open_zarr(source_path) diff --git a/applications/benchmarking/DynaCell/run_fid_ts.sh b/applications/benchmarking/DynaCell/run_fid_ts.sh index eb70ccd15..65c581715 100644 --- a/applications/benchmarking/DynaCell/run_fid_ts.sh +++ b/applications/benchmarking/DynaCell/run_fid_ts.sh @@ -5,8 +5,7 @@ python fid_ts.py embed \ -sc Nuclei-prediction \ -tc Organelle \ -c /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ - -so nuclei_prediction_embeddings.zarr \ - -to organelle_embeddings.zarr \ + -o . \ -b 4 \ -d cuda @@ -17,16 +16,15 @@ python fid_ts.py embed \ -sc Membrane-prediction \ -tc Membrane \ -c /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ - -so membrane_prediction_embeddings.zarr \ - -to membrane_embeddings.zarr \ + -o . \ -b 4 \ -d cuda # Compute FID from separate embedding files python fid_ts.py compute-fid \ - -sp nuclei_prediction_embeddings.zarr \ - -tp organelle_embeddings.zarr + -s _Nuclei-prediction.zarr \ + -t _Organelle.zarr python fid_ts.py compute-fid \ - -sp membrane_prediction_embeddings.zarr \ - -tp membrane_embeddings.zarr \ No newline at end of file + -s _Membrane-prediction.zarr \ + -t _Membrane.zarr \ No newline at end of file