Skip to content

Commit 83ca7c3

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Don't create a mesh inside jit in sharded-computation tutorial
PiperOrigin-RevId: 707288222
1 parent 8f4e13f commit 83ca7c3

File tree

2 files changed

+0
-2
lines changed

2 files changed

+0
-2
lines changed

docs/sharded-computation.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@
399399
"@jax.jit\n",
400400
"def f_contract_2(x):\n",
401401
" out = x.sum(axis=0)\n",
402-
" mesh = jax.make_mesh((8,), ('x',))\n",
403402
" sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
404403
" return jax.lax.with_sharding_constraint(out, sharding)\n",
405404
"\n",

docs/sharded-computation.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ For example, suppose that within `f_contract` above, you'd prefer the output not
143143
@jax.jit
144144
def f_contract_2(x):
145145
out = x.sum(axis=0)
146-
mesh = jax.make_mesh((8,), ('x',))
147146
sharding = jax.sharding.NamedSharding(mesh, P('x'))
148147
return jax.lax.with_sharding_constraint(out, sharding)
149148

0 commit comments

Comments
 (0)