Skip to content

Commit a6e3dfa

Browse files
authored
Merge pull request #148 from vrabiczan/TwoKey_plot
Two key plot
2 parents 13471fc + 09477a9 commit a6e3dfa

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from examples.visualization_examples.prepare_datasets import get_data_developer_salary_data
2+
from niaarm import Dataset, get_rules
3+
from niaarm.visualize import two_key_plot
4+
5+
# Get prepared data developer salary data
6+
arm_df = get_data_developer_salary_data()
7+
8+
# Prepare Dataset
9+
dataset = Dataset(
10+
path_or_df=arm_df,
11+
delimiter=","
12+
)
13+
14+
# Get rules
15+
metrics = ("support", "confidence")
16+
rules, run_time = get_rules(
17+
dataset=dataset,
18+
algorithm="DifferentialEvolution",
19+
metrics=metrics,
20+
max_evals=500
21+
)
22+
23+
# Sort rules
24+
rules.sort(by="support")
25+
# Print rule information
26+
print("\nRules:")
27+
print(rules)
28+
print(f'\nTime to generate rules: {f"{run_time:.3f}"} seconds')
29+
print("\nRule information: ", rules[3])
30+
print("Antecedent: ", rules[3].antecedent)
31+
print("Consequent: ", rules[3].consequent)
32+
print("Confidence: ", rules[3].confidence)
33+
print("Support: ", rules[3].support)
34+
print("Lift: ", rules[3].lift)
35+
print("\nMetrics:", metrics)
36+
37+
# Visualize scatter plot
38+
fig = two_key_plot(rules=rules, metrics=metrics, interactive=True)
39+
fig.show()

niaarm/visualize.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,110 @@ def create_plot_data(data_frame):
448448
plt.grid(which="both", color="grey", linestyle="-", linewidth=0.5)
449449

450450
return plt
451+
452+
def two_key_plot(rules, metrics, interactive=False):
453+
"""
454+
Visualize rules as a two key plot with two primary metrics (support, confidence) and rule order.
455+
456+
Args:
457+
rules (Rule): Association rule or rules to visualize.
458+
metrics (tuple): Two metrics to display on the x and y axes. 'order' will be used for point color.
459+
interactive (bool): Make plot interactive. Default: False.
460+
461+
Returns:
462+
Figure or plot.
463+
"""
464+
465+
# Ensure exactly two metrics for the axes
466+
if len(metrics) != 2:
467+
raise ValueError("Please provide exactly two metrics for a two-key plot.")
468+
469+
# Function to prepare the data
470+
def prepare_data(rules, metrics):
471+
data = {
472+
"rule": [],
473+
metrics[0]: [],
474+
metrics[1]: [],
475+
"order": [] # Store rule order (length)
476+
}
477+
478+
for rule in rules:
479+
data["rule"].append(rule.__repr__())
480+
data[metrics[0]].append(getattr(rule, metrics[0]))
481+
data[metrics[1]].append(getattr(rule, metrics[1]))
482+
483+
# Calculate order dynamically as the total number of items in antecedent and consequent
484+
if hasattr(rule, 'antecedent') and hasattr(rule, 'consequent'):
485+
rule_order = len(rule.antecedent) + len(rule.consequent)
486+
else:
487+
rule_order = 0 # Fallback if structure is missing
488+
489+
data["order"].append(rule_order)
490+
491+
# Return as DataFrame
492+
data_frame = pd.DataFrame(data)
493+
return data_frame
494+
495+
# Check if one or more rules
496+
if not hasattr(rules, "data") and not isinstance(rules, list):
497+
rules = [rules]
498+
499+
# Prepare the data
500+
df = prepare_data(rules, metrics)
501+
502+
# Interactive plot using Plotly
503+
if interactive:
504+
title = f'Interactive two-key plot for {len(rules)} rules' \
505+
if len(rules) > 1 else "Interactive two-key plot for rule"
506+
507+
# Create figure
508+
fig = px.scatter(
509+
data_frame=df,
510+
x=metrics[0],
511+
y=metrics[1],
512+
color=df["order"].astype(str),
513+
hover_name="rule",
514+
title=title,
515+
labels={"color": "order"},
516+
color_discrete_sequence=px.colors.qualitative.Plotly
517+
)
518+
fig.update_layout(
519+
xaxis_title=metrics[0],
520+
yaxis_title=metrics[1],
521+
legend_title = "Order"
522+
)
523+
return fig
524+
525+
# Static plot using Matplotlib
526+
else:
527+
plt.figure(figsize=(12, 8))
528+
529+
# Map each order to a unique color
530+
unique_orders = sorted(df["order"].unique())
531+
color_map = plt.colormaps.get_cmap("Set1")
532+
color_indices = np.linspace(0, 1, len(unique_orders))
533+
colors = [color_map(i) for i in color_indices]
534+
color_mapping = {order: colors[i] for i, order in enumerate(unique_orders)}
535+
536+
# Plot each order separately for discrete colors
537+
for order in unique_orders:
538+
subset = df[df["order"] == order]
539+
x_data = np.array(subset[metrics[0]].tolist())
540+
y_data = np.array(subset[metrics[1]].tolist())
541+
542+
plt.scatter(
543+
x_data,
544+
y_data,
545+
label=order,
546+
color=color_mapping[order],
547+
alpha=0.7
548+
)
549+
550+
# Add legend and labels
551+
plt.title(f'Two-key plot for {len(rules)} rules')
552+
plt.xlabel(metrics[0])
553+
plt.ylabel(metrics[1])
554+
plt.legend(title="Order")
555+
plt.grid(True)
556+
return plt
557+

tests/test_two_key_plot.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from unittest import TestCase
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
from niaarm.visualize import two_key_plot
5+
6+
class Rule:
7+
def __init__(self, antecedent, consequent, support, confidence):
8+
self.antecedent = antecedent
9+
self.consequent = consequent
10+
self.support = support
11+
self.confidence = confidence
12+
13+
def __repr__(self):
14+
return f"Rule({self.antecedent} -> {self.consequent})"
15+
16+
class TestTwoKeyPlot(TestCase):
17+
18+
@classmethod
19+
def setUpClass(cls):
20+
cls.rule1 = Rule(antecedent=["A", "B"], consequent=["C"], support=0.3, confidence=0.8)
21+
cls.rule2 = Rule(antecedent=["D"], consequent=["E", "F"], support=0.5, confidence=0.7)
22+
cls.rule3 = Rule(antecedent=["G", "H"], consequent=["I"], support=0.2, confidence=0.9)
23+
24+
cls.rules = [cls.rule1, cls.rule2, cls.rule3] # Ensure rules are available to all tests
25+
26+
def test_two_key_plot(self):
27+
metrics = ("support", "confidence")
28+
29+
plot = two_key_plot(self.rules, metrics, interactive=False)
30+
31+
# Verify that the return type is Matplotlib's pyplot
32+
self.assertIs(plot, plt)
33+
34+
# Ensure a figure is created
35+
self.assertTrue(plt.gcf().axes, "No axes found in the generated plot.")
36+
37+
def test_invalid_metrics(self):
38+
with self.assertRaises(ValueError):
39+
two_key_plot(self.rules, ("support",), interactive=False)
40+
41+
def test_interactive_plot(self):
42+
metrics = ("support", "confidence")
43+
fig = two_key_plot(self.rules, metrics, interactive=True)
44+
45+
# Verify that a Plotly figure is returned
46+
self.assertEqual(fig.__class__.__name__, "Figure", "Expected a Plotly figure but got a different type.")

0 commit comments

Comments
 (0)