11import argparse
2- from smartredis import Client
32import torch
4- import torch .nn as nn
53import numpy as np
64import io
7- from sklearn .model_selection import train_test_split
85import torch .optim as optim
9- import time
10- from typing import Tuple , Union
11- from matplotlib import pyplot as plt
12-
13- from sklearn .metrics import mean_squared_error
14-
15- class MLP (nn .Module ):
16- def __init__ (self , num_layers , layer_width , input_size , output_size , activation_fn ):
17- super (MLP , self ).__init__ ()
18-
19- layers = []
20- layers .append (nn .Linear (input_size , layer_width ))
21- layers .append (activation_fn )
22-
23- for _ in range (num_layers - 2 ):
24- layers .append (nn .Linear (layer_width , layer_width ))
25- layers .append (activation_fn )
266
27- layers .append (nn .Linear (layer_width , output_size ))
28- self .layers = nn .Sequential (* layers )
29-
30- def forward (self , x ):
31- return self .layers (x )
32-
33- def loss_weighted_center (y_true , y_pred , weights , weights_power ):
34- weights_normed = torch .pow (weights , weights_power )
35- weights_normed = weights_normed / torch .sum (weights_normed )
7+ from matplotlib import pyplot as plt
8+ from smartredis import Client
369
37- return torch . sum ( torch . sum (( y_true - y_pred ) ** 2 , dim = 1 ) * weights_normed )
10+ from MLP import MLP , MLPTrainer
3811
3912
4013def train (args ):
4114 client = Client ()
4215 torch .set_default_dtype (torch .float64 )
4316
4417 # Read the solution direction from a database
45- dimension = int (client .get_tensor ("solution_dim" ))
18+ dimension = int (client .get_tensor ("solution_dim" )[ 0 ] )
4619
4720 print (f"Solution dimension = { dimension } ." )
48-
4921 # Initialize the model
50- model = MLP (
51- num_layers = 3 ,
52- layer_width = 10 ,
53- input_size = dimension ,
54- output_size = dimension ,
55- activation_fn = torch .nn .ELU ()
56- )
57-
58- # Initialize the optimizer
59- learning_rate = 1e-3
60- optimizer = optim .Adam (model .parameters (), lr = learning_rate )
22+ if args .model_name == "mlp" :
23+ model = MLP (
24+ input_size = dimension ,
25+ output_size = dimension ,
26+ num_layers = 3 ,
27+ layer_width = 10 ,
28+ activation_fn = torch .nn .ELU ()
29+ )
30+ trainer = MLPTrainer (model , args .radius_power )
6131
32+ data_ready = client .poll_key ("points" , 1 , 10000 )
33+ points = client .get_tensor ("points" )
34+ interior_points = np .vstack ([client .get_tensor (f"points_MPI_{ i } " for i in range (4 ))])
35+ X = torch .from_numpy (points ).to (torch .float64 )
6236 # Make sure all datasets are avaialble in the smartredis database.
37+
38+ epochs = 5000
6339 iteration = 1
6440 while True :
6541
@@ -69,47 +45,22 @@ def train(args):
6945 if (not data_ready ):
7046 raise RuntimeError ("Data not found in SmartRedis; aborting training." )
7147
72- points = client .get_tensor ("points" )
7348 displacements = client .get_tensor ("displacements" )
74-
49+ interior_points = client . get_tensor
7550 client .delete_tensor ("data_ready" )
7651
77- X = torch .from_numpy (points ).to (torch .float64 )
7852 y = torch .from_numpy (displacements ).to (torch .float64 )
7953
80- # Find the center of the shape as the average of all the points on the inner boundary
81- r = torch .sqrt (torch .sum (X ** 2 , dim = 1 ))
82- inner = r < 5
83- center = torch .mean (X [inner ], dim = 0 )
84-
85- dist = torch .sqrt (torch .sum ((X - center )** 2 , dim = 1 ))
86- wts = dist / torch .sum (dist )
8754
8855 validation_rmse = []
89- model .train ()
90- epochs = 5000
9156 n_epochs = 0
9257
9358 for epoch in range (epochs ):
94- # Zero the gradients
95- optimizer .zero_grad ()
96-
97- # Forward pass on the training data
98- displ_pred = model (X )
99-
100- # Compute loss on the training data
101- loss_train = loss_weighted_center (displ_pred , y , wts , args .radius_power )
102-
103- if (loss_train < 5e-05 ):
59+ loss , model = trainer .training_step (X , y )
60+ if trainer .converged ():
10461 break
10562
106- # Backward pass and optimization
107- loss_train .backward ()
108- optimizer .step ()
109-
110- n_epochs = n_epochs + 1
111-
112- print (f"MSE { loss_train .item ()} , number of epochs { n_epochs } " , flush = True )
63+ print (f"MSE { loss .item ()} , number of epochs { epoch } " , flush = True )
11364 np .savez (
11465 f"data_{ iteration :02d} .npz" ,
11566 points = points ,
@@ -150,6 +101,11 @@ def train(args):
150101 parser = argparse .ArgumentParser (description = "Training script for mesh motion" )
151102 parser .add_argument ("mpi_ranks" , help = "number of mpi ranks" , type = int )
152103 parser .add_argument ("radius_power" , help = "power law to weight losses" , type = float )
104+ parser .add_argument ("model_name" ,
105+ help = "which model to use to calculate interior displacements" ,
106+ choices = ["mlp" ],
107+ type = str
108+ )
153109 args = parser .parse_args ()
154110
155111 train (args )
0 commit comments