Skip to content

Commit 8cc3db2

Browse files
committed
[mesh motion] initial working RBF boundary velocity interpolation
1 parent 5e3aa2f commit 8cc3db2

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.png
22
*.csv
3+
*.pth
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import torch
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
import csv
7+
import os
8+
9+
from rbf_network import rbf_dict, RadialBasisFunctionNetwork
10+
11+
12+
def velocity_u(x, y):
13+
return (np.sin(np.pi * x) ** 2) * np.sin(2 * np.pi * y) * np.pi
14+
15+
def velocity_v(x, y):
16+
return -np.sin(2 * np.pi * x) * (np.sin(np.pi * y) ** 2) * np.pi
17+
18+
def generate_boundary_points(num_rays, R, C):
19+
theta = np.linspace(0, 2 * np.pi, num_rays, endpoint=False)
20+
21+
circle_x = C[0] + R * np.cos(theta)
22+
circle_y = C[1] + R * np.sin(theta)
23+
circle_boundary = np.column_stack((circle_x, circle_y))
24+
25+
outer_boundary = []
26+
for t in theta:
27+
dx, dy = np.cos(t), np.sin(t)
28+
intersections = []
29+
30+
if dx != 0:
31+
for x_edge in [0.0, 1.0]:
32+
s = (x_edge - C[0]) / dx
33+
y = C[1] + s * dy
34+
if 0 <= y <= 1 and s > 0:
35+
intersections.append([x_edge, y])
36+
if dy != 0:
37+
for y_edge in [0.0, 1.0]:
38+
s = (y_edge - C[1]) / dy
39+
x = C[0] + s * dx
40+
if 0 <= x <= 1 and s > 0:
41+
intersections.append([x, y_edge])
42+
43+
if intersections:
44+
dists = [np.linalg.norm(np.array(p) - np.array(C)) for p in intersections]
45+
outer_point = intersections[np.argmin(dists)]
46+
outer_boundary.append(outer_point)
47+
48+
outer_boundary = np.array(outer_boundary)
49+
boundary_points = np.vstack([outer_boundary, circle_boundary])
50+
return torch.tensor(boundary_points, dtype=torch.float32)
51+
52+
53+
def filter_inside_circle(points, R, C):
54+
distances = np.sqrt((points[:, 0] - C[0]) ** 2 + (points[:, 1] - C[1]) ** 2)
55+
return points[distances > R]
56+
57+
58+
def visualize_velocity_field_with_mask(x, y, u, v, rbf_type, centers, title, R, C):
59+
fig, ax = plt.subplots(figsize=(6, 6))
60+
ax.set_aspect('equal')
61+
ax.quiver(x, y, u, v, scale=40)
62+
63+
circle = plt.Circle(C, R, color='white', zorder=10)
64+
ax.add_patch(circle)
65+
66+
num_centers = len(centers)
67+
s_max, s_min = 100, 25
68+
num_min, num_max = 16, 128
69+
s = s_max - (s_max - s_min) * (num_centers - num_min) / (num_max - num_min)
70+
s = max(s_min, min(s_max, s))
71+
72+
centers_np = centers.numpy()
73+
ax.scatter(centers_np[:, 0], centers_np[:, 1], color='k', marker='x', s=s, linewidths=2, label='Centers')
74+
ax.legend()
75+
ax.set_title(f"{title} {rbf_type.upper()} {num_centers}")
76+
ax.set_xlabel('x')
77+
ax.set_ylabel('y')
78+
ax.grid(True)
79+
80+
fig_name = title.replace(" ", "-")
81+
plt.savefig(f"{fig_name}-rbf_type_{rbf_type}-num_centers_{num_centers}.png", dpi=200)
82+
plt.close(fig)
83+
84+
def visualize_velocity_error_norm(x, y, u_pred, v_pred, rbf_type, centers, title, R, C):
85+
"""
86+
Visualize the 2-norm of the velocity error at validation points.
87+
"""
88+
# Exact velocity
89+
u_true = velocity_u(x, y)
90+
v_true = velocity_v(x, y)
91+
92+
error_norm = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2)
93+
94+
fig, ax = plt.subplots(figsize=(6, 6))
95+
ax.set_aspect("equal")
96+
97+
# Plot error as color scatter
98+
sc = ax.scatter(x, y, c=error_norm, cmap='magma', s=10)
99+
plt.colorbar(sc, ax=ax, label="||u_pred - u_true||")
100+
101+
# Mask the circle area in white
102+
circle = plt.Circle(C, R, color='white', zorder=10)
103+
ax.add_patch(circle)
104+
105+
# Plot centers
106+
num_centers = len(centers)
107+
s_max, s_min = 100, 25
108+
num_min, num_max = 16, 128
109+
s = s_max - (s_max - s_min) * (num_centers - num_min) / (num_max - num_min)
110+
s = max(s_min, min(s_max, s))
111+
centers_np = centers.numpy()
112+
ax.scatter(centers_np[:, 0], centers_np[:, 1], color='k', marker='x', s=s, linewidths=2, label='Centers')
113+
114+
ax.set_title(f"{title} {rbf_type.upper()} {num_centers}")
115+
ax.set_xlabel("x")
116+
ax.set_ylabel("y")
117+
ax.grid(True)
118+
ax.legend()
119+
120+
fig_name = f"{title.replace(' ', '-')}-rbf_type_{rbf_type}-num_centers_{num_centers}.png"
121+
plt.savefig(fig_name, dpi=200)
122+
plt.close(fig)
123+
124+
125+
def main(num_points, rbf_type):
126+
R = 0.15
127+
C = (0.5, 0.75)
128+
129+
centers = generate_boundary_points(num_points, R=R, C=C)
130+
131+
num_points_val = 100
132+
x_val = np.linspace(0, 1, num_points_val)
133+
y_val = np.linspace(0, 1, num_points_val)
134+
X_val, Y_val = np.meshgrid(x_val, y_val)
135+
xy_val = np.column_stack((X_val.flatten(), Y_val.flatten()))
136+
xy_val_filtered = filter_inside_circle(xy_val, R=R, C=C)
137+
138+
# Training data (u, v)
139+
x_train = centers
140+
u_train = torch.tensor(velocity_u(centers[:, 0], centers[:, 1]), dtype=torch.float32)
141+
v_train = torch.tensor(velocity_v(centers[:, 0], centers[:, 1]), dtype=torch.float32)
142+
143+
# Fit two RBF models: one for u, one for v
144+
r_max = 3 / num_points
145+
146+
def train_component_model(y_train):
147+
model = RadialBasisFunctionNetwork(x_train, r_max, rbf_dict, rbf_type=rbf_type)
148+
optimizer = optim.Adam(model.parameters(), lr=0.05)
149+
criterion = nn.MSELoss()
150+
best_loss = float("inf")
151+
best_model_state = None
152+
stop_loss = 1e-6
153+
epochs = 4000
154+
155+
for epoch in range(epochs):
156+
model.train()
157+
optimizer.zero_grad()
158+
output = model(x_train)
159+
loss = criterion(output, y_train)
160+
loss.backward()
161+
optimizer.step()
162+
163+
if loss.item() < best_loss:
164+
best_loss = loss.item()
165+
best_model_state = model.state_dict().copy()
166+
167+
if loss.item() < stop_loss:
168+
break
169+
170+
if best_model_state:
171+
model.load_state_dict(best_model_state)
172+
model.eval()
173+
return model
174+
175+
model_u = train_component_model(u_train)
176+
model_v = train_component_model(v_train)
177+
178+
# Predict velocities at validation points
179+
with torch.no_grad():
180+
x_val_torch = torch.tensor(xy_val_filtered, dtype=torch.float32)
181+
u_pred = model_u(x_val_torch).numpy()
182+
v_pred = model_v(x_val_torch).numpy()
183+
184+
# Visualize velocity field
185+
visualize_velocity_field_with_mask(
186+
xy_val_filtered[:, 0], xy_val_filtered[:, 1], u_pred, v_pred,
187+
rbf_type, centers, title="Velocity Field", R=R, C=C
188+
)
189+
190+
# Visualize 2-norm error
191+
visualize_velocity_error_norm(
192+
xy_val_filtered[:, 0], xy_val_filtered[:, 1], u_pred, v_pred,
193+
rbf_type, centers, title="Velocity Error Norm", R=R, C=C
194+
)
195+
196+
# Save mean/max error if desired
197+
u_true = velocity_u(xy_val_filtered[:, 0], xy_val_filtered[:, 1])
198+
v_true = velocity_v(xy_val_filtered[:, 0], xy_val_filtered[:, 1])
199+
err_u = np.abs(u_pred - u_true) / (np.max(np.abs(u_true)) + 1e-12)
200+
err_v = np.abs(v_pred - v_true) / (np.max(np.abs(v_true)) + 1e-12)
201+
202+
csv_filename = "velocity_validation.csv"
203+
header = ["model_rbf_type", "num_points", "r_max", "err_mean_u", "err_max_u", "err_mean_v", "err_max_v"]
204+
data = [rbf_type, num_points, r_max, np.mean(err_u), np.max(err_u), np.mean(err_v), np.max(err_v)]
205+
206+
file_exists = os.path.isfile(csv_filename)
207+
with open(csv_filename, mode='a', newline='') as file:
208+
writer = csv.writer(file)
209+
if not file_exists:
210+
writer.writerow(header)
211+
writer.writerow(data)
212+
213+
print(f"Appended to {csv_filename}: {data}")
214+
215+
216+
if __name__ == "__main__":
217+
for rbf_type in ["gaussian", "wendland_d2_c4"]:
218+
for num_points in [16, 32, 64, 128]:
219+
main(num_points, rbf_type)

0 commit comments

Comments
 (0)