File tree Expand file tree Collapse file tree 3 files changed +5
-17
lines changed Expand file tree Collapse file tree 3 files changed +5
-17
lines changed Original file line number Diff line number Diff line change @@ -42,7 +42,7 @@ def __init__(
4242
4343 self .shuffle ()
4444
45- def __getitem__ (self , item ):
45+ def __getitem__ (self , item ) -> dict [ str , np . ndarray ] :
4646 if not 0 <= item < self .num_batches :
4747 raise IndexError (f"Index { item } is out of bounds for dataset with { self .num_batches } batches." )
4848
Original file line number Diff line number Diff line change 11import keras
2+ import numpy as np
23
34from bayesflow .adapters import Adapter
45from bayesflow .simulators .simulator import Simulator
5- from bayesflow .types import Tensor
66
77
88class OnlineDataset (keras .utils .PyDataset ):
@@ -20,18 +20,12 @@ def __init__(
2020 ):
2121 super ().__init__ (** kwargs )
2222
23- if keras .backend .backend () == "torch" and kwargs .get ("use_multiprocessing" ):
24- # keras workaround: https://github.com/keras-team/keras/issues/19346
25- import multiprocessing as mp
26-
27- mp .set_start_method ("spawn" , force = True )
28-
2923 self .batch_size = batch_size
3024 self ._num_batches = num_batches
3125 self .adapter = adapter
3226 self .simulator = simulator
3327
34- def __getitem__ (self , item : int ) -> dict [str , Tensor ]:
28+ def __getitem__ (self , item : int ) -> dict [str , np . ndarray ]:
3529 batch = self .simulator .sample ((self .batch_size ,))
3630
3731 if self .adapter is not None :
Original file line number Diff line number Diff line change 11import keras
2+ import numpy as np
23
34from bayesflow .adapters import Adapter
45from bayesflow .simulators .simulator import Simulator
5- from bayesflow .types import Tensor
66from bayesflow .utils import logging
77
88
@@ -22,12 +22,6 @@ def __init__(
2222 ):
2323 super ().__init__ (** kwargs )
2424
25- if keras .backend .backend () == "torch" and kwargs .get ("use_multiprocessing" ):
26- # keras workaround: https://github.com/keras-team/keras/issues/19346
27- import multiprocessing as mp
28-
29- mp .set_start_method ("spawn" , force = True )
30-
3125 self .batches = None
3226 self ._num_batches = num_batches
3327 self .batch_size = batch_size
@@ -46,7 +40,7 @@ def __init__(
4640
4741 self .regenerate ()
4842
49- def __getitem__ (self , item : int ) -> dict [str , Tensor ]:
43+ def __getitem__ (self , item : int ) -> dict [str , np . ndarray ]:
5044 """Get a batch of pre-simulated data"""
5145 batch = self .batches [item ]
5246
You can’t perform that action at this time.
0 commit comments