Skip to content

Commit 3caafff

Browse files
Github action: auto-update.
1 parent e9cd933 commit 3caafff

File tree

51 files changed

+1580
-175
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1580
-175
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""
2+
Visualization of discrete-continuous convolutions
3+
==========================================================
4+
5+
In this example, we demonstrate the usage of the discrete-continuous (DISCO) convolutions
6+
used in the localized neural operator framework. These modules can be used on both equidistant
7+
and unstructured grids.
8+
"""
9+
10+
# %%
11+
# Preparation
12+
13+
14+
import os
15+
import torch
16+
import torch.nn as nn
17+
import numpy as np
18+
import math
19+
from functools import partial
20+
21+
from matplotlib import image
22+
23+
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
24+
25+
import matplotlib.pyplot as plt
26+
27+
cmap="inferno"
28+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
29+
30+
from neuralop.layers.discrete_continuous_convolution import DiscreteContinuousConv2d, DiscreteContinuousConvTranspose2d, EquidistantDiscreteContinuousConv2d, EquidistantDiscreteContinuousConvTranspose2d
31+
32+
# %%
33+
# Let's start by loading an example image
34+
os.system("curl https://upload.wikimedia.org/wikipedia/commons/thumb/d/d3/Albert_Einstein_Head.jpg/360px-Albert_Einstein_Head.jpg -o ./einstein.jpg")
35+
36+
nx = 90
37+
ny = 120
38+
39+
img = image.imread('./einstein.jpg')
40+
data = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(ny,nx)).squeeze()
41+
plt.imshow(data, cmap=cmap)
42+
plt.show()
43+
44+
# %%
45+
# Let's create a grid on which the data lives
46+
47+
x_in = torch.linspace(0, 2, nx)
48+
y_in = torch.linspace(0, 3, ny)
49+
50+
x_in, y_in = torch.meshgrid(x_in, y_in)
51+
grid_in = torch.stack([x_in.reshape(-1), y_in.reshape(-1)])
52+
53+
# compute the correct quadrature weights
54+
# IMPORTANT: this needs to be done right in order for the DISCO convolution to be normalized proeperly
55+
w_x = 2*torch.ones_like(x_in) / nx
56+
w_y = 3*torch.ones_like(y_in) / ny
57+
q_in = (w_x * w_y).reshape(-1)
58+
59+
# %%
60+
# Visualize the grid
61+
62+
plt.figure(figsize=(4,6), )
63+
plt.scatter(grid_in[0], grid_in[1], s=0.2)
64+
plt.xlim(0,2)
65+
plt.ylim(0,3)
66+
plt.show()
67+
68+
69+
# %%
70+
# Format data into the same format and plot it on the grid
71+
72+
data = data.permute(1,0).flip(1).reshape(-1)
73+
74+
plt.figure(figsize=(4,6), )
75+
plt.tripcolor(grid_in[0], grid_in[1], data, cmap=cmap)
76+
# plt.colorbar()
77+
plt.xlim(0,2)
78+
plt.ylim(0,3)
79+
plt.show()
80+
81+
# %%
82+
# For the convolution output we require an output mesh
83+
nxo = 90
84+
nyo = 120
85+
86+
x_out = torch.linspace(0, 2, nxo)
87+
y_out = torch.linspace(0, 3, nyo)
88+
89+
x_out, y_out = torch.meshgrid(x_out, y_out)
90+
grid_out = torch.stack([x_out.reshape(-1), y_out.reshape(-1)])
91+
92+
# compute the correct quadrature weights
93+
w_x = 2*torch.ones_like(x_out) / nxo
94+
w_y = 3*torch.ones_like(y_out) / nyo
95+
q_out = (w_x * w_y).reshape(-1)
96+
97+
# %%
98+
# Initialize the convolution and set the weights to something resembling an edge filter/finit differences
99+
conv = DiscreteContinuousConv2d(1, 1, grid_in=grid_in, grid_out=grid_out, quad_weights=q_in, kernel_shape=[2,4], radius_cutoff=5/nyo, periodic=False).float()
100+
101+
# initialize a kernel resembling an edge filter
102+
w = torch.zeros_like(conv.weight)
103+
w[0,0,1] = 1.0
104+
w[0,0,3] = -1.0
105+
conv.weight = nn.Parameter(w)
106+
psi = conv.get_psi()
107+
108+
# %% apply the DISCO convolution to the data and plot it
109+
# in order to compute the convolved image, we need to first bring it into the right shape with `batch_size x n_channels x n_grid_points`
110+
out = conv(data.reshape(1, 1, -1))
111+
112+
print(out.shape)
113+
114+
plt.figure(figsize=(4,6), )
115+
plt.imshow(torch.flip(out.squeeze().detach().reshape(nxo, nyo).transpose(0,1), dims=(-2, )), cmap=cmap)
116+
plt.colorbar()
117+
plt.show()
118+
119+
out1 = torch.flip(out.squeeze().detach().reshape(nxo, nyo).transpose(0,1), dims=(-2, ))
120+
121+
# %% do the same but on an equidistant grid:
122+
conv_equi = EquidistantDiscreteContinuousConv2d(1, 1, (nx, ny), (nxo, nyo), kernel_shape=[2,4], radius_cutoff=5/nyo, domain_length=[2,3])
123+
124+
# initialize a kernel resembling an edge filter
125+
w = torch.zeros_like(conv.weight)
126+
w[0,0,1] = 1.0
127+
w[0,0,3] = -1.0
128+
conv_equi.weight = nn.Parameter(w)
129+
130+
data = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(ny,nx)).float()
131+
132+
out_equi = conv_equi(data)
133+
134+
print(out_equi.shape)
135+
136+
plt.figure(figsize=(4,6), )
137+
plt.imshow(out_equi.squeeze().detach(), cmap=cmap)
138+
plt.colorbar()
139+
plt.show()
140+
141+
out2 = out_equi.squeeze().detach()
142+
143+
print(out2.shape)
144+
145+
# %%
146+
147+
plt.figure(figsize=(4,6), )
148+
plt.imshow(conv_equi.get_psi()[0].detach(), cmap=cmap)
149+
plt.colorbar()
150+
151+
# # %%
152+
153+
# print("plt the error:")
154+
# plt.figure(figsize=(4,6), )
155+
# plt.imshow(out1 - out2, cmap=cmap)
156+
# plt.colorbar()
157+
# plt.show()
158+
159+
# %% test the transpose convolution
160+
convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quad_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()
161+
162+
# initialize a flat
163+
w = torch.zeros_like(conv.weight)
164+
w[0,0,0] = 1.0
165+
w[0,0,1] = 1.0
166+
w[0,0,2] = 1.0
167+
w[0,0,3] = 1.0
168+
convt.weight = nn.Parameter(w)
169+
170+
data = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(ny,nx)).squeeze().float().permute(1,0).flip(1).reshape(-1)
171+
out = convt(data.reshape(1, 1, -1))
172+
173+
print(out.shape)
174+
175+
plt.figure(figsize=(4,6), )
176+
plt.imshow(torch.flip(out.squeeze().detach().reshape(nx, ny).transpose(0,1), dims=(-2, )), cmap=cmap)
177+
plt.colorbar()
178+
plt.show()
179+
180+
181+
182+
# %% test the equidistant transpose convolution
183+
convt_equi = EquidistantDiscreteContinuousConvTranspose2d(1, 1, (nxo, nyo), (nx, ny), kernel_shape=[2,4], radius_cutoff=3/nyo, domain_length=[2,3])
184+
185+
# initialize a flat
186+
w = torch.zeros_like(convt_equi.weight)
187+
w[0,0,0] = 1.0
188+
w[0,0,1] = 1.0
189+
w[0,0,2] = 1.0
190+
w[0,0,3] = 1.0
191+
convt_equi.weight = nn.Parameter(w)
192+
193+
data = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(nyo,nxo)).float()
194+
out_equi = convt_equi(data)
195+
196+
print(out_equi.shape)
197+
198+
plt.figure(figsize=(4,6), )
199+
plt.imshow(out_equi.squeeze().detach(), cmap=cmap)
200+
plt.colorbar()
201+
plt.show()
Binary file not shown.

0 commit comments

Comments
 (0)