From fbe8ecbe13ad2505beec763fbd89a38888b04f5a Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Wed, 16 Apr 2025 15:32:49 +0000 Subject: [PATCH] Set default stage in adapter to "inference" - In addition, make the stage used in the datasets configurable with a parameter. Defaults to "training" for now. --- bayesflow/adapters/adapter.py | 22 +++++++++++++------- bayesflow/adapters/transforms/standardize.py | 2 +- bayesflow/datasets/disk_dataset.py | 4 +++- bayesflow/datasets/offline_dataset.py | 5 ++++- bayesflow/datasets/online_dataset.py | 5 ++++- bayesflow/datasets/rounds_dataset.py | 5 ++++- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 5e3b8aaef..f16bbe4bd 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -79,13 +79,15 @@ def from_config(cls, config: dict, custom_objects=None) -> "Adapter": def get_config(self) -> dict: return {"transforms": serialize(self.transforms)} - def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: + def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]: """Apply the transforms in the forward direction. Parameters ---------- data : dict The data to be transformed. + stage : str, one of ["training", "validation", "inference"] + The stage the function is called in. **kwargs : dict Additional keyword arguments passed to each transform. @@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]: data = data.copy() for transform in self.transforms: - data = transform(data, **kwargs) + data = transform(data, stage=stage, **kwargs) return data - def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]: + def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]: """Apply the transforms in the inverse direction. Parameters ---------- data : dict The data to be transformed. + stage : str, one of ["training", "validation", "inference"] + The stage the function is called in. **kwargs : dict Additional keyword arguments passed to each transform. @@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]: data = data.copy() for transform in reversed(self.transforms): - data = transform(data, inverse=True, **kwargs) + data = transform(data, stage=stage, inverse=True, **kwargs) return data - def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]: + def __call__( + self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs + ) -> dict[str, np.ndarray]: """Apply the transforms in the given direction. Parameters @@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) The data to be transformed. inverse : bool, optional If False, apply the forward transform, else apply the inverse transform (default False). + stage : str, one of ["training", "validation", "inference"] + The stage the function is called in. **kwargs Additional keyword arguments passed to each transform. @@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) The transformed data. """ if inverse: - return self.inverse(data, **kwargs) + return self.inverse(data, stage=stage, **kwargs) - return self.forward(data, **kwargs) + return self.forward(data, stage=stage, **kwargs) def __repr__(self): result = "" diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py index 740f21ee0..2a058f14a 100644 --- a/bayesflow/adapters/transforms/standardize.py +++ b/bayesflow/adapters/transforms/standardize.py @@ -96,7 +96,7 @@ def get_config(self) -> dict: "momentum": serialize(self.momentum), } - def forward(self, data: np.ndarray, stage: str = "training", **kwargs) -> np.ndarray: + def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray: if self.axis is None: self.axis = tuple(range(data.ndim - 1)) diff --git a/bayesflow/datasets/disk_dataset.py b/bayesflow/datasets/disk_dataset.py index d2f2d18bd..8753e3480 100644 --- a/bayesflow/datasets/disk_dataset.py +++ b/bayesflow/datasets/disk_dataset.py @@ -31,6 +31,7 @@ def __init__( batch_size: int, load_fn: callable = None, adapter: Adapter | None, + stage: str = "training", **kwargs, ): super().__init__(**kwargs) @@ -39,6 +40,7 @@ def __init__( self.load_fn = load_fn or pickle_load self.adapter = adapter self.files = list(map(str, self.root.glob(pattern))) + self.stage = stage self.shuffle() @@ -55,7 +57,7 @@ def __getitem__(self, item) -> dict[str, np.ndarray]: batch = tree_stack(batch) if self.adapter is not None: - batch = self.adapter(batch) + batch = self.adapter(batch, stage=self.stage) return batch diff --git a/bayesflow/datasets/offline_dataset.py b/bayesflow/datasets/offline_dataset.py index 380ab5d60..51f2b51f7 100644 --- a/bayesflow/datasets/offline_dataset.py +++ b/bayesflow/datasets/offline_dataset.py @@ -21,12 +21,15 @@ def __init__( batch_size: int, adapter: Adapter | None, num_samples: int = None, + *, + stage: str = "training", **kwargs, ): super().__init__(**kwargs) self.batch_size = batch_size self.data = data self.adapter = adapter + self.stage = stage if num_samples is None: self.num_samples = self._get_num_samples_from_data(data) @@ -52,7 +55,7 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]: } if self.adapter is not None: - batch = self.adapter(batch) + batch = self.adapter(batch, stage=self.stage) return batch diff --git a/bayesflow/datasets/online_dataset.py b/bayesflow/datasets/online_dataset.py index 33199a418..18701f70e 100644 --- a/bayesflow/datasets/online_dataset.py +++ b/bayesflow/datasets/online_dataset.py @@ -16,6 +16,8 @@ def __init__( batch_size: int, num_batches: int, adapter: Adapter | None, + *, + stage: str = "training", **kwargs, ): super().__init__(**kwargs) @@ -24,12 +26,13 @@ def __init__( self._num_batches = num_batches self.adapter = adapter self.simulator = simulator + self.stage = stage def __getitem__(self, item: int) -> dict[str, np.ndarray]: batch = self.simulator.sample((self.batch_size,)) if self.adapter is not None: - batch = self.adapter(batch) + batch = self.adapter(batch, stage=self.stage) return batch diff --git a/bayesflow/datasets/rounds_dataset.py b/bayesflow/datasets/rounds_dataset.py index bb2d96d97..b6c59336c 100644 --- a/bayesflow/datasets/rounds_dataset.py +++ b/bayesflow/datasets/rounds_dataset.py @@ -18,6 +18,8 @@ def __init__( num_batches: int, epochs_per_round: int, adapter: Adapter | None, + *, + stage: str = "training", **kwargs, ): super().__init__(**kwargs) @@ -27,6 +29,7 @@ def __init__( self.batch_size = batch_size self.adapter = adapter self.epoch = 0 + self.stage = stage if epochs_per_round == 1: logging.warning( @@ -45,7 +48,7 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]: batch = self.batches[item] if self.adapter is not None: - batch = self.adapter(batch) + batch = self.adapter(batch, stage=self.stage) return batch