@@ -17,17 +17,33 @@ class Database():
1717 None meaning no scaling.
1818 :param array_like space: the input spatial data
1919 """
20- def __init__ (self , parameters = None , snapshots = None ):
20+ def __init__ (self , parameters = None , snapshots = None , space = None ):
2121 self ._pairs = []
2222
2323 if parameters is None and snapshots is None :
2424 return
2525
26+ if parameters is None :
27+ parameters = [None ] * len (snapshots )
28+ elif snapshots is None :
29+ snapshots = [None ] * len (parameters )
30+
2631 if len (parameters ) != len (snapshots ):
27- raise ValueError
32+ raise ValueError ( 'parameters and snapshots must have the same length' )
2833
2934 for param , snap in zip (parameters , snapshots ):
30- self .add (Parameter (param ), Snapshot (snap ))
35+ param = Parameter (param )
36+ if isinstance (space , dict ):
37+ snap_space = space .get (tuple (param .values ), None )
38+ # print('snap_space', snap_space)
39+ else :
40+ snap_space = space
41+ snap = Snapshot (snap , space = snap_space )
42+
43+ self .add (param , snap )
44+
45+ # TODO: eventually improve the `space` assignment in the snapshots,
46+ # snapshots can have different space coordinates
3147
3248 @property
3349 def parameters_matrix (self ):
@@ -74,7 +90,9 @@ def __len__(self):
7490
7591 def __str__ (self ):
7692 """ Print minimal info about the Database """
77- return str (self .parameters_matrix )
93+ s = 'Database with {} snapshots and {} parameters' .format (
94+ self .snapshots_matrix .shape [1 ], self .parameters_matrix .shape [1 ])
95+ return s
7896
7997 def add (self , parameter , snapshot ):
8098 """
@@ -103,6 +121,10 @@ def split(self, chunks, seed=None):
103121 >>> train, test = db.split([80, 20]) # n snapshots
104122
105123 """
124+
125+ if seed is not None :
126+ np .random .seed (seed )
127+
106128 if all (isinstance (n , int ) for n in chunks ):
107129 if sum (chunks ) != len (self ):
108130 raise ValueError ('chunk elements are inconsistent' )
@@ -118,6 +140,7 @@ def split(self, chunks, seed=None):
118140 if not np .isclose (sum (chunks ), 1. ):
119141 raise ValueError ('chunk elements are inconsistent' )
120142
143+
121144 cum_chunks = np .cumsum (chunks )
122145 cum_chunks = np .insert (cum_chunks , 0 , 0.0 )
123146 ids = np .ones (len (self )) * - 1.
@@ -137,3 +160,15 @@ def split(self, chunks, seed=None):
137160 new_database [i ].add (p , s )
138161
139162 return new_database
163+
164+ def get_snapshot_space (self , index ):
165+ """
166+ Get the space coordinates of a snapshot by its index.
167+
168+ :param int index: The index of the snapshot.
169+ :return: The space coordinates of the snapshot.
170+ :rtype: numpy.ndarray
171+ """
172+ if index < 0 or index >= len (self ._pairs ):
173+ raise IndexError ("Snapshot index out of range." )
174+ return self ._pairs [index ][1 ].space
0 commit comments