-
I'm trying to figure out how to call a jax function (which would be annoying to rewrite using pure python code) on numpy arrays at compile time, inside of jit. The only way I could make it work was to change the mesh by hand in the jax config, however I was wondering what the proper way to get this to work would be? import jax
import jax.numpy as jnp
from functools import partial
from jax.sharding import *
from jax import shard_map, P
import numpy as np
if jax.config.jax_num_cpu_devices <=1:
jax.config.update("jax_num_cpu_devices", 4)
mesh = jax.make_mesh((4,),("i",), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh)
x = jnp.ones((4, 8), out_sharding=P('i'))
def fun_jax(x):
return jax.vmap(jax.lax.sin)(x).sum()
def fun_np(x):
return np.sin(x).sum()
fun = fun_jax
workaround=False
# fun = fun_np
@jax.jit
@partial(jax.shard_map, in_specs=P('i'), out_specs=P('i'))
@jax.vmap
def test(x):
with jax.ensure_compile_time_eval():
a = np.ones(123) + x.shape[-1]
if not workaround:
# want to run this jax function on numpy arrays at compile time
# this gives sharding errors if a mesh is set (here we call it on numpy arrays)
const = np.asarray(fun(a)) # ZeroDivisionError: integer division or modulo by zero
# reshard by hand
a = jnp.array(a, device=NamedSharding(jax.typeof(x).sharding.mesh, P()))
const = np.asarray(fun(a)) # ZeroDivisionError: integer division or modulo by zero
# we are inside jit, so the set_mesh context manager doesnt't let us change the mesh
with jax.set_mesh(mesh): # ValueError: `set_mesh` can only be used outside of `jax.jit`
const = np.asarray(fun(a))
else: # workaround
# unset mesh by hand, and reset it afterwards
mesh_bak = jax._src.config.abstract_mesh_context_manager.value
jax._src.config.abstract_mesh_context_manager.set_local(None)
const = np.asarray(fun(a))
jax._src.config.abstract_mesh_context_manager.set_local(mesh_bak)
return x.sum() + const
test(x) this errors with File ~/venv/lib/python3.13/site-packages/jax/_src/sharding.py:186, in Sharding.shard_shape(self, global_shape)
180 def shard_shape(self, global_shape: Shape) -> Shape:
181 """Returns the shape of the data on each device.
182
183 The shard shape returned by this function is calculated from
184 ``global_shape`` and the properties of the sharding.
185 """
--> 186 return _common_shard_shape(self, global_shape)
File ~/venv/lib/python3.13/site-packages/jax/_src/sharding.py:62, in _common_shard_shape(self, global_shape)
60 for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
61 try:
---> 62 quotient, remainder = divmod(s, p)
63 except TypeError:
64 # TODO Figure out how to partition dynamic shapes
65 raise NotImplementedError
ZeroDivisionError: integer division or modulo by zero```
```jax: 0.7.1
jaxlib: 0.7.1
numpy: 2.2.6
python: 3.13.5
device info: cpu-4, 4 local devices"
process_count: 1 |
Beta Was this translation helpful? Give feedback.
Answered by
yashk2810
Sep 2, 2025
Replies: 1 comment 4 replies
-
I'll look. Maybe convert this into an issue? |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I looked into this. The problem is doing ensure_compile_time_eval under shard_map creates
Manual
HloShardings that don't really work in the execution part i.e. the impl rule of primitives.Your workaround is fine for now but you can use
jax.sharding.use_abstract_mesh
to override the mesh (but just be careful about doing that under a shard_map).