diff --git a/niaarm/rule.py b/niaarm/rule.py index d1a7c51..15eb221 100644 --- a/niaarm/rule.py +++ b/niaarm/rule.py @@ -224,17 +224,15 @@ def __post_init__(self, transactions): min_ = transactions.min(numeric_only=True) max_ = transactions.max(numeric_only=True) acc = 0 - contains_antecedent = pd.Series( - np.ones(self.num_transactions, dtype=bool), dtype=bool - ) + contains_antecedent = np.ones(self.num_transactions, dtype=bool) for attribute in self.antecedent: if attribute.dtype != "cat": feature_min = min_[attribute.name] feature_max = max_[attribute.name] acc += 1 if feature_max == feature_min \ else (attribute.max_val - attribute.min_val) / (feature_max - feature_min) - contains_antecedent &= transactions[attribute.name] <= attribute.max_val - contains_antecedent &= transactions[attribute.name] >= attribute.min_val + contains_antecedent &= ((transactions[attribute.name] <= attribute.max_val) & + (transactions[attribute.name] >= attribute.min_val)).to_numpy() else: contains_antecedent &= ( np.isin(transactions[attribute.name], attribute.categories) @@ -242,17 +240,16 @@ def __post_init__(self, transactions): self.antecedent_count = contains_antecedent.sum() - contains_consequent = pd.Series( - np.ones(self.num_transactions, dtype=bool), dtype=bool - ) + contains_consequent = np.ones(self.num_transactions, dtype=bool) + for attribute in self.consequent: if attribute.dtype != "cat": feature_min = min_[attribute.name] feature_max = max_[attribute.name] acc += 1 if feature_max == feature_min \ else (attribute.max_val - attribute.min_val) / (feature_max - feature_min) - contains_consequent &= transactions[attribute.name] <= attribute.max_val - contains_consequent &= transactions[attribute.name] >= attribute.min_val + contains_consequent &= ((transactions[attribute.name] <= attribute.max_val) & + (transactions[attribute.name] >= attribute.min_val)).to_numpy() else: contains_consequent &= ( np.isin(transactions[attribute.name], attribute.categories) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 7d76e37..186eb49 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -133,3 +133,61 @@ def test_zhang(self): def test_leverage(self): self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3)) + + +class TestMetricsShuffleDataframe(TestCase): + def setUp(self): + self.data = Dataset( + pd.DataFrame({"col1": [1.5, 2.5, 1.0], "col2": ["Green", "Blue", "Red"]}).sample(frac=1.0) + ) + self.rule = Rule( + [Feature("col1", dtype="float", min_val=1.0, max_val=1.5)], + [Feature("col2", dtype="cat", categories=["Red", "Green"])], + transactions=self.data.transactions, + ) + + def test_support(self): + self.assertEqual(self.rule.support, 2 / 3) + + def test_confidence(self): + self.assertEqual(self.rule.confidence, 1) + + def test_lift(self): + self.assertEqual(self.rule.lift, 1.5) + + def test_coverage(self): + self.assertEqual(self.rule.coverage, 2 / 3) + + def test_rhs_support(self): + self.assertEqual(self.rule.rhs_support, 2 / 3) + + def test_conviction(self): + self.assertAlmostEqual( + self.rule.conviction, + (1 - self.rule.rhs_support) + / (1 - self.rule.confidence + 2.220446049250313e-16), + ) + + def test_amplitude(self): + self.assertEqual(self.rule.amplitude, 5 / 6) + + def test_inclusion(self): + self.assertEqual(self.rule.inclusion, 1) + + def test_interestingness(self): + self.assertEqual(self.rule.interestingness, 1 * 1 * (1 - (2 / 3) / 3)) + + def test_comprehensibility(self): + self.assertAlmostEqual(self.rule.comprehensibility, 0.630929753571) + + def test_netconf(self): + self.assertAlmostEqual(self.rule.netconf, ((2/3) - (2/3 * 2/3))/(2/3 * 1/3)) + + def test_yulesq(self): + self.assertAlmostEqual(self.rule.yulesq,1) + + def test_zhang(self): + self.assertAlmostEqual(self.rule.zhang, 1) + + def test_leverage(self): + self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3))