1+ """
2+ Visualization of discrete-continuous convolutions
3+ ==========================================================
4+
5+ In this example, we demonstrate the usage of the discrete-continuous (DISCO) convolutions
6+ used in the localized neural operator framework. These modules can be used on both equidistant
7+ and unstructured grids.
8+ """
9+
10+ # %%
11+ # Preparation
12+
13+
14+ import os
15+ import torch
16+ import torch .nn as nn
17+ import numpy as np
18+ import math
19+ from functools import partial
20+
21+ from matplotlib import image
22+
23+ from torch_harmonics .quadrature import legendre_gauss_weights , lobatto_weights , clenshaw_curtiss_weights
24+
25+ import matplotlib .pyplot as plt
26+
27+ cmap = "inferno"
28+ device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
29+
30+ from neuralop .layers .discrete_continuous_convolution import DiscreteContinuousConv2d , DiscreteContinuousConvTranspose2d , EquidistantDiscreteContinuousConv2d , EquidistantDiscreteContinuousConvTranspose2d
31+
32+ # %%
33+ # Let's start by loading an example image
34+ os .system ("curl https://upload.wikimedia.org/wikipedia/commons/thumb/d/d3/Albert_Einstein_Head.jpg/360px-Albert_Einstein_Head.jpg -o ./einstein.jpg" )
35+
36+ nx = 90
37+ ny = 120
38+
39+ img = image .imread ('./einstein.jpg' )
40+ data = nn .functional .interpolate (torch .from_numpy (img ).unsqueeze (0 ).unsqueeze (0 ), size = (ny ,nx )).squeeze ()
41+ plt .imshow (data , cmap = cmap )
42+ plt .show ()
43+
44+ # %%
45+ # Let's create a grid on which the data lives
46+
47+ x_in = torch .linspace (0 , 2 , nx )
48+ y_in = torch .linspace (0 , 3 , ny )
49+
50+ x_in , y_in = torch .meshgrid (x_in , y_in )
51+ grid_in = torch .stack ([x_in .reshape (- 1 ), y_in .reshape (- 1 )])
52+
53+ # compute the correct quadrature weights
54+ # IMPORTANT: this needs to be done right in order for the DISCO convolution to be normalized proeperly
55+ w_x = 2 * torch .ones_like (x_in ) / nx
56+ w_y = 3 * torch .ones_like (y_in ) / ny
57+ q_in = (w_x * w_y ).reshape (- 1 )
58+
59+ # %%
60+ # Visualize the grid
61+
62+ plt .figure (figsize = (4 ,6 ), )
63+ plt .scatter (grid_in [0 ], grid_in [1 ], s = 0.2 )
64+ plt .xlim (0 ,2 )
65+ plt .ylim (0 ,3 )
66+ plt .show ()
67+
68+
69+ # %%
70+ # Format data into the same format and plot it on the grid
71+
72+ data = data .permute (1 ,0 ).flip (1 ).reshape (- 1 )
73+
74+ plt .figure (figsize = (4 ,6 ), )
75+ plt .tripcolor (grid_in [0 ], grid_in [1 ], data , cmap = cmap )
76+ # plt.colorbar()
77+ plt .xlim (0 ,2 )
78+ plt .ylim (0 ,3 )
79+ plt .show ()
80+
81+ # %%
82+ # For the convolution output we require an output mesh
83+ nxo = 90
84+ nyo = 120
85+
86+ x_out = torch .linspace (0 , 2 , nxo )
87+ y_out = torch .linspace (0 , 3 , nyo )
88+
89+ x_out , y_out = torch .meshgrid (x_out , y_out )
90+ grid_out = torch .stack ([x_out .reshape (- 1 ), y_out .reshape (- 1 )])
91+
92+ # compute the correct quadrature weights
93+ w_x = 2 * torch .ones_like (x_out ) / nxo
94+ w_y = 3 * torch .ones_like (y_out ) / nyo
95+ q_out = (w_x * w_y ).reshape (- 1 )
96+
97+ # %%
98+ # Initialize the convolution and set the weights to something resembling an edge filter/finit differences
99+ conv = DiscreteContinuousConv2d (1 , 1 , grid_in = grid_in , grid_out = grid_out , quad_weights = q_in , kernel_shape = [2 ,4 ], radius_cutoff = 5 / nyo , periodic = False ).float ()
100+
101+ # initialize a kernel resembling an edge filter
102+ w = torch .zeros_like (conv .weight )
103+ w [0 ,0 ,1 ] = 1.0
104+ w [0 ,0 ,3 ] = - 1.0
105+ conv .weight = nn .Parameter (w )
106+ psi = conv .get_psi ()
107+
108+ # %% apply the DISCO convolution to the data and plot it
109+ # in order to compute the convolved image, we need to first bring it into the right shape with `batch_size x n_channels x n_grid_points`
110+ out = conv (data .reshape (1 , 1 , - 1 ))
111+
112+ print (out .shape )
113+
114+ plt .figure (figsize = (4 ,6 ), )
115+ plt .imshow (torch .flip (out .squeeze ().detach ().reshape (nxo , nyo ).transpose (0 ,1 ), dims = (- 2 , )), cmap = cmap )
116+ plt .colorbar ()
117+ plt .show ()
118+
119+ out1 = torch .flip (out .squeeze ().detach ().reshape (nxo , nyo ).transpose (0 ,1 ), dims = (- 2 , ))
120+
121+ # %% do the same but on an equidistant grid:
122+ conv_equi = EquidistantDiscreteContinuousConv2d (1 , 1 , (nx , ny ), (nxo , nyo ), kernel_shape = [2 ,4 ], radius_cutoff = 5 / nyo , domain_length = [2 ,3 ])
123+
124+ # initialize a kernel resembling an edge filter
125+ w = torch .zeros_like (conv .weight )
126+ w [0 ,0 ,1 ] = 1.0
127+ w [0 ,0 ,3 ] = - 1.0
128+ conv_equi .weight = nn .Parameter (w )
129+
130+ data = nn .functional .interpolate (torch .from_numpy (img ).unsqueeze (0 ).unsqueeze (0 ), size = (ny ,nx )).float ()
131+
132+ out_equi = conv_equi (data )
133+
134+ print (out_equi .shape )
135+
136+ plt .figure (figsize = (4 ,6 ), )
137+ plt .imshow (out_equi .squeeze ().detach (), cmap = cmap )
138+ plt .colorbar ()
139+ plt .show ()
140+
141+ out2 = out_equi .squeeze ().detach ()
142+
143+ print (out2 .shape )
144+
145+ # %%
146+
147+ plt .figure (figsize = (4 ,6 ), )
148+ plt .imshow (conv_equi .get_psi ()[0 ].detach (), cmap = cmap )
149+ plt .colorbar ()
150+
151+ # # %%
152+
153+ # print("plt the error:")
154+ # plt.figure(figsize=(4,6), )
155+ # plt.imshow(out1 - out2, cmap=cmap)
156+ # plt.colorbar()
157+ # plt.show()
158+
159+ # %% test the transpose convolution
160+ convt = DiscreteContinuousConvTranspose2d (1 , 1 , grid_in = grid_out , grid_out = grid_in , quad_weights = q_out , kernel_shape = [2 ,4 ], radius_cutoff = 3 / nyo , periodic = False ).float ()
161+
162+ # initialize a flat
163+ w = torch .zeros_like (conv .weight )
164+ w [0 ,0 ,0 ] = 1.0
165+ w [0 ,0 ,1 ] = 1.0
166+ w [0 ,0 ,2 ] = 1.0
167+ w [0 ,0 ,3 ] = 1.0
168+ convt .weight = nn .Parameter (w )
169+
170+ data = nn .functional .interpolate (torch .from_numpy (img ).unsqueeze (0 ).unsqueeze (0 ), size = (ny ,nx )).squeeze ().float ().permute (1 ,0 ).flip (1 ).reshape (- 1 )
171+ out = convt (data .reshape (1 , 1 , - 1 ))
172+
173+ print (out .shape )
174+
175+ plt .figure (figsize = (4 ,6 ), )
176+ plt .imshow (torch .flip (out .squeeze ().detach ().reshape (nx , ny ).transpose (0 ,1 ), dims = (- 2 , )), cmap = cmap )
177+ plt .colorbar ()
178+ plt .show ()
179+
180+
181+
182+ # %% test the equidistant transpose convolution
183+ convt_equi = EquidistantDiscreteContinuousConvTranspose2d (1 , 1 , (nxo , nyo ), (nx , ny ), kernel_shape = [2 ,4 ], radius_cutoff = 3 / nyo , domain_length = [2 ,3 ])
184+
185+ # initialize a flat
186+ w = torch .zeros_like (convt_equi .weight )
187+ w [0 ,0 ,0 ] = 1.0
188+ w [0 ,0 ,1 ] = 1.0
189+ w [0 ,0 ,2 ] = 1.0
190+ w [0 ,0 ,3 ] = 1.0
191+ convt_equi .weight = nn .Parameter (w )
192+
193+ data = nn .functional .interpolate (torch .from_numpy (img ).unsqueeze (0 ).unsqueeze (0 ), size = (nyo ,nxo )).float ()
194+ out_equi = convt_equi (data )
195+
196+ print (out_equi .shape )
197+
198+ plt .figure (figsize = (4 ,6 ), )
199+ plt .imshow (out_equi .squeeze ().detach (), cmap = cmap )
200+ plt .colorbar ()
201+ plt .show ()
0 commit comments