Skip to content

Commit 8868354

Browse files
authored
Merge pull request #149 from vrabiczan/Sankey-diagram
Sankey diagram
2 parents 978c7be + 332e945 commit 8868354

File tree

5 files changed

+378
-1
lines changed

5 files changed

+378
-1
lines changed

examples/visualization_examples/prepare_datasets.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,127 @@ def get_data_developer_salary_data():
165165
]]
166166

167167
return arm_df
168+
169+
170+
def get_abalone_data():
171+
# Read csv and create DataFrame
172+
df = pd.read_csv("datasets/Abalone.csv")
173+
174+
######### DISCRETIZATION #########
175+
def get_descriptive_stats(data_frame, column, bins_num):
176+
stats = data_frame[column].describe()
177+
bins_values = []
178+
if bins_num == 5:
179+
bins_values = [
180+
stats["min"],
181+
stats["25%"],
182+
stats["50%"],
183+
stats["75%"],
184+
stats["max"],
185+
stats["max"] + 0.01
186+
]
187+
elif bins_num == 3:
188+
bins_values = [
189+
stats["min"],
190+
(stats["min"] + (stats["max"] - stats["min"]) / 3),
191+
(stats["min"] + 2 * (stats["max"] - stats["min"]) / 3),
192+
stats["max"] + 0.01
193+
]
194+
195+
return bins_values
196+
197+
# LENGTH
198+
length_stats = get_descriptive_stats(df, "Length", 3)
199+
length_labels = ["Small", "Medium", "Large"]
200+
df["Length"] = pd.cut(
201+
df["Length"],
202+
bins=length_stats,
203+
labels=length_labels,
204+
include_lowest=True
205+
)
206+
207+
# DIAMETER
208+
diameter_stats = get_descriptive_stats(df, "Diameter", 3)
209+
diameter_labels = ["Small", "Medium", "Large"]
210+
df["Diameter"] = pd.cut(
211+
df["Diameter"],
212+
bins=diameter_stats,
213+
labels=diameter_labels,
214+
include_lowest=True
215+
)
216+
217+
# HEIGHT
218+
height_stats = get_descriptive_stats(df, "Height", 3)
219+
height_labels = ["Small", "Medium", "Large"]
220+
df["Height"] = pd.cut(
221+
df["Height"],
222+
bins=height_stats,
223+
labels=height_labels,
224+
include_lowest=True
225+
)
226+
227+
# WHOLE WEIGHT
228+
whole_weight_stats = get_descriptive_stats(df, "Whole weight", 3)
229+
whole_weight_labels = ["Light", "Medium", "Heavy"]
230+
df["Whole weight"] = pd.cut(
231+
df["Whole weight"],
232+
bins=whole_weight_stats,
233+
labels=whole_weight_labels,
234+
include_lowest=True
235+
)
236+
237+
# SHUCKED WEIGHT
238+
shucked_weight_stats = get_descriptive_stats(df, "Shucked weight", 3)
239+
shucked_weight_labels = ["Light", "Medium", "Heavy"]
240+
df["Shucked weight"] = pd.cut(
241+
df["Shucked weight"],
242+
bins=shucked_weight_stats,
243+
labels=shucked_weight_labels,
244+
include_lowest=True
245+
)
246+
247+
# VISCERA WEIGHT
248+
viscera_weight_stats = get_descriptive_stats(df, "Viscera weight", 3)
249+
viscera_weight_labels = ["Light", "Medium", "Heavy"]
250+
df["Viscera weight"] = pd.cut(
251+
df["Viscera weight"],
252+
bins=viscera_weight_stats,
253+
labels=viscera_weight_labels,
254+
include_lowest=True
255+
)
256+
257+
# SHELL WEIGHT
258+
shell_weight_stats = get_descriptive_stats(df, "Shell weight", 3)
259+
shell_weight_labels = ["Light", "Medium", "Heavy"]
260+
df["Shell weight"] = pd.cut(
261+
df["Shell weight"],
262+
bins=shell_weight_stats,
263+
labels=shell_weight_labels,
264+
include_lowest=True
265+
)
266+
267+
# AGE
268+
age_stats = get_descriptive_stats(df, "Rings", 3)
269+
age_labels = ["Young", "Adult", "Old"]
270+
df["Age"] = pd.cut(
271+
df["Rings"],
272+
bins=age_stats,
273+
labels=age_labels,
274+
include_lowest=True
275+
)
276+
277+
# Select relevant columns for ARM
278+
arm_df = df[[
279+
"Sex",
280+
"Length",
281+
"Diameter",
282+
"Height",
283+
"Whole weight",
284+
"Shucked weight",
285+
"Viscera weight",
286+
"Shell weight",
287+
"Age"
288+
]]
289+
290+
return arm_df
291+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from examples.visualization_examples.prepare_datasets import get_abalone_data
2+
from niaarm import Dataset, get_rules
3+
from niaarm.visualize import sankey_diagram
4+
5+
# Get prepared data developer salary data
6+
arm_df = get_abalone_data()
7+
8+
# Prepare Dataset
9+
dataset = Dataset(
10+
path_or_df=arm_df,
11+
delimiter=","
12+
)
13+
14+
# Get rules
15+
metrics = ("support", "confidence")
16+
rules, run_time = get_rules(
17+
dataset=dataset,
18+
algorithm="DifferentialEvolution",
19+
metrics=metrics,
20+
max_evals=500
21+
)
22+
23+
# Sort rules
24+
rules.sort(by="support")
25+
# Print rule information
26+
print("\nRules:")
27+
print(rules)
28+
print(f'\nTime to generate rules: {f"{run_time:.3f}"} seconds')
29+
print("\nRule information: ", rules[3])
30+
print("Antecedent: ", rules[3].antecedent)
31+
print("Consequent: ", rules[3].consequent)
32+
print("Confidence: ", rules[3].confidence)
33+
print("Support: ", rules[3].support)
34+
print("Lift: ", rules[3].lift)
35+
print("\nMetrics:", metrics)
36+
37+
# Visualize sankey diagram
38+
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
39+
fig.show()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from examples.visualization_examples.prepare_datasets import get_weather_data
2+
from niaarm import Dataset, get_rules
3+
from niaarm.visualize import sankey_diagram
4+
5+
# Get prepared weather data
6+
arm_df = get_weather_data()
7+
8+
# Prepare Dataset
9+
dataset = Dataset(
10+
path_or_df=arm_df,
11+
delimiter=","
12+
)
13+
14+
# Get rules
15+
metrics = ("support", "confidence")
16+
rules, run_time = get_rules(
17+
dataset=dataset,
18+
algorithm="DifferentialEvolution",
19+
metrics=metrics,
20+
max_evals=500
21+
)
22+
23+
# Add lift after the rules have been generated
24+
# Cannot be in metrics before because get_rules metrics doesn't contain lift, therefore we need to add after
25+
metrics = list(metrics)
26+
metrics.append("lift")
27+
metrics = tuple(metrics)
28+
29+
# Sort rules
30+
rules.sort(by="support")
31+
# Print rule information
32+
print("\nRules:")
33+
print(rules)
34+
print(f'\nTime to generate rules: {f"{run_time:.3f}"} seconds')
35+
print("\nRule information: ", rules[3])
36+
print("Antecedent: ", rules[3].antecedent)
37+
print("Consequent: ", rules[3].consequent)
38+
print("Confidence: ", rules[3].confidence)
39+
print("Support: ", rules[3].support)
40+
print("Lift: ", rules[3].lift)
41+
print("\nMetrics:", metrics)
42+
43+
# Visualize sankey diagram
44+
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
45+
fig.show()

niaarm/visualize.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from matplotlib.colors import Normalize
44
import numpy as np
55
import plotly.express as px
6+
import plotly.graph_objects as go
67
import pandas as pd
78
from sklearn.cluster import KMeans
9+
from itertools import combinations
810

911

1012
def hill_slopes(rule, transactions):
@@ -554,4 +556,124 @@ def prepare_data(rules, metrics):
554556
plt.legend(title="Order")
555557
plt.grid(True)
556558
return plt
557-
559+
560+
561+
def sankey_diagram(rules, interestingness_measure, M=4):
562+
"""
563+
Visualize rules as a sankey diagram.
564+
565+
Args:
566+
rules (Rule): Association rule or rules to visualize.
567+
interestingness_measures (str): Interestingness measure Z = {supp, cons, lift},reflecting the quality of a particular connection.
568+
m (int): Maximum number of rules to be selected for visualization. Default: 4
569+
570+
Returns:
571+
Figure or plot.
572+
"""
573+
574+
575+
def compute_similarity(rule1, rule2):
576+
"""Compute similarity between two rules."""
577+
ant_inter = len(set(str(rule1.antecedent)) & set(str(rule2.antecedent)))
578+
ant_union = len(set(str(rule1.antecedent)) | set(str(rule2.antecedent)))
579+
con_inter = len(set(str(rule1.consequent)) & set(str(rule2.consequent)))
580+
con_union = len(set(str(rule1.consequent)) | set(str(rule2.consequent)))
581+
return (ant_inter + con_inter) / (ant_union + con_union)
582+
583+
def build_adjacency_matrix(rules):
584+
size = len(rules)
585+
adjacency_matrix = np.zeros((size, size))
586+
587+
for i, j in combinations(range(size), 2):
588+
similarity = compute_similarity(rules[i], rules[j])
589+
adjacency_matrix[i, j] = similarity
590+
adjacency_matrix[j, i] = similarity
591+
592+
return adjacency_matrix
593+
594+
def knapsack_selection(adj_matrix, rules, M):
595+
fitness_scores = np.array([rule.fitness for rule in rules])
596+
N = len(rules) # number of rules
597+
weights = np.ones(N) # all rules have the same weight
598+
similarity_weight = 1.0
599+
fitness_weight = 0.5
600+
combined_profits = similarity_weight * np.sum(adj_matrix) + fitness_weight * fitness_scores # combined similarities with fitness for values
601+
602+
selected = np.zeros(N, dtype=int)
603+
604+
# Initialize DP table
605+
dp = np.zeros((N + 1, M + 1))
606+
for i in range(1, N + 1):
607+
for w in range(1, M + 1):
608+
if weights[i - 1] <= w:
609+
dp[i, w] = max(dp[i - 1, w], dp[i - 1, w - 1] + combined_profits[i - 1])
610+
else:
611+
dp[i, w] = dp[i - 1, w]
612+
613+
# Backtrack to find selected rules
614+
w = M
615+
for i in range(N, 0, -1):
616+
if dp[i, w] != dp[i - 1, w]:
617+
selected[i - 1] = 1
618+
w -= 1
619+
620+
selected_rules = [rules[i] for i in range(N) if selected[i]]
621+
622+
return selected_rules
623+
624+
def prepare_data(rules, M, interestingness_measure):
625+
if not rules:
626+
return [], [], [], []
627+
628+
adj_matrix = build_adjacency_matrix(rules)
629+
selected_rules = knapsack_selection(adj_matrix, rules, M)
630+
631+
sources=[]
632+
targets=[]
633+
values=[]
634+
labels=[]
635+
node_indices = {}
636+
637+
for rule in selected_rules:
638+
# Ensure all antecedents and consequents exist in the node list
639+
for item in rule.antecedent + rule.consequent:
640+
item_str = str(item)
641+
if item_str not in node_indices:
642+
node_indices[item_str] = len(labels)
643+
labels.append(item_str)
644+
645+
# Connect each antecedent to each consequent
646+
for antecedent in rule.antecedent:
647+
for consequent in rule.consequent:
648+
sources.append(node_indices[str(antecedent)])
649+
targets.append(node_indices[str(consequent)])
650+
651+
# Assign measure value for each connection
652+
if hasattr(rule, interestingness_measure):
653+
measure_value = getattr(rule, interestingness_measure)
654+
else:
655+
measure_value = rule.support # Default support
656+
values.append(measure_value)
657+
658+
return labels, sources, targets, values
659+
660+
labels, sources, targets, values = prepare_data(rules, M, interestingness_measure)
661+
662+
# Visualization using Plotly
663+
fig = go.Figure(go.Sankey(
664+
node=dict(
665+
pad=15,
666+
thickness=20,
667+
line=dict(color='black', width=0.5),
668+
label=labels
669+
),
670+
link=dict(
671+
source=sources,
672+
target=targets,
673+
value=values
674+
)
675+
))
676+
fig.update_layout(title_text=f'Sankey Diagram of Association Rules ({interestingness_measure})', font_size=10)
677+
678+
return fig
679+

0 commit comments

Comments
 (0)