diff --git a/bayesflow/datasets/offline_dataset.py b/bayesflow/datasets/offline_dataset.py index cbb1b2972..94fdffebc 100644 --- a/bayesflow/datasets/offline_dataset.py +++ b/bayesflow/datasets/offline_dataset.py @@ -2,6 +2,7 @@ import numpy as np from bayesflow.adapters import Adapter +from bayesflow.utils import logging class OfflineDataset(keras.utils.PyDataset): @@ -11,12 +12,20 @@ class OfflineDataset(keras.utils.PyDataset): See the `DiskDataset` class for handling large datasets that are split into multiple smaller files. """ - def __init__(self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, **kwargs): + def __init__( + self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs + ): super().__init__(**kwargs) self.batch_size = batch_size self.data = data self.adapter = adapter - self.num_samples = next(iter(data.values())).shape[0] + + if num_samples is None: + self.num_samples = self._get_num_samples_from_data(data) + logging.debug(f"Automatically determined {self.num_samples} samples in data.") + else: + self.num_samples = num_samples + self.indices = np.arange(self.num_samples, dtype="int64") self.shuffle() @@ -29,7 +38,10 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]: item = slice(item * self.batch_size, (item + 1) * self.batch_size) item = self.indices[item] - batch = {key: np.take(value, item, axis=0) for key, value in self.data.items()} + batch = { + key: np.take(value, item, axis=0) if isinstance(value, np.ndarray) else value + for key, value in self.data.items() + } if self.adapter is not None: batch = self.adapter(batch) @@ -46,3 +58,13 @@ def on_epoch_end(self) -> None: def shuffle(self) -> None: """Shuffle the dataset in-place.""" np.random.shuffle(self.indices) + + @staticmethod + def _get_num_samples_from_data(data: dict) -> int: + for key, value in data.items(): + if hasattr(value, "shape"): + ndim = len(value.shape) + if ndim > 1: + return value.shape[0] + + raise ValueError("Could not determine number of samples from data. Please pass it manually.")