33import torch
44import torch .nn as nn
55import torch .optim as optim
6+ import csv
7+ import os
68
79from rbf_network import WendlandLinearNetwork
810
@@ -21,7 +23,7 @@ def compute_velocity(x, y, psi_values):
2123 v = - dx # v = -∂ψ/∂x
2224 return u , v
2325
24- def visualize_psi (x , y , psi_values , title = "Stream Function" , centers = None ):
26+ def visualize_psi (x , y , psi_values , centers , title = "Stream Function" ):
2527 plt .figure (figsize = (6 , 6 ))
2628 plt .contourf (x , y , psi_values , levels = 20 , cmap = 'viridis' )
2729 plt .colorbar (label = 'ψ' )
@@ -30,24 +32,23 @@ def visualize_psi(x, y, psi_values, title="Stream Function", centers=None):
3032 plt .ylabel ('y' )
3133 plt .grid ()
3234
33- # Plot centers if provided
34- if centers is not None :
35- centers_np = centers .numpy () # Convert from torch tensor to numpy
36- plt .scatter (centers_np [:, 0 ], centers_np [:, 1 ], color = 'white' , marker = 'x' , s = 100 , linewidths = 2 , label = 'Centers' )
37- plt .legend ()
35+ # Plot centers
36+ centers_np = centers .numpy () # Convert from torch tensor to numpy
37+ plt .scatter (centers_np [:, 0 ], centers_np [:, 1 ], color = 'white' , marker = 'x' , s = 100 , linewidths = 2 , label = 'Centers' )
38+ plt .legend ()
3839
3940 fig_name = title .replace (" " , "-" )
40- plt .savefig (f"{ fig_name } .png" , dpi = 200 )
41+ plt .savefig (f"{ fig_name } -num_centers- { len ( centers ) } .png" , dpi = 200 )
4142
42- def visualize_velocity_field (x , y , u , v , title = "Velocity Field" ):
43+ def visualize_velocity_field (x , y , u , v , num_centers , title = "Velocity Field" ):
4344 plt .figure (figsize = (6 , 6 ))
4445 plt .quiver (x , y , u , v )
4546 plt .title (title )
4647 plt .xlabel ('x' )
4748 plt .ylabel ('y' )
4849 plt .grid ()
4950 fig_name = title .replace (" " , "-" )
50- plt .savefig (f"{ fig_name } .png" , dpi = 200 )
51+ plt .savefig (f"{ fig_name } -num_centers { num_centers } .png" , dpi = 200 )
5152
5253def generate_centers (num_centers ):
5354 """
@@ -59,9 +60,8 @@ def generate_centers(num_centers):
5960 centers = np .vstack ([X .ravel (), Y .ravel ()]).T
6061 return torch .tensor (centers , dtype = torch .float32 )
6162
62- def main ():
63+ def main (num_points ):
6364 # Generate training data
64- num_points = 10
6565 x = np .linspace (0 , 1 , num_points )
6666 y = np .linspace (0 , 1 , num_points )
6767 X , Y = np .meshgrid (x , y )
@@ -76,7 +76,7 @@ def main():
7676 # centers = generate_centers(32).clone().detach()
7777 centers = x_train
7878 print (centers .shape )
79- r_max = 0.1
79+ r_max = 3.0 / num_points
8080 smoothness = 4 # C^4 smoothness
8181
8282 # Initialize model
@@ -104,9 +104,9 @@ def main():
104104 print (f'Epoch [{ epoch + 1 } /{ epochs } ], Loss: { loss .item ():.14f} ' )
105105
106106 # Generate validation data
107- num_points = 100
108- x_val = np .linspace (0 , 1 , num_points )
109- y_val = np .linspace (0 , 1 , num_points )
107+ num_points_val = 100
108+ x_val = np .linspace (0 , 1 , num_points_val )
109+ y_val = np .linspace (0 , 1 , num_points_val )
110110 X_val , Y_val = np .meshgrid (x_val , y_val )
111111 xy_val = np .column_stack ((X_val .flatten (), Y_val .flatten ()))
112112 psi_val = psi (X_val , Y_val )
@@ -122,19 +122,46 @@ def main():
122122 #psi_actual = psi(X_val, Y_val)
123123
124124 # Visualize actual and predicted stream functions
125- visualize_psi (X_val , Y_val , psi_val , title = "Actual Stream Function" )
126- visualize_psi (X_val , Y_val , psi_pred , title = "Predicted Stream Function" )
125+ visualize_psi (X_val , Y_val , psi_val , centers , title = "Actual Stream Function" )
126+ visualize_psi (X_val , Y_val , psi_pred , centers , title = "Predicted Stream Function" )
127127
128- visualize_psi (X_val , Y_val , np .abs (psi_pred - psi_val ),
129- title = "Stream Function Approximation Error" ,
130- centers = centers )
128+ err_val = np .abs (psi_pred - psi_val )
129+ visualize_psi (X_val , Y_val , err_val , centers ,
130+ title = "Stream Function Approximation Error" )
131+
132+ # Define the filename
133+ csv_filename = "stream_function_validation.csv"
134+
135+ # Define the header and the values to be appended
136+ header = ["num_points" , "point_dist" , "r_max" , "err_validation" ]
137+ data = [num_points , 1.0 / num_points , r_max , np .mean (err_val )]
138+
139+ # Check if file exists
140+ file_exists = os .path .isfile (csv_filename )
141+
142+ # Open file in append mode
143+ with open (csv_filename , mode = 'a' , newline = '' ) as file :
144+ writer = csv .writer (file )
145+
146+ # If the file doesn't exist, write the header first
147+ if not file_exists :
148+ writer .writerow (header )
149+
150+ # Append the data row
151+ writer .writerow (data )
152+
153+ print (f"Appended to { csv_filename } : { data } " )
131154
132155 # Compute and visualize actual and predicted velocity fields
133156 u_val , v_val = compute_velocity (X_val , Y_val , psi_val )
134- visualize_velocity_field (X_val , Y_val , u_val , v_val , title = "Actual Velocity Field" )
157+ visualize_velocity_field (X_val , Y_val , u_val , v_val , num_points ,
158+ title = "Actual Velocity Field" )
135159
136160 u_pred , v_pred = compute_velocity (X_val , Y_val , psi_pred )
137- visualize_velocity_field (X_val , Y_val , u_pred , v_pred , title = "Predicted Velocity Field" )
161+ visualize_velocity_field (X_val , Y_val , u_pred , v_pred , num_points ,
162+ title = "Predicted Velocity Field" )
138163
139164if __name__ == "__main__" :
140- main ()
165+ main (num_points = 4 )
166+ main (num_points = 8 )
167+ main (num_points = 16 )
0 commit comments