Skip to content

Commit 03a23f1

Browse files
authored
Merge pull request #249 from mathLab/multirom
Multirom
2 parents 3121318 + 5e54f60 commit 03a23f1

21 files changed

+2094
-86
lines changed

ezyrb/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
__all__ = [
44
'Database', 'Snapshot', 'Reduction', 'POD', 'Approximation', 'RBF', 'Linear', 'GPR',
55
'ANN', 'KNeighborsRegressor', 'RadiusNeighborsRegressor', 'AE',
6-
'ReducedOrderModel', 'PODAE', 'RegularGrid'
6+
'ReducedOrderModel', 'PODAE', 'RegularGrid',
7+
'MultiReducedOrderModel'
78
]
89

910
from .database import Database
1011
from .snapshot import Snapshot
1112
from .parameter import Parameter
12-
from .reducedordermodel import ReducedOrderModel
13+
from .reducedordermodel import ReducedOrderModel, MultiReducedOrderModel
1314
from .reduction import *
1415
from .approximation import *
1516
from .regular_grid import RegularGrid

ezyrb/approximation/ann.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def _build_model(self, points, values):
135135
layers.insert(0, points.shape[1])
136136
layers.append(values.shape[1])
137137

138-
self.model = self._list_to_sequential(layers, self.function)
138+
if self.model is None:
139+
self.model = self._list_to_sequential(layers, self.function)
140+
else:
141+
self.model = self.model
139142

140143
def fit(self, points, values):
141144
"""

ezyrb/database.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,33 @@ class Database():
1717
None meaning no scaling.
1818
:param array_like space: the input spatial data
1919
"""
20-
def __init__(self, parameters=None, snapshots=None):
20+
def __init__(self, parameters=None, snapshots=None, space=None):
2121
self._pairs = []
2222

2323
if parameters is None and snapshots is None:
2424
return
2525

26+
if parameters is None:
27+
parameters = [None] * len(snapshots)
28+
elif snapshots is None:
29+
snapshots = [None] * len(parameters)
30+
2631
if len(parameters) != len(snapshots):
27-
raise ValueError
32+
raise ValueError('parameters and snapshots must have the same length')
2833

2934
for param, snap in zip(parameters, snapshots):
30-
self.add(Parameter(param), Snapshot(snap))
35+
param = Parameter(param)
36+
if isinstance(space, dict):
37+
snap_space = space.get(tuple(param.values), None)
38+
# print('snap_space', snap_space)
39+
else:
40+
snap_space = space
41+
snap = Snapshot(snap, space=snap_space)
42+
43+
self.add(param, snap)
44+
45+
# TODO: eventually improve the `space` assignment in the snapshots,
46+
# snapshots can have different space coordinates
3147

3248
@property
3349
def parameters_matrix(self):
@@ -74,7 +90,9 @@ def __len__(self):
7490

7591
def __str__(self):
7692
""" Print minimal info about the Database """
77-
return str(self.parameters_matrix)
93+
s = 'Database with {} snapshots and {} parameters'.format(
94+
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
95+
return s
7896

7997
def add(self, parameter, snapshot):
8098
"""
@@ -103,6 +121,10 @@ def split(self, chunks, seed=None):
103121
>>> train, test = db.split([80, 20]) # n snapshots
104122
105123
"""
124+
125+
if seed is not None:
126+
np.random.seed(seed)
127+
106128
if all(isinstance(n, int) for n in chunks):
107129
if sum(chunks) != len(self):
108130
raise ValueError('chunk elements are inconsistent')
@@ -118,6 +140,7 @@ def split(self, chunks, seed=None):
118140
if not np.isclose(sum(chunks), 1.):
119141
raise ValueError('chunk elements are inconsistent')
120142

143+
121144
cum_chunks = np.cumsum(chunks)
122145
cum_chunks = np.insert(cum_chunks, 0, 0.0)
123146
ids = np.ones(len(self)) * -1.
@@ -137,3 +160,15 @@ def split(self, chunks, seed=None):
137160
new_database[i].add(p, s)
138161

139162
return new_database
163+
164+
def get_snapshot_space(self, index):
165+
"""
166+
Get the space coordinates of a snapshot by its index.
167+
168+
:param int index: The index of the snapshot.
169+
:return: The space coordinates of the snapshot.
170+
:rtype: numpy.ndarray
171+
"""
172+
if index < 0 or index >= len(self._pairs):
173+
raise IndexError("Snapshot index out of range.")
174+
return self._pairs[index][1].space

ezyrb/parameter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
class Parameter:
55

66
def __init__(self, values):
7-
self.values = values
7+
if isinstance(values, Parameter):
8+
self.values = values.values
9+
else:
10+
self.values = values
811

912
@property
1013
def values(self):
@@ -15,4 +18,5 @@ def values(self):
1518
def values(self, new_values):
1619
if np.asarray(new_values).ndim != 1:
1720
raise ValueError('only 1D array are usable as parameter.')
18-
self._values = new_values
21+
22+
self._values = np.asarray(new_values)

ezyrb/plugin/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
'DatabaseScaler',
66
'ShiftSnapshots',
77
'AutomaticShiftSnapshots',
8+
'Aggregation',
9+
'DatabaseSplitter',
10+
'DatabaseDictionarySplitter'
811
]
912

1013
from .scaler import DatabaseScaler
1114
from .plugin import Plugin
1215
from .shift import ShiftSnapshots
1316
from .automatic_shift import AutomaticShiftSnapshots
17+
from .aggregation import Aggregation
18+
from .database_splitter import DatabaseSplitter
19+
from .database_splitter import DatabaseDictionarySplitter

0 commit comments

Comments
 (0)