-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_model.py
More file actions
163 lines (131 loc) · 4.46 KB
/
base_model.py
File metadata and controls
163 lines (131 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import keras
from keras import ops
from keras import layers
from keras import backend as K
import numpy as np
K.clear_session()
class BaseModel(keras.Model):
def __init__(
self,
data_gen,
**kwargs,
):
super().__init__(**kwargs)
self.data_gen = data_gen
# use random batch as test input
test_x, test_y = self.get_random_item()
self.coarsening_factor = self.get_coarsening_factor(test_x)
self.input_name_LR = 'LR_data'
self.input_name_HR = 'HR_data'
self.input_shape_LR = \
test_x[self.input_name_LR].shape[1:]
self.input_shape_HR = \
test_x[self.input_name_HR].shape[1:]
self.num_vars = self.input_shape_LR[-1]
assert self.num_vars == self.input_shape_LR[-1], \
"unequal variables in LR and HR data"
# create mask
self.masking = \
Masking(test_x['meta']['mask'][0,],
self.num_vars,
name=self.name + '_masking')
# setup loss and loss tracker
self.loss_MSE = keras.losses.MeanSquaredError()
self.loss_MSLE = keras.losses.MeanSquaredError()
self.loss_tracker = keras.metrics.Mean(name="loss")
@property
def metrics(self):
return [
self.loss_tracker,
]
def get_random_item(self):
idx = np.random.randint(self.data_gen.__len__())
return self.data_gen.__getitem__(idx)
def build_model(self, name):
inputs, outputs = self.builder()
self.model = keras.Model(
inputs=inputs,
outputs=outputs,
name=name)
# use random batch as test input
test_x, test_y = self.get_random_item()
input_x = self.create_input(test_x)
self.build(input_x)
return self.model
def build(self, *args, **kwargs):
self.model.build(*args, **kwargs)
self.built = True
def create_input(self, inputs):
return inputs
def get_coarsening_factor(self, test_x):
# number of necessary upsampling blocks is inferred from LR
# and HR grids
grid_HR_shape = test_x['meta']['grid_HR']['lat'].shape
grid_LR_shape = test_x['meta']['grid_LR']['lat'].shape
coarsening = \
np.asarray(grid_HR_shape) / np.asarray(grid_LR_shape)
assert coarsening[0] == coarsening[1], "unequal lat/lon coarsening"
coarsening_factor = coarsening[0]
return coarsening_factor
def summary(self, **kwargs):
return self.model.summary(**kwargs)
def call(self, inputs, training=True):
inputs = self.create_input(inputs)
return self.model(inputs, training=training)
def builder(self):
pass
def train_step(self, data, training=True):
pass
def test_step(self, data):
return self.train_step(data, training=False)
class Masking(layers.Layer):
def __init__(
self,
mask,
num_vars,
**kwargs,
):
super().__init__(**kwargs)
mask = ops.convert_to_tensor(mask)
self.rows, self.cols = ops.where(mask == 1)
self.num_vars = num_vars
# usable mask for multiply
self.mask = ops.tile(
ops.expand_dims(mask, -1),
self.num_vars)
def call(self, inputs):
return ops.multiply(inputs, self.mask)
class Activation(layers.Layer):
def __init__(
self,
activation,
**kwargs,
):
super().__init__(**kwargs)
self.activation = activation
def build(self, input_shape):
if self.activation == 'prelu':
self.actv_lr = layers.PReLU()
elif self.activation == 'tanh_scaled':
self.actv_lr = layers.Activation('tanh')
elif self.activation == 'relu1':
self.actv_lr = layers.Identity()
else:
self.actv_lr = layers.Activation(self.activation)
def call(self, inputs):
activated = self.actv_lr(inputs)
if self.activation == 'tanh_scaled':
activated = ScaleAndShift()(activated)
elif self.activation == 'relu1':
activated = keras.activations.relu(inputs, max_value=1)
return activated
class ScaleAndShift(layers.Layer):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
def build(self, input_shape):
pass
def call(self, inputs):
return (inputs + 1) / 2