Skip to content

Commit 44b045e

Browse files
committed
Add initial implementation of a discretized manifold
1 parent 068f6ab commit 44b045e

File tree

1 file changed

+293
-0
lines changed

1 file changed

+293
-0
lines changed

stochman/discretized_manifold.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
#!/usr/bin/env python3
2+
from typing import Optional, Tuple, Union
3+
from math import ceil
4+
5+
import networkx as nx
6+
import torch
7+
8+
from stochman.manifold import Manifold
9+
from stochman.curves import CubicSpline, DiscreteCurve
10+
11+
12+
class DiscretizedManifold(Manifold):
13+
def __init__(self):
14+
self.grid = []
15+
self.grid_size = []
16+
self.G = nx.Graph()
17+
self.__metric__ = torch.Tensor()
18+
self._diagonal_metric = False
19+
self._alpha = torch.Tensor()
20+
21+
22+
def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0):
23+
"""
24+
Discretize a manifold to a given grid.
25+
26+
Input:
27+
model: a stochman.Manifold that is to be approximed with a graph.
28+
29+
grid: a list of torch.linspace's that defines the grid over which
30+
the manifold will be discretized. For example,
31+
grid = [torch.linspace(-3, 3, 50), torch.linspace(-3, 3, 50)]
32+
will discretize a two-dimensional manifold on a 50x50 grid.
33+
34+
use_diagonals:
35+
If True, diagonal edges are included in the graph, otherwise
36+
they are excluded.
37+
Default: True.
38+
39+
batch_size: Number of edge-lengths that are computed in parallel. The larger
40+
value you pick here, the faster the discretization will be.
41+
However, memory usage increases with this number, so a good
42+
choice is model and hardware specific.
43+
Default: 4.
44+
45+
interpolation_noise:
46+
On fitting, the manifold metric is evalated on the provided grid.
47+
The `metric` function then performs interpolation of this metric,
48+
using the mean of a Gaussian process. The observation noise of
49+
this GP regressor can be tuned through the `interpolation_noise`
50+
argument.
51+
Default: 0.0.
52+
"""
53+
self.grid = grid
54+
self.grid_size = [g.numel() for g in grid]
55+
self.G = nx.Graph()
56+
57+
dim = len(grid)
58+
if len(grid) != 2:
59+
raise Exception('Currently we only support 2D grids -- sorry!')
60+
61+
# Add nodes to graph
62+
xsize, ysize = len(grid[0]), len(grid[1])
63+
node_idx = lambda x, y: x*ysize + y
64+
self.G.add_nodes_from(range(xsize*ysize))
65+
66+
point_set = torch.cartesian_prod(
67+
torch.linspace(0, xsize-1, xsize, dtype=torch.long),
68+
torch.linspace(0, ysize-1, ysize, dtype=torch.long)
69+
) # (big)x2
70+
71+
point_sets = [ ] # these will be [N, 2] matrices of index points
72+
neighbour_funcs = [ ] # these will be functions for getting the neighbour index
73+
74+
# add sets
75+
point_sets.append(point_set[point_set[:, 0] > 0]) # x > 0
76+
neighbour_funcs.append([lambda x: x-1, lambda y: y])
77+
78+
point_sets.append(point_set[point_set[:, 1] > 0]) # y > 0
79+
neighbour_funcs.append([lambda x: x, lambda y: y-1])
80+
81+
point_sets.append(point_set[point_set[:, 0] < xsize-1]) # x < xsize-1
82+
neighbour_funcs.append([lambda x: x+1, lambda y: y])
83+
84+
point_sets.append(point_set[point_set[:, 1] < ysize-1]) # y < ysize-1
85+
neighbour_funcs.append([lambda x: x, lambda y: y+1])
86+
87+
if use_diagonals:
88+
point_sets.append(point_set[torch.logical_and(point_set[:,0] > 0, point_set[:,1] > 0)])
89+
neighbour_funcs.append([lambda x: x-1, lambda y: y-1])
90+
91+
point_sets.append(point_set[torch.logical_and(point_set[:,0] < xsize-1, point_set[:,1] > 0)])
92+
neighbour_funcs.append([lambda x: x+1, lambda y: y-1])
93+
94+
t = torch.linspace(0, 1, 2)
95+
for ps, nf in zip(point_sets, neighbour_funcs):
96+
for i in range(ceil(ps.shape[0] / batch_size)):
97+
x = ps[batch_size*i:batch_size*(i+1), 0]
98+
y = ps[batch_size*i:batch_size*(i+1), 1]
99+
xn = nf[0](x); yn = nf[1](y)
100+
101+
bs = x.shape[0] # may be different from batch size for the last batch
102+
103+
line = CubicSpline(begin=torch.zeros(bs, dim), end=torch.ones(bs, dim), num_nodes=2)
104+
line.begin = torch.cat([grid[0][x].view(-1, 1), grid[1][y].view(-1, 1)], dim=1) # (bs)x2
105+
line.end = torch.cat([grid[0][xn].view(-1, 1), grid[1][yn].view(-1, 1)], dim=1) # (bs)x2
106+
107+
#if external_curve_length_function:
108+
# weight = external_curve_length_function(model, line(t))
109+
#else:
110+
with torch.no_grad():
111+
weight = model.curve_length(line(t))
112+
113+
node_index1 = node_idx(x, y)
114+
node_index2 = node_idx(xn, yn)
115+
116+
for n1, n2, w in zip(node_index1, node_index2, weight):
117+
self.G.add_edge(n1.item(), n2.item(), weight=w.item())
118+
119+
# Evaluate metric at grid
120+
try:
121+
Mlist = []
122+
with torch.no_grad():
123+
for x in range(xsize):
124+
for y in range(ysize):
125+
p = torch.tensor([self.grid[0][x], self.grid[1][y]])
126+
Mlist.append(model.metric(p)) # 1x(d)x(d) or 1x(d)
127+
M = torch.cat(Mlist, dim=0) # (big)x(d)x(d) or (big)x(d)
128+
self._diagonal_metric = M.dim() == 2
129+
d = M.shape[-1]
130+
if self._diagonal_metric:
131+
self.__metric__ = M.view([*self.grid_size, d]) # e.g. (xsize)x(ysize)x(d)
132+
else:
133+
self.__metric__ = M.view([*self.grid_size, d, d]) # e.g. (xsize)x(ysize)x(d)x(d)
134+
135+
# Compute interpolation weights. We use the mean function of a GP regressor.
136+
mesh = torch.meshgrid(*self.grid, indexing='ij')
137+
grid_points = torch.cat([m.unsqueeze(-1) for m in mesh], dim=-1) # e.g. 100x100x2 a 2D grid with 100 points in each dim
138+
K = self._kernel(grid_points.view(-1, len(self.grid))) # (num_grid)x(num_grid)
139+
if interpolation_noise > 0.0:
140+
K += interpolation_noise * torch.eye(K.shape[0])
141+
num_grid = K.shape[0]
142+
self._alpha = torch.linalg.solve(K, self.__metric__.view(num_grid, -1)) # (num_grid)x(d²) or (num_grid)x(d)
143+
except:
144+
import warnings
145+
warnings.warn("It appears that your model does not implement a metric.")
146+
# XXX: Down the road, we should be able to estimate the metric from the observed distances
147+
148+
# def set(self, metric, grid, use_diagonals=True):
149+
# """
150+
# be able to set metric directly from a pre-evaluated grid -- tihs is currently not implemented
151+
# """
152+
# pass
153+
154+
def metric(self, points):
155+
"""
156+
Return the metric tensor at a specified set of points.
157+
158+
Input:
159+
points: a Nx(d) torch Tensor representing a set of
160+
points where the metric tensor is to be
161+
computed.
162+
163+
Output:
164+
M: a Nx(d)x(d) or Nx(d) torch Tensor representing
165+
the metric tensor at the given points.
166+
If M is Nx(d)x(d) then M[i] is a (d)x(d) symmetric
167+
positive definite matrix. If M is Nx(d) then M[i]
168+
is to be interpreted as the diagonal elements of
169+
a (d)x(d) diagonal matrix.
170+
"""
171+
# XXX: We should also support returning the derivative of the metric! (for ODEs; see local_PCA)
172+
K = self._kernel(points) # Nx(num_grid)
173+
M = K.mm(self._alpha) # Nx(d²) or Nx(d)
174+
if not self._diagonal_metric:
175+
d = len(grid)
176+
M = M.view(-1, d, d)
177+
return M
178+
179+
def _grid_dist2(self, p):
180+
"""Return the squared Euclidean distance from a set of points to the grid.
181+
182+
Input:
183+
p: a Nx(d) torch Tensor corresponding to N latent points.
184+
185+
Output:
186+
dist2: a NxM torch Tensor containing all Euclidean distances
187+
to the M grid points.
188+
"""
189+
190+
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
191+
mesh = torch.meshgrid(*self.grid, indexing='ij') # XXX: IT MUST BE POSSIBLE TO AVOID THIS GRID CONSTRUCTION
192+
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
193+
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2
194+
return dist2
195+
196+
def _kernel(self, p):
197+
"""Evaluate the interpolation kernel for computing the metric.
198+
199+
Input:
200+
p: a torch Tensor corresponding to a point on the manifold.
201+
202+
Output:
203+
val: a torch Tensor with the kernel values.
204+
"""
205+
lengthscales = [(g[1]-g[0])**2 for g in self.grid]
206+
207+
dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes())
208+
mesh = torch.meshgrid(*self.grid, indexing='ij')
209+
for mesh_dim, dim in zip(mesh, range(len(self.grid))):
210+
dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2/lengthscales[dim]
211+
212+
return torch.exp(-dist2)
213+
214+
def _grid_point(self, p):
215+
"""Return the index of the nearest grid point.
216+
217+
Input:
218+
p: a torch Tensor corresponding to a latent point.
219+
220+
Output:
221+
idx: an integer correponding to the node index of
222+
the nearest point on the grid.
223+
"""
224+
return self._grid_dist2(p).argmin().item()
225+
226+
def shortest_path(self, p1, p2):
227+
"""Compute the shortest path on the discretized manifold.
228+
229+
Inputs:
230+
p1: a torch Tensor corresponding to one latent point.
231+
232+
p2: a torch Tensor corresponding to another latent point.
233+
234+
Outputs:
235+
curve: a DiscreteCurve forming the shortest path from p1 to p2.
236+
237+
dist: a scalar indicating the length of the shortest curve.
238+
"""
239+
idx1 = self._grid_point(p1)
240+
idx2 = self._grid_point(p2)
241+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
242+
#coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
243+
mesh = torch.meshgrid(*self.grid, indexing='ij')
244+
raw_coordinates = [m.flatten()[path].view(1, -1) for m in mesh]
245+
coordinates = torch.cat(raw_coordinates, dim=0) # (dim)xN
246+
N = len(path)
247+
curve = DiscreteCurve(begin=coordinates[:, 0], end=coordinates[:, -1], num_nodes=N)
248+
with torch.no_grad():
249+
curve.parameters[:, :] = coordinates[:, 1:-1].t()
250+
dist = 0
251+
for i in range(N-1):
252+
dist += self.G.edges[path[i], path[i+1]]['weight']
253+
return curve, dist
254+
255+
def connecting_geodesic(self, p1, p2, curve=None):
256+
"""Compute the shortest path on the discretized manifold and fit
257+
a smooth curve to the resulting discrete curve.
258+
259+
Inputs:
260+
p1: a torch Tensor corresponding to one latent point.
261+
262+
p2: a torch Tensor corresponding to another latent point.
263+
264+
Optional input:
265+
curve: a curve that should be fitted to the discrete graph
266+
geodesic. By default this is None and a CubicSpline
267+
with default paramaters will be constructed.
268+
269+
Outputs:
270+
curve: a smooth curve forming the shortest path from p1 to p2.
271+
By default the curve is a CubicSpline with its default
272+
parameters; this can be changed through the optional
273+
curve input.
274+
"""
275+
device = p1.device
276+
idx1 = self._grid_point(p1)
277+
idx2 = self._grid_point(p2)
278+
path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements
279+
weights = [self.G.edges[path[k], path[k+1]]['weight'] for k in range(len(path)-1)]
280+
mesh = torch.meshgrid(*self.grid, indexing='ij')
281+
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
282+
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
283+
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
284+
285+
if curve is None:
286+
curve = CubicSpline(p1, p2)
287+
else:
288+
curve.begin = p1
289+
curve.end = p2
290+
291+
curve.fit(t, coordinates)
292+
293+
return curve

0 commit comments

Comments
 (0)