-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmesh_model.py
More file actions
executable file
·386 lines (311 loc) · 16.9 KB
/
mesh_model.py
File metadata and controls
executable file
·386 lines (311 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import torch
#import torch_geometric
#import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
from torch import Tensor
from typing import Union, Tuple, Optional
from torch.nn import Parameter, Linear, Sequential, LayerNorm, ReLU
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import GCNConv
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax, degree
from torch_geometric_temporal.nn.recurrent import GConvLSTM
import enum
import stats
class NodeType(enum.IntEnum):
"""
Define the code for the one-hot vector representing the node types.
Note that this is consistent with the codes provided in the original
MeshGraphNets study:
https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets
"""
NORMAL = 0
WELL = 1
FAULT = 2
BOUNDARY = 3
SIZE = 4
""" GCN-based model"""
""" Modified from https://github.com/locuslab/cfd-gcn/blob/master/models.py"""
class MeshGCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=6, improved=False,
cached=False, bias=True, fine_marker_dict=None):
super().__init__()
self.sdf = None
in_channels += 1 # account for sdf
channels = [in_channels]
channels += [hidden_channels] * (num_layers - 1)
channels.append(out_channels)
convs = []
for i in range(num_layers):
convs.append(GCNConv(channels[i], channels[i+1], improved=improved,
cached=cached, bias=bias))
self.convs = nn.ModuleList(convs)
def forward(self, data):
x = data.x
edge_index = data.edge_index
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = self.convs[-1](x, edge_index)
return x
""" Original Meshgraphnet model"""
class MeshGraphNet(torch.nn.Module):
def __init__(self, input_dim_node, input_dim_edge,
hidden_dim, output_dim, args,
emb=False):
super(MeshGraphNet, self).__init__()
"""
MeshGraphNet model. This model is built upon Deepmind's 2021 paper.
This model consists of three parts: (1) Preprocessing: encoder (2) Processor
(3) postproccessing: decoder. Encoder has an edge and node decoders respectively.
Processor has two processors for edge and node respectively. Note that edge attributes have to be
updated first. Decoder is only for nodes.
Input_dim: dynamic variables + node_type (node_position is encoded in edge attributes)
Hidden_dim: 128 in deepmind's paper
Output_dim: dynamic variables: velocity changes (1)
"""
self.device = args.device
self.well_weight = args.well_weight
self.data_type = args.data_type
self.num_layers = args.num_layers
self.node_type_index = args.node_type_index
self.node_based = args.node_based
# encoder convert raw inputs into latent embeddings
self.node_encoder = Sequential(Linear(input_dim_node , hidden_dim),
ReLU(),
Linear( hidden_dim, hidden_dim),
LayerNorm(hidden_dim))
if not self.node_based:
self.edge_encoder = Sequential(Linear( input_dim_edge , hidden_dim),
ReLU(),
Linear( hidden_dim, hidden_dim),
LayerNorm(hidden_dim)
)
self.processor = nn.ModuleList()
assert (self.num_layers >= 1), 'Number of message passing layers is not >=1'
processor_layer=self.build_processor_model()
for _ in range(self.num_layers):
self.processor.append(processor_layer(hidden_dim, hidden_dim, node_based=self.node_based))
# decoder: only for node embeddings
self.decoder = Sequential(Linear( hidden_dim , hidden_dim),
ReLU(),
Linear( hidden_dim, output_dim)
)
def build_processor_model(self):
return ProcessorLayer
def forward(self,data,mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge):
"""
Encoder encodes graph (node/edge features) into latent vectors (node/edge embeddings)
The return of processor is fed into the processor for generating new feature vectors
"""
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
x = stats.normalize(x,mean_vec_x,std_vec_x)
edge_attr=stats.normalize(edge_attr,mean_vec_edge,std_vec_edge)
# Step 1: encode node/edge features into latent node/edge embeddings
x = self.node_encoder(x) # output shape is the specified hidden dimension
edge_attr = self.edge_encoder(edge_attr) # output shape is the specified hidden dimension
# step 2: perform message passing with latent node/edge embeddings
for i in range(self.num_layers):
x,edge_attr = self.processor[i](x,edge_index,edge_attr)
# step 3: decode latent node embeddings into physical quantities of interest
return self.decoder(x)
def loss(self, pred, inputs,mean_vec_y,std_vec_y, num):
#Define the node types that we calculate loss for
#Get the loss mask for the nodes of the types we calculate loss for
#Need more delibrations
if (self.data_type.upper() == 'HEXA'):
well_loss_mask = (torch.argmax(inputs.x[:,1:],dim=1)==torch.tensor(0)) # extra weight (well)
normal_loss_mask = (torch.argmax(inputs.x[:,1:],dim=1)==torch.tensor(1))
if (self.data_type.upper() == 'PEBI'):
# Hard-coded index for node type
well_loss_mask = torch.logical_or((torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.WELL)),
(torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.FAULT))) # extra weight (well)
normal_loss_mask = torch.logical_or((torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.NORMAL)),
(torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.BOUNDARY)))
#Normalize labels with dataset statistics.
labels = stats.normalize(inputs.y[:, num],mean_vec_y[num],std_vec_y[num]).unsqueeze(-1)
#Find sum of square errors
error=torch.sum((labels-pred)**2,axis=1)
#Root and mean the errors for the nodes we calculate loss for
loss=torch.sqrt(torch.mean(error[normal_loss_mask])) + \
self.well_weight * torch.sqrt(torch.mean(error[well_loss_mask]))
#loss=torch.sqrt(torch.mean(error))
return loss
"""Recurrent MGN model"""
class TransferTempoMGN(torch.nn.Module):
def __init__(self, mgn_model, hidden_dim, output_dim, args, emb=False):
super(TransferTempoMGN, self).__init__()
"""
input: mgn_model: a pretrained Meshgraphnet
"""
# initialize FeatureExtractor class, which has a complete forward function and returns
# the last layer of processor
self.device = args.device
self.well_weight = args.well_weight
self.data_type = args.data_type
self.num_layers = args.num_layers
self.node_type_index = args.node_type_index
self.need_edge_weight = args.need_edge_weight
self.node_based = args.node_based
#self.feature_extractor = mgn_model
self.feature_extractor = nn.ModuleList(mgn_model.children())[:-1]
if (args.pre_trained):
self.decoder = nn.ModuleList(mgn_model.children())[-1]
for param in mgn_model.parameters():
param.requires_grad = False
else:
# Fine-tuned a decoder. certainly we can the pre-trained one too
self.decoder = Sequential(Linear( hidden_dim , hidden_dim),
ReLU(),
Linear( hidden_dim, output_dim)
)
# Stack a consLSTM model after the last layer of processor is finished
self.lstm_filter_size = args.lstm_filter_size
self.recurrent_model = GConvLSTM(hidden_dim, hidden_dim, self.lstm_filter_size)
def build_processor_model(self):
return ProcessorLayer
def forward(self,data,mean_vec_x,std_vec_x,mean_vec_edge,std_vec_edge, h_0, c_0):
"""
Encoder encodes graph (node/edge features) into latent vectors (node/edge embeddings)
The return of processor is fed into the processor for generating new feature vectors
h_0: hidden state from previous timestep
c_0: cell state from previous timestep
"""
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
x = stats.normalize(x,mean_vec_x,std_vec_x)
edge_attr=stats.normalize(edge_attr,mean_vec_edge,std_vec_edge)
# Step 1: encode node/edge features into latent node/edge embeddings
x = self.feature_extractor[0](x) # output shape is the specified hidden dimension
if not self.node_based:
edge_attr = self.feature_extractor[1](edge_attr) # output shape is the specified hidden dimension
# step 2: perform message passing with latent node/edge embeddings
for i in range(self.num_layers):
if not self.node_based:
x, edge_attr = self.feature_extractor[2][i](x,edge_index,edge_attr)
else:
x, _ = self.feature_extractor[1][i](x,edge_index,edge_attr)
# step 3: decode latent node embeddings into physical quantities of interest
# step 3: feed the propagated node embeddings into convLSTM
if (self.need_edge_weight):
edge_weight = edge_attr
else:
edge_weight = torch.ones( edge_attr.shape[0] ).to(self.device)
h_new, c_new = self.recurrent_model(x, edge_index, edge_weight, h_0, c_0)
# step 4: decode latent node embeddings into physical quantities of interest
# step 5: return hidden state and cell state
return self.decoder(h_new), h_new, c_new
def loss(self, pred, inputs,mean_vec_y,std_vec_y, num):
#Define the node types that we calculate loss for
#Get the loss mask for the nodes of the types we calculate loss for
#Need more delibrations
if (self.data_type.upper() == 'HEXA'):
well_loss_mask = (torch.argmax(inputs.x[:,1:],dim=1)==torch.tensor(0)) # extra weight (well)
normal_loss_mask = (torch.argmax(inputs.x[:,1:],dim=1)==torch.tensor(1))
if (self.data_type.upper() == 'PEBI'):
# Hard-coded index for node type
well_loss_mask = torch.logical_or((torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.WELL)),
(torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.FAULT))) # extra weight (well)
normal_loss_mask = torch.logical_or((torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.NORMAL)),
(torch.argmax(inputs.x[:,self.node_type_index:self.node_type_index + NodeType.SIZE],dim=1)==torch.tensor(NodeType.BOUNDARY)))
#stats.normalize labels with dataset statistics.
labels = stats.normalize(inputs.y[:, num],mean_vec_y[num],std_vec_y[num]).unsqueeze(-1)
#Find sum of square errors
error=torch.sum((labels-pred)**2,axis=1)
#Root and mean the errors for the nodes we calculate loss for
loss=torch.sqrt(torch.mean(error[normal_loss_mask])) + \
self.well_weight * torch.sqrt(torch.mean(error[well_loss_mask]))
#loss=torch.sqrt(torch.mean(error))
return loss
"""ProcessorLayer inherits from the PyG MessagePassing base class and handles processor/GNN part of the architecture. 👇
## ProcessorLayer Class: Edge Message Passing, Aggregation, and Updating
## Edge and Node MLP
"""
class ProcessorLayer(MessagePassing):
def __init__(self, in_channels, out_channels,
node_based=False, agg_method='sum', **kwargs):
super(ProcessorLayer, self).__init__( **kwargs )
"""
in_channels: dim of node embeddings [128], out_channels: dim of edge embeddings [128]
"""
# Note that the node and edge encoders both have the same hidden dimension
# size. This means that the input of the edge processor will always be
# three times the specified hidden dimension
# (input: adjacent node embeddings and self embeddings)
self.node_based = node_based
self.agg_method = agg_method
if not node_based:
self.edge_mlp = Sequential(Linear( 3* in_channels , out_channels),
ReLU(),
Linear( out_channels, out_channels),
LayerNorm(out_channels))
self.node_mlp = Sequential(Linear( 2* in_channels , out_channels),
ReLU(),
Linear( out_channels, out_channels),
LayerNorm(out_channels))
self.reset_parameters()
def reset_parameters(self):
"""
reset parameters for stacked MLP layers
"""
if hasattr(self, 'edge_mlp'):
self.edge_mlp[0].reset_parameters()
self.edge_mlp[2].reset_parameters()
self.node_mlp[0].reset_parameters()
self.node_mlp[2].reset_parameters()
def forward(self, x, edge_index, edge_attr, size = None):
"""
Handle the pre and post-processing of node features/embeddings,
as well as initiates message passing by calling the propagate function.
Note that message passing and aggregation are handled by the propagate
function, and the update
x has shpae [node_num , in_channels] (node embeddings)
edge_index: [2, edge_num]
edge_attr: [E, in_channels]
"""
if not self.node_based:
#print('edge_index {}'.format(edge_index.shape))
#print(edge_index)
out, updated_edges = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
updated_nodes = torch.cat([x, out], dim=1)
updated_nodes = x + self.node_mlp(updated_nodes)
else:
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
#print('x shape {}; out shape {}'.format(x.shape, out.shape))
updated_nodes = torch.cat([x, out], dim=1) # residual connection
#print('updated_nodes shape {}'.format(updated_nodes.shape))
updated_nodes = x + self.node_mlp(updated_nodes)
updated_edges = None
return updated_nodes, updated_edges
def message(self, x_i, x_j, edge_attr):
"""
source_node: x_i has the shape of [E, in_channels]
target_node: x_j has the shape of [E, in_channels]
target_edge: edge_attr has the shape of [E, out_channels]
The messages that are passed are the raw embeddings. These are not processed.
"""
if not self.node_based:
updated_edges=torch.cat([x_i, x_j, edge_attr], dim = 1) # tmp_emb has the shape of [E, 3 * in_channels]
return self.edge_mlp(updated_edges)+edge_attr
else:
return x_j # return the raw embeddings of targe nodes
def aggregate(self, inputs, index, dim_size = None):
"""
First we aggregate from neighbors (i.e., adjacent nodes) through concatenation,
then we aggregate self message (from the edge itself). This is streamlined
into one operation here.
"""
# The axis along which to index number of nodes.
node_dim = 0
#print('inputs {} index {}'.format(inputs.shape, index.shape))
#print(index)
out = torch_scatter.scatter(inputs, index, dim=node_dim, reduce = self.agg_method)
if not self.node_based:
return out, inputs
else:
#print('inputs shape {}'.format(inputs.shape))
#print('out shape {}'.format(out.shape))
return out