Skip to content

Commit a8605b8

Browse files
committed
refactor: register
1 parent ac5bc8d commit a8605b8

File tree

7 files changed

+122
-106
lines changed

7 files changed

+122
-106
lines changed

mols2grid/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mols2grid.callbacks import make_popup_callback
55
from mols2grid.dispatch import display, save
66
from mols2grid.molgrid import MolGrid
7-
from mols2grid.select import get_selection, list_grids
7+
from mols2grid.select import get_selection, link_marimo_state, list_grids
88
from mols2grid.utils import is_running_within_streamlit, sdf_to_dataframe
99

1010
if is_running_within_streamlit():

mols2grid/molgrid.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import ast
21
import json
32
import os
43
import warnings
@@ -640,9 +639,9 @@ def to_interactive( # noqa: PLR0912
640639
try:
641640
cached_selection = register.get_selection(self._grid_id)
642641
except KeyError:
643-
register._init_grid(self._grid_id)
642+
register.add_grid(self._grid_id)
644643
else:
645-
register._update_current_grid(self._grid_id)
644+
register.current_selection = self._grid_id
646645
if self.cache_selection:
647646
df["cached_checkbox"] = False
648647
df.loc[
@@ -785,40 +784,6 @@ def get_selection(self):
785784
columns=self._extra_columns
786785
)
787786

788-
def get_marimo_selection(self):
789-
"""Returns a marimo state object containing the list of selected indices.
790-
Only available when running in marimo.
791-
792-
Returns
793-
-------
794-
getter
795-
A getter function for the selection state.
796-
Calling it with no arguments returns the current list of selected IDs.
797-
"""
798-
if not is_running_within_marimo():
799-
raise RuntimeError("This method is only available in a marimo notebook.")
800-
if not hasattr(self, "widget"):
801-
raise RuntimeError(
802-
"Please run the `display` method first to render the underlying widget"
803-
)
804-
805-
import marimo as mo
806-
807-
get_state, set_state = mo.state([])
808-
809-
def _on_change(change):
810-
try:
811-
sel = ast.literal_eval(change["new"])
812-
set_state(list(sel.keys()))
813-
except (ValueError, SyntaxError):
814-
pass
815-
816-
if not getattr(self.widget, "_marimo_hooked", False):
817-
self.widget.observe(_on_change, names=["selection"])
818-
self.widget._marimo_hooked = True
819-
820-
return get_state
821-
822787
def filter(self, mask):
823788
"""Filters the grid using a mask (boolean array).
824789

mols2grid/select.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,110 @@ class SelectionRegister:
1717

1818
def __init__(self):
1919
self.SELECTIONS = {}
20+
self.CALLBACKS = []
21+
self._current_selection = None
2022

21-
def _update_current_grid(self, name):
22-
self.current_selection = name
23+
@property
24+
def current_selection(self):
25+
"""The name of the last updated grid (created or interacted with),
26+
or ``None`` if no grid has been displayed yet."""
27+
return self._current_selection
28+
29+
@current_selection.setter
30+
def current_selection(self, name):
31+
if name is not None and name not in self.SELECTIONS:
32+
raise ValueError(
33+
f"The selection for {name} must be initialized "
34+
"before setting it as the current grid"
35+
)
36+
self._current_selection = name
2337

24-
def _init_grid(self, name):
38+
@current_selection.deleter
39+
def current_selection(self):
40+
self._current_selection = None
41+
42+
def add_grid(self, name):
43+
"""Adds a grid to track selections for."""
2544
overwrite = self.SELECTIONS.get(name, False)
2645
if overwrite and not is_running_within_marimo():
2746
warnings.warn(
2847
f"Overwriting non-empty {name!r} grid selection: {overwrite!s}",
2948
stacklevel=2,
3049
)
3150
self.SELECTIONS[name] = {}
32-
self._update_current_grid(name)
51+
self.current_selection = name
52+
self.add_callback(self._store_selection)
3353

3454
def selection_updated(self, name, event):
35-
self.SELECTIONS[name] = literal_eval(event.new)
36-
self._update_current_grid(name)
55+
"""Callback function linked to the widget."""
56+
self.current_selection = name
57+
selection = literal_eval(event.new)
58+
for callback in self.CALLBACKS:
59+
callback(name, selection)
60+
61+
def _store_selection(self, name, selection):
62+
"""Makes the selection available to the register."""
63+
self.SELECTIONS[name] = selection
64+
65+
def add_callback(self, callback):
66+
"""Add a callback function to be called when the selection is updated.
67+
68+
Parameters
69+
----------
70+
callback : callable
71+
The function to execute when a the selection is updated.
72+
"""
73+
self.CALLBACKS.append(callback)
3774

3875
def get_selection(self, name=None):
3976
"""Returns the selection for a specific MolGrid instance
4077
4178
Parameters
4279
----------
4380
name : str or None
44-
Name of the grid to fetch the selection from. If `None`, the most
81+
Name of the grid to fetch the selection from. If ``None``, the most
4582
recently updated grid is returned
4683
"""
4784
name = self.current_selection if name is None else name
4885
return self.SELECTIONS[name]
4986

87+
def link_marimo_state(self):
88+
"""Link the register to marimo by initializing a ``state`` dict.
89+
When the selection on a grid is updated, the state setter function from marimo
90+
is called to update the state's value.
91+
92+
Returns
93+
-------
94+
get_state, set_state:
95+
The state getter and setter returned by ``marimo.state({})``.
96+
``get_state`` returns a dictionary of all selections where the keys are
97+
grid names and values are the indices selected for that grid.
98+
"""
99+
if not is_running_within_marimo():
100+
raise RuntimeError("This method is only available in a marimo notebook.")
101+
102+
import marimo as mo
103+
104+
get_state, set_state = mo.state({})
105+
106+
def marimo_callback(name, selection):
107+
set_state(lambda value: {**value, name: list(selection)})
108+
109+
self.add_callback(marimo_callback)
110+
return get_state, set_state
111+
50112
def list_grids(self):
51113
"""Returns a list of grid names"""
52-
return list(self.SELECTIONS.keys())
114+
return list(self.SELECTIONS)
53115

54-
def _clear(self):
116+
def clear(self):
55117
"""Clears all selections"""
56-
if hasattr(self, "current_selection"):
57-
del self.current_selection
58118
self.SELECTIONS.clear()
119+
self.CALLBACKS.clear()
120+
del self.current_selection
59121

60122

61123
register = SelectionRegister()
62124
get_selection = register.get_selection
63125
list_grids = register.list_grids
126+
link_marimo_state = register.link_marimo_state

tests/test_interface.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_selection_click(driver: CustomDriver, html_doc):
176176
driver.wait_for_img_load()
177177
sel = driver.click_checkbox()
178178
assert sel == {0: "CCC(C)CC"}
179-
register._clear()
179+
register.clear()
180180

181181

182182
def test_export_csv(driver: CustomDriver, html_doc):
@@ -204,11 +204,11 @@ def test_export_csv(driver: CustomDriver, html_doc):
204204
)
205205
assert content == expected
206206
csv_file.unlink()
207-
register._clear()
207+
register.clear()
208208

209209

210210
def test_selection_with_cache_check_and_uncheck(driver: CustomDriver, df):
211-
register._init_grid("cached_sel")
211+
register.add_grid("cached_sel")
212212
event = SimpleNamespace(new='{0: "CCC(C)CC"}')
213213
register.selection_updated("cached_sel", event)
214214
grid = get_grid(df, name="cached_sel", cache_selection=True)
@@ -219,7 +219,7 @@ def test_selection_with_cache_check_and_uncheck(driver: CustomDriver, df):
219219
assert sel == {0: "CCC(C)CC"}
220220
empty_sel = driver.click_checkbox(is_empty=True)
221221
assert empty_sel is True
222-
register._clear()
222+
register.clear()
223223

224224

225225
def test_selection_check_uncheck_invert(driver: CustomDriver, html_doc):
@@ -243,7 +243,7 @@ def test_selection_check_uncheck_invert(driver: CustomDriver, html_doc):
243243
driver.grid_action("invert")
244244
sel = driver.wait_for_selection(is_empty=False)
245245
assert len(sel) == 29
246-
register._clear()
246+
register.clear()
247247

248248

249249
@pytest.mark.parametrize("prerender", [True, False])

tests/test_marimo_integration.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import sys
2+
from types import SimpleNamespace
23
from unittest.mock import MagicMock, patch
34

45
import pandas as pd
56
import pytest
67

78
from mols2grid import MolGrid
9+
from mols2grid.select import register
810
from mols2grid.utils import is_running_within_marimo
911
from mols2grid.widget import MolGridWidget
1012

1113

1214
@pytest.fixture
13-
def mock_marimo_module():
14-
with patch.dict(sys.modules, {"marimo": MagicMock()}):
15-
yield
15+
def mock_marimo_module(monkeypatch: pytest.MonkeyPatch):
16+
monkeypatch.setattr(sys, "modules", {**sys.modules, "marimo": MagicMock()})
1617

1718

1819
@pytest.mark.usefixtures("mock_marimo_module")
@@ -21,11 +22,9 @@ def test_is_running_within_marimo_true():
2122

2223

2324
def test_is_running_within_marimo_false():
24-
# Ensure marimo is not in sys.modules for this test
25-
with patch.dict(sys.modules):
26-
if "marimo" in sys.modules:
27-
del sys.modules["marimo"]
28-
assert is_running_within_marimo() is False
25+
if "marimo" in sys.modules:
26+
assert isinstance(sys.modules["marimo"], MagicMock)
27+
assert is_running_within_marimo() is False
2928

3029

3130
@pytest.fixture
@@ -52,13 +51,6 @@ def test_display_in_marimo(grid_fixture):
5251
assert result == mock_anywidget.return_value
5352

5453

55-
@pytest.mark.usefixtures("mock_marimo_module")
56-
def test_get_marimo_selection_before_rendering_raises(grid_fixture):
57-
_, mg = grid_fixture
58-
with pytest.raises(RuntimeError, match="run the `display` method first"):
59-
mg.get_marimo_selection()
60-
61-
6254
@pytest.mark.usefixtures("mock_marimo_module")
6355
def test_get_selection_state_inside_marimo(grid_fixture):
6456
_, mg = grid_fixture
@@ -71,57 +63,53 @@ def test_get_selection_state_inside_marimo(grid_fixture):
7163
"marimo.state", return_value=(mock_get_state, mock_set_state)
7264
) as mock_state:
7365
# Call get_marimo_selection
74-
state_getter = mg.get_marimo_selection()
66+
state_getter, _ = register.link_marimo_state()
7567

7668
# Check if marimo.state was called with empty list
77-
mock_state.assert_called_once_with([])
78-
79-
# Check if _marimo_hooked is set
80-
assert getattr(mg.widget, "_marimo_hooked", False) is True
69+
mock_state.assert_called_once_with({})
8170

8271
# Verify return value
8372
assert state_getter == mock_get_state
8473

8574

86-
def test_get_selection_state_outside_marimo(grid_fixture):
87-
_, mg = grid_fixture
88-
75+
def test_get_selection_state_outside_marimo():
8976
# Ensure marimo is not in sys.modules
9077
with patch.dict(sys.modules):
9178
if "marimo" in sys.modules:
9279
del sys.modules["marimo"]
9380

9481
with pytest.raises(RuntimeError, match="only available in a marimo notebook"):
95-
mg.get_marimo_selection()
82+
register.link_marimo_state()
9683

9784

9885
@pytest.mark.usefixtures("mock_marimo_module")
99-
def test_selection_state_update_logic(grid_fixture):
86+
def test_selection_state_update_logic(grid_fixture, monkeypatch: pytest.MonkeyPatch):
10087
_, mg = grid_fixture
101-
mg.render()
102-
10388
mock_set_state = MagicMock()
10489
with (
10590
patch("marimo.state", return_value=(MagicMock(), mock_set_state)),
106-
patch.object(mg.widget, "observe") as mock_observe,
91+
patch.object(MolGridWidget, "observe") as mock_observe,
10792
):
10893
# Inspect the observe call to capture the callback
109-
mg.get_marimo_selection()
94+
register.link_marimo_state()
95+
mock_callback = MagicMock(wraps=register.CALLBACKS[-1])
96+
monkeypatch.setattr(register, "CALLBACKS", [mock_callback])
11097

98+
mg.render()
11199
# Verify observe was called
112100
mock_observe.assert_called()
113-
args, _ = mock_observe.call_args
114-
callback = args[0]
115101

116102
# Simulate event with valid selection
117-
# The widget returns a string representation of a dict
118-
new_selection = {1: "C", 2: "CC"}
119-
event = {"new": str(new_selection)}
120-
121-
callback(event)
122-
mock_set_state.assert_called_with([1, 2])
123-
124-
# Test invalid input (should pass silently)
125-
mock_set_state.reset_mock()
126-
callback({"new": "invalid json"})
127-
mock_set_state.assert_not_called()
103+
event_values = {1: "C", 2: "CC"}
104+
event = SimpleNamespace(new=str(event_values))
105+
register.selection_updated("default", event)
106+
107+
# check callback was called with expected selection
108+
mock_callback.assert_called_with("default", event_values)
109+
110+
# check inner lambda works as expected: given a current state with
111+
# mol 42 selected, and new state where only 1 and 2 are selected,
112+
# 42 should disappear and only 1, 2 remain
113+
lambda_setter = mock_set_state.call_args[0][0]
114+
result = lambda_setter({"default": [42]})
115+
assert result == {"default": [1, 2]}

tests/test_molgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_get_selection(small_df):
238238
assert register.current_selection == "grid"
239239
assert grid.get_selection().equals(small_df.head(1))
240240
assert other.get_selection().equals(small_df.head(0)) # empty dataframe
241-
register._clear()
241+
register.clear()
242242

243243

244244
def test_save(grid, tmp_path):

0 commit comments

Comments
 (0)