|
| 1 | +#!/usr/bin/env python3 |
| 2 | +from typing import Optional, Tuple, Union |
| 3 | +from math import ceil |
| 4 | + |
| 5 | +import networkx as nx |
| 6 | +import torch |
| 7 | + |
| 8 | +from stochman.manifold import Manifold |
| 9 | +from stochman.curves import CubicSpline, DiscreteCurve |
| 10 | + |
| 11 | + |
| 12 | +class DiscretizedManifold(Manifold): |
| 13 | + def __init__(self): |
| 14 | + self.grid = [] |
| 15 | + self.grid_size = [] |
| 16 | + self.G = nx.Graph() |
| 17 | + self.__metric__ = torch.Tensor() |
| 18 | + self._diagonal_metric = False |
| 19 | + self._alpha = torch.Tensor() |
| 20 | + |
| 21 | + |
| 22 | + def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0): |
| 23 | + """ |
| 24 | + Discretize a manifold to a given grid. |
| 25 | +
|
| 26 | + Input: |
| 27 | + model: a stochman.Manifold that is to be approximed with a graph. |
| 28 | +
|
| 29 | + grid: a list of torch.linspace's that defines the grid over which |
| 30 | + the manifold will be discretized. For example, |
| 31 | + grid = [torch.linspace(-3, 3, 50), torch.linspace(-3, 3, 50)] |
| 32 | + will discretize a two-dimensional manifold on a 50x50 grid. |
| 33 | + |
| 34 | + use_diagonals: |
| 35 | + If True, diagonal edges are included in the graph, otherwise |
| 36 | + they are excluded. |
| 37 | + Default: True. |
| 38 | + |
| 39 | + batch_size: Number of edge-lengths that are computed in parallel. The larger |
| 40 | + value you pick here, the faster the discretization will be. |
| 41 | + However, memory usage increases with this number, so a good |
| 42 | + choice is model and hardware specific. |
| 43 | + Default: 4. |
| 44 | + |
| 45 | + interpolation_noise: |
| 46 | + On fitting, the manifold metric is evalated on the provided grid. |
| 47 | + The `metric` function then performs interpolation of this metric, |
| 48 | + using the mean of a Gaussian process. The observation noise of |
| 49 | + this GP regressor can be tuned through the `interpolation_noise` |
| 50 | + argument. |
| 51 | + Default: 0.0. |
| 52 | + """ |
| 53 | + self.grid = grid |
| 54 | + self.grid_size = [g.numel() for g in grid] |
| 55 | + self.G = nx.Graph() |
| 56 | + |
| 57 | + dim = len(grid) |
| 58 | + if len(grid) != 2: |
| 59 | + raise Exception('Currently we only support 2D grids -- sorry!') |
| 60 | + |
| 61 | + # Add nodes to graph |
| 62 | + xsize, ysize = len(grid[0]), len(grid[1]) |
| 63 | + node_idx = lambda x, y: x*ysize + y |
| 64 | + self.G.add_nodes_from(range(xsize*ysize)) |
| 65 | + |
| 66 | + point_set = torch.cartesian_prod( |
| 67 | + torch.linspace(0, xsize-1, xsize, dtype=torch.long), |
| 68 | + torch.linspace(0, ysize-1, ysize, dtype=torch.long) |
| 69 | + ) # (big)x2 |
| 70 | + |
| 71 | + point_sets = [ ] # these will be [N, 2] matrices of index points |
| 72 | + neighbour_funcs = [ ] # these will be functions for getting the neighbour index |
| 73 | + |
| 74 | + # add sets |
| 75 | + point_sets.append(point_set[point_set[:, 0] > 0]) # x > 0 |
| 76 | + neighbour_funcs.append([lambda x: x-1, lambda y: y]) |
| 77 | + |
| 78 | + point_sets.append(point_set[point_set[:, 1] > 0]) # y > 0 |
| 79 | + neighbour_funcs.append([lambda x: x, lambda y: y-1]) |
| 80 | + |
| 81 | + point_sets.append(point_set[point_set[:, 0] < xsize-1]) # x < xsize-1 |
| 82 | + neighbour_funcs.append([lambda x: x+1, lambda y: y]) |
| 83 | + |
| 84 | + point_sets.append(point_set[point_set[:, 1] < ysize-1]) # y < ysize-1 |
| 85 | + neighbour_funcs.append([lambda x: x, lambda y: y+1]) |
| 86 | + |
| 87 | + if use_diagonals: |
| 88 | + point_sets.append(point_set[torch.logical_and(point_set[:,0] > 0, point_set[:,1] > 0)]) |
| 89 | + neighbour_funcs.append([lambda x: x-1, lambda y: y-1]) |
| 90 | + |
| 91 | + point_sets.append(point_set[torch.logical_and(point_set[:,0] < xsize-1, point_set[:,1] > 0)]) |
| 92 | + neighbour_funcs.append([lambda x: x+1, lambda y: y-1]) |
| 93 | + |
| 94 | + t = torch.linspace(0, 1, 2) |
| 95 | + for ps, nf in zip(point_sets, neighbour_funcs): |
| 96 | + for i in range(ceil(ps.shape[0] / batch_size)): |
| 97 | + x = ps[batch_size*i:batch_size*(i+1), 0] |
| 98 | + y = ps[batch_size*i:batch_size*(i+1), 1] |
| 99 | + xn = nf[0](x); yn = nf[1](y) |
| 100 | + |
| 101 | + bs = x.shape[0] # may be different from batch size for the last batch |
| 102 | + |
| 103 | + line = CubicSpline(begin=torch.zeros(bs, dim), end=torch.ones(bs, dim), num_nodes=2) |
| 104 | + line.begin = torch.cat([grid[0][x].view(-1, 1), grid[1][y].view(-1, 1)], dim=1) # (bs)x2 |
| 105 | + line.end = torch.cat([grid[0][xn].view(-1, 1), grid[1][yn].view(-1, 1)], dim=1) # (bs)x2 |
| 106 | + |
| 107 | + #if external_curve_length_function: |
| 108 | + # weight = external_curve_length_function(model, line(t)) |
| 109 | + #else: |
| 110 | + with torch.no_grad(): |
| 111 | + weight = model.curve_length(line(t)) |
| 112 | + |
| 113 | + node_index1 = node_idx(x, y) |
| 114 | + node_index2 = node_idx(xn, yn) |
| 115 | + |
| 116 | + for n1, n2, w in zip(node_index1, node_index2, weight): |
| 117 | + self.G.add_edge(n1.item(), n2.item(), weight=w.item()) |
| 118 | + |
| 119 | + # Evaluate metric at grid |
| 120 | + try: |
| 121 | + Mlist = [] |
| 122 | + with torch.no_grad(): |
| 123 | + for x in range(xsize): |
| 124 | + for y in range(ysize): |
| 125 | + p = torch.tensor([self.grid[0][x], self.grid[1][y]]) |
| 126 | + Mlist.append(model.metric(p)) # 1x(d)x(d) or 1x(d) |
| 127 | + M = torch.cat(Mlist, dim=0) # (big)x(d)x(d) or (big)x(d) |
| 128 | + self._diagonal_metric = M.dim() == 2 |
| 129 | + d = M.shape[-1] |
| 130 | + if self._diagonal_metric: |
| 131 | + self.__metric__ = M.view([*self.grid_size, d]) # e.g. (xsize)x(ysize)x(d) |
| 132 | + else: |
| 133 | + self.__metric__ = M.view([*self.grid_size, d, d]) # e.g. (xsize)x(ysize)x(d)x(d) |
| 134 | + |
| 135 | + # Compute interpolation weights. We use the mean function of a GP regressor. |
| 136 | + mesh = torch.meshgrid(*self.grid, indexing='ij') |
| 137 | + grid_points = torch.cat([m.unsqueeze(-1) for m in mesh], dim=-1) # e.g. 100x100x2 a 2D grid with 100 points in each dim |
| 138 | + K = self._kernel(grid_points.view(-1, len(self.grid))) # (num_grid)x(num_grid) |
| 139 | + if interpolation_noise > 0.0: |
| 140 | + K += interpolation_noise * torch.eye(K.shape[0]) |
| 141 | + num_grid = K.shape[0] |
| 142 | + self._alpha = torch.linalg.solve(K, self.__metric__.view(num_grid, -1)) # (num_grid)x(d²) or (num_grid)x(d) |
| 143 | + except: |
| 144 | + import warnings |
| 145 | + warnings.warn("It appears that your model does not implement a metric.") |
| 146 | + # XXX: Down the road, we should be able to estimate the metric from the observed distances |
| 147 | + |
| 148 | + # def set(self, metric, grid, use_diagonals=True): |
| 149 | + # """ |
| 150 | + # be able to set metric directly from a pre-evaluated grid -- tihs is currently not implemented |
| 151 | + # """ |
| 152 | + # pass |
| 153 | + |
| 154 | + def metric(self, points): |
| 155 | + """ |
| 156 | + Return the metric tensor at a specified set of points. |
| 157 | +
|
| 158 | + Input: |
| 159 | + points: a Nx(d) torch Tensor representing a set of |
| 160 | + points where the metric tensor is to be |
| 161 | + computed. |
| 162 | +
|
| 163 | + Output: |
| 164 | + M: a Nx(d)x(d) or Nx(d) torch Tensor representing |
| 165 | + the metric tensor at the given points. |
| 166 | + If M is Nx(d)x(d) then M[i] is a (d)x(d) symmetric |
| 167 | + positive definite matrix. If M is Nx(d) then M[i] |
| 168 | + is to be interpreted as the diagonal elements of |
| 169 | + a (d)x(d) diagonal matrix. |
| 170 | + """ |
| 171 | + # XXX: We should also support returning the derivative of the metric! (for ODEs; see local_PCA) |
| 172 | + K = self._kernel(points) # Nx(num_grid) |
| 173 | + M = K.mm(self._alpha) # Nx(d²) or Nx(d) |
| 174 | + if not self._diagonal_metric: |
| 175 | + d = len(grid) |
| 176 | + M = M.view(-1, d, d) |
| 177 | + return M |
| 178 | + |
| 179 | + def _grid_dist2(self, p): |
| 180 | + """Return the squared Euclidean distance from a set of points to the grid. |
| 181 | +
|
| 182 | + Input: |
| 183 | + p: a Nx(d) torch Tensor corresponding to N latent points. |
| 184 | +
|
| 185 | + Output: |
| 186 | + dist2: a NxM torch Tensor containing all Euclidean distances |
| 187 | + to the M grid points. |
| 188 | + """ |
| 189 | + |
| 190 | + dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes()) |
| 191 | + mesh = torch.meshgrid(*self.grid, indexing='ij') # XXX: IT MUST BE POSSIBLE TO AVOID THIS GRID CONSTRUCTION |
| 192 | + for mesh_dim, dim in zip(mesh, range(len(self.grid))): |
| 193 | + dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2 |
| 194 | + return dist2 |
| 195 | + |
| 196 | + def _kernel(self, p): |
| 197 | + """Evaluate the interpolation kernel for computing the metric. |
| 198 | +
|
| 199 | + Input: |
| 200 | + p: a torch Tensor corresponding to a point on the manifold. |
| 201 | + |
| 202 | + Output: |
| 203 | + val: a torch Tensor with the kernel values. |
| 204 | + """ |
| 205 | + lengthscales = [(g[1]-g[0])**2 for g in self.grid] |
| 206 | + |
| 207 | + dist2 = torch.zeros(p.shape[0], self.G.number_of_nodes()) |
| 208 | + mesh = torch.meshgrid(*self.grid, indexing='ij') |
| 209 | + for mesh_dim, dim in zip(mesh, range(len(self.grid))): |
| 210 | + dist2 += (p[:, dim].view(-1, 1) - mesh_dim.reshape(1, -1))**2/lengthscales[dim] |
| 211 | + |
| 212 | + return torch.exp(-dist2) |
| 213 | + |
| 214 | + def _grid_point(self, p): |
| 215 | + """Return the index of the nearest grid point. |
| 216 | +
|
| 217 | + Input: |
| 218 | + p: a torch Tensor corresponding to a latent point. |
| 219 | + |
| 220 | + Output: |
| 221 | + idx: an integer correponding to the node index of |
| 222 | + the nearest point on the grid. |
| 223 | + """ |
| 224 | + return self._grid_dist2(p).argmin().item() |
| 225 | + |
| 226 | + def shortest_path(self, p1, p2): |
| 227 | + """Compute the shortest path on the discretized manifold. |
| 228 | +
|
| 229 | + Inputs: |
| 230 | + p1: a torch Tensor corresponding to one latent point. |
| 231 | +
|
| 232 | + p2: a torch Tensor corresponding to another latent point. |
| 233 | + |
| 234 | + Outputs: |
| 235 | + curve: a DiscreteCurve forming the shortest path from p1 to p2. |
| 236 | +
|
| 237 | + dist: a scalar indicating the length of the shortest curve. |
| 238 | + """ |
| 239 | + idx1 = self._grid_point(p1) |
| 240 | + idx2 = self._grid_point(p2) |
| 241 | + path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements |
| 242 | + #coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN |
| 243 | + mesh = torch.meshgrid(*self.grid, indexing='ij') |
| 244 | + raw_coordinates = [m.flatten()[path].view(1, -1) for m in mesh] |
| 245 | + coordinates = torch.cat(raw_coordinates, dim=0) # (dim)xN |
| 246 | + N = len(path) |
| 247 | + curve = DiscreteCurve(begin=coordinates[:, 0], end=coordinates[:, -1], num_nodes=N) |
| 248 | + with torch.no_grad(): |
| 249 | + curve.parameters[:, :] = coordinates[:, 1:-1].t() |
| 250 | + dist = 0 |
| 251 | + for i in range(N-1): |
| 252 | + dist += self.G.edges[path[i], path[i+1]]['weight'] |
| 253 | + return curve, dist |
| 254 | + |
| 255 | + def connecting_geodesic(self, p1, p2, curve=None): |
| 256 | + """Compute the shortest path on the discretized manifold and fit |
| 257 | + a smooth curve to the resulting discrete curve. |
| 258 | +
|
| 259 | + Inputs: |
| 260 | + p1: a torch Tensor corresponding to one latent point. |
| 261 | +
|
| 262 | + p2: a torch Tensor corresponding to another latent point. |
| 263 | + |
| 264 | + Optional input: |
| 265 | + curve: a curve that should be fitted to the discrete graph |
| 266 | + geodesic. By default this is None and a CubicSpline |
| 267 | + with default paramaters will be constructed. |
| 268 | + |
| 269 | + Outputs: |
| 270 | + curve: a smooth curve forming the shortest path from p1 to p2. |
| 271 | + By default the curve is a CubicSpline with its default |
| 272 | + parameters; this can be changed through the optional |
| 273 | + curve input. |
| 274 | + """ |
| 275 | + device = p1.device |
| 276 | + idx1 = self._grid_point(p1) |
| 277 | + idx2 = self._grid_point(p2) |
| 278 | + path = nx.shortest_path(self.G, source=idx1, target=idx2, weight='weight') # list with N elements |
| 279 | + weights = [self.G.edges[path[k], path[k+1]]['weight'] for k in range(len(path)-1)] |
| 280 | + mesh = torch.meshgrid(*self.grid, indexing='ij') |
| 281 | + raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh] |
| 282 | + coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim) |
| 283 | + t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights) |
| 284 | + |
| 285 | + if curve is None: |
| 286 | + curve = CubicSpline(p1, p2) |
| 287 | + else: |
| 288 | + curve.begin = p1 |
| 289 | + curve.end = p2 |
| 290 | + |
| 291 | + curve.fit(t, coordinates) |
| 292 | + |
| 293 | + return curve |
0 commit comments