Skip to content

Commit 213e41e

Browse files
committed
Fixed mask sharding if inputs is sharded
Fixes #5209
1 parent 01c0bd1 commit 213e41e

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

flax/nnx/nn/stochastic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def __call__(
150150
broadcast_shape = list(inputs.shape)
151151
for dim in self.broadcast_dims:
152152
broadcast_shape[dim] = 1
153-
mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape)
153+
mask = random.bernoulli(
154+
key, p=keep_prob, shape=broadcast_shape, out_sharding=inputs.sharding
155+
)
154156
mask = jnp.broadcast_to(mask, inputs.shape)
155157
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
156158

tests/nnx/spmd_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ def test_out_sharding_embed_attend(self):
243243
assert 'float32[2@X,10]' in str(jax.typeof(layer.attend(sharded_array)))
244244
assert 'float32[2@X,10@Y]' in str(jax.typeof(layer.attend(sharded_array, out_sharding=P("X", "Y"))))
245245

246+
def test_out_sharding_dropout(self):
247+
mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit))
248+
with jax.set_mesh(mesh):
249+
replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32)
250+
sharded_array = reshard(replicated_array, P("X", None))
251+
layers = [
252+
nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)),
253+
nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)),
254+
]
255+
for layer in layers:
256+
assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array)))
257+
246258
@parameterized.product(use_hijax=[True, False])
247259
def test_logical_rules(self, use_hijax):
248260
self.enter_context(nnx.var_defaults(hijax=use_hijax))

0 commit comments

Comments
 (0)