22import numpy as np
33
44from bayesflow .adapters import Adapter
5+ from bayesflow .utils import logging
56
67
78class OfflineDataset (keras .utils .PyDataset ):
@@ -11,12 +12,20 @@ class OfflineDataset(keras.utils.PyDataset):
1112 See the `DiskDataset` class for handling large datasets that are split into multiple smaller files.
1213 """
1314
14- def __init__ (self , data : dict [str , np .ndarray ], batch_size : int , adapter : Adapter | None , ** kwargs ):
15+ def __init__ (
16+ self , data : dict [str , np .ndarray ], batch_size : int , adapter : Adapter | None , num_samples : int = None , ** kwargs
17+ ):
1518 super ().__init__ (** kwargs )
1619 self .batch_size = batch_size
1720 self .data = data
1821 self .adapter = adapter
19- self .num_samples = next (iter (data .values ())).shape [0 ]
22+
23+ if num_samples is None :
24+ self .num_samples = self ._get_num_samples_from_data (data )
25+ logging .debug (f"Automatically determined { self .num_samples } samples in data." )
26+ else :
27+ self .num_samples = num_samples
28+
2029 self .indices = np .arange (self .num_samples , dtype = "int64" )
2130
2231 self .shuffle ()
@@ -29,7 +38,10 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
2938 item = slice (item * self .batch_size , (item + 1 ) * self .batch_size )
3039 item = self .indices [item ]
3140
32- batch = {key : np .take (value , item , axis = 0 ) for key , value in self .data .items ()}
41+ batch = {
42+ key : np .take (value , item , axis = 0 ) if isinstance (value , np .ndarray ) else value
43+ for key , value in self .data .items ()
44+ }
3345
3446 if self .adapter is not None :
3547 batch = self .adapter (batch )
@@ -46,3 +58,13 @@ def on_epoch_end(self) -> None:
4658 def shuffle (self ) -> None :
4759 """Shuffle the dataset in-place."""
4860 np .random .shuffle (self .indices )
61+
62+ @staticmethod
63+ def _get_num_samples_from_data (data : dict ) -> int :
64+ for key , value in data .items ():
65+ if hasattr (value , "shape" ):
66+ ndim = len (value .shape )
67+ if ndim > 1 :
68+ return value .shape [0 ]
69+
70+ raise ValueError ("Could not determine number of samples from data. Please pass it manually." )
0 commit comments