@@ -24,11 +24,11 @@ def compute_velocity(x, y, psi_values):
2424 v = - dx # v = -∂ψ/∂x
2525 return u , v
2626
27- def visualize_psi (x , y , psi_values , centers , title = "Stream Function" ):
27+ def visualize_psi (x , y , psi_values , rbf_type , centers , title ):
2828 plt .figure (figsize = (6 , 6 ))
2929 plt .contourf (x , y , psi_values , levels = 20 , cmap = 'viridis' )
3030 plt .colorbar (label = 'ψ' )
31- plt .title (title + f" num_centers { len (centers )} " )
31+ plt .title (title + f"-rbf_type_ { rbf_type } -num_centers_ { len (centers )} " )
3232 plt .xlabel ('x' )
3333 plt .ylabel ('y' )
3434 plt .grid ()
@@ -39,17 +39,17 @@ def visualize_psi(x, y, psi_values, centers, title="Stream Function"):
3939 plt .legend ()
4040
4141 fig_name = title .replace (" " , "-" )
42- plt .savefig (f"{ fig_name } -num_centers- { len (centers )} .png" , dpi = 200 )
42+ plt .savefig (f"{ fig_name } -rbf_type_ { rbf_type } -num_centers_ { len (centers )} .png" , dpi = 200 )
4343
44- def visualize_velocity_field (x , y , u , v , num_centers , title = "Velocity Field " ):
44+ def visualize_velocity_field (x , y , u , v , rbf_type , num_centers , title = "Velocity" ):
4545 plt .figure (figsize = (6 , 6 ))
4646 plt .quiver (x , y , u , v )
47- plt .title (title )
47+ plt .title (f" { title } - { rbf_type } -n_centers_ { num_centers } " )
4848 plt .xlabel ('x' )
4949 plt .ylabel ('y' )
5050 plt .grid ()
5151 fig_name = title .replace (" " , "-" )
52- plt .savefig (f"{ fig_name } -num_centers { num_centers } .png" , dpi = 200 )
52+ plt .savefig (f"{ fig_name } -rbf_type_ { rbf_type } -num_centers_ { num_centers } .png" , dpi = 200 )
5353
5454def generate_centers (num_centers ):
5555 """
@@ -106,7 +106,7 @@ def estimate_convergence_order(csv_filename):
106106
107107 print (f"Updated { csv_filename } with convergence orders for { error_columns } ." )
108108
109- def main (num_points ):
109+ def main (num_points , rbf_type ):
110110 # Generate training data
111111 x = np .linspace (0 , 1 , num_points )
112112 y = np .linspace (0 , 1 , num_points )
@@ -119,21 +119,26 @@ def main(num_points):
119119 y_train = torch .tensor (psi_train , dtype = torch .float32 )
120120
121121 # Generate centers
122- # centers = generate_centers(32).clone().detach()
123122 centers = x_train
124123 print (centers .shape )
124+
125+ # Gaussian 3d-order support
126+ #r_max = 2.5 / num_points
125127 r_max = 2.5 / num_points
126- smoothness = 4 # C^4 smoothness
127128
128129 # Initialize model
129- model = RadialBasisFunctionNetwork (centers , r_max , rbf_dict , rbf_type = "gaussian" )
130+ model = RadialBasisFunctionNetwork (centers , r_max , rbf_dict , rbf_type = rbf_type )
130131
131132 # Optimizer and loss
132133 optimizer = optim .Adam (model .parameters (), lr = 0.05 )
133- criterion = nn .MSELoss ()
134+ criterion = torch . nn .MSELoss ()
134135
135136 # Training loop
136137 epochs = 4000
138+ best_loss = float ("inf" ) # Initialize best loss to a large value
139+ best_model_state = None # Store best model state
140+ stop_loss = 1e-08
141+
137142 for epoch in range (epochs ):
138143 model .train ()
139144 optimizer .zero_grad ()
@@ -142,12 +147,28 @@ def main(num_points):
142147 loss .backward ()
143148 optimizer .step ()
144149
145- if loss .item () < 1e-12 :
146- print (f"Stopping early at epoch { epoch + 1 } due to reaching loss { loss .item ()} < 1e-05" )
150+ # Save the model if it has the lowest loss so far
151+ if loss .item () < best_loss :
152+ best_loss = loss .item ()
153+ best_model_state = model .state_dict ().copy () # Copy best model state
154+
155+ # Early stopping criterion
156+ if loss .item () < stop_loss :
157+ print (f"Stopping early at epoch { epoch + 1 } due to reaching loss { loss .item ()} < { stop_loss } " )
147158 break
148159
149- if ((epoch == 1 ) or ((epoch + 1 ) % 50 == 0 )):
150- print (f'Epoch [{ epoch + 1 } /{ epochs } ], Loss: { loss .item ():.14f} ' )
160+ # Print progress every 50 epochs
161+ if epoch == 1 or (epoch + 1 ) % 50 == 0 :
162+ print (f'Epoch [{ epoch + 1 } /{ epochs } ], Loss: { loss .item ():.14f} , Best Loss: { best_loss :.14f} ' )
163+
164+ # Restore the best model state if training didn't reach convergence
165+ if best_model_state :
166+ model .load_state_dict (best_model_state )
167+ print (f"Restored best model with loss: { best_loss :.14f} " )
168+
169+ # Save the best model to file
170+ torch .save (best_model_state , "best_rbf_model.pth" )
171+ print ("Best model saved as 'best_rbf_model.pth'." )
151172
152173 # Generate validation data
153174 num_points_val = 100
@@ -167,19 +188,24 @@ def main(num_points):
167188 psi_pred = pred .reshape (X_val .shape )
168189
169190 # Visualize actual and predicted stream functions
170- visualize_psi (X_val , Y_val , psi_val , centers , title = "Actual Stream Function" )
171- visualize_psi (X_val , Y_val , psi_pred , centers , title = "Predicted Stream Function" )
191+ visualize_psi (X_val , Y_val , psi_val , rbf_type ,
192+ centers , title = "Actual Stream Function" )
193+ visualize_psi (X_val , Y_val , psi_pred , rbf_type ,
194+ centers , title = "Predicted Stream Function" )
172195
173196 err_val = np .abs (psi_pred - psi_val ) / np .max (psi_val )
174- visualize_psi (X_val , Y_val , err_val , centers ,
175- title = "Stream Function Relative Approximation Error" )
197+ visualize_psi (X_val , Y_val , err_val , rbf_type , centers ,
198+ title = "Stream Function Relative Error" )
176199
177200 # Define the filename
178201 csv_filename = "stream_function_validation.csv"
179202
180203 # Define the header and the values to be appended
181- header = ["num_points" , "point_dist" , "r_max" , "err_mean" , "err_max" ]
182- data = [num_points , 1.0 / num_points , r_max , np .mean (err_val ), np .max (err_val )]
204+ header = ["model_rbf_type" , "num_points" , "support_radius" ,
205+ "point_dist" , "err_mean" , "err_max" ]
206+
207+ data = [model .rbf_type , num_points , r_max , 1.0 / num_points ,
208+ np .mean (err_val ), np .max (err_val )]
183209
184210 # Check if file exists
185211 file_exists = os .path .isfile (csv_filename )
@@ -199,18 +225,19 @@ def main(num_points):
199225
200226 # Compute and visualize actual and predicted velocity fields
201227 u_val , v_val = compute_velocity (X_val , Y_val , psi_val )
202- visualize_velocity_field (X_val , Y_val , u_val , v_val , num_points ,
203- title = "Actual Velocity Field" )
228+ visualize_velocity_field (X_val , Y_val , u_val , v_val , rbf_type , num_points ,
229+ title = "Velocity Field" )
204230
205231 u_pred , v_pred = compute_velocity (X_val , Y_val , psi_pred )
206- visualize_velocity_field (X_val , Y_val , u_pred , v_pred , num_points ,
232+ visualize_velocity_field (X_val , Y_val , u_pred , v_pred , rbf_type , num_points ,
207233 title = "Predicted Velocity Field" )
208234
209235if __name__ == "__main__" :
210236
211- # Run mesh convergence study
212- for num_points in [4 ,8 ,16 ,32 ]:
213- main (num_points )
237+ # Run the parameter study
238+ for rbf_type in ["gaussian" ]:
239+ for num_points in [4 ,8 ,16 ,32 ]:
240+ main (num_points , rbf_type )
214241
215242 # Estimate convergence order
216243 estimate_convergence_order ("stream_function_validation.csv" )
0 commit comments