Skip to content

Commit f236371

Browse files
icelaglaceCursor Assistantcursoragent
authored
Add cache-limit option (#351)
Co-authored-by: Cursor Assistant <assistant@cursor.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 697a4d2 commit f236371

File tree

5 files changed

+116
-1
lines changed

5 files changed

+116
-1
lines changed

src/mflux/callbacks/callback_manager.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from argparse import Namespace
22

3+
import mlx.core as mx
4+
35
from mflux.callbacks.instances.battery_saver import BatterySaver
46
from mflux.callbacks.instances.canny_saver import CannyImageSaver
57
from mflux.callbacks.instances.depth_saver import DepthImageSaver
@@ -62,12 +64,28 @@ def _register_stepwise_handler(args: Namespace, model, latent_creator) -> None:
6264
@staticmethod
6365
def _register_memory_saver(args: Namespace, model) -> MemorySaver | None:
6466
memory_saver = None
67+
cache_limit_bytes = CallbackManager._resolve_cache_limit_bytes(getattr(args, "mlx_cache_limit_gb", None))
6568
if args.low_ram:
6669
seeds = getattr(args, "seed", []) or []
6770
images = getattr(args, "image_path", [])
6871
if not isinstance(images, list):
6972
images = [images] if images is not None else []
7073
keep_transformer = len(seeds) > 1 or len(images) > 1
71-
memory_saver = MemorySaver(model=model, keep_transformer=keep_transformer, args=args)
74+
memory_saver = MemorySaver(
75+
model=model,
76+
keep_transformer=keep_transformer,
77+
cache_limit_bytes=cache_limit_bytes or 1000**3,
78+
args=args,
79+
)
7280
model.callbacks.register(memory_saver)
81+
elif cache_limit_bytes is not None:
82+
mx.set_cache_limit(cache_limit_bytes)
83+
mx.clear_cache()
84+
mx.reset_peak_memory()
7385
return memory_saver
86+
87+
@staticmethod
88+
def _resolve_cache_limit_bytes(mlx_cache_limit_gb: float | None) -> int | None:
89+
if mlx_cache_limit_gb is None:
90+
return None
91+
return int(mlx_cache_limit_gb * (1000**3))

src/mflux/cli/parser/parsers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ def int_or_special_value(value) -> int | scale_factor.ScaleFactor:
3535
)
3636

3737

38+
def positive_float(value: str) -> float:
39+
try:
40+
parsed = float(value)
41+
except ValueError:
42+
raise argparse.ArgumentTypeError(f"'{value}' is not a valid number")
43+
if parsed <= 0:
44+
raise argparse.ArgumentTypeError(f"'{value}' must be > 0")
45+
return parsed
46+
47+
3848
# fmt: off
3949
class CommandLineParser(argparse.ArgumentParser):
4050

@@ -52,6 +62,7 @@ def __init__(self, *args, **kwargs):
5262
def add_general_arguments(self) -> None:
5363
self.add_argument("--battery-percentage-stop-limit", "-B", type=lambda v: max(min(int(v), 99), 1), default=ui_defaults.BATTERY_PERCENTAGE_STOP_LIMIT, help=f"On Macs powered by battery, stop image generation when battery reaches this percentage. Default: {ui_defaults.BATTERY_PERCENTAGE_STOP_LIMIT}")
5464
self.add_argument("--low-ram", action="store_true", help="Enable low-RAM mode to reduce memory usage (may impact performance).")
65+
self.add_argument("--mlx-cache-limit-gb", type=positive_float, default=None, help="Limit MLX cache size in GB without enabling full low-RAM mode (e.g. 8 or 16).")
5566

5667
def add_seedvr2_upscale_arguments(self) -> None:
5768
self.supports_image_generation = True

src/mflux/models/common/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This README covers stable, shared patterns. For model-specific usage, see each m
1818
- [Metadata reuse](#metadata-reuse)
1919
- [Metadata inspection](#metadata-inspection)
2020
- [Resource and inspection options](#resource-and-inspection-options)
21+
- [MLX cache limit](#mlx-cache-limit)
2122
- [Cache locations](#cache-locations)
2223
- [Capabilities by task](#capabilities-by-task)
2324
- [Command reference](#command-reference)
@@ -522,6 +523,24 @@ image.save("image.png")
522523

523524
---
524525

526+
## MLX cache limit
527+
528+
Use `--mlx-cache-limit-gb` to cap MLX cache usage.
529+
530+
```sh
531+
mflux-generate-z-image-turbo \
532+
--model z-image-turbo \
533+
--steps 9 \
534+
--prompt "a portrait" \
535+
--mlx-cache-limit-gb 2.5
536+
```
537+
538+
- Value must be positive (`> 0`).
539+
- The value is converted internally using decimal gigabytes (`GB * 1000^3`).
540+
- Works in both normal mode and `--low-ram` mode.
541+
542+
---
543+
525544
## Cache locations
526545

527546
- **MFLUX cache**: set `MFLUX_CACHE_DIR` to override the default (`~/Library/Caches/mflux` on macOS, `~/.cache/mflux` on Linux).
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
5+
from mflux.cli.parser.parsers import CommandLineParser
6+
7+
8+
@pytest.fixture
9+
def parser() -> CommandLineParser:
10+
parser = CommandLineParser(description="Parser for MLX cache limit flag tests.")
11+
parser.add_general_arguments()
12+
return parser
13+
14+
15+
@pytest.mark.fast
16+
def test_mlx_cache_limit_gb_parses(parser: CommandLineParser):
17+
with patch("sys.argv", ["mflux-generate", "--mlx-cache-limit-gb", "8"]):
18+
args = parser.parse_args()
19+
assert args.mlx_cache_limit_gb == 8.0
20+
21+
22+
@pytest.mark.fast
23+
def test_mlx_cache_limit_gb_rejects_non_positive_value(parser: CommandLineParser):
24+
with patch("sys.argv", ["mflux-generate", "--mlx-cache-limit-gb", "0"]):
25+
with pytest.raises(SystemExit):
26+
parser.parse_args()
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from argparse import Namespace
2+
from types import SimpleNamespace
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
7+
from mflux.callbacks.callback_manager import CallbackManager
8+
9+
10+
@pytest.mark.fast
11+
def test_register_memory_saver_sets_mlx_cache_limit_without_low_ram():
12+
args = Namespace(low_ram=False, mlx_cache_limit_gb=2.5)
13+
model = SimpleNamespace(callbacks=MagicMock())
14+
15+
with (
16+
patch("mflux.callbacks.callback_manager.mx.set_cache_limit") as mock_set_cache_limit,
17+
patch("mflux.callbacks.callback_manager.mx.clear_cache") as mock_clear_cache,
18+
patch("mflux.callbacks.callback_manager.mx.reset_peak_memory") as mock_reset_peak_memory,
19+
):
20+
memory_saver = CallbackManager._register_memory_saver(args=args, model=model)
21+
22+
assert memory_saver is None
23+
mock_set_cache_limit.assert_called_once_with(int(2.5 * (1000**3)))
24+
mock_clear_cache.assert_called_once()
25+
mock_reset_peak_memory.assert_called_once()
26+
model.callbacks.register.assert_not_called()
27+
28+
29+
@pytest.mark.fast
30+
def test_register_memory_saver_uses_mlx_cache_limit_for_low_ram_mode():
31+
args = Namespace(low_ram=True, mlx_cache_limit_gb=3.0, seed=[42], image_path=None)
32+
model = SimpleNamespace(callbacks=MagicMock())
33+
mocked_memory_saver = object()
34+
35+
with patch("mflux.callbacks.callback_manager.MemorySaver", return_value=mocked_memory_saver) as mock_memory_saver:
36+
memory_saver = CallbackManager._register_memory_saver(args=args, model=model)
37+
38+
assert memory_saver is mocked_memory_saver
39+
_, kwargs = mock_memory_saver.call_args
40+
assert kwargs["cache_limit_bytes"] == int(3.0 * (1000**3))
41+
model.callbacks.register.assert_called_once_with(mocked_memory_saver)

0 commit comments

Comments
 (0)