|
64 | 64 | "source": [ |
65 | 65 | "## Run the Tesseract\n", |
66 | 66 | "\n", |
67 | | - "The main entrypoint to `tesseract_jax` is the function `apply_tesseract`.\n", |
68 | | - "Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together.\n", |
| 67 | + "The main entrypoint to `tesseract_jax` is `apply_tesseract()`.\n", |
69 | 68 | "\n", |
| 69 | + "Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together.\n", |
70 | 70 | "The result should be:\n", |
71 | 71 | "\n", |
72 | 72 | "$$\\begin{pmatrix} 1 \\\\ 2 \\\\ 3 \\end{pmatrix} + 2 \\cdot \\begin{pmatrix} 4 \\\\ 5 \\\\ 6 \\end{pmatrix} = \\begin{pmatrix} 9 \\\\ 12 \\\\ 15 \\end{pmatrix}$$" |
|
83 | 83 | }, |
84 | 84 | { |
85 | 85 | "cell_type": "code", |
86 | | - "execution_count": 1, |
| 86 | + "execution_count": 2, |
87 | 87 | "metadata": {}, |
88 | 88 | "outputs": [], |
89 | 89 | "source": [ |
|
102 | 102 | }, |
103 | 103 | { |
104 | 104 | "cell_type": "code", |
105 | | - "execution_count": 2, |
| 105 | + "execution_count": 3, |
106 | 106 | "metadata": {}, |
107 | 107 | "outputs": [ |
108 | 108 | { |
|
118 | 118 | " 'abstract_eval']" |
119 | 119 | ] |
120 | 120 | }, |
121 | | - "execution_count": 2, |
| 121 | + "execution_count": 4, |
122 | 122 | "metadata": {}, |
123 | 123 | "output_type": "execute_result" |
124 | 124 | } |
|
147 | 147 | }, |
148 | 148 | { |
149 | 149 | "cell_type": "code", |
150 | | - "execution_count": 3, |
| 150 | + "execution_count": 5, |
151 | 151 | "metadata": {}, |
152 | 152 | "outputs": [ |
153 | 153 | { |
|
201 | 201 | }, |
202 | 202 | { |
203 | 203 | "cell_type": "code", |
204 | | - "execution_count": 4, |
| 204 | + "execution_count": 6, |
205 | 205 | "metadata": {}, |
206 | 206 | "outputs": [ |
207 | 207 | { |
|
210 | 210 | "Array(16.135319, dtype=float32)" |
211 | 211 | ] |
212 | 212 | }, |
213 | | - "execution_count": 4, |
| 213 | + "execution_count": 7, |
214 | 214 | "metadata": {}, |
215 | 215 | "output_type": "execute_result" |
216 | 216 | } |
|
241 | 241 | }, |
242 | 242 | { |
243 | 243 | "cell_type": "code", |
244 | | - "execution_count": 5, |
| 244 | + "execution_count": 8, |
245 | 245 | "metadata": {}, |
246 | 246 | "outputs": [ |
247 | 247 | { |
|
250 | 250 | "Array(16.135319, dtype=float32)" |
251 | 251 | ] |
252 | 252 | }, |
253 | | - "execution_count": 5, |
| 253 | + "execution_count": 9, |
254 | 254 | "metadata": {}, |
255 | 255 | "output_type": "execute_result" |
256 | 256 | } |
|
280 | 280 | }, |
281 | 281 | { |
282 | 282 | "cell_type": "code", |
283 | | - "execution_count": 6, |
| 283 | + "execution_count": 10, |
284 | 284 | "metadata": {}, |
285 | 285 | "outputs": [ |
286 | 286 | { |
287 | | - "data": { |
288 | | - "text/plain": [ |
289 | | - "(Array(16.135319, dtype=float32), Array(25.004124, dtype=float32))" |
290 | | - ] |
291 | | - }, |
292 | | - "execution_count": 6, |
293 | | - "metadata": {}, |
294 | | - "output_type": "execute_result" |
| 287 | + "name": "stdout", |
| 288 | + "output_type": "stream", |
| 289 | + "text": [ |
| 290 | + "primal=Array(16.135319, dtype=float32), jvp=Array(25.004124, dtype=float32)\n" |
| 291 | + ] |
295 | 292 | } |
296 | 293 | ], |
297 | 294 | "source": [ |
298 | | - "jax.jvp(fancy_operation, (a, b), (a, b))" |
| 295 | + "primal, jvp = jax.jvp(fancy_operation, (a, b), (a, b))\n", |
| 296 | + "print(f\"{primal=}, {jvp=}\")" |
299 | 297 | ] |
300 | 298 | }, |
301 | 299 | { |
302 | 300 | "cell_type": "markdown", |
303 | 301 | "metadata": {}, |
304 | 302 | "source": [ |
305 | | - "(where the first argument is the primal value, and the second is the Jacobian of fancy_operation calculated in $(a,b)$ multiplied with the vector $(a \\, a)$)." |
| 303 | + "Where `jvp` is the Jacobian of `fancy_operation` calculated in $(a,b)$ multiplied with the vector $(a, a)$." |
306 | 304 | ] |
307 | 305 | }, |
308 | 306 | { |
|
314 | 312 | }, |
315 | 313 | { |
316 | 314 | "cell_type": "code", |
317 | | - "execution_count": 7, |
| 315 | + "execution_count": 11, |
318 | 316 | "metadata": {}, |
319 | 317 | "outputs": [ |
320 | 318 | { |
321 | | - "data": { |
322 | | - "text/plain": [ |
323 | | - "({'v': Array([-0.20733577, 0.56435245, -0.329298 ], dtype=float32)},\n", |
324 | | - " {'s': Array(80.709854, dtype=float32),\n", |
325 | | - " 'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})" |
326 | | - ] |
327 | | - }, |
328 | | - "execution_count": 7, |
329 | | - "metadata": {}, |
330 | | - "output_type": "execute_result" |
| 319 | + "name": "stdout", |
| 320 | + "output_type": "stream", |
| 321 | + "text": [ |
| 322 | + "({'v': Array([-0.20733577, 0.56435245, -0.329298 ], dtype=float32)},\n", |
| 323 | + " {'s': Array(80.709854, dtype=float32),\n", |
| 324 | + " 'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})\n" |
| 325 | + ] |
331 | 326 | } |
332 | 327 | ], |
333 | 328 | "source": [ |
334 | 329 | "primal, vjp = jax.vjp(fancy_operation, a, b)\n", |
335 | | - "vjp(primal)" |
| 330 | + "pprint(vjp(primal))" |
336 | 331 | ] |
337 | 332 | }, |
338 | 333 | { |
|
348 | 343 | "source": [ |
349 | 344 | "#### Computing the gradient\n", |
350 | 345 | "\n", |
351 | | - "Let's calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$:" |
| 346 | + "Let's calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$. `a` is the first argument, so we pass `jax.grad()` a parameter `argnums=0`." |
352 | 347 | ] |
353 | 348 | }, |
354 | 349 | { |
355 | 350 | "cell_type": "code", |
356 | | - "execution_count": 8, |
| 351 | + "execution_count": 12, |
357 | 352 | "metadata": {}, |
358 | 353 | "outputs": [ |
359 | 354 | { |
|
362 | 357 | "{'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)}" |
363 | 358 | ] |
364 | 359 | }, |
365 | | - "execution_count": 8, |
| 360 | + "execution_count": 13, |
366 | 361 | "metadata": {}, |
367 | 362 | "output_type": "execute_result" |
368 | 363 | } |
369 | 364 | ], |
370 | 365 | "source": [ |
371 | | - "jax.grad(fancy_operation)(a, b)" |
| 366 | + "jax.grad(fancy_operation, argnums=0)(a, b)" |
| 367 | + ] |
| 368 | + }, |
| 369 | + { |
| 370 | + "cell_type": "markdown", |
| 371 | + "metadata": {}, |
| 372 | + "source": [ |
| 373 | + "Or similar to our VJP calculation, we could calculate the gradients for both parameters `a` and `b` simultaneously." |
| 374 | + ] |
| 375 | + }, |
| 376 | + { |
| 377 | + "cell_type": "code", |
| 378 | + "execution_count": 14, |
| 379 | + "metadata": {}, |
| 380 | + "outputs": [ |
| 381 | + { |
| 382 | + "data": { |
| 383 | + "text/plain": [ |
| 384 | + "({'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)},\n", |
| 385 | + " {'s': Array(5.002062, dtype=float32),\n", |
| 386 | + " 'v': Array([-0.05139923, 3.139905 , -0.08163408], dtype=float32)})" |
| 387 | + ] |
| 388 | + }, |
| 389 | + "execution_count": 15, |
| 390 | + "metadata": {}, |
| 391 | + "output_type": "execute_result" |
| 392 | + } |
| 393 | + ], |
| 394 | + "source": [ |
| 395 | + "jax.grad(fancy_operation, argnums=[0, 1])(a, b)" |
372 | 396 | ] |
373 | 397 | }, |
374 | 398 | { |
|
382 | 406 | }, |
383 | 407 | { |
384 | 408 | "cell_type": "code", |
385 | | - "execution_count": 9, |
| 409 | + "execution_count": 16, |
386 | 410 | "metadata": {}, |
387 | 411 | "outputs": [ |
388 | 412 | { |
|
391 | 415 | "{'v': Array([-0.01284981, 0.03497622, -0.02040852], dtype=float32)}" |
392 | 416 | ] |
393 | 417 | }, |
394 | | - "execution_count": 9, |
| 418 | + "execution_count": 17, |
395 | 419 | "metadata": {}, |
396 | 420 | "output_type": "execute_result" |
397 | 421 | } |
|
409 | 433 | "jax.jit(jax.grad(jitted_op))(a, b)" |
410 | 434 | ] |
411 | 435 | }, |
| 436 | + { |
| 437 | + "cell_type": "markdown", |
| 438 | + "metadata": {}, |
| 439 | + "source": [ |
| 440 | + "## Teardown and conclusions" |
| 441 | + ] |
| 442 | + }, |
| 443 | + { |
| 444 | + "cell_type": "markdown", |
| 445 | + "metadata": {}, |
| 446 | + "source": [ |
| 447 | + "Since we kept the Tesseract alive using `.serve()`, now we need to stop it using `.teardown()`" |
| 448 | + ] |
| 449 | + }, |
412 | 450 | { |
413 | 451 | "cell_type": "code", |
414 | | - "execution_count": 10, |
| 452 | + "execution_count": 18, |
415 | 453 | "metadata": {}, |
416 | 454 | "outputs": [], |
417 | 455 | "source": [ |
418 | 456 | "vectoradd.teardown()" |
419 | 457 | ] |
| 458 | + }, |
| 459 | + { |
| 460 | + "cell_type": "markdown", |
| 461 | + "metadata": {}, |
| 462 | + "source": [ |
| 463 | + "And that's it!\n", |
| 464 | + "You've worked through building up differentiable pipelines with Tesseracts that blend seamlessly with JAX's API, thanks to Tesseract-JAX." |
| 465 | + ] |
420 | 466 | } |
421 | 467 | ], |
422 | 468 | "metadata": { |
|
0 commit comments