Skip to content

Commit 45c7a05

Browse files
committed
heading structure and blurb
1 parent 5289a63 commit 45c7a05

File tree

1 file changed

+69
-42
lines changed

1 file changed

+69
-42
lines changed

examples/cfd/demo.ipynb

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,19 @@
1212
"All examples are expected to run from the `examples/<example_name>` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).\n",
1313
"</div>\n",
1414
"\n",
15-
"`cfd-tesseract` is a differentiable Navier-Stokes solver based on [JAX-CFD](https://github.com/google/jax-cfd) that is wrapped in a Tesseract. "
15+
"`cfd-tesseract` is a differentiable Navier-Stokes solver based on [JAX-CFD](https://github.com/google/jax-cfd) that is wrapped in a Tesseract. \n",
16+
"\n",
17+
"In this demo, you will learn how to:\n",
18+
"1. Define a fluid velocity field with periodic boundary conditions, using JAX-CFD\n",
19+
"1. Pass that velocity field to a Tesseract which applies a differentiable Navier-Stokes solver to evolve it\n",
20+
"1. Use gradient-based optimizations to find an initial field configuration which spontaneously evolves into the Pasteur Labs Logo!"
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"metadata": {},
26+
"source": [
27+
"## Step 1: Build `cfd-tesseract` + install dependencies"
1628
]
1729
},
1830
{
@@ -64,16 +76,30 @@
6476
"%pip install -r requirements.txt -q"
6577
]
6678
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"## Step 2: Test the forward evaluation with Tesseract-JAX"
84+
]
85+
},
6786
{
6887
"cell_type": "markdown",
6988
"metadata": {},
7089
"source": [
7190
"Let's set up the Tesseract with Tesseract-JAX and test a simple forward evaluation."
7291
]
7392
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"First, we'll define an initial guess for the velocity field over a grid. The resulting `vx` and `vy` give our horizontal and vertical velocity fields, respectively."
98+
]
99+
},
74100
{
75101
"cell_type": "code",
76-
"execution_count": 3,
102+
"execution_count": 1,
77103
"metadata": {},
78104
"outputs": [],
79105
"source": [
@@ -97,8 +123,15 @@
97123
"v0 = cfd.initial_conditions.filtered_velocity_field(\n",
98124
" jax.random.PRNGKey(seed), grid, max_velocity\n",
99125
")\n",
100-
"vx, vy = v0\n",
101-
"\n",
126+
"vx, vy = v0"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": 8,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
102135
"cfd_tesseract = Tesseract.from_image(\"jax-cfd\")\n",
103136
"cfd_tesseract.serve()\n",
104137
"\n",
@@ -115,17 +148,22 @@
115148
"\n",
116149
"v0 = np.stack([np.array(vx.array.data), np.array(vy.array.data)], axis=-1)\n",
117150
"\n",
151+
"\n",
118152
"def cfd_tesseract_fn(v0):\n",
119153
" return apply_tesseract(cfd_tesseract, inputs=dict(v0=v0, **params))\n",
120154
"\n",
155+
"\n",
121156
"outputs = cfd_tesseract_fn(v0)"
122157
]
123158
},
124159
{
125160
"cell_type": "markdown",
126161
"metadata": {},
127162
"source": [
128-
"Lets look at the results of the forward pass."
163+
"## Step 3: Visualise the outputs from our Tesseract\n",
164+
"\n",
165+
"Using the results of the forward pass, we can set up a basic approach for visualising our velocity field.\n",
166+
"We'll use `matplotlib` to show the $x$ and $y$ components of the velocity as heatmaps."
129167
]
130168
},
131169
{
@@ -156,9 +194,16 @@
156194
"plt.show()"
157195
]
158196
},
197+
{
198+
"cell_type": "markdown",
199+
"metadata": {},
200+
"source": [
201+
"We can take this further, and view the vorticity field of the fluid; first we define periodic boundary conditions."
202+
]
203+
},
159204
{
160205
"cell_type": "code",
161-
"execution_count": 5,
206+
"execution_count": 9,
162207
"metadata": {},
163208
"outputs": [],
164209
"source": [
@@ -167,34 +212,15 @@
167212
" (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),\n",
168213
" (cfd.boundaries.BCType.PERIODIC, cfd.boundaries.BCType.PERIODIC),\n",
169214
" )\n",
170-
")"
171-
]
172-
},
173-
{
174-
"cell_type": "code",
175-
"execution_count": 6,
176-
"metadata": {},
177-
"outputs": [
178-
{
179-
"data": {
180-
"text/plain": [
181-
"HomogeneousBoundaryConditions(types=(('periodic', 'periodic'), ('periodic', 'periodic')), bc_values=((0.0, 0.0), (0.0, 0.0)))"
182-
]
183-
},
184-
"execution_count": 6,
185-
"metadata": {},
186-
"output_type": "execute_result"
187-
}
188-
],
189-
"source": [
190-
"bc"
215+
")\n",
216+
"print(bc)"
191217
]
192218
},
193219
{
194220
"cell_type": "markdown",
195221
"metadata": {},
196222
"source": [
197-
"And look at its vorticity field."
223+
"Next we define a vorticity function (recalling that vorticity is the curl of the flow velocity, *ie.* $\\omega = \\nabla \\times \\mathbf{v}$."
198224
]
199225
},
200226
{
@@ -240,9 +266,12 @@
240266
"cell_type": "markdown",
241267
"metadata": {},
242268
"source": [
243-
"## Optimization\n",
269+
"## Step 4: Optimizing the fluid to evolve into the Pasteur Labs logo\n",
270+
"\n",
271+
"Now we want to perform an actual optimization.\n",
272+
"Our goal is to find the initial state, such that the final state resembles logo of Pasteur Labs.\n",
244273
"\n",
245-
"Now we want to perform an actual optimization. The target is to find the initial state, such that the final state looks a bit like the logo of Pasteur Labs. Herefore we first load the logo.\n"
274+
"Let's start by loading in our logo!"
246275
]
247276
},
248277
{
@@ -334,7 +363,7 @@
334363
" vort = vorticity(vxn, vyn)\n",
335364
"\n",
336365
" # decrase difference of vorticity and image and ensure the field is divergence free\n",
337-
" return mse(vort, img) + 0.05 * mse(div, 0.0)\n"
366+
" return mse(vort, img) + 0.05 * mse(div, 0.0)"
338367
]
339368
},
340369
{
@@ -376,11 +405,11 @@
376405
"from tqdm import tqdm\n",
377406
"\n",
378407
"\n",
379-
"def loss_fn_capt(v0_flat, img=img, xlen: int= grid.shape[0]):\n",
408+
"def loss_fn_capt(v0_flat, img=img, xlen: int = grid.shape[0]):\n",
380409
" total_len = len(v0_flat)\n",
381410
" ylen = (total_len // 2) // xlen\n",
382-
" v0x = v0_flat[:total_len//2].reshape(xlen, ylen)\n",
383-
" v0y = v0_flat[total_len//2:].reshape(xlen, ylen)\n",
411+
" v0x = v0_flat[: total_len // 2].reshape(xlen, ylen)\n",
412+
" v0y = v0_flat[total_len // 2 :].reshape(xlen, ylen)\n",
384413
"\n",
385414
" div = divergence(v0x, v0y)\n",
386415
"\n",
@@ -395,7 +424,6 @@
395424
" return mse(vort, img) + 0.05 * mse(div, 0.0)\n",
396425
"\n",
397426
"\n",
398-
"\n",
399427
"v0_field = cfd.initial_conditions.filtered_velocity_field(\n",
400428
" jax.random.PRNGKey(221), grid, max_velocity\n",
401429
")\n",
@@ -406,12 +434,13 @@
406434
"max_iter = 400\n",
407435
"with tqdm(total=max_iter) as pbar:\n",
408436
" i = 0\n",
437+
"\n",
409438
" def callback(intermediate_result):\n",
410439
" global i\n",
411440
" i += 1\n",
412441
" pbar.set_postfix(loss=f\"{intermediate_result.fun:.4f}\")\n",
413442
" pbar.update(1)\n",
414-
" \n",
443+
"\n",
415444
" opt = minimize(\n",
416445
" grad_fn,\n",
417446
" v0_flat,\n",
@@ -423,13 +452,13 @@
423452
"\n",
424453
"if i < max_iter:\n",
425454
" print(\"Optimisation converged before reaching max_iter!\")\n",
426-
" \n",
455+
"\n",
427456
"# Reshape to generate gif in next cell\n",
428457
"v0_flat = opt.x\n",
429458
"xlen = grid.shape[0]\n",
430459
"ylen = grid.shape[1]\n",
431-
"v0x = v0_flat[:xlen*ylen].reshape(xlen, ylen)\n",
432-
"v0y = v0_flat[xlen*ylen:].reshape(xlen, ylen)\n",
460+
"v0x = v0_flat[: xlen * ylen].reshape(xlen, ylen)\n",
461+
"v0y = v0_flat[xlen * ylen :].reshape(xlen, ylen)\n",
433462
"v0 = jnp.stack([v0x, v0y], axis=-1)"
434463
]
435464
},
@@ -11791,10 +11820,8 @@
1179111820
}
1179211821
],
1179311822
"source": [
11794-
"from IPython.display import HTML\n",
11795-
"\n",
1179611823
"import matplotlib.animation as animation\n",
11797-
"\n",
11824+
"from IPython.display import HTML\n",
1179811825
"\n",
1179911826
"trajectory = []\n",
1180011827
"\n",

0 commit comments

Comments
 (0)