Skip to content

Commit 12e5cb3

Browse files
committed
feat: add Apriori with association rule mining (support, confidence, lift)
1 parent 4ce1185 commit 12e5cb3

File tree

1 file changed

+92
-85
lines changed

1 file changed

+92
-85
lines changed
Lines changed: 92 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
"""
2-
Apriori Algorithm is a Association rule mining technique, also known as market basket
3-
analysis, aims to discover interesting relationships or associations among a set of
4-
items in a transactional or relational database.
2+
Apriori Algorithm with Association Rules (support, confidence, lift).
53
6-
For example, Apriori Algorithm states: "If a customer buys item A and item B, then they
7-
are likely to buy item C." This rule suggests a relationship between items A, B, and C,
8-
indicating that customers who purchased A and B are more likely to also purchase item C.
4+
This implementation finds:
5+
- Frequent itemsets
6+
- Association rules with minimum confidence and lift
97
108
WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm
11-
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
129
"""
1310

1411
from itertools import combinations
12+
from collections import defaultdict
1513

1614

1715
def load_data() -> list[list[str]]:
@@ -24,90 +22,99 @@ def load_data() -> list[list[str]]:
2422
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]
2523

2624

27-
def prune(itemset: list, candidates: list, length: int) -> list:
28-
"""
29-
Prune candidate itemsets that are not frequent.
30-
The goal of pruning is to filter out candidate itemsets that are not frequent. This
31-
is done by checking if all the (k-1) subsets of a candidate itemset are present in
32-
the frequent itemsets of the previous iteration (valid subsequences of the frequent
33-
itemsets from the previous iteration).
34-
35-
Prunes candidate itemsets that are not frequent.
36-
37-
>>> itemset = ['X', 'Y', 'Z']
38-
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
39-
>>> prune(itemset, candidates, 2)
40-
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
41-
42-
>>> itemset = ['1', '2', '3', '4']
43-
>>> candidates = ['1', '2', '4']
44-
>>> prune(itemset, candidates, 3)
45-
[]
46-
"""
47-
pruned = []
48-
for candidate in candidates:
49-
is_subsequence = True
50-
for item in candidate:
51-
if item not in itemset or itemset.count(item) < length - 1:
52-
is_subsequence = False
25+
class Apriori:
26+
"""Apriori algorithm class with support, confidence, and lift filtering."""
27+
28+
def __init__(self, transactions, min_support=0.25, min_confidence=0.5, min_lift=1.0):
29+
self.transactions = [set(t) for t in transactions]
30+
self.min_support = min_support
31+
self.min_confidence = min_confidence
32+
self.min_lift = min_lift
33+
self.itemsets = []
34+
self.rules = []
35+
36+
self.find_frequent_itemsets()
37+
self.generate_association_rules()
38+
39+
def _get_support(self, itemset: frozenset) -> float:
40+
"""Return support of an itemset."""
41+
return sum(1 for t in self.transactions if itemset.issubset(t)) / len(self.transactions)
42+
43+
def confidence(self, antecedent: frozenset, consequent: frozenset) -> float:
44+
"""Calculate confidence of a rule A -> B."""
45+
support_antecedent = self._get_support(antecedent)
46+
support_both = self._get_support(antecedent | consequent)
47+
return support_both / support_antecedent if support_antecedent > 0 else 0
48+
49+
def lift(self, antecedent: frozenset, consequent: frozenset) -> float:
50+
"""Calculate lift of a rule A -> B."""
51+
support_consequent = self._get_support(consequent)
52+
conf = self.confidence(antecedent, consequent)
53+
return conf / support_consequent if support_consequent > 0 else 0
54+
55+
def find_frequent_itemsets(self):
56+
"""Generate all frequent itemsets."""
57+
item_counts = defaultdict(int)
58+
for t in self.transactions:
59+
for item in t:
60+
item_counts[frozenset([item])] += 1
61+
62+
total = len(self.transactions)
63+
current_itemsets = {k: v / total for k, v in item_counts.items() if v / total >= self.min_support}
64+
self.itemsets.append(current_itemsets)
65+
66+
k = 2
67+
while current_itemsets:
68+
candidates = set()
69+
keys = list(current_itemsets.keys())
70+
for i in range(len(keys)):
71+
for j in range(i + 1, len(keys)):
72+
union = keys[i] | keys[j]
73+
if len(union) == k:
74+
if all(frozenset(sub) in current_itemsets for sub in combinations(union, k - 1)):
75+
candidates.add(union)
76+
77+
freq_candidates = {c: self._get_support(c) for c in candidates if self._get_support(c) >= self.min_support}
78+
if not freq_candidates:
5379
break
54-
if is_subsequence:
55-
pruned.append(candidate)
56-
return pruned
57-
58-
59-
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
60-
"""
61-
Returns a list of frequent itemsets and their support counts.
62-
63-
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
64-
>>> apriori(data, 2)
65-
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
66-
67-
>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
68-
>>> apriori(data, 3)
69-
[]
70-
"""
71-
itemset = [list(transaction) for transaction in data]
72-
frequent_itemsets = []
73-
length = 1
74-
75-
while itemset:
76-
# Count itemset support
77-
counts = [0] * len(itemset)
78-
for transaction in data:
79-
for j, candidate in enumerate(itemset):
80-
if all(item in transaction for item in candidate):
81-
counts[j] += 1
8280

83-
# Prune infrequent itemsets
84-
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
85-
86-
# Append frequent itemsets (as a list to maintain order)
87-
for i, item in enumerate(itemset):
88-
frequent_itemsets.append((sorted(item), counts[i]))
89-
90-
length += 1
91-
itemset = prune(itemset, list(combinations(itemset, length)), length)
92-
93-
return frequent_itemsets
81+
self.itemsets.append(freq_candidates)
82+
current_itemsets = freq_candidates
83+
k += 1
84+
85+
return self.itemsets
86+
87+
def generate_association_rules(self):
88+
"""Generate association rules with min confidence and lift."""
89+
for level in self.itemsets:
90+
for itemset in level:
91+
if len(itemset) < 2:
92+
continue
93+
for i in range(1, len(itemset)):
94+
for antecedent in combinations(itemset, i):
95+
antecedent = frozenset(antecedent)
96+
consequent = itemset - antecedent
97+
conf = self.confidence(antecedent, consequent)
98+
lft = self.lift(antecedent, consequent)
99+
if conf >= self.min_confidence and lft >= self.min_lift:
100+
self.rules.append((antecedent, consequent, conf, lft))
101+
return self.rules
94102

95103

96104
if __name__ == "__main__":
97-
"""
98-
Apriori algorithm for finding frequent itemsets.
99-
100-
Args:
101-
data: A list of transactions, where each transaction is a list of items.
102-
min_support: The minimum support threshold for frequent itemsets.
103-
104-
Returns:
105-
A list of frequent itemsets along with their support counts.
106-
"""
107105
import doctest
108106

109107
doctest.testmod()
110108

111-
# user-defined threshold or minimum support level
112-
frequent_itemsets = apriori(data=load_data(), min_support=2)
113-
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
109+
transactions = load_data()
110+
model = Apriori(transactions, min_support=0.25, min_confidence=0.1, min_lift=0.0)
111+
112+
print("Frequent itemsets:")
113+
for level in model.itemsets:
114+
for items, sup in level.items():
115+
print(f"{set(items)}: {sup:.2f}")
116+
117+
print("\nAssociation Rules:")
118+
for rule in model.rules:
119+
antecedent, consequent, conf, lift = rule
120+
print(f"{set(antecedent)} -> {set(consequent)}, conf={conf:.2f}, lift={lift:.2f}")

0 commit comments

Comments
 (0)