Skip to content

Commit 703f0c9

Browse files
Github action: auto-update.
1 parent 76d34d7 commit 703f0c9

File tree

95 files changed

+2541
-441
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2541
-441
lines changed
Binary file not shown.
Binary file not shown.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Resampling layers
3+
=================
4+
5+
When working with neural operators, we often need to change the resolution of our data.
6+
For some architectures, like the FNO, this is handled automatically due to the
7+
resolution-invariant nature of the Fourier domain.
8+
9+
However, for other architectures, like the U-Net, we need to explicitly upsample and downsample
10+
the data as it flows through the network. The ``neuralop.layers.resample`` function provides a
11+
convenient way to do this.
12+
13+
In this example, we'll demonstrate how to use the ``resample`` function to upsample and downsample
14+
a sample from a Gaussian Random Field, which serves as a better visual tool than piecewise
15+
constant data for observing the effects of interpolation.
16+
17+
For 1D and 2D inputs, the ``resample`` function uses PyTorch’s built-in spatial interpolators
18+
for efficiency, applying linear interpolation for 1D data and bicubic interpolation for 2D data directly
19+
in the spatial domain.
20+
21+
For 3D or higher-dimensional inputs, the ``resample`` function switches to a spectral interpolation method
22+
based on the Fourier transform. The input is transformed into the frequency domain using a real n-dimensional FFT,
23+
which decomposes the signal into its frequency components. By resizing this frequency representation and
24+
then applying an inverse FFT, the function achieves smooth, alias-free interpolation
25+
that preserves the signal’s overall structure.
26+
"""
27+
import torch
28+
import matplotlib.pyplot as plt
29+
from neuralop.layers.resample import resample
30+
31+
# %%
32+
# First, let's generate a data input. We create a high-resolution Gaussian Random Field (GRF), which
33+
# is a smooth, continuous signal, making it ideal for visualizing the effects of resampling.
34+
device = 'cpu'
35+
36+
def generate_grf(shape, alpha=2.5, device='cpu'):
37+
"""Generates a 2D Gaussian Random Field.
38+
39+
Parameters
40+
----------
41+
shape : tuple
42+
The desired output shape (height, width).
43+
alpha : float, optional
44+
A parameter controlling the smoothness of the field.
45+
Higher alpha leads to smoother fields, by default 2.5.
46+
device : str, optional
47+
The device to create the tensor on, by default 'cpu'.
48+
49+
Returns
50+
-------
51+
torch.Tensor
52+
A 4D tensor of shape (1, 1, height, width) containing the GRF.
53+
"""
54+
n, m = shape
55+
freq_x = torch.fft.fftfreq(n, d=1/n, device=device).view(-1, 1)
56+
freq_y = torch.fft.fftfreq(m, d=1/m, device=device).view(1, -1)
57+
58+
norm_sq = freq_x**2 + freq_y**2
59+
norm_sq[0, 0] = 1.0 # Avoid division by zero
60+
61+
# Generate white noise in frequency domain
62+
noise = torch.randn(n, m, dtype=torch.cfloat, device=device)
63+
64+
# Apply a power-law filter
65+
filtered_noise = noise * (norm_sq**(-alpha/2.0))
66+
67+
# Inverse FFT to get the spatial field
68+
field = torch.fft.ifft2(filtered_noise).real
69+
70+
# Normalize to [0, 1] for visualization
71+
field = (field - field.min()) / (field.max() - field.min())
72+
73+
return field.unsqueeze(0).unsqueeze(0) # Add batch and channel dims
74+
75+
# Generate a 128x128 sample as our ground truth
76+
high_res = 128
77+
high_res_data = generate_grf((high_res, high_res), device=device)
78+
79+
# Define the low resolution we want to simulate (4x downsampling)
80+
low_res = 32
81+
82+
# %%
83+
# Now, let's use the ``resample`` function to simulate downsampling and upsampling operations.
84+
# This could for instance be used in the encoder and decoder of a U-Net architecture.
85+
# The function takes an input tensor, a `scale_factor`, and a list of
86+
# `axis` dimensions to which the resampling is applied.
87+
88+
# To downsample from 128x128 to 32x32, we need a scale factor of 32/128 = 0.25
89+
downsample_factor = low_res / high_res
90+
downsampled_data = resample(high_res_data, downsample_factor, [2, 3])
91+
92+
# To upsample from 32x32 back to 128x128, we need a scale factor of 128/32 = 4
93+
upsample_factor = high_res / low_res
94+
upsampled_data = resample(downsampled_data, upsample_factor, [2, 3])
95+
96+
97+
# %%
98+
# Finally, let's visualize the results to see the effect of the ``resample`` function.
99+
100+
fig, axs = plt.subplots(1, 3, figsize=(14, 6))
101+
plt.subplots_adjust(wspace=0.04)
102+
fig.suptitle('Resampling a Gaussian Random Field', fontsize=24)
103+
104+
# Plot the original high-resolution data
105+
im1 = axs[0].imshow(high_res_data.squeeze().cpu().numpy(), cmap='viridis', vmin=0, vmax=1)
106+
axs[0].set_title(f'High-Res Data ({high_res}x{high_res})', fontsize=16, fontweight='bold')
107+
cbar1 = fig.colorbar(im1, ax=axs[0], fraction=0.046, pad=0.04, ticks=[0, 0.5, 1])
108+
cbar1.ax.tick_params(labelsize=14)
109+
110+
# Plot the downsampled data
111+
im2 = axs[1].imshow(downsampled_data.squeeze().cpu().numpy(), cmap='viridis', vmin=0, vmax=1)
112+
axs[1].set_title(f'Downsampled (x{downsample_factor}) ({low_res}x{low_res})', fontsize=16, fontweight='bold')
113+
cbar2 = fig.colorbar(im2, ax=axs[1], fraction=0.046, pad=0.04, ticks=[0, 0.5, 1])
114+
cbar2.ax.tick_params(labelsize=14)
115+
116+
# Plot the upsampled data
117+
im3 = axs[2].imshow(upsampled_data.squeeze().cpu().numpy(), cmap='viridis', vmin=0, vmax=1)
118+
axs[2].set_title(f'Upsampled Back (x{upsample_factor:.0f}) ({high_res}x{high_res})', fontsize=16, fontweight='bold')
119+
cbar3 = fig.colorbar(im3, ax=axs[2], fraction=0.046, pad=0.04, ticks=[0, 0.5, 1])
120+
cbar3.ax.tick_params(labelsize=14)
121+
122+
# Hide axis ticks for a cleaner look
123+
for ax in axs.flat:
124+
ax.set_xticks([])
125+
ax.set_yticks([])
126+
127+
plt.tight_layout(rect=[0, 0.03, 1, 1.08])
128+
plt.show()
129+
130+
# %%
131+
# The ``resample`` function effectively changes the resolution of the data.
132+
# Notice that the upsampled image on the right is a faithful, if slightly blurrier,
133+
# reconstruction of the original. This is because the downsampling step is lossy;
134+
# high-frequency details are lost and cannot be perfectly recovered.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)