diff --git a/src/virtual_stain_flow/engine/context.py b/src/virtual_stain_flow/engine/context.py index 8110136..64b8ffd 100644 --- a/src/virtual_stain_flow/engine/context.py +++ b/src/virtual_stain_flow/engine/context.py @@ -6,11 +6,12 @@ isolated and modular computations. """ -from typing import Dict, Iterable, Tuple, Union +from typing import Dict, Iterable, Union, Optional import torch +from torch import Tensor -from .names import TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS +from .names import INPUTS, TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS ContextValue = Union[torch.Tensor, torch.nn.Module] @@ -47,17 +48,9 @@ def add(self, **items: ContextValue) -> "Context": where keys are the names of the tensors. """ - for k, v in items.items(): - if k in RESERVED_KEYS and not isinstance(v, torch.Tensor): - raise ReservedKeyTypeError( - f"Reserved key '{k}' must be a torch.Tensor, got {type(v)}" - ) - elif k in RESERVED_MODEL_KEYS and not isinstance(v, torch.nn.Module): - raise ReservedKeyTypeError( - f"Reserved key '{k}' must be a torch.nn.Module, got {type(v)}" - ) - - self._store.update(items) + for key, value in items.items(): + self[key] = value + return self def require(self, keys: Iterable[str]) -> None: @@ -81,15 +74,18 @@ def as_kwargs(self) -> Dict[str, ContextValue]: """ return self._store - def as_metric_args(self) -> Tuple[ContextValue, ContextValue]: + def as_metric_args(self) -> tuple[Tensor, Tensor]: """ Returns the predictions and targets tensors for Image quality assessment metric computation. Intended use: metric.update(*ctx.as_metric_args()) + + :return: A tuple (preds, targets) of tensors. + :raises ValueError: If either preds or targets is missing. """ - self.require([PREDS, TARGETS]) - preds = self._store[PREDS] - targs = self._store[TARGETS] + self.require(keys=[PREDS, TARGETS]) + preds: Tensor = self.preds + targs: Tensor = self.targets return (preds, targs) def __repr__(self) -> str: @@ -109,6 +105,28 @@ def __repr__(self) -> str: # --- Methods for dict like behavior of context class --- def __setitem__(self, key: str, value: ContextValue) -> None: + """ + Sets a context item, with checks for reserved keys. + + :param key: The name of the context item. + :param value: The tensor/module to store. + """ + # Only allow torch.Tensor or torch.nn.Module values + if not isinstance(value, (torch.Tensor, torch.nn.Module)): + raise TypeError( + f"Context values must be torch.Tensor or torch.nn.Module, got {type(value)}" + ) + + # Further type check matching for reserved keys + if key in RESERVED_KEYS and not isinstance(value, torch.Tensor): + raise ReservedKeyTypeError( + f"Reserved key '{key}' must be a torch.Tensor, got {type(value)}" + ) + elif key in RESERVED_MODEL_KEYS and not isinstance(value, torch.nn.Module): + raise ReservedKeyTypeError( + f"Reserved key '{key}' must be a torch.nn.Module, got {type(value)}" + ) + self._store[key] = value def __contains__(self, key: str) -> bool: @@ -123,7 +141,7 @@ def __iter__(self): def __len__(self): return len(self._store) - def get(self, key: str, default: ContextValue = None) -> ContextValue: + def get(self, key: str, default: Optional[ContextValue] = None) -> Optional[ContextValue]: return self._store.get(key, default) def values(self): @@ -134,3 +152,59 @@ def items(self): def keys(self): return self._store.keys() + + def pop(self, key: str) -> ContextValue: + """ + Remove and return the value for key if key is in the context, + else raises a KeyError. + """ + if key not in self._store: + raise KeyError(f"Key '{key}' not found in Context.") + return self._store.pop(key) + + def __or__(self, other: "Context") -> "Context": + """ + Merge two Context objects using the | operator. + Returns a new Context with items from both contexts. + Items from the right operand (other) take precedence in case of key conflicts. + + :param other: Another Context object to merge with. + :return: A new Context object containing items from both contexts. + """ + if not isinstance(other, Context): + raise NotImplementedError( + "__or__ operation only supported between Context objects." + ) + new_context = Context(**self._store) + new_context.add(**other._store) + return new_context + + def __ror__(self, other: "Context") -> "Context": + """ + Reverse merge (right | operator) for Context objects. + Called when the left operand doesn't support __or__ with Context. + + :param other: Another Context object to merge with. + :return: A new Context object containing items from both contexts. + """ + if not isinstance(other, Context): + raise NotImplementedError( + "__or__ operation only supported between Context objects." + ) + new_context = Context(**other._store) + new_context.add(**self._store) + return new_context + + # --- Properties for robust typing for reserved keys --- + # let fail if key is not present + @property + def inputs(self) -> Tensor: + return self._store[INPUTS] # type: ignore + + @property + def targets(self) -> Tensor: + return self._store[TARGETS] # type: ignore + + @property + def preds(self) -> Tensor: + return self._store[PREDS] # type: ignore diff --git a/src/virtual_stain_flow/engine/forward_groups.py b/src/virtual_stain_flow/engine/forward_groups.py index 20acca8..0f8f906 100644 --- a/src/virtual_stain_flow/engine/forward_groups.py +++ b/src/virtual_stain_flow/engine/forward_groups.py @@ -31,13 +31,13 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Tuple, Dict +from typing import Optional, Dict import torch import torch.optim as optim import torch.nn as nn -from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL +from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL, DISCRIMINATOR_MODEL from .context import Context @@ -52,9 +52,9 @@ class AbstractForwardGroup(ABC): """ # Subclasses should override these with ordered tuples. - input_keys: Tuple[str, ...] - target_keys: Tuple[str, ...] - output_keys: Tuple[str, ...] + input_keys: tuple[str, ...] + target_keys: tuple[str, ...] + output_keys: tuple[str, ...] def __init__( self, @@ -75,7 +75,7 @@ def _move_tensors(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso } @staticmethod - def _normalize_outputs(raw) -> Tuple[torch.Tensor, ...]: + def _normalize_outputs(raw) -> tuple[torch.Tensor, ...]: """ Normalize model outputs to a tuple of tensors while preserving order. @@ -140,9 +140,9 @@ class GeneratorForwardGroup(AbstractForwardGroup): metric_value = metric_fn(preds, targets) """ - input_keys: Tuple[str, ...] = (INPUTS,) - target_keys: Tuple[str, ...] = (TARGETS,) - output_keys: Tuple[str, ...] = (PREDS,) + input_keys: tuple[str, ...] = (INPUTS,) + target_keys: tuple[str, ...] = (TARGETS,) + output_keys: tuple[str, ...] = (PREDS,) def __init__( self, @@ -207,3 +207,88 @@ def optimizer(self) -> Optional[optim.Optimizer]: Convenience property to access the generator optimizer directly. """ return self._optimizers[GENERATOR_MODEL] + + +class DiscriminatorForwardGroup(AbstractForwardGroup): + """ + Forward group for a simple single (GAN/wGAN) discriminator workflow. + The discriminator is assumed to take in a "stack" of input and target + images concatenated along the channel dimension, and output a score/probability. + Relevant context values are input_keys, target_keys, output_keys for a + single-discriminator model, where: + - the forward is called as: + p = discriminator(stack) + - the evaluation is less straightforward, but typically involves + computing losses/metrics based on p and real/fake labels: + metric_value = metric_fn(p, real_or_fake_labels) + or perhaps involving the discrminator model itself for wasserstein distance: + metric_value = metric_fn(discriminator, stack, real_or_fake_labels) + """ + + input_keys: tuple[str, ...] = ("stack",) + target_keys: tuple[str, ...] = () + output_keys: tuple[str, ...] = ("p",) + + def __init__( + self, + discriminator: nn.Module, + optimizer: Optional[optim.Optimizer] = None, + device: torch.device = torch.device("cpu"), + ): + super().__init__(device=device) + + self._models[DISCRIMINATOR_MODEL] = discriminator + self._models[DISCRIMINATOR_MODEL].to(self.device) + self._optimizers[DISCRIMINATOR_MODEL] = optimizer + + def __call__(self, train: bool, **inputs: torch.Tensor) -> Context: + """ + Executes the forward pass, managing training/eval modes and optimizer steps. + Subclasses may override this method if needed. + + :param train: Whether to run in training mode. Meant to be specified + by the trainer to switch between train/eval modes and determine + whether gradients should be computed. + :param inputs: Keyword arguments of input tensors. + """ + + fp_model = self.model + fp_optimizer = self.optimizer + + # 1) Stage and validate inputs/targets + ctx = Context(**self._move_tensors(inputs), **{DISCRIMINATOR_MODEL: fp_model }) + ctx.require(self.input_keys) + ctx.require(self.target_keys) + + # 2) Forward, with grad only when training + fp_model.train(mode=train) + train and fp_optimizer is not None and fp_optimizer.zero_grad(set_to_none=True) + with torch.set_grad_enabled(train): + model_inputs = [ctx[k] for k in self.input_keys] # ordered + raw = fp_model(*model_inputs) + y_tuple = self._normalize_outputs(raw) + + # 3) Arity check + map outputs to names + if len(y_tuple) != len(self.output_keys): + raise ValueError( + f"Model returned {len(y_tuple)} outputs, " + f"but output_keys expects {len(self.output_keys)}" + ) + outputs = {k: v for k, v in zip(self.output_keys, y_tuple)} + + # 5) Return enriched context (preds available for losses/metrics) + return ctx.add(**outputs) + + @property + def model(self) -> nn.Module: + """ + Convenience property to access the discriminator model directly. + """ + return self._models[DISCRIMINATOR_MODEL] + + @property + def optimizer(self) -> Optional[optim.Optimizer]: + """ + Convenience property to access the discriminator optimizer directly. + """ + return self._optimizers[DISCRIMINATOR_MODEL] diff --git a/src/virtual_stain_flow/engine/loss_group.py b/src/virtual_stain_flow/engine/loss_group.py index 2dd98cc..16b3a74 100644 --- a/src/virtual_stain_flow/engine/loss_group.py +++ b/src/virtual_stain_flow/engine/loss_group.py @@ -26,10 +26,12 @@ import torch -from .loss_utils import AbstractLoss, _get_loss_name, _scalar_from_ctx -from .context import Context +from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx +from .context import Context, ContextValue from .names import PREDS, TARGETS +Scalar = Union[int, float, bool] + @dataclass class LossItem: @@ -53,7 +55,7 @@ class LossItem: losses and centralizes device management. ) """ - module: Union[torch.nn.Module, AbstractLoss] + module: Union[torch.nn.Module, BaseLoss] args: Union[str, Tuple[str, ...]] = (PREDS, TARGETS) key: Optional[str] = None weight: float = 1.0 @@ -63,7 +65,7 @@ class LossItem: def __post_init__(self): - self.key = self.key or _get_loss_name(self.module) + self.key = str(self.key or _get_loss_name(self.module)) self.args = (self.args,) if isinstance(self.args, str) else self.args try: @@ -94,7 +96,9 @@ def __call__( if context is not None: context.require(self.args) - inputs = {arg: context[arg] for arg in self.args} + inputs: Dict[str, ContextValue] = { + arg: context[arg] for arg in self.args + } if not self.enabled or (not train and not self.compute_at_val): zero = _scalar_from_ctx(0.0, inputs) @@ -127,7 +131,7 @@ class LossGroup: items: Sequence[LossItem] @property - def item_names(self) -> List[str]: + def item_names(self) -> List[Optional[str]]: return [item.key for item in self.items] def __call__( @@ -135,7 +139,7 @@ def __call__( train: bool, context: Optional[Context] = None, **inputs: torch.Tensor - ) -> Tuple[torch.Tensor, Dict[str, float]]: + ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: """ Compute the total loss and individual loss values. @@ -153,7 +157,7 @@ def __call__( for item in self.items: raw, weighted = item(train, context=context, **inputs) - logs[item.key] = raw.item() + logs[item.key] = raw.item() # type: ignore total += weighted return total, logs diff --git a/src/virtual_stain_flow/engine/loss_utils.py b/src/virtual_stain_flow/engine/loss_utils.py index ec61b0c..22bada3 100644 --- a/src/virtual_stain_flow/engine/loss_utils.py +++ b/src/virtual_stain_flow/engine/loss_utils.py @@ -4,27 +4,28 @@ Utility functions for loss handling. """ -from typing import Union, Dict +from typing import Union, Mapping import torch -from ..losses.AbstractLoss import AbstractLoss +from ..losses.BaseLoss import BaseLoss +from .context import Context, ContextValue def _get_loss_name( - loss_fn: Union[torch.nn.Module, AbstractLoss] + loss_fn: Union[torch.nn.Module, BaseLoss] ) -> str: """ Helper method to get the name of the loss function. """ - if isinstance(loss_fn, AbstractLoss) and hasattr(loss_fn, "metric_name"): + if isinstance(loss_fn, BaseLoss) and hasattr(loss_fn, "metric_name"): return loss_fn.metric_name elif isinstance(loss_fn, torch.nn.Module): return type(loss_fn).__name__ else: raise TypeError( "Expected loss_fn to be either a torch.nn.Module or " - "an AbstractLoss instance." + "a BaseLoss instance." f"Got {type(loss_fn)} instead." ) @@ -40,7 +41,7 @@ def _scalar_from_device( def _scalar_from_ctx( value: float, - ctx: Dict[str, torch.Tensor] + ctx: Union[Mapping[str, ContextValue], Context] ): """ Helper method to convert a scalar value on the same device and diff --git a/src/virtual_stain_flow/engine/names.py b/src/virtual_stain_flow/engine/names.py index 1379059..fd7a217 100644 --- a/src/virtual_stain_flow/engine/names.py +++ b/src/virtual_stain_flow/engine/names.py @@ -11,6 +11,7 @@ PREDS: Final[str] = "preds" # always mean the predicted image tensor predicting from inputs the targets GENERATOR_MODEL: Final[str] = "generator" +DISCRIMINATOR_MODEL: Final[str] = "discriminator" RESERVED_KEYS: FrozenSet[str] = frozenset({INPUTS, TARGETS, PREDS}) -RESERVED_MODEL_KEYS: FrozenSet[str] = frozenset({GENERATOR_MODEL}) +RESERVED_MODEL_KEYS: FrozenSet[str] = frozenset({GENERATOR_MODEL, DISCRIMINATOR_MODEL}) diff --git a/src/virtual_stain_flow/engine/orchestrators.py b/src/virtual_stain_flow/engine/orchestrators.py new file mode 100644 index 0000000..52a91b1 --- /dev/null +++ b/src/virtual_stain_flow/engine/orchestrators.py @@ -0,0 +1,183 @@ +""" +orchestrators.py + +Collection of orchestrator classes that manages training flow +for complex models involving multiple components, such as GANs. + +This is constrasted with ForwardGroup classes, which handle +the forward pass and optimization of single model components. +The addition of orchestrators helps keep ForwardGroup classes simple. + +Internally, an orchestrator manages multiple ForwardGroups +and defines coordinated training steps that involve forward passes +through several components in a specific sequence. +""" + +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch import Tensor +from torch import optim + +from .forward_groups import GeneratorForwardGroup, DiscriminatorForwardGroup +from .context import Context + + +@dataclass +class OrchestratedStep: + """ + Thin wrapper around orchestrator methods to present step-like objects + to trainers with the same interface as ForwardGroups, exposing: + - __call__(train=..., **batch) for forward pass, and + - .step() to step the optimizer + """ + + forward_fn: Callable[..., Context] + optimizer: Optional[optim.Optimizer] = None + + def __call__(self, train: bool, **batch) -> Context: + return self.forward_fn(train=train, **batch) + + def step(self) -> None: + if self.optimizer is not None: + self.optimizer.step() + + +class GANOrchestrator: + """ + Orchestrator for a GAN-style setup with separate generator and discriminator + training steps. + + Stores GeneratorForwardGroup and a DiscriminatorForwardGroup: + The GeneratorForwardGroup and DiscriminatorForwardGroup are the + simplified building blocks that conducts exclusively the forward pass of + either generator or discriminator). + The Orchestrator._discriminator_forward and Orchestrator._generator_forward + methods is uses these simple forward groups to build more complex steps that + enable GAN training, which requires a coordinated forward pass through both + the discriminator and generator. + """ + + def __init__( + self, + generator_fg: GeneratorForwardGroup, + discriminator_fg: DiscriminatorForwardGroup, + ): + """ + Initialize from already-constructed forward groups. + + This keeps concerns separated: forward groups manage single-module + behavior; the orchestrator manages their composition. + """ + # simple forward group storage + if not isinstance(generator_fg, GeneratorForwardGroup): + raise TypeError("generator_fg must be a GeneratorForwardGroup") + self._gen_fg: GeneratorForwardGroup = generator_fg + if not isinstance(discriminator_fg, DiscriminatorForwardGroup): + raise TypeError("discriminator_fg must be a DiscriminatorForwardGroup") + self._disc_fg: DiscriminatorForwardGroup = discriminator_fg + + # Public step-like objects that trainers can use directly + self.discriminator_step = OrchestratedStep( + forward_fn=self._discriminator_forward, + optimizer=self._disc_fg.optimizer, + ) + self.generator_step = OrchestratedStep( + forward_fn=self._generator_forward, + optimizer=self._gen_fg.optimizer, + ) + + def _build_real_fake_contexts( + self, + train: bool, + gen_ctx: Context, + ) -> Context: + """ + Given a generator context containing inputs / targets / preds, + generates the real and fake stacks by concatenating the true + input with the true target or predicted target along the + channel dimension. The result stacks serve as direct inputs + to the discriminator. + + The discriminator is then run on both stacks to produce + outputs scores of if it thinks the provided stack is real. + + :param train: Whether the model is in training mode. + :param gen_ctx: The Context produced by the generator forward pass, + containing at least INPUTS, TARGETS, and PREDS tensors. + :return: A merged Context containing outputs from both + the real and fake discriminator passes, as well as + the original generator context. + """ + # Stack along channel dim: [inputs, targets] vs [inputs, preds] + # Context objects handles type checking before the reserved key + # context values (inputs, targets, preds) are tenors, so no + # further type checking is needed here. + real_stack: Tensor = torch.cat(tensors=[gen_ctx.inputs, gen_ctx.targets], dim=1) + fake_stack: Tensor = torch.cat(tensors=[gen_ctx.inputs, gen_ctx.preds], dim=1) + + # Real batch: D(x, y_true) + ctx_real: Context = self._disc_fg(train=train, stack=real_stack) + ctx_real["real_stack"] = real_stack + ctx_real["p_real_as_real"] = ctx_real.pop(key="p") + + # Fake batch: D(x, y_fake) + ctx_fake: Context = self._disc_fg(train=train, stack=fake_stack) + ctx_fake["fake_stack"] = fake_stack + ctx_fake["p_fake_as_real"] = ctx_fake.pop(key="p") + + # Merge: real info, fake info, and generator info + return ctx_real | ctx_fake | gen_ctx + + def _discriminator_forward( + self, + train: bool, + inputs: torch.Tensor, + targets: torch.Tensor, + ) -> Context: + """ + Forward step to train only the discriminator. + + :param train: Whether the model is in training mode. + :param inputs: The input tensor for the models. + :param targets: The target tensor for the models. + :return: A Context containing discriminator outputs for both + real and fake stacks, as well as the original generator context. + """ + # Generator is always eval for discriminator updates + gen_ctx: Context = self._gen_fg(train=False, inputs=inputs, targets=targets) + + return self._build_real_fake_contexts(train=train, gen_ctx=gen_ctx) + + def _generator_forward( + self, + train: bool, + inputs: torch.Tensor, + targets: torch.Tensor, + ) -> Context: + """ + Forward step to train only the generator. + + :param train: Whether the model is in training mode. + :param inputs: The input tensor for the models. + :param targets: The target tensor for the models. + :return: A Context containing generator outputs plus + p_fake_as_real from the discriminator. + """ + + # Generate predictions and then run discriminator on fake stack + gen_ctx: Context = self._gen_fg(train=train,inputs=inputs,targets=targets) + fake_stack: Tensor = torch.cat(tensors=[gen_ctx.inputs, gen_ctx.preds], dim=1) + disc_ctx: Context = self._disc_fg(train=train, stack=fake_stack) + + # Attach discriminator score to the generator context and return. + return gen_ctx.add(p_fake_as_real=disc_ctx["p"]) + + @property + def generator_forward_group(self) -> GeneratorForwardGroup: + return self._gen_fg + + @property + def discriminator_forward_group(self) -> DiscriminatorForwardGroup: + return self._disc_fg diff --git a/src/virtual_stain_flow/losses/AbstractLoss.py b/src/virtual_stain_flow/losses/AbstractLoss.py deleted file mode 100644 index 8342e54..0000000 --- a/src/virtual_stain_flow/losses/AbstractLoss.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABC, abstractmethod - -import torch -import torch.nn as nn - -""" -Adapted from https://github.com/WayScience/nuclear_speckles_analysis -""" -class AbstractLoss(nn.Module, ABC): - """Abstract class for metrics""" - - def __init__(self, _metric_name: str): - - super(AbstractLoss, self).__init__() - - self._metric_name = _metric_name - self._trainer = None - - @property - def trainer(self): - return self._trainer - - @trainer.setter - def trainer(self, value): - """ - Setter of trainer meant to be called by the trainer class during initialization - """ - self._trainer = value - - @property - def metric_name(self, _metric_name: str): - """Defines the mertic name returned by the class.""" - return self._metric_name - - @abstractmethod - def forward(self, truth: torch.Tensor, generated: torch.Tensor - ) -> float: - """ - Computes the metric given information about the data - - :param truth: The tensor containing the ground truth image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type truth: torch.Tensor - :param generated: The tensor containing model generated image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type generated: torch.Tensor - :return: The computed metric as a float value. - :rtype: float - """ - pass diff --git a/src/virtual_stain_flow/losses/BaseLoss.py b/src/virtual_stain_flow/losses/BaseLoss.py new file mode 100644 index 0000000..33f8fbc --- /dev/null +++ b/src/virtual_stain_flow/losses/BaseLoss.py @@ -0,0 +1,50 @@ +""" + +""" +from __future__ import annotations +from typing import Protocol, runtime_checkable, Any + +from torch import Tensor +import torch.nn as nn + + +@runtime_checkable +class LossLike(Protocol): + """ + For type checking purposes only. + """ + + def forward(self, *args: Any, **kwargs: Any) -> Tensor: ... + def __call__(self, *args: Any, **kwargs: Any) -> Tensor: ... + + +class BaseLoss(nn.Module): + """Base class for loss functions.""" + + def __init__(self, _metric_name: str): + + super(BaseLoss, self).__init__() + + self._metric_name = _metric_name + self._trainer = None + + @property + def trainer(self): + return self._trainer + + @trainer.setter + def trainer(self, value): + """ + Setter of trainer meant to be called by the trainer class during initialization + """ + self._trainer = value + + @property + def metric_name(self): + """Defines the mertic name returned by the class.""" + return self._metric_name + + def forward(self, *args: Any, **kwargs: Any) -> Tensor: + raise NotImplementedError( + f"{self.__class__.__name__}.forward() must be implemented by subclasses." + ) diff --git a/src/virtual_stain_flow/losses/DiscriminatorLoss.py b/src/virtual_stain_flow/losses/DiscriminatorLoss.py deleted file mode 100644 index bbd937e..0000000 --- a/src/virtual_stain_flow/losses/DiscriminatorLoss.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch - -from .AbstractLoss import AbstractLoss - -class WassersteinLoss(AbstractLoss): - """ - This class implements the loss function for the discriminator in a Wasserstein Generative Adversarial Network (wGAN). - The discriminator loss measures how well the discriminator is able to distinguish between real (ground expected_truth) - images and fake (expected_generated) images produced by the generator. - """ - def __init__(self, _metric_name): - super().__init__(_metric_name) - - def forward(self, expected_truth, expected_generated): - """ - Computes the Wasserstein Discriminator Loss given probability scores expected_truth and expected_generated from the discriminator - - :param expected_truth: The tensor containing the ground expected_truth - probability score predicted by the discriminator over a batch of real images (input target pair), - should be of shape [batch_size, 1]. - :type expected_truth: torch.Tensor - :param expected_generated: The tensor containing model expected_generated - probability score predicted by the discriminator over a batch of generated images (input generated pair), - should be of shape [batch_size, 1]. - :type expected_generated: torch.Tensor - :return: The computed metric as a float value. - :rtype: float - """ - - # If the probability output is more than Scalar, take the mean of the output - # For compatibility with both a Discriminator class that would output a scalar probability (currently implemented) - # and a Discriminator class that would output a 2d matrix of probabilities (currently not implemented) - if expected_truth.dim() >= 3: - expected_truth = torch.mean(expected_truth, tuple(range(2, expected_truth.dim()))) - if expected_generated.dim() >= 3: - expected_generated = torch.mean(expected_generated, tuple(range(2, expected_generated.dim()))) - - return (expected_generated - expected_truth).mean() diff --git a/src/virtual_stain_flow/losses/GeneratorLoss.py b/src/virtual_stain_flow/losses/GeneratorLoss.py deleted file mode 100644 index 351f639..0000000 --- a/src/virtual_stain_flow/losses/GeneratorLoss.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Optional - -import torch -from torch.nn import L1Loss - -from .AbstractLoss import AbstractLoss - -class GeneratorLoss(AbstractLoss): - """ - Computes the loss for the GaN generator. - Combines an adversarial loss component with an image reconstruction loss. - """ - def __init__(self, - _metric_name: str, - reconstruction_loss: Optional[torch.tensor] = L1Loss(), - reconstruction_weight: float = 1.0 - ): - """ - :param reconstruction_loss: The image reconstruction loss, - defaults to L1Loss(reduce=False) - :type reconstruction_loss: torch.tensor - :param reconstruction_weight: The weight for the image reconstruction loss, defaults to 1.0 - :type reconstruction_weight: float - """ - - super().__init__(_metric_name) - - self._reconstruction_loss = reconstruction_loss - if isinstance(reconstruction_weight, float): - self._reconstruction_weight = reconstruction_weight - else: - raise ValueError("reconstruction_weight must be a float value") - - def forward(self, - discriminator_probs: torch.tensor, - truth: torch.tensor, - generated: torch.tensor, - epoch: int = 0 - ): - """ - Computes the loss for the GaN generator. - - :param discriminator_probs: The probabilities of the discriminator for the fake images being real. - :type discriminator_probs: torch.tensor - :param truth: The tensor containing the ground truth image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type truth: torch.Tensor - :param generated: The tensor containing model generated image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type generated: torch.Tensor - :param epoch: The current epoch number. - Used for a smoothing weight for the adversarial loss component - Defaults to 0. - :type epoch: int - :return: The computed metric as a float value. - :rtype: float - """ - - # Adversarial loss - adversarial_loss = -torch.mean(discriminator_probs) - adversarial_loss = 0.01 * adversarial_loss/(epoch + 1) - - image_loss = self._reconstruction_loss(generated, truth) - - return adversarial_loss + self._reconstruction_weight * image_loss.mean() diff --git a/src/virtual_stain_flow/losses/GradientPenaltyLoss.py b/src/virtual_stain_flow/losses/GradientPenaltyLoss.py deleted file mode 100644 index f40242d..0000000 --- a/src/virtual_stain_flow/losses/GradientPenaltyLoss.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.autograd as autograd - -from .AbstractLoss import AbstractLoss - -class GradientPenaltyLoss(AbstractLoss): - def __init__(self, _metric_name, discriminator, weight=10.0): - super().__init__(_metric_name) - - self._discriminator = discriminator - self._weight = weight - - def forward(self, truth, generated): - """ - Computes Gradient Penalty Loss for wGaN GP - - :param truth: The tensor containing the ground truth image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type truth: torch.Tensor - :param generated: The tensor containing model generated image, - should be of shape [batch_size, channel_number, img_height, img_width]. - :type generated: torch.Tensor - :return: The computed metric as a float value. - :rtype: float - """ - - device = self.trainer.device - - batch_size = truth.size(0) - eta = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(truth) - interpolated = (eta * truth + (1 - eta) * generated).requires_grad_(True) - prob_interpolated = self._discriminator(interpolated) - - gradients = autograd.grad( - outputs=prob_interpolated, - inputs=interpolated, - grad_outputs=torch.ones_like(prob_interpolated), - create_graph=True, - retain_graph=True, - )[0] - - gradients = gradients.view(batch_size, -1) - gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() - return self._weight * gradient_penalty diff --git a/src/virtual_stain_flow/losses/wgan_losses.py b/src/virtual_stain_flow/losses/wgan_losses.py new file mode 100644 index 0000000..ba4365c --- /dev/null +++ b/src/virtual_stain_flow/losses/wgan_losses.py @@ -0,0 +1,100 @@ +""" +wgan_losses.py + +Wasserstein GAN loss implementations. +""" + +import torch +from torch import Tensor + +from .BaseLoss import BaseLoss + + +class WassersteinLoss(BaseLoss): + """ + This class implements the loss function for the discriminator in a + Wasserstein Generative Adversarial Network (wGAN). + The discriminator loss measures how well the discriminator is able to + distinguish between real (ground expected_truth) + images and fake (expected_generated) images produced by the generator. + """ + def __init__(self, _metric_name='WassersteinLoss'): + super().__init__(_metric_name=_metric_name) + + def forward( + self, + p_real_as_real: Tensor, + p_fake_as_real: Tensor + ) -> Tensor: + """ + Computes the Wasserstein Discriminator Loss given probability scores + """ + + # Ensure reduction of p tensors to [batch_size, 1] + if p_real_as_real.dim() >= 3: + p_real_as_real = torch.mean(p_real_as_real, tuple(range(2, p_real_as_real.dim()))) + if p_fake_as_real.dim() >= 3: + p_fake_as_real = torch.mean(p_fake_as_real, tuple(range(2, p_fake_as_real.dim()))) + + return (p_fake_as_real - p_real_as_real).mean() + + +class GradientPenaltyLoss(BaseLoss): + """ + This class implements the gradient penalty loss for the discriminator in a + Wasserstein Generative Adversarial Network (wGAN-GP). + The gradient penalty is used to enforce the Lipschitz constraint on the discriminator, + which helps stabilize the training of the GAN. + """ + def __init__(self, _metric_name='GradientPenaltyLoss'): + super().__init__(_metric_name=_metric_name) + + def forward( + self, + real_stack: torch.Tensor, + fake_stack: torch.Tensor, + discriminator: torch.nn.Module, + ): + """ + Computes the Gradient Penalty Loss given the gradients of the discriminator's output + with respect to its input. + """ + + device = next(discriminator.parameters()).device + batch_size = real_stack.size(0) + eta = torch.rand(batch_size, 1, 1, 1, device=device).expand_as(real_stack) + interpolated = (eta * real_stack + (1 - eta) * fake_stack).requires_grad_(True) + p_interpolated = discriminator(interpolated) + + gradients = torch.autograd.grad( + outputs=p_interpolated, + inputs=interpolated, + grad_outputs=torch.ones_like(p_interpolated, device=device), + create_graph=True, + retain_graph=True, + )[0] + + gradients = gradients.view(batch_size, -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +class AdversarialLoss(BaseLoss): + """ + Adversarial loss for the generator in a Wasserstein Generative Adversarial Network (wGAN). + """ + + def __init__(self, _metric_name='AdversarialLoss'): + super().__init__(_metric_name=_metric_name) + + def forward(self, p_fake_as_real: torch.Tensor): + """ + Computes the Adversarial Loss for the generator given the probability scores + assigned by the discriminator to the fake (generated) images. + """ + + # Ensure reduction of p tensors to [batch_size, 1] + if p_fake_as_real.dim() >= 3: + p_fake_as_real = torch.mean(p_fake_as_real, tuple(range(2, p_fake_as_real.dim()))) + + return -p_fake_as_real.mean() diff --git a/src/virtual_stain_flow/models/blocks.py b/src/virtual_stain_flow/models/blocks.py index a1e9866..5071618 100644 --- a/src/virtual_stain_flow/models/blocks.py +++ b/src/virtual_stain_flow/models/blocks.py @@ -49,7 +49,7 @@ class AbstractBlock(ABC, nn.Module): def __init__( self, in_channels: int, - out_channels: int, + out_channels: Optional[int] = None, num_units: int = 1, **kwargs: dict ): @@ -63,6 +63,8 @@ def __init__( if in_channels <= 0: raise ValueError("Expected in_channels to be positive, " f"got {in_channels}") + if out_channels is None: + out_channels = in_channels if not isinstance(out_channels, int): raise TypeError("Expected out_channels to be int, " f"got {type(out_channels).__name__}") @@ -101,10 +103,9 @@ def num_units(self) -> int: # These 2 below should be overriden to reflect the actual spatial dimension # changes the block applies. By default they indicate spatial preserving # blocks, i.e. the height and width of the input tensor remain unchanged. - @property def out_h(self, in_h: int) -> int: return in_h - @property + def out_w(self, in_w: int) -> int: return in_w diff --git a/src/virtual_stain_flow/models/discriminator.py b/src/virtual_stain_flow/models/discriminator.py index 7129f59..615bba7 100644 --- a/src/virtual_stain_flow/models/discriminator.py +++ b/src/virtual_stain_flow/models/discriminator.py @@ -4,7 +4,7 @@ Implementation of GaN discriminators to use along with UNet or FNet generator. """ -from typing import Dict, Any +from typing import Dict, Any, Optional import torch from torch import nn @@ -109,6 +109,7 @@ def __init__( self, n_in_channels: int, n_in_filters: int, + out_activation: Optional[torch.nn.Module] = None, _conv_depth: int=4, _leaky_relu_alpha: float=0.2, _batch_norm: bool=False, @@ -120,6 +121,7 @@ def __init__( :param n_in_channels: (int) number of input channels :param n_in_filters: (int) number of filters in the first convolutional layer. Every subsequent layer will double the number of filters + :param out_activation: output activation function :param _conv_depth: (int) depth of the convolutional network :param _leaky_relu_alpha: (float) alpha value for leaky ReLU activation. ust be between 0 and 1 @@ -168,12 +170,14 @@ def __init__( nn.LazyLinear(512), nn.LeakyReLU(_leaky_relu_alpha, inplace=True), nn.Linear(512, 1), - nn.Sigmoid() ) + self.out_activation = out_activation or torch.nn.Identity() + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._conv_layers(x) x = self.fc(x) + x = self.out_activation(x) return x @@ -189,7 +193,6 @@ def to_config(self) -> Dict[str, Any]: "_conv_depth": self._conv_depth, "_leaky_relu_alpha": self._leaky_relu_alpha, "_batch_norm": self._batch_norm, - "_pool_before_fc": self._pool_before_fc, }, } diff --git a/src/virtual_stain_flow/models/model.py b/src/virtual_stain_flow/models/model.py index 9915335..8417ea1 100644 --- a/src/virtual_stain_flow/models/model.py +++ b/src/virtual_stain_flow/models/model.py @@ -14,7 +14,7 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Any import pathlib import torch @@ -78,7 +78,7 @@ def save_weights( return weight_file @abstractmethod - def to_config(self) -> Dict: + def to_config(self) -> Dict[str, Any]: """ Converts the model configuration to a dictionary format. @@ -88,7 +88,7 @@ def to_config(self) -> Dict: @classmethod @abstractmethod - def from_config(cls, config: Dict) -> 'BaseGeneratorModel': + def from_config(cls, config: Dict) -> 'BaseModel': """ Creates a model instance from a configuration dictionary. diff --git a/src/virtual_stain_flow/models/stages.py b/src/virtual_stain_flow/models/stages.py index 8599d96..bbc0dfb 100644 --- a/src/virtual_stain_flow/models/stages.py +++ b/src/virtual_stain_flow/models/stages.py @@ -173,18 +173,16 @@ def skip_channels(self) -> int: def out_channels(self) -> int: return self._out_channels - @property def out_h(self, in_h: int) -> int: _out_h = in_h - for block in self.blocks: + for block in [self.in_block, self.comp_block]: if isinstance(block, Conv2DDownBlock): _out_h = block.out_h(_out_h) return _out_h - @property def out_w(self, in_w: int) -> int: _out_w = in_w - for block in self.blocks: + for block in [self.in_block, self.comp_block]: if isinstance(block, Conv2DDownBlock): _out_w = block.out_w(_out_w) return _out_w diff --git a/src/virtual_stain_flow/trainers/AbstractTrainer.py b/src/virtual_stain_flow/trainers/AbstractTrainer.py index 1485322..6c64acf 100644 --- a/src/virtual_stain_flow/trainers/AbstractTrainer.py +++ b/src/virtual_stain_flow/trainers/AbstractTrainer.py @@ -177,7 +177,7 @@ def _init_data( return None @abstractmethod - def train_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + def train_step(self, inputs: torch.Tensor, targets: torch.Tensor)->Dict[str, float]: """ Abstract method for training the model on one batch Must be implemented by subclasses. @@ -194,7 +194,7 @@ def train_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, tor pass @abstractmethod - def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor)->Dict[str, torch.Tensor]: + def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor)->Dict[str, float]: """ Abstract method for evaluating the model on one batch Must be implemented by subclasses. @@ -361,7 +361,7 @@ def train( if hasattr(logger, "on_train_end"): logger.on_train_end() - def _collect_early_stop_metric(self) -> float: + def _collect_early_stop_metric(self) -> Optional[float]: if self._early_termination_metric is None: # Do not perform early stopping when no termination metric is specified early_term_metric = None @@ -563,7 +563,7 @@ def update_loss(self, self._train_losses[loss_name].append(loss) def update_metrics(self, - metric: torch.tensor, + metric: torch.Tensor, metric_name: str, validation: bool = False): if validation: diff --git a/src/virtual_stain_flow/trainers/logging_gan_trainer.py b/src/virtual_stain_flow/trainers/logging_gan_trainer.py new file mode 100644 index 0000000..8e5fe69 --- /dev/null +++ b/src/virtual_stain_flow/trainers/logging_gan_trainer.py @@ -0,0 +1,332 @@ +""" +Logging GAN Trainer + +This module defines the LoggingGANTrainer class, which extends the +AbstractTrainer to provide training and evaluation functionalities for a GAN +model using the engine subpackage for forward passes and loss computations. +""" + +import pathlib +from typing import Dict, List, Union, Optional + +import torch + +from .AbstractTrainer import AbstractTrainer +from ..engine.loss_group import LossGroup, LossItem +from ..losses.wgan_losses import ( + WassersteinLoss, + GradientPenaltyLoss, + AdversarialLoss +) +from ..engine.forward_groups import GeneratorForwardGroup, DiscriminatorForwardGroup +from ..engine.orchestrators import GANOrchestrator + +Scalar = Union[int, float] + + +class BaseGANTrainer(AbstractTrainer): + """ + Flexible trainer class for GAN models with logging. + """ + + def __init__( + self, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + generator_optimizer: torch.optim.Optimizer, + discriminator_optimizer: torch.optim.Optimizer, + generator_loss_group: LossGroup, + discriminator_loss_group: LossGroup, + n_discriminator_steps: int = 3, + **kwargs + ): + """ + Initialize the trainer with the GAN orchestrator and loss groups. + + :param generator: The generator model to be trained. + :param discriminator: The discriminator model to be trained. + :param generator_optimizer: The optimizer for the generator. + :param discriminator_optimizer: The optimizer for the discriminator. + :param generator_loss_group: The loss group for the generator. + :param discriminator_loss_group: The loss group for the discriminator. + :param n_discriminator_steps: Number of discriminator steps per generator step. + :kwargs: Additional arguments for the AbstractTrainer + """ + + device = kwargs.get('device', torch.device('cpu')) + + # Registry for logging model parameters + self._models: List[torch.nn.Module] = [generator, discriminator] + + generator_fg = GeneratorForwardGroup( + generator=generator, + optimizer=generator_optimizer, + device=device + ) + + discriminator_fg = DiscriminatorForwardGroup( + discriminator=discriminator, + optimizer=discriminator_optimizer, + device=device + ) + + self._orchestrator = GANOrchestrator( + generator_fg=generator_fg, + discriminator_fg=discriminator_fg + ) + + self._generator_loss_group: LossGroup = generator_loss_group + self._discriminator_loss_group: LossGroup = discriminator_loss_group + + # Internal counters for update frequencies + self._n_discriminator_steps: int = n_discriminator_steps + self._global_step: int = 0 + + super().__init__( + model=generator, # register generator as main model for early stopping + optimizer=generator_optimizer, + losses=generator_loss_group, + **kwargs + ) + + def train_step( + self, + inputs: torch.Tensor, + targets: torch.Tensor + ) -> Dict[str, float]: + """ + Perform a single training step for both the generator and discriminator. + + :param inputs: The input tensor for the models. + :param targets: The target tensor for the models. + :return: A dictionary containing the loss values for both generator and discriminator. + """ + # Always update discriminator + #disc_ctx = self._discriminator_forward_group( + disc_ctx = self._orchestrator.discriminator_step( + train=True, + inputs=inputs, + targets=targets + ) + disc_weighted_total, disc_logs = self._discriminator_loss_group( + train=True, + context=disc_ctx + ) + disc_weighted_total.backward() + #self._discriminator_forward_group.step() + self._orchestrator.discriminator_step.step() + + # always evaluate metrics on discriminator context + # which will always represent the most updated generator state + ctx = disc_ctx + for _, metric in self.metrics.items(): + metric.update(*ctx.as_metric_args(), validation=True) + + # Update generator 1 in every n_discriminator_steps + if (self._global_step % self._n_discriminator_steps) == 0: + + # Generator step + gen_ctx = self._orchestrator.generator_step( + train=True, + inputs=inputs, + targets=targets + ) + gen_weighted_total, gen_logs = self._generator_loss_group( + train=True, + context=gen_ctx + ) + gen_weighted_total.backward() + self._orchestrator.generator_step.step() + else: + gen_ctx = None + gen_logs = {} + + self._global_step += 1 + + # if generator logs are not computed this step (due to skipped update), + # compute from discriminator context + if not gen_logs: + _, gen_logs = self._generator_loss_group( + train=True, + context=ctx + ) + + return gen_logs | disc_logs + + def evaluate_step( + self, + inputs: torch.Tensor, + targets: torch.Tensor + ) -> Dict[str, float]: + """ + Perform a single evaluation step for both the generator and discriminator. + + :param inputs: The input tensor for the models. + :param targets: The target tensor for the models. + :return: A dictionary containing the loss values for both generator and discriminator. + """ + + # Taking a shortcut here by only evaluating through + # the discriminator forward group, which will also + # contain the generator outputs + ctx = self._orchestrator.discriminator_step( + train=False, + inputs=inputs, + targets=targets + ) + _, gen_logs = self._generator_loss_group( + train=False, + context=ctx + ) + _, disc_logs = self._discriminator_loss_group( + train=False, + context=ctx + ) + + for _, metric in self.metrics.items(): + metric.update(*ctx.as_metric_args(), validation=True) + + return gen_logs | disc_logs + + def save_model( + self, + save_path: pathlib.Path, + file_name_prefix: Optional[str] = None, + file_name_suffix: Optional[str] = None, + file_ext: str = '.pth', + best_model: bool = True + ) -> Optional[List[pathlib.Path]]: + + if file_name_suffix is None: + file_name_suffix = 'weights_' + ( + 'best' if best_model else str(self.epoch) + ) + + gen_path = self.model.save_weights( + filename=f"generator_{file_name_suffix}{file_ext}", + dir=save_path + ) + + return [gen_path] + +class LoggingWGANTrainer(BaseGANTrainer): + """ + Predefined WGAN trainer needing only drop-in generator losses and weights. + + Under default settings, this trainer: + + trains the generator with: + - AdverserialLoss() operating on p_fake_as_real from the discriminator + - + [any additional provided generator losses] + + trains the discriminator with: + - WassersteinLoss() operating on p_real_as_real and p_fake_as_real + - GradientPenaltyLoss() operating on real_stack, fake_stack, and discriminator + """ + + def __init__( + self, + *, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + generator_optimizer: torch.optim.Optimizer, + discriminator_optimizer: torch.optim.Optimizer, + generator_losses: Union[torch.nn.Module, List[torch.nn.Module]], + generator_loss_weights: Optional[Union[Scalar, List[Scalar]]] = None, + generator_adverserial_loss: torch.nn.Module = AdversarialLoss(), + generator_adverserial_loss_weight: Scalar = 1.0, + discriminator_loss: torch.nn.Module = WassersteinLoss(), + discriminator_loss_weight: Scalar = 1.0, + discriminator_gradient_penalty_loss: torch.nn.Module = GradientPenaltyLoss(), + discriminator_gradient_penalty_loss_weight: Scalar = 10.0, + n_discriminator_steps: int = 3, + **kwargs + ): + """ + Initialize the WGAN trainer with the GAN orchestrator and loss groups. + + :param generator: The generator model to be trained. + :param discriminator: The discriminator model to be trained. + :param generator_optimizer: The optimizer for the generator. + :param discriminator_optimizer: The optimizer for the discriminator. + :param generator_losses: The loss function(s) for the generator. + :param generator_loss_weights: The weight(s) for the generator loss function(s). + :param discriminator_loss: The loss function for the discriminator. + :param discriminator_loss_weight: The weight for the discriminator loss function. + :param discriminator_gradient_penalty_loss: The gradient penalty loss function for the discriminator. + :param discriminator_loss_weight: The weight for the discriminator loss function. + :param discriminator_gradient_penalty_loss: The gradient penalty loss function for the discriminator. + :param discriminator_gradient_penalty_loss_weight: The weight for the gradient penalty loss function. + :param n_discriminator_steps: Number of discriminator steps per generator step. + :kwargs: Additional arguments for the AbstractTrainer + """ + + device = kwargs.get('device', torch.device('cpu')) + + generator_losses = generator_losses if isinstance( + generator_losses, + list + ) else [generator_losses] + if generator_loss_weights is None: + generator_loss_weights = [1.0] * len(generator_losses) + elif isinstance(generator_loss_weights, Scalar): + generator_loss_weights = [generator_loss_weights] * len(generator_losses) + elif isinstance(generator_loss_weights, list): + if len(generator_loss_weights) != len(generator_losses): + raise ValueError( + "Length of generator_loss_weights must match length of generator_losses." + ) + else: + raise TypeError( + "generator_loss_weights must be a float or list of floats." + ) + + generator_loss_group = LossGroup( + items=[ + LossItem( + module=loss, + weight=weight, + args=('preds', 'targets'), + device=device + ) + for loss, weight in zip( + generator_losses, + generator_loss_weights + ) + ] + [ + LossItem( + module=generator_adverserial_loss, + weight=generator_adverserial_loss_weight, + args=('p_fake_as_real',), + device=device + ) + ] + ) + + discriminator_loss_group = LossGroup( + items=[ + LossItem( + module=discriminator_loss, + weight=discriminator_loss_weight, + args=('p_real_as_real', 'p_fake_as_real'), + device=device + ), + LossItem( + module=discriminator_gradient_penalty_loss, + weight=discriminator_gradient_penalty_loss_weight, + args=('real_stack', 'fake_stack', 'discriminator'), + device=device + ) + ] + ) + + super().__init__( + generator=generator, + discriminator=discriminator, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_loss_group=generator_loss_group, + discriminator_loss_group=discriminator_loss_group, + n_discriminator_steps=n_discriminator_steps, + **kwargs + ) diff --git a/src/virtual_stain_flow/trainers/logging_trainer.py b/src/virtual_stain_flow/trainers/logging_trainer.py index 9fe8da7..6442dea 100644 --- a/src/virtual_stain_flow/trainers/logging_trainer.py +++ b/src/virtual_stain_flow/trainers/logging_trainer.py @@ -16,6 +16,8 @@ from ..engine.loss_group import LossGroup, LossItem from ..engine.forward_groups import GeneratorForwardGroup +Scalar = Union[int, float] + class SingleGeneratorTrainer(AbstractTrainer): """ @@ -28,7 +30,7 @@ def __init__( optimizer: torch.optim.Optimizer, losses: Union[torch.nn.Module, List[torch.nn.Module]], device: torch.device, - loss_weights: Optional[Union[float, List[float]]] = None, + loss_weights: Optional[Union[Scalar, List[Scalar]]] = None, **kwargs ): """ @@ -42,6 +44,9 @@ def __init__( :kwargs: Additional arguments for the AbstractTrainer (for data/metric and more) """ + # Registry for logging model parameters + self._models: List[torch.nn.Module] = [model] + self._forward_group = GeneratorForwardGroup( generator=model, optimizer=optimizer, @@ -51,11 +56,16 @@ def __init__( losses = losses if isinstance(losses, list) else [losses] if loss_weights is None: loss_weights = [1.0] * len(losses) - elif isinstance(loss_weights, float): + elif isinstance(loss_weights, Scalar): loss_weights = [loss_weights] * len(losses) - elif len(loss_weights) != len(losses): - raise ValueError( - "Length of loss_weights must match length of losses." + elif isinstance(loss_weights, List): + if len(loss_weights) != len(losses): + raise ValueError( + "Length of loss_weights must match length of losses." + ) + else: + raise TypeError( + "loss_weights must be a float or list of floats." ) self._loss_group = LossGroup( @@ -72,7 +82,7 @@ def __init__( super().__init__( model=self._forward_group.model, - optimizer=self._forward_group.optimizer, + optimizer=self._forward_group.optimizer, # type: ignore **kwargs ) @@ -138,8 +148,6 @@ def save_model( file_ext: str = '.pth', best_model: bool = True ) -> Optional[List[pathlib.Path]]: - pass - if file_name_prefix is None: file_name_prefix = 'generator' diff --git a/src/virtual_stain_flow/trainers/trainer_protocol.py b/src/virtual_stain_flow/trainers/trainer_protocol.py index 4494330..80b55cc 100644 --- a/src/virtual_stain_flow/trainers/trainer_protocol.py +++ b/src/virtual_stain_flow/trainers/trainer_protocol.py @@ -4,7 +4,8 @@ Protocol for defining behavior and needed attributes of a trainer class. """ -from typing import Protocol, Dict, runtime_checkable +import pathlib +from typing import Protocol, Dict, runtime_checkable, Any, List, Optional import torch @@ -38,7 +39,7 @@ def train_epoch(self) -> Dict[str, float]: ... def evaluate_epoch(self) -> Dict[str, float]: ... - def train(self, num_epochs: int) -> None: ... + def train(self, *args: Any, **kwargs: Any) -> None: ... @property def epoch(self) -> int: ... @@ -54,3 +55,13 @@ def model(self) -> torch.nn.Module: ... @property def best_model(self) -> torch.nn.Module: ... + + def save_model( + self, + save_path: pathlib.Path, + file_name_prefix: Optional[str], + file_name_suffix: Optional[str], + file_ext: str = '.pth', + best_model: bool = True, + ) -> Optional[List[pathlib.Path]]: + ... diff --git a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py index b48c1ac..a46c61d 100644 --- a/src/virtual_stain_flow/vsf_logging/MlflowLogger.py +++ b/src/virtual_stain_flow/vsf_logging/MlflowLogger.py @@ -2,6 +2,7 @@ MlflowLogger.py """ +import json import pathlib import tempfile from typing import Union, Dict, Optional, List, Any @@ -9,6 +10,7 @@ import mlflow from torch import nn +from ..models.model import BaseModel from ..trainers.trainer_protocol import TrainerProtocol from .callbacks.LoggerCallback import ( AbstractLoggerCallback, @@ -41,8 +43,8 @@ def __init__( description: Optional[str] = None, target_channel_name: Optional[str] = None, tags: Optional[Dict[str, str]] = None, - mlflow_start_run_args: dict = None, - mlflow_log_params_args: dict = None, + mlflow_start_run_args: Optional[Dict] = None, + mlflow_log_params_args: Optional[Dict] = None, callbacks: Optional[List[Any]] = None, save_model_at_train_end: bool = True, save_model_every_n_epochs: Optional[int] = None, @@ -193,6 +195,33 @@ def on_train_start(self): raise TypeError("mlflow_log_params_args must be None or a dictionary.") self._run_id = mlflow.active_run().info.run_id + + # log model config if available + if hasattr(self.trainer, '_models') and isinstance(self.trainer._models, List): + models = self.trainer._models + elif hasattr(self.trainer, 'model'): + models = [self.trainer.model] + else: + models = [] + + for model in models: + + if isinstance(model, BaseModel) and hasattr(model, 'to_config'): + try: + config = model.to_config() # type: ignore + except Exception as e: + print(f"Could not get model config for logging: {e}") + config = None + if config: + try: + self.log_config( + tag=model.__class__.__name__, + config=config, # type: ignore + stage=None + ) + except Exception as e: + print(f"Fail to log model config as artifact: {e}") + for callback in self.callbacks: # TODO consider if we want hasattr checks @@ -467,11 +496,57 @@ def _save_model_weights( best_model=best_model ) - for saved_file_path in saved_file_paths: + for saved_file_path in (saved_file_paths or []): mlflow.log_artifact( str(saved_file_path), artifact_path=artifact_path ) + + def log_config( + self, + tag: str, + config: Dict[str, Any], + stage: Optional[str] = None, + ) -> None: + """ + Serialize a configuration dict to JSON, save it to a temporary file, + and log it to MLflow as an artifact. + + :param tag: Name/identifier for the config (used in filename / artifact path). + :param config: The configuration to log (must be a dictionary). + :param stage: Optional stage to nest under the artifact path + :raises TypeError: If `config` is not a dict. + """ + + if not isinstance(config, dict): + raise TypeError(f"`config` must be a dict, got {type(config).__name__}") + + # Where to place it inside MLflow’s artifact store (a directory path) + artifact_path = "/".join(p for p in ("configs", stage) if p) + + # Write JSON into a temporary directory so MLflow can copy it, then clean up. + with tempfile.TemporaryDirectory(prefix="log_config_") as tmpdir: + tmpdir_path = pathlib.Path(tmpdir) + file_path = tmpdir_path / f"{tag}.json" + + # Use default=str to avoid failures on non-JSON-serializable objects + # (e.g., pathlib.Path, numpy types, Enums); they'll be stringified. + file_path.write_text( + json.dumps( + config, + indent=2, + sort_keys=True, + ensure_ascii=False, + default=str, + ), + encoding="utf-8", + ) + + # Log the JSON file as an MLflow artifact + mlflow.log_artifact( + str(file_path), + artifact_path=artifact_path + ) """ Access point for callback to model diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py index aae74bb..63617f8 100644 --- a/tests/engine/conftest.py +++ b/tests/engine/conftest.py @@ -112,6 +112,43 @@ def forward(self, x): return MultiOutputConv() +@pytest.fixture +def simple_discriminator(): + """ + Simple discriminator model for GAN testing. + Takes concatenated input/target stack (B, 6, H, W) -> outputs score (B, 1) + Uses conv + global average pooling + linear layer. + """ + class SimpleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d( + in_channels=6, # stacked input + target + out_channels=16, + kernel_size=3, + padding=1, + bias=True + ) + self.pool = nn.AdaptiveAvgPool2d(1) # Global average pooling + self.fc = nn.Linear(16, 1) # Output single score + + def forward(self, x): + x = self.conv(x) + x = torch.relu(x) + x = self.pool(x) # (B, 16, 1, 1) + x = x.flatten(1) # (B, 16) + x = self.fc(x) # (B, 1) + return x + + return SimpleDiscriminator() + + +@pytest.fixture +def random_stack(): + """Random stack tensor (batch=2, channels=6, height=8, width=8) for discriminator.""" + return torch.randn(2, 6, 8, 8) + + @pytest.fixture def sample_inputs(): """Create sample inputs for loss computation.""" @@ -208,3 +245,31 @@ def forward_pass_context_eval(forward_group, random_input, random_target, torch_ inputs=random_input.to(torch_device), targets=random_target.to(torch_device), ) + + +@pytest.fixture +def disc_optimizer(simple_discriminator): + """Create an Adam optimizer for the discriminator model.""" + import torch.optim as optim + return optim.Adam(simple_discriminator.parameters(), lr=1e-3) + + +@pytest.fixture +def discriminator_forward_group(simple_discriminator, disc_optimizer, torch_device): + """Create a DiscriminatorForwardGroup with the simple discriminator and optimizer.""" + from virtual_stain_flow.engine.forward_groups import DiscriminatorForwardGroup + return DiscriminatorForwardGroup( + discriminator=simple_discriminator, + optimizer=disc_optimizer, + device=torch_device, + ) + + +@pytest.fixture +def gan_orchestrator(forward_group, discriminator_forward_group): + """Create a GANOrchestrator with generator and discriminator forward groups.""" + from virtual_stain_flow.engine.orchestrators import GANOrchestrator + return GANOrchestrator( + generator_fg=forward_group, + discriminator_fg=discriminator_forward_group, + ) diff --git a/tests/engine/test_context.py b/tests/engine/test_context.py index 7c99f00..5b623e2 100644 --- a/tests/engine/test_context.py +++ b/tests/engine/test_context.py @@ -42,19 +42,32 @@ def test_context_with_module(self, simple_conv_model, method): def test_context_getitem(self, random_input): """Test retrieving items from context.""" ctx = Context(inputs=random_input) - retrieved = ctx[INPUTS] - assert torch.equal(retrieved, random_input) + retrieved = ctx[INPUTS] + assert torch.equal(retrieved, random_input)#type: ignore + + def test_invalid_context_value_type(self): + """Test that adding invalid context value types raises TypeError.""" + with pytest.raises( + TypeError, + match="Context values must be torch.Tensor or torch.nn.Module" + ): + ctx = Context() + ctx.add(invalid_value=42) # type: ignore - @pytest.mark.parametrize("key,value,expected_msg", [ - (PREDS, "not a tensor", "Reserved key 'preds' must be a torch.Tensor"), - (TARGETS, 42, "Reserved key 'targets' must be a torch.Tensor"), - (INPUTS, [1, 2, 3], "Reserved key 'inputs' must be a torch.Tensor"), + @pytest.mark.parametrize("key,expected_msg", [ + (PREDS, "Reserved key 'preds' must be a torch.Tensor"), + (TARGETS, "Reserved key 'targets' must be a torch.Tensor"), + (INPUTS, "Reserved key 'inputs' must be a torch.Tensor"), ]) - def test_context_reserved_key_type_error(self, key, value, expected_msg): + def test_context_reserved_key_type_error(self, key, expected_msg, simple_conv_model): """Test that reserved keys must be tensors.""" - with pytest.raises(ReservedKeyTypeError, match=expected_msg): + with pytest.raises( + ReservedKeyTypeError, + match=expected_msg + ): ctx = Context() - ctx.add(**{key: value}) + # try adding a module under a reserved tensor key + ctx.add(**{key: simple_conv_model}) def test_context_generator_model_addition(self, simple_conv_model): """Test adding generator model with reserved key.""" @@ -63,11 +76,15 @@ def test_context_generator_model_addition(self, simple_conv_model): assert isinstance(ctx[GENERATOR_MODEL], nn.Module) assert ctx[GENERATOR_MODEL] is simple_conv_model - def test_context_reserved_model_key_type_error(self): + def test_context_reserved_model_key_type_error(self, random_input): """Test that reserved model keys must be torch.nn.Module.""" - with pytest.raises(ReservedKeyTypeError, match="Reserved key 'generator' must be a torch.nn.Module"): + with pytest.raises( + ReservedKeyTypeError, + match="Reserved key 'generator' must be a torch.nn.Module" + ): ctx = Context() - ctx.add(generator="not a module") + # try adding a tensor under a reserved model key + ctx.add(generator=random_input) #type: ignore class TestContextRequire: @@ -146,7 +163,7 @@ def test_repr_with_module(self, simple_conv_model): def test_getitem(self, random_input): """Test __getitem__ retrieves stored values.""" ctx = Context(inputs=random_input) - assert torch.equal(ctx[INPUTS], random_input) + assert torch.equal(ctx[INPUTS], random_input) #type: ignore def test_getitem_missing_key(self): """Test __getitem__ raises KeyError for missing key.""" @@ -203,15 +220,15 @@ def test_setitem(self, random_input): """Test __setitem__ to add/update values.""" ctx = Context() ctx[INPUTS] = random_input - assert torch.equal(ctx[INPUTS], random_input) + assert torch.equal(ctx[INPUTS], random_input) #type: ignore assert len(ctx) == 1 def test_setitem_override(self, random_input, random_target): """Test __setitem__ overrides existing value.""" ctx = Context(inputs=random_input) ctx[INPUTS] = random_target - assert torch.equal(ctx[INPUTS], random_target) - assert not torch.equal(ctx[INPUTS], random_input) + assert torch.equal(ctx[INPUTS], random_target) #type: ignore + assert not torch.equal(ctx[INPUTS], random_input) #type: ignore def test_contains_present_key(self, random_input): """Test __contains__ for present key.""" @@ -229,7 +246,7 @@ def test_get_existing_key(self, random_input): """Test get() with existing key.""" ctx = Context(inputs=random_input) retrieved = ctx.get(INPUTS) - assert torch.equal(retrieved, random_input) + assert torch.equal(retrieved, random_input) #type: ignore def test_get_missing_key_default_none(self): """Test get() with missing key returns None by default.""" @@ -241,5 +258,80 @@ def test_get_missing_key_custom_default(self, random_input): """Test get() with missing key returns custom default.""" ctx = Context() default_value = "default" - result = ctx.get("nonexistent", default_value) + result = ctx.get("nonexistent", default_value) #type: ignore assert result == default_value + + +class TestContextMerge: + """Test Context merge operations using | operator.""" + + def test_or_merge_basic(self, random_input, random_target): + """Test basic merge of two contexts with different keys.""" + ctx1 = Context(inputs=random_input) + ctx2 = Context(targets=random_target) + + merged = ctx1 | ctx2 + + assert INPUTS in merged + assert TARGETS in merged + assert torch.equal(merged[INPUTS], random_input) #type: ignore + assert torch.equal(merged[TARGETS], random_target) #type: ignore + + def test_or_merge_precedence(self, random_input, random_target): + """Test that right operand takes precedence in key conflicts.""" + ctx1 = Context(inputs=random_input) + ctx2 = Context(inputs=random_target) + + merged = ctx1 | ctx2 + + # ctx2's value should win + assert torch.equal(merged[INPUTS], random_target) #type: ignore + assert not torch.equal(merged[INPUTS], random_input) #type: ignore + + def test_or_original_unchanged(self, random_input, random_target): + """Test that original contexts are unchanged after merge.""" + ctx1 = Context(inputs=random_input) + ctx2 = Context(targets=random_target) + + _ = ctx1 | ctx2 + + # Original contexts should remain unchanged + assert len(ctx1) == 1 + assert len(ctx2) == 1 + assert TARGETS not in ctx1 + assert INPUTS not in ctx2 + + def test_ror_merge_basic(self, random_input, random_target): + """Test reverse merge with same result as forward merge.""" + ctx1 = Context(inputs=random_input) + ctx2 = Context(targets=random_target) + + # Both should produce same result for non-overlapping keys + merged_or = ctx1 | ctx2 + merged_ror = ctx2.__ror__(ctx1) + + assert set(merged_or.keys()) == set(merged_ror.keys()) + assert torch.equal(merged_or[INPUTS], merged_ror[INPUTS]) #type: ignore + assert torch.equal(merged_or[TARGETS], merged_ror[TARGETS]) #type: ignore + + def test_or_not_implemented(self, random_input): + """Test __or__ returns NotImplemented for non-Context operand.""" + ctx = Context(inputs=random_input) + x = {} + + with pytest.raises( + NotImplementedError, + match="__or__ operation only supported between Context objects." + ): + _ = ctx | x # type: ignore + + def test_ror_not_implemented(self, random_input): + """Test __ror__ returns NotImplemented for non-Context operand.""" + ctx = Context(inputs=random_input) + x = {} + + with pytest.raises( + NotImplementedError, + match="__or__ operation only supported between Context objects." + ): + _ = x | ctx # type: ignore diff --git a/tests/engine/test_forward_group.py b/tests/engine/test_forward_group.py index 93e6a9e..b4b6c7d 100644 --- a/tests/engine/test_forward_group.py +++ b/tests/engine/test_forward_group.py @@ -5,7 +5,8 @@ from virtual_stain_flow.engine.forward_groups import ( AbstractForwardGroup, - GeneratorForwardGroup + GeneratorForwardGroup, + DiscriminatorForwardGroup ) from virtual_stain_flow.engine.names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL @@ -128,3 +129,52 @@ def test_forward_output_arity_mismatch(self, multi_output_model, random_input, r with pytest.raises(ValueError, match="Model returned 2 outputs.*output_keys expects 1"): forward_group(train=False, inputs=random_input, targets=random_target) + + +class TestDiscriminatorForwardGroup: + """Test DiscriminatorForwardGroup functionality.""" + + def test_forward_train_mode(self, simple_discriminator, random_stack): + """Test that discriminator is set to train mode when train=True.""" + forward_group = DiscriminatorForwardGroup( + device=torch.device("cpu"), + discriminator=simple_discriminator + ) + + ctx = forward_group(train=True, stack=random_stack) + + assert forward_group.model.training is True + assert ctx["p"].requires_grad is True + + def test_forward_eval_mode(self, simple_discriminator, random_stack): + """Test that discriminator is set to eval mode when train=False.""" + forward_group = DiscriminatorForwardGroup( + device=torch.device("cpu"), + discriminator=simple_discriminator + ) + + ctx = forward_group(train=False, stack=random_stack) + + assert forward_group.model.training is False + assert ctx["p"].requires_grad is False + + def test_optimizer_zero_grad(self, simple_discriminator, disc_optimizer, random_stack): + """Test that optimizer.zero_grad() is called when train=True.""" + forward_group = DiscriminatorForwardGroup( + device=torch.device("cpu"), + discriminator=simple_discriminator, + optimizer=disc_optimizer + ) + + # Manually create some gradients + dummy_loss = sum(p.sum() for p in forward_group.model.parameters()) + dummy_loss.backward() + + # Check that gradients exist + assert any(p.grad is not None for p in forward_group.model.parameters()) + + # Forward should zero gradients + _ = forward_group(train=True, stack=random_stack) + + # Gradients should be None (set_to_none=True) + assert all(p.grad is None for p in forward_group.model.parameters()) diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py new file mode 100644 index 0000000..0493e62 --- /dev/null +++ b/tests/engine/test_orchestrator.py @@ -0,0 +1,63 @@ +"""Tests for GANOrchestrator.""" + +import torch + +from virtual_stain_flow.engine.names import INPUTS, TARGETS, PREDS + + +class TestGANOrchestrator: + """Test GANOrchestrator functionality.""" + + def test_discriminator_forward(self, gan_orchestrator, random_input, random_target): + """Test that _discriminator_forward produces correct context with real and fake stacks.""" + ctx = gan_orchestrator._discriminator_forward( + train=False, + inputs=random_input, + targets=random_target + ) + + # Check that generator outputs are present + assert INPUTS in ctx + assert TARGETS in ctx + assert PREDS in ctx + + # Check that discriminator outputs for real and fake are present + assert "real_stack" in ctx + assert "fake_stack" in ctx + assert "p_real_as_real" in ctx + assert "p_fake_as_real" in ctx + + # Verify shapes + batch_size = random_input.shape[0] + assert ctx["p_real_as_real"].shape[0] == batch_size + assert ctx["p_fake_as_real"].shape[0] == batch_size + + # Verify real_stack is concatenation of inputs and targets + expected_real_stack = torch.cat([ctx[INPUTS], ctx[TARGETS]], dim=1) + assert torch.allclose(ctx["real_stack"], expected_real_stack) + + # Verify fake_stack is concatenation of inputs and preds + expected_fake_stack = torch.cat([ctx[INPUTS], ctx[PREDS]], dim=1) + assert torch.allclose(ctx["fake_stack"], expected_fake_stack) + + def test_generator_forward(self, gan_orchestrator, random_input, random_target): + """Test that _generator_forward produces correct context with generator outputs and discriminator score.""" + ctx = gan_orchestrator._generator_forward( + train=False, + inputs=random_input, + targets=random_target + ) + + # Check that generator outputs are present + assert INPUTS in ctx + assert TARGETS in ctx + assert PREDS in ctx + + # Check that discriminator score for fake is present + assert "p_fake_as_real" in ctx + + # Verify shapes + batch_size = random_input.shape[0] + assert ctx[PREDS].shape[0] == batch_size + assert ctx["p_fake_as_real"].shape[0] == batch_size + assert ctx["p_fake_as_real"].shape[1] == 1 # Single score output diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index 6992140..cd5e443 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -29,9 +29,9 @@ class DummyProgressBar: def set_postfix_str(self, *args, **kwargs): pass - self._epoch_pbar = DummyProgressBar() + self._epoch_pbar = DummyProgressBar() # type: ignore - def train_step(self, inputs: torch.tensor, targets: torch.tensor) -> dict: + def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: """ Minimal train step that returns a dict of losses. Stores call information for verification. @@ -48,7 +48,7 @@ def train_step(self, inputs: torch.tensor, targets: torch.tensor) -> dict: 'loss_b': torch.tensor(0.3), } - def evaluate_step(self, inputs: torch.tensor, targets: torch.tensor) -> dict: + def evaluate_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> dict: """ Minimal evaluate step that returns a dict of losses. Stores call information for verification. @@ -154,3 +154,55 @@ def conv_trainer(conv_model, conv_optimizer, simple_loss, image_train_loader, im early_termination_metric='MSELoss' ) return trainer + + +@pytest.fixture +def simple_discriminator(): + """ + Simple discriminator model for GAN testing. + Takes concatenated input/target stack (B, 2, H, W) -> outputs score (B, 1) + """ + import torch.nn as nn + + class SimpleDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1) + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + x = torch.relu(self.conv(x)) + x = self.pool(x).flatten(1) + return self.fc(x) + + return SimpleDiscriminator() + + +@pytest.fixture +def discriminator_optimizer(simple_discriminator): + """Create an optimizer for the discriminator.""" + return torch.optim.Adam(simple_discriminator.parameters(), lr=0.0001) + + +@pytest.fixture +def wgan_trainer(conv_model, simple_discriminator, conv_optimizer, discriminator_optimizer, + simple_loss, image_train_loader, image_val_loader): + """ + Create a LoggingWGANTrainer for testing. + """ + from virtual_stain_flow.trainers.logging_gan_trainer import LoggingWGANTrainer + + trainer = LoggingWGANTrainer( + generator=conv_model, + discriminator=simple_discriminator, + generator_optimizer=conv_optimizer, + discriminator_optimizer=discriminator_optimizer, + generator_losses=simple_loss, + device=torch.device('cpu'), + train_loader=image_train_loader, + val_loader=image_val_loader, + batch_size=4, + n_discriminator_steps=3 + ) + return trainer diff --git a/tests/trainers/test_logging_wgan_trainer.py b/tests/trainers/test_logging_wgan_trainer.py new file mode 100644 index 0000000..54004b0 --- /dev/null +++ b/tests/trainers/test_logging_wgan_trainer.py @@ -0,0 +1,179 @@ +""" +Tests for LoggingWGANTrainer train_step and evaluate_step methods +""" + +import torch + + +class TestLoggingWGANTrainerTrainStep: + """Tests for LoggingWGANTrainer.train_step method.""" + + def test_train_step_returns_dict(self, wgan_trainer): + """Test that train_step returns a dictionary.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + losses = wgan_trainer.train_step(inputs, targets) + + assert isinstance(losses, dict) + + def test_train_step_returns_generator_and_discriminator_losses(self, wgan_trainer): + """Test that train_step returns both generator and discriminator losses.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + losses = wgan_trainer.train_step(inputs, targets) + + # Should have generator loss (MSE + Adversarial) + assert 'MSELoss' in losses + assert 'AdversarialLoss' in losses + # Should have discriminator losses (Wasserstein + GP) + assert 'WassersteinLoss' in losses + assert 'GradientPenaltyLoss' in losses + + def test_train_step_updates_discriminator_every_step(self, wgan_trainer): + """Test that discriminator is updated every training step.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + # Get initial discriminator parameters + disc_params_before = [ + p.clone() for p in wgan_trainer._orchestrator.discriminator_forward_group.model.parameters() + ] + + # Run train step + wgan_trainer.train_step(inputs, targets) + + # Check that discriminator parameters changed + disc_params_after = list(wgan_trainer._orchestrator.discriminator_forward_group.model.parameters()) + + params_changed = any( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(disc_params_before, disc_params_after) + ) + assert params_changed + + def test_train_step_generator_update_frequency(self, wgan_trainer): + """Test that generator is updated according to n_discriminator_steps.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + # Reset global step to ensure consistent starting point + wgan_trainer._global_step = 0 + + # Get initial generator parameters + gen_params_before = [ + p.clone() for p in wgan_trainer.model.parameters() + ] + + # Run first train step (_global_step=0, should update) + wgan_trainer.train_step(inputs, targets) + + gen_params_after_step0 = [ + p.clone() for p in wgan_trainer.model.parameters() + ] + + # Generator should have updated on step 0 + params_changed_step0 = any( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(gen_params_before, gen_params_after_step0) + ) + assert params_changed_step0, "Generator should update at step 0" + + # Run step 1 (_global_step=1, should NOT update) + wgan_trainer.train_step(inputs, targets) + + gen_params_after_step1 = [ + p.clone() for p in wgan_trainer.model.parameters() + ] + + # Generator should NOT have changed from step 0 to step 1 + params_unchanged_step1 = all( + torch.equal(p_step0, p_step1) + for p_step0, p_step1 in zip(gen_params_after_step0, gen_params_after_step1) + ) + assert params_unchanged_step1, "Generator should not update at step 1" + + # Run step 2 (_global_step=2, should NOT update) + wgan_trainer.train_step(inputs, targets) + + gen_params_after_step2 = [ + p.clone() for p in wgan_trainer.model.parameters() + ] + + # Generator should NOT have changed from step 1 to step 2 + params_unchanged_step2 = all( + torch.equal(p_step1, p_step2) + for p_step1, p_step2 in zip(gen_params_after_step1, gen_params_after_step2) + ) + assert params_unchanged_step2, "Generator should not update at step 2" + + # Run step 3 (_global_step=3, should update) + wgan_trainer.train_step(inputs, targets) + + gen_params_after_step3 = [ + p.clone() for p in wgan_trainer.model.parameters() + ] + + params_changed_step3 = any( + not torch.equal(p_step2, p_step3) + for p_step2, p_step3 in zip(gen_params_after_step2, gen_params_after_step3) + ) + assert params_changed_step3, "Generator should update at step 3" + + +class TestLoggingWGANTrainerEvaluateStep: + """Tests for LoggingWGANTrainer.evaluate_step method.""" + + def test_evaluate_step_returns_dict(self, wgan_trainer): + """Test that evaluate_step returns a dictionary.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + losses = wgan_trainer.evaluate_step(inputs, targets) + + assert isinstance(losses, dict) + + def test_evaluate_step_returns_generator_and_discriminator_losses(self, wgan_trainer): + """Test that evaluate_step returns both generator and discriminator losses.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + losses = wgan_trainer.evaluate_step(inputs, targets) + + # Should have generator losses + assert 'MSELoss' in losses + assert 'AdversarialLoss' in losses + # Should have discriminator losses + assert 'WassersteinLoss' in losses + assert 'GradientPenaltyLoss' in losses + + def test_evaluate_step_does_not_update_models(self, wgan_trainer): + """Test that evaluate_step does not update generator or discriminator.""" + inputs = torch.randn(2, 1, 16, 16) + targets = torch.randn(2, 1, 16, 16) + + # Get initial parameters + gen_params_before = [p.clone() for p in wgan_trainer.model.parameters()] + disc_params_before = [ + p.clone() for p in wgan_trainer._orchestrator.discriminator_forward_group.model.parameters() + ] + + # Run evaluate step + wgan_trainer.evaluate_step(inputs, targets) + + # Check that parameters did not change + gen_params_after = list(wgan_trainer.model.parameters()) + disc_params_after = list(wgan_trainer._orchestrator.discriminator_forward_group.model.parameters()) + + gen_changed = any( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(gen_params_before, gen_params_after) + ) + disc_changed = any( + not torch.equal(p_before, p_after) + for p_before, p_after in zip(disc_params_before, disc_params_after) + ) + + assert not gen_changed + assert not disc_changed