Skip to content
Merged
4 changes: 3 additions & 1 deletion atomai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .imspec import ImSpec
from .regressor import Regressor
from .classifier import Classifier
from .denoiser import DenoisingAutoencoder, denoise_images
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
from .dklgp import dklGPR, Reconstructor
from .loaders import load_model, load_ensemble, load_pretrained_model

__all__ = ["Segmentor", "ImSpec", "BaseVAE", "VAE", "rVAE",
"jVAE", "jrVAE", "load_model", "load_ensemble",
"load_pretrained_model", "dklGPR", "Regressor",
"Classifier", "Reconstructor"]
"Classifier", "Reconstructor", "DenoisingAutoencoder",
"denoise_images"]
270 changes: 270 additions & 0 deletions atomai/models/denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""
denoiser.py
===========

Denoising autoencoder model for image cleaning

Created by Maxim Ziatdinov (email: [email protected])
Modified with conventional batch normalization approach
"""

from typing import Type, Union, Optional, Tuple
import torch
import numpy as np
from ..trainers import BaseTrainer
from ..predictors import BasePredictor
from ..nets import ConvBlock, UpsampleBlock
from ..utils import set_train_rng, preprocess_denoiser_data


class DenoisingAutoencoder(BaseTrainer):
"""
Denoising autoencoder model for image cleaning and noise reduction

Args:
encoder_filters: List of filter sizes for encoder layers (Default: [8, 16, 32, 64])
decoder_filters: List of filter sizes for decoder layers (Default: [64, 32, 16, 8])
encoder_layers: Number of convolutional layers per encoder block (Default: [1, 2, 2, 2])
decoder_layers: Number of convolutional layers per decoder block (Default: [2, 2, 2, 1])
use_batch_norm: Whether to use batch normalization in both encoder and decoder (Default: True)
upsampling_mode: Upsampling method ('nearest' or 'bilinear') (Default: 'nearest')
**seed: Random seed for reproducibility (Default: 1)

Example:
>>> # Initialize model
>>> model = aoi.models.DenoisingAutoencoder()
>>> # Train on noisy/clean image pairs
>>> model.fit(noisy_images, clean_images, noisy_test, clean_test,
>>> training_cycles=500, swa=True)
>>> # Denoise new images
>>> cleaned = model.predict(new_noisy_images)
"""

def __init__(self,
encoder_filters: list = [8, 16, 32, 64],
decoder_filters: list = [64, 32, 16, 8],
encoder_layers: list = [1, 2, 2, 2],
decoder_layers: list = [2, 2, 2, 1],
use_batch_norm: bool = False,
upsampling_mode: str = 'nearest',
**kwargs) -> None:
"""
Initialize denoising autoencoder
"""
super(DenoisingAutoencoder, self).__init__()

seed = kwargs.get("seed", 1)
set_train_rng(seed)

# Store architecture parameters
self.encoder_filters = encoder_filters
self.decoder_filters = decoder_filters
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.use_batch_norm = use_batch_norm
self.upsampling_mode = upsampling_mode

# Build the autoencoder
self.net = self._build_autoencoder()
self.net.to(self.device)

# Initialize meta state dict for saving/loading
self.meta_state_dict = {
"model_type": "denoising_autoencoder",
"encoder_filters": encoder_filters,
"decoder_filters": decoder_filters,
"encoder_layers": encoder_layers,
"decoder_layers": decoder_layers,
"use_batch_norm": use_batch_norm,
"upsampling_mode": upsampling_mode,
"weights": self.net.state_dict()
}

def _build_autoencoder(self) -> torch.nn.Module:
"""
Build the encoder-decoder architecture with consistent batch norm placement
"""
# Build encoder
encoder_modules = []
in_channels = 1 # Assuming grayscale images

for i, (filters, layers) in enumerate(zip(self.encoder_filters, self.encoder_layers)):
# Add convolutional block with consistent batch norm usage
encoder_modules.append(
ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels,
output_channels=filters, batch_norm=self.use_batch_norm)
)
# Add max pooling (except for the last layer)
if i < len(self.encoder_filters) - 1:
encoder_modules.append(torch.nn.MaxPool2d(2, 2))
in_channels = filters

encoder = torch.nn.Sequential(*encoder_modules)

# Build decoder
decoder_modules = []

for i, (filters, layers) in enumerate(zip(self.decoder_filters, self.decoder_layers)):
# Add upsampling (except for the first layer)
if i > 0:
decoder_modules.append(
UpsampleBlock(ndim=2, input_channels=in_channels,
output_channels=in_channels, mode=self.upsampling_mode)
)

# Add convolutional block with same batch norm setting as encoder
decoder_modules.append(
ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels,
output_channels=filters, batch_norm=self.use_batch_norm)
)
in_channels = filters

# Final output layer (no batch norm for final reconstruction)
decoder_modules.append(torch.nn.Conv2d(in_channels, 1, 1))

decoder = torch.nn.Sequential(*decoder_modules)

# Combine encoder and decoder
autoencoder = torch.nn.Sequential(encoder, decoder)

return autoencoder

def fit(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Union[np.ndarray, torch.Tensor],
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
loss: str = 'mse',
optimizer: Optional[Type[torch.optim.Optimizer]] = None,
training_cycles: int = 500,
batch_size: int = 32,
compute_accuracy: bool = False,
full_epoch: bool = False,
swa: bool = True,
perturb_weights: bool = False,
**kwargs):
"""
Train the denoising autoencoder

Args:
X_train: Noisy input images for training
y_train: Clean target images for training
X_test: Noisy input images for testing
y_test: Clean target images for testing
loss: Loss function (Default: 'mse')
optimizer: Optimizer (Default: Adam with lr=1e-3)
training_cycles: Number of training epochs
batch_size: Batch size for training
compute_accuracy: Whether to compute accuracy metrics
full_epoch: Whether to use full epochs
swa: Whether to use stochastic weight averaging
perturb_weights: Whether to use weight perturbation
**kwargs: Additional arguments for training
"""
if X_test is None or y_test is None:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X_train, y_train, test_size=kwargs.get("test_size", .15),
shuffle=True, random_state=kwargs.get("seed", 1))

# Preprocess data
X_train, y_train, X_test, y_test = preprocess_denoiser_data(
X_train, y_train, X_test, y_test)

# Compile and run training
self.compile_trainer(
(X_train, y_train, X_test, y_test),
loss=loss, optimizer=optimizer, training_cycles=training_cycles,
batch_size=batch_size, compute_accuracy=compute_accuracy,
full_epoch=full_epoch, swa=swa, perturb_weights=perturb_weights,
**kwargs
)

self.run()

# Update meta state dict
self.meta_state_dict["weights"] = self.net.state_dict()

def predict(self,
data: Union[np.ndarray, torch.Tensor],
**kwargs) -> np.ndarray:
"""
Denoise input images

Args:
data: Input noisy images
**num_batches: Number of batches for prediction (Default: 10)

Returns:
Denoised images
"""
use_gpu = self.device == 'cuda'
predictor = BasePredictor(self.net, use_gpu, **kwargs)

# Ensure proper format for prediction
if isinstance(data, np.ndarray):
if data.ndim == 2:
data = data[None, None, ...] # Add batch and channel dims
elif data.ndim == 3:
data = data[:, None, ...] # Add channel dim

prediction = predictor.predict(data, **kwargs)

return prediction.detach().cpu().numpy().squeeze()

def load_weights(self, filepath: str) -> None:
"""
Load saved model weights
"""
weight_dict = torch.load(filepath, map_location=self.device)
if "weights" in weight_dict:
self.net.load_state_dict(weight_dict["weights"])
else:
self.net.load_state_dict(weight_dict)


def init_denoising_autoencoder(**kwargs) -> Tuple[Type[torch.nn.Module], dict]:
"""
Initialize a denoising autoencoder model

Returns:
Tuple of (model, meta_state_dict)
"""
model = DenoisingAutoencoder(**kwargs)
return model.net, model.meta_state_dict


# Convenience function for quick denoising
def denoise_images(noisy_images: np.ndarray,
clean_images: np.ndarray,
test_noisy: Optional[np.ndarray] = None,
test_clean: Optional[np.ndarray] = None,
training_cycles: int = 500,
**kwargs) -> Tuple[DenoisingAutoencoder, np.ndarray]:
"""
Convenience function for training a denoising autoencoder and making predictions

Args:
noisy_images: Training noisy images
clean_images: Training clean images
test_noisy: Test noisy images (optional)
test_clean: Test clean images (optional)
training_cycles: Number of training cycles
**kwargs: Additional arguments for model and training

Returns:
Tuple of (trained_model, predictions_on_test_data)
"""
# Initialize model
model = DenoisingAutoencoder(**kwargs)

# Train model
model.fit(noisy_images, clean_images, test_noisy, test_clean,
training_cycles=training_cycles, **kwargs)

# Make predictions if test data provided
predictions = None
if test_noisy is not None:
predictions = model.predict(test_noisy)

return model, predictions
43 changes: 43 additions & 0 deletions atomai/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .imspec import ImSpec
from .regressor import Regressor
from .classifier import Classifier
from .denoiser import DenoisingAutoencoder
from .dgm import BaseVAE, VAE, rVAE, jrVAE, jVAE
from ..utils import average_weights

Expand Down Expand Up @@ -49,6 +50,8 @@ def load_model(filepath: str) -> Union[Segmentor, Union[VAE, rVAE, jrVAE, jVAE],
model = load_cls_model(loaded_dict)
elif model_type == "vae":
model = load_vae_model(loaded_dict)
elif model_type == "denoising_autoencoder":
model = load_denoising_autoencoder(loaded_dict)
else:
raise ValueError(
"The model type {} cannot be loaded".format(model_type))
Expand Down Expand Up @@ -192,6 +195,46 @@ def load_vae_model(meta_dict: Dict[str, torch.Tensor]) -> Type[BaseVAE]:
return m


def load_denoising_autoencoder(meta_dict: Dict[str, torch.Tensor]) -> Type[DenoisingAutoencoder]:
"""
Loads trained AtomAI denoising autoencoder models

Args:
meta_dict (dict):
dictionary with trained weights and key information
about model's structure

Returns:
DenoisingAutoencoder object with NN in evaluation state
"""
from .denoiser import DenoisingAutoencoder

encoder_filters = meta_dict.pop("encoder_filters", [8, 16, 32, 64])
decoder_filters = meta_dict.pop("decoder_filters", [64, 32, 16, 8])
encoder_layers = meta_dict.pop("encoder_layers", [1, 2, 2, 2])
decoder_layers = meta_dict.pop("decoder_layers", [2, 2, 2, 1])
use_batch_norm = meta_dict.pop("use_batch_norm", True)
upsampling_mode = meta_dict.pop("upsampling_mode", 'nearest')
weights = meta_dict.pop("weights")

model = DenoisingAutoencoder(
encoder_filters=encoder_filters,
decoder_filters=decoder_filters,
encoder_layers=encoder_layers,
decoder_layers=decoder_layers,
use_batch_norm=use_batch_norm,
upsampling_mode=upsampling_mode,
**meta_dict
)

model.net.load_state_dict(weights)
if "optimizer" in meta_dict.keys():
optimizer = meta_dict.pop("optimizer")
model.optimizer = optimizer
model.net.eval()
return model


def load_ensemble(filepath: str) -> Tuple[Type[torch.nn.Module], Dict[int, Dict[str, torch.Tensor]]]:
"""
Loads trained ensemble models
Expand Down
3 changes: 2 additions & 1 deletion atomai/stat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .multivar import (imlocal, calculate_transition_matrix,
sum_transitions, update_classes)
from .fft_nmf import SlidingFFTNMF

__all__ = ['imlocal', 'calculate_transition_matrix', 'sum_transitions',
'update_classes']
'update_classes', 'SlidingFFTNMF']
Loading