Skip to content

Commit 0edbdd9

Browse files
author
Zan Vrabic
committed
Sankey diagram v1 and example
1 parent a6e3dfa commit 0edbdd9

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from examples.visualization_examples.prepare_datasets import get_weather_data
2+
from niaarm import Dataset, get_rules
3+
from niaarm.visualize import sankey_diagram
4+
5+
# Get prepared weather data
6+
arm_df = get_weather_data()
7+
8+
# Prepare Dataset
9+
dataset = Dataset(
10+
path_or_df=arm_df,
11+
delimiter=","
12+
)
13+
14+
# Get rules
15+
metrics = ("support", "confidence")
16+
rules, run_time = get_rules(
17+
dataset=dataset,
18+
algorithm="DifferentialEvolution",
19+
metrics=metrics,
20+
max_evals=500
21+
)
22+
23+
# Add lift after the rules have been generated
24+
# Cannot be in metrics before because get_rules metrics doesn't contain lift, therefore we need to add after
25+
metrics = list(metrics)
26+
metrics.append("lift")
27+
metrics = tuple(metrics)
28+
29+
# Sort rules
30+
rules.sort(by="support")
31+
# Print rule information
32+
print("\nRules:")
33+
print(rules)
34+
print(f'\nTime to generate rules: {f"{run_time:.3f}"} seconds')
35+
print("\nRule information: ", rules[3])
36+
print("Antecedent: ", rules[3].antecedent)
37+
print("Consequent: ", rules[3].consequent)
38+
print("Confidence: ", rules[3].confidence)
39+
print("Support: ", rules[3].support)
40+
print("Lift: ", rules[3].lift)
41+
print("\nMetrics:", metrics)
42+
43+
# Visualize scatter plot
44+
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
45+
fig.show()

niaarm/visualize.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from matplotlib.colors import Normalize
44
import numpy as np
55
import plotly.express as px
6+
import plotly.graph_objects as go
67
import pandas as pd
78
from sklearn.cluster import KMeans
9+
from itertools import combinations
810

911

1012
def hill_slopes(rule, transactions):
@@ -554,4 +556,115 @@ def prepare_data(rules, metrics):
554556
plt.legend(title="Order")
555557
plt.grid(True)
556558
return plt
557-
559+
560+
561+
def sankey_diagram(rules, interestingness_measure, M=4):
562+
"""
563+
Visualize rules as a sankey diagram.
564+
565+
Args:
566+
rules (Rule): Association rule or rules to visualize.
567+
interestingness_measures (str): Interestingness measure Z = {supp, cons, lift},reflecting the quality of a particular connection.
568+
m (int): Maximum number of rules to be selected for visualization. Default: 4
569+
570+
Returns:
571+
Figure or plot.
572+
"""
573+
574+
575+
def compute_similarity(rule1, rule2):
576+
"""Compute similarity between two rules."""
577+
ant_inter = len(set(str(rule1.antecedent)) & set(str(rule2.antecedent)))
578+
ant_union = len(set(str(rule1.antecedent)) | set(str(rule2.antecedent)))
579+
con_inter = len(set(str(rule1.consequent)) & set(str(rule2.consequent)))
580+
con_union = len(set(str(rule1.consequent)) | set(str(rule2.consequent)))
581+
return (ant_inter + con_inter) / (ant_union + con_union)
582+
583+
def build_adjacency_matrix(rules):
584+
size = len(rules)
585+
adjacency_matrix = np.zeros((size, size))
586+
587+
for i, j in combinations(range(size), 2):
588+
similarity = compute_similarity(rules[i], rules[j])
589+
adjacency_matrix[i, j] = similarity
590+
adjacency_matrix[j, i] = similarity
591+
592+
return adjacency_matrix
593+
594+
def knapsack_selection(adj_matrix, rules, M):
595+
fitness_scores = np.array([rule.fitness for rule in rules])
596+
N = len(rules)
597+
weights = np.ones(N)
598+
similarity_weight = 1.0
599+
fitness_weight = 0.5
600+
combined_profits = similarity_weight * np.sum(adj_matrix) + fitness_weight * fitness_scores
601+
602+
selected = np.zeros(N, dtype=int)
603+
604+
# Initialize DP table
605+
dp = np.zeros((N + 1, M + 1))
606+
for i in range(1, N + 1):
607+
for w in range(1, M + 1):
608+
if weights[i - 1] <= w:
609+
dp[i, w] = max(dp[i - 1, w], dp[i - 1, w - 1] + combined_profits[i - 1])
610+
else:
611+
dp[i, w] = dp[i - 1, w]
612+
613+
# Backtrack to find selected rules
614+
w = M
615+
for i in range(N, 0, -1):
616+
if dp[i, w] != dp[i - 1, w]:
617+
selected[i - 1] = 1
618+
w -= 1
619+
620+
selected_rules = [rules[i] for i in range(N) if selected[i]]
621+
622+
return selected_rules
623+
624+
def prepare_data(rules, M, interestingness_measure):
625+
adj_matrix = build_adjacency_matrix(rules)
626+
selected_rules = knapsack_selection(adj_matrix, rules, M)
627+
628+
sources=[]
629+
targets=[]
630+
values=[]
631+
labels=[]
632+
node_indices = {}
633+
634+
for rule in selected_rules:
635+
for antecedent in rule.antecedent:
636+
if str(antecedent) not in node_indices:
637+
node_indices[str(antecedent)] = len(labels)
638+
labels.append(str(antecedent))
639+
sources.append(node_indices[str(antecedent)])
640+
641+
for consequent in rule.consequent:
642+
if str(consequent) not in node_indices:
643+
node_indices[str(consequent)] = len(labels)
644+
labels.append(str(consequent))
645+
targets.append(node_indices[str(consequent)])
646+
647+
measure_value = getattr(rule, interestingness_measure, rule.support) #default support
648+
values.append(measure_value)
649+
650+
return labels, sources, targets, values
651+
652+
labels, sources, targets, values = prepare_data(rules, M, interestingness_measure)
653+
654+
fig = go.Figure(go.Sankey(
655+
node=dict(
656+
pad=15,
657+
thickness=20,
658+
line=dict(color='black', width=0.5),
659+
label=labels
660+
),
661+
link=dict(
662+
source=sources,
663+
target=targets,
664+
value=values
665+
)
666+
))
667+
fig.update_layout(title_text=f'Sankey Diagram of Association Rules ({interestingness_measure})', font_size=10)
668+
669+
return fig
670+

0 commit comments

Comments
 (0)