Skip to content

Commit 69e2387

Browse files
committed
update datasets for usage with numpy arrays
since we don't use keras tensors anymore, we can get rid of the multiprocessing start methods
1 parent 5334284 commit 69e2387

File tree

3 files changed

+5
-17
lines changed

3 files changed

+5
-17
lines changed

bayesflow/datasets/disk_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242

4343
self.shuffle()
4444

45-
def __getitem__(self, item):
45+
def __getitem__(self, item) -> dict[str, np.ndarray]:
4646
if not 0 <= item < self.num_batches:
4747
raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.")
4848

bayesflow/datasets/online_dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
2+
import numpy as np
23

34
from bayesflow.adapters import Adapter
45
from bayesflow.simulators.simulator import Simulator
5-
from bayesflow.types import Tensor
66

77

88
class OnlineDataset(keras.utils.PyDataset):
@@ -20,18 +20,12 @@ def __init__(
2020
):
2121
super().__init__(**kwargs)
2222

23-
if keras.backend.backend() == "torch" and kwargs.get("use_multiprocessing"):
24-
# keras workaround: https://github.com/keras-team/keras/issues/19346
25-
import multiprocessing as mp
26-
27-
mp.set_start_method("spawn", force=True)
28-
2923
self.batch_size = batch_size
3024
self._num_batches = num_batches
3125
self.adapter = adapter
3226
self.simulator = simulator
3327

34-
def __getitem__(self, item: int) -> dict[str, Tensor]:
28+
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
3529
batch = self.simulator.sample((self.batch_size,))
3630

3731
if self.adapter is not None:

bayesflow/datasets/rounds_dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
2+
import numpy as np
23

34
from bayesflow.adapters import Adapter
45
from bayesflow.simulators.simulator import Simulator
5-
from bayesflow.types import Tensor
66
from bayesflow.utils import logging
77

88

@@ -22,12 +22,6 @@ def __init__(
2222
):
2323
super().__init__(**kwargs)
2424

25-
if keras.backend.backend() == "torch" and kwargs.get("use_multiprocessing"):
26-
# keras workaround: https://github.com/keras-team/keras/issues/19346
27-
import multiprocessing as mp
28-
29-
mp.set_start_method("spawn", force=True)
30-
3125
self.batches = None
3226
self._num_batches = num_batches
3327
self.batch_size = batch_size
@@ -46,7 +40,7 @@ def __init__(
4640

4741
self.regenerate()
4842

49-
def __getitem__(self, item: int) -> dict[str, Tensor]:
43+
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
5044
"""Get a batch of pre-simulated data"""
5145
batch = self.batches[item]
5246

0 commit comments

Comments
 (0)