diff --git a/examples/turbdiff/README.md b/examples/turbdiff/README.md new file mode 100644 index 000000000..74207b1b0 --- /dev/null +++ b/examples/turbdiff/README.md @@ -0,0 +1,36 @@ +# TurbDiff: Generative Modeling for 3D Flow Simulation + +This is an implementation of the TurbDiff model as described in the paper "From Zero to Turbulence: Generative Modeling for 3D Flow Simulation" (ICLR 2024) by Marten Lienen, David Lüdke, Jan Hansen-Palmus, and Stephan Günnemann. + +## Overview + +TurbDiff is a denoising diffusion probabilistic model (DDPM) designed for generating realistic 3D turbulent flow fields. Unlike traditional autoregressive approaches, TurbDiff directly learns the manifold of all possible turbulent flow states without relying on any initial flow state. + +## Model Architecture + +The model architecture consists of: + +1. A 3D U-Net backbone with attention mechanisms +2. Specialized conditioning for boundary conditions and geometry +3. A diffusion process based on the DDPM framework +4. Custom components for handling turbulent flow characteristics + +## Usage + +Please refer to the `train.py` and `infer.py` scripts for training and inference examples. + +## Citation + +``` +@inproceedings{lienen2024zero, + title = {From {{Zero}} to {{Turbulence}}: {{Generative Modeling}} for {{3D Flow Simulation}}}, + author = {Lienen, Marten and L{\"u}dke, David and {Hansen-Palmus}, Jan and G{\"u}nnemann, Stephan}, + booktitle = {International {{Conference}} on {{Learning Representations}}}, + year = {2024}, +} +``` + +## References + +- Original implementation: [https://github.com/martenlienen/generative-turbulence](https://github.com/martenlienen/generative-turbulence) +- Paper: [https://arxiv.org/abs/2306.01776](https://arxiv.org/abs/2306.01776) diff --git a/examples/turbdiff/conditioning.py b/examples/turbdiff/conditioning.py new file mode 100644 index 000000000..4be1cc428 --- /dev/null +++ b/examples/turbdiff/conditioning.py @@ -0,0 +1,168 @@ +""" +Conditioning module for TurbDiff model. +""" + +from enum import Enum, auto +import paddle +import paddle.nn as nn + + +class ConditioningType(Enum): + """Types of conditioning supported by the model.""" + LOCAL = auto() # Local conditioning like boundary conditions + GLOBAL = auto() # Global conditioning like domain parameters + + +class Conditioning: + """Handles the conditioning for the TurbDiff model.""" + + def __init__(self, cell_type_embedding=None, use_cell_pos=False): + """ + Initialize the conditioning. + + Args: + cell_type_embedding: Module to embed cell types + use_cell_pos: Whether to use cell positions as features + """ + self.cell_type_embedding = cell_type_embedding + self.use_cell_pos = use_cell_pos + + # Calculate conditioning dimensions + self.local_conditioning_dim = 0 + if cell_type_embedding is not None: + self.local_conditioning_dim += cell_type_embedding.embedding_dim + if use_cell_pos: + self.local_conditioning_dim += 3 # x, y, z positions + + self.global_conditioning_dim = 0 # Will be set by specific implementations + + def prepare_local_conditioning(self, data): + """ + Prepare local conditioning from data. + + Args: + data: Input data containing cell types and possibly other information + + Returns: + Tensor of local conditioning features + """ + conditioning_elements = [] + + # Add cell type embeddings if available + if self.cell_type_embedding is not None and hasattr(data, 'cell_type'): + cell_type_emb = self.cell_type_embedding(data.cell_type) + conditioning_elements.append(cell_type_emb) + + # Add cell positions if requested + if self.use_cell_pos and hasattr(data, 'cell_pos'): + # Normalize cell positions to [-1, 1] range + pos_min = paddle.min(data.cell_pos, axis=[0, 2, 3, 4], keepdim=True) + pos_max = paddle.max(data.cell_pos, axis=[0, 2, 3, 4], keepdim=True) + pos_norm = 2 * (data.cell_pos - pos_min) / (pos_max - pos_min + 1e-8) - 1 + conditioning_elements.append(pos_norm) + + # Combine all conditioning elements + if conditioning_elements: + return paddle.concat(conditioning_elements, axis=1) + else: + return None + + def prepare_global_conditioning(self, data): + """ + Prepare global conditioning from data. + + Args: + data: Input data containing global parameters + + Returns: + Tensor of global conditioning features or None + """ + # This should be implemented by specific model extensions + # For base implementation, return None + return None + + def prepare_conditioning(self, data): + """ + Prepare all conditioning from data. + + Args: + data: Input data + + Returns: + Dictionary of conditioning tensors by type + """ + conditioning = {} + + local_cond = self.prepare_local_conditioning(data) + if local_cond is not None: + conditioning[ConditioningType.LOCAL] = local_cond + + global_cond = self.prepare_global_conditioning(data) + if global_cond is not None: + conditioning[ConditioningType.GLOBAL] = global_cond + + return conditioning + + +class CellTypeEmbedding(nn.Layer): + """Embedding for different cell types (fluid, solid, boundary, etc.).""" + + def __init__(self, num_types, embedding_dim): + """ + Initialize the cell type embedding. + + Args: + num_types: Number of different cell types + embedding_dim: Dimension of the embedding + """ + super().__init__() + self.embedding = nn.Embedding(num_types, embedding_dim) + self.embedding_dim = embedding_dim + + def forward(self, cell_types): + """ + Get embeddings for cell types. + + Args: + cell_types: Tensor of cell type indices [B, 1, H, W, D] + + Returns: + Embedded tensor [B, embedding_dim, H, W, D] + """ + # Reshape for embedding lookup + shape = cell_types.shape + flat_types = paddle.reshape(cell_types, [shape[0], -1]) + + # Lookup embeddings + embeddings = self.embedding(flat_types) + + # Reshape back to original shape with embedding dimension + embeddings = paddle.reshape(embeddings, + [shape[0], shape[2], shape[3], shape[4], self.embedding_dim]) + embeddings = paddle.transpose(embeddings, [0, 4, 1, 2, 3]) + + return embeddings + + @classmethod + def create(cls, embedding_type, embedding_dim, num_types=5): + """ + Create a cell type embedding of the specified type. + + Args: + embedding_type: Type of embedding ('learned', 'fixed', etc.) + embedding_dim: Dimension of the embedding + num_types: Number of different cell types + + Returns: + CellTypeEmbedding instance + """ + if embedding_type == 'learned': + return cls(num_types, embedding_dim) + elif embedding_type == 'fixed': + embedding = cls(num_types, embedding_dim) + # Initialize with fixed values and freeze parameters + for param in embedding.parameters(): + param.stop_gradient = True + return embedding + else: + raise ValueError(f"Unknown embedding type: {embedding_type}") diff --git a/examples/turbdiff/config.yaml b/examples/turbdiff/config.yaml new file mode 100644 index 000000000..17c42c581 --- /dev/null +++ b/examples/turbdiff/config.yaml @@ -0,0 +1,57 @@ +# TurbDiff model configuration + +# Variables to model +variables: + - name: U + dims: 3 # Velocity vector field (u, v, w) + - name: p + dims: 1 # Pressure scalar field + +# Model architecture configuration +model: + dim: 64 # Base dimension for feature maps + u_net_levels: 4 # Number of U-Net downsampling/upsampling levels + actfn: SiLU # Activation function + norm_type: instance # Normalization type: instance, batch, layer, or none + with_geometry_embedding: true # Whether to use geometry embedding + cell_type_features: true # Whether to use cell type conditioning + cell_type_embedding_type: learned # Type of cell embedding: learned or fixed + cell_type_embedding_dim: 8 # Dimension of cell type embedding + cell_pos_features: true # Whether to use cell position features + num_cell_types: 5 # Number of different cell types + +# Diffusion process configuration +diffusion: + timesteps: 1000 # Number of diffusion timesteps + loss_type: l2 # Loss type: l1, l2, or huber + beta_schedule: sigmoid # Schedule for noise variance: linear, cosine, or sigmoid + clip_denoised: false # Whether to clip denoised values to [-1, 1] + noise_bcs: true # Whether to add noise to boundary conditions + learned_variances: false # Whether to predict variance + elbo_weight: null # Weight for evidence lower bound term + detach_elbo_mean: true # Whether to detach mean for ELBO calculation + +# Data configuration +data: + normalization_mode: mean-std # Normalization mode: mean-std, min-max, or none + num_workers: 4 # Number of data loader workers + +# Training configuration +training: + learning_rate: 1.0e-4 # Base learning rate + min_learning_rate: 1.0e-5 # Minimum learning rate for scheduler + warmup_steps: 1000 # Number of warmup steps for learning rate + weight_decay: 1.0e-4 # Weight decay for regularization + beta1: 0.9 # Adam beta1 + beta2: 0.999 # Adam beta2 + gradient_clip_val: 1.0 # Gradient clipping value + save_interval: 10 # Save checkpoint every N epochs + val_interval: 5 # Validate every N epochs + +# Inference configuration +inference: + num_timesteps: 100 # Number of timesteps for fast sampling + metrics: + - mse # Mean squared error + - psnr # Peak signal-to-noise ratio + - ssim # Structural similarity diff --git a/examples/turbdiff/data_utils.py b/examples/turbdiff/data_utils.py new file mode 100644 index 000000000..404d0a3d8 --- /dev/null +++ b/examples/turbdiff/data_utils.py @@ -0,0 +1,385 @@ +""" +Data handling utilities for TurbDiff model. +""" + +import os +import h5py +import numpy as np +import paddle +from paddle.io import Dataset, DataLoader + + +class Variable: + """Represents a physical variable in the flow field.""" + + def __init__(self, name, dims, index=None): + """ + Initialize variable. + + Args: + name: Variable name + dims: Number of dimensions (components) + index: Optional index for multi-variable storage + """ + self.name = name + self.dims = dims + self.index = index + + def __repr__(self): + return f"Variable({self.name}, dims={self.dims})" + + +class TurbulenceDataset(Dataset): + """Dataset for 3D turbulence data.""" + + def __init__(self, data_dir, split='train', variables=None, transform=None): + """ + Initialize the dataset. + + Args: + data_dir: Directory containing HDF5 data files + split: Data split ('train', 'val', 'test') + variables: List of variables to load + transform: Optional transform to apply to data + """ + super().__init__() + self.data_dir = data_dir + self.split = split + self.transform = transform + + # Default variables if none specified + self.variables = variables or [ + Variable('U', 3), # Velocity + Variable('p', 1), # Pressure + ] + + # Find all HDF5 files for the given split + self.data_files = [] + for file in os.listdir(os.path.join(data_dir, split)): + if file.endswith('.h5'): + self.data_files.append(os.path.join(data_dir, split, file)) + + # Load dataset statistics + self.stats = self._load_statistics() + + # Cache for data samples + self.cache = {} + self.max_cache_size = 100 # Adjust based on memory availability + + def _load_statistics(self): + """Load dataset statistics for normalization.""" + stats_file = os.path.join(self.data_dir, 'statistics.h5') + if not os.path.exists(stats_file): + print(f"Warning: Statistics file {stats_file} not found. Using default normalization.") + return self._default_statistics() + + with h5py.File(stats_file, 'r') as f: + stats = {} + for var in self.variables: + if var.name in f: + stats[var.name] = { + 'mean': paddle.to_tensor(f[var.name]['mean'][()]), + 'std': paddle.to_tensor(f[var.name]['std'][()]), + 'min': paddle.to_tensor(f[var.name]['min'][()]), + 'max': paddle.to_tensor(f[var.name]['max'][()]), + } + else: + print(f"Warning: Statistics for {var.name} not found. Using defaults.") + stats[var.name] = self._default_variable_statistics(var) + + return stats + + def _default_statistics(self): + """Create default statistics if no statistics file is available.""" + stats = {} + for var in self.variables: + stats[var.name] = self._default_variable_statistics(var) + return stats + + def _default_variable_statistics(self, var): + """Create default statistics for a variable.""" + if var.name == 'U': + # Velocity statistics + return { + 'mean': paddle.zeros([var.dims]), + 'std': paddle.ones([var.dims]), + 'min': paddle.full([var.dims], -10.0), + 'max': paddle.full([var.dims], 10.0), + } + elif var.name == 'p': + # Pressure statistics + return { + 'mean': paddle.zeros([var.dims]), + 'std': paddle.ones([var.dims]), + 'min': paddle.full([var.dims], -5.0), + 'max': paddle.full([var.dims], 5.0), + } + else: + # Default statistics for unknown variables + return { + 'mean': paddle.zeros([var.dims]), + 'std': paddle.ones([var.dims]), + 'min': paddle.full([var.dims], -1.0), + 'max': paddle.full([var.dims], 1.0), + } + + def __len__(self): + """Get number of samples in the dataset.""" + return len(self.data_files) + + def __getitem__(self, idx): + """ + Get a sample from the dataset. + + Args: + idx: Sample index + + Returns: + Dictionary containing data fields + """ + # Check if sample is in cache + if idx in self.cache: + return self.cache[idx] + + # Load sample from file + file_path = self.data_files[idx] + sample = self._load_sample(file_path) + + # Apply transforms if any + if self.transform: + sample = self.transform(sample) + + # Cache sample + if len(self.cache) < self.max_cache_size: + self.cache[idx] = sample + + return sample + + def _load_sample(self, file_path): + """ + Load a sample from an HDF5 file. + + Args: + file_path: Path to HDF5 file + + Returns: + Dictionary with data fields + """ + with h5py.File(file_path, 'r') as f: + sample = {} + + # Load variables + var_data = [] + for var in self.variables: + if var.name in f: + data = paddle.to_tensor(f[var.name][()], dtype=paddle.float32) + var_data.append(data) + else: + print(f"Warning: Variable {var.name} not found in {file_path}") + # Create zero tensor with appropriate shape + shape = list(f.attrs.get('grid_shape', [64, 64, 64])) + shape = [var.dims] + shape + var_data.append(paddle.zeros(shape, dtype=paddle.float32)) + + # Combine variables into a single tensor + if var_data: + sample['x'] = paddle.concat(var_data, axis=0) + + # Load cell types if available + if 'cell_type' in f: + sample['cell_type'] = paddle.to_tensor(f['cell_type'][()], dtype=paddle.int64) + + # Load cell positions if available + if 'cell_pos' in f: + sample['cell_pos'] = paddle.to_tensor(f['cell_pos'][()], dtype=paddle.float32) + + # Load cell indices (for boundary conditions) + if 'cell_idx' in f: + sample['cell_idx'] = paddle.to_tensor(f['cell_idx'][()], dtype=paddle.int64) + + # Load metadata + sample['metadata'] = { + 'filename': os.path.basename(file_path), + 'shape': sample['x'].shape[1:], + } + + # Add any other attributes from the file + for key, value in f.attrs.items(): + if isinstance(value, (int, float, str, bool, np.ndarray)): + sample['metadata'][key] = value + + return sample + + +class Normalization: + """Handles normalization and denormalization of data.""" + + def __init__(self, variables, mode='mean-std'): + """ + Initialize normalization. + + Args: + variables: List of variables to normalize + mode: Normalization mode ('mean-std', 'min-max', 'none') + """ + self.variables = variables + self.mode = mode + + def normalize(self, x, stats): + """ + Normalize data. + + Args: + x: Input tensor [B, C, H, W, D] + stats: Statistics dictionary + + Returns: + Normalized tensor + """ + if self.mode == 'none': + return x + + # Start with a copy of the input + normalized = paddle.clone(x) + + # Normalize each variable + start_idx = 0 + for var in self.variables: + if var.name in stats: + end_idx = start_idx + var.dims + + if self.mode == 'mean-std': + mean = stats[var.name]['mean'] + std = stats[var.name]['std'] + + # Handle broadcasting for channel dimension + if mean.ndim == 1: + mean = mean.reshape([1, -1, 1, 1, 1]) + std = std.reshape([1, -1, 1, 1, 1]) + + normalized[:, start_idx:end_idx] = (x[:, start_idx:end_idx] - mean) / (std + 1e-8) + + elif self.mode == 'min-max': + min_val = stats[var.name]['min'] + max_val = stats[var.name]['max'] + + # Handle broadcasting for channel dimension + if min_val.ndim == 1: + min_val = min_val.reshape([1, -1, 1, 1, 1]) + max_val = max_val.reshape([1, -1, 1, 1, 1]) + + normalized[:, start_idx:end_idx] = 2.0 * (x[:, start_idx:end_idx] - min_val) / (max_val - min_val + 1e-8) - 1.0 + + # Move to next variable + start_idx += var.dims + + return normalized + + def denormalize(self, x, stats): + """ + Denormalize data. + + Args: + x: Normalized tensor [B, C, H, W, D] + stats: Statistics dictionary + + Returns: + Denormalized tensor + """ + if self.mode == 'none': + return x + + # Start with a copy of the input + denormalized = paddle.clone(x) + + # Denormalize each variable + start_idx = 0 + for var in self.variables: + if var.name in stats: + end_idx = start_idx + var.dims + + if self.mode == 'mean-std': + mean = stats[var.name]['mean'] + std = stats[var.name]['std'] + + # Handle broadcasting for channel dimension + if mean.ndim == 1: + mean = mean.reshape([1, -1, 1, 1, 1]) + std = std.reshape([1, -1, 1, 1, 1]) + + denormalized[:, start_idx:end_idx] = x[:, start_idx:end_idx] * (std + 1e-8) + mean + + elif self.mode == 'min-max': + min_val = stats[var.name]['min'] + max_val = stats[var.name]['max'] + + # Handle broadcasting for channel dimension + if min_val.ndim == 1: + min_val = min_val.reshape([1, -1, 1, 1, 1]) + max_val = max_val.reshape([1, -1, 1, 1, 1]) + + denormalized[:, start_idx:end_idx] = (x[:, start_idx:end_idx] + 1.0) * 0.5 * (max_val - min_val + 1e-8) + min_val + + # Move to next variable + start_idx += var.dims + + return denormalized + + def normalize_batch(self, batch, stats): + """Normalize a batch of data.""" + if 'x' in batch: + batch['x_normalized'] = self.normalize(batch['x'], stats) + return batch + + def denormalize_batch(self, batch, stats): + """Denormalize a batch of data.""" + if 'x_normalized' in batch: + batch['x'] = self.denormalize(batch['x_normalized'], stats) + return batch + + +def create_dataloader(dataset, batch_size, shuffle=True, num_workers=0): + """ + Create a DataLoader for the dataset. + + Args: + dataset: Dataset instance + batch_size: Batch size + shuffle: Whether to shuffle the data + num_workers: Number of worker processes + + Returns: + DataLoader instance + """ + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + ) + + +def collate_fn(batch): + """ + Custom collate function for batching samples. + + Args: + batch: List of samples + + Returns: + Batched sample + """ + # Extract fields present in all samples + fields = batch[0].keys() + + result = {} + for field in fields: + if field == 'metadata': + # Metadata is not tensors, just collect in a list + result[field] = [sample[field] for sample in batch] + else: + # Stack tensors along the batch dimension + result[field] = paddle.stack([sample[field] for sample in batch]) + + return result diff --git a/examples/turbdiff/diffusion.py b/examples/turbdiff/diffusion.py new file mode 100644 index 000000000..e054998d0 --- /dev/null +++ b/examples/turbdiff/diffusion.py @@ -0,0 +1,408 @@ +""" +Diffusion process implementation for TurbDiff model in PaddleScience. +""" + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def linear_beta_schedule(timesteps): + """ + Linear schedule, proposed in original DDPM paper. + """ + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return paddle.linspace(beta_start, beta_end, timesteps, dtype=paddle.float32) + + +def log_linear_beta_schedule(timesteps): + """ + A version of the linear beta schedule that works for arbitrary timesteps. + """ + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + betas = paddle.linspace(beta_start, beta_end, timesteps, dtype=paddle.float32) + + # Map step indices from [0, T-1] to [-1, 1] + t = paddle.linspace(-1, 1, timesteps, dtype=paddle.float32) + + # Define the mapping function + a = 3 + # Make the log-slope more shallow for small indices + slope = (1 + t) ** a + # Normalize to maintain the same beta sum + slope = slope / paddle.sum(slope) * timesteps + + # Map to the corresponding beta value + mapped_indices = slope.cumsum(0) - 1 + mapped_indices = paddle.clip(mapped_indices, 0, timesteps - 1).astype(paddle.int64) + + return paddle.gather(betas, mapped_indices) + + +def log_snr_linear_beta_schedule(timesteps, snr_1=1e3, snr_T=1e-5): + """ + A beta schedule that decays the log-SNR linearly. + """ + log_snr_1 = math.log(snr_1) + log_snr_T = math.log(snr_T) + + # Linear schedule in log-SNR space + log_snr = paddle.linspace(log_snr_1, log_snr_T, timesteps, dtype=paddle.float32) + + # Convert log-SNR to alpha_cumprod using the formula + # log(snr) = log(alpha_cumprod / (1 - alpha_cumprod)) + alphas_cumprod = paddle.exp(log_snr) / (1 + paddle.exp(log_snr)) + + # Get betas from alphas_cumprod + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], [1, 0], value=1.0) + alphas = alphas_cumprod / alphas_cumprod_prev + betas = 1 - alphas + + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + t = paddle.linspace(0, timesteps, steps, dtype=paddle.float32) / timesteps + alphas_cumprod = paddle.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return paddle.clip(betas, 0, 0.999) + + +def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): + """ + Sigmoid schedule proposed in https://arxiv.org/abs/2212.11972 - Figure 8 + """ + steps = timesteps + 1 + t = paddle.linspace(0, timesteps, steps, dtype=paddle.float32) / timesteps + v_start = paddle.sigmoid(paddle.to_tensor(start / tau, dtype=paddle.float32)) + v_end = paddle.sigmoid(paddle.to_tensor(end / tau, dtype=paddle.float32)) + alphas_cumprod = paddle.sigmoid((t * (end - start) + start) / tau) + alphas_cumprod = (alphas_cumprod - v_start) / (v_end - v_start) + alphas_cumprod = paddle.clip(alphas_cumprod, clamp_min, 1.0) + alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] + betas = 1 - alphas + return betas + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + KL divergence between normal distributions parameterized by mean and log-variance. + """ + kl = 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + paddle.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2) + ) + return kl + + +def normal_log_lk(x, mean, log_var): + """Log-likelihood of x under the given normal distribution.""" + log_2pi = math.log(2 * math.pi) + return -0.5 * (log_2pi + log_var + (x - mean) ** 2 * paddle.exp(-log_var)) + + +class GaussianDiffusion(nn.Layer): + """ + Gaussian diffusion model for 3D turbulence. + """ + + def __init__( + self, + model, + *, + timesteps=1000, + loss_type="l2", + beta_schedule="sigmoid", + clip_denoised=False, + noise_bcs=False, + learned_variances=False, + elbo_weight=None, + detach_elbo_mean=True, + ): + super().__init__() + + self.model = model + self.clip_denoised = clip_denoised + self.noise_bcs = noise_bcs + self.learned_variances = learned_variances + self.elbo_weight = elbo_weight + self.detach_elbo_mean = detach_elbo_mean + + # Set up beta schedule + if beta_schedule == "linear": + beta_schedule_fn = linear_beta_schedule + elif beta_schedule == "log-linear": + beta_schedule_fn = log_linear_beta_schedule + elif beta_schedule == "log-snr-linear": + beta_schedule_fn = log_snr_linear_beta_schedule + elif beta_schedule == "cosine": + beta_schedule_fn = cosine_beta_schedule + elif beta_schedule == "sigmoid": + beta_schedule_fn = sigmoid_beta_schedule + else: + raise ValueError(f"Unknown beta schedule {beta_schedule}") + + betas = beta_schedule_fn(timesteps) + + alphas = 1.0 - betas + alphas_cumprod = paddle.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], [1, 0], value=1.0) + + self.num_timesteps = timesteps + self.loss_type = loss_type + + # Register buffers for diffusion process + self.register_buffer("betas", betas) + self.register_buffer("alphas_cumprod", alphas_cumprod) + + # Calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", paddle.sqrt(alphas_cumprod)) + self.register_buffer("sqrt_one_minus_alphas_cumprod", paddle.sqrt(1.0 - alphas_cumprod)) + self.register_buffer("sqrt_recip_alphas_cumprod", 1.0 / paddle.sqrt(alphas_cumprod)) + self.register_buffer("sqrt_recipm1_alphas_cumprod", paddle.sqrt(1.0 / alphas_cumprod - 1)) + + # Calculations for posterior q(x_{t-1} | x_t, x_0) + self.register_buffer("log_betas", paddle.log(betas)) + + # Posterior log var - numerically stable version + posterior_log_var = ( + self.log_betas + + paddle.log1p(-alphas_cumprod_prev) + - paddle.log1p(-alphas_cumprod) + ) + + # Adjust the first timestep to avoid -inf + posterior_log_var[0] = self.log_betas[0] * ( + posterior_log_var[1] / self.log_betas[1] + ) + self.register_buffer("posterior_log_var", posterior_log_var) + + self.register_buffer( + "posterior_mean_coef1", + betas * paddle.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), + ) + self.register_buffer( + "posterior_mean_coef2", + (1.0 - alphas_cumprod_prev) * paddle.sqrt(alphas) / (1.0 - alphas_cumprod), + ) + + def predict_start_from_noise(self, x_t, t, noise): + """Predict x_0 from noise.""" + return ( + self.sqrt_recip_alphas_cumprod[t] * x_t + - self.sqrt_recipm1_alphas_cumprod[t] * noise + ) + + def predict_noise_from_start(self, x_t, t, x0): + """Predict noise from x_0.""" + return ( + (self.sqrt_recip_alphas_cumprod[t] * x_t - x0) + / self.sqrt_recipm1_alphas_cumprod[t] + ) + + def q_posterior(self, x_start, x_t, t): + """Compute the posterior mean and log-variance.""" + posterior_mean = ( + self.posterior_mean_coef1[t] * x_start + self.posterior_mean_coef2[t] * x_t + ) + posterior_log_var = self.posterior_log_var[t] + return posterior_mean, posterior_log_var + + def model_predictions(self, x_t, t, C, cell_idx, clip_x_start=False): + """Get model predictions for x_0 and noise.""" + model_output = self.model(x_t, t, C) + + if self.learned_variances: + # Split the output into prediction and variance + model_output, model_log_var = paddle.split(model_output, 2, axis=1) + + # Apply boundary conditions + if cell_idx is not None: + bc_mask = paddle.zeros_like(cell_idx, dtype=paddle.float32) + bc_mask = paddle.where(cell_idx > 0, paddle.ones_like(bc_mask), bc_mask) + bc_mask = bc_mask.unsqueeze(1).tile([1, model_output.shape[1], 1, 1, 1]) + model_log_var = paddle.where(bc_mask > 0, model_log_var, paddle.ones_like(model_log_var) * -20) + else: + model_log_var = self.posterior_log_var[t] + + # Get x_0 prediction + x_start = self.predict_start_from_noise(x_t, t, model_output) + + if clip_x_start and self.clip_denoised: + x_start = paddle.clip(x_start, -1.0, 1.0) + + # Get the mean for q(x_{t-1} | x_t, x_0) + model_mean, _ = self.q_posterior(x_start, x_t, t) + + return model_output, x_start, model_mean, model_log_var + + def p_sample(self, x_t, t, C, cell_idx): + """Sample from p(x_{t-1} | x_t).""" + noise, x_start, model_mean, model_log_var = self.model_predictions( + x_t, t, C, cell_idx, clip_x_start=True + ) + + # No noise when t == 0 + nonzero_mask = paddle.cast(t > 0, dtype=paddle.float32) + nonzero_mask = nonzero_mask.reshape([-1, 1, 1, 1, 1]) + + # Sample from the predicted distribution + noise = paddle.randn(x_t.shape, dtype=paddle.float32) + + # Apply boundary conditions to noise + if cell_idx is not None: + # If cell_idx > 0, it's a boundary cell, so we don't add noise + bc_mask = paddle.zeros_like(cell_idx, dtype=paddle.float32) + bc_mask = paddle.where(cell_idx > 0, paddle.ones_like(bc_mask), bc_mask) + bc_mask = bc_mask.unsqueeze(1).tile([1, noise.shape[1], 1, 1, 1]) + + if not self.noise_bcs: + noise = paddle.where(bc_mask > 0, paddle.zeros_like(noise), noise) + + sample = model_mean + nonzero_mask * paddle.exp(0.5 * model_log_var) * noise + + return sample + + def p_sample_loop(self, x_bcs, C, cell_idx, pbar=False, start_from=None): + """Run the reverse diffusion process to generate samples.""" + # Start from pure noise + b = x_bcs.shape[0] + sample = paddle.randn(x_bcs.shape, dtype=paddle.float32) + + # Apply boundary conditions from the start + if cell_idx is not None: + bc_mask = paddle.zeros_like(cell_idx, dtype=paddle.float32) + bc_mask = paddle.where(cell_idx > 0, paddle.ones_like(bc_mask), bc_mask) + bc_mask = bc_mask.unsqueeze(1).tile([1, sample.shape[1], 1, 1, 1]) + sample = paddle.where(bc_mask > 0, x_bcs, sample) + + # Choose starting timestep + timesteps = self.num_timesteps + if start_from is not None: + timesteps = min(timesteps, start_from) + + # Iterate through all timesteps (or from start_from) + time_range = list(reversed(range(0, timesteps))) + if pbar: + # If pbar is True, we would use tqdm here in PyTorch, but for simplicity + # we'll just use the range directly in this implementation + pass + + for i in time_range: + t = paddle.full([b], i, dtype=paddle.int64) + sample = self.p_sample(sample, t, C, cell_idx) + + # Apply boundary conditions at each step + if cell_idx is not None: + bc_mask = paddle.zeros_like(cell_idx, dtype=paddle.float32) + bc_mask = paddle.where(cell_idx > 0, paddle.ones_like(bc_mask), bc_mask) + bc_mask = bc_mask.unsqueeze(1).tile([1, sample.shape[1], 1, 1, 1]) + sample = paddle.where(bc_mask > 0, x_bcs, sample) + + return sample + + def q_sample(self, x_start, t, noise=None): + """Sample from q(x_t | x_0).""" + if noise is None: + noise = paddle.randn(x_start.shape, dtype=paddle.float32) + + return ( + self.sqrt_alphas_cumprod[t].reshape([-1, 1, 1, 1, 1]) * x_start + + self.sqrt_one_minus_alphas_cumprod[t].reshape([-1, 1, 1, 1, 1]) * noise + ) + + def loss_fn(self): + """Get the appropriate loss function.""" + if self.loss_type == "l1": + return F.l1_loss + elif self.loss_type == "l2": + return F.mse_loss + else: + raise ValueError(f"Unknown loss type {self.loss_type}") + + def p_losses(self, x_start, t, C, cell_idx, cell_mask=None): + """Compute training losses.""" + # Generate random noise + noise = paddle.randn(x_start.shape, dtype=paddle.float32) + + # Apply boundary conditions to noise + if cell_idx is not None and not self.noise_bcs: + bc_mask = paddle.zeros_like(cell_idx, dtype=paddle.float32) + bc_mask = paddle.where(cell_idx > 0, paddle.ones_like(bc_mask), bc_mask) + bc_mask = bc_mask.unsqueeze(1).tile([1, noise.shape[1], 1, 1, 1]) + noise = paddle.where(bc_mask > 0, paddle.zeros_like(noise), noise) + + # Get noisy samples + x_t = self.q_sample(x_start, t, noise) + + # Get model predictions + model_output, x_0_pred, model_mean, model_log_var = self.model_predictions( + x_t, t, C, cell_idx + ) + + # Basic loss term + loss_fn = self.loss_fn() + + if self.learned_variances: + # If learning variances, predict noise + target = noise + else: + # Otherwise, directly predict x_0 + target = x_start + model_output = x_0_pred + + # Compute the basic loss + loss = loss_fn(model_output, target, reduction="none") + + # Apply cell mask if provided + if cell_mask is not None: + cell_mask = cell_mask.unsqueeze(1).tile([1, loss.shape[1], 1, 1, 1]) + loss = loss * cell_mask + + # Reduce over non-batch dimensions + loss = paddle.mean(loss, axis=[1, 2, 3, 4]) + + # Add ELBO term if specified + if self.elbo_weight is not None: + # Compute KL divergence for learned variances + # Between the model posterior and the true posterior + true_mean, true_log_var = self.q_posterior(x_start, x_t, t) + if self.detach_elbo_mean: + true_mean = true_mean.detach() + kl = normal_kl(model_mean, model_log_var, true_mean, true_log_var) + kl = paddle.mean(kl, axis=[1, 2, 3, 4]) + + # Compute log likelihood of x_start under the posterior + decoder_nll = -normal_log_lk(x_start, model_mean, model_log_var) + decoder_nll = paddle.mean(decoder_nll, axis=[1, 2, 3, 4]) + + # At t=0, use the decoder NLL, + # otherwise use the KL divergence + mask_t0 = (t == 0).astype(paddle.float32) + mask_not_t0 = 1 - mask_t0 + + elbo_loss = mask_t0 * decoder_nll + mask_not_t0 * kl + loss = loss + self.elbo_weight * elbo_loss + + return loss + + def forward(self, x, C, cell_idx=None, cell_mask=None): + """Forward pass for training.""" + b = x.shape[0] + t = paddle.randint(0, self.num_timesteps, [b], dtype=paddle.int64) + loss = self.p_losses(x, t, C, cell_idx, cell_mask) + return paddle.mean(loss), t diff --git a/examples/turbdiff/evaluate.py b/examples/turbdiff/evaluate.py new file mode 100644 index 000000000..26db88af0 --- /dev/null +++ b/examples/turbdiff/evaluate.py @@ -0,0 +1,556 @@ +""" +Evaluation script for TurbDiff model in PaddleScience. +""" + +import os +import argparse +import yaml +import paddle +import numpy as np +import h5py +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from skimage.metrics import structural_similarity as ssim +import pandas as pd + +from model import DenoisingModel +from diffusion import GaussianDiffusion +from data_utils import ( + Variable, TurbulenceDataset, Normalization, + create_dataloader +) +from conditioning import Conditioning, CellTypeEmbedding, ConditioningType +from infer import load_model, load_config, setup_environment, create_variables + + +def parse_args(): + parser = argparse.ArgumentParser(description='Evaluate TurbDiff model') + parser.add_argument('--config', type=str, default='config.yaml', + help='Path to configuration file') + parser.add_argument('--model_path', type=str, required=True, + help='Path to trained model checkpoint') + parser.add_argument('--data_dir', type=str, required=True, + help='Path to dataset directory') + parser.add_argument('--output_dir', type=str, default='evaluation_results', + help='Output directory for evaluation results') + parser.add_argument('--num_samples', type=int, default=50, + help='Number of samples to evaluate') + parser.add_argument('--batch_size', type=int, default=4, + help='Batch size for evaluation') + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility') + parser.add_argument('--device', type=str, default='gpu', + help='Device to use (gpu or cpu)') + parser.add_argument('--metrics', type=str, nargs='+', + default=['mse', 'mae', 'psnr', 'ssim', 'energy_spectrum'], + help='Metrics to evaluate') + parser.add_argument('--save_samples', action='store_true', + help='Save generated samples') + return parser.parse_args() + + +def calculate_mse(x, y): + """Calculate Mean Squared Error.""" + return paddle.mean((x - y) ** 2).item() + + +def calculate_mae(x, y): + """Calculate Mean Absolute Error.""" + return paddle.mean(paddle.abs(x - y)).item() + + +def calculate_psnr(x, y, data_range=None): + """Calculate Peak Signal-to-Noise Ratio.""" + if data_range is None: + data_range = paddle.max(y) - paddle.min(y) + + mse = calculate_mse(x, y) + if mse == 0: + return float('inf') + + return 20 * np.log10(data_range) - 10 * np.log10(mse) + + +def calculate_ssim(x, y): + """Calculate Structural Similarity Index.""" + # Convert to numpy arrays + x_np = x.numpy() + y_np = y.numpy() + + # Calculate SSIM for each channel and average + ssim_values = [] + for c in range(x_np.shape[1]): + for d in range(x_np.shape[4]): # For each slice in depth + ssim_val = ssim( + x_np[0, c, :, :, d], + y_np[0, c, :, :, d], + data_range=np.max(y_np[0, c, :, :, d]) - np.min(y_np[0, c, :, :, d]) + ) + ssim_values.append(ssim_val) + + return np.mean(ssim_values) + + +def calculate_energy_spectrum(velocity_field, dx=1.0): + """ + Calculate energy spectrum for a velocity field. + + Args: + velocity_field: Tensor of shape [3, H, W, D] (u, v, w components) + dx: Grid spacing + + Returns: + k: Wave numbers + E: Energy spectrum + """ + # Get shape and dimensions + if velocity_field.ndim > 3: + # If batch dimension is present, take first item + velocity_field = velocity_field[0] + + # Convert to numpy + u = velocity_field[0].numpy() + v = velocity_field[1].numpy() if velocity_field.shape[0] > 1 else np.zeros_like(u) + w = velocity_field[2].numpy() if velocity_field.shape[0] > 2 else np.zeros_like(u) + + # Get grid shape + nx, ny, nz = u.shape + + # Compute FFTs of velocity components + u_hat = np.fft.fftn(u) / (nx * ny * nz) + v_hat = np.fft.fftn(v) / (nx * ny * nz) + w_hat = np.fft.fftn(w) / (nx * ny * nz) + + # Create wavenumber grid + kx = 2 * np.pi * np.fft.fftfreq(nx, dx) + ky = 2 * np.pi * np.fft.fftfreq(ny, dx) + kz = 2 * np.pi * np.fft.fftfreq(nz, dx) + + # Create meshgrid + kxx, kyy, kzz = np.meshgrid(kx, ky, kz, indexing='ij') + k_squared = kxx**2 + kyy**2 + kzz**2 + + # Energy in spectral space + E_hat = 0.5 * (np.abs(u_hat)**2 + np.abs(v_hat)**2 + np.abs(w_hat)**2) + + # Compute energy spectrum by binning + k_min = 2 * np.pi / max(nx, ny, nz) + k_max = np.sqrt(3) * np.pi * min(nx, ny, nz) + k_bins = np.logspace(np.log10(k_min), np.log10(k_max), 32) + + # Initialize energy spectrum + E = np.zeros_like(k_bins[:-1]) + + # Bin energy + for i in range(len(k_bins) - 1): + k_lower = k_bins[i] + k_upper = k_bins[i+1] + + # Find wavenumbers in this bin + mask = (k_squared > k_lower**2) & (k_squared <= k_upper**2) + + # Accumulate energy + if np.any(mask): + E[i] = np.sum(E_hat[mask]) + + # Wave number for each bin (use midpoint) + k = 0.5 * (k_bins[1:] + k_bins[:-1]) + + return k, E + + +def compute_energy_spectrum_error(pred_k, pred_E, true_k, true_E): + """ + Compute error between predicted and true energy spectra. + + Args: + pred_k: Predicted wave numbers + pred_E: Predicted energy spectrum + true_k: True wave numbers + true_E: True energy spectrum + + Returns: + Error metric + """ + # Interpolate spectra to common wave numbers if they're different + if not np.array_equal(pred_k, true_k): + # Use a common k-space for comparison + common_k = np.unique(np.concatenate([pred_k, true_k])) + common_k.sort() + + # Interpolate to common k-space + from scipy.interpolate import interp1d + + # Handle zero values with small epsilon to enable log interpolation + epsilon = 1e-12 + + # Interpolate in log-log space + pred_interp = interp1d( + np.log10(pred_k), + np.log10(pred_E + epsilon), + kind='linear', + bounds_error=False, + fill_value='extrapolate' + ) + + true_interp = interp1d( + np.log10(true_k), + np.log10(true_E + epsilon), + kind='linear', + bounds_error=False, + fill_value='extrapolate' + ) + + # Get interpolated values + pred_E_common = 10 ** pred_interp(np.log10(common_k)) - epsilon + true_E_common = 10 ** true_interp(np.log10(common_k)) - epsilon + + # Use only valid region (where both spectra have data) + valid_mask = ( + (common_k >= max(pred_k[0], true_k[0])) & + (common_k <= min(pred_k[-1], true_k[-1])) + ) + + pred_E_valid = pred_E_common[valid_mask] + true_E_valid = true_E_common[valid_mask] + else: + # Same wave numbers, no interpolation needed + pred_E_valid = pred_E + true_E_valid = true_E + + # Compute relative error in energy spectrum + rel_error = np.mean(np.abs(pred_E_valid - true_E_valid) / (true_E_valid + 1e-8)) + + return rel_error + + +def evaluate_sample(generated, target, variables, metrics): + """ + Evaluate a generated sample against ground truth. + + Args: + generated: Generated sample tensor [B, C, H, W, D] + target: Ground truth tensor [B, C, H, W, D] + variables: List of variables + metrics: List of metrics to evaluate + + Returns: + Dictionary of metrics results + """ + results = {} + + # Calculate global metrics for all variables + if 'mse' in metrics: + results['mse'] = calculate_mse(generated, target) + + if 'mae' in metrics: + results['mae'] = calculate_mae(generated, target) + + if 'psnr' in metrics: + results['psnr'] = calculate_psnr(generated, target) + + if 'ssim' in metrics: + results['ssim'] = calculate_ssim(generated, target) + + # Calculate per-variable metrics + start_idx = 0 + for var in variables: + end_idx = start_idx + var.dims + + # Extract variable data + generated_var = generated[:, start_idx:end_idx] + target_var = target[:, start_idx:end_idx] + + # Calculate metrics for this variable + var_results = {} + + if 'mse' in metrics: + var_results['mse'] = calculate_mse(generated_var, target_var) + + if 'mae' in metrics: + var_results['mae'] = calculate_mae(generated_var, target_var) + + if 'psnr' in metrics: + var_results['psnr'] = calculate_psnr(generated_var, target_var) + + if 'ssim' in metrics: + var_results['ssim'] = calculate_ssim(generated_var, target_var) + + # Calculate energy spectrum for velocity field + if 'energy_spectrum' in metrics and var.name == 'U' and var.dims == 3: + # Get energy spectrum for generated sample + gen_k, gen_E = calculate_energy_spectrum(generated_var) + + # Get energy spectrum for target + true_k, true_E = calculate_energy_spectrum(target_var) + + # Calculate error + var_results['energy_spectrum_error'] = compute_energy_spectrum_error( + gen_k, gen_E, true_k, true_E + ) + + # Save spectra for later plotting + var_results['gen_spectrum'] = (gen_k, gen_E) + var_results['true_spectrum'] = (true_k, true_E) + + # Store variable results + results[var.name] = var_results + + # Move to next variable + start_idx = end_idx + + return results + + +def plot_energy_spectra(results, output_dir): + """ + Plot energy spectra from evaluation results. + + Args: + results: Dictionary of evaluation results + output_dir: Output directory for plots + """ + os.makedirs(os.path.join(output_dir, 'plots'), exist_ok=True) + + # Check if we have energy spectrum results for velocity + if 'U' in results and 'samples' in results and 'gen_spectrum' in results['U']: + # Create figure + plt.figure(figsize=(10, 8)) + + # Get sample results + sample_results = results['samples'] + + # Get average spectra + true_k, true_E = results['U']['true_spectrum'] + gen_k, gen_E = results['U']['gen_spectrum'] + + # Plot spectra + plt.loglog(true_k, true_E, 'b-', linewidth=2, label='Ground Truth') + plt.loglog(gen_k, gen_E, 'r-', linewidth=2, label='Generated') + + # Plot Kolmogorov -5/3 scaling law for reference + k_range = np.logspace(np.log10(true_k[5]), np.log10(true_k[-5]), 100) + # Scale to match true spectrum + scale_idx = len(true_k) // 3 + scale_factor = true_E[scale_idx] / (true_k[scale_idx] ** (-5/3)) + kolmogorov = scale_factor * k_range ** (-5/3) + plt.loglog(k_range, kolmogorov, 'k--', linewidth=1, label='k^(-5/3)') + + # Add labels and legend + plt.xlabel('Wave Number (k)') + plt.ylabel('Energy Spectrum E(k)') + plt.title('Turbulence Energy Spectrum Comparison') + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.legend() + + # Save figure + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'plots', 'energy_spectrum.png'), dpi=300) + plt.close() + + +def main(): + # Parse arguments + args = parse_args() + + # Load configuration + config = load_config(args.config) + + # Set up environment + setup_environment(args, config) + + # Create variables from config + variables = create_variables(config) + + # Load model + model, conditioning = load_model(args, config, variables) + + # Create test dataset + test_dataset = TurbulenceDataset( + data_dir=args.data_dir, + split='test', + variables=variables, + ) + + # Create data loader + test_loader = create_dataloader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=config['data'].get('num_workers', 0) + ) + + # Create normalization + normalization = Normalization( + variables, + mode=config['data'].get('normalization_mode', 'mean-std') + ) + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize results storage + all_results = { + 'samples': [], + 'global': {metric: [] for metric in args.metrics if metric != 'energy_spectrum'}, + } + + # Add per-variable metrics + for var in variables: + all_results[var.name] = {metric: [] for metric in args.metrics if metric != 'energy_spectrum'} + if var.name == 'U' and 'energy_spectrum' in args.metrics: + all_results[var.name]['energy_spectrum_error'] = [] + + # Process samples + print(f"Evaluating model on {min(args.num_samples, len(test_loader))} test samples...") + + # Set number of timesteps for inference + num_timesteps = config['inference'].get('num_timesteps', None) + + for batch_idx, batch in enumerate(tqdm(test_loader)): + if batch_idx >= args.num_samples: + break + + # Get ground truth + ground_truth = batch['x'] + + # Prepare conditioning + C = {} + if 'cell_type' in batch: + local_cond = conditioning.prepare_local_conditioning(batch) + if local_cond is not None: + C[ConditioningType.LOCAL] = local_cond + + # Generate sample + cell_idx = batch.get('cell_idx', None) + cell_mask = batch.get('cell_mask', None) + + # Generate random noise + shape = ground_truth.shape + noise = paddle.randn(shape) + + # Sample from the model + with paddle.no_grad(): + sample = model.p_sample_loop( + noise, + C, + cell_idx=cell_idx, + cell_mask=cell_mask, + verbose=False, + num_timesteps=num_timesteps + ) + + # Denormalize + denormalized_sample = normalization.denormalize(sample, test_dataset.stats) + denormalized_ground_truth = normalization.denormalize(ground_truth, test_dataset.stats) + + # Evaluate sample + results = evaluate_sample( + denormalized_sample, + denormalized_ground_truth, + variables, + args.metrics + ) + + # Store results + all_results['samples'].append(results) + + # Update global metrics + for metric in args.metrics: + if metric == 'energy_spectrum': + continue + if metric in results: + all_results['global'][metric].append(results[metric]) + + # Update per-variable metrics + for var in variables: + if var.name in results: + for metric_name, value in results[var.name].items(): + if not metric_name.startswith('gen_') and not metric_name.startswith('true_'): + all_results[var.name][metric_name].append(value) + + # Save sample if requested + if args.save_samples: + sample_dir = os.path.join(args.output_dir, 'samples') + os.makedirs(sample_dir, exist_ok=True) + + # Save as HDF5 + sample_path = os.path.join(sample_dir, f'sample_{batch_idx:04d}.h5') + with h5py.File(sample_path, 'w') as f: + # Save generated sample + f.create_dataset('generated', data=denormalized_sample.numpy()) + + # Save ground truth + f.create_dataset('ground_truth', data=denormalized_ground_truth.numpy()) + + # Calculate average metrics + avg_results = {'global': {}} + + # Global metrics + for metric in all_results['global']: + values = all_results['global'][metric] + avg_results['global'][metric] = np.mean(values) + + # Per-variable metrics + for var in variables: + avg_results[var.name] = {} + for metric in all_results[var.name]: + values = all_results[var.name][metric] + avg_results[var.name][metric] = np.mean(values) + + # Print results + print("\nEvaluation Results:") + print("===================") + + print("\nGlobal Metrics:") + for metric, value in avg_results['global'].items(): + print(f" {metric}: {value:.6f}") + + print("\nPer-Variable Metrics:") + for var in variables: + print(f" {var.name}:") + for metric, value in avg_results[var.name].items(): + if not metric.startswith('gen_') and not metric.startswith('true_'): + print(f" {metric}: {value:.6f}") + + # Save results to CSV + results_df = pd.DataFrame() + + # Add global metrics + for metric, value in avg_results['global'].items(): + results_df.loc['global', metric] = value + + # Add per-variable metrics + for var in variables: + for metric, value in avg_results[var.name].items(): + if not metric.startswith('gen_') and not metric.startswith('true_'): + results_df.loc[var.name, metric] = value + + # Save DataFrame + results_df.to_csv(os.path.join(args.output_dir, 'evaluation_results.csv')) + + # Save full results + with open(os.path.join(args.output_dir, 'all_results.yaml'), 'w') as f: + # Filter out non-serializable items + serializable_results = { + 'global': all_results['global'], + } + + for var in variables: + serializable_results[var.name] = {} + for metric, values in all_results[var.name].items(): + if not metric.startswith('gen_') and not metric.startswith('true_'): + serializable_results[var.name][metric] = values + + yaml.dump(serializable_results, f) + + # Plot energy spectra if applicable + if 'energy_spectrum' in args.metrics: + plot_energy_spectra(all_results, args.output_dir) + + print(f"\nResults saved to {args.output_dir}") + + +if __name__ == '__main__': + main() diff --git a/examples/turbdiff/infer.py b/examples/turbdiff/infer.py new file mode 100644 index 000000000..7fc78b192 --- /dev/null +++ b/examples/turbdiff/infer.py @@ -0,0 +1,549 @@ +""" +Inference script for TurbDiff model in PaddleScience. +""" + +import os +import argparse +import yaml +import paddle +import numpy as np +import h5py +import time +from tqdm import tqdm +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +from model import DenoisingModel +from diffusion import GaussianDiffusion +from data_utils import ( + Variable, TurbulenceDataset, Normalization, + create_dataloader +) +from conditioning import Conditioning, CellTypeEmbedding, ConditioningType + + +def parse_args(): + parser = argparse.ArgumentParser(description='Inference with TurbDiff model') + parser.add_argument('--config', type=str, default='config.yaml', + help='Path to configuration file') + parser.add_argument('--model_path', type=str, required=True, + help='Path to trained model checkpoint') + parser.add_argument('--data_dir', type=str, required=True, + help='Path to dataset directory') + parser.add_argument('--output_dir', type=str, default='inference_results', + help='Output directory for generated samples') + parser.add_argument('--num_samples', type=int, default=10, + help='Number of samples to generate') + parser.add_argument('--batch_size', type=int, default=1, + help='Batch size for inference') + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility') + parser.add_argument('--device', type=str, default='gpu', + help='Device to use (gpu or cpu)') + parser.add_argument('--mode', type=str, default='sample', + choices=['sample', 'interpolate', 'unconditional'], + help='Inference mode') + parser.add_argument('--visualize', action='store_true', + help='Visualize generated samples') + return parser.parse_args() + + +def load_config(config_path): + """Load configuration from YAML file.""" + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + return config + + +def setup_environment(args, config): + """Set up inference environment.""" + # Set random seed + paddle.seed(args.seed) + np.random.seed(args.seed) + + # Set device + paddle.set_device(args.device) + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + +def create_variables(config): + """Create variable definitions from config.""" + variables = [] + for var_config in config['variables']: + variables.append(Variable(var_config['name'], var_config['dims'])) + return variables + + +def load_model(args, config, variables): + """Load the trained model.""" + # Calculate total dimensions for variables + vars_dim = sum(var.dims for var in variables) + + # Create cell type embedding if specified + cell_type_embedding = None + if config['model'].get('cell_type_features', True): + cell_type_embedding = CellTypeEmbedding.create( + config['model'].get('cell_type_embedding_type', 'learned'), + config['model'].get('cell_type_embedding_dim', 4), + config['model'].get('num_cell_types', 5) + ) + + # Create conditioning module + conditioning = Conditioning( + cell_type_embedding=cell_type_embedding, + use_cell_pos=config['model'].get('cell_pos_features', False) + ) + + # Create denoising model + model = DenoisingModel( + in_features=vars_dim, + out_features=vars_dim * (2 if config['diffusion'].get('learned_variances', False) else 1), + c_local_features=conditioning.local_conditioning_dim, + c_global_features=conditioning.global_conditioning_dim, + timesteps=config['diffusion'].get('timesteps', 1000), + dim=config['model'].get('dim', 32), + u_net_levels=config['model'].get('u_net_levels', 4), + actfn=getattr(paddle.nn, config['model'].get('actfn', 'Silu')), + norm_type=config['model'].get('norm_type', 'instance'), + with_geometry_embedding=config['model'].get('with_geometry_embedding', True), + ) + + # Wrap with diffusion model + diffusion = GaussianDiffusion( + model, + timesteps=config['diffusion'].get('timesteps', 1000), + loss_type=config['diffusion'].get('loss_type', 'l2'), + beta_schedule=config['diffusion'].get('beta_schedule', 'sigmoid'), + clip_denoised=config['diffusion'].get('clip_denoised', False), + noise_bcs=config['diffusion'].get('noise_bcs', False), + learned_variances=config['diffusion'].get('learned_variances', False), + elbo_weight=config['diffusion'].get('elbo_weight', None), + detach_elbo_mean=config['diffusion'].get('detach_elbo_mean', True), + ) + + # Load checkpoint + checkpoint = paddle.load(args.model_path) + diffusion.set_state_dict(checkpoint['model_state_dict']) + + # Switch to evaluation mode + diffusion.eval() + + return diffusion, conditioning + + +def load_test_dataset(args, config, variables): + """Load test dataset for conditioning and evaluation.""" + # Create test dataset + test_dataset = TurbulenceDataset( + data_dir=args.data_dir, + split='test', + variables=variables, + ) + + # Create data loader + test_loader = create_dataloader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=config['data'].get('num_workers', 0) + ) + + # Create normalization + normalization = Normalization( + variables, + mode=config['data'].get('normalization_mode', 'mean-std') + ) + + return test_dataset, test_loader, normalization + + +def sample_from_model(model, conditioning, batch, normalization, stats, num_timesteps=None): + """Sample from the model using the provided conditioning.""" + # Prepare conditioning + C = {} + if 'cell_type' in batch: + local_cond = conditioning.prepare_local_conditioning(batch) + if local_cond is not None: + C[ConditioningType.LOCAL] = local_cond + + # Get shape from the input batch + shape = batch['x'].shape + + # Generate random noise + noise = paddle.randn(shape) + + # Sample from the model + cell_idx = batch.get('cell_idx', None) + cell_mask = batch.get('cell_mask', None) + + # Start timer + start_time = time.time() + + with paddle.no_grad(): + sample = model.p_sample_loop( + noise, + C, + cell_idx=cell_idx, + cell_mask=cell_mask, + verbose=True, + num_timesteps=num_timesteps + ) + + # End timer + sample_time = time.time() - start_time + + # Denormalize the sample + denormalized_sample = normalization.denormalize(sample, stats) + + return denormalized_sample, sample_time + + +def interpolate_samples(model, conditioning, batch1, batch2, normalization, stats, num_steps=5, num_timesteps=None): + """Interpolate between two conditioning sets.""" + # Prepare conditioning for first batch + C1 = {} + if 'cell_type' in batch1: + local_cond1 = conditioning.prepare_local_conditioning(batch1) + if local_cond1 is not None: + C1[ConditioningType.LOCAL] = local_cond1 + + # Prepare conditioning for second batch + C2 = {} + if 'cell_type' in batch2: + local_cond2 = conditioning.prepare_local_conditioning(batch2) + if local_cond2 is not None: + C2[ConditioningType.LOCAL] = local_cond2 + + # Get shape from the input batch + shape = batch1['x'].shape + + # Generate same random noise for all interpolations + noise = paddle.randn(shape) + + # Sample with interpolated conditioning + samples = [] + for alpha in np.linspace(0, 1, num_steps): + # Interpolate conditioning + C = {} + if ConditioningType.LOCAL in C1 and ConditioningType.LOCAL in C2: + C[ConditioningType.LOCAL] = (1 - alpha) * C1[ConditioningType.LOCAL] + alpha * C2[ConditioningType.LOCAL] + + # Get cell indices from first batch (could be modified to interpolate these too) + cell_idx = batch1.get('cell_idx', None) + cell_mask = batch1.get('cell_mask', None) + + with paddle.no_grad(): + sample = model.p_sample_loop( + noise, + C, + cell_idx=cell_idx, + cell_mask=cell_mask, + verbose=False, + num_timesteps=num_timesteps + ) + + # Denormalize the sample + denormalized_sample = normalization.denormalize(sample, stats) + samples.append(denormalized_sample) + + return samples + + +def generate_unconditional_samples(model, shape, normalization, stats, num_samples=1, num_timesteps=None): + """Generate unconditional samples.""" + samples = [] + + for _ in range(num_samples): + # Generate random noise + noise = paddle.randn(shape) + + # Sample from the model without conditioning + with paddle.no_grad(): + sample = model.p_sample_loop( + noise, + {}, + verbose=False, + num_timesteps=num_timesteps + ) + + # Denormalize the sample + denormalized_sample = normalization.denormalize(sample, stats) + samples.append(denormalized_sample) + + return samples + + +def save_samples(samples, variables, output_dir, prefix='sample'): + """Save generated samples to HDF5 files.""" + os.makedirs(output_dir, exist_ok=True) + + # Handle both single sample and batch of samples + if not isinstance(samples, list): + samples = [samples] + + for i, sample in enumerate(samples): + # Create HDF5 file + file_path = os.path.join(output_dir, f"{prefix}_{i:04d}.h5") + with h5py.File(file_path, 'w') as f: + # Split sample by variables + start_idx = 0 + for var in variables: + end_idx = start_idx + var.dims + var_data = sample[:, start_idx:end_idx].numpy() + + # Save variable data + if var.dims == 1: + # Scalar field + f.create_dataset(var.name, data=var_data[0]) + else: + # Vector field components + f.create_dataset(var.name, data=var_data[0]) + + # Move to next variable + start_idx = end_idx + + # Save metadata + f.attrs['generated'] = True + f.attrs['generated_time'] = np.string_(time.strftime("%Y-%m-%d %H:%M:%S")) + + print(f"Saved {len(samples)} samples to {output_dir}") + + +def visualize_sample(sample, variables, output_dir, idx=0): + """Visualize a generated sample.""" + os.makedirs(os.path.join(output_dir, 'visualizations'), exist_ok=True) + + # Get the first batch item + if sample.ndim > 4: # [B, C, H, W, D] + sample = sample[0] # [C, H, W, D] + + # Get grid dimensions + _, h, w, d = sample.shape + + # Create a mesh grid for visualization + y, x, z = np.meshgrid( + np.linspace(0, 1, h), + np.linspace(0, 1, w), + np.linspace(0, 1, d), + ) + + # Start variable index + start_idx = 0 + + # Process each variable + for var in variables: + end_idx = start_idx + var.dims + var_data = sample[start_idx:end_idx].numpy() + + if var.name == 'U' and var.dims == 3: + # Vector field (velocity) + u = var_data[0] + v = var_data[1] + w = var_data[2] + + # Calculate velocity magnitude + magnitude = np.sqrt(u**2 + v**2 + w**2) + + # Create 3D plot for velocity magnitude + fig = plt.figure(figsize=(10, 8)) + ax = fig.add_subplot(111, projection='3d') + + # Plot a slice of the velocity field + slice_idx = d // 2 + sc = ax.scatter( + x[:, :, slice_idx].flatten(), + y[:, :, slice_idx].flatten(), + z[:, :, slice_idx].flatten(), + c=magnitude[:, :, slice_idx].flatten(), + cmap='viridis', + s=2 + ) + + # Add colorbar + plt.colorbar(sc, ax=ax, label='Velocity Magnitude') + + # Set labels and title + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_title(f'Velocity Magnitude (Z-slice at {slice_idx/d:.2f})') + + # Save figure + plt.savefig(os.path.join(output_dir, 'visualizations', f'velocity_magnitude_{idx:04d}.png')) + plt.close() + + # Create 2D slices plot + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + + # X-slice + im0 = axes[0].imshow(magnitude[w//2, :, :].T, cmap='viridis', origin='lower') + axes[0].set_title(f'X-slice at {w//2/w:.2f}') + plt.colorbar(im0, ax=axes[0]) + + # Y-slice + im1 = axes[1].imshow(magnitude[:, h//2, :].T, cmap='viridis', origin='lower') + axes[1].set_title(f'Y-slice at {h//2/h:.2f}') + plt.colorbar(im1, ax=axes[1]) + + # Z-slice + im2 = axes[2].imshow(magnitude[:, :, d//2], cmap='viridis', origin='lower') + axes[2].set_title(f'Z-slice at {d//2/d:.2f}') + plt.colorbar(im2, ax=axes[2]) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'visualizations', f'velocity_slices_{idx:04d}.png')) + plt.close() + + elif var.name == 'p' and var.dims == 1: + # Scalar field (pressure) + p = var_data[0] + + # Create 2D slices plot + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + + # X-slice + im0 = axes[0].imshow(p[w//2, :, :].T, cmap='coolwarm', origin='lower') + axes[0].set_title(f'Pressure X-slice at {w//2/w:.2f}') + plt.colorbar(im0, ax=axes[0]) + + # Y-slice + im1 = axes[1].imshow(p[:, h//2, :].T, cmap='coolwarm', origin='lower') + axes[1].set_title(f'Pressure Y-slice at {h//2/h:.2f}') + plt.colorbar(im1, ax=axes[1]) + + # Z-slice + im2 = axes[2].imshow(p[:, :, d//2], cmap='coolwarm', origin='lower') + axes[2].set_title(f'Pressure Z-slice at {d//2/d:.2f}') + plt.colorbar(im2, ax=axes[2]) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'visualizations', f'pressure_slices_{idx:04d}.png')) + plt.close() + + # Move to next variable + start_idx = end_idx + + +def main(): + # Parse arguments + args = parse_args() + + # Load configuration + config = load_config(args.config) + + # Set up environment + setup_environment(args, config) + + # Create variables from config + variables = create_variables(config) + + # Load model + model, conditioning = load_model(args, config, variables) + + # Load test dataset + test_dataset, test_loader, normalization = load_test_dataset(args, config, variables) + + # Set number of timesteps for inference + num_timesteps = config['inference'].get('num_timesteps', None) + + # Generate samples based on mode + if args.mode == 'sample': + print(f"Generating {args.num_samples} samples from test set conditions...") + + # Limit to requested number of samples + sample_count = 0 + + for batch in tqdm(test_loader): + if sample_count >= args.num_samples: + break + + # Generate sample + sample, sample_time = sample_from_model( + model, + conditioning, + batch, + normalization, + test_dataset.stats, + num_timesteps + ) + + # Save sample + save_samples(sample, variables, args.output_dir, prefix=f'sample_{sample_count:04d}') + + # Visualize if requested + if args.visualize: + visualize_sample(sample, variables, args.output_dir, sample_count) + + # Print timing + print(f"Sample {sample_count} generated in {sample_time:.2f}s") + + sample_count += 1 + + elif args.mode == 'interpolate': + print("Generating interpolated samples...") + + # Get two different conditioning samples + if len(test_loader) < 2: + print("Error: Need at least 2 samples in test set for interpolation") + return + + # Get the first two batches + batches = [] + for i, batch in enumerate(test_loader): + batches.append(batch) + if i >= 1: + break + + # Interpolate between the two batches + interpolated_samples = interpolate_samples( + model, + conditioning, + batches[0], + batches[1], + normalization, + test_dataset.stats, + num_steps=args.num_samples, + num_timesteps=num_timesteps + ) + + # Save interpolated samples + save_samples(interpolated_samples, variables, args.output_dir, prefix='interpolation') + + # Visualize if requested + if args.visualize: + for i, sample in enumerate(interpolated_samples): + visualize_sample(sample, variables, args.output_dir, i) + + elif args.mode == 'unconditional': + print("Generating unconditional samples...") + + # Get shape from test dataset + for batch in test_loader: + shape = batch['x'].shape + break + + # Generate unconditional samples + unconditional_samples = generate_unconditional_samples( + model, + shape, + normalization, + test_dataset.stats, + num_samples=args.num_samples, + num_timesteps=num_timesteps + ) + + # Save unconditional samples + save_samples(unconditional_samples, variables, args.output_dir, prefix='unconditional') + + # Visualize if requested + if args.visualize: + for i, sample in enumerate(unconditional_samples): + visualize_sample(sample, variables, args.output_dir, i) + + print("Inference completed!") + + +if __name__ == '__main__': + main() diff --git a/examples/turbdiff/model.py b/examples/turbdiff/model.py new file mode 100644 index 000000000..7237a89d3 --- /dev/null +++ b/examples/turbdiff/model.py @@ -0,0 +1,452 @@ +""" +TurbDiff model implementation for PaddleScience. +Based on the paper "From Zero to Turbulence: Generative Modeling for 3D Flow Simulation" +by Marten Lienen, David Lüdke, Jan Hansen-Palmus, and Stephan Günnemann. +""" + +import math +from dataclasses import dataclass +from functools import partial +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +@dataclass +class ModelPrediction: + """Model prediction output container.""" + noise: paddle.Tensor + x_start: paddle.Tensor + mean: paddle.Tensor + log_var: paddle.Tensor + + +# Small helper modules + +class Residual(nn.Layer): + """Residual connection wrapper for a layer.""" + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +def pad_to_multiple_of(x: paddle.Tensor, n: int, *, mode: str): + """Pad tensor to be a multiple of n in each spatial dimension.""" + h, w, d = x.shape[-3:] + h_pad = n - h % n if h % n != 0 else 0 + w_pad = n - w % n if w % n != 0 else 0 + d_pad = n - d % n if d % n != 0 else 0 + + if min(h_pad, w_pad, d_pad) > 0: + return F.pad(x, [0, d_pad, 0, w_pad, 0, h_pad], mode=mode), ( + h_pad, + w_pad, + d_pad, + ) + else: + return x, (0, 0, 0) + + +def unpad(x: paddle.Tensor, padding): + """Remove padding from tensor.""" + h_pad, w_pad, d_pad = padding + if min(padding) > 0: + return x[..., :-h_pad if h_pad > 0 else None, + :-w_pad if w_pad > 0 else None, + :-d_pad if d_pad > 0 else None] + else: + return x + + +class PreNorm(nn.Layer): + """Apply normalization before a function.""" + def __init__(self, norm: nn.Layer, fn: nn.Layer): + super().__init__() + self.norm = norm + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + + +# Sinusoidal positional embeddings + +class SinusoidalPosEmb(nn.Layer): + """Sinusoidal positional embedding.""" + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(paddle.arange(half_dim, dtype=paddle.float32) * -emb) + emb = t[:, None] * emb[None, :] + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) + return emb + + +class NyquistFrequencyEmbedding(nn.Layer): + """ + Sine-cosine embedding for timesteps that scales from 1/8 to a (< 1) multiple of + the Nyquist frequency. + """ + def __init__(self, dim: int, timesteps: int): + super().__init__() + assert dim % 2 == 0 + + T = timesteps + k = dim // 2 + + # Nyquist frequency for T samples per cycle + nyquist_frequency = T / 2 + + golden_ratio = (1 + np.sqrt(5)) / 2 + frequencies = np.geomspace(1 / 8, nyquist_frequency / (2 * golden_ratio), num=k) + + # Sample every frequency twice, once shifted by pi/2 to get cosine + scale = np.repeat(2 * np.pi * frequencies / timesteps, 2) + bias = np.tile(np.array([0, np.pi / 2]), k) + + self.scale = paddle.to_tensor(scale, dtype=paddle.float32) + self.bias = paddle.to_tensor(bias, dtype=paddle.float32) + + def forward(self, t): + # paddle equivalent of torch.addcmul + phase = self.bias + self.scale * t[..., None] + return paddle.sin(phase) + + +# Building block modules + +class Block(nn.Layer): + """Basic convolutional block with normalization and activation.""" + def __init__( + self, + dim, + dim_out, + actfn, + norm_klass=None, + ): + super().__init__() + self.conv = nn.Conv3D(dim, dim_out, 3, padding=1, padding_mode='replicate') + self.norm = norm_klass(dim_out) + self.act = actfn() + + def forward(self, x, scale_shift=None): + x = self.conv(x) + x = self.norm(x) + + if scale_shift is not None: + scale, shift = scale_shift + x = shift + (scale + 1) * x + + x = self.act(x) + return x + + +class ResnetBlock(nn.Layer): + """Residual block with conditioning.""" + def __init__(self, dim_in, dim_out, *, c_dim: int, actfn, norm_klass): + super().__init__() + + self.project_onto_scale_shift = nn.Linear(c_dim, dim_out * 2) + + self.block1 = Block(dim_in, dim_out, actfn=actfn, norm_klass=norm_klass) + self.block2 = Block(dim_out, dim_out, actfn=actfn, norm_klass=norm_klass) + self.conv = nn.Conv3D(dim_in, dim_out, 1) if dim_in != dim_out else nn.Identity() + + def forward(self, x, c): + # Reshape conditioning output + c = self.project_onto_scale_shift(c) + c = paddle.reshape(c, shape=[*c.shape[:-1], c.shape[-1], 1, 1, 1]) + scale, shift = paddle.split(c, 2, axis=-4) + + h = self.block1(x, scale_shift=(scale, shift)) + h = self.block2(h) + + return h + self.conv(x) + + +class LinearAttention(nn.Layer): + """Linear attention mechanism.""" + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv3D(dim, hidden_dim * 3, 1, bias_attr=False) + + self.to_out = nn.Sequential( + nn.Conv3D(hidden_dim, dim, 1), + nn.GroupNorm(1, dim) + ) + + def forward(self, x): + b, c, h, w, d = x.shape + qkv = self.to_qkv(x).chunk(3, axis=1) + q, k, v = map(lambda t: paddle.reshape( + t, [b, self.heads, -1, h * w * d]), qkv) + + q = q * self.scale + + k = paddle.transpose(k, [0, 1, 3, 2]) + context = paddle.matmul(k, v) + + out = paddle.matmul(context, q) + out = paddle.reshape(out, [b, self.heads, -1, h, w, d]) + out = paddle.transpose(out, [0, 1, 2, 3, 4, 5]) + out = paddle.reshape(out, [b, -1, h, w, d]) + return self.to_out(out) + + +class Attention(nn.Layer): + """Standard attention mechanism.""" + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + + self.to_qkv = nn.Conv3D(dim, hidden_dim * 3, 1, bias_attr=False) + self.to_out = nn.Conv3D(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w, d = x.shape + qkv = self.to_qkv(x).chunk(3, axis=1) + q, k, v = map(lambda t: paddle.reshape( + t, [b, self.heads, -1, h * w * d]), qkv) + + # Transpose for attention dot product + q = paddle.transpose(q, [0, 1, 3, 2]) + v = paddle.transpose(v, [0, 1, 3, 2]) + + # Scale dot-product attention + dots = paddle.matmul(q, k) * (k.shape[-1] ** -0.5) + attn = F.softmax(dots, axis=-1) + out = paddle.matmul(attn, v) + + out = paddle.transpose(out, [0, 1, 3, 2]) + out = paddle.reshape(out, [b, -1, h, w, d]) + + return self.to_out(out) + + +# U-Net model + +class UNet(nn.Layer): + """ + A general U-Net structure with interpolation instead of max-pool and transposed + convolutions. + """ + def __init__( + self, + downsampling_blocks, + upsampling_blocks, + center_block, + *, + downsampling_factor=2.0, + ): + super().__init__() + + assert len(downsampling_blocks) == len(upsampling_blocks) + + self.downsampling_blocks = nn.LayerList(downsampling_blocks) + self.upsampling_blocks = nn.LayerList(upsampling_blocks) + self.center_block = center_block + self.downsampling_factor = downsampling_factor + self.scale_factor = 1 / downsampling_factor + + def forward(self, x, *args, **kwargs): + h = [x] + + # Downsample + for block in self.downsampling_blocks: + h.append(block(h[-1], *args, **kwargs)) + h[-1] = F.interpolate( + h[-1], + scale_factor=self.scale_factor, + mode="trilinear", + align_corners=False, + data_format="NCDHW" + ) + + # Center block + h[-1] = self.center_block(h[-1], *args, **kwargs) + + # Upsample + for i, block in enumerate(self.upsampling_blocks): + h[-1] = F.interpolate( + h[-1], + size=h[-i - 2].shape[-3:], + mode="trilinear", + align_corners=False, + data_format="NCDHW" + ) + h[-1] = block(paddle.concat([h[-1], h[-i - 2]], axis=1), *args, **kwargs) + + return h[-1] + + +class GeometryEmbedding(nn.Layer): + """Extract geometry features from local conditioning.""" + def __init__(self, in_features, out_features, actfn): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.actfn = actfn + + self.extract_features = nn.Sequential( + nn.Conv3D(in_features, out_features, kernel_size=5, stride=5), + actfn(), + nn.Conv3D(out_features, out_features, kernel_size=5, stride=1), + actfn(), + nn.Conv3D(out_features, out_features, kernel_size=5, stride=5), + ) + + def forward(self, c_local): + # We pool the geometry embedding over all spatial dimensions + # to get a global feature vector + h = self.extract_features(c_local) + return paddle.mean(h, axis=[-3, -2, -1]) + + +class DenoisingModel(nn.Layer): + """Core denoising model for the diffusion process.""" + def __init__( + self, + *, + in_features: int, + out_features: int, + c_local_features: int, + c_global_features: int, + timesteps: int, + dim: int, + u_net_levels: int, + actfn=nn.Silu, + norm_type: str = "instance", + with_geometry_embedding: bool = False, + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.c_local_features = c_local_features + self.c_global_features = c_global_features + self.dim = dim + self.timesteps = timesteps + self.u_net_levels = u_net_levels + self.with_geometry_embedding = with_geometry_embedding + + # Set up normalization + if norm_type == "instance": + norm_klass = lambda dim: nn.GroupNorm(dim, dim) + elif norm_type == "layer": + norm_klass = lambda dim: nn.GroupNorm(1, dim) + elif norm_type == "group": + norm_klass = lambda dim: nn.GroupNorm(8, dim) + else: + raise RuntimeError(f"Unknown norm type {norm_type}") + + # Input encoding + self.encode_x = nn.Conv3D(in_features, dim, 1) + + # Setup conditioning + c_local_dim = 0 + if c_local_features > 0: + self.encode_c_local = nn.Conv3D(c_local_features, dim, 1) + c_local_dim += dim + + c_dim = dim + self.encode_t = NyquistFrequencyEmbedding(dim, timesteps) + + if c_global_features > 0: + self.encode_c_global = nn.Linear(c_global_features, dim) + c_dim += dim + + if with_geometry_embedding and c_local_features > 0: + self.geometry_embedding = GeometryEmbedding(c_local_features, dim, actfn) + c_dim += dim + + # Conditioning processing + self.process_c = nn.Sequential( + nn.Linear(c_dim, 4 * c_dim), + actfn(), + nn.Linear(4 * c_dim, c_dim), + actfn(), + ) + + # Decoder + resnet_block = partial( + ResnetBlock, c_dim=c_dim, actfn=actfn, norm_klass=norm_klass + ) + + self.decode = nn.Sequential( + resnet_block(dim, dim), + nn.Conv3D(dim, out_features, 1) + ) + + # U-Net architecture + downsampling_blocks = [resnet_block(dim + c_local_dim, dim * 2)] + [ + resnet_block(dim * 2**i, dim * 2 ** (i + 1)) for i in range(1, u_net_levels) + ] + upsampling_blocks = [ + resnet_block(2 * dim * 2 ** (i + 1), dim * 2**i) + for i in reversed(range(u_net_levels)) + ] + center_dim = dim * 2**u_net_levels + center_block = nn.Sequential( + resnet_block(center_dim, center_dim), + Residual(PreNorm(norm_klass(center_dim), Attention(center_dim))), + resnet_block(center_dim, center_dim), + ) + self.u_net = UNet(downsampling_blocks, upsampling_blocks, center_block) + + def forward(self, x, t, C): + """ + x: Input tensor [B, C, H, W, D] + t: Timestep tensor [B] + C: Dictionary of conditioning tensors + """ + # Encode input + h = self.encode_x(x) + + # Process time embedding + t_emb = self.encode_t(t) + + # Process conditioning + cond_elements = [t_emb] + + # Process local conditioning (boundary conditions) + c_local = None + if self.c_local_features > 0 and "local" in C: + c_local = C["local"] + c_local_encoded = self.encode_c_local(c_local) + h = paddle.concat([h, c_local_encoded], axis=1) + + # Process global conditioning + if self.c_global_features > 0 and "global" in C: + c_global = C["global"] + c_global_encoded = self.encode_c_global(c_global) + cond_elements.append(c_global_encoded) + + # Process geometry embedding + if self.with_geometry_embedding and c_local is not None: + geom_emb = self.geometry_embedding(c_local) + cond_elements.append(geom_emb) + + # Combine all conditioning elements + c = paddle.concat(cond_elements, axis=-1) + c = self.process_c(c) + + # Apply U-Net + h = self.u_net(h, c=c) + + # Decode + output = self.decode(h) + + return output diff --git a/examples/turbdiff/train.py b/examples/turbdiff/train.py new file mode 100644 index 000000000..1d66b8db6 --- /dev/null +++ b/examples/turbdiff/train.py @@ -0,0 +1,343 @@ +""" +Training script for TurbDiff model in PaddleScience. +""" + +import os +import argparse +import time +import yaml +import paddle +import numpy as np + +from model import DenoisingModel +from diffusion import GaussianDiffusion +from data_utils import ( + Variable, TurbulenceDataset, Normalization, + create_dataloader +) +from conditioning import Conditioning, CellTypeEmbedding, ConditioningType + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train TurbDiff model') + parser.add_argument('--config', type=str, default='config.yaml', + help='Path to configuration file') + parser.add_argument('--data_dir', type=str, required=True, + help='Path to dataset directory') + parser.add_argument('--output_dir', type=str, default='output', + help='Output directory for checkpoints and logs') + parser.add_argument('--batch_size', type=int, default=16, + help='Batch size for training') + parser.add_argument('--epochs', type=int, default=100, + help='Number of training epochs') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Initial learning rate') + parser.add_argument('--seed', type=int, default=42, + help='Random seed for reproducibility') + parser.add_argument('--device', type=str, default='gpu', + help='Device to use (gpu or cpu)') + return parser.parse_args() + + +def load_config(config_path): + """Load configuration from YAML file.""" + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + return config + + +def setup_environment(args, config): + """Set up training environment.""" + # Set random seed + paddle.seed(args.seed) + np.random.seed(args.seed) + + # Set device + paddle.set_device(args.device) + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Save configuration + with open(os.path.join(args.output_dir, 'config.yaml'), 'w') as f: + yaml.dump(config, f) + + +def create_variables(config): + """Create variable definitions from config.""" + variables = [] + for var_config in config['variables']: + variables.append(Variable(var_config['name'], var_config['dims'])) + return variables + + +def create_model(config, variables): + """Create model from configuration.""" + # Calculate total dimensions for variables + vars_dim = sum(var.dims for var in variables) + + # Create cell type embedding if specified + cell_type_embedding = None + if config['model'].get('cell_type_features', True): + cell_type_embedding = CellTypeEmbedding.create( + config['model'].get('cell_type_embedding_type', 'learned'), + config['model'].get('cell_type_embedding_dim', 4), + config['model'].get('num_cell_types', 5) + ) + + # Create conditioning module + conditioning = Conditioning( + cell_type_embedding=cell_type_embedding, + use_cell_pos=config['model'].get('cell_pos_features', False) + ) + + # Create denoising model + model = DenoisingModel( + in_features=vars_dim, + out_features=vars_dim * (2 if config['diffusion'].get('learned_variances', False) else 1), + c_local_features=conditioning.local_conditioning_dim, + c_global_features=conditioning.global_conditioning_dim, + timesteps=config['diffusion'].get('timesteps', 1000), + dim=config['model'].get('dim', 32), + u_net_levels=config['model'].get('u_net_levels', 4), + actfn=getattr(paddle.nn, config['model'].get('actfn', 'Silu')), + norm_type=config['model'].get('norm_type', 'instance'), + with_geometry_embedding=config['model'].get('with_geometry_embedding', True), + ) + + # Wrap with diffusion model + diffusion = GaussianDiffusion( + model, + timesteps=config['diffusion'].get('timesteps', 1000), + loss_type=config['diffusion'].get('loss_type', 'l2'), + beta_schedule=config['diffusion'].get('beta_schedule', 'sigmoid'), + clip_denoised=config['diffusion'].get('clip_denoised', False), + noise_bcs=config['diffusion'].get('noise_bcs', False), + learned_variances=config['diffusion'].get('learned_variances', False), + elbo_weight=config['diffusion'].get('elbo_weight', None), + detach_elbo_mean=config['diffusion'].get('detach_elbo_mean', True), + ) + + return diffusion, conditioning + + +def create_datasets(args, config, variables): + """Create datasets from configuration.""" + # Create normalization + normalization = Normalization( + variables, + mode=config['data'].get('normalization_mode', 'mean-std') + ) + + # Create training dataset + train_dataset = TurbulenceDataset( + data_dir=args.data_dir, + split='train', + variables=variables, + ) + + # Create validation dataset + val_dataset = TurbulenceDataset( + data_dir=args.data_dir, + split='val', + variables=variables, + ) + + # Create data loaders + train_loader = create_dataloader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=config['data'].get('num_workers', 0) + ) + + val_loader = create_dataloader( + val_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=config['data'].get('num_workers', 0) + ) + + return train_dataset, val_dataset, train_loader, val_loader, normalization + + +def create_optimizer(model, config, total_steps): + """Create optimizer and learning rate scheduler.""" + # Create learning rate scheduler + learning_rate = config['training']['learning_rate'] + min_learning_rate = config['training'].get('min_learning_rate', learning_rate / 10) + + # Cosine decay with linear warmup + scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=learning_rate, + T_max=total_steps, + eta_min=min_learning_rate + ) + + if config['training'].get('warmup_steps', 0) > 0: + scheduler = paddle.optimizer.lr.LinearWarmup( + scheduler, + warmup_steps=config['training']['warmup_steps'], + start_lr=learning_rate / 10, + end_lr=learning_rate + ) + + # Create optimizer + optimizer = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=model.parameters(), + weight_decay=config['training'].get('weight_decay', 0.0), + beta1=config['training'].get('beta1', 0.9), + beta2=config['training'].get('beta2', 0.999), + ) + + return optimizer, scheduler + + +def train_epoch(model, train_loader, optimizer, conditioning, normalization, epoch, device, log_interval=10): + """Train for one epoch.""" + model.train() + total_loss = 0.0 + start_time = time.time() + + for batch_idx, batch in enumerate(train_loader): + # Prepare data + x = batch['x'] + + # Normalize data + x_normalized = normalization.normalize(x, train_loader.dataset.stats) + + # Prepare conditioning + C = {} + if 'cell_type' in batch: + local_cond = conditioning.prepare_local_conditioning(batch) + if local_cond is not None: + C[ConditioningType.LOCAL] = local_cond + + # Forward pass + cell_idx = batch.get('cell_idx', None) + cell_mask = batch.get('cell_mask', None) + loss, t = model(x_normalized, C, cell_idx, cell_mask) + + # Backward pass and optimize + optimizer.clear_grad() + loss.backward() + optimizer.step() + + # Update statistics + total_loss += loss.item() + + # Log progress + if (batch_idx + 1) % log_interval == 0: + elapsed = time.time() - start_time + print(f'Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | ' + f'Loss {loss.item():.4f} | {elapsed:.2f}s elapsed') + + # Return average loss + return total_loss / len(train_loader) + + +def validate(model, val_loader, conditioning, normalization, device): + """Validate the model.""" + model.eval() + total_loss = 0.0 + + with paddle.no_grad(): + for batch in val_loader: + # Prepare data + x = batch['x'] + + # Normalize data + x_normalized = normalization.normalize(x, val_loader.dataset.stats) + + # Prepare conditioning + C = {} + if 'cell_type' in batch: + local_cond = conditioning.prepare_local_conditioning(batch) + if local_cond is not None: + C[ConditioningType.LOCAL] = local_cond + + # Forward pass + cell_idx = batch.get('cell_idx', None) + cell_mask = batch.get('cell_mask', None) + loss, _ = model(x_normalized, C, cell_idx, cell_mask) + + # Update statistics + total_loss += loss.item() + + # Return average loss + return total_loss / len(val_loader) + + +def save_checkpoint(model, optimizer, epoch, loss, path): + """Save model checkpoint.""" + paddle.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + }, path) + + +def main(): + # Parse arguments + args = parse_args() + + # Load configuration + config = load_config(args.config) + + # Set up environment + setup_environment(args, config) + + # Create variables from config + variables = create_variables(config) + + # Create datasets and loaders + train_dataset, val_dataset, train_loader, val_loader, normalization = create_datasets( + args, config, variables + ) + + # Create model + model, conditioning = create_model(config, variables) + + # Create optimizer and scheduler + total_steps = args.epochs * len(train_loader) + optimizer, lr_scheduler = create_optimizer(model, config, total_steps) + + # Training loop + best_val_loss = float('inf') + + for epoch in range(1, args.epochs + 1): + print(f"Epoch {epoch}/{args.epochs}") + + # Train for one epoch + train_loss = train_epoch( + model, train_loader, optimizer, conditioning, + normalization, epoch, args.device + ) + + # Validate + val_loss = validate( + model, val_loader, conditioning, + normalization, args.device + ) + + # Print metrics + print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") + + # Save checkpoint + checkpoint_path = os.path.join(args.output_dir, f"checkpoint_epoch_{epoch}.pdparams") + save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_path) + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + best_path = os.path.join(args.output_dir, "best_model.pdparams") + save_checkpoint(model, optimizer, epoch, val_loss, best_path) + print(f"New best model saved with val_loss: {val_loss:.4f}") + + print("Training completed!") + + +if __name__ == '__main__': + main()