Skip to content

Commit e03d3b0

Browse files
Github action: auto-update.
1 parent ef0edc0 commit e03d3b0

File tree

61 files changed

+4557
-231
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+4557
-231
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

dev/_downloads/32ba71ed3a8fca9c71321516055525f0/plot_DISCO_convolutions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@
9696

9797
# %%
9898
# Initialize the convolution and set the weights to something resembling an edge filter/finit differences
99-
conv = DiscreteContinuousConv2d(1, 1, grid_in=grid_in, grid_out=grid_out, quad_weights=q_in, kernel_shape=[2,4], radius_cutoff=5/nyo, periodic=False).float()
99+
conv = DiscreteContinuousConv2d(1, 1, grid_in=grid_in, grid_out=grid_out, quadrature_weights=q_in, kernel_shape=[2,4], radius_cutoff=5/nyo, periodic=False).float()
100100

101101
# initialize a kernel resembling an edge filter
102102
w = torch.zeros_like(conv.weight)
103103
w[0,0,1] = 1.0
104104
w[0,0,3] = -1.0
105105
conv.weight = nn.Parameter(w)
106-
psi = conv.get_psi()
106+
psi = conv.get_local_filter_matrix()
107107

108108
# %% apply the DISCO convolution to the data and plot it
109109
# in order to compute the convolved image, we need to first bring it into the right shape with `batch_size x n_channels x n_grid_points`
@@ -145,7 +145,7 @@
145145
# %%
146146

147147
plt.figure(figsize=(4,6), )
148-
plt.imshow(conv_equi.get_psi()[0].detach(), cmap=cmap)
148+
plt.imshow(conv_equi.get_local_filter_matrix()[0].detach(), cmap=cmap)
149149
plt.colorbar()
150150

151151
# # %%
@@ -157,7 +157,7 @@
157157
# plt.show()
158158

159159
# %% test the transpose convolution
160-
convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quad_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()
160+
convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quadrature_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()
161161

162162
# initialize a flat
163163
w = torch.zeros_like(conv.weight)
Binary file not shown.

dev/_downloads/38d20cb4d1344392d40c819f90342451/plot_DISCO_convolutions.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
},
131131
"outputs": [],
132132
"source": [
133-
"conv = DiscreteContinuousConv2d(1, 1, grid_in=grid_in, grid_out=grid_out, quad_weights=q_in, kernel_shape=[2,4], radius_cutoff=5/nyo, periodic=False).float()\n\n# initialize a kernel resembling an edge filter\nw = torch.zeros_like(conv.weight)\nw[0,0,1] = 1.0\nw[0,0,3] = -1.0\nconv.weight = nn.Parameter(w)\npsi = conv.get_psi()"
133+
"conv = DiscreteContinuousConv2d(1, 1, grid_in=grid_in, grid_out=grid_out, quadrature_weights=q_in, kernel_shape=[2,4], radius_cutoff=5/nyo, periodic=False).float()\n\n# initialize a kernel resembling an edge filter\nw = torch.zeros_like(conv.weight)\nw[0,0,1] = 1.0\nw[0,0,3] = -1.0\nconv.weight = nn.Parameter(w)\npsi = conv.get_local_filter_matrix()"
134134
]
135135
},
136136
{
@@ -170,7 +170,7 @@
170170
},
171171
"outputs": [],
172172
"source": [
173-
"plt.figure(figsize=(4,6), )\nplt.imshow(conv_equi.get_psi()[0].detach(), cmap=cmap)\nplt.colorbar()\n\n# # %%\n\n# print(\"plt the error:\")\n# plt.figure(figsize=(4,6), )\n# plt.imshow(out1 - out2, cmap=cmap)\n# plt.colorbar()\n# plt.show()"
173+
"plt.figure(figsize=(4,6), )\nplt.imshow(conv_equi.get_local_filter_matrix()[0].detach(), cmap=cmap)\nplt.colorbar()\n\n# # %%\n\n# print(\"plt the error:\")\n# plt.figure(figsize=(4,6), )\n# plt.imshow(out1 - out2, cmap=cmap)\n# plt.colorbar()\n# plt.show()"
174174
]
175175
},
176176
{
@@ -181,7 +181,7 @@
181181
},
182182
"outputs": [],
183183
"source": [
184-
"convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quad_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()\n\n# initialize a flat\nw = torch.zeros_like(conv.weight)\nw[0,0,0] = 1.0\nw[0,0,1] = 1.0\nw[0,0,2] = 1.0\nw[0,0,3] = 1.0\nconvt.weight = nn.Parameter(w)\n\ndata = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(ny,nx)).squeeze().float().permute(1,0).flip(1).reshape(-1)\nout = convt(data.reshape(1, 1, -1))\n\nprint(out.shape)\n\nplt.figure(figsize=(4,6), )\nplt.imshow(torch.flip(out.squeeze().detach().reshape(nx, ny).transpose(0,1), dims=(-2, )), cmap=cmap)\nplt.colorbar()\nplt.show()"
184+
"convt = DiscreteContinuousConvTranspose2d(1, 1, grid_in=grid_out, grid_out=grid_in, quadrature_weights=q_out, kernel_shape=[2,4], radius_cutoff=3/nyo, periodic=False).float()\n\n# initialize a flat\nw = torch.zeros_like(conv.weight)\nw[0,0,0] = 1.0\nw[0,0,1] = 1.0\nw[0,0,2] = 1.0\nw[0,0,3] = 1.0\nconvt.weight = nn.Parameter(w)\n\ndata = nn.functional.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(0), size=(ny,nx)).squeeze().float().permute(1,0).flip(1).reshape(-1)\nout = convt(data.reshape(1, 1, -1))\n\nprint(out.shape)\n\nplt.figure(figsize=(4,6), )\nplt.imshow(torch.flip(out.squeeze().detach().reshape(nx, ny).transpose(0,1), dims=(-2, )), cmap=cmap)\nplt.colorbar()\nplt.show()"
185185
]
186186
},
187187
{
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)