Skip to content

Commit feec9fc

Browse files
committed
expand graph plot to continuous feature color
1 parent c5b8602 commit feec9fc

File tree

1 file changed

+102
-37
lines changed

1 file changed

+102
-37
lines changed
Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,162 @@
11
import matplotlib.pyplot as plt
22
import networkx as nx
33
import numpy as np
4-
from matplotlib import cm
4+
from matplotlib import cm, colormaps
5+
from matplotlib.colors import Normalize
56
from matplotlib.patches import Patch
67
from graphconstructor import Graph
78

89

910
def plot_graph_by_feature(
1011
G: Graph,
11-
class_attr: str = None,
12+
color_attribute: str = None,
13+
attribute_type: str = "categorical",
1214
pos=None,
1315
cmap_name: str = "tab20",
1416
default_color="teal",
1517
with_labels: bool = True):
1618
"""
17-
Color nodes by the categorical attribute stored on each node (e.g., node['cf_class']).
18-
19+
Color nodes by the selected attribute stored on each node (e.g., node['cf_class']).
20+
1921
Parameters
2022
----------
2123
G : Graph
2224
Graph whose nodes carry the class attribute.
23-
class_attr : str
25+
color_attribute : str
2426
Node attribute name with the class label (default: 'cf_class').
27+
attribute_type : str
28+
'categorical' or 'continuous'. This will determine the legend used.
2529
pos : dict or None
2630
Optional positions dict; if None, uses nx.spring_layout.
2731
cmap_name : str
28-
Matplotlib categorical colormap (e.g., 'tab20', 'tab10', 'Set3').
32+
Matplotlib colormap (e.g., 'tab20' for categorical, 'viridis' for continuous).
2933
default_color : str
3034
Color for nodes missing the class attribute.
3135
with_labels : bool
3236
Draw node labels.
3337
"""
38+
if attribute_type not in {"categorical", "continuous"}:
39+
raise ValueError("attribute_type must be 'categorical' or 'continuous'")
40+
3441
nxG = G.to_networkx()
35-
36-
# Collect class labels for nodes (in node order)
3742
node_list = list(nxG.nodes())
38-
if class_attr:
39-
node_classes = [nxG.nodes[n].get(class_attr, None) for n in node_list]
40-
41-
# Stable set of unique classes (preserve first-seen order, skip None)
42-
unique_classes = [c for c in dict.fromkeys(node_classes) if c is not None]
43-
unique_classes.sort()
4443

45-
# Map classes -> colors
46-
if unique_classes:
47-
cmap = cm.get_cmap(cmap_name, len(unique_classes))
48-
class_to_color = {c: cmap(i) for i, c in enumerate(unique_classes)}
49-
else:
50-
class_to_color = {}
44+
# Initialize variables for continuous colorbar
45+
norm = None
46+
cmap_continuous = None
5147

52-
node_colors = [class_to_color.get(c, default_color) for c in node_classes]
48+
if color_attribute:
49+
node_features = [nxG.nodes[n].get(color_attribute, None) for n in node_list]
50+
51+
if attribute_type == "categorical":
52+
# Stable set of unique classes (preserve first-seen order, skip None)
53+
unique_classes = [c for c in dict.fromkeys(node_features) if c is not None]
54+
unique_classes.sort()
55+
56+
# Map classes -> colors
57+
if unique_classes:
58+
cmap = colormaps.get_cmap(cmap_name, len(unique_classes))
59+
class_to_color = {c: cmap(i) for i, c in enumerate(unique_classes)}
60+
else:
61+
class_to_color = {}
62+
63+
node_colors = [class_to_color.get(c, default_color) for c in node_features]
64+
65+
elif attribute_type == "continuous":
66+
# Filter out None values to find min/max
67+
valid_values = [v for v in node_features if v is not None]
68+
69+
if valid_values:
70+
# Convert to numeric (in case they aren't already)
71+
try:
72+
valid_values = [float(v) for v in valid_values]
73+
vmin, vmax = min(valid_values), max(valid_values)
74+
75+
# Create normalization and colormap for continuous scale
76+
norm = Normalize(vmin=vmin, vmax=vmax)
77+
cmap_continuous = colormaps.get_cmap(cmap_name)
78+
79+
# Map node values to colors
80+
node_colors = []
81+
for val in node_features:
82+
if val is not None:
83+
try:
84+
node_colors.append(cmap_continuous(norm(float(val))))
85+
except (ValueError, TypeError):
86+
node_colors.append(default_color)
87+
else:
88+
node_colors.append(default_color)
89+
90+
unique_classes = True # Flag to indicate we have valid data
91+
except (ValueError, TypeError):
92+
# Fall back to default color if conversion fails
93+
node_colors = [default_color] * len(node_list)
94+
unique_classes = False
95+
else:
96+
node_colors = [default_color] * len(node_list)
97+
unique_classes = False
5398
else:
5499
node_colors = default_color
55100
unique_classes = False
56101

102+
# Handle edge weights
57103
if G.weighted:
58104
edge_weights = [d.get("weight", 1.0) for _, _, d in nxG.edges(data=True)]
59-
# Scale edge widths for visibility; tweak as needed
60-
edge_widths = [0.5 + 5.0 * (w / max(edge_weights)) for w in edge_weights]
61-
62-
# --- Node sizes (optional): use degree for a bit of visual structure ---
105+
if edge_weights:
106+
max_weight = max(edge_weights)
107+
edge_widths = [0.5 + 5.0 * (w / max_weight) for w in edge_weights]
108+
edge_colors = [cm.gray(w/max_weight) for w in edge_weights]
109+
else:
110+
edge_widths = 1.0
111+
edge_colors = "gray"
112+
else:
113+
edge_widths = 1.0
114+
edge_colors = "gray"
115+
116+
# Node sizes based on degree
63117
degrees = dict(nxG.degree())
64-
# Scale size gently: 200 for degree 0, 200*(1+sqrt(deg)) otherwise
65118
node_sizes = [200.0 * (1.0 + np.sqrt(degrees.get(n, 0))) for n in nxG.nodes()]
66-
119+
67120
# Layout
68121
if pos is None:
69122
pos = nx.spring_layout(nxG, seed=42)
70-
71-
# Figure size similar to your original heuristic
123+
124+
# Figure size
72125
size = (len(node_list) ** 0.5)
73126
fig, ax = plt.subplots(figsize=(size, size))
74-
127+
128+
# Draw the graph
75129
nx.draw(
76130
nxG,
77131
pos=pos,
78132
ax=ax,
79133
with_labels=with_labels,
80134
node_color=node_colors,
81135
node_size=node_sizes,
82-
edge_color="gray" if not G.weighted else [cm.gray(w/max(edge_weights)) for w in edge_weights],
83-
width=1.0 if not G.weighted else edge_widths,
136+
edge_color=edge_colors,
137+
width=edge_widths,
84138
alpha=0.85,
85139
linewidths=0.5,
86140
font_size=8,
87141
)
88-
89-
# Legend
142+
143+
# Legend or Colorbar
90144
if unique_classes:
91-
handles = [Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c)) for c in unique_classes]
92-
ax.legend(handles=handles, title=class_attr, loc="best", frameon=True)
93-
145+
if attribute_type == "categorical":
146+
# Categorical legend with patches
147+
handles = [Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c))
148+
for c in unique_classes]
149+
ax.legend(handles=handles, title=color_attribute, loc="best", frameon=True)
150+
151+
elif attribute_type == "continuous" and norm is not None and cmap_continuous is not None:
152+
# Continuous colorbar
153+
sm = cm.ScalarMappable(cmap=cmap_continuous, norm=norm)
154+
sm.set_array([])
155+
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
156+
cbar.set_label(color_attribute, rotation=270, labelpad=15)
157+
94158
ax.set_axis_off()
95159
fig.tight_layout()
96160
plt.show()
161+
97162
return fig, ax

0 commit comments

Comments
 (0)