Skip to content

Commit 8cdf0de

Browse files
committed
lazily import optional deps so that they are actually optional
1 parent 74f54f8 commit 8cdf0de

File tree

12 files changed

+373
-277
lines changed

12 files changed

+373
-277
lines changed

justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ lint:
2020

2121
# run tests
2222
test *FILES:
23-
uv run --group dev pytest {{FILES}}
23+
uv run --group dev --all-extras pytest {{FILES}}
2424

2525
# include --dev-addr localhost:8001 to avoid conflicts with other mkdocs instances
2626
# serve docs for live editing

mismo/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import importlib.metadata
4+
import warnings
5+
36
from mismo import arrays as arrays
47
from mismo import cluster as cluster
58
from mismo import compare as compare
@@ -45,4 +48,8 @@
4548
from mismo.types import UnionTable as UnionTable
4649
from mismo.types import Updates as Updates
4750

48-
__version__ = "0.2.0"
51+
try:
52+
__version__ = importlib.metadata.version(__name__)
53+
except importlib.metadata.PackageNotFoundError as e:
54+
warnings.warn(f"Could not determine version of {__name__}\n{e!s}", stacklevel=2)
55+
__version__ = "unknown"

mismo/_upset.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from collections.abc import Iterable
44
from itertools import combinations
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

7-
import altair as alt
87
import pandas as pd
98

9+
if TYPE_CHECKING:
10+
import altair as alt
11+
1012

1113
def combos(set_names: Iterable[str]) -> frozenset[frozenset[str]]:
1214
"""
@@ -64,6 +66,8 @@ def upset_chart(data: Any) -> alt.Chart:
6466
Chart
6567
An Altair chart.
6668
"""
69+
import altair as alt
70+
6771
df = _to_df(data)
6872
longer = _pivot_longer(df)
6973
sets = (
@@ -184,6 +188,8 @@ def _pivot_longer(df: pd.DataFrame) -> pd.DataFrame:
184188

185189

186190
def _intersection_plot(base: alt.Chart, sets: Iterable[str], x) -> alt.Chart:
191+
import altair as alt
192+
187193
sets = list(sets)
188194
intersection_base = base.transform_filter(alt.datum.set == sets[0]).encode(
189195
x=x,
@@ -200,6 +206,8 @@ def _intersection_plot(base: alt.Chart, sets: Iterable[str], x) -> alt.Chart:
200206

201207

202208
def _matrix_plot(base: alt.Chart, sets: Iterable[str], x, y) -> alt.Chart:
209+
import altair as alt
210+
203211
sets = list(sets)
204212
matrix_circle_bg = base.mark_circle(size=100, color="lightgray", opacity=1).encode(
205213
x=x,
@@ -214,6 +222,8 @@ def _matrix_plot(base: alt.Chart, sets: Iterable[str], x, y) -> alt.Chart:
214222

215223

216224
def _set_plot(base: alt.Chart, y) -> alt.Chart:
225+
import altair as alt
226+
217227
set_base = base.transform_filter(alt.datum.is_intersect).encode(
218228
x=alt.X("sum(intersection_size):Q", title="Number of Pairs"),
219229
y=y.axis(None),

mismo/cluster/_dashboard.py

Lines changed: 6 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,15 @@
11
from __future__ import annotations
22

3-
import json
4-
from typing import TYPE_CHECKING, Any, Iterable, Mapping
3+
from typing import TYPE_CHECKING, Iterable, Mapping
54

6-
import ibis
7-
from ibis import _
85
from ibis.expr import types as ir
9-
from IPython.display import display
10-
import ipywidgets
11-
import rich
12-
import solara
136

14-
from mismo import _util
157
from mismo._datasets import Datasets
16-
from mismo.cluster._connected_components import connected_components
178

189
if TYPE_CHECKING:
19-
import ipycytoscape # type: ignore
10+
import solara
2011

2112

22-
def cytoscape_widget(
23-
tables: Datasets | ir.Table | Iterable[ir.Table] | Mapping[str, ir.Table],
24-
links: ir.Table,
25-
) -> ipycytoscape.CytoscapeWidget:
26-
"""Make a ipycytoscape.CytoscapeWidget that shows records and links.
27-
28-
This shows ALL the supplied records and links, so be careful,
29-
you probably want to filter them down first.
30-
31-
Parameters
32-
----------
33-
tables :
34-
Table(s) of records with at least the column `record_id`.
35-
links :
36-
A table of edges with at least columns
37-
(record_id_l, record_id_r) and optionally other columns.
38-
The column `width` is used to set the width of the edges.
39-
If not given, it is determined from the column `odds`, if
40-
present, or set to 5 otherwise.
41-
The column `opacity` is used to set the opacity of the edges.
42-
If not given, it is set to 0.5.
43-
"""
44-
with _util.optional_import("ipycytoscape"):
45-
import ipycytoscape # type: ignore
46-
47-
ds = Datasets(tables)
48-
links = _filter_links(links, ds)
49-
graph = {"nodes": _nodes_to_json(ds), "edges": _edges_to_json(links)}
50-
cyto = ipycytoscape.CytoscapeWidget(graph, layout={"name": "fcose"})
51-
style = [
52-
*cyto.get_style(),
53-
{
54-
"selector": "node",
55-
"css": {
56-
"label": "data(label)",
57-
"font-size": 8,
58-
"color": "data(color)",
59-
"width": 15,
60-
"height": 15,
61-
},
62-
},
63-
{
64-
"selector": "edge",
65-
"css": {
66-
"curve-style": "straight",
67-
"width": "data(width)",
68-
"opacity": "data(opacity)",
69-
},
70-
},
71-
]
72-
cyto.set_style(style)
73-
return cyto
74-
75-
76-
def _filter_links(links: ir.Table, ds: Datasets) -> ir.Table:
77-
if "record_id_l" not in links.columns:
78-
raise ValueError("links must have a record_id_l column")
79-
if "record_id_r" not in links.columns:
80-
raise ValueError("links must have a record_id_r column")
81-
links = links.filter(
82-
_.record_id_l.isin(ds.all_record_ids()),
83-
_.record_id_r.isin(ds.all_record_ids()),
84-
)
85-
return links
86-
87-
88-
def _nodes_to_json(ds: Datasets) -> list[dict[str, Any]]:
89-
colors = ["blue", "red", "green"]
90-
cmap = dict(zip(ds.names, colors[: len(ds)]))
91-
92-
def f(name: str, t: ir.Table) -> ir.Table:
93-
m = {
94-
"dataset": ibis.literal(name),
95-
"id": _.record_id.cast(str),
96-
}
97-
if "label" not in t.columns:
98-
m["label"] = name + ":" + _.record_id.cast(str)
99-
if "color" not in t.columns:
100-
m["color"] = ibis.literal(cmap[name])
101-
return t.mutate(m)
102-
103-
return _to_json(*ds.map(f))
104-
105-
106-
def _edges_to_json(links: ir.Table) -> list[dict[str, Any]]:
107-
def _ensure_has_width(links: ir.Table) -> ir.Table:
108-
if "width" in links.columns:
109-
return links
110-
if "odds" not in links.columns:
111-
return links.mutate(width=5)
112-
log_odds = _.odds.log10()
113-
log_odds_fraction = log_odds / log_odds.max()
114-
width = 10 * log_odds_fraction
115-
return links.mutate(width=width)
116-
117-
def _ensure_has_opacity(links: ir.Table) -> ir.Table:
118-
if "opacity" in links.columns:
119-
return links
120-
return links.mutate(opacity=0.5)
121-
122-
if "source" in links.columns:
123-
raise ValueError("links must not have a source column")
124-
if "target" in links.columns:
125-
raise ValueError("links must not have a target column")
126-
links = links.mutate(source="record_id_l", target="record_id_r")
127-
links = _ensure_has_width(links)
128-
links = _ensure_has_opacity(links)
129-
return _to_json(links)
130-
131-
132-
def _to_json(*tables: ir.Table) -> list[dict[str, Any]]:
133-
records = []
134-
for t in tables:
135-
# include default_handler to avoid https://stackoverflow.com/a/60492211/5156887
136-
json_str = t.to_pandas().to_json(
137-
orient="records", default_handler=str, date_format="iso"
138-
)
139-
records.extend(json.loads(json_str))
140-
return records
141-
142-
143-
def _render_to_html(obj) -> str:
144-
console = rich.console.Console(record=True)
145-
with console.capture():
146-
console.print(obj)
147-
template = """
148-
<pre style="font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace; line-height: 1.2em">
149-
<code style="font-family:inherit">{code}</code>
150-
</pre>
151-
""" # noqa: E501
152-
return console.export_html(code_format=template, inline_styles=True)
153-
154-
155-
@solara.component
15613
def cluster_dashboard(
15714
ds: Datasets | ir.Table | Iterable[ir.Table] | Mapping[str, ir.Table],
15815
links: ir.Table,
@@ -179,50 +36,11 @@ def cluster_dashboard(
17936
The column `opacity` is used to set the opacity of the edges.
18037
If not given, it is set to 0.5.
18138
"""
182-
ds = Datasets(ds)
183-
184-
def get_output():
185-
out = ipywidgets.Output()
186-
out.append_display_data(ipywidgets.HTML("Select a node or edge..."))
187-
return out
188-
189-
info = solara.use_memo(get_output)
39+
from mismo.cluster._dashboard_internal import cluster_dashboard
19040

191-
def make_cyto() -> tuple[Any, dict[Any, dict]]:
192-
cyto = cytoscape_widget(ds, links)
193-
lookup = {r["record_id"]: r for r in _nodes_to_json(ds)}
41+
return cluster_dashboard(ds, links)
19442

195-
def on_record(node: dict[str, Any]):
196-
info.clear_output()
197-
html_widget = ipywidgets.HTML(_render_to_html(node["data"]))
198-
with info:
199-
display(html_widget)
20043

201-
def on_edge(edge: dict[str, Any]):
202-
s = edge["data"]["record_id_l"]
203-
t = edge["data"]["record_id_r"]
204-
record_l = lookup[s]
205-
record_r = lookup[t]
206-
box = ipywidgets.HBox(
207-
[
208-
ipywidgets.HTML(_render_to_html(record_l)),
209-
ipywidgets.HTML(_render_to_html(record_r)),
210-
]
211-
)
212-
info.clear_output()
213-
with info:
214-
display(box)
215-
216-
cyto.on("node", "click", on_record)
217-
cyto.on("edge", "click", on_edge)
218-
return cyto
219-
220-
cyto = solara.use_memo(make_cyto, [ds, links])
221-
222-
return solara.Column([cyto, info])
223-
224-
225-
@solara.component
22644
def clusters_dashboard(
22745
tables: Datasets | ir.Table | Iterable[ir.Table] | Mapping[str, ir.Table],
22846
links: ir.Table,
@@ -232,35 +50,6 @@ def clusters_dashboard(
23250
Pass the entire dataset and the links between records,
23351
and use this to filter down to a particular cluster.
23452
"""
53+
from mismo.cluster._dashboard_internal import clusters_dashboard
23554

236-
def get_ds():
237-
ds = Datasets(tables)
238-
li = _filter_links(links, ds)
239-
ds = connected_components(records=ds, links=li)
240-
ds = ds.cache()
241-
return ds
242-
243-
ds = solara.use_memo(get_ds, [tables, links])
244-
245-
def get_components():
246-
return (
247-
ds.unioned()
248-
.select("component")
249-
.distinct()
250-
.order_by("component")
251-
.component.execute()
252-
.to_list()
253-
)
254-
255-
all_components = solara.use_memo(get_components, [ds])
256-
257-
component = solara.use_reactive(all_components[0] if len(all_components) else None)
258-
component_selector = solara.Select(
259-
"Component", values=all_components, value=component
260-
)
261-
262-
def get_subgraph():
263-
return ds.map(lambda name, t: t.filter(_.component == component.value))
264-
265-
subgraph = solara.use_memo(get_subgraph, [ds, component.value])
266-
return solara.Column([component_selector, cluster_dashboard(subgraph, links)])
55+
return clusters_dashboard(tables, links)

0 commit comments

Comments
 (0)