Skip to content

Commit 0b831d4

Browse files
authored
Merge pull request #289 from neo4j/update-widget-data-util
update widget data util
2 parents c5503e3 + d0ec689 commit 0b831d4

File tree

9 files changed

+465
-402
lines changed

9 files changed

+465
-402
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
## New features
66

7+
- Add convenience method `add_data` and `remove_data` to `GraphWidget`.
8+
79
## Bug fixes
810

911
- Fixed a bug with the theme detection inn VSCode.

examples/getting-started.ipynb

Lines changed: 327 additions & 368 deletions
Large diffs are not rendered by default.

justfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ py-sync:
22
cd python-wrapper && uv sync --group dev --group docs --group notebook --extra pandas --extra neo4j --extra gds --extra snowflake
33

44
py-style:
5+
just py-sync
56
./scripts/makestyle.sh && ./scripts/checkstyle.sh
67

78
py-test:

python-wrapper/src/neo4j_viz/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .node_size import RealNumber
1010
from .options import CaptionAlignment
1111

12-
NodeIdType = Union[str, int]
12+
NodeIdType = str | int
1313

1414

1515
def create_aliases(field_name: str) -> AliasChoices:

python-wrapper/src/neo4j_viz/relationship.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from .options import CaptionAlignment
1111

12+
RelationshipIdType = str | int
13+
1214

1315
def create_aliases(field_name: str) -> AliasChoices:
1416
valid_names = [field_name]
@@ -43,7 +45,7 @@ class Relationship(
4345
"""
4446

4547
#: Unique identifier for the relationship
46-
id: Union[str, int] = Field(
48+
id: RelationshipIdType = Field(
4749
default_factory=lambda: uuid4().hex, description="Unique identifier for the relationship"
4850
)
4951
#: Node ID where the relationship points from

python-wrapper/src/neo4j_viz/visualization_graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _build_render_options(
8888
self,
8989
layout: Layout | None,
9090
layout_options: dict[str, Any] | LayoutOptions | None,
91-
renderer: Renderer,
91+
renderer: Renderer | str,
9292
pan_position: tuple[float, float] | None,
9393
initial_zoom: float | None,
9494
min_zoom: float,
@@ -105,6 +105,9 @@ def _build_render_options(
105105
"overriding `max_allowed_nodes`, but rendering could then take a long time"
106106
)
107107

108+
if isinstance(renderer, str):
109+
renderer = Renderer(renderer)
110+
108111
Renderer.check(renderer, num_nodes)
109112

110113
if not layout:
@@ -133,7 +136,7 @@ def render(
133136
self,
134137
layout: Layout | None = None,
135138
layout_options: dict[str, Any] | LayoutOptions | None = None,
136-
renderer: Renderer = Renderer.CANVAS,
139+
renderer: Renderer | str = Renderer.CANVAS,
137140
width: str = "100%",
138141
height: str = "600px",
139142
pan_position: tuple[float, float] | None = None,
@@ -207,7 +210,7 @@ def render_widget(
207210
self,
208211
layout: Layout | None = None,
209212
layout_options: dict[str, Any] | LayoutOptions | None = None,
210-
renderer: Renderer = Renderer.CANVAS,
213+
renderer: Renderer | str = Renderer.CANVAS,
211214
width: str = "100%",
212215
height: str = "600px",
213216
pan_position: tuple[float, float] | None = None,

python-wrapper/src/neo4j_viz/widget.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import anywidget
88
import traitlets
99

10-
from .node import Node
10+
from .node import Node, NodeIdType
1111
from .options import RenderOptions
12-
from .relationship import Relationship
12+
from .relationship import Relationship, RelationshipIdType
1313

1414

1515
def _serialize_entity(entity: Union[Node, Relationship]) -> dict[str, Any]:
@@ -79,3 +79,72 @@ def from_graph_data(
7979
options=options.to_js_options() if options else {},
8080
theme=theme,
8181
)
82+
83+
def add_data(
84+
self, nodes: Node | list[Node] | None = None, relationships: Relationship | list[Relationship] | None = None
85+
) -> None:
86+
"""
87+
Add nodes or relationships to the graph widget.
88+
89+
Parameters
90+
-----------
91+
nodes:
92+
Nodes to add to the graph widget.
93+
relationships:
94+
Relationships to add to the graph widget.
95+
"""
96+
if isinstance(nodes, Node):
97+
nodes = [nodes]
98+
if isinstance(relationships, Relationship):
99+
relationships = [relationships]
100+
101+
if nodes:
102+
self.nodes = self.nodes + [_serialize_entity(n) for n in nodes]
103+
if relationships:
104+
self.relationships = self.relationships + [_serialize_entity(r) for r in relationships]
105+
106+
def remove_data(
107+
self,
108+
nodes: Node | list[Node | NodeIdType] | NodeIdType | None = None,
109+
relationships: Relationship | list[Relationship | RelationshipIdType] | RelationshipIdType | None = None,
110+
) -> None:
111+
"""
112+
Remove nodes or relationships from the graph widget.
113+
114+
Parameters
115+
-----------
116+
nodes:
117+
Nodes to remove from the graph widget.
118+
relationships:
119+
Relationships to remove from the graph widget.
120+
"""
121+
if isinstance(nodes, Node):
122+
node_ids_to_remove = {nodes.id}
123+
elif isinstance(nodes, NodeIdType):
124+
node_ids_to_remove = {nodes}
125+
elif nodes is None:
126+
node_ids_to_remove = set()
127+
else:
128+
node_ids_to_remove = {n.id if isinstance(n, Node) else n for n in nodes}
129+
130+
if isinstance(relationships, Relationship):
131+
rel_ids_to_remove = {relationships.id}
132+
elif isinstance(relationships, RelationshipIdType):
133+
rel_ids_to_remove = {relationships}
134+
elif relationships is None:
135+
rel_ids_to_remove = set()
136+
else:
137+
rel_ids_to_remove = {r.id if isinstance(r, Relationship) else r for r in relationships}
138+
139+
if node_ids_to_remove:
140+
self.nodes = [n for n in self.nodes if n["id"] not in node_ids_to_remove]
141+
142+
def keep_rel(r: dict[str, Any]) -> bool:
143+
return (
144+
r["id"] not in rel_ids_to_remove
145+
and r["from"] not in node_ids_to_remove
146+
and r["to"] not in node_ids_to_remove
147+
)
148+
149+
if rel_ids_to_remove:
150+
self.relationships = [r for r in self.relationships if keep_rel(r)]

python-wrapper/tests/test_widget.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,33 @@ def test_replace_all_data(self) -> None:
182182
assert len(widget.relationships) == 2
183183
assert widget.nodes[0]["id"] == "x1"
184184

185+
def test_add_data(self) -> None:
186+
"""Test adding new data to existing graph."""
187+
nodes = [Node(id="n1"), Node(id="n2")]
188+
rels = [Relationship(source="n1", target="n2")]
189+
widget = GraphWidget.from_graph_data(nodes, rels)
190+
191+
widget.add_data(Node(id="x1"), Relationship(source="x1", target="x2"))
192+
193+
assert len(widget.nodes) == 3
194+
assert len(widget.relationships) == 2
195+
196+
def test_remove_data(self) -> None:
197+
"""Test removing data from the graph."""
198+
node_1 = Node(id="n1")
199+
nodes = [node_1, Node(id="n2"), Node(id="n3")]
200+
rels = [
201+
Relationship(source="n1", target="n2"),
202+
Relationship(id=42, source="n2", target="n1"),
203+
Relationship(source="n2", target="n1"), # detach delete
204+
Relationship(id=43, source="n3", target="n3"),
205+
]
206+
widget = GraphWidget.from_graph_data(nodes, rels)
207+
208+
widget.remove_data(nodes=[node_1, "n2"], relationships=[rels[0], "42"])
209+
assert {n["id"] for n in widget.nodes} == {"n3"}
210+
assert {r["id"] for r in widget.relationships} == {"43"}
211+
185212

186213
render_widget_cases = {
187214
"default": {},

0 commit comments

Comments
 (0)