|
| 1 | +from argparse import Namespace |
| 2 | +from types import SimpleNamespace |
| 3 | +from unittest.mock import MagicMock, patch |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from mflux.callbacks.callback_manager import CallbackManager |
| 8 | + |
| 9 | + |
| 10 | +@pytest.mark.fast |
| 11 | +def test_register_memory_saver_sets_mlx_cache_limit_without_low_ram(): |
| 12 | + args = Namespace(low_ram=False, mlx_cache_limit_gb=2.5) |
| 13 | + model = SimpleNamespace(callbacks=MagicMock()) |
| 14 | + |
| 15 | + with ( |
| 16 | + patch("mflux.callbacks.callback_manager.mx.set_cache_limit") as mock_set_cache_limit, |
| 17 | + patch("mflux.callbacks.callback_manager.mx.clear_cache") as mock_clear_cache, |
| 18 | + patch("mflux.callbacks.callback_manager.mx.reset_peak_memory") as mock_reset_peak_memory, |
| 19 | + ): |
| 20 | + memory_saver = CallbackManager._register_memory_saver(args=args, model=model) |
| 21 | + |
| 22 | + assert memory_saver is None |
| 23 | + mock_set_cache_limit.assert_called_once_with(int(2.5 * (1000**3))) |
| 24 | + mock_clear_cache.assert_called_once() |
| 25 | + mock_reset_peak_memory.assert_called_once() |
| 26 | + model.callbacks.register.assert_not_called() |
| 27 | + |
| 28 | + |
| 29 | +@pytest.mark.fast |
| 30 | +def test_register_memory_saver_uses_mlx_cache_limit_for_low_ram_mode(): |
| 31 | + args = Namespace(low_ram=True, mlx_cache_limit_gb=3.0, seed=[42], image_path=None) |
| 32 | + model = SimpleNamespace(callbacks=MagicMock()) |
| 33 | + mocked_memory_saver = object() |
| 34 | + |
| 35 | + with patch("mflux.callbacks.callback_manager.MemorySaver", return_value=mocked_memory_saver) as mock_memory_saver: |
| 36 | + memory_saver = CallbackManager._register_memory_saver(args=args, model=model) |
| 37 | + |
| 38 | + assert memory_saver is mocked_memory_saver |
| 39 | + _, kwargs = mock_memory_saver.call_args |
| 40 | + assert kwargs["cache_limit_bytes"] == int(3.0 * (1000**3)) |
| 41 | + model.callbacks.register.assert_called_once_with(mocked_memory_saver) |
0 commit comments