Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d16d917
Add pop method and merge operators to Context class for the convenien…
wli51 Jan 5, 2026
cea1cc3
Add DiscriminatorForwardGroup class for GAN discriminator management
wli51 Jan 5, 2026
123ece2
Add GANOrchestrator and OrchestratedStep class to abstract away the G…
wli51 Jan 5, 2026
b2aae07
Refactor GANOrchestrator methods to streamline context handling and i…
wli51 Jan 5, 2026
668a213
Add minimal testing for GANOrchestrator and implement simple discrimi…
wli51 Jan 5, 2026
adcbb1f
Enhance Context class with reserved key checks and add properties for…
wli51 Jan 7, 2026
f53527c
Refactor DiscriminatorForwardGroup to use constant keys for model and…
wli51 Jan 7, 2026
6ba51f4
Refactor GANOrchestrator to tighten type checks accessing values from…
wli51 Jan 7, 2026
146d238
Refactor AbstractBlock and Stage classes to improve output channel ha…
wli51 Jan 7, 2026
c91d2c2
Refactor type annotations in AbstractForwardGroup and its subclasses …
wli51 Jan 7, 2026
f053713
Add tests for or and ror operations and switch from returnning NotImp…
wli51 Jan 7, 2026
8e6b630
Enhance Context class with type checks for values and update pop meth…
wli51 Jan 7, 2026
83d35d1
Tighten-up typing of the loss and lossgroup modules. Due to the need …
wli51 Jan 8, 2026
b4b6edb
Refactor type annotations in AbstractTrainer class to use Tensor type…
wli51 Jan 8, 2026
fd312b9
Fix return type of from_config method in BaseModel to return BaseMode…
wli51 Jan 8, 2026
df370b9
Refactor train method in TrainerProtocol to accept variable arguments…
wli51 Jan 8, 2026
21f0d85
Tighten up type annotations in BaseModel class for to_config method
wli51 Jan 8, 2026
78876a0
Allow customizable or no output activation function in GlobalDiscrimi…
wli51 Jan 8, 2026
68f0ca2
Add model configuration logging (as artifact) to MlflowLogger
wli51 Jan 8, 2026
7fb9693
Enhance TrainerProtocol with model saving functionality and update Ml…
wli51 Jan 8, 2026
f068af2
Refactor SingleGeneratorTrainer to enhance type annotations for loss_…
wli51 Jan 8, 2026
c048d40
Add LoggingGANTrainer class for enhanced GAN training and evaluation …
wli51 Jan 8, 2026
afb60d8
Add minimal tests for LoggingWGANTrainer train_step and evaluate_step…
wli51 Jan 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 92 additions & 18 deletions src/virtual_stain_flow/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
isolated and modular computations.
"""

from typing import Dict, Iterable, Tuple, Union
from typing import Dict, Iterable, Union, Optional

import torch
from torch import Tensor

from .names import TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS
from .names import INPUTS, TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS

ContextValue = Union[torch.Tensor, torch.nn.Module]

Expand Down Expand Up @@ -47,17 +48,9 @@ def add(self, **items: ContextValue) -> "Context":
where keys are the names of the tensors.
"""

for k, v in items.items():
if k in RESERVED_KEYS and not isinstance(v, torch.Tensor):
raise ReservedKeyTypeError(
f"Reserved key '{k}' must be a torch.Tensor, got {type(v)}"
)
elif k in RESERVED_MODEL_KEYS and not isinstance(v, torch.nn.Module):
raise ReservedKeyTypeError(
f"Reserved key '{k}' must be a torch.nn.Module, got {type(v)}"
)

self._store.update(items)
for key, value in items.items():
self[key] = value

return self

def require(self, keys: Iterable[str]) -> None:
Expand All @@ -81,15 +74,18 @@ def as_kwargs(self) -> Dict[str, ContextValue]:
"""
return self._store

def as_metric_args(self) -> Tuple[ContextValue, ContextValue]:
def as_metric_args(self) -> tuple[Tensor, Tensor]:
"""
Returns the predictions and targets tensors for
Image quality assessment metric computation.
Intended use: metric.update(*ctx.as_metric_args())

:return: A tuple (preds, targets) of tensors.
:raises ValueError: If either preds or targets is missing.
"""
self.require([PREDS, TARGETS])
preds = self._store[PREDS]
targs = self._store[TARGETS]
self.require(keys=[PREDS, TARGETS])
preds: Tensor = self.preds
targs: Tensor = self.targets
return (preds, targs)

def __repr__(self) -> str:
Expand All @@ -109,6 +105,28 @@ def __repr__(self) -> str:
# --- Methods for dict like behavior of context class ---

def __setitem__(self, key: str, value: ContextValue) -> None:
"""
Sets a context item, with checks for reserved keys.

:param key: The name of the context item.
:param value: The tensor/module to store.
"""
# Only allow torch.Tensor or torch.nn.Module values
if not isinstance(value, (torch.Tensor, torch.nn.Module)):
raise TypeError(
f"Context values must be torch.Tensor or torch.nn.Module, got {type(value)}"
)

# Further type check matching for reserved keys
if key in RESERVED_KEYS and not isinstance(value, torch.Tensor):
raise ReservedKeyTypeError(
f"Reserved key '{key}' must be a torch.Tensor, got {type(value)}"
)
elif key in RESERVED_MODEL_KEYS and not isinstance(value, torch.nn.Module):
raise ReservedKeyTypeError(
f"Reserved key '{key}' must be a torch.nn.Module, got {type(value)}"
)

self._store[key] = value

def __contains__(self, key: str) -> bool:
Expand All @@ -123,7 +141,7 @@ def __iter__(self):
def __len__(self):
return len(self._store)

def get(self, key: str, default: ContextValue = None) -> ContextValue:
def get(self, key: str, default: Optional[ContextValue] = None) -> Optional[ContextValue]:
return self._store.get(key, default)

def values(self):
Expand All @@ -134,3 +152,59 @@ def items(self):

def keys(self):
return self._store.keys()

def pop(self, key: str) -> ContextValue:
"""
Remove and return the value for key if key is in the context,
else raises a KeyError.
"""
if key not in self._store:
raise KeyError(f"Key '{key}' not found in Context.")
return self._store.pop(key)

def __or__(self, other: "Context") -> "Context":
"""
Merge two Context objects using the | operator.
Returns a new Context with items from both contexts.
Items from the right operand (other) take precedence in case of key conflicts.

:param other: Another Context object to merge with.
:return: A new Context object containing items from both contexts.
"""
if not isinstance(other, Context):
raise NotImplementedError(
"__or__ operation only supported between Context objects."
)
new_context = Context(**self._store)
new_context.add(**other._store)
return new_context

def __ror__(self, other: "Context") -> "Context":
"""
Reverse merge (right | operator) for Context objects.
Called when the left operand doesn't support __or__ with Context.

:param other: Another Context object to merge with.
:return: A new Context object containing items from both contexts.
"""
if not isinstance(other, Context):
raise NotImplementedError(
"__or__ operation only supported between Context objects."
)
new_context = Context(**other._store)
new_context.add(**self._store)
return new_context

# --- Properties for robust typing for reserved keys ---
# let fail if key is not present
@property
def inputs(self) -> Tensor:
return self._store[INPUTS] # type: ignore

@property
def targets(self) -> Tensor:
return self._store[TARGETS] # type: ignore

@property
def preds(self) -> Tensor:
return self._store[PREDS] # type: ignore
103 changes: 94 additions & 9 deletions src/virtual_stain_flow/engine/forward_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
"""

from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
from typing import Optional, Dict

import torch
import torch.optim as optim
import torch.nn as nn

from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL
from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL, DISCRIMINATOR_MODEL
from .context import Context


Expand All @@ -52,9 +52,9 @@ class AbstractForwardGroup(ABC):
"""

# Subclasses should override these with ordered tuples.
input_keys: Tuple[str, ...]
target_keys: Tuple[str, ...]
output_keys: Tuple[str, ...]
input_keys: tuple[str, ...]
target_keys: tuple[str, ...]
output_keys: tuple[str, ...]

def __init__(
self,
Expand All @@ -75,7 +75,7 @@ def _move_tensors(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso
}

@staticmethod
def _normalize_outputs(raw) -> Tuple[torch.Tensor, ...]:
def _normalize_outputs(raw) -> tuple[torch.Tensor, ...]:
"""
Normalize model outputs to a tuple of tensors while preserving order.

Expand Down Expand Up @@ -140,9 +140,9 @@ class GeneratorForwardGroup(AbstractForwardGroup):
metric_value = metric_fn(preds, targets)
"""

input_keys: Tuple[str, ...] = (INPUTS,)
target_keys: Tuple[str, ...] = (TARGETS,)
output_keys: Tuple[str, ...] = (PREDS,)
input_keys: tuple[str, ...] = (INPUTS,)
target_keys: tuple[str, ...] = (TARGETS,)
output_keys: tuple[str, ...] = (PREDS,)

def __init__(
self,
Expand Down Expand Up @@ -207,3 +207,88 @@ def optimizer(self) -> Optional[optim.Optimizer]:
Convenience property to access the generator optimizer directly.
"""
return self._optimizers[GENERATOR_MODEL]


class DiscriminatorForwardGroup(AbstractForwardGroup):
"""
Forward group for a simple single (GAN/wGAN) discriminator workflow.
The discriminator is assumed to take in a "stack" of input and target
images concatenated along the channel dimension, and output a score/probability.
Relevant context values are input_keys, target_keys, output_keys for a
single-discriminator model, where:
- the forward is called as:
p = discriminator(stack)
- the evaluation is less straightforward, but typically involves
computing losses/metrics based on p and real/fake labels:
metric_value = metric_fn(p, real_or_fake_labels)
or perhaps involving the discrminator model itself for wasserstein distance:
metric_value = metric_fn(discriminator, stack, real_or_fake_labels)
"""

input_keys: tuple[str, ...] = ("stack",)
target_keys: tuple[str, ...] = ()
output_keys: tuple[str, ...] = ("p",)

def __init__(
self,
discriminator: nn.Module,
optimizer: Optional[optim.Optimizer] = None,
device: torch.device = torch.device("cpu"),
):
super().__init__(device=device)

self._models[DISCRIMINATOR_MODEL] = discriminator
self._models[DISCRIMINATOR_MODEL].to(self.device)
self._optimizers[DISCRIMINATOR_MODEL] = optimizer

def __call__(self, train: bool, **inputs: torch.Tensor) -> Context:
"""
Executes the forward pass, managing training/eval modes and optimizer steps.
Subclasses may override this method if needed.

:param train: Whether to run in training mode. Meant to be specified
by the trainer to switch between train/eval modes and determine
whether gradients should be computed.
:param inputs: Keyword arguments of input tensors.
"""

fp_model = self.model
fp_optimizer = self.optimizer

# 1) Stage and validate inputs/targets
ctx = Context(**self._move_tensors(inputs), **{DISCRIMINATOR_MODEL: fp_model })
ctx.require(self.input_keys)
ctx.require(self.target_keys)

# 2) Forward, with grad only when training
fp_model.train(mode=train)
train and fp_optimizer is not None and fp_optimizer.zero_grad(set_to_none=True)
with torch.set_grad_enabled(train):
model_inputs = [ctx[k] for k in self.input_keys] # ordered
raw = fp_model(*model_inputs)
y_tuple = self._normalize_outputs(raw)

# 3) Arity check + map outputs to names
if len(y_tuple) != len(self.output_keys):
raise ValueError(
f"Model returned {len(y_tuple)} outputs, "
f"but output_keys expects {len(self.output_keys)}"
)
outputs = {k: v for k, v in zip(self.output_keys, y_tuple)}

# 5) Return enriched context (preds available for losses/metrics)
return ctx.add(**outputs)

@property
def model(self) -> nn.Module:
"""
Convenience property to access the discriminator model directly.
"""
return self._models[DISCRIMINATOR_MODEL]

@property
def optimizer(self) -> Optional[optim.Optimizer]:
"""
Convenience property to access the discriminator optimizer directly.
"""
return self._optimizers[DISCRIMINATOR_MODEL]
20 changes: 12 additions & 8 deletions src/virtual_stain_flow/engine/loss_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

import torch

from .loss_utils import AbstractLoss, _get_loss_name, _scalar_from_ctx
from .context import Context
from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx
from .context import Context, ContextValue
from .names import PREDS, TARGETS

Scalar = Union[int, float, bool]


@dataclass
class LossItem:
Expand All @@ -53,7 +55,7 @@ class LossItem:
losses and centralizes device management.
)
"""
module: Union[torch.nn.Module, AbstractLoss]
module: Union[torch.nn.Module, BaseLoss]
args: Union[str, Tuple[str, ...]] = (PREDS, TARGETS)
key: Optional[str] = None
weight: float = 1.0
Expand All @@ -63,7 +65,7 @@ class LossItem:

def __post_init__(self):

self.key = self.key or _get_loss_name(self.module)
self.key = str(self.key or _get_loss_name(self.module))
self.args = (self.args,) if isinstance(self.args, str) else self.args

try:
Expand Down Expand Up @@ -94,7 +96,9 @@ def __call__(

if context is not None:
context.require(self.args)
inputs = {arg: context[arg] for arg in self.args}
inputs: Dict[str, ContextValue] = {
arg: context[arg] for arg in self.args
}

if not self.enabled or (not train and not self.compute_at_val):
zero = _scalar_from_ctx(0.0, inputs)
Expand Down Expand Up @@ -127,15 +131,15 @@ class LossGroup:
items: Sequence[LossItem]

@property
def item_names(self) -> List[str]:
def item_names(self) -> List[Optional[str]]:
return [item.key for item in self.items]

def __call__(
self,
train: bool,
context: Optional[Context] = None,
**inputs: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
) -> Tuple[torch.Tensor, Dict[str, Scalar]]:
"""
Compute the total loss and individual loss values.

Expand All @@ -153,7 +157,7 @@ def __call__(

for item in self.items:
raw, weighted = item(train, context=context, **inputs)
logs[item.key] = raw.item()
logs[item.key] = raw.item() # type: ignore
total += weighted

return total, logs
Loading