Skip to content

Commit 279c650

Browse files
committed
destruc'd vjp output for clarity, grad argnums passed explicitly
1 parent 12dd660 commit 279c650

File tree

1 file changed

+86
-40
lines changed

1 file changed

+86
-40
lines changed

examples/simple/demo.ipynb

Lines changed: 86 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
"source": [
6565
"## Run the Tesseract\n",
6666
"\n",
67-
"The main entrypoint to `tesseract_jax` is the function `apply_tesseract`.\n",
68-
"Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together.\n",
67+
"The main entrypoint to `tesseract_jax` is `apply_tesseract()`.\n",
6968
"\n",
69+
"Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together.\n",
7070
"The result should be:\n",
7171
"\n",
7272
"$$\\begin{pmatrix} 1 \\\\ 2 \\\\ 3 \\end{pmatrix} + 2 \\cdot \\begin{pmatrix} 4 \\\\ 5 \\\\ 6 \\end{pmatrix} = \\begin{pmatrix} 9 \\\\ 12 \\\\ 15 \\end{pmatrix}$$"
@@ -83,7 +83,7 @@
8383
},
8484
{
8585
"cell_type": "code",
86-
"execution_count": 1,
86+
"execution_count": 2,
8787
"metadata": {},
8888
"outputs": [],
8989
"source": [
@@ -102,7 +102,7 @@
102102
},
103103
{
104104
"cell_type": "code",
105-
"execution_count": 2,
105+
"execution_count": 3,
106106
"metadata": {},
107107
"outputs": [
108108
{
@@ -118,7 +118,7 @@
118118
" 'abstract_eval']"
119119
]
120120
},
121-
"execution_count": 2,
121+
"execution_count": 4,
122122
"metadata": {},
123123
"output_type": "execute_result"
124124
}
@@ -147,7 +147,7 @@
147147
},
148148
{
149149
"cell_type": "code",
150-
"execution_count": 3,
150+
"execution_count": 5,
151151
"metadata": {},
152152
"outputs": [
153153
{
@@ -201,7 +201,7 @@
201201
},
202202
{
203203
"cell_type": "code",
204-
"execution_count": 4,
204+
"execution_count": 6,
205205
"metadata": {},
206206
"outputs": [
207207
{
@@ -210,7 +210,7 @@
210210
"Array(16.135319, dtype=float32)"
211211
]
212212
},
213-
"execution_count": 4,
213+
"execution_count": 7,
214214
"metadata": {},
215215
"output_type": "execute_result"
216216
}
@@ -241,7 +241,7 @@
241241
},
242242
{
243243
"cell_type": "code",
244-
"execution_count": 5,
244+
"execution_count": 8,
245245
"metadata": {},
246246
"outputs": [
247247
{
@@ -250,7 +250,7 @@
250250
"Array(16.135319, dtype=float32)"
251251
]
252252
},
253-
"execution_count": 5,
253+
"execution_count": 9,
254254
"metadata": {},
255255
"output_type": "execute_result"
256256
}
@@ -280,29 +280,27 @@
280280
},
281281
{
282282
"cell_type": "code",
283-
"execution_count": 6,
283+
"execution_count": 10,
284284
"metadata": {},
285285
"outputs": [
286286
{
287-
"data": {
288-
"text/plain": [
289-
"(Array(16.135319, dtype=float32), Array(25.004124, dtype=float32))"
290-
]
291-
},
292-
"execution_count": 6,
293-
"metadata": {},
294-
"output_type": "execute_result"
287+
"name": "stdout",
288+
"output_type": "stream",
289+
"text": [
290+
"primal=Array(16.135319, dtype=float32), jvp=Array(25.004124, dtype=float32)\n"
291+
]
295292
}
296293
],
297294
"source": [
298-
"jax.jvp(fancy_operation, (a, b), (a, b))"
295+
"primal, jvp = jax.jvp(fancy_operation, (a, b), (a, b))\n",
296+
"print(f\"{primal=}, {jvp=}\")"
299297
]
300298
},
301299
{
302300
"cell_type": "markdown",
303301
"metadata": {},
304302
"source": [
305-
"(where the first argument is the primal value, and the second is the Jacobian of fancy_operation calculated in $(a,b)$ multiplied with the vector $(a \\, a)$)."
303+
"Where `jvp` is the Jacobian of `fancy_operation` calculated in $(a,b)$ multiplied with the vector $(a, a)$."
306304
]
307305
},
308306
{
@@ -314,25 +312,22 @@
314312
},
315313
{
316314
"cell_type": "code",
317-
"execution_count": 7,
315+
"execution_count": 11,
318316
"metadata": {},
319317
"outputs": [
320318
{
321-
"data": {
322-
"text/plain": [
323-
"({'v': Array([-0.20733577, 0.56435245, -0.329298 ], dtype=float32)},\n",
324-
" {'s': Array(80.709854, dtype=float32),\n",
325-
" 'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})"
326-
]
327-
},
328-
"execution_count": 7,
329-
"metadata": {},
330-
"output_type": "execute_result"
319+
"name": "stdout",
320+
"output_type": "stream",
321+
"text": [
322+
"({'v': Array([-0.20733577, 0.56435245, -0.329298 ], dtype=float32)},\n",
323+
" {'s': Array(80.709854, dtype=float32),\n",
324+
" 'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})\n"
325+
]
331326
}
332327
],
333328
"source": [
334329
"primal, vjp = jax.vjp(fancy_operation, a, b)\n",
335-
"vjp(primal)"
330+
"pprint(vjp(primal))"
336331
]
337332
},
338333
{
@@ -348,12 +343,12 @@
348343
"source": [
349344
"#### Computing the gradient\n",
350345
"\n",
351-
"Let's calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$:"
346+
"Let's calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$. `a` is the first argument, so we pass `jax.grad()` a parameter `argnums=0`."
352347
]
353348
},
354349
{
355350
"cell_type": "code",
356-
"execution_count": 8,
351+
"execution_count": 12,
357352
"metadata": {},
358353
"outputs": [
359354
{
@@ -362,13 +357,42 @@
362357
"{'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)}"
363358
]
364359
},
365-
"execution_count": 8,
360+
"execution_count": 13,
366361
"metadata": {},
367362
"output_type": "execute_result"
368363
}
369364
],
370365
"source": [
371-
"jax.grad(fancy_operation)(a, b)"
366+
"jax.grad(fancy_operation, argnums=0)(a, b)"
367+
]
368+
},
369+
{
370+
"cell_type": "markdown",
371+
"metadata": {},
372+
"source": [
373+
"Or similar to our VJP calculation, we could calculate the gradients for both parameters `a` and `b` simultaneously."
374+
]
375+
},
376+
{
377+
"cell_type": "code",
378+
"execution_count": 14,
379+
"metadata": {},
380+
"outputs": [
381+
{
382+
"data": {
383+
"text/plain": [
384+
"({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n",
385+
" {'s': Array(5.002062, dtype=float32),\n",
386+
" 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})"
387+
]
388+
},
389+
"execution_count": 15,
390+
"metadata": {},
391+
"output_type": "execute_result"
392+
}
393+
],
394+
"source": [
395+
"jax.grad(fancy_operation, argnums=[0, 1])(a, b)"
372396
]
373397
},
374398
{
@@ -382,7 +406,7 @@
382406
},
383407
{
384408
"cell_type": "code",
385-
"execution_count": 9,
409+
"execution_count": 16,
386410
"metadata": {},
387411
"outputs": [
388412
{
@@ -391,7 +415,7 @@
391415
"{'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)}"
392416
]
393417
},
394-
"execution_count": 9,
418+
"execution_count": 17,
395419
"metadata": {},
396420
"output_type": "execute_result"
397421
}
@@ -409,14 +433,36 @@
409433
"jax.jit(jax.grad(jitted_op))(a, b)"
410434
]
411435
},
436+
{
437+
"cell_type": "markdown",
438+
"metadata": {},
439+
"source": [
440+
"## Teardown and conclusions"
441+
]
442+
},
443+
{
444+
"cell_type": "markdown",
445+
"metadata": {},
446+
"source": [
447+
"Since we kept the Tesseract alive using `.serve()`, now we need to stop it using `.teardown()`"
448+
]
449+
},
412450
{
413451
"cell_type": "code",
414-
"execution_count": 10,
452+
"execution_count": 18,
415453
"metadata": {},
416454
"outputs": [],
417455
"source": [
418456
"vectoradd.teardown()"
419457
]
458+
},
459+
{
460+
"cell_type": "markdown",
461+
"metadata": {},
462+
"source": [
463+
"And that's it!\n",
464+
"You've worked through building up differentiable pipelines with Tesseracts that blend seamlessly with JAX's API, thanks to Tesseract-JAX."
465+
]
420466
}
421467
],
422468
"metadata": {

0 commit comments

Comments
 (0)