From a5b542f066f990a1cfdcce17f30251ab4318748e Mon Sep 17 00:00:00 2001 From: "howsun.jow" Date: Tue, 10 Jun 2025 14:27:00 +0100 Subject: [PATCH 1/2] Fix for rule metric calculations where we have multiple categories attributes. --- niaarm/rule.py | 4 ++-- tests/test_metrics.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/niaarm/rule.py b/niaarm/rule.py index ce73b79..d1a7c51 100644 --- a/niaarm/rule.py +++ b/niaarm/rule.py @@ -237,7 +237,7 @@ def __post_init__(self, transactions): contains_antecedent &= transactions[attribute.name] >= attribute.min_val else: contains_antecedent &= ( - transactions[attribute.name] == attribute.categories[0] + np.isin(transactions[attribute.name], attribute.categories) ) self.antecedent_count = contains_antecedent.sum() @@ -255,7 +255,7 @@ def __post_init__(self, transactions): contains_consequent &= transactions[attribute.name] >= attribute.min_val else: contains_consequent &= ( - transactions[attribute.name] == attribute.categories[0] + np.isin(transactions[attribute.name], attribute.categories) ) self.__amplitude = 1 - (1 / (len(self.antecedent) + len(self.consequent))) * acc self.consequent_count = contains_consequent.sum() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 2805588..907f539 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,4 +1,5 @@ import os +import pandas as pd from unittest import TestCase from niaarm import Dataset, Feature, Rule @@ -74,3 +75,28 @@ def test_zhang(self): def test_leverage(self): self.assertAlmostEqual(self.rule_one.leverage, 0.102040816326) self.assertAlmostEqual(self.rule_two.leverage, 0.102040816326) + + +class TestMetricsMultipleCategories(TestCase): + def setUp(self): + self.data = Dataset( + pd.DataFrame({'col1': [1.5,2.5, 1.0], 'col2': ['Green', 'Blue', 'Red']}) + ) + self.rule = Rule([Feature("col1", dtype="float", min_val=0.99, max_val=1.51)], + [Feature("col2", dtype="cat", categories=["Red", "Green"])], + transactions=self.data.transactions) + + def test_support(self): + self.assertEqual(self.rule.support, 2 / 3) + + def test_coverage(self): + self.assertEqual(self.rule.coverage, 2 / 3) + + def test_rhs_support(self): + self.assertEqual(self.rule.rhs_support, 2 / 3) + + def test_confidence(self): + self.assertEqual(self.rule.confidence, 1) + + def test_lift(self): + self.assertEqual(self.rule.lift, 1.5) From 1f5b7955cd3027a10ec481a1166b33ac67382cd1 Mon Sep 17 00:00:00 2001 From: "howsun.jow" Date: Wed, 11 Jun 2025 09:48:36 +0100 Subject: [PATCH 2/2] Added other metrics to unit test. --- tests/test_metrics.py | 49 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 907f539..7d76e37 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -80,23 +80,56 @@ def test_leverage(self): class TestMetricsMultipleCategories(TestCase): def setUp(self): self.data = Dataset( - pd.DataFrame({'col1': [1.5,2.5, 1.0], 'col2': ['Green', 'Blue', 'Red']}) + pd.DataFrame({"col1": [1.5, 2.5, 1.0], "col2": ["Green", "Blue", "Red"]}) + ) + self.rule = Rule( + [Feature("col1", dtype="float", min_val=1.0, max_val=1.5)], + [Feature("col2", dtype="cat", categories=["Red", "Green"])], + transactions=self.data.transactions, ) - self.rule = Rule([Feature("col1", dtype="float", min_val=0.99, max_val=1.51)], - [Feature("col2", dtype="cat", categories=["Red", "Green"])], - transactions=self.data.transactions) def test_support(self): self.assertEqual(self.rule.support, 2 / 3) + def test_confidence(self): + self.assertEqual(self.rule.confidence, 1) + + def test_lift(self): + self.assertEqual(self.rule.lift, 1.5) + def test_coverage(self): self.assertEqual(self.rule.coverage, 2 / 3) def test_rhs_support(self): self.assertEqual(self.rule.rhs_support, 2 / 3) - def test_confidence(self): - self.assertEqual(self.rule.confidence, 1) + def test_conviction(self): + self.assertAlmostEqual( + self.rule.conviction, + (1 - self.rule.rhs_support) + / (1 - self.rule.confidence + 2.220446049250313e-16), + ) - def test_lift(self): - self.assertEqual(self.rule.lift, 1.5) + def test_amplitude(self): + self.assertEqual(self.rule.amplitude, 5 / 6) + + def test_inclusion(self): + self.assertEqual(self.rule.inclusion, 1) + + def test_interestingness(self): + self.assertEqual(self.rule.interestingness, 1 * 1 * (1 - (2 / 3) / 3)) + + def test_comprehensibility(self): + self.assertAlmostEqual(self.rule.comprehensibility, 0.630929753571) + + def test_netconf(self): + self.assertAlmostEqual(self.rule.netconf, ((2/3) - (2/3 * 2/3))/(2/3 * 1/3)) + + def test_yulesq(self): + self.assertAlmostEqual(self.rule.yulesq,1) + + def test_zhang(self): + self.assertAlmostEqual(self.rule.zhang, 1) + + def test_leverage(self): + self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3))