Skip to content

Commit b2aab77

Browse files
author
Zan Vrabic
committed
Examples, added comments and tests
1 parent 0edbdd9 commit b2aab77

File tree

5 files changed

+222
-6
lines changed

5 files changed

+222
-6
lines changed

examples/visualization_examples/prepare_datasets.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,127 @@ def get_data_developer_salary_data():
165165
]]
166166

167167
return arm_df
168+
169+
170+
def get_abalone_data():
171+
# Read csv and create DataFrame
172+
df = pd.read_csv("datasets/Abalone.csv")
173+
174+
######### DISCRETIZATION #########
175+
def get_descriptive_stats(data_frame, column, bins_num):
176+
stats = data_frame[column].describe()
177+
bins_values = []
178+
if bins_num == 5:
179+
bins_values = [
180+
stats["min"],
181+
stats["25%"],
182+
stats["50%"],
183+
stats["75%"],
184+
stats["max"],
185+
stats["max"] + 0.01
186+
]
187+
elif bins_num == 3:
188+
bins_values = [
189+
stats["min"],
190+
(stats["min"] + (stats["max"] - stats["min"]) / 3),
191+
(stats["min"] + 2 * (stats["max"] - stats["min"]) / 3),
192+
stats["max"] + 0.01
193+
]
194+
195+
return bins_values
196+
197+
# LENGTH
198+
length_stats = get_descriptive_stats(df, "Length", 3)
199+
length_labels = ["Small", "Medium", "Large"]
200+
df["Length"] = pd.cut(
201+
df["Length"],
202+
bins=length_stats,
203+
labels=length_labels,
204+
include_lowest=True
205+
)
206+
207+
# DIAMETER
208+
diameter_stats = get_descriptive_stats(df, "Diameter", 3)
209+
diameter_labels = ["Small", "Medium", "Large"]
210+
df["Diameter"] = pd.cut(
211+
df["Diameter"],
212+
bins=diameter_stats,
213+
labels=diameter_labels,
214+
include_lowest=True
215+
)
216+
217+
# HEIGHT
218+
height_stats = get_descriptive_stats(df, "Height", 3)
219+
height_labels = ["Small", "Medium", "Large"]
220+
df["Height"] = pd.cut(
221+
df["Height"],
222+
bins=height_stats,
223+
labels=height_labels,
224+
include_lowest=True
225+
)
226+
227+
# WHOLE WEIGHT
228+
whole_weight_stats = get_descriptive_stats(df, "Whole weight", 3)
229+
whole_weight_labels = ["Light", "Medium", "Heavy"]
230+
df["Whole weight"] = pd.cut(
231+
df["Whole weight"],
232+
bins=whole_weight_stats,
233+
labels=whole_weight_labels,
234+
include_lowest=True
235+
)
236+
237+
# SHUCKED WEIGHT
238+
shucked_weight_stats = get_descriptive_stats(df, "Shucked weight", 3)
239+
shucked_weight_labels = ["Light", "Medium", "Heavy"]
240+
df["Shucked weight"] = pd.cut(
241+
df["Shucked weight"],
242+
bins=shucked_weight_stats,
243+
labels=shucked_weight_labels,
244+
include_lowest=True
245+
)
246+
247+
# VISCERA WEIGHT
248+
viscera_weight_stats = get_descriptive_stats(df, "Viscera weight", 3)
249+
viscera_weight_labels = ["Light", "Medium", "Heavy"]
250+
df["Viscera weight"] = pd.cut(
251+
df["Viscera weight"],
252+
bins=viscera_weight_stats,
253+
labels=viscera_weight_labels,
254+
include_lowest=True
255+
)
256+
257+
# SHELL WEIGHT
258+
shell_weight_stats = get_descriptive_stats(df, "Shell weight", 3)
259+
shell_weight_labels = ["Light", "Medium", "Heavy"]
260+
df["Shell weight"] = pd.cut(
261+
df["Shell weight"],
262+
bins=shell_weight_stats,
263+
labels=shell_weight_labels,
264+
include_lowest=True
265+
)
266+
267+
# AGE
268+
age_stats = get_descriptive_stats(df, "Rings", 3)
269+
age_labels = ["Young", "Adult", "Old"]
270+
df["Age"] = pd.cut(
271+
df["Rings"],
272+
bins=age_stats,
273+
labels=age_labels,
274+
include_lowest=True
275+
)
276+
277+
# Select relevant columns for ARM
278+
arm_df = df[[
279+
"Sex",
280+
"Length",
281+
"Diameter",
282+
"Height",
283+
"Whole weight",
284+
"Shucked weight",
285+
"Viscera weight",
286+
"Shell weight",
287+
"Age"
288+
]]
289+
290+
return arm_df
291+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from examples.visualization_examples.prepare_datasets import get_abalone_data
2+
from niaarm import Dataset, get_rules
3+
from niaarm.visualize import sankey_diagram
4+
5+
# Get prepared data developer salary data
6+
arm_df = get_abalone_data()
7+
8+
# Prepare Dataset
9+
dataset = Dataset(
10+
path_or_df=arm_df,
11+
delimiter=","
12+
)
13+
14+
# Get rules
15+
metrics = ("support", "confidence")
16+
rules, run_time = get_rules(
17+
dataset=dataset,
18+
algorithm="DifferentialEvolution",
19+
metrics=metrics,
20+
max_evals=500
21+
)
22+
23+
# Sort rules
24+
rules.sort(by="support")
25+
# Print rule information
26+
print("\nRules:")
27+
print(rules)
28+
print(f'\nTime to generate rules: {f"{run_time:.3f}"} seconds')
29+
print("\nRule information: ", rules[3])
30+
print("Antecedent: ", rules[3].antecedent)
31+
print("Consequent: ", rules[3].consequent)
32+
print("Confidence: ", rules[3].confidence)
33+
print("Support: ", rules[3].support)
34+
print("Lift: ", rules[3].lift)
35+
print("\nMetrics:", metrics)
36+
37+
# Visualize sankey diagram
38+
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
39+
fig.show()

examples/visualization_examples/sankey_diagram/weather_data_sankey_diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@
4040
print("Lift: ", rules[3].lift)
4141
print("\nMetrics:", metrics)
4242

43-
# Visualize scatter plot
43+
# Visualize sankey diagram
4444
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
4545
fig.show()

niaarm/visualize.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -593,11 +593,11 @@ def build_adjacency_matrix(rules):
593593

594594
def knapsack_selection(adj_matrix, rules, M):
595595
fitness_scores = np.array([rule.fitness for rule in rules])
596-
N = len(rules)
597-
weights = np.ones(N)
596+
N = len(rules) # number of rules
597+
weights = np.ones(N) # all rules have the same weight
598598
similarity_weight = 1.0
599599
fitness_weight = 0.5
600-
combined_profits = similarity_weight * np.sum(adj_matrix) + fitness_weight * fitness_scores
600+
combined_profits = similarity_weight * np.sum(adj_matrix) + fitness_weight * fitness_scores # combined similarities with fitness for values
601601

602602
selected = np.zeros(N, dtype=int)
603603

@@ -622,6 +622,9 @@ def knapsack_selection(adj_matrix, rules, M):
622622
return selected_rules
623623

624624
def prepare_data(rules, M, interestingness_measure):
625+
if not rules:
626+
return [], [], [], []
627+
625628
adj_matrix = build_adjacency_matrix(rules)
626629
selected_rules = knapsack_selection(adj_matrix, rules, M)
627630

@@ -644,13 +647,17 @@ def prepare_data(rules, M, interestingness_measure):
644647
labels.append(str(consequent))
645648
targets.append(node_indices[str(consequent)])
646649

647-
measure_value = getattr(rule, interestingness_measure, rule.support) #default support
648-
values.append(measure_value)
650+
if hasattr(rule, interestingness_measure):
651+
measure_value = getattr(rule, interestingness_measure)
652+
else:
653+
measure_value=rule.support # Default support
654+
values.append(measure_value)
649655

650656
return labels, sources, targets, values
651657

652658
labels, sources, targets, values = prepare_data(rules, M, interestingness_measure)
653659

660+
# Visualization using Plotly
654661
fig = go.Figure(go.Sankey(
655662
node=dict(
656663
pad=15,

tests/test_sankey_diagram.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
from niaarm.visualize import sankey_diagram
3+
from niaarm import Rule
4+
5+
class TestSankeyDiagram(unittest.TestCase):
6+
7+
@classmethod
8+
def setUpClass(cls):
9+
cls.rule1 = Rule(antecedent=["A", "B"], consequent=["C"])
10+
cls.rule1.fitness = 1.0
11+
cls.rule1.num_transactions = 10
12+
cls.rule2 = Rule(antecedent=["D"], consequent=["E", "F"])
13+
cls.rule2.fitness = 0.8
14+
cls.rule2.num_transactions = 15
15+
cls.rule3 = Rule(antecedent=["G", "H"], consequent=["I"])
16+
cls.rule3.fitness = 0.9
17+
cls.rule3.num_transactions = 12
18+
19+
cls.rules = [cls.rule1, cls.rule2, cls.rule3]
20+
21+
def test_sankey_output_type(self):
22+
fig = sankey_diagram(self.rules, "support", M=3)
23+
self.assertEqual(fig.__class__.__name__, "Figure")
24+
25+
def test_sankey_structure(self):
26+
fig = sankey_diagram(self.rules, "support", M=3)
27+
self.assertTrue("source" in fig.data[0].link)
28+
29+
def test_sankey_values(self):
30+
fig = sankey_diagram(self.rules, "support", M=3)
31+
link_data = fig.data[0].link
32+
flow_values = link_data['value']
33+
self.assertEqual(len(flow_values), len(self.rules))
34+
35+
def test_sankey_with_custom_fitness(self):
36+
fig = sankey_diagram(self.rules, "support", M=2)
37+
link_data = fig.data[0].link
38+
flow_values = link_data['value']
39+
self.assertGreater(len(flow_values), 0)
40+
41+
def test_sankey_no_empty_rules(self):
42+
fig = sankey_diagram([], "support", M=3)
43+
self.assertEqual(len(fig.data[0].link['source']), 0)
44+
self.assertEqual(len(fig.data[0].link['target']), 0)
45+
self.assertEqual(len(fig.data[0].link['value']), 0)
46+

0 commit comments

Comments
 (0)