|
45 | 45 | "import jax\n", |
46 | 46 | "from jax import lax\n", |
47 | 47 | "from jax import numpy as jnp\n", |
48 | | - "from jax.experimental import mesh_utils\n", |
49 | 48 | "from jax.experimental import pallas as pl\n", |
50 | 49 | "from jax.experimental import shard_map\n", |
51 | 50 | "from jax.experimental.pallas import tpu as pltpu\n", |
|
245 | 244 | ], |
246 | 245 | "source": [ |
247 | 246 | "partition = P(None, 'x')\n", |
248 | | - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", |
249 | | - "mesh = jax.sharding.Mesh(devices, partition)\n", |
| 247 | + "mesh = jax.make_mesh((num_devices,), ('x',))\n", |
250 | 248 | "sharding = jax.sharding.NamedSharding(mesh, partition)\n", |
251 | 249 | "\n", |
252 | 250 | "# Create an input array that shards the last dimension across\n", |
|
263 | 261 | " dst_ref=output_ref,\n", |
264 | 262 | " send_sem=send_sem,\n", |
265 | 263 | " recv_sem=recv_sem,\n", |
266 | | - " device_id=(0, right_neighbor),\n", |
| 264 | + " device_id=(right_neighbor,),\n", |
267 | 265 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
268 | 266 | " )\n", |
269 | 267 | " remote_copy_op.start()\n", |
|
373 | 371 | ], |
374 | 372 | "source": [ |
375 | 373 | "partition = P('x', None)\n", |
376 | | - "devices = mesh_utils.create_device_mesh((num_devices, 1))\n", |
377 | | - "mesh = jax.sharding.Mesh(devices, partition)\n", |
| 374 | + "mesh = jax.make_mesh((num_devices,), ('x',))\n", |
378 | 375 | "sharding = jax.sharding.NamedSharding(mesh, partition)\n", |
379 | 376 | "\n", |
380 | 377 | "# Create an input array that shards the first dimension across\n", |
|
413 | 410 | " dst_ref=output_ref.at[copy_slot],\n", |
414 | 411 | " send_sem=send_sem,\n", |
415 | 412 | " recv_sem=recv_sems.at[outer_step],\n", |
416 | | - " device_id=(right_neighbor, 0),\n", |
| 413 | + " device_id=(right_neighbor,),\n", |
417 | 414 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
418 | 415 | " )\n", |
419 | 416 | " remote_copy_op.start()\n", |
|
683 | 680 | ], |
684 | 681 | "source": [ |
685 | 682 | "partition = P(None, 'x')\n", |
686 | | - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", |
687 | | - "mesh = jax.sharding.Mesh(devices, partition)\n", |
| 683 | + "mesh = jax.make_mesh((num_devices,), ('x',))\n", |
688 | 684 | "sharding = jax.sharding.NamedSharding(mesh, partition)\n", |
689 | 685 | "\n", |
690 | 686 | "input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n", |
|
717 | 713 | " pltpu.semaphore_signal(\n", |
718 | 714 | " barrier_sem,\n", |
719 | 715 | " inc=1,\n", |
720 | | - " device_id=(0, left_neighbor),\n", |
| 716 | + " device_id=(left_neighbor,),\n", |
721 | 717 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
722 | 718 | " )\n", |
723 | 719 | " pltpu.semaphore_signal(\n", |
724 | 720 | " barrier_sem,\n", |
725 | 721 | " inc=1,\n", |
726 | | - " device_id=(0, right_neighbor),\n", |
| 722 | + " device_id=(right_neighbor,),\n", |
727 | 723 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
728 | 724 | " )\n", |
729 | 725 | " pltpu.semaphore_wait(barrier_sem, 2)\n", |
|
736 | 732 | " dst_ref=hbm_scratch.at[working_slot],\n", |
737 | 733 | " send_sem=remote_send_sem,\n", |
738 | 734 | " recv_sem=remote_recv_sem,\n", |
739 | | - " device_id=(0, right_neighbor),\n", |
| 735 | + " device_id=(right_neighbor,),\n", |
740 | 736 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
741 | 737 | " )\n", |
742 | 738 | " initial_copy.start()\n", |
|
748 | 744 | " pltpu.semaphore_signal(\n", |
749 | 745 | " capacity_sem,\n", |
750 | 746 | " inc=1,\n", |
751 | | - " device_id=(0, left_neighbor),\n", |
| 747 | + " device_id=(left_neighbor,),\n", |
752 | 748 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
753 | 749 | " )\n", |
754 | 750 | "\n", |
|
769 | 765 | " dst_ref=hbm_scratch.at[receiving_slot],\n", |
770 | 766 | " send_sem=remote_send_sem,\n", |
771 | 767 | " recv_sem=remote_recv_sem,\n", |
772 | | - " device_id=(0, right_neighbor),\n", |
| 768 | + " device_id=(right_neighbor,),\n", |
773 | 769 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
774 | 770 | " )\n", |
775 | 771 | " remote_copy.start()\n", |
|
913 | 909 | "outputs": [], |
914 | 910 | "source": [ |
915 | 911 | "partition = P(None, 'x')\n", |
916 | | - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", |
917 | | - "mesh = jax.sharding.Mesh(devices, partition)\n", |
| 912 | + "mesh = jax.make_mesh((num_devices,), ('x',))\n", |
918 | 913 | "sharding = jax.sharding.NamedSharding(mesh, partition)\n", |
919 | 914 | "\n", |
920 | 915 | "# We need a block size of (16, 128) to ensure that a half-slice is at least\n", |
|
944 | 939 | " pltpu.semaphore_signal(\n", |
945 | 940 | " semaphore,\n", |
946 | 941 | " inc=1,\n", |
947 | | - " device_id=(0, neighbor),\n", |
| 942 | + " device_id=(neighbor,),\n", |
948 | 943 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
949 | 944 | " )\n", |
950 | 945 | "\n", |
|
985 | 980 | " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", |
986 | 981 | " send_sem=left_send_sem,\n", |
987 | 982 | " recv_sem=left_recv_sem,\n", |
988 | | - " device_id=(0, left_neighbor),\n", |
| 983 | + " device_id=(left_neighbor,),\n", |
989 | 984 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
990 | 985 | " )\n", |
991 | 986 | "\n", |
|
994 | 989 | " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", |
995 | 990 | " send_sem=right_send_sem,\n", |
996 | 991 | " recv_sem=right_recv_sem,\n", |
997 | | - " device_id=(0, right_neighbor),\n", |
| 992 | + " device_id=(right_neighbor,),\n", |
998 | 993 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
999 | 994 | " )\n", |
1000 | 995 | "\n", |
|
1003 | 998 | " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", |
1004 | 999 | " send_sem=left_send_sem,\n", |
1005 | 1000 | " recv_sem=left_recv_sem,\n", |
1006 | | - " device_id=(0, left_neighbor),\n", |
| 1001 | + " device_id=(left_neighbor,),\n", |
1007 | 1002 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1008 | 1003 | " )\n", |
1009 | 1004 | " right_copy = pltpu.make_async_remote_copy(\n", |
|
1013 | 1008 | " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", |
1014 | 1009 | " send_sem=right_send_sem,\n", |
1015 | 1010 | " recv_sem=right_recv_sem,\n", |
1016 | | - " device_id=(0, right_neighbor),\n", |
| 1011 | + " device_id=(right_neighbor,),\n", |
1017 | 1012 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1018 | 1013 | " )\n", |
1019 | 1014 | "\n", |
|
1026 | 1021 | " pltpu.semaphore_signal(\n", |
1027 | 1022 | " barrier_sem,\n", |
1028 | 1023 | " inc=1,\n", |
1029 | | - " device_id=(0, left_neighbor),\n", |
| 1024 | + " device_id=(left_neighbor,),\n", |
1030 | 1025 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1031 | 1026 | " )\n", |
1032 | 1027 | " pltpu.semaphore_signal(\n", |
1033 | 1028 | " barrier_sem,\n", |
1034 | 1029 | " inc=1,\n", |
1035 | | - " device_id=(0, right_neighbor),\n", |
| 1030 | + " device_id=(right_neighbor,),\n", |
1036 | 1031 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1037 | 1032 | " )\n", |
1038 | 1033 | " pltpu.semaphore_wait(barrier_sem, 2)\n", |
|
1378 | 1373 | "outputs": [], |
1379 | 1374 | "source": [ |
1380 | 1375 | "partition = P(None, 'x')\n", |
1381 | | - "devices = mesh_utils.create_device_mesh((1, num_devices))\n", |
1382 | | - "mesh = jax.sharding.Mesh(devices, partition)\n", |
| 1376 | + "mesh = jax.make_mesh((num_devices,), ('x',))\n", |
1383 | 1377 | "sharding = jax.sharding.NamedSharding(mesh, partition)\n", |
1384 | 1378 | "\n", |
1385 | 1379 | "# We pick a large outer kernel block size that we do not want to place\n", |
|
1445 | 1439 | " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", |
1446 | 1440 | " send_sem=left_send_sem,\n", |
1447 | 1441 | " recv_sem=left_recv_sem,\n", |
1448 | | - " device_id=(0, left_neighbor),\n", |
| 1442 | + " device_id=(left_neighbor,),\n", |
1449 | 1443 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1450 | 1444 | " )\n", |
1451 | 1445 | "\n", |
|
1454 | 1448 | " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", |
1455 | 1449 | " send_sem=right_send_sem,\n", |
1456 | 1450 | " recv_sem=right_recv_sem,\n", |
1457 | | - " device_id=(0, right_neighbor),\n", |
| 1451 | + " device_id=(right_neighbor,),\n", |
1458 | 1452 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1459 | 1453 | " )\n", |
1460 | 1454 | "\n", |
|
1463 | 1457 | " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", |
1464 | 1458 | " send_sem=left_send_sem,\n", |
1465 | 1459 | " recv_sem=left_recv_sem,\n", |
1466 | | - " device_id=(0, left_neighbor),\n", |
| 1460 | + " device_id=(left_neighbor,),\n", |
1467 | 1461 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1468 | 1462 | " )\n", |
1469 | 1463 | " right_copy = pltpu.make_async_remote_copy(\n", |
1470 | 1464 | " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n", |
1471 | 1465 | " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", |
1472 | 1466 | " send_sem=right_send_sem,\n", |
1473 | 1467 | " recv_sem=right_recv_sem,\n", |
1474 | | - " device_id=(0, right_neighbor),\n", |
| 1468 | + " device_id=(right_neighbor,),\n", |
1475 | 1469 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1476 | 1470 | " )\n", |
1477 | 1471 | "\n", |
|
1484 | 1478 | " pltpu.semaphore_signal(\n", |
1485 | 1479 | " barrier_sem,\n", |
1486 | 1480 | " inc=1,\n", |
1487 | | - " device_id=(0, left_neighbor),\n", |
| 1481 | + " device_id=(left_neighbor,),\n", |
1488 | 1482 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1489 | 1483 | " )\n", |
1490 | 1484 | " pltpu.semaphore_signal(\n", |
1491 | 1485 | " barrier_sem,\n", |
1492 | 1486 | " inc=1,\n", |
1493 | | - " device_id=(0, right_neighbor),\n", |
| 1487 | + " device_id=(right_neighbor,),\n", |
1494 | 1488 | " device_id_type=pltpu.DeviceIdType.MESH,\n", |
1495 | 1489 | " )\n", |
1496 | 1490 | " pltpu.semaphore_wait(barrier_sem, 2)\n", |
|
0 commit comments