Skip to content

Commit 5e3aa2f

Browse files
committed
[mesh motion] best gaussian approximation settingss for stream function
1 parent bf51ede commit 5e3aa2f

File tree

2 files changed

+79
-33
lines changed

2 files changed

+79
-33
lines changed

src/displacementSmartSimMotionSolver/pytorchApproximationModels/rbf_network.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
import torch
22
import torch.nn as nn
33

4+
import torch
5+
46
# Define various RBF functions with enforced compact support
57
def gaussian_rbf(r):
6-
"""Infinitely smooth Gaussian RBF, matching Wendland's implementation."""
7-
return torch.exp(-r**2) # Matches WendlandLinearNetwork
8+
"""Infinitely smooth Gaussian RBF."""
9+
return torch.exp(-r**2)
10+
11+
def wendland_d2_c2_rbf(r):
12+
"""
13+
Wendland's C^2 RBF for d=2.
14+
Compactly supported, continuously differentiable (C^2).
15+
16+
Formula: (1 - r)^4_+ (4r + 1)
17+
"""
18+
mask = (r < 1).float()
19+
rm = (1 - r).clamp(min=0.0)
20+
return mask * rm**4 * (4 * r + 1)
821

9-
def wendland_rbf(r):
10-
"""Compactly supported Wendland C^4 RBF."""
22+
def wendland_d2_c4_rbf(r):
23+
"""
24+
Wendland's C^4 RBF for d=2.
25+
Compactly supported, twice continuously differentiable (C^4).
26+
27+
Formula: (1 - r)^6_+ (35r^2 + 18r + 3)
28+
"""
1129
mask = (r < 1).float()
1230
rm = (1 - r).clamp(min=0.0)
13-
return mask * (1 + 6*r + (35/3)*r**2) * rm**6
31+
return mask * rm**6 * (35 * r**2 + 18 * r + 3)
1432

1533
def multiquadric_rbf(r):
1634
"""Multiquadric RBF with compact support."""
@@ -25,7 +43,8 @@ def inverse_multiquadric_rbf(r):
2543
# Create an RBF function dictionary
2644
rbf_dict = {
2745
"gaussian": gaussian_rbf,
28-
"wendland": wendland_rbf,
46+
"wendland_d2_c2": wendland_d2_c2_rbf,
47+
"wendland_d2_c4": wendland_d2_c4_rbf,
2948
"multiquadric": multiquadric_rbf,
3049
"inverse_multiquadric": inverse_multiquadric_rbf
3150
}

src/displacementSmartSimMotionSolver/pytorchApproximationModels/test_rbf_network_stream_function.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def compute_velocity(x, y, psi_values):
2424
v = -dx # v = -∂ψ/∂x
2525
return u, v
2626

27-
def visualize_psi(x, y, psi_values, centers, title="Stream Function"):
27+
def visualize_psi(x, y, psi_values, rbf_type, centers, title):
2828
plt.figure(figsize=(6, 6))
2929
plt.contourf(x, y, psi_values, levels=20, cmap='viridis')
3030
plt.colorbar(label='ψ')
31-
plt.title(title + f" num_centers {len(centers)}")
31+
plt.title(title + f"-rbf_type_{rbf_type}-num_centers_{len(centers)}")
3232
plt.xlabel('x')
3333
plt.ylabel('y')
3434
plt.grid()
@@ -39,17 +39,17 @@ def visualize_psi(x, y, psi_values, centers, title="Stream Function"):
3939
plt.legend()
4040

4141
fig_name = title.replace(" ", "-")
42-
plt.savefig(f"{fig_name}-num_centers-{len(centers)}.png", dpi=200)
42+
plt.savefig(f"{fig_name}-rbf_type_{rbf_type}-num_centers_{len(centers)}.png", dpi=200)
4343

44-
def visualize_velocity_field(x, y, u, v, num_centers, title="Velocity Field"):
44+
def visualize_velocity_field(x, y, u, v, rbf_type, num_centers, title="Velocity"):
4545
plt.figure(figsize=(6, 6))
4646
plt.quiver(x, y, u, v)
47-
plt.title(title)
47+
plt.title(f"{title}-{rbf_type}-n_centers_{num_centers}")
4848
plt.xlabel('x')
4949
plt.ylabel('y')
5050
plt.grid()
5151
fig_name = title.replace(" ", "-")
52-
plt.savefig(f"{fig_name}-num_centers{num_centers}.png", dpi=200)
52+
plt.savefig(f"{fig_name}-rbf_type_{rbf_type}-num_centers_{num_centers}.png", dpi=200)
5353

5454
def generate_centers(num_centers):
5555
"""
@@ -106,7 +106,7 @@ def estimate_convergence_order(csv_filename):
106106

107107
print(f"Updated {csv_filename} with convergence orders for {error_columns}.")
108108

109-
def main(num_points):
109+
def main(num_points, rbf_type):
110110
# Generate training data
111111
x = np.linspace(0, 1, num_points)
112112
y = np.linspace(0, 1, num_points)
@@ -119,21 +119,26 @@ def main(num_points):
119119
y_train = torch.tensor(psi_train, dtype=torch.float32)
120120

121121
# Generate centers
122-
# centers = generate_centers(32).clone().detach()
123122
centers = x_train
124123
print(centers.shape)
124+
125+
# Gaussian 3d-order support
126+
#r_max = 2.5 / num_points
125127
r_max = 2.5 / num_points
126-
smoothness = 4 # C^4 smoothness
127128

128129
# Initialize model
129-
model = RadialBasisFunctionNetwork(centers, r_max, rbf_dict, rbf_type="gaussian")
130+
model = RadialBasisFunctionNetwork(centers, r_max, rbf_dict, rbf_type=rbf_type)
130131

131132
# Optimizer and loss
132133
optimizer = optim.Adam(model.parameters(), lr=0.05)
133-
criterion = nn.MSELoss()
134+
criterion = torch.nn.MSELoss()
134135

135136
# Training loop
136137
epochs = 4000
138+
best_loss = float("inf") # Initialize best loss to a large value
139+
best_model_state = None # Store best model state
140+
stop_loss = 1e-08
141+
137142
for epoch in range(epochs):
138143
model.train()
139144
optimizer.zero_grad()
@@ -142,12 +147,28 @@ def main(num_points):
142147
loss.backward()
143148
optimizer.step()
144149

145-
if loss.item() < 1e-12:
146-
print(f"Stopping early at epoch {epoch + 1} due to reaching loss {loss.item()} < 1e-05")
150+
# Save the model if it has the lowest loss so far
151+
if loss.item() < best_loss:
152+
best_loss = loss.item()
153+
best_model_state = model.state_dict().copy() # Copy best model state
154+
155+
# Early stopping criterion
156+
if loss.item() < stop_loss:
157+
print(f"Stopping early at epoch {epoch + 1} due to reaching loss {loss.item()} < {stop_loss}")
147158
break
148159

149-
if ((epoch == 1) or ((epoch + 1) % 50 == 0)):
150-
print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.14f}')
160+
# Print progress every 50 epochs
161+
if epoch == 1 or (epoch + 1) % 50 == 0:
162+
print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.14f}, Best Loss: {best_loss:.14f}')
163+
164+
# Restore the best model state if training didn't reach convergence
165+
if best_model_state:
166+
model.load_state_dict(best_model_state)
167+
print(f"Restored best model with loss: {best_loss:.14f}")
168+
169+
# Save the best model to file
170+
torch.save(best_model_state, "best_rbf_model.pth")
171+
print("Best model saved as 'best_rbf_model.pth'.")
151172

152173
# Generate validation data
153174
num_points_val = 100
@@ -167,19 +188,24 @@ def main(num_points):
167188
psi_pred = pred.reshape(X_val.shape)
168189

169190
# Visualize actual and predicted stream functions
170-
visualize_psi(X_val, Y_val, psi_val, centers, title="Actual Stream Function")
171-
visualize_psi(X_val, Y_val, psi_pred, centers, title="Predicted Stream Function")
191+
visualize_psi(X_val, Y_val, psi_val, rbf_type,
192+
centers, title="Actual Stream Function")
193+
visualize_psi(X_val, Y_val, psi_pred, rbf_type,
194+
centers, title="Predicted Stream Function")
172195

173196
err_val = np.abs(psi_pred - psi_val) / np.max(psi_val)
174-
visualize_psi(X_val, Y_val, err_val, centers,
175-
title="Stream Function Relative Approximation Error")
197+
visualize_psi(X_val, Y_val, err_val, rbf_type, centers,
198+
title="Stream Function Relative Error")
176199

177200
# Define the filename
178201
csv_filename = "stream_function_validation.csv"
179202

180203
# Define the header and the values to be appended
181-
header = ["num_points", "point_dist", "r_max", "err_mean", "err_max"]
182-
data = [num_points, 1.0 / num_points, r_max, np.mean(err_val), np.max(err_val)]
204+
header = ["model_rbf_type", "num_points", "support_radius",
205+
"point_dist", "err_mean", "err_max"]
206+
207+
data = [model.rbf_type, num_points, r_max, 1.0 / num_points,
208+
np.mean(err_val), np.max(err_val)]
183209

184210
# Check if file exists
185211
file_exists = os.path.isfile(csv_filename)
@@ -199,18 +225,19 @@ def main(num_points):
199225

200226
# Compute and visualize actual and predicted velocity fields
201227
u_val, v_val = compute_velocity(X_val, Y_val, psi_val)
202-
visualize_velocity_field(X_val, Y_val, u_val, v_val, num_points,
203-
title="Actual Velocity Field")
228+
visualize_velocity_field(X_val, Y_val, u_val, v_val, rbf_type, num_points,
229+
title="Velocity Field")
204230

205231
u_pred, v_pred = compute_velocity(X_val, Y_val, psi_pred)
206-
visualize_velocity_field(X_val, Y_val, u_pred, v_pred, num_points,
232+
visualize_velocity_field(X_val, Y_val, u_pred, v_pred, rbf_type, num_points,
207233
title="Predicted Velocity Field")
208234

209235
if __name__ == "__main__":
210236

211-
# Run mesh convergence study
212-
for num_points in [4,8,16,32]:
213-
main(num_points)
237+
# Run the parameter study
238+
for rbf_type in ["gaussian"]:
239+
for num_points in [4,8,16,32]:
240+
main(num_points, rbf_type)
214241

215242
# Estimate convergence order
216243
estimate_convergence_order("stream_function_validation.csv")

0 commit comments

Comments
 (0)