Skip to content

Commit 7c4bf69

Browse files
committed
Add viz motifs, fix deprecated scipy function
1 parent 444ee0d commit 7c4bf69

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

hypergraphx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
sys.version_info >= MIN_PYTHON_VERSION
1212
), f"requires Python {'.'.join([str(n) for n in MIN_PYTHON_VERSION])} or newer"
1313

14-
__version__ = "1.7.7"
14+
__version__ = "1.7.8"

hypergraphx/viz/draw_motifs.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from hypergraphx import Hypergraph
2+
3+
import matplotlib.pyplot as plt
4+
import networkx as nx
5+
import itertools
6+
from matplotlib.patches import Polygon
7+
8+
def draw_motifs(patterns,
9+
edge_size_colors=None,
10+
node_labels=None,
11+
node_size=500,
12+
node_color='lightblue',
13+
edge_color='black',
14+
save_path=None):
15+
# Collect all unique nodes across all patterns
16+
all_nodes = set(itertools.chain.from_iterable(itertools.chain.from_iterable(patterns)))
17+
G_global = nx.Graph()
18+
G_global.add_nodes_from(all_nodes)
19+
global_pos = nx.spring_layout(G_global, seed=42) # consistent layout
20+
21+
if edge_size_colors is None:
22+
edge_size_colors = {
23+
3: '#FFDAB9', # light orange
24+
4: '#ADD8E6' # light blue
25+
}
26+
27+
default_color = '#D3D3D3' # light gray for other sizes
28+
29+
edge_sizes = set(len(edge) for graph in patterns for edge in graph if len(edge) > 2)
30+
for size in edge_sizes:
31+
if size not in edge_size_colors:
32+
edge_size_colors[size] = default_color
33+
34+
# Set up plots
35+
num_graphs = len(patterns)
36+
fig, axes = plt.subplots(1, num_graphs, figsize=(5 * num_graphs, 5))
37+
if num_graphs == 1:
38+
axes = [axes]
39+
40+
# Plot each hypergraph
41+
for idx, (hypergraph, ax) in enumerate(zip(patterns, axes)):
42+
G = nx.Graph()
43+
nodes = set(itertools.chain.from_iterable(hypergraph))
44+
G.add_nodes_from(nodes)
45+
pos = {n: global_pos[n] for n in nodes}
46+
47+
# Draw nodes
48+
nx.draw_networkx_nodes(G, pos, ax=ax, node_size=node_size, node_color=node_color, edgecolors="black")
49+
50+
if node_labels:
51+
nx.draw_networkx_labels(G, pos, ax=ax)
52+
53+
# Draw hyperedges
54+
for i, hedge in enumerate(hypergraph):
55+
hedge_pos = [pos[n] for n in hedge]
56+
edge_size = len(hedge)
57+
58+
if edge_size < 2:
59+
continue # skip size-1
60+
61+
if edge_size == 2:
62+
# Draw as traditional edge
63+
nx.draw_networkx_edges(G, pos, edgelist=[tuple(hedge)], ax=ax, edge_color=edge_color, width=2)
64+
else:
65+
color = edge_size_colors[edge_size]
66+
polygon = Polygon(hedge_pos, closed=True, fill=True, alpha=0.3, color=color, edgecolor=edge_color)
67+
ax.add_patch(polygon)
68+
69+
ax.axis('off')
70+
71+
plt.tight_layout()
72+
if save_path:
73+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
74+
else:
75+
plt.show()

0 commit comments

Comments
 (0)