-
I'm encountering an issue where Minimal Reproducible Example: import jax
import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
def main():
# Create a mesh for data parallelism
devices = mesh_utils.create_device_mesh((jax.device_count(),))
mesh = Mesh(devices, axis_names=('dp',))
# Create sharding specs
data_sharding = NamedSharding(mesh, P('dp', None))
rng_sharding = NamedSharding(mesh, P())
# Initialize data
rng = jax.random.PRNGKey(42)
alpha = jnp.ones((64, 1))
beta = jnp.ones((64, 1))
# Shard the data
alpha_sharded = jax.device_put(alpha, data_sharding)
beta_sharded = jax.device_put(beta, data_sharding)
rng_sharded = jax.device_put(rng, rng_sharding)
print(f"alpha shape: {alpha_sharded.shape}, sharding: {alpha_sharded.sharding}")
print(f"beta shape: {beta_sharded.shape}, sharding: {beta_sharded.sharding}")
print(f"rng sharding: {rng_sharded.sharding}")
# Try to call random.beta with sharded inputs
try:
actions = random.beta(rng_sharded, alpha_sharded, beta_sharded)
print(f"Success! actions shape: {actions.shape}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
# Try with jit
@jax.jit
def beta_fn(rng, alpha, beta):
return random.beta(rng, alpha, beta)
print("\nTrying with jit:")
try:
actions = beta_fn(rng_sharded, alpha_sharded, beta_sharded)
print(f"Success! actions shape: {actions.shape}")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
# Try with shard_map (reproducing the actual scenario)
print("\nTrying with shard_map:")
def shard_beta_fn_impl(rng, alpha, beta):
# Inside shard_map, we're working with per-device data
print(f"Inside shard_map - alpha shape: {alpha.shape}")
print(f"Inside shard_map - beta shape: {beta.shape}")
print(f"Inside shard_map - rng type: {type(rng)}")
return random.beta(rng, alpha, beta)
shard_beta_fn = shard_map(
shard_beta_fn_impl,
mesh=mesh,
in_specs=(P(), P('dp', None), P('dp', None)),
out_specs=P('dp', None)
)
try:
actions = shard_beta_fn(rng_sharded, alpha_sharded, beta_sharded)
print(f"Success with shard_map! actions shape: {actions.shape}")
except Exception as e:
print(f"Error with shard_map: {type(e).__name__}: {e}")
# Try with jit + shard_map
print("\nTrying with jit + shard_map:")
@jax.jit
def jit_shard_beta_fn(rng, alpha, beta):
def inner(rng, alpha, beta):
return random.beta(rng, alpha, beta)
return shard_map(
inner,
mesh=mesh,
in_specs=(P(), P('dp', None), P('dp', None)),
out_specs=P('dp', None)
)(rng, alpha, beta)
try:
actions = jit_shard_beta_fn(rng_sharded, alpha_sharded, beta_sharded)
print(f"Success with jit + shard_map! actions shape: {actions.shape}")
except Exception as e:
print(f"Error with jit + shard_map: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print()
main() Output:
Is there a recommended workaround for sampling from Beta/Gamma distributions inside More system info:
|
Beta Was this translation helpful? Give feedback.
Answered by
DBraun
Aug 25, 2025
Replies: 1 comment
-
Answered here #31080 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
DBraun
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answered here #31080