Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions bayesflow/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from bayesflow.adapters import Adapter
from bayesflow.utils import logging


class OfflineDataset(keras.utils.PyDataset):
Expand All @@ -11,12 +12,20 @@
See the `DiskDataset` class for handling large datasets that are split into multiple smaller files.
"""

def __init__(self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, **kwargs):
def __init__(
self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add docs about the arguments of the DataSet classes. But I guess this can be done independently of this bug fix.

):
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
self.adapter = adapter
self.num_samples = next(iter(data.values())).shape[0]

if num_samples is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like that the user has both options (a) to automatically compute num_samples (default) or (b) to manually specify it, if for some reason, num_samples cannot be reliably inferred by the data.

self.num_samples = self._get_num_samples_from_data(data)
logging.debug(f"Automatically determined {self.num_samples} samples in data.")
else:
self.num_samples = num_samples

Check warning on line 27 in bayesflow/datasets/offline_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_dataset.py#L27

Added line #L27 was not covered by tests

self.indices = np.arange(self.num_samples, dtype="int64")

self.shuffle()
Expand All @@ -29,7 +38,10 @@
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
item = self.indices[item]

batch = {key: np.take(value, item, axis=0) for key, value in self.data.items()}
batch = {
key: np.take(value, item, axis=0) if isinstance(value, np.ndarray) else value
for key, value in self.data.items()
}

if self.adapter is not None:
batch = self.adapter(batch)
Expand All @@ -46,3 +58,13 @@
def shuffle(self) -> None:
"""Shuffle the dataset in-place."""
np.random.shuffle(self.indices)

@staticmethod
def _get_num_samples_from_data(data: dict) -> int:
for key, value in data.items():
if hasattr(value, "shape"):
ndim = len(value.shape)
if ndim > 1:
return value.shape[0]

raise ValueError("Could not determine number of samples from data. Please pass it manually.")

Check warning on line 70 in bayesflow/datasets/offline_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_dataset.py#L70

Added line #L70 was not covered by tests