|
1 | 1 | import matplotlib.pyplot as plt |
2 | 2 | import networkx as nx |
3 | 3 | import numpy as np |
4 | | -from matplotlib import cm |
| 4 | +from matplotlib import cm, colormaps |
| 5 | +from matplotlib.colors import Normalize |
5 | 6 | from matplotlib.patches import Patch |
6 | 7 | from graphconstructor import Graph |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def plot_graph_by_feature( |
10 | 11 | G: Graph, |
11 | | - class_attr: str = None, |
| 12 | + color_attribute: str = None, |
| 13 | + attribute_type: str = "categorical", |
12 | 14 | pos=None, |
13 | 15 | cmap_name: str = "tab20", |
14 | 16 | default_color="teal", |
15 | 17 | with_labels: bool = True): |
16 | 18 | """ |
17 | | - Color nodes by the categorical attribute stored on each node (e.g., node['cf_class']). |
18 | | -
|
| 19 | + Color nodes by the selected attribute stored on each node (e.g., node['cf_class']). |
| 20 | + |
19 | 21 | Parameters |
20 | 22 | ---------- |
21 | 23 | G : Graph |
22 | 24 | Graph whose nodes carry the class attribute. |
23 | | - class_attr : str |
| 25 | + color_attribute : str |
24 | 26 | Node attribute name with the class label (default: 'cf_class'). |
| 27 | + attribute_type : str |
| 28 | + 'categorical' or 'continuous'. This will determine the legend used. |
25 | 29 | pos : dict or None |
26 | 30 | Optional positions dict; if None, uses nx.spring_layout. |
27 | 31 | cmap_name : str |
28 | | - Matplotlib categorical colormap (e.g., 'tab20', 'tab10', 'Set3'). |
| 32 | + Matplotlib colormap (e.g., 'tab20' for categorical, 'viridis' for continuous). |
29 | 33 | default_color : str |
30 | 34 | Color for nodes missing the class attribute. |
31 | 35 | with_labels : bool |
32 | 36 | Draw node labels. |
33 | 37 | """ |
| 38 | + if attribute_type not in {"categorical", "continuous"}: |
| 39 | + raise ValueError("attribute_type must be 'categorical' or 'continuous'") |
| 40 | + |
34 | 41 | nxG = G.to_networkx() |
35 | | - |
36 | | - # Collect class labels for nodes (in node order) |
37 | 42 | node_list = list(nxG.nodes()) |
38 | | - if class_attr: |
39 | | - node_classes = [nxG.nodes[n].get(class_attr, None) for n in node_list] |
40 | | - |
41 | | - # Stable set of unique classes (preserve first-seen order, skip None) |
42 | | - unique_classes = [c for c in dict.fromkeys(node_classes) if c is not None] |
43 | | - unique_classes.sort() |
44 | 43 |
|
45 | | - # Map classes -> colors |
46 | | - if unique_classes: |
47 | | - cmap = cm.get_cmap(cmap_name, len(unique_classes)) |
48 | | - class_to_color = {c: cmap(i) for i, c in enumerate(unique_classes)} |
49 | | - else: |
50 | | - class_to_color = {} |
| 44 | + # Initialize variables for continuous colorbar |
| 45 | + norm = None |
| 46 | + cmap_continuous = None |
51 | 47 |
|
52 | | - node_colors = [class_to_color.get(c, default_color) for c in node_classes] |
| 48 | + if color_attribute: |
| 49 | + node_features = [nxG.nodes[n].get(color_attribute, None) for n in node_list] |
| 50 | + |
| 51 | + if attribute_type == "categorical": |
| 52 | + # Stable set of unique classes (preserve first-seen order, skip None) |
| 53 | + unique_classes = [c for c in dict.fromkeys(node_features) if c is not None] |
| 54 | + unique_classes.sort() |
| 55 | + |
| 56 | + # Map classes -> colors |
| 57 | + if unique_classes: |
| 58 | + cmap = colormaps.get_cmap(cmap_name, len(unique_classes)) |
| 59 | + class_to_color = {c: cmap(i) for i, c in enumerate(unique_classes)} |
| 60 | + else: |
| 61 | + class_to_color = {} |
| 62 | + |
| 63 | + node_colors = [class_to_color.get(c, default_color) for c in node_features] |
| 64 | + |
| 65 | + elif attribute_type == "continuous": |
| 66 | + # Filter out None values to find min/max |
| 67 | + valid_values = [v for v in node_features if v is not None] |
| 68 | + |
| 69 | + if valid_values: |
| 70 | + # Convert to numeric (in case they aren't already) |
| 71 | + try: |
| 72 | + valid_values = [float(v) for v in valid_values] |
| 73 | + vmin, vmax = min(valid_values), max(valid_values) |
| 74 | + |
| 75 | + # Create normalization and colormap for continuous scale |
| 76 | + norm = Normalize(vmin=vmin, vmax=vmax) |
| 77 | + cmap_continuous = colormaps.get_cmap(cmap_name) |
| 78 | + |
| 79 | + # Map node values to colors |
| 80 | + node_colors = [] |
| 81 | + for val in node_features: |
| 82 | + if val is not None: |
| 83 | + try: |
| 84 | + node_colors.append(cmap_continuous(norm(float(val)))) |
| 85 | + except (ValueError, TypeError): |
| 86 | + node_colors.append(default_color) |
| 87 | + else: |
| 88 | + node_colors.append(default_color) |
| 89 | + |
| 90 | + unique_classes = True # Flag to indicate we have valid data |
| 91 | + except (ValueError, TypeError): |
| 92 | + # Fall back to default color if conversion fails |
| 93 | + node_colors = [default_color] * len(node_list) |
| 94 | + unique_classes = False |
| 95 | + else: |
| 96 | + node_colors = [default_color] * len(node_list) |
| 97 | + unique_classes = False |
53 | 98 | else: |
54 | 99 | node_colors = default_color |
55 | 100 | unique_classes = False |
56 | 101 |
|
| 102 | + # Handle edge weights |
57 | 103 | if G.weighted: |
58 | 104 | edge_weights = [d.get("weight", 1.0) for _, _, d in nxG.edges(data=True)] |
59 | | - # Scale edge widths for visibility; tweak as needed |
60 | | - edge_widths = [0.5 + 5.0 * (w / max(edge_weights)) for w in edge_weights] |
61 | | - |
62 | | - # --- Node sizes (optional): use degree for a bit of visual structure --- |
| 105 | + if edge_weights: |
| 106 | + max_weight = max(edge_weights) |
| 107 | + edge_widths = [0.5 + 5.0 * (w / max_weight) for w in edge_weights] |
| 108 | + edge_colors = [cm.gray(w/max_weight) for w in edge_weights] |
| 109 | + else: |
| 110 | + edge_widths = 1.0 |
| 111 | + edge_colors = "gray" |
| 112 | + else: |
| 113 | + edge_widths = 1.0 |
| 114 | + edge_colors = "gray" |
| 115 | + |
| 116 | + # Node sizes based on degree |
63 | 117 | degrees = dict(nxG.degree()) |
64 | | - # Scale size gently: 200 for degree 0, 200*(1+sqrt(deg)) otherwise |
65 | 118 | node_sizes = [200.0 * (1.0 + np.sqrt(degrees.get(n, 0))) for n in nxG.nodes()] |
66 | | - |
| 119 | + |
67 | 120 | # Layout |
68 | 121 | if pos is None: |
69 | 122 | pos = nx.spring_layout(nxG, seed=42) |
70 | | - |
71 | | - # Figure size similar to your original heuristic |
| 123 | + |
| 124 | + # Figure size |
72 | 125 | size = (len(node_list) ** 0.5) |
73 | 126 | fig, ax = plt.subplots(figsize=(size, size)) |
74 | | - |
| 127 | + |
| 128 | + # Draw the graph |
75 | 129 | nx.draw( |
76 | 130 | nxG, |
77 | 131 | pos=pos, |
78 | 132 | ax=ax, |
79 | 133 | with_labels=with_labels, |
80 | 134 | node_color=node_colors, |
81 | 135 | node_size=node_sizes, |
82 | | - edge_color="gray" if not G.weighted else [cm.gray(w/max(edge_weights)) for w in edge_weights], |
83 | | - width=1.0 if not G.weighted else edge_widths, |
| 136 | + edge_color=edge_colors, |
| 137 | + width=edge_widths, |
84 | 138 | alpha=0.85, |
85 | 139 | linewidths=0.5, |
86 | 140 | font_size=8, |
87 | 141 | ) |
88 | | - |
89 | | - # Legend |
| 142 | + |
| 143 | + # Legend or Colorbar |
90 | 144 | if unique_classes: |
91 | | - handles = [Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c)) for c in unique_classes] |
92 | | - ax.legend(handles=handles, title=class_attr, loc="best", frameon=True) |
93 | | - |
| 145 | + if attribute_type == "categorical": |
| 146 | + # Categorical legend with patches |
| 147 | + handles = [Patch(facecolor=class_to_color[c], edgecolor="none", label=str(c)) |
| 148 | + for c in unique_classes] |
| 149 | + ax.legend(handles=handles, title=color_attribute, loc="best", frameon=True) |
| 150 | + |
| 151 | + elif attribute_type == "continuous" and norm is not None and cmap_continuous is not None: |
| 152 | + # Continuous colorbar |
| 153 | + sm = cm.ScalarMappable(cmap=cmap_continuous, norm=norm) |
| 154 | + sm.set_array([]) |
| 155 | + cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04) |
| 156 | + cbar.set_label(color_attribute, rotation=270, labelpad=15) |
| 157 | + |
94 | 158 | ax.set_axis_off() |
95 | 159 | fig.tight_layout() |
96 | 160 | plt.show() |
| 161 | + |
97 | 162 | return fig, ax |
0 commit comments