Skip to content

Commit 6206a26

Browse files
committed
Fix for transaction indexing (e.g. when the transaction dataframe is shuffled or sampled from) affecting metric calculations.
1 parent 90373f7 commit 6206a26

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

niaarm/rule.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,35 +224,32 @@ def __post_init__(self, transactions):
224224
min_ = transactions.min(numeric_only=True)
225225
max_ = transactions.max(numeric_only=True)
226226
acc = 0
227-
contains_antecedent = pd.Series(
228-
np.ones(self.num_transactions, dtype=bool), dtype=bool
229-
)
227+
contains_antecedent = np.ones(self.num_transactions, dtype=bool)
230228
for attribute in self.antecedent:
231229
if attribute.dtype != "cat":
232230
feature_min = min_[attribute.name]
233231
feature_max = max_[attribute.name]
234232
acc += 1 if feature_max == feature_min \
235233
else (attribute.max_val - attribute.min_val) / (feature_max - feature_min)
236-
contains_antecedent &= transactions[attribute.name] <= attribute.max_val
237-
contains_antecedent &= transactions[attribute.name] >= attribute.min_val
234+
contains_antecedent &= ((transactions[attribute.name] <= attribute.max_val) &
235+
(transactions[attribute.name] >= attribute.min_val)).to_numpy()
238236
else:
239237
contains_antecedent &= (
240238
np.isin(transactions[attribute.name], attribute.categories)
241239
)
242240

243241
self.antecedent_count = contains_antecedent.sum()
244242

245-
contains_consequent = pd.Series(
246-
np.ones(self.num_transactions, dtype=bool), dtype=bool
247-
)
243+
contains_consequent = np.ones(self.num_transactions, dtype=bool)
244+
248245
for attribute in self.consequent:
249246
if attribute.dtype != "cat":
250247
feature_min = min_[attribute.name]
251248
feature_max = max_[attribute.name]
252249
acc += 1 if feature_max == feature_min \
253250
else (attribute.max_val - attribute.min_val) / (feature_max - feature_min)
254-
contains_consequent &= transactions[attribute.name] <= attribute.max_val
255-
contains_consequent &= transactions[attribute.name] >= attribute.min_val
251+
contains_consequent &= ((transactions[attribute.name] <= attribute.max_val) &
252+
(transactions[attribute.name] >= attribute.min_val)).to_numpy()
256253
else:
257254
contains_consequent &= (
258255
np.isin(transactions[attribute.name], attribute.categories)

tests/test_metrics.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,61 @@ def test_zhang(self):
133133

134134
def test_leverage(self):
135135
self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3))
136+
137+
138+
class TestMetricsShuffleDataframe(TestCase):
139+
def setUp(self):
140+
self.data = Dataset(
141+
pd.DataFrame({"col1": [1.5, 2.5, 1.0], "col2": ["Green", "Blue", "Red"]}).sample(frac=1.0)
142+
)
143+
self.rule = Rule(
144+
[Feature("col1", dtype="float", min_val=1.0, max_val=1.5)],
145+
[Feature("col2", dtype="cat", categories=["Red", "Green"])],
146+
transactions=self.data.transactions,
147+
)
148+
149+
def test_support(self):
150+
self.assertEqual(self.rule.support, 2 / 3)
151+
152+
def test_confidence(self):
153+
self.assertEqual(self.rule.confidence, 1)
154+
155+
def test_lift(self):
156+
self.assertEqual(self.rule.lift, 1.5)
157+
158+
def test_coverage(self):
159+
self.assertEqual(self.rule.coverage, 2 / 3)
160+
161+
def test_rhs_support(self):
162+
self.assertEqual(self.rule.rhs_support, 2 / 3)
163+
164+
def test_conviction(self):
165+
self.assertAlmostEqual(
166+
self.rule.conviction,
167+
(1 - self.rule.rhs_support)
168+
/ (1 - self.rule.confidence + 2.220446049250313e-16),
169+
)
170+
171+
def test_amplitude(self):
172+
self.assertEqual(self.rule.amplitude, 5 / 6)
173+
174+
def test_inclusion(self):
175+
self.assertEqual(self.rule.inclusion, 1)
176+
177+
def test_interestingness(self):
178+
self.assertEqual(self.rule.interestingness, 1 * 1 * (1 - (2 / 3) / 3))
179+
180+
def test_comprehensibility(self):
181+
self.assertAlmostEqual(self.rule.comprehensibility, 0.630929753571)
182+
183+
def test_netconf(self):
184+
self.assertAlmostEqual(self.rule.netconf, ((2/3) - (2/3 * 2/3))/(2/3 * 1/3))
185+
186+
def test_yulesq(self):
187+
self.assertAlmostEqual(self.rule.yulesq,1)
188+
189+
def test_zhang(self):
190+
self.assertAlmostEqual(self.rule.zhang, 1)
191+
192+
def test_leverage(self):
193+
self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3))

0 commit comments

Comments
 (0)