|
15 | 15 | "metadata": {}, |
16 | 16 | "outputs": [], |
17 | 17 | "source": [ |
18 | | - "from typing import Callable\n", |
| 18 | + "from collections.abc import Callable\n", |
19 | 19 | "\n", |
20 | 20 | "import matplotlib.pyplot as plt\n", |
21 | 21 | "import numpy as np\n", |
|
245 | 245 | } |
246 | 246 | ], |
247 | 247 | "source": [ |
248 | | - "steps, loss_vals = zip(*losses.items())\n", |
| 248 | + "steps, loss_vals = zip(*losses.items(), strict=True)\n", |
249 | 249 | "plt.plot(steps, loss_vals)" |
250 | 250 | ] |
251 | 251 | }, |
|
321 | 321 | "outputs": [], |
322 | 322 | "source": [ |
323 | 323 | "def plot_grid_warp(\n", |
324 | | - " ax: plt.Axes, z: np.ndarray, target_samples: np.ndarray, n_lines: int, idx: int\n", |
| 324 | + " ax: plt.Axes,\n", |
| 325 | + " z: np.ndarray,\n", |
| 326 | + " target_samples: np.ndarray | torch.Tensor,\n", |
| 327 | + " n_lines: int,\n", |
| 328 | + " idx: int,\n", |
325 | 329 | "):\n", |
326 | 330 | " \"\"\"plots how the flow warps space\"\"\"\n", |
327 | 331 | "\n", |
328 | 332 | " grid = z.reshape((n_lines, n_lines, 2))\n", |
329 | 333 | " # y coords\n", |
330 | 334 | " p1 = np.reshape(grid[1:, :, :], (n_lines**2 - n_lines, 2))\n", |
331 | 335 | " p2 = np.reshape(grid[:-1, :, :], (n_lines**2 - n_lines, 2))\n", |
332 | | - " lcy = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n", |
| 336 | + " lcy = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n", |
333 | 337 | " # x coords\n", |
334 | 338 | " p1 = np.reshape(grid[:, 1:, :], (n_lines**2 - n_lines, 2))\n", |
335 | 339 | " p2 = np.reshape(grid[:, :-1, :], (n_lines**2 - n_lines, 2))\n", |
336 | | - " lcx = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n", |
| 340 | + " lcx = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n", |
337 | 341 | " # draw the lines\n", |
338 | 342 | " ax.add_collection(lcx)\n", |
339 | 343 | " ax.add_collection(lcy)\n", |
|
477 | 481 | "xs, *_ = model.forward(latent_grid)\n", |
478 | 482 | "xs = [z.detach().numpy() for z in xs]\n", |
479 | 483 | "\n", |
480 | | - "for idx, [z0, z1] in enumerate(zip(xs, xs[1:])):\n", |
| 484 | + "for idx, [z0, z1] in enumerate(zip(xs, xs[1:], strict=True)):\n", |
481 | 485 | " _, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))\n", |
482 | 486 | "\n", |
483 | 487 | " plot_point_flow(ax1, z0, z1)\n", |
|
506 | 510 | " fig, axes = plt.subplots(plot_grid_height, 2 * plot_grid_height, figsize=(20, 10))\n", |
507 | 511 | " fig.subplots_adjust(wspace=0.05, hspace=0.05)\n", |
508 | 512 | "\n", |
509 | | - " for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat):\n", |
| 513 | + " for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat, strict=True):\n", |
510 | 514 | " plot_point_flow(ax, z0, z1)\n", |
511 | 515 | " ax.set(xlim=[-4, 4], ylim=[-4, 4], xticks=[], yticks=[])\n", |
512 | 516 | "\n", |
|
0 commit comments