Skip to content

Commit 0b46a23

Browse files
committed
Update Pallas distributed tutorials with jax.make_mesh
1 parent 16fca38 commit 0b46a23

File tree

2 files changed

+50
-62
lines changed

2 files changed

+50
-62
lines changed

docs/pallas/tpu/distributed.ipynb

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
"import jax\n",
4646
"from jax import lax\n",
4747
"from jax import numpy as jnp\n",
48-
"from jax.experimental import mesh_utils\n",
4948
"from jax.experimental import pallas as pl\n",
5049
"from jax.experimental import shard_map\n",
5150
"from jax.experimental.pallas import tpu as pltpu\n",
@@ -245,8 +244,7 @@
245244
],
246245
"source": [
247246
"partition = P(None, 'x')\n",
248-
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
249-
"mesh = jax.sharding.Mesh(devices, partition)\n",
247+
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
250248
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
251249
"\n",
252250
"# Create an input array that shards the last dimension across\n",
@@ -263,7 +261,7 @@
263261
" dst_ref=output_ref,\n",
264262
" send_sem=send_sem,\n",
265263
" recv_sem=recv_sem,\n",
266-
" device_id=(0, right_neighbor),\n",
264+
" device_id=(right_neighbor,),\n",
267265
" device_id_type=pltpu.DeviceIdType.MESH,\n",
268266
" )\n",
269267
" remote_copy_op.start()\n",
@@ -373,8 +371,7 @@
373371
],
374372
"source": [
375373
"partition = P('x', None)\n",
376-
"devices = mesh_utils.create_device_mesh((num_devices, 1))\n",
377-
"mesh = jax.sharding.Mesh(devices, partition)\n",
374+
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
378375
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
379376
"\n",
380377
"# Create an input array that shards the first dimension across\n",
@@ -413,7 +410,7 @@
413410
" dst_ref=output_ref.at[copy_slot],\n",
414411
" send_sem=send_sem,\n",
415412
" recv_sem=recv_sems.at[outer_step],\n",
416-
" device_id=(right_neighbor, 0),\n",
413+
" device_id=(right_neighbor,),\n",
417414
" device_id_type=pltpu.DeviceIdType.MESH,\n",
418415
" )\n",
419416
" remote_copy_op.start()\n",
@@ -683,8 +680,7 @@
683680
],
684681
"source": [
685682
"partition = P(None, 'x')\n",
686-
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
687-
"mesh = jax.sharding.Mesh(devices, partition)\n",
683+
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
688684
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
689685
"\n",
690686
"input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n",
@@ -717,13 +713,13 @@
717713
" pltpu.semaphore_signal(\n",
718714
" barrier_sem,\n",
719715
" inc=1,\n",
720-
" device_id=(0, left_neighbor),\n",
716+
" device_id=(left_neighbor,),\n",
721717
" device_id_type=pltpu.DeviceIdType.MESH,\n",
722718
" )\n",
723719
" pltpu.semaphore_signal(\n",
724720
" barrier_sem,\n",
725721
" inc=1,\n",
726-
" device_id=(0, right_neighbor),\n",
722+
" device_id=(right_neighbor,),\n",
727723
" device_id_type=pltpu.DeviceIdType.MESH,\n",
728724
" )\n",
729725
" pltpu.semaphore_wait(barrier_sem, 2)\n",
@@ -736,7 +732,7 @@
736732
" dst_ref=hbm_scratch.at[working_slot],\n",
737733
" send_sem=remote_send_sem,\n",
738734
" recv_sem=remote_recv_sem,\n",
739-
" device_id=(0, right_neighbor),\n",
735+
" device_id=(right_neighbor,),\n",
740736
" device_id_type=pltpu.DeviceIdType.MESH,\n",
741737
" )\n",
742738
" initial_copy.start()\n",
@@ -748,7 +744,7 @@
748744
" pltpu.semaphore_signal(\n",
749745
" capacity_sem,\n",
750746
" inc=1,\n",
751-
" device_id=(0, left_neighbor),\n",
747+
" device_id=(left_neighbor,),\n",
752748
" device_id_type=pltpu.DeviceIdType.MESH,\n",
753749
" )\n",
754750
"\n",
@@ -769,7 +765,7 @@
769765
" dst_ref=hbm_scratch.at[receiving_slot],\n",
770766
" send_sem=remote_send_sem,\n",
771767
" recv_sem=remote_recv_sem,\n",
772-
" device_id=(0, right_neighbor),\n",
768+
" device_id=(right_neighbor,),\n",
773769
" device_id_type=pltpu.DeviceIdType.MESH,\n",
774770
" )\n",
775771
" remote_copy.start()\n",
@@ -913,8 +909,7 @@
913909
"outputs": [],
914910
"source": [
915911
"partition = P(None, 'x')\n",
916-
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
917-
"mesh = jax.sharding.Mesh(devices, partition)\n",
912+
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
918913
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
919914
"\n",
920915
"# We need a block size of (16, 128) to ensure that a half-slice is at least\n",
@@ -944,7 +939,7 @@
944939
" pltpu.semaphore_signal(\n",
945940
" semaphore,\n",
946941
" inc=1,\n",
947-
" device_id=(0, neighbor),\n",
942+
" device_id=(neighbor,),\n",
948943
" device_id_type=pltpu.DeviceIdType.MESH,\n",
949944
" )\n",
950945
"\n",
@@ -985,7 +980,7 @@
985980
" dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
986981
" send_sem=left_send_sem,\n",
987982
" recv_sem=left_recv_sem,\n",
988-
" device_id=(0, left_neighbor),\n",
983+
" device_id=(left_neighbor,),\n",
989984
" device_id_type=pltpu.DeviceIdType.MESH,\n",
990985
" )\n",
991986
"\n",
@@ -994,7 +989,7 @@
994989
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
995990
" send_sem=right_send_sem,\n",
996991
" recv_sem=right_recv_sem,\n",
997-
" device_id=(0, right_neighbor),\n",
992+
" device_id=(right_neighbor,),\n",
998993
" device_id_type=pltpu.DeviceIdType.MESH,\n",
999994
" )\n",
1000995
"\n",
@@ -1003,7 +998,7 @@
1003998
" dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
1004999
" send_sem=left_send_sem,\n",
10051000
" recv_sem=left_recv_sem,\n",
1006-
" device_id=(0, left_neighbor),\n",
1001+
" device_id=(left_neighbor,),\n",
10071002
" device_id_type=pltpu.DeviceIdType.MESH,\n",
10081003
" )\n",
10091004
" right_copy = pltpu.make_async_remote_copy(\n",
@@ -1013,7 +1008,7 @@
10131008
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
10141009
" send_sem=right_send_sem,\n",
10151010
" recv_sem=right_recv_sem,\n",
1016-
" device_id=(0, right_neighbor),\n",
1011+
" device_id=(right_neighbor,),\n",
10171012
" device_id_type=pltpu.DeviceIdType.MESH,\n",
10181013
" )\n",
10191014
"\n",
@@ -1026,13 +1021,13 @@
10261021
" pltpu.semaphore_signal(\n",
10271022
" barrier_sem,\n",
10281023
" inc=1,\n",
1029-
" device_id=(0, left_neighbor),\n",
1024+
" device_id=(left_neighbor,),\n",
10301025
" device_id_type=pltpu.DeviceIdType.MESH,\n",
10311026
" )\n",
10321027
" pltpu.semaphore_signal(\n",
10331028
" barrier_sem,\n",
10341029
" inc=1,\n",
1035-
" device_id=(0, right_neighbor),\n",
1030+
" device_id=(right_neighbor,),\n",
10361031
" device_id_type=pltpu.DeviceIdType.MESH,\n",
10371032
" )\n",
10381033
" pltpu.semaphore_wait(barrier_sem, 2)\n",
@@ -1378,8 +1373,7 @@
13781373
"outputs": [],
13791374
"source": [
13801375
"partition = P(None, 'x')\n",
1381-
"devices = mesh_utils.create_device_mesh((1, num_devices))\n",
1382-
"mesh = jax.sharding.Mesh(devices, partition)\n",
1376+
"mesh = jax.make_mesh((num_devices,), ('x',))\n",
13831377
"sharding = jax.sharding.NamedSharding(mesh, partition)\n",
13841378
"\n",
13851379
"# We pick a large outer kernel block size that we do not want to place\n",
@@ -1445,7 +1439,7 @@
14451439
" dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
14461440
" send_sem=left_send_sem,\n",
14471441
" recv_sem=left_recv_sem,\n",
1448-
" device_id=(0, left_neighbor),\n",
1442+
" device_id=(left_neighbor,),\n",
14491443
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14501444
" )\n",
14511445
"\n",
@@ -1454,7 +1448,7 @@
14541448
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
14551449
" send_sem=right_send_sem,\n",
14561450
" recv_sem=right_recv_sem,\n",
1457-
" device_id=(0, right_neighbor),\n",
1451+
" device_id=(right_neighbor,),\n",
14581452
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14591453
" )\n",
14601454
"\n",
@@ -1463,15 +1457,15 @@
14631457
" dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
14641458
" send_sem=left_send_sem,\n",
14651459
" recv_sem=left_recv_sem,\n",
1466-
" device_id=(0, left_neighbor),\n",
1460+
" device_id=(left_neighbor,),\n",
14671461
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14681462
" )\n",
14691463
" right_copy = pltpu.make_async_remote_copy(\n",
14701464
" src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n",
14711465
" dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
14721466
" send_sem=right_send_sem,\n",
14731467
" recv_sem=right_recv_sem,\n",
1474-
" device_id=(0, right_neighbor),\n",
1468+
" device_id=(right_neighbor,),\n",
14751469
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14761470
" )\n",
14771471
"\n",
@@ -1484,13 +1478,13 @@
14841478
" pltpu.semaphore_signal(\n",
14851479
" barrier_sem,\n",
14861480
" inc=1,\n",
1487-
" device_id=(0, left_neighbor),\n",
1481+
" device_id=(left_neighbor,),\n",
14881482
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14891483
" )\n",
14901484
" pltpu.semaphore_signal(\n",
14911485
" barrier_sem,\n",
14921486
" inc=1,\n",
1493-
" device_id=(0, right_neighbor),\n",
1487+
" device_id=(right_neighbor,),\n",
14941488
" device_id_type=pltpu.DeviceIdType.MESH,\n",
14951489
" )\n",
14961490
" pltpu.semaphore_wait(barrier_sem, 2)\n",

0 commit comments

Comments
 (0)