-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
103 lines (83 loc) · 2.9 KB
/
data_utils.py
File metadata and controls
103 lines (83 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import keras
import tools
import data_manager_cmes
import data_manager_swot
from importlib import reload
reload(data_manager_cmes)
reload(data_manager_swot)
from data_manager_cmes import DataManagerCMEMS
from data_manager_swot import DataManagerSWOT
class DataFactory():
def __new__(
cls,
testing_mode=False,
case_study='cmems',
):
if case_study == 'cmems':
return DataManagerCMEMS(testing_mode)
if case_study == 'swot':
return DataManagerSWOT(testing_mode)
else:
raise ValueError("unknown case study")
class DataGenerator(keras.utils.PyDataset):
def __init__(self, x, y,
ft_type='hybrid',
batch_size=4,
shuffle=False,
lookback=0,
encoder=None,
unroll_dim=1,
**kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.ft_type = ft_type
self.shuffle = shuffle
self.lookback = lookback
self.encoder = encoder
self.unroll_dim = unroll_dim
self.unroll_y = (self.unroll_dim > 0 and
self.ft_type == 'hybrid')
self.unroll_x = self.unroll_y
self.__setup_data(x, y)
def __setup_data(self, x, y):
assert len(x) == 2
assert len(y) == 1
self.indices = np.arange(self.lookback,
x[0].shape[0] - self.unroll_dim)
self.n = len(self.indices)
self.__do_shuffle()
if self.ft_type == 'hybrid':
self.x = x
elif self.ft_type == 'only':
self.x = x
elif self.ft_type == 'disabled':
self.x = [x[0]]
else:
raise ValueError('invalid feedthrough type')
self.y = y
if self.unroll_x:
# set of shifted feedthrough inputs. self.indices is
# truncated with unroll_dim to make sure this does not
# lead to problems.
ft_set = [x[1][i:,] for i in range(self.unroll_dim + 1)]
# append feedthroughs to state input
self.x =[x[0]] + ft_set
if self.unroll_y:
self.y = [y[0][i:,] for i in range(self.unroll_dim + 1)]
def __len__(self):
# number of batches
return int(np.ceil(self.n / self.batch_size))
def __getitem__(self, index):
low = index * self.batch_size
high = np.min([low + self.batch_size, self.n])
inds = self.indices[low:high]
batch_x = tools.create_lookback(inds, self.x, self.lookback)
batch_y = tools.create_lookback(inds, self.y, self.lookback)
# batch_y = [y[inds,] for y in self.y]
return (batch_x, batch_y)
def __do_shuffle(self):
if self.shuffle:
np.random.shuffle(self.indices)
def on_epoch_end(self):
self.__do_shuffle()