Skip to content

Commit ae78070

Browse files
Fix issues in flax/nnx/rnglib.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3a9a55d commit ae78070

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

flax/nnx/rnglib.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,13 +474,16 @@ def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | t
474474
for name, stream in self.items():
475475
for predicate, num_splits in split_predicates.items():
476476
if predicate((), stream):
477-
if num_splits is None:
478-
keys[name] = stream
479-
else:
477+
if num_splits is not None:
480478
keys[name] = stream.split(num_splits)
479+
else:
480+
# Fork to create a new stream without splitting the key, avoiding shared state.
481+
keys[name] = stream.fork()
481482
break
482-
else:
483-
keys[name] = stream
483+
else:
484+
# If no predicate matches, fork the stream to avoid sharing state.
485+
# This is consistent with the previous `fork` behavior.
486+
keys[name] = stream.fork()
484487

485488
return Rngs(**keys)
486489

0 commit comments

Comments
 (0)