Skip to content

Commit bf51ede

Browse files
committed
[mesh motion] generalized RBF network, Gaussian converges
1 parent ff52b4e commit bf51ede

File tree

2 files changed

+79
-59
lines changed

2 files changed

+79
-59
lines changed
Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,78 @@
11
import torch
22
import torch.nn as nn
3-
import math
43

5-
class WendlandLinearNetwork(nn.Module):
6-
def __init__(self, centers, r_max, smoothness):
4+
# Define various RBF functions with enforced compact support
5+
def gaussian_rbf(r):
6+
"""Infinitely smooth Gaussian RBF, matching Wendland's implementation."""
7+
return torch.exp(-r**2) # Matches WendlandLinearNetwork
8+
9+
def wendland_rbf(r):
10+
"""Compactly supported Wendland C^4 RBF."""
11+
mask = (r < 1).float()
12+
rm = (1 - r).clamp(min=0.0)
13+
return mask * (1 + 6*r + (35/3)*r**2) * rm**6
14+
15+
def multiquadric_rbf(r):
16+
"""Multiquadric RBF with compact support."""
17+
mask = (r < 1).float()
18+
return mask * torch.sqrt(1 + r**2)
19+
20+
def inverse_multiquadric_rbf(r):
21+
"""Inverse multiquadric RBF with compact support."""
22+
mask = (r < 1).float()
23+
return mask / torch.sqrt(1 + r**2)
24+
25+
# Create an RBF function dictionary
26+
rbf_dict = {
27+
"gaussian": gaussian_rbf,
28+
"wendland": wendland_rbf,
29+
"multiquadric": multiquadric_rbf,
30+
"inverse_multiquadric": inverse_multiquadric_rbf
31+
}
32+
33+
class RadialBasisFunctionNetwork(nn.Module):
34+
def __init__(self, centers, r_max, rbf_dict, rbf_type):
735
"""
8-
Initialize the Wendland RBF network with linear polynomial terms.
36+
Generalized RBF network with user-selectable RBF functions.
937
1038
Parameters:
11-
centers (torch.Tensor): shape (num_centers, dimension) RBF centers.
12-
r_max (float): radius of compact support.
13-
smoothness (int): even integer specifying desired smoothness (C^smoothness).
39+
centers (torch.Tensor): shape (num_centers, dimension), RBF centers.
40+
r_max (float): radius of compact support (applies to all RBFs).
41+
rbf_dict (dict): Dictionary mapping RBF type names to function implementations.
42+
rbf_type (str): Type of RBF function to use (must be in rbf_dict).
1443
"""
1544
super().__init__()
1645

17-
if smoothness != 4:
18-
raise NotImplementedError("Only smoothness=4 (C⁴) is currently implemented explicitly.")
19-
20-
self.centers = centers.clone().detach()
46+
self.centers = centers.clone().detach() # Fixed RBF centers
2147
self.r_max = r_max
22-
self.smoothness = smoothness
23-
self.k = smoothness // 2
2448
self.num_centers, self.dimension = centers.shape
49+
self.rbf_type = rbf_type.lower() # Store selected RBF type as an attribute
2550

26-
# Trainable parameters (explicitly initialized!)
27-
self.weights = nn.Parameter(torch.zeros(self.num_centers))
28-
self.a0 = nn.Parameter(torch.tensor(0.0))
29-
self.a = nn.Parameter(torch.zeros(self.dimension))
51+
# Ensure rbf_type is valid
52+
if self.rbf_type not in rbf_dict:
53+
raise ValueError(f"Invalid RBF type '{self.rbf_type}'. Available options: {list(rbf_dict.keys())}")
3054

31-
#def rbf(self, x):
32-
# r = torch.cdist(x, self.centers) / self.r_max
33-
# mask = (r < 1).float()
34-
# rm = (1 - r).clamp(min=0.0)
55+
self.rbf_function = rbf_dict[self.rbf_type] # Store selected RBF function
3556

36-
# phi = (1 + 6*r + (35/3)*r**2) * rm**6 # Only compute once
37-
# return phi * mask # Ensure compact support
57+
# Trainable parameters (weights for RBFs) - match Wendland's model!
58+
self.weights = nn.Parameter(torch.zeros(self.num_centers)) # Initialize to zeros
59+
self.a0 = nn.Parameter(torch.tensor(0.0)) # Bias term initialized as 0 (like Wendland)
3860

3961
def rbf(self, x):
4062
"""
41-
Compute Gaussian RBF instead of Wendland.
63+
Compute the RBF values for input x using the selected RBF function.
4264
"""
43-
r = torch.cdist(x, self.centers) / self.r_max
44-
return torch.exp(-r**2) # Infinitely smooth Gaussian RBF
45-
46-
#def rbf(self, x):
47-
# """
48-
# Compute Gaussian RBF with enforced compact support.
49-
# """
50-
# r = torch.cdist(x, self.centers) / self.r_max
51-
# mask = (r < 1).float() # Ensure compact support
52-
# return mask * torch.exp(-r**2) # Compactly supported Gaussian
65+
r = torch.cdist(x, self.centers) / self.r_max # Compute normalized distance
66+
return self.rbf_function(r) # Apply the selected RBF function
5367

5468
def forward(self, x):
5569
"""
56-
Forward pass: Wendland RBF with explicit linear polynomial extension.
70+
Forward pass: Compute RBF output.
5771
"""
5872
rbf_output = self.rbf(x)
5973
rbf_term = rbf_output @ self.weights
60-
linear_term = x @ self.a
61-
return self.a0 + rbf_term #+ linear_term
74+
return self.a0 + rbf_term # No polynomial correction term
75+
76+
def get_rbf_type(self):
77+
"""Return the currently selected RBF type."""
78+
return self.rbf_type

src/displacementSmartSimMotionSolver/pytorchApproximationModels/test_rbf_network_stream_function.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
import os
99

10-
from rbf_network import WendlandLinearNetwork
10+
from rbf_network import rbf_dict, RadialBasisFunctionNetwork
1111

1212
def psi(x, y):
1313
"""
@@ -28,7 +28,7 @@ def visualize_psi(x, y, psi_values, centers, title="Stream Function"):
2828
plt.figure(figsize=(6, 6))
2929
plt.contourf(x, y, psi_values, levels=20, cmap='viridis')
3030
plt.colorbar(label='ψ')
31-
plt.title(title)
31+
plt.title(title + f" num_centers {len(centers)}")
3232
plt.xlabel('x')
3333
plt.ylabel('y')
3434
plt.grid()
@@ -64,7 +64,7 @@ def generate_centers(num_centers):
6464
def estimate_convergence_order(csv_filename):
6565
"""
6666
Opens a CSV file containing numerical convergence results and estimates the
67-
convergence order for each row using log-log error reduction.
67+
convergence order for each error column using log-log error reduction.
6868
6969
The last row's convergence order is set equal to the second-to-last row.
7070
@@ -75,33 +75,36 @@ def estimate_convergence_order(csv_filename):
7575
df = pd.read_csv(csv_filename)
7676

7777
# Ensure required columns exist
78-
required_columns = {"point_dist", "err_validation"}
78+
required_columns = {"point_dist", "err_mean", "err_max"}
7979
if not required_columns.issubset(df.columns):
8080
raise ValueError(f"Missing required columns in CSV: {required_columns - set(df.columns)}")
8181

82-
# Compute convergence order using log-log slope formula
83-
convergence_orders = []
82+
# List of error columns to process
83+
error_columns = ["err_mean", "err_max"]
84+
85+
for error_col in error_columns:
86+
convergence_orders = [] # Store convergence orders for this error type
8487

85-
for i in range(len(df) - 1): # Iterate up to the second-to-last row
86-
h_coarse, h_fine = df.iloc[i]["point_dist"], df.iloc[i + 1]["point_dist"]
87-
err_coarse, err_fine = df.iloc[i]["err_validation"], df.iloc[i + 1]["err_validation"]
88+
for i in range(len(df) - 1): # Iterate up to the second-to-last row
89+
h_coarse, h_fine = df.iloc[i]["point_dist"], df.iloc[i + 1]["point_dist"]
90+
err_coarse, err_fine = df.iloc[i][error_col], df.iloc[i + 1][error_col]
8891

89-
if err_coarse > 0 and err_fine > 0: # Avoid log errors due to zero or negative values
90-
p = np.log(err_coarse / err_fine) / np.log(h_coarse / h_fine)
91-
convergence_orders.append(p)
92-
else:
93-
convergence_orders.append(np.nan)
92+
if err_coarse > 0 and err_fine > 0: # Avoid log errors due to zero or negative values
93+
p = np.log(err_coarse / err_fine) / np.log(h_coarse / h_fine)
94+
convergence_orders.append(p)
95+
else:
96+
convergence_orders.append(np.nan)
9497

95-
# Ensure last row gets the same convergence order as the previous row
96-
convergence_orders.append(convergence_orders[-1] if len(convergence_orders) > 0 else np.nan)
98+
# Ensure last row gets the same convergence order as the previous row
99+
convergence_orders.append(convergence_orders[-1] if len(convergence_orders) > 0 else np.nan)
97100

98-
# Add convergence order column
99-
df["error_convergence_order"] = convergence_orders
101+
# Add convergence order column to DataFrame
102+
df[f"{error_col}_convergence_order"] = convergence_orders
100103

101104
# Save the updated CSV file
102105
df.to_csv(csv_filename, index=False)
103106

104-
print(f"Updated {csv_filename} with convergence orders.")
107+
print(f"Updated {csv_filename} with convergence orders for {error_columns}.")
105108

106109
def main(num_points):
107110
# Generate training data
@@ -123,7 +126,7 @@ def main(num_points):
123126
smoothness = 4 # C^4 smoothness
124127

125128
# Initialize model
126-
model = WendlandLinearNetwork(centers, r_max, smoothness)
129+
model = RadialBasisFunctionNetwork(centers, r_max, rbf_dict, rbf_type="gaussian")
127130

128131
# Optimizer and loss
129132
optimizer = optim.Adam(model.parameters(), lr=0.05)
@@ -175,8 +178,8 @@ def main(num_points):
175178
csv_filename = "stream_function_validation.csv"
176179

177180
# Define the header and the values to be appended
178-
header = ["num_points", "point_dist", "r_max", "err_validation"]
179-
data = [num_points, 1.0 / num_points, r_max, np.mean(err_val)]
181+
header = ["num_points", "point_dist", "r_max", "err_mean", "err_max"]
182+
data = [num_points, 1.0 / num_points, r_max, np.mean(err_val), np.max(err_val)]
180183

181184
# Check if file exists
182185
file_exists = os.path.isfile(csv_filename)

0 commit comments

Comments
 (0)