33
44Fourier Differentiation
55========================================================
6- An example of usage of our Fourier Differentiation Function on 1d data.
6+ An example of usage of our Fourier Differentiation Function
77"""
88
99# %%
1313import torch
1414import numpy as np
1515import matplotlib .pyplot as plt
16- from neuralop .losses .fourier_diff import fourier_derivative_1d
16+ from neuralop .losses .differentiation import FourierDiff
1717
1818device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1919
2020
2121
2222# %%
2323# Creating an example of periodic 1D curve
24- # --------------------
25- # Here we consider sin(x) and cos(x), which are periodic on the interval [0,2pi ]
26- L = 2 * torch .pi
24+ # ----------------------------------------
25+ # Here we consider sin(x) and cos(x), which are periodic on the interval [0, 2π ]
26+ L = 2 * torch .pi
2727x = torch .linspace (0 , L , 101 )[:- 1 ]
2828f = torch .stack ([torch .sin (x ), torch .cos (x )], dim = 0 )
2929x_np = x .cpu ().numpy ()
3030
3131# %%
3232# Differentiate the signal
3333# -----------------------------------------
34- # We use the Fourier differentiation function to differentiate the signal
35- dfdx = fourier_derivative_1d ( f , order = 1 , L = L )
36- df2dx2 = fourier_derivative_1d (f , order = 2 , L = L )
37- df3dx3 = fourier_derivative_1d ( f , order = 3 , L = L )
34+ # We use the FourierDiff class to differentiate the signal
35+ fd1d = FourierDiff ( dim = 1 , L = L , use_fc = False )
36+ derivatives = fd1d . compute_multiple_derivatives (f , [ 1 , 2 , 3 ] )
37+ dfdx , df2dx2 , df3dx3 = derivatives
3838
3939
4040# %%
4141# Plot the results for sin(x)
42- # ----------------------
42+ # ---------------------------
4343plt .figure ()
4444plt .plot (x_np , dfdx [0 ].squeeze ().cpu ().numpy (), label = 'Fourier dfdx' )
4545plt .plot (x_np , np .cos (x_np ), '--' , label = 'dfdx' )
5353
5454# %%
5555# Plot the results for cos(x)
56- # ----------------------
56+ # ---------------------------
5757plt .figure ()
5858plt .plot (x_np , dfdx [1 ].squeeze ().cpu ().numpy (), label = 'Fourier dfdx' )
5959plt .plot (x_np , - np .sin (x_np ), '--' , label = 'dfdx' )
6969
7070# %%
7171# Creating an example of non-periodic 1D curve
72- # --------------------
73- # Here we consider sin(16x )-cos(8x ) and exp(-0.8x)+sin(x)
74- L = 2 * torch .pi
75- x = torch .linspace (0 , L , 101 )[:- 1 ]
76- f = torch .stack ([torch .sin (3 * x ) - torch .cos (x ), torch .exp (- 0.8 * x )+ torch .sin (x )], dim = 0 )
72+ # -------------------------------------------
73+ # Here we consider sin(3x )-cos(x ) and exp(-0.8x)+sin(x)
74+ L = 2 * torch .pi
75+ x = torch .linspace (0 , L , 101 )[:- 1 ]
76+ f = torch .stack ([torch .sin (3 * x ) - torch .cos (x ), torch .exp (- 0.8 * x ) + torch .sin (x )], dim = 0 )
7777x_np = x .cpu ().numpy ()
7878
7979# %%
8080# Differentiate the signal
8181# -----------------------------------------
82- # We use the Fourier differentiation function with Fourier continuation to differentiate the signal
83- dfdx = fourier_derivative_1d (f , order = 1 , L = L , use_FC = 'Legendre' , FC_d = 4 , FC_n_additional_pts = 30 )
84- df2dx2 = fourier_derivative_1d (f , order = 2 , L = L , use_FC = 'Legendre' , FC_d = 4 , FC_n_additional_pts = 30 )
82+ # We use the FourierDiff class with Fourier continuation to differentiate the signal
83+ fd1d = FourierDiff (dim = 1 , L = L , use_fc = 'Legendre' , fc_degree = 4 , fc_n_additional_pts = 50 )
84+ derivatives = fd1d .compute_multiple_derivatives (f , [1 , 2 ])
85+ dfdx , df2dx2 = derivatives
8586
8687
8788# %%
88- # Plot the results for sin(16x )-cos(8x )
89- # ----------------------
89+ # Plot the results for sin(3x )-cos(x )
90+ # --------------------------------------
9091plt .figure ()
9192plt .plot (x_np , dfdx [0 ].squeeze ().cpu ().numpy (), label = 'Fourier dfdx' )
9293plt .plot (x_np , 3 * torch .cos (3 * x ) + torch .sin (x ), '--' , label = 'dfdx' )
9899
99100# %%
100101# Plot the results for exp(-0.8x)+sin(x)
101- # ----------------------
102+ # ---------------------------------------
102103plt .figure ()
103104plt .plot (x_np , dfdx [1 ].squeeze ().cpu ().numpy (), label = 'Fourier dfdx' )
104105plt .plot (x_np , - 0.8 * torch .exp (- 0.8 * x )+ torch .cos (x ), '--' , label = 'dfdx' )
105106plt .plot (x_np , df2dx2 [1 ].squeeze ().cpu ().numpy (), label = 'Fourier df2dx2' )
106107plt .plot (x_np , 0.64 * torch .exp (- 0.8 * x )- torch .sin (x ), '--' , label = 'df2dx2' )
107108plt .xlabel ('x' )
108109plt .legend ()
110+ plt .show ()
111+
112+
113+ # %%
114+ # 2D Fourier Differentiation Examples
115+ # ===================================
116+ # Here we demonstrate the FourierDiff class for 2D functions
117+
118+ # %%
119+ # Creating an example of periodic 2D function
120+ # -----------------------------------------
121+ # Here we consider f(x,y) = sin(x) * cos(y), which is periodic on the interval [0, 2π] × [0, 2π]
122+ L_x , L_y = 2 * torch .pi , 2 * torch .pi
123+ nx , ny = 180 , 186
124+ x = torch .linspace (0 , L_x , nx , dtype = torch .float64 )
125+ y = torch .linspace (0 , L_y , ny , dtype = torch .float64 )
126+ X , Y = torch .meshgrid (x , y , indexing = 'ij' )
127+
128+ # Test function: f(x,y) = sin(x) * cos(y)
129+ f_2d = torch .sin (X ) * torch .cos (Y )
130+
131+ # %%
132+ # Differentiate the 2D signal
133+ # -----------------------------------------
134+ # We use the FourierDiff class to compute derivatives
135+ fd2d = FourierDiff (dim = 2 , L = (L_x , L_y ))
136+
137+ # Compute derivatives
138+ df_dx = fd2d .dx (f_2d )
139+ df_dy = fd2d .dy (f_2d )
140+ laplacian = fd2d .laplacian (f_2d )
141+
142+ # Expected analytical results for f(x,y) = sin(x) * cos(y)
143+ df_dx_expected = torch .cos (X ) * torch .cos (Y )
144+ df_dy_expected = - torch .sin (X ) * torch .sin (Y )
145+ laplacian_expected = - 2 * torch .sin (X ) * torch .cos (Y )
146+
147+ # %%
148+ # Plot the 2D results
149+ # ----------------------
150+ fig , axes = plt .subplots (2 , 3 , figsize = (15 , 10 ))
151+ fig .suptitle ('2D Fourier Differentiation Results: f(x,y) = sin(x) * cos(y)' )
152+
153+ # Compute consistent colorbar limits for each derivative pair
154+ df_dx_min = min (df_dx .min ().item (), df_dx_expected .min ().item ())
155+ df_dx_max = max (df_dx .max ().item (), df_dx_expected .max ().item ())
156+ df_dy_min = min (df_dy .min ().item (), df_dy_expected .min ().item ())
157+ df_dy_max = max (df_dy .max ().item (), df_dy_expected .max ().item ())
158+
159+ # Original function
160+ im0 = axes [0 , 0 ].imshow (f_2d .cpu ().numpy ())
161+ axes [0 , 0 ].set_title ('Original: sin(x) * cos(y)' )
162+ plt .colorbar (im0 , ax = axes [0 , 0 ], shrink = 0.57 )
163+
164+ # ∂f/∂x computed
165+ im1 = axes [0 , 1 ].imshow (df_dx .cpu ().numpy (), vmin = df_dx_min , vmax = df_dx_max )
166+ axes [0 , 1 ].set_title ('∂f/∂x (computed)' )
167+ plt .colorbar (im1 , ax = axes [0 , 1 ], shrink = 0.57 )
168+
169+ # ∂f/∂x expected
170+ im2 = axes [0 , 2 ].imshow (df_dx_expected .cpu ().numpy (), vmin = df_dx_min , vmax = df_dx_max )
171+ axes [0 , 2 ].set_title ('∂f/∂x (expected: cos(x) * cos(y))' )
172+ plt .colorbar (im2 , ax = axes [0 , 2 ], shrink = 0.57 )
173+
174+ # ∂f/∂y computed
175+ im3 = axes [1 , 0 ].imshow (df_dy .cpu ().numpy (), vmin = df_dy_min , vmax = df_dy_max )
176+ axes [1 , 0 ].set_title ('∂f/∂y (computed)' )
177+ plt .colorbar (im3 , ax = axes [1 , 0 ], shrink = 0.57 )
178+
179+ # ∂f/∂y expected
180+ im4 = axes [1 , 1 ].imshow (df_dy_expected .cpu ().numpy (), vmin = df_dy_min , vmax = df_dy_max )
181+ axes [1 , 1 ].set_title ('∂f/∂y (expected: -sin(x) * sin(y))' )
182+ plt .colorbar (im4 , ax = axes [1 , 1 ], shrink = 0.57 )
183+
184+ # Laplacian
185+ im5 = axes [1 , 2 ].imshow (laplacian .cpu ().numpy ())
186+ axes [1 , 2 ].set_title ('∇²f (computed)' )
187+ plt .colorbar (im5 , ax = axes [1 , 2 ], shrink = 0.57 )
188+
189+ plt .tight_layout ()
190+ plt .show ()
191+
192+
193+
194+
195+ # %%
196+ # 3D Fourier Differentiation Examples
197+ # ===================================
198+ # Here we demonstrate the FourierDiff class for 3D functions
199+
200+ # %%
201+ # Creating an example of periodic 3D function
202+ # -----------------------------------------
203+ # Here we consider f(x,y,z) = sin(x) * cos(y) * sin(z), which is periodic on [0, 2π]³
204+ L_x , L_y , L_z = 2 * torch .pi , 2 * torch .pi , 2 * torch .pi
205+ nx , ny , nz = 176 , 180 , 192
206+ x = torch .linspace (0 , L_x , nx , dtype = torch .float64 )
207+ y = torch .linspace (0 , L_y , ny , dtype = torch .float64 )
208+ z = torch .linspace (0 , L_z , nz , dtype = torch .float64 )
209+ X , Y , Z = torch .meshgrid (x , y , z , indexing = 'ij' )
210+
211+ # Test function: f(x,y,z) = sin(x) * cos(y) * sin(z)
212+ f_3d = torch .sin (X ) * torch .cos (Y ) * torch .sin (Z )
213+
214+ # Alternative: create tensor directly like in the test
215+ f_3d_alt = torch .randn (nx , ny , nz , dtype = torch .float64 )
216+
217+ # %%
218+ # Differentiate the 3D signal
219+ # -----------------------------------------
220+ # We use the FourierDiff class to compute derivatives
221+ fd3d = FourierDiff (dim = 3 , L = (L_x , L_y , L_z ))
222+
223+ # Compute derivatives
224+ df_dx_3d = fd3d .dx (f_3d )
225+ df_dy_3d = fd3d .dy (f_3d )
226+ df_dz_3d = fd3d .dz (f_3d )
227+ laplacian_3d = fd3d .laplacian (f_3d )
228+
229+ # Expected analytical results for f(x,y,z) = sin(x) * cos(y) * sin(z)
230+ df_dx_expected_3d = torch .cos (X ) * torch .cos (Y ) * torch .sin (Z )
231+ df_dy_expected_3d = - torch .sin (X ) * torch .sin (Y ) * torch .sin (Z )
232+ df_dz_expected_3d = torch .sin (X ) * torch .cos (Y ) * torch .cos (Z )
233+ laplacian_expected_3d = - 3 * torch .sin (X ) * torch .cos (Y ) * torch .sin (Z )
234+
235+ # %%
236+ # Plot a slice of the 3D results (z=0 plane)
237+ # ------------------------------------------
238+ z_slice_idx = nz // 2
239+ fig , axes = plt .subplots (2 , 3 , figsize = (18 , 12 ))
240+ fig .suptitle ('3D Fourier Differentiation Results (z=0 slice): f(x,y,z) = sin(x) * cos(y) * sin(z)' )
241+
242+ # Compute consistent colorbar limits for each derivative pair at the z-slice
243+ df_dx_3d_slice = df_dx_3d [:, :, z_slice_idx ]
244+ df_dx_expected_3d_slice = df_dx_expected_3d [:, :, z_slice_idx ]
245+ df_dy_3d_slice = df_dy_3d [:, :, z_slice_idx ]
246+ df_dy_expected_3d_slice = df_dy_expected_3d [:, :, z_slice_idx ]
247+
248+ df_dx_3d_min = min (df_dx_3d_slice .min ().item (), df_dx_expected_3d_slice .min ().item ())
249+ df_dx_3d_max = max (df_dx_3d_slice .max ().item (), df_dx_expected_3d_slice .max ().item ())
250+ df_dy_3d_min = min (df_dy_3d_slice .min ().item (), df_dy_expected_3d_slice .min ().item ())
251+ df_dy_3d_max = max (df_dy_3d_slice .max ().item (), df_dy_expected_3d_slice .max ().item ())
252+
253+ # Original function slice
254+ im0 = axes [0 , 0 ].imshow (f_3d [:, :, z_slice_idx ].cpu ().numpy ())
255+ axes [0 , 0 ].set_title ('Original: sin(x) * cos(y) * sin(z)' )
256+ plt .colorbar (im0 , ax = axes [0 , 0 ], shrink = 0.57 )
257+
258+ # ∂f/∂x slice
259+ im1 = axes [0 , 1 ].imshow (df_dx_3d_slice .cpu ().numpy (), vmin = df_dx_3d_min , vmax = df_dx_3d_max )
260+ axes [0 , 1 ].set_title ('∂f/∂x (computed)' )
261+ plt .colorbar (im1 , ax = axes [0 , 1 ], shrink = 0.57 )
262+
263+ # ∂f/∂x expected slice
264+ im2 = axes [0 , 2 ].imshow (df_dx_expected_3d_slice .cpu ().numpy (), vmin = df_dx_3d_min , vmax = df_dx_3d_max )
265+ axes [0 , 2 ].set_title ('∂f/∂x (expected: cos(x) * cos(y) * sin(z))' )
266+ plt .colorbar (im2 , ax = axes [0 , 2 ], shrink = 0.57 )
267+
268+ # ∂f/∂y slice
269+ im3 = axes [1 , 0 ].imshow (df_dy_3d_slice .cpu ().numpy (), vmin = df_dy_3d_min , vmax = df_dy_3d_max )
270+ axes [1 , 0 ].set_title ('∂f/∂y (computed)' )
271+ plt .colorbar (im3 , ax = axes [1 , 0 ], shrink = 0.57 )
272+
273+ # ∂f/∂y expected slice
274+ im4 = axes [1 , 1 ].imshow (df_dy_expected_3d_slice .cpu ().numpy (), vmin = df_dy_3d_min , vmax = df_dy_3d_max )
275+ axes [1 , 1 ].set_title ('∂f/∂y (expected: -sin(x) * sin(y) * sin(z))' )
276+ plt .colorbar (im4 , ax = axes [1 , 1 ], shrink = 0.57 )
277+
278+ # ∂f/∂z slice
279+ im5 = axes [1 , 2 ].imshow (df_dz_3d [:, :, z_slice_idx ].cpu ().numpy ())
280+ axes [1 , 2 ].set_title ('∂f/∂z (computed)' )
281+ plt .colorbar (im5 , ax = axes [1 , 2 ], shrink = 0.57 )
282+
283+ plt .tight_layout ()
109284plt .show ()
0 commit comments