Skip to content

Commit 9ae5c8e

Browse files
committed
update PReLU implementation
1 parent d229885 commit 9ae5c8e

File tree

4 files changed

+30
-27
lines changed

4 files changed

+30
-27
lines changed

examples/local_pca_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_subset_mnist(n: int = 1000):
3434
# Plot metric and data
3535
plt.figure()
3636
ran = torch.linspace(-3.0, 3.0, 100)
37-
X, Y = torch.meshgrid([ran, ran], indexing='ij')
37+
X, Y = torch.meshgrid([ran, ran], indexing="ij")
3838
XY = torch.stack((X.flatten(), Y.flatten()), dim=1) # 10000x2
3939
gridM = M.metric(XY) # 10000x2
4040
Mim = gridM.sum(dim=1).reshape((100, 100)).detach().t()
@@ -56,7 +56,7 @@ def get_subset_mnist(n: int = 1000):
5656
# Compute discretized geodesics
5757
plt.figure()
5858
ran2 = torch.linspace(-3.0, 3.0, 133)
59-
X2, Y2 = torch.meshgrid([ran2, ran2], indexing='ij')
59+
X2, Y2 = torch.meshgrid([ran2, ran2], indexing="ij")
6060
XY2 = torch.stack((X2.flatten(), Y2.flatten()), dim=1) # 10000x2
6161
DMim = DM.metric(XY2).log().sum(dim=1).view(133, 133).t()
6262
plt.imshow(DMim, extent=(ran[0], ran[-1], ran[0], ran[-1]), origin="lower")

stochman/discretized_manifold.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
5555

5656
dim = len(grid)
5757
if len(grid) != 2:
58-
raise Exception('Currently we only support 2D grids -- sorry!')
58+
raise Exception("Currently we only support 2D grids -- sorry!")
5959

6060
# Add nodes to graph
6161
xsize, ysize = len(grid[0]), len(grid[1])
@@ -64,7 +64,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
6464

6565
point_set = torch.cartesian_prod(
6666
torch.linspace(0, xsize - 1, xsize, dtype=torch.long),
67-
torch.linspace(0, ysize - 1, ysize, dtype=torch.long)
67+
torch.linspace(0, ysize - 1, ysize, dtype=torch.long),
6868
) # (big)x2
6969

7070
point_sets = [] # these will be [N, 2] matrices of index points
@@ -93,8 +93,8 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
9393
t = torch.linspace(0, 1, 2)
9494
for ps, nf in zip(point_sets, neighbour_funcs):
9595
for i in range(ceil(ps.shape[0] / batch_size)):
96-
x = ps[batch_size * i:batch_size * (i + 1), 0]
97-
y = ps[batch_size * i:batch_size * (i + 1), 1]
96+
x = ps[batch_size * i : batch_size * (i + 1), 0]
97+
y = ps[batch_size * i : batch_size * (i + 1), 1]
9898
xn, yn = nf[0](x), nf[1](y)
9999

100100
bs = x.shape[0] # may be different from batch size for the last batch
@@ -132,7 +132,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
132132
self.__metric__ = M.view([*self.grid_size, d, d]) # e.g. (xsize)x(ysize)x(d)x(d)
133133

134134
# Compute interpolation weights. We use the mean function of a GP regressor.
135-
mesh = torch.meshgrid(*self.grid, indexing='ij')
135+
mesh = torch.meshgrid(*self.grid, indexing="ij")
136136
grid_points = torch.cat(
137137
[m.unsqueeze(-1) for m in mesh], dim=-1
138138
) # e.g. 100x100x2 a 2D grid with 100 points in each dim
@@ -145,6 +145,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
145145
) # (num_grid)x(d²) or (num_grid)x(d)
146146
except:
147147
import warnings
148+
148149
warnings.warn("It appears that your model does not implement a metric.")
149150
# XXX: Down the road, we should be able to estimate the metric from the observed distances
150151

@@ -191,9 +192,9 @@ def _grid_dist2(self, p):
191192
"""
192193

193194
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
194-
mesh = torch.meshgrid(*self.grid, indexing='ij')
195+
mesh = torch.meshgrid(*self.grid, indexing="ij")
195196
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
196-
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2
197+
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1)) ** 2
197198
return dist2
198199

199200
def _kernel(self, p):
@@ -205,12 +206,12 @@ def _kernel(self, p):
205206
Output:
206207
val: a torch Tensor with the kernel values.
207208
"""
208-
lengthscales = [(g[1] - g[0])**2 for g in self.grid]
209+
lengthscales = [(g[1] - g[0]) ** 2 for g in self.grid]
209210

210211
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
211-
mesh = torch.meshgrid(*self.grid, indexing='ij')
212+
mesh = torch.meshgrid(*self.grid, indexing="ij")
212213
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
213-
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2 / lengthscales[dim]
214+
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1)) ** 2 / lengthscales[dim]
214215

215216
return torch.exp(-dist2)
216217

@@ -241,9 +242,9 @@ def shortest_path(self, p1, p2):
241242
"""
242243
idx1 = self._grid_point(p1)
243244
idx2 = self._grid_point(p2)
244-
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
245+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight="weight") # list with N elements
245246
# coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
246-
mesh = torch.meshgrid(*self.grid, indexing='ij')
247+
mesh = torch.meshgrid(*self.grid, indexing="ij")
247248
raw_coordinates = [m.flatten()[path].view(1, -1) for m in mesh]
248249
coordinates = torch.cat(raw_coordinates, dim=0) # (dim)xN
249250
N = len(path)
@@ -252,7 +253,7 @@ def shortest_path(self, p1, p2):
252253
curve.parameters[:, :] = coordinates[:, 1:-1].t()
253254
dist = 0
254255
for i in range(N - 1):
255-
dist += self.G.edges[path[i], path[i + 1]]['weight']
256+
dist += self.G.edges[path[i], path[i + 1]]["weight"]
256257
return curve, dist
257258

258259
def connecting_geodesic(self, p1, p2, curve=None):
@@ -282,7 +283,7 @@ def connecting_geodesic(self, p1, p2, curve=None):
282283
p2 = p2.unsqueeze(0) # 1xD
283284
B = p1.shape[0]
284285
if p1.shape != p2.shape:
285-
raise NameError('shape mismatch')
286+
raise NameError("shape mismatch")
286287

287288
if curve is None:
288289
curve = CubicSpline(p1, p2)
@@ -295,10 +296,10 @@ def connecting_geodesic(self, p1, p2, curve=None):
295296
idx1 = self._grid_point(p1[b].unsqueeze(0))
296297
idx2 = self._grid_point(p2[b].unsqueeze(0))
297298
path = nx.shortest_path(
298-
self.G, source=idx1, target=idx2, weight='weight'
299+
self.G, source=idx1, target=idx2, weight="weight"
299300
) # list with N elements
300-
weights = [self.G.edges[path[k], path[k + 1]]['weight'] for k in range(len(path) - 1)]
301-
mesh = torch.meshgrid(*self.grid, indexing='ij')
301+
weights = [self.G.edges[path[k], path[k + 1]]["weight"] for k in range(len(path) - 1)]
302+
mesh = torch.meshgrid(*self.grid, indexing="ij")
302303
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
303304
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
304305
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)

stochman/nnj.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,14 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
319319
return jac
320320

321321

322+
class PReLU(AbstractActivationJacobian, nn.PReLU):
323+
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
324+
jac = (val >= 0.0).type(val.dtype) + (val < 0.0).type(val.dtype) * self.weight.reshape(
325+
(1, self.num_parameters) + (1,) * (val.ndim - 2)
326+
)
327+
return jac
328+
329+
322330
class ELU(AbstractActivationJacobian, nn.ELU):
323331
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
324332
jac = torch.ones_like(val)
@@ -340,13 +348,6 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
340348
return jac
341349

342350

343-
class PReLU(AbstractActivationJacobian, nn.PReLU):
344-
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
345-
jac = torch.ones_like(val)
346-
jac[x < 0.0] = self.weight
347-
return jac
348-
349-
350351
class LeakyReLU(AbstractActivationJacobian, nn.LeakyReLU):
351352
def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
352353
jac = torch.ones_like(val)

tests/test_nnj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
5656
(nnj.Sequential(nnj.Linear(_features, 2), nnj.OneMinusX()), _linear_input_shape),
5757
(nnj.Sequential(nnj.Linear(_features, 2), nnj.PReLU()), _linear_input_shape),
5858
(nnj.Sequential(nnj.Linear(_features, 2), nnj.Softmax(dim=-1)), _linear_input_shape),
59+
(nnj.Sequential(nnj.Linear(_features, 2), nnj.PReLU()), _linear_input_shape),
5960
(
6061
nnj.Sequential(nnj.Conv1d(_features, 2, 5), nnj.ConvTranspose1d(2, _features, 5)),
6162
_1d_conv_input_shape,
@@ -145,7 +146,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
145146
input = torch.randn(*input_shape, device=device, dtype=dtype)
146147
_, jac = model(input, jacobian=True)
147148
jacnum = _compare_jacobian(model, input).to(device)
148-
assert torch.isclose(jac, jacnum, atol=1e-4).all(), "jacobians did not match"
149+
assert torch.isclose(jac, jacnum, atol=1e-3).all(), "jacobians did not match"
149150

150151
@pytest.mark.parametrize("return_jac", [True, False])
151152
def test_jac_return(self, model, input_shape, device, return_jac):

0 commit comments

Comments
 (0)