Skip to content

Commit 91dac63

Browse files
committed
scan: improve docs & errors around dynamic length
1 parent e679811 commit 91dac63

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ def scan(f, init, xs, length=None):
178178
:py:func:`scan` compiles ``f``, so while it can be combined with
179179
:py:func:`jit`, it's usually unnecessary.
180180
181+
.. note::
182+
:func:`scan` is designed for iterating with a static number of iterations.
183+
For iteration with a dynamic number of iterations, use :func:`fori_loop`
184+
or :func:`while_loop`.
185+
181186
Args:
182187
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
183188
that ``f`` accepts two arguments where the first is a value of the loop
@@ -239,7 +244,9 @@ def scan(f, init, xs, length=None):
239244
try:
240245
length = int(length)
241246
except core.ConcretizationTypeError as err:
242-
msg = 'The `length` argument to `scan` expects a concrete `int` value.'
247+
msg = ('The `length` argument to `scan` expects a concrete `int` value.'
248+
' For scan-like iteration with a dynamic length, use `while_loop`'
249+
' or `fori_loop`.')
243250
raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type]
244251
if not all(length == l for l in lengths):
245252
msg = ("scan got `length` argument of {} which disagrees with "

0 commit comments

Comments
 (0)