|
12 | 12 | "All examples are expected to run from the `examples/<example_name>` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).\n", |
13 | 13 | "</div>\n", |
14 | 14 | "\n", |
15 | | - "`cfd-tesseract` is a differentiable Navier-Stokes solver based on [JAX-CFD](https://github.com/google/jax-cfd) that is wrapped in a Tesseract. " |
| 15 | + "`cfd-tesseract` is a differentiable Navier-Stokes solver based on [JAX-CFD](https://github.com/google/jax-cfd) that is wrapped in a Tesseract. \n", |
| 16 | + "\n", |
| 17 | + "In this demo, you will learn how to:\n", |
| 18 | + "1. Define a fluid velocity field with periodic boundary conditions, using JAX-CFD\n", |
| 19 | + "1. Pass that velocity field to a Tesseract which applies a differentiable Navier-Stokes solver to evolve it\n", |
| 20 | + "1. Use gradient-based optimizations to find an initial field configuration which spontaneously evolves into the Pasteur Labs Logo!" |
| 21 | + ] |
| 22 | + }, |
| 23 | + { |
| 24 | + "cell_type": "markdown", |
| 25 | + "metadata": {}, |
| 26 | + "source": [ |
| 27 | + "## Step 1: Build `cfd-tesseract` + install dependencies" |
16 | 28 | ] |
17 | 29 | }, |
18 | 30 | { |
|
64 | 76 | "%pip install -r requirements.txt -q" |
65 | 77 | ] |
66 | 78 | }, |
| 79 | + { |
| 80 | + "cell_type": "markdown", |
| 81 | + "metadata": {}, |
| 82 | + "source": [ |
| 83 | + "## Step 2: Test the forward evaluation with Tesseract-JAX" |
| 84 | + ] |
| 85 | + }, |
67 | 86 | { |
68 | 87 | "cell_type": "markdown", |
69 | 88 | "metadata": {}, |
70 | 89 | "source": [ |
71 | 90 | "Let's set up the Tesseract with Tesseract-JAX and test a simple forward evaluation." |
72 | 91 | ] |
73 | 92 | }, |
| 93 | + { |
| 94 | + "cell_type": "markdown", |
| 95 | + "metadata": {}, |
| 96 | + "source": [ |
| 97 | + "First, we'll define an initial guess for the velocity field over a grid. The resulting `vx` and `vy` give our horizontal and vertical velocity fields, respectively." |
| 98 | + ] |
| 99 | + }, |
74 | 100 | { |
75 | 101 | "cell_type": "code", |
76 | | - "execution_count": 3, |
| 102 | + "execution_count": 1, |
77 | 103 | "metadata": {}, |
78 | 104 | "outputs": [], |
79 | 105 | "source": [ |
|
97 | 123 | "v0 = cfd.initial_conditions.filtered_velocity_field(\n", |
98 | 124 | " jax.random.PRNGKey(seed), grid, max_velocity\n", |
99 | 125 | ")\n", |
100 | | - "vx, vy = v0\n", |
101 | | - "\n", |
| 126 | + "vx, vy = v0" |
| 127 | + ] |
| 128 | + }, |
| 129 | + { |
| 130 | + "cell_type": "code", |
| 131 | + "execution_count": 8, |
| 132 | + "metadata": {}, |
| 133 | + "outputs": [], |
| 134 | + "source": [ |
102 | 135 | "cfd_tesseract = Tesseract.from_image(\"jax-cfd\")\n", |
103 | 136 | "cfd_tesseract.serve()\n", |
104 | 137 | "\n", |
|
115 | 148 | "\n", |
116 | 149 | "v0 = np.stack([np.array(vx.array.data), np.array(vy.array.data)], axis=-1)\n", |
117 | 150 | "\n", |
| 151 | + "\n", |
118 | 152 | "def cfd_tesseract_fn(v0):\n", |
119 | 153 | " return apply_tesseract(cfd_tesseract, inputs=dict(v0=v0, **params))\n", |
120 | 154 | "\n", |
| 155 | + "\n", |
121 | 156 | "outputs = cfd_tesseract_fn(v0)" |
122 | 157 | ] |
123 | 158 | }, |
124 | 159 | { |
125 | 160 | "cell_type": "markdown", |
126 | 161 | "metadata": {}, |
127 | 162 | "source": [ |
128 | | - "Lets look at the results of the forward pass." |
| 163 | + "## Step 3: Visualise the outputs from our Tesseract\n", |
| 164 | + "\n", |
| 165 | + "Using the results of the forward pass, we can set up a basic approach for visualising our velocity field.\n", |
| 166 | + "We'll use `matplotlib` to show the $x$ and $y$ components of the velocity as heatmaps." |
129 | 167 | ] |
130 | 168 | }, |
131 | 169 | { |
|
156 | 194 | "plt.show()" |
157 | 195 | ] |
158 | 196 | }, |
| 197 | + { |
| 198 | + "cell_type": "markdown", |
| 199 | + "metadata": {}, |
| 200 | + "source": [ |
| 201 | + "We can take this further, and view the vorticity field of the fluid; first we define periodic boundary conditions." |
| 202 | + ] |
| 203 | + }, |
159 | 204 | { |
160 | 205 | "cell_type": "code", |
161 | | - "execution_count": 5, |
| 206 | + "execution_count": 9, |
162 | 207 | "metadata": {}, |
163 | 208 | "outputs": [], |
164 | 209 | "source": [ |
|
167 | 212 | " (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),\n", |
168 | 213 | " (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),\n", |
169 | 214 | " )\n", |
170 | | - ")" |
171 | | - ] |
172 | | - }, |
173 | | - { |
174 | | - "cell_type": "code", |
175 | | - "execution_count": 6, |
176 | | - "metadata": {}, |
177 | | - "outputs": [ |
178 | | - { |
179 | | - "data": { |
180 | | - "text/plain": [ |
181 | | - "HomogeneousBoundaryConditions(types=(('periodic', 'periodic'), ('periodic', 'periodic')), bc_values=((0.0, 0.0), (0.0, 0.0)))" |
182 | | - ] |
183 | | - }, |
184 | | - "execution_count": 6, |
185 | | - "metadata": {}, |
186 | | - "output_type": "execute_result" |
187 | | - } |
188 | | - ], |
189 | | - "source": [ |
190 | | - "bc" |
| 215 | + ")\n", |
| 216 | + "print(bc)" |
191 | 217 | ] |
192 | 218 | }, |
193 | 219 | { |
194 | 220 | "cell_type": "markdown", |
195 | 221 | "metadata": {}, |
196 | 222 | "source": [ |
197 | | - "And look at its vorticity field." |
| 223 | + "Next we define a vorticity function (recalling that vorticity is the curl of the flow velocity, *ie.* $\\omega = \\nabla \\times \\mathbf{v}$." |
198 | 224 | ] |
199 | 225 | }, |
200 | 226 | { |
|
240 | 266 | "cell_type": "markdown", |
241 | 267 | "metadata": {}, |
242 | 268 | "source": [ |
243 | | - "## Optimization\n", |
| 269 | + "## Step 4: Optimizing the fluid to evolve into the Pasteur Labs logo\n", |
| 270 | + "\n", |
| 271 | + "Now we want to perform an actual optimization.\n", |
| 272 | + "Our goal is to find the initial state, such that the final state resembles logo of Pasteur Labs.\n", |
244 | 273 | "\n", |
245 | | - "Now we want to perform an actual optimization. The target is to find the initial state, such that the final state looks a bit like the logo of Pasteur Labs. Herefore we first load the logo.\n" |
| 274 | + "Let's start by loading in our logo!" |
246 | 275 | ] |
247 | 276 | }, |
248 | 277 | { |
|
334 | 363 | " vort = vorticity(vxn, vyn)\n", |
335 | 364 | "\n", |
336 | 365 | " # decrase difference of vorticity and image and ensure the field is divergence free\n", |
337 | | - " return mse(vort, img) + 0.05 * mse(div, 0.0)\n" |
| 366 | + " return mse(vort, img) + 0.05 * mse(div, 0.0)" |
338 | 367 | ] |
339 | 368 | }, |
340 | 369 | { |
|
376 | 405 | "from tqdm import tqdm\n", |
377 | 406 | "\n", |
378 | 407 | "\n", |
379 | | - "def loss_fn_capt(v0_flat, img=img, xlen: int= grid.shape[0]):\n", |
| 408 | + "def loss_fn_capt(v0_flat, img=img, xlen: int = grid.shape[0]):\n", |
380 | 409 | " total_len = len(v0_flat)\n", |
381 | 410 | " ylen = (total_len // 2) // xlen\n", |
382 | | - " v0x = v0_flat[:total_len//2].reshape(xlen, ylen)\n", |
383 | | - " v0y = v0_flat[total_len//2:].reshape(xlen, ylen)\n", |
| 411 | + " v0x = v0_flat[: total_len // 2].reshape(xlen, ylen)\n", |
| 412 | + " v0y = v0_flat[total_len // 2 :].reshape(xlen, ylen)\n", |
384 | 413 | "\n", |
385 | 414 | " div = divergence(v0x, v0y)\n", |
386 | 415 | "\n", |
|
395 | 424 | " return mse(vort, img) + 0.05 * mse(div, 0.0)\n", |
396 | 425 | "\n", |
397 | 426 | "\n", |
398 | | - "\n", |
399 | 427 | "v0_field = cfd.initial_conditions.filtered_velocity_field(\n", |
400 | 428 | " jax.random.PRNGKey(221), grid, max_velocity\n", |
401 | 429 | ")\n", |
|
406 | 434 | "max_iter = 400\n", |
407 | 435 | "with tqdm(total=max_iter) as pbar:\n", |
408 | 436 | " i = 0\n", |
| 437 | + "\n", |
409 | 438 | " def callback(intermediate_result):\n", |
410 | 439 | " global i\n", |
411 | 440 | " i += 1\n", |
412 | 441 | " pbar.set_postfix(loss=f\"{intermediate_result.fun:.4f}\")\n", |
413 | 442 | " pbar.update(1)\n", |
414 | | - " \n", |
| 443 | + "\n", |
415 | 444 | " opt = minimize(\n", |
416 | 445 | " grad_fn,\n", |
417 | 446 | " v0_flat,\n", |
|
423 | 452 | "\n", |
424 | 453 | "if i < max_iter:\n", |
425 | 454 | " print(\"Optimisation converged before reaching max_iter!\")\n", |
426 | | - " \n", |
| 455 | + "\n", |
427 | 456 | "# Reshape to generate gif in next cell\n", |
428 | 457 | "v0_flat = opt.x\n", |
429 | 458 | "xlen = grid.shape[0]\n", |
430 | 459 | "ylen = grid.shape[1]\n", |
431 | | - "v0x = v0_flat[:xlen*ylen].reshape(xlen, ylen)\n", |
432 | | - "v0y = v0_flat[xlen*ylen:].reshape(xlen, ylen)\n", |
| 460 | + "v0x = v0_flat[: xlen * ylen].reshape(xlen, ylen)\n", |
| 461 | + "v0y = v0_flat[xlen * ylen :].reshape(xlen, ylen)\n", |
433 | 462 | "v0 = jnp.stack([v0x, v0y], axis=-1)" |
434 | 463 | ] |
435 | 464 | }, |
|
11791 | 11820 | } |
11792 | 11821 | ], |
11793 | 11822 | "source": [ |
11794 | | - "from IPython.display import HTML\n", |
11795 | | - "\n", |
11796 | 11823 | "import matplotlib.animation as animation\n", |
11797 | | - "\n", |
| 11824 | + "from IPython.display import HTML\n", |
11798 | 11825 | "\n", |
11799 | 11826 | "trajectory = []\n", |
11800 | 11827 | "\n", |
|
0 commit comments