|
49 | 49 | }, |
50 | 50 | { |
51 | 51 | "cell_type": "code", |
52 | | - "execution_count": 1, |
| 52 | + "execution_count": 7, |
53 | 53 | "metadata": { |
54 | | - "colab": { |
55 | | - "base_uri": "https://localhost:8080/" |
56 | | - }, |
57 | | - "id": "hVi6mApuVw3r", |
58 | | - "outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf" |
| 54 | + "id": "hVi6mApuVw3r" |
59 | 55 | }, |
60 | 56 | "outputs": [], |
61 | 57 | "source": [ |
|
84 | 80 | }, |
85 | 81 | { |
86 | 82 | "cell_type": "code", |
87 | | - "execution_count": 2, |
| 83 | + "execution_count": 8, |
88 | 84 | "metadata": { |
89 | 85 | "colab": { |
90 | 86 | "base_uri": "https://localhost:8080/" |
91 | 87 | }, |
92 | 88 | "id": "mzDIDvj7Vw0k", |
93 | | - "outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434" |
| 89 | + "outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a" |
94 | 90 | }, |
95 | 91 | "outputs": [ |
96 | 92 | { |
|
119 | 115 | }, |
120 | 116 | { |
121 | 117 | "cell_type": "code", |
122 | | - "execution_count": 3, |
| 118 | + "execution_count": 9, |
123 | 119 | "metadata": { |
124 | 120 | "colab": { |
125 | 121 | "base_uri": "https://localhost:8080/" |
126 | 122 | }, |
127 | 123 | "id": "IyPx_-IBVwxr", |
128 | | - "outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499" |
| 124 | + "outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb" |
129 | 125 | }, |
130 | 126 | "outputs": [ |
131 | 127 | { |
|
141 | 137 | "Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)" |
142 | 138 | ] |
143 | 139 | }, |
144 | | - "execution_count": 3, |
| 140 | + "execution_count": 9, |
145 | 141 | "metadata": {}, |
146 | 142 | "output_type": "execute_result" |
147 | 143 | } |
|
172 | 168 | }, |
173 | 169 | { |
174 | 170 | "cell_type": "code", |
175 | | - "execution_count": 4, |
| 171 | + "execution_count": 10, |
176 | 172 | "metadata": { |
177 | 173 | "colab": { |
178 | 174 | "base_uri": "https://localhost:8080/" |
179 | 175 | }, |
180 | 176 | "id": "NO2ulM_QW7a8", |
181 | | - "outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb" |
| 177 | + "outputId": "d888371b-080e-4bff-be5d-ea56beda3aac" |
182 | 178 | }, |
183 | 179 | "outputs": [ |
184 | 180 | { |
|
208 | 204 | }, |
209 | 205 | { |
210 | 206 | "cell_type": "code", |
211 | | - "execution_count": 5, |
| 207 | + "execution_count": 11, |
212 | 208 | "metadata": { |
213 | 209 | "colab": { |
214 | 210 | "base_uri": "https://localhost:8080/" |
215 | 211 | }, |
216 | 212 | "id": "1-TzmA0AXCAf", |
217 | | - "outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71" |
| 213 | + "outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2" |
218 | 214 | }, |
219 | 215 | "outputs": [ |
220 | 216 | { |
|
256 | 252 | }, |
257 | 253 | { |
258 | 254 | "cell_type": "code", |
259 | | - "execution_count": 6, |
| 255 | + "execution_count": 12, |
260 | 256 | "metadata": { |
261 | 257 | "colab": { |
262 | 258 | "base_uri": "https://localhost:8080/" |
263 | 259 | }, |
264 | 260 | "id": "Gy7ABds3XND3", |
265 | | - "outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b" |
| 261 | + "outputId": "0d72dad2-381a-4e96-f771-40d705da1376" |
266 | 262 | }, |
267 | 263 | "outputs": [ |
268 | 264 | { |
|
297 | 293 | }, |
298 | 294 | { |
299 | 295 | "cell_type": "code", |
300 | | - "execution_count": 7, |
| 296 | + "execution_count": 13, |
301 | 297 | "metadata": { |
302 | 298 | "colab": { |
303 | 299 | "base_uri": "https://localhost:8080/" |
304 | 300 | }, |
305 | 301 | "id": "grCcotr-XQjY", |
306 | | - "outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a" |
| 302 | + "outputId": "c2db656c-809f-49a6-c948-629d6420360c" |
307 | 303 | }, |
308 | 304 | "outputs": [ |
309 | 305 | { |
|
324 | 320 | " [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)" |
325 | 321 | ] |
326 | 322 | }, |
327 | | - "execution_count": 7, |
| 323 | + "execution_count": 13, |
328 | 324 | "metadata": {}, |
329 | 325 | "output_type": "execute_result" |
330 | 326 | } |
|
460 | 456 | }, |
461 | 457 | { |
462 | 458 | "cell_type": "code", |
463 | | - "execution_count": 13, |
| 459 | + "execution_count": 14, |
464 | 460 | "metadata": { |
465 | 461 | "colab": { |
466 | 462 | "base_uri": "https://localhost:8080/" |
467 | 463 | }, |
468 | 464 | "id": "fpFEaMBcXsJG", |
469 | | - "outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660" |
| 465 | + "outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef" |
470 | 466 | }, |
471 | 467 | "outputs": [ |
472 | 468 | { |
|
479 | 475 | "We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n", |
480 | 476 | "Result type: ShapedArray(int32[4@X,4])\n" |
481 | 477 | ] |
482 | | - }, |
483 | | - { |
484 | | - "name": "stdout", |
485 | | - "output_type": "stream", |
486 | | - "text": [ |
487 | | - "Result type: ShapedArray(int32[4@X,4])\n" |
488 | | - ] |
489 | 478 | } |
490 | 479 | ], |
491 | 480 | "source": [ |
|
550 | 539 | }, |
551 | 540 | { |
552 | 541 | "cell_type": "code", |
553 | | - "execution_count": 10, |
| 542 | + "execution_count": 15, |
554 | 543 | "metadata": { |
555 | 544 | "colab": { |
556 | 545 | "base_uri": "https://localhost:8080/" |
557 | 546 | }, |
558 | 547 | "id": "geptWrdYX0OM", |
559 | | - "outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f" |
| 548 | + "outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f" |
560 | 549 | }, |
561 | 550 | "outputs": [ |
562 | 551 | { |
|
588 | 577 | { |
589 | 578 | "cell_type": "markdown", |
590 | 579 | "metadata": { |
591 | | - "id": "AQQjzUeGX4P6" |
| 580 | + "id": "LZWjgiMZ7uSS" |
| 581 | + }, |
| 582 | + "source": [ |
| 583 | + "You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:" |
| 584 | + ] |
| 585 | + }, |
| 586 | + { |
| 587 | + "cell_type": "code", |
| 588 | + "execution_count": 27, |
| 589 | + "metadata": { |
| 590 | + "colab": { |
| 591 | + "base_uri": "https://localhost:8080/" |
| 592 | + }, |
| 593 | + "id": "IVzPSkp77uCF", |
| 594 | + "outputId": "db80a604-98ac-4343-8677-23729adf7ffc" |
| 595 | + }, |
| 596 | + "outputs": [ |
| 597 | + { |
| 598 | + "name": "stdout", |
| 599 | + "output_type": "stream", |
| 600 | + "text": [ |
| 601 | + "mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n", |
| 602 | + "x.sharding: ShapedArray(float32[4@X,4@Y])\n", |
| 603 | + "\n", |
| 604 | + "mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n", |
| 605 | + "y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n", |
| 606 | + "\n", |
| 607 | + "z.sharding: ShapedArray(float32[4@X,4@Y])\n", |
| 608 | + "\n" |
| 609 | + ] |
| 610 | + }, |
| 611 | + { |
| 612 | + "data": { |
| 613 | + "text/plain": [ |
| 614 | + "Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n", |
| 615 | + " [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n", |
| 616 | + " [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n", |
| 617 | + " [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)" |
| 618 | + ] |
| 619 | + }, |
| 620 | + "execution_count": 27, |
| 621 | + "metadata": {}, |
| 622 | + "output_type": "execute_result" |
| 623 | + } |
| 624 | + ], |
| 625 | + "source": [ |
| 626 | + "import functools\n", |
| 627 | + "\n", |
| 628 | + "@functools.partial(auto_axes, axes='X')\n", |
| 629 | + "def g(y):\n", |
| 630 | + " print(f'mesh inside g: {get_abstract_mesh()}')\n", |
| 631 | + " print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n", |
| 632 | + " return y * 2\n", |
| 633 | + "\n", |
| 634 | + "@jax.jit\n", |
| 635 | + "def f(arr1):\n", |
| 636 | + " print(f'mesh inside f: {get_abstract_mesh()}')\n", |
| 637 | + " x = jnp.sin(arr1)\n", |
| 638 | + " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", |
| 639 | + "\n", |
| 640 | + " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", |
| 641 | + "\n", |
| 642 | + " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", |
| 643 | + " return z + 1\n", |
| 644 | + "\n", |
| 645 | + "some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", |
| 646 | + "f(some_x)" |
| 647 | + ] |
| 648 | + }, |
| 649 | + { |
| 650 | + "cell_type": "markdown", |
| 651 | + "metadata": { |
| 652 | + "id": "_3sfJjRq8w9f" |
| 653 | + }, |
| 654 | + "source": [ |
| 655 | + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." |
| 656 | + ] |
| 657 | + }, |
| 658 | + { |
| 659 | + "cell_type": "markdown", |
| 660 | + "metadata": { |
| 661 | + "id": "sJcWbfAh7UcO" |
592 | 662 | }, |
593 | 663 | "source": [ |
594 | 664 | "## Concrete array shardings can mention `Auto` mesh axis\n", |
|
606 | 676 | }, |
607 | 677 | { |
608 | 678 | "cell_type": "code", |
609 | | - "execution_count": 25, |
| 679 | + "execution_count": null, |
610 | 680 | "metadata": { |
611 | 681 | "colab": { |
612 | 682 | "base_uri": "https://localhost:8080/" |
|
708 | 778 | } |
709 | 779 | }, |
710 | 780 | "nbformat": 4, |
711 | | - "nbformat_minor": 4 |
| 781 | + "nbformat_minor": 0 |
712 | 782 | } |
0 commit comments