-
Notifications
You must be signed in to change notification settings - Fork 787
Closed
Description
Currently, nnx.Droupout samples a mask which matches the shape of the input (+broadcasting rules):
flax/flax/nnx/nn/stochastic.py
Lines 150 to 153 in 01c0bd1
| broadcast_shape = list(inputs.shape) | |
| for dim in self.broadcast_dims: | |
| broadcast_shape[dim] = 1 | |
| mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape) |
It should also match the sharding of the input.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Fedora 41
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:- flax: 0.8.5
- jax: 0.4.30
- jaxlib: 0.4.30
- Python version: 3.13.9
- GPU/TPU model and memory: NVIDIA RTX 2000 Ada Generation (8G)
- CUDA version (if applicable): 13
Problem you have encountered:
nnx.Dropout does not work with explicitely sharded inputs.
What you expected to happen:
nnx.Dropout accepts explicitely sharded input
Logs, error messages, etc:
Traceback (most recent call last):
File "/home/ngranger/Projects/datarater/draft/bug.py", line 27, in <module>
dropout(x, rngs=rngs)
~~~~~~~^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/flax/nnx/nn/stochastic.py", line 155, in __call__
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/lax/lax.py", line 2898, in select
return select_n_p.bind(pred, on_false, on_true)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/core.py", line 632, in bind
return self._true_bind(*args, **params)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/core.py", line 648, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/core.py", line 660, in bind_with_trace
return trace.process_primitive(self, args, params)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/core.py", line 1205, in process_primitive
return primitive.impl(*args, **params)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "./venv/lib64/python3.13/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
outs = fun(*args)
jax._src.core.ShardingTypeError: select `which` must be scalar or have the same sharding as cases, got `which` sharding NamedSharding(mesh=AbstractMesh('batch': 8, axis_types=(Explicit,), device_kind=cpu, num_cores=None), spec=PartitionSpec(None, None)) but case sharding NamedSharding(mesh=AbstractMesh('batch': 8, axis_types=(Explicit,), device_kind=cpu, num_cores=None), spec=PartitionSpec('batch', None)).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Steps to reproduce:
import jax
from flax import nnx
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import AxisType
from jax.sharding import PartitionSpec as P
jax.config.update("jax_num_cpu_devices", 8)
jax.config.update("jax_platform_name", "cpu")
print(jax.devices())
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((jax.device_count(),)),
axis_names=("batch",),
axis_types=(AxisType.Explicit,),
)
dropout = nnx.Dropout(rate=0.5)
rngs = nnx.Rngs(0)
with jax.set_mesh(mesh):
x = jnp.zeros([128, 16], out_sharding=P("batch"))
print(jax.typeof(x))
dropout(x, rngs=rngs)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels