|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.optim as optim |
| 6 | +import csv |
| 7 | +import os |
| 8 | + |
| 9 | +from rbf_network import rbf_dict, RadialBasisFunctionNetwork |
| 10 | + |
| 11 | + |
| 12 | +def velocity_u(x, y): |
| 13 | + return (np.sin(np.pi * x) ** 2) * np.sin(2 * np.pi * y) * np.pi |
| 14 | + |
| 15 | +def velocity_v(x, y): |
| 16 | + return -np.sin(2 * np.pi * x) * (np.sin(np.pi * y) ** 2) * np.pi |
| 17 | + |
| 18 | +def generate_boundary_points(num_rays, R, C): |
| 19 | + theta = np.linspace(0, 2 * np.pi, num_rays, endpoint=False) |
| 20 | + |
| 21 | + circle_x = C[0] + R * np.cos(theta) |
| 22 | + circle_y = C[1] + R * np.sin(theta) |
| 23 | + circle_boundary = np.column_stack((circle_x, circle_y)) |
| 24 | + |
| 25 | + outer_boundary = [] |
| 26 | + for t in theta: |
| 27 | + dx, dy = np.cos(t), np.sin(t) |
| 28 | + intersections = [] |
| 29 | + |
| 30 | + if dx != 0: |
| 31 | + for x_edge in [0.0, 1.0]: |
| 32 | + s = (x_edge - C[0]) / dx |
| 33 | + y = C[1] + s * dy |
| 34 | + if 0 <= y <= 1 and s > 0: |
| 35 | + intersections.append([x_edge, y]) |
| 36 | + if dy != 0: |
| 37 | + for y_edge in [0.0, 1.0]: |
| 38 | + s = (y_edge - C[1]) / dy |
| 39 | + x = C[0] + s * dx |
| 40 | + if 0 <= x <= 1 and s > 0: |
| 41 | + intersections.append([x, y_edge]) |
| 42 | + |
| 43 | + if intersections: |
| 44 | + dists = [np.linalg.norm(np.array(p) - np.array(C)) for p in intersections] |
| 45 | + outer_point = intersections[np.argmin(dists)] |
| 46 | + outer_boundary.append(outer_point) |
| 47 | + |
| 48 | + outer_boundary = np.array(outer_boundary) |
| 49 | + boundary_points = np.vstack([outer_boundary, circle_boundary]) |
| 50 | + return torch.tensor(boundary_points, dtype=torch.float32) |
| 51 | + |
| 52 | + |
| 53 | +def filter_inside_circle(points, R, C): |
| 54 | + distances = np.sqrt((points[:, 0] - C[0]) ** 2 + (points[:, 1] - C[1]) ** 2) |
| 55 | + return points[distances > R] |
| 56 | + |
| 57 | + |
| 58 | +def visualize_velocity_field_with_mask(x, y, u, v, rbf_type, centers, title, R, C): |
| 59 | + fig, ax = plt.subplots(figsize=(6, 6)) |
| 60 | + ax.set_aspect('equal') |
| 61 | + ax.quiver(x, y, u, v, scale=40) |
| 62 | + |
| 63 | + circle = plt.Circle(C, R, color='white', zorder=10) |
| 64 | + ax.add_patch(circle) |
| 65 | + |
| 66 | + num_centers = len(centers) |
| 67 | + s_max, s_min = 100, 25 |
| 68 | + num_min, num_max = 16, 128 |
| 69 | + s = s_max - (s_max - s_min) * (num_centers - num_min) / (num_max - num_min) |
| 70 | + s = max(s_min, min(s_max, s)) |
| 71 | + |
| 72 | + centers_np = centers.numpy() |
| 73 | + ax.scatter(centers_np[:, 0], centers_np[:, 1], color='k', marker='x', s=s, linewidths=2, label='Centers') |
| 74 | + ax.legend() |
| 75 | + ax.set_title(f"{title} {rbf_type.upper()} {num_centers}") |
| 76 | + ax.set_xlabel('x') |
| 77 | + ax.set_ylabel('y') |
| 78 | + ax.grid(True) |
| 79 | + |
| 80 | + fig_name = title.replace(" ", "-") |
| 81 | + plt.savefig(f"{fig_name}-rbf_type_{rbf_type}-num_centers_{num_centers}.png", dpi=200) |
| 82 | + plt.close(fig) |
| 83 | + |
| 84 | +def visualize_velocity_error_norm(x, y, u_pred, v_pred, rbf_type, centers, title, R, C): |
| 85 | + """ |
| 86 | + Visualize the 2-norm of the velocity error at validation points. |
| 87 | + """ |
| 88 | + # Exact velocity |
| 89 | + u_true = velocity_u(x, y) |
| 90 | + v_true = velocity_v(x, y) |
| 91 | + |
| 92 | + error_norm = np.sqrt((u_pred - u_true)**2 + (v_pred - v_true)**2) |
| 93 | + |
| 94 | + fig, ax = plt.subplots(figsize=(6, 6)) |
| 95 | + ax.set_aspect("equal") |
| 96 | + |
| 97 | + # Plot error as color scatter |
| 98 | + sc = ax.scatter(x, y, c=error_norm, cmap='magma', s=10) |
| 99 | + plt.colorbar(sc, ax=ax, label="||u_pred - u_true||") |
| 100 | + |
| 101 | + # Mask the circle area in white |
| 102 | + circle = plt.Circle(C, R, color='white', zorder=10) |
| 103 | + ax.add_patch(circle) |
| 104 | + |
| 105 | + # Plot centers |
| 106 | + num_centers = len(centers) |
| 107 | + s_max, s_min = 100, 25 |
| 108 | + num_min, num_max = 16, 128 |
| 109 | + s = s_max - (s_max - s_min) * (num_centers - num_min) / (num_max - num_min) |
| 110 | + s = max(s_min, min(s_max, s)) |
| 111 | + centers_np = centers.numpy() |
| 112 | + ax.scatter(centers_np[:, 0], centers_np[:, 1], color='k', marker='x', s=s, linewidths=2, label='Centers') |
| 113 | + |
| 114 | + ax.set_title(f"{title} {rbf_type.upper()} {num_centers}") |
| 115 | + ax.set_xlabel("x") |
| 116 | + ax.set_ylabel("y") |
| 117 | + ax.grid(True) |
| 118 | + ax.legend() |
| 119 | + |
| 120 | + fig_name = f"{title.replace(' ', '-')}-rbf_type_{rbf_type}-num_centers_{num_centers}.png" |
| 121 | + plt.savefig(fig_name, dpi=200) |
| 122 | + plt.close(fig) |
| 123 | + |
| 124 | + |
| 125 | +def main(num_points, rbf_type): |
| 126 | + R = 0.15 |
| 127 | + C = (0.5, 0.75) |
| 128 | + |
| 129 | + centers = generate_boundary_points(num_points, R=R, C=C) |
| 130 | + |
| 131 | + num_points_val = 100 |
| 132 | + x_val = np.linspace(0, 1, num_points_val) |
| 133 | + y_val = np.linspace(0, 1, num_points_val) |
| 134 | + X_val, Y_val = np.meshgrid(x_val, y_val) |
| 135 | + xy_val = np.column_stack((X_val.flatten(), Y_val.flatten())) |
| 136 | + xy_val_filtered = filter_inside_circle(xy_val, R=R, C=C) |
| 137 | + |
| 138 | + # Training data (u, v) |
| 139 | + x_train = centers |
| 140 | + u_train = torch.tensor(velocity_u(centers[:, 0], centers[:, 1]), dtype=torch.float32) |
| 141 | + v_train = torch.tensor(velocity_v(centers[:, 0], centers[:, 1]), dtype=torch.float32) |
| 142 | + |
| 143 | + # Fit two RBF models: one for u, one for v |
| 144 | + r_max = 3 / num_points |
| 145 | + |
| 146 | + def train_component_model(y_train): |
| 147 | + model = RadialBasisFunctionNetwork(x_train, r_max, rbf_dict, rbf_type=rbf_type) |
| 148 | + optimizer = optim.Adam(model.parameters(), lr=0.05) |
| 149 | + criterion = nn.MSELoss() |
| 150 | + best_loss = float("inf") |
| 151 | + best_model_state = None |
| 152 | + stop_loss = 1e-6 |
| 153 | + epochs = 4000 |
| 154 | + |
| 155 | + for epoch in range(epochs): |
| 156 | + model.train() |
| 157 | + optimizer.zero_grad() |
| 158 | + output = model(x_train) |
| 159 | + loss = criterion(output, y_train) |
| 160 | + loss.backward() |
| 161 | + optimizer.step() |
| 162 | + |
| 163 | + if loss.item() < best_loss: |
| 164 | + best_loss = loss.item() |
| 165 | + best_model_state = model.state_dict().copy() |
| 166 | + |
| 167 | + if loss.item() < stop_loss: |
| 168 | + break |
| 169 | + |
| 170 | + if best_model_state: |
| 171 | + model.load_state_dict(best_model_state) |
| 172 | + model.eval() |
| 173 | + return model |
| 174 | + |
| 175 | + model_u = train_component_model(u_train) |
| 176 | + model_v = train_component_model(v_train) |
| 177 | + |
| 178 | + # Predict velocities at validation points |
| 179 | + with torch.no_grad(): |
| 180 | + x_val_torch = torch.tensor(xy_val_filtered, dtype=torch.float32) |
| 181 | + u_pred = model_u(x_val_torch).numpy() |
| 182 | + v_pred = model_v(x_val_torch).numpy() |
| 183 | + |
| 184 | + # Visualize velocity field |
| 185 | + visualize_velocity_field_with_mask( |
| 186 | + xy_val_filtered[:, 0], xy_val_filtered[:, 1], u_pred, v_pred, |
| 187 | + rbf_type, centers, title="Velocity Field", R=R, C=C |
| 188 | + ) |
| 189 | + |
| 190 | + # Visualize 2-norm error |
| 191 | + visualize_velocity_error_norm( |
| 192 | + xy_val_filtered[:, 0], xy_val_filtered[:, 1], u_pred, v_pred, |
| 193 | + rbf_type, centers, title="Velocity Error Norm", R=R, C=C |
| 194 | + ) |
| 195 | + |
| 196 | + # Save mean/max error if desired |
| 197 | + u_true = velocity_u(xy_val_filtered[:, 0], xy_val_filtered[:, 1]) |
| 198 | + v_true = velocity_v(xy_val_filtered[:, 0], xy_val_filtered[:, 1]) |
| 199 | + err_u = np.abs(u_pred - u_true) / (np.max(np.abs(u_true)) + 1e-12) |
| 200 | + err_v = np.abs(v_pred - v_true) / (np.max(np.abs(v_true)) + 1e-12) |
| 201 | + |
| 202 | + csv_filename = "velocity_validation.csv" |
| 203 | + header = ["model_rbf_type", "num_points", "r_max", "err_mean_u", "err_max_u", "err_mean_v", "err_max_v"] |
| 204 | + data = [rbf_type, num_points, r_max, np.mean(err_u), np.max(err_u), np.mean(err_v), np.max(err_v)] |
| 205 | + |
| 206 | + file_exists = os.path.isfile(csv_filename) |
| 207 | + with open(csv_filename, mode='a', newline='') as file: |
| 208 | + writer = csv.writer(file) |
| 209 | + if not file_exists: |
| 210 | + writer.writerow(header) |
| 211 | + writer.writerow(data) |
| 212 | + |
| 213 | + print(f"Appended to {csv_filename}: {data}") |
| 214 | + |
| 215 | + |
| 216 | +if __name__ == "__main__": |
| 217 | + for rbf_type in ["gaussian", "wendland_d2_c4"]: |
| 218 | + for num_points in [16, 32, 64, 128]: |
| 219 | + main(num_points, rbf_type) |
0 commit comments