Skip to content

Commit 5523498

Browse files
committed
[mesh motion] RBF convergence tests added
1 parent b1ee5f5 commit 5523498

File tree

6 files changed

+50
-23
lines changed

6 files changed

+50
-23
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

src/displacementSmartSimMotionSolver/pytorchApproximationModels/test_rbf_network_stream_function.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
import torch.nn as nn
55
import torch.optim as optim
6+
import csv
7+
import os
68

79
from rbf_network import WendlandLinearNetwork
810

@@ -21,7 +23,7 @@ def compute_velocity(x, y, psi_values):
2123
v = -dx # v = -∂ψ/∂x
2224
return u, v
2325

24-
def visualize_psi(x, y, psi_values, title="Stream Function", centers=None):
26+
def visualize_psi(x, y, psi_values, centers, title="Stream Function"):
2527
plt.figure(figsize=(6, 6))
2628
plt.contourf(x, y, psi_values, levels=20, cmap='viridis')
2729
plt.colorbar(label='ψ')
@@ -30,24 +32,23 @@ def visualize_psi(x, y, psi_values, title="Stream Function", centers=None):
3032
plt.ylabel('y')
3133
plt.grid()
3234

33-
# Plot centers if provided
34-
if centers is not None:
35-
centers_np = centers.numpy() # Convert from torch tensor to numpy
36-
plt.scatter(centers_np[:, 0], centers_np[:, 1], color='white', marker='x', s=100, linewidths=2, label='Centers')
37-
plt.legend()
35+
# Plot centers
36+
centers_np = centers.numpy() # Convert from torch tensor to numpy
37+
plt.scatter(centers_np[:, 0], centers_np[:, 1], color='white', marker='x', s=100, linewidths=2, label='Centers')
38+
plt.legend()
3839

3940
fig_name = title.replace(" ", "-")
40-
plt.savefig(f"{fig_name}.png", dpi=200)
41+
plt.savefig(f"{fig_name}-num_centers-{len(centers)}.png", dpi=200)
4142

42-
def visualize_velocity_field(x, y, u, v, title="Velocity Field"):
43+
def visualize_velocity_field(x, y, u, v, num_centers, title="Velocity Field"):
4344
plt.figure(figsize=(6, 6))
4445
plt.quiver(x, y, u, v)
4546
plt.title(title)
4647
plt.xlabel('x')
4748
plt.ylabel('y')
4849
plt.grid()
4950
fig_name = title.replace(" ", "-")
50-
plt.savefig(f"{fig_name}.png", dpi=200)
51+
plt.savefig(f"{fig_name}-num_centers{num_centers}.png", dpi=200)
5152

5253
def generate_centers(num_centers):
5354
"""
@@ -59,9 +60,8 @@ def generate_centers(num_centers):
5960
centers = np.vstack([X.ravel(), Y.ravel()]).T
6061
return torch.tensor(centers, dtype=torch.float32)
6162

62-
def main():
63+
def main(num_points):
6364
# Generate training data
64-
num_points = 10
6565
x = np.linspace(0, 1, num_points)
6666
y = np.linspace(0, 1, num_points)
6767
X, Y = np.meshgrid(x, y)
@@ -76,7 +76,7 @@ def main():
7676
# centers = generate_centers(32).clone().detach()
7777
centers = x_train
7878
print(centers.shape)
79-
r_max = 0.1
79+
r_max = 3.0 / num_points
8080
smoothness = 4 # C^4 smoothness
8181

8282
# Initialize model
@@ -104,9 +104,9 @@ def main():
104104
print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.14f}')
105105

106106
# Generate validation data
107-
num_points = 100
108-
x_val = np.linspace(0, 1, num_points)
109-
y_val = np.linspace(0, 1, num_points)
107+
num_points_val = 100
108+
x_val = np.linspace(0, 1, num_points_val)
109+
y_val = np.linspace(0, 1, num_points_val)
110110
X_val, Y_val = np.meshgrid(x_val, y_val)
111111
xy_val = np.column_stack((X_val.flatten(), Y_val.flatten()))
112112
psi_val = psi(X_val, Y_val)
@@ -122,19 +122,46 @@ def main():
122122
#psi_actual = psi(X_val, Y_val)
123123

124124
# Visualize actual and predicted stream functions
125-
visualize_psi(X_val, Y_val, psi_val, title="Actual Stream Function")
126-
visualize_psi(X_val, Y_val, psi_pred, title="Predicted Stream Function")
125+
visualize_psi(X_val, Y_val, psi_val, centers, title="Actual Stream Function")
126+
visualize_psi(X_val, Y_val, psi_pred, centers, title="Predicted Stream Function")
127127

128-
visualize_psi(X_val, Y_val, np.abs(psi_pred - psi_val),
129-
title="Stream Function Approximation Error",
130-
centers=centers)
128+
err_val = np.abs(psi_pred - psi_val)
129+
visualize_psi(X_val, Y_val, err_val, centers,
130+
title="Stream Function Approximation Error")
131+
132+
# Define the filename
133+
csv_filename = "stream_function_validation.csv"
134+
135+
# Define the header and the values to be appended
136+
header = ["num_points", "point_dist", "r_max", "err_validation"]
137+
data = [num_points, 1.0 / num_points, r_max, np.mean(err_val)]
138+
139+
# Check if file exists
140+
file_exists = os.path.isfile(csv_filename)
141+
142+
# Open file in append mode
143+
with open(csv_filename, mode='a', newline='') as file:
144+
writer = csv.writer(file)
145+
146+
# If the file doesn't exist, write the header first
147+
if not file_exists:
148+
writer.writerow(header)
149+
150+
# Append the data row
151+
writer.writerow(data)
152+
153+
print(f"Appended to {csv_filename}: {data}")
131154

132155
# Compute and visualize actual and predicted velocity fields
133156
u_val, v_val = compute_velocity(X_val, Y_val, psi_val)
134-
visualize_velocity_field(X_val, Y_val, u_val, v_val, title="Actual Velocity Field")
157+
visualize_velocity_field(X_val, Y_val, u_val, v_val, num_points,
158+
title="Actual Velocity Field")
135159

136160
u_pred, v_pred = compute_velocity(X_val, Y_val, psi_pred)
137-
visualize_velocity_field(X_val, Y_val, u_pred, v_pred, title="Predicted Velocity Field")
161+
visualize_velocity_field(X_val, Y_val, u_pred, v_pred, num_points,
162+
title="Predicted Velocity Field")
138163

139164
if __name__ == "__main__":
140-
main()
165+
main(num_points=4)
166+
main(num_points=8)
167+
main(num_points=16)

0 commit comments

Comments
 (0)