diff --git a/src/eligibility_signposting_api/services/calculators/rule_calculator.py b/src/eligibility_signposting_api/services/calculators/rule_calculator.py index be418f10..7d75772a 100644 --- a/src/eligibility_signposting_api/services/calculators/rule_calculator.py +++ b/src/eligibility_signposting_api/services/calculators/rule_calculator.py @@ -40,6 +40,18 @@ def get_attribute_value(self) -> str | None: (r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "PERSON"), None ) attribute_value = person.get(self.rule.attribute_name) if person else None + case rules.RuleAttributeLevel.COHORT: + cohorts: Mapping[str, str | None] | None = next( + (r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None + ) + if self.rule.attribute_name == "COHORT_LABEL": + cohort_map = self.get_value(cohorts, "COHORT_MAP") + cohorts_dict = self.get_value(cohort_map, "cohorts") + m_dict = self.get_value(cohorts_dict, "M") + person_cohorts: set[str] = set(m_dict.keys()) + attribute_value = ",".join(person_cohorts) + else: + attribute_value = cohorts.get(self.rule.attribute_name) if cohorts else None case rules.RuleAttributeLevel.TARGET: target: Mapping[str, str | None] | None = next( (r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == self.rule.attribute_target), None @@ -50,6 +62,11 @@ def get_attribute_value(self) -> str | None: raise NotImplementedError(msg) return attribute_value + @staticmethod + def get_value(dictionary: Mapping[str, Any] | None, key: str) -> dict: + v = dictionary.get(key, {}) if isinstance(dictionary, dict) else {} + return v if isinstance(v, dict) else {} + def evaluate_rule(self, attribute_value: str | None) -> tuple[eligibility.Status, str]: """Evaluate a rule against a person data attribute. Return the result, and the reason for the result.""" matcher_class = OperatorRegistry.get(self.rule.operator) diff --git a/src/eligibility_signposting_api/services/rules/operators.py b/src/eligibility_signposting_api/services/rules/operators.py index 7ecd15c1..1f9c4af8 100644 --- a/src/eligibility_signposting_api/services/rules/operators.py +++ b/src/eligibility_signposting_api/services/rules/operators.py @@ -165,8 +165,9 @@ def _matches(self, item: str | None) -> bool: class IsIn(Operator): def _matches(self, item: str | None) -> bool: item = item if item is not None else self.item_default - comparators = str(self.rule_value).split(",") - return str(item) in comparators + comparators = set(str(self.rule_value).split(",")) + items = set(str(item).split(",")) + return bool(items & comparators) @OperatorRegistry.register(RuleOperator.not_in) @@ -174,8 +175,9 @@ def _matches(self, item: str | None) -> bool: class NotIn(Operator): def _matches(self, item: str | None) -> bool: item = item if item is not None else self.item_default - comparators = str(self.rule_value).split(",") - return str(item) not in comparators + comparators = set(str(self.rule_value).split(",")) + items = set(str(item).split(",")) + return not bool(items & comparators) @OperatorRegistry.register(RuleOperator.is_null) diff --git a/tests/unit/services/calculators/test_eligibility_calculator.py b/tests/unit/services/calculators/test_eligibility_calculator.py index 8e8e564c..893cefb7 100644 --- a/tests/unit/services/calculators/test_eligibility_calculator.py +++ b/tests/unit/services/calculators/test_eligibility_calculator.py @@ -649,7 +649,7 @@ def test_base_eligible_and_icb_example( ], ) @freeze_time("2025-01-01") -def test_not_actionable_status_on_target_when_last_successful_date_lte_today( +def test_status_on_target_based_on_last_successful_date( vaccine: str, last_successful_date: str, expected_status: Status, test_comment: str, faker: Faker ): # Given @@ -718,3 +718,48 @@ def test_not_actionable_status_on_target_when_last_successful_date_lte_today( ), test_comment, ) + + +def test_status_on_cohort_attribute_level(faker: Faker): + # Given + nhs_number = NHSNumber(faker.nhs_number()) + + person_row = person_rows_builder(nhs_number, cohorts=["cohort1", "covid_eligibility_complaint_list"]) + + campaign_configs = [ + rule_builder.CampaignConfigFactory.build( + target="RSV", + iterations=[ + rule_builder.IterationFactory.build( + iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort1")], + iteration_rules=[ + rule_builder.IterationRuleFactory.build( + type=rules.RuleType.filter, + name=rules.RuleName("Exclude those in a complaint cohort"), + description=rules.RuleDescription( + "Ensure anyone who has registered a complaint is not shown as eligible" + ), + priority=15, + operator=rules.RuleOperator.member_of, + attribute_level=rules.RuleAttributeLevel.COHORT, + attribute_name=rules.RuleAttributeName("COHORT_LABEL"), + comparator=rules.RuleComparator("covid_eligibility_complaint_list"), + ) + ], + ) + ], + ) + ] + + calculator = EligibilityCalculator(person_row, campaign_configs) + + # When + actual = calculator.evaluate_eligibility() + + # Then + assert_that( + actual, + is_eligibility_status().with_conditions( + has_item(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_eligible)) + ), + ) diff --git a/tests/unit/services/operators/test_operators.py b/tests/unit/services/operators/test_operators.py index fe731e78..1c2b2ba7 100644 --- a/tests/unit/services/operators/test_operators.py +++ b/tests/unit/services/operators/test_operators.py @@ -374,6 +374,11 @@ ("PP77", RuleOperator.is_in, "QH8,QJG[[NVL:QH8]]", False, "Default value specified, but unused"), (None, RuleOperator.is_in, "QH8,QJG[[NVL:QH8]]", True, "Default value used"), (None, RuleOperator.is_in, "QH8,QJG[[NVL:PP77]]", False, "Default value used"), + ("QH8", RuleOperator.is_in, "QH8", True, ""), + ("QH8,QJG", RuleOperator.is_in, "QH8", True, ""), + ("QH8,QJG,QGX", RuleOperator.is_in, "QH8,QJG", True, ""), + ("QH8,QGX", RuleOperator.is_in, "QH8,QJG", True, ""), + ("QH8,QJG", RuleOperator.is_in, "QH8,QJG,QGX", True, ""), ] # is not_in @@ -386,6 +391,12 @@ ("PP77", RuleOperator.not_in, "QH8,QJG[[NVL:QH8]]", True, "Default value specified, but unused"), (None, RuleOperator.not_in, "QH8,QJG[[NVL:QH8]]", False, "Default value used"), (None, RuleOperator.not_in, "QH8,QJG[[NVL:PP77]]", True, "Default value used"), + ("QH8", RuleOperator.not_in, "QH8", False, ""), + ("QH8,QJG", RuleOperator.not_in, "QH8", False, ""), + ("QH8,QJG,QGX", RuleOperator.not_in, "QH8,QJG", False, ""), + ("QH8,QGX", RuleOperator.not_in, "QH8,QJG", False, ""), + ("QH8,QJG", RuleOperator.not_in, "QH8,QJG,QGX", False, ""), + ("QH8,QJG", RuleOperator.not_in, "QHX", True, ""), ] # is member_of