Skip to content

Commit 97c381d

Browse files
committed
Merge branch 'main' into dev
2 parents d0f3d22 + 283c5a0 commit 97c381d

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

bayesflow/datasets/offline_dataset.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from bayesflow.adapters import Adapter
5+
from bayesflow.utils import logging
56

67

78
class OfflineDataset(keras.utils.PyDataset):
@@ -11,12 +12,20 @@ class OfflineDataset(keras.utils.PyDataset):
1112
See the `DiskDataset` class for handling large datasets that are split into multiple smaller files.
1213
"""
1314

14-
def __init__(self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, **kwargs):
15+
def __init__(
16+
self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs
17+
):
1518
super().__init__(**kwargs)
1619
self.batch_size = batch_size
1720
self.data = data
1821
self.adapter = adapter
19-
self.num_samples = next(iter(data.values())).shape[0]
22+
23+
if num_samples is None:
24+
self.num_samples = self._get_num_samples_from_data(data)
25+
logging.debug(f"Automatically determined {self.num_samples} samples in data.")
26+
else:
27+
self.num_samples = num_samples
28+
2029
self.indices = np.arange(self.num_samples, dtype="int64")
2130

2231
self.shuffle()
@@ -29,7 +38,10 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
2938
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
3039
item = self.indices[item]
3140

32-
batch = {key: np.take(value, item, axis=0) for key, value in self.data.items()}
41+
batch = {
42+
key: np.take(value, item, axis=0) if isinstance(value, np.ndarray) else value
43+
for key, value in self.data.items()
44+
}
3345

3446
if self.adapter is not None:
3547
batch = self.adapter(batch)
@@ -46,3 +58,13 @@ def on_epoch_end(self) -> None:
4658
def shuffle(self) -> None:
4759
"""Shuffle the dataset in-place."""
4860
np.random.shuffle(self.indices)
61+
62+
@staticmethod
63+
def _get_num_samples_from_data(data: dict) -> int:
64+
for key, value in data.items():
65+
if hasattr(value, "shape"):
66+
ndim = len(value.shape)
67+
if ndim > 1:
68+
return value.shape[0]
69+
70+
raise ValueError("Could not determine number of samples from data. Please pass it manually.")

0 commit comments

Comments
 (0)