Skip to content

Commit 62646fd

Browse files
CopilotRouthleckclaude
authored
Fix bm.for_loop jit parameter handling and remove unused parameters (#803)
* Initial plan * Fix bm.for_loop jit parameter handling and remove unused parameters Co-authored-by: Routhleck <[email protected]> * Enhance progress_bar parameter to support ProgressBar instances This commit improves the `progress_bar` parameter in `bm.for_loop()` and `bm.scan()` to accept ProgressBar instances and integers for advanced customization, while maintaining full backward compatibility. Changes: - Added `_convert_progress_bar_to_pbar()` helper function for parameter conversion - Updated type hints to `Union[bool, brainstate.transform.ProgressBar, int]` - Enhanced docstrings with detailed examples for all supported types - Exported `ProgressBar` from `brainpy.math` for easy access - Added 10 comprehensive test cases covering all usage patterns - Updated API documentation to include ProgressBar Supported usage: - `progress_bar=True/False` (backward compatible) - `progress_bar=bm.ProgressBar(freq=10)` (custom frequency) - `progress_bar=bm.ProgressBar(desc="Processing")` (custom description) - `progress_bar=10` (integer shorthand for freq parameter) All 37 tests in test_controls.py pass, ensuring no regressions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix runners.py to use functools.partial instead of removed unroll_kwargs The unroll_kwargs parameter was removed from bm.for_loop() as it was never actually implemented. Updated runners.py to use functools.partial() to bind shared_args to _step_func_predict, following the pattern already used in train/online.py. Changes: - Added functools import - Changed for_loop call to use functools.partial(self._step_func_predict, shared_args=shared_args) - Removed unroll_kwargs={'shared_args': shared_args} which was non-functional 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix LoopOverTime to remove remat parameter from for_loop call The remat parameter was removed from bm.for_loop() as it was never implemented. Updated LoopOverTime to: 1. Remove remat from the for_loop() call 2. Keep remat parameter in __init__ for backward compatibility 3. Add deprecation warning when remat=True is passed Changes: - Removed remat=self.remat from for_loop call on line 283 - Added deprecation warning for remat parameter - Removed self.remat storage (commented out for clarity) This maintains backward compatibility while warning users that the parameter no longer has any effect. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix zero-length scan error when using jit=False When jit=False is used with zero-length input arrays, JAX's disable_jit() mode cannot handle the scan operation because it cannot infer the output type. Changes: - Added check for zero-length inputs when jit=False - Automatically falls back to JIT mode for zero-length inputs - Issues a UserWarning to inform users of the fallback - Added test case to verify zero-length input handling - All 38 tests in test_controls.py pass This fix resolves: ValueError: zero-length scan is not supported in disable_jit() mode because the output type is unknown. * Set MPLBACKEND=Agg in CI to fix Tkinter issues on Windows On Windows Python 3.13 CI environment, Tcl/Tk is not properly configured, causing matplotlib tests to fail with TclError. Setting the MPLBACKEND environment variable to 'Agg' (non-interactive backend) resolves this issue. Changes: - Added MPLBACKEND=Agg env var to all test jobs (Linux, macOS, Windows) - This ensures consistent behavior across all CI platforms - Fixes TclError in test_phase_plane.py and test_aligns.py on Windows This is a cleaner solution than modifying individual test files. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: Routhleck <[email protected]> Co-authored-by: routhleck <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 0c5b521 commit 62646fd

File tree

8 files changed

+2287
-26
lines changed

8 files changed

+2287
-26
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ jobs:
4848
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
4949
pip install -e .
5050
- name: Test with pytest
51+
env:
52+
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
5153
run: |
5254
pytest brainpy/
5355
@@ -77,6 +79,8 @@ jobs:
7779
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
7880
pip install -e .
7981
- name: Test with pytest
82+
env:
83+
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
8084
run: |
8185
pytest brainpy/
8286
@@ -106,5 +110,7 @@ jobs:
106110
python -m pip install -r requirements-dev.txt
107111
pip install -e .
108112
- name: Test with pytest
113+
env:
114+
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
109115
run: |
110116
pytest brainpy/

brainpy/math/object_transform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
Details please see the following.
3030
"""
3131

32+
from brainstate.transform import ProgressBar
33+
3234
from .autograd import *
3335
from .base import *
3436
from .collectors import *

brainpy/math/object_transform/controls.py

Lines changed: 122 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numbers
1717
from typing import Union, Sequence, Any, Dict, Callable, Optional
1818

19+
import jax
1920
import jax.numpy as jnp
2021

2122
import brainstate
@@ -31,6 +32,42 @@
3132
]
3233

3334

35+
def _convert_progress_bar_to_pbar(
36+
progress_bar: Union[bool, brainstate.transform.ProgressBar, int, None]
37+
) -> Optional[brainstate.transform.ProgressBar]:
38+
"""Convert progress_bar parameter to brainstate pbar format.
39+
40+
Parameters
41+
----------
42+
progress_bar : bool, ProgressBar, int, None
43+
The progress_bar parameter value.
44+
45+
Returns
46+
-------
47+
pbar : ProgressBar or None
48+
The converted ProgressBar instance or None.
49+
50+
Raises
51+
------
52+
TypeError
53+
If progress_bar is not a valid type.
54+
"""
55+
if progress_bar is False or progress_bar is None:
56+
return None
57+
elif progress_bar is True:
58+
return brainstate.transform.ProgressBar()
59+
elif isinstance(progress_bar, int):
60+
# Support brainstate convention: int means freq parameter
61+
return brainstate.transform.ProgressBar(freq=progress_bar)
62+
elif isinstance(progress_bar, brainstate.transform.ProgressBar):
63+
return progress_bar
64+
else:
65+
raise TypeError(
66+
f"progress_bar must be bool, int, or ProgressBar instance, "
67+
f"got {type(progress_bar).__name__}"
68+
)
69+
70+
3471
def cond(
3572
pred: bool,
3673
true_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
@@ -205,10 +242,8 @@ def for_loop(
205242
operands: Any,
206243
reverse: bool = False,
207244
unroll: int = 1,
208-
remat: bool = False,
209245
jit: Optional[bool] = None,
210-
progress_bar: bool = False,
211-
unroll_kwargs: Optional[Dict] = None,
246+
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
212247
):
213248
"""``for-loop`` control flow with :py:class:`~.Variable`.
214249
@@ -266,10 +301,6 @@ def for_loop(
266301
If body function `body_func` receives multiple arguments,
267302
`operands` should be a tuple/list whose length is equal to the
268303
number of arguments.
269-
remat: bool
270-
Make ``fun`` recompute internal linearization points when differentiated.
271-
jit: bool
272-
Whether to just-in-time compile the function.
273304
reverse: bool
274305
Optional boolean specifying whether to run the scan iteration
275306
forward (the default) or in reverse, equivalent to reversing the leading
@@ -278,10 +309,37 @@ def for_loop(
278309
Optional positive int specifying, in the underlying operation of the
279310
scan primitive, how many scan iterations to unroll within a single
280311
iteration of a loop.
281-
progress_bar: bool
282-
Whether we use the progress bar to report the running progress.
312+
jit: bool
313+
Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
314+
progress_bar: bool, ProgressBar, int
315+
Whether and how to display a progress bar during execution:
316+
317+
- ``False`` (default): No progress bar
318+
- ``True``: Display progress bar with default settings
319+
- ``ProgressBar`` instance: Display progress bar with custom settings
320+
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)
321+
322+
For advanced customization, create a :py:class:`brainpy.math.ProgressBar` instance:
323+
324+
>>> import brainpy.math as bm
325+
>>> # Custom update frequency
326+
>>> pbar = bm.ProgressBar(freq=10)
327+
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
328+
>>>
329+
>>> # Custom description
330+
>>> pbar = bm.ProgressBar(desc="Processing data")
331+
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
332+
>>>
333+
>>> # Update exactly 20 times during execution
334+
>>> pbar = bm.ProgressBar(count=20)
335+
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
336+
>>>
337+
>>> # Integer shorthand (equivalent to ProgressBar(freq=10))
338+
>>> result = bm.for_loop(body_fun, operands, progress_bar=10)
283339
284340
.. versionadded:: 2.4.2
341+
.. versionchanged:: 2.7.3
342+
Now accepts ProgressBar instances and integers for advanced customization.
285343
dyn_vars: Variable, sequence of Variable, dict
286344
The instances of :py:class:`~.Variable`.
287345
@@ -296,8 +354,6 @@ def for_loop(
296354
.. deprecated:: 2.4.0
297355
No longer need to provide ``child_objs``. This function is capable of automatically
298356
collecting the children objects used in the target ``func``.
299-
unroll_kwargs: dict
300-
The keyword arguments without unrolling.
301357
302358
Returns::
303359
@@ -306,11 +362,45 @@ def for_loop(
306362
"""
307363
if not isinstance(operands, (tuple, list)):
308364
operands = (operands,)
309-
return brainstate.transform.for_loop(
310-
warp_to_no_state_input_output(body_fun),
311-
*operands, reverse=reverse, unroll=unroll,
312-
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
313-
)
365+
366+
# Convert progress_bar to pbar format
367+
pbar = _convert_progress_bar_to_pbar(progress_bar)
368+
369+
# Handle jit parameter
370+
# Note: JAX's scan doesn't support zero-length inputs in disable_jit mode.
371+
# For zero-length inputs, we need to use JIT mode even when jit=False.
372+
should_disable_jit = False
373+
if jit is False:
374+
# Check if any operand has zero length
375+
first_operand = operands[0]
376+
is_zero_length = False
377+
if hasattr(first_operand, 'shape') and len(first_operand.shape) > 0:
378+
is_zero_length = (first_operand.shape[0] == 0)
379+
380+
if is_zero_length:
381+
# Use JIT mode for zero-length inputs to avoid JAX limitation
382+
import warnings
383+
warnings.warn(
384+
"for_loop with jit=False and zero-length input detected. "
385+
"Using JIT mode to avoid JAX's disable_jit limitation with zero-length scans.",
386+
UserWarning
387+
)
388+
else:
389+
should_disable_jit = True
390+
391+
if should_disable_jit:
392+
with jax.disable_jit():
393+
return brainstate.transform.for_loop(
394+
warp_to_no_state_input_output(body_fun),
395+
*operands, reverse=reverse, unroll=unroll,
396+
pbar=pbar,
397+
)
398+
else:
399+
return brainstate.transform.for_loop(
400+
warp_to_no_state_input_output(body_fun),
401+
*operands, reverse=reverse, unroll=unroll,
402+
pbar=pbar,
403+
)
314404

315405

316406
def scan(
@@ -320,7 +410,7 @@ def scan(
320410
reverse: bool = False,
321411
unroll: int = 1,
322412
remat: bool = False,
323-
progress_bar: bool = False,
413+
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
324414
):
325415
"""``scan`` control flow with :py:class:`~.Variable`.
326416
@@ -359,23 +449,35 @@ def scan(
359449
Optional positive int specifying, in the underlying operation of the
360450
scan primitive, how many scan iterations to unroll within a single
361451
iteration of a loop.
362-
progress_bar: bool
363-
Whether we use the progress bar to report the running progress.
452+
progress_bar: bool, ProgressBar, int
453+
Whether and how to display a progress bar during execution:
454+
455+
- ``False`` (default): No progress bar
456+
- ``True``: Display progress bar with default settings
457+
- ``ProgressBar`` instance: Display progress bar with custom settings
458+
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)
459+
460+
See :py:func:`for_loop` for detailed examples of ProgressBar usage.
364461
365462
.. versionadded:: 2.4.2
463+
.. versionchanged:: 2.7.3
464+
Now accepts ProgressBar instances and integers for advanced customization.
366465
367466
Returns::
368467
369468
outs: Any
370469
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
371470
"""
471+
# Convert progress_bar to pbar format
472+
pbar = _convert_progress_bar_to_pbar(progress_bar)
473+
372474
return brainstate.transform.scan(
373475
warp_to_no_state_input_output(body_fun),
374476
init=init,
375477
xs=operands,
376478
reverse=reverse,
377479
unroll=unroll,
378-
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
480+
pbar=pbar,
379481
)
380482

381483

0 commit comments

Comments
 (0)