Skip to content

Commit 65b6088

Browse files
IvyZXhawkinsp
authored andcommitted
Avoid index out of range error in carry structure check
1 parent 259194a commit 65b6088

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.36
13+
## jax 0.4.37
14+
15+
* Bug fixes
16+
* Fix a bug that will throw `index out of range` error in
17+
{func}`jax.lax.while_loop` if the user register pytree node class with
18+
different aux data for the flatten and flatten_with_path.
19+
20+
## jax 0.4.36 (Dec 5, 2024)
1421

1522
* Breaking Changes
1623
* This release lands "stackless", an internal change to JAX's tracing

jax/_src/lax/control_flow/loops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
376376
f'of the carry output is a {thing2}, so {explanation}'
377377
for path, thing1, thing2, explanation
378378
in equality_errors(in_carry, out_carry)]
379+
if len(diffs) == 0:
380+
# The trees may have different aux data but structures are the same.
381+
return
379382
if len(diffs) == 1:
380383
differences = f'{diffs[0]}.\n'.capitalize()
381384
else:
@@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
393396
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
394397
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
395398
if not core.typematch(in_aval, out_aval)]
399+
if len(diffs) == 0:
400+
# The trees may have different aux data but structures are the same.
401+
return
396402
if len(diffs) == 1:
397403
differences = f'{diffs[0]}.\n'.capitalize()
398404
else:

tests/lax_control_flow_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,19 @@ def testWhileTypeErrors(self):
322322
lax.while_loop(lambda c: True, lambda c: (True, True),
323323
(np.bool_(True), np.float32(0.)))
324324

325+
def testWhileLoopCustomPytreeDiffAuxData(self):
326+
class Node:
327+
def __init__(self, x, y):
328+
self.x = x
329+
self.y = y
330+
tree_util.register_pytree_with_keys(
331+
Node,
332+
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
333+
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
334+
lambda o: ((o.x, o.y), 'without_keys'), # flatten
335+
)
336+
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))
337+
325338
def testNestedWhileWithDynamicUpdateSlice(self):
326339
num = 5
327340

0 commit comments

Comments
 (0)