Skip to content

Commit 92953a7

Browse files
committed
Continue to modularize trainer
Break the definition and particulars of model training out of the training script into their own files. To include a new model, users should add another file to networks directory and create two new classes: the model class which defines the architecture and the trainer class which controls how the model is trained
1 parent a07cea6 commit 92953a7

File tree

5 files changed

+92
-75
lines changed

5 files changed

+92
-75
lines changed

run/meshMotion/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
*.stl
33
*.csv
44
*.pdf
5+
ellipsoid3D_MachineLearningMeshMotion/*
Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,41 @@
11
import argparse
2-
from smartredis import Client
32
import torch
4-
import torch.nn as nn
53
import numpy as np
64
import io
7-
from sklearn.model_selection import train_test_split
85
import torch.optim as optim
9-
import time
10-
from typing import Tuple, Union
11-
from matplotlib import pyplot as plt
12-
13-
from sklearn.metrics import mean_squared_error
14-
15-
class MLP(nn.Module):
16-
def __init__(self, num_layers, layer_width, input_size, output_size, activation_fn):
17-
super(MLP, self).__init__()
18-
19-
layers = []
20-
layers.append(nn.Linear(input_size, layer_width))
21-
layers.append(activation_fn)
22-
23-
for _ in range(num_layers - 2):
24-
layers.append(nn.Linear(layer_width, layer_width))
25-
layers.append(activation_fn)
266

27-
layers.append(nn.Linear(layer_width, output_size))
28-
self.layers = nn.Sequential(*layers)
29-
30-
def forward(self, x):
31-
return self.layers(x)
32-
33-
def loss_weighted_center(y_true, y_pred, weights, weights_power):
34-
weights_normed = torch.pow(weights, weights_power)
35-
weights_normed = weights_normed/torch.sum(weights_normed)
7+
from matplotlib import pyplot as plt
8+
from smartredis import Client
369

37-
return torch.sum(torch.sum((y_true-y_pred)**2, dim=1)*weights_normed)
10+
from MLP import MLP, MLPTrainer
3811

3912

4013
def train(args):
4114
client = Client()
4215
torch.set_default_dtype(torch.float64)
4316

4417
# Read the solution direction from a database
45-
dimension = int(client.get_tensor("solution_dim"))
18+
dimension = int(client.get_tensor("solution_dim")[0])
4619

4720
print (f"Solution dimension = {dimension}.")
48-
4921
# Initialize the model
50-
model = MLP(
51-
num_layers=3,
52-
layer_width=10,
53-
input_size=dimension,
54-
output_size=dimension,
55-
activation_fn=torch.nn.ELU()
56-
)
57-
58-
# Initialize the optimizer
59-
learning_rate = 1e-3
60-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
22+
if args.model_name == "mlp":
23+
model = MLP(
24+
input_size=dimension,
25+
output_size=dimension,
26+
num_layers=3,
27+
layer_width=10,
28+
activation_fn=torch.nn.ELU()
29+
)
30+
trainer = MLPTrainer(model, args.radius_power)
6131

32+
data_ready = client.poll_key("points", 1, 10000)
33+
points = client.get_tensor("points")
34+
interior_points = np.vstack([client.get_tensor(f"points_MPI_{i}" for i in range(4))])
35+
X = torch.from_numpy(points).to(torch.float64)
6236
# Make sure all datasets are avaialble in the smartredis database.
37+
38+
epochs = 5000
6339
iteration = 1
6440
while True:
6541

@@ -69,47 +45,22 @@ def train(args):
6945
if (not data_ready):
7046
raise RuntimeError("Data not found in SmartRedis; aborting training.")
7147

72-
points = client.get_tensor("points")
7348
displacements = client.get_tensor("displacements")
74-
49+
interior_points = client.get_tensor
7550
client.delete_tensor("data_ready")
7651

77-
X = torch.from_numpy(points).to(torch.float64)
7852
y = torch.from_numpy(displacements).to(torch.float64)
7953

80-
# Find the center of the shape as the average of all the points on the inner boundary
81-
r = torch.sqrt(torch.sum(X**2, dim=1))
82-
inner = r < 5
83-
center = torch.mean(X[inner], dim=0)
84-
85-
dist = torch.sqrt(torch.sum((X-center)**2, dim=1))
86-
wts = dist/torch.sum(dist)
8754

8855
validation_rmse = []
89-
model.train()
90-
epochs = 5000
9156
n_epochs = 0
9257

9358
for epoch in range(epochs):
94-
# Zero the gradients
95-
optimizer.zero_grad()
96-
97-
# Forward pass on the training data
98-
displ_pred = model(X)
99-
100-
# Compute loss on the training data
101-
loss_train = loss_weighted_center(displ_pred, y, wts, args.radius_power)
102-
103-
if (loss_train < 5e-05):
59+
loss, model = trainer.training_step(X, y)
60+
if trainer.converged():
10461
break
10562

106-
# Backward pass and optimization
107-
loss_train.backward()
108-
optimizer.step()
109-
110-
n_epochs = n_epochs + 1
111-
112-
print (f"MSE {loss_train.item()}, number of epochs {n_epochs}", flush=True)
63+
print(f"MSE {loss.item()}, number of epochs {epoch}", flush=True)
11364
np.savez(
11465
f"data_{iteration:02d}.npz",
11566
points=points,
@@ -150,6 +101,11 @@ def train(args):
150101
parser = argparse.ArgumentParser(description="Training script for mesh motion")
151102
parser.add_argument("mpi_ranks", help="number of mpi ranks", type=int)
152103
parser.add_argument("radius_power", help="power law to weight losses", type=float)
104+
parser.add_argument("model_name",
105+
help="which model to use to calculate interior displacements",
106+
choices=["mlp"],
107+
type=str
108+
)
153109
args = parser.parse_args()
154110

155111
train(args)

run/meshMotion/networks/MLP.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
5+
class MLP(nn.Module):
6+
def __init__(self, input_size, output_size, activation_fn, num_layers, layer_width):
7+
super(MLP, self).__init__()
8+
9+
layers = []
10+
layers.append(nn.Linear(input_size, layer_width))
11+
layers.append(activation_fn)
12+
13+
for _ in range(num_layers - 2):
14+
layers.append(nn.Linear(layer_width, layer_width))
15+
layers.append(activation_fn)
16+
17+
layers.append(nn.Linear(layer_width, output_size))
18+
self.layers = nn.Sequential(*layers)
19+
20+
def forward(self, x):
21+
return self.layers(x)
22+
23+
class MLPTrainer:
24+
def __init__(self, model, radius_power, lr=1e-3, loss_stop=5e-5):
25+
self.model = model
26+
self.optimizer = optim.Adam(model.parameters(), lr=lr)
27+
self.loss_stop = loss_stop
28+
self.loss_value = None
29+
self.radius_power = radius_power
30+
31+
def loss(self, X, y_true):
32+
inner = y_true != 0.
33+
center = torch.mean(X[inner], dim=0)
34+
scaled_dist = torch.sqrt(torch.sum((X-center)**2, dim=1))**self.radius_power
35+
wts = scaled_dist/torch.sum(scaled_dist)
36+
37+
y_pred = self.model(X)
38+
return torch.sum(wts*torch.sum(torch.sqrt((y_true-y_pred)**2), dim=1))
39+
40+
def training_step(self, X, y_true):
41+
self.optimizer.zero_grad()
42+
loss_value = self.loss(X, y_true)
43+
self.loss_value = loss_value
44+
loss_value.backward()
45+
self.optimizer.step()
46+
47+
return loss_value, self.model
48+
49+
def converged(self):
50+
if self.loss_value.item() < self.loss_stop:
51+
return True
52+
return False
53+
54+
55+
56+

run/meshMotion/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
gmsh
2+
PyFoam

run/meshMotion/smartsim_driver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def main(args):
5757
openfoam_rs = exp.create_run_settings(
5858
exe="moveDynamicMesh",
5959
exe_args="-parallel",
60+
run_command="mpirun"
6061
)
6162
openfoam_rs.set_tasks(num_mpi_ranks)
6263
openfoam_rs.set_nodes(1)
63-
openfoam_rs.set("exclusive")
6464

6565
# Create the model from the OpenFOAM case argument
6666
openfoam_model = exp.create_model(
@@ -75,7 +75,7 @@ def main(args):
7575

7676
training_rs = exp.create_run_settings(
7777
exe="python",
78-
exe_args=f"ml_model_training.py {num_mpi_ranks} {args.radius_power}"
78+
exe_args=f"ml_model_training.py {num_mpi_ranks} {args.radius_power} mlp"
7979
)
8080
training_rs.set_tasks(1)
8181
training_rs.set_nodes(1)
@@ -84,7 +84,9 @@ def main(args):
8484
name="ml_model_training",
8585
run_settings=training_rs
8686
)
87-
ml_model_training.attach_generator_files(to_copy="ml_model_training.py")
87+
ml_model_training.attach_generator_files(
88+
to_copy=["ml_model_training.py", "networks/MLP.py"]
89+
)
8890

8991
exp.generate(ml_model_training, overwrite=True)
9092

0 commit comments

Comments
 (0)