Skip to content

Commit a146e6a

Browse files
committed
OfflineEnsembleDataset: independent slices of indices for each ensemble member
1 parent d8fa571 commit a146e6a

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

bayesflow/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from .offline_dataset import OfflineDataset
8+
from .offline_ensemble_dataset import OfflineEnsembleDataset
89
from .online_dataset import OnlineDataset
910
from .disk_dataset import DiskDataset
1011

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
3+
from .offline_dataset import OfflineDataset
4+
5+
6+
class OfflineEnsembleDataset(OfflineDataset):
7+
"""
8+
A dataset that is pre-simulated and stored in memory, extending :py:class:`OfflineDataset`.
9+
10+
The only difference is that it allows to train an :py:class:`ApproximatorEnsemble` in parallel by returning
11+
batches with ``num_ensemble`` different random subsets of the available data.
12+
"""
13+
14+
def __init__(self, num_ensemble: int, **kwargs):
15+
super().__init__(**kwargs)
16+
self.num_ensemble = num_ensemble
17+
18+
# Create indices with shape (num_samples, num_ensemble)
19+
_indices = np.arange(self.num_samples, dtype="int64")
20+
_indices = np.repeat(_indices[:, None], self.num_ensemble, axis=1)
21+
22+
# Shuffle independently along second axis
23+
for i in range(self.num_ensemble):
24+
np.random.shuffle(_indices[:, i])
25+
26+
self.indices = _indices
27+
28+
# Shuffle first axis
29+
if self._shuffle:
30+
self.shuffle()

0 commit comments

Comments
 (0)