Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from collections.abc import Collection, Iterator, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from itertools import groupby
from typing import Any

Expand Down Expand Up @@ -45,6 +46,13 @@ def campaigns_grouped_by_condition_name(
):
yield condition_name, list(campaign_group)

@cached_property
def person_cohorts(self) -> set[str]:
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
)
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())

def evaluate_eligibility(self) -> eligibility.EligibilityStatus:
"""Iterates over campaign groups, evaluates eligibility, and returns a consolidated status."""

Expand Down Expand Up @@ -78,12 +86,7 @@ def check_base_eligibility(self, iteration: rules.Iteration) -> bool:
}
if magic_cohort in iteration_cohorts:
return True

cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
)
person_cohorts: set[str] = set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
return bool(iteration_cohorts & person_cohorts)
return bool(iteration_cohorts & self.person_cohorts)

def evaluate_eligibility_by_iteration_rules(
self, campaign_group: list[rules.CampaignConfig]
Expand Down Expand Up @@ -123,8 +126,12 @@ def evaluate_priority_group(
worst_status_so_far_for_condition: eligibility.Status,
) -> tuple[eligibility.Status, list[eligibility.Reason], list[eligibility.Reason]]:
exclusion_reasons, actionable_reasons = [], []

exclude_capable_rules = [
ir for ir in iteration_rule_group if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
ir
for ir in iteration_rule_group
if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
and (ir.cohort_label is None or (ir.cohort_label in self.person_cohorts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it will work for the case of multiple rules, with the same priority, that are executed in an AND

]

best_status = eligibility.Status.not_eligible if exclude_capable_rules else eligibility.Status.actionable
Expand Down
4 changes: 3 additions & 1 deletion tests/fixtures/builders/model/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def future_date(days_ahead: int = 365) -> date:
class IterationCohortFactory(ModelFactory[rules.IterationCohort]): ...


class IterationRuleFactory(ModelFactory[rules.IterationRule]): ...
class IterationRuleFactory(ModelFactory[rules.IterationRule]):
attribute_target = None
Copy link
Contributor Author

@Karthikeyannhs Karthikeyannhs May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no random values will be ret for target_attribute and cohort_label.

cohort_label = None


class IterationFactory(ModelFactory[rules.Iteration]):
Expand Down
47 changes: 46 additions & 1 deletion tests/unit/services/calculators/test_eligibility_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def test_status_on_target_based_on_last_successful_date(
),
rule_builder.IterationRuleFactory.build(
type=rules.RuleType.suppression,
name=rules.RuleName("You have a future booking to be vaccinated against RSV"),
name=rules.RuleName("You have a vaccination date in the future for RSV"),
description=rules.RuleDescription("Exclude anyone with future Completed RSV Vaccination"),
priority=10,
operator=rules.RuleOperator.day_lte,
Expand Down Expand Up @@ -763,3 +763,48 @@ def test_status_on_cohort_attribute_level(faker: Faker):
has_item(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_eligible))
),
)


@pytest.mark.parametrize(
("person_cohorts", "cohort_label", "expected_status", "test_comment"),
[
(["cohort1", "cohort2"], "cohort1", Status.not_actionable, "matches the cohort label"),
(["cohort2", "cohort3"], "cohort1", Status.actionable, "doesn't match the cohort label"),
],
)
def test_status_if_iteration_rules_contains_cohort_label_field(
person_cohorts, cohort_label: str, expected_status: Status, test_comment: str, faker: Faker
):
# Given
nhs_number = NHSNumber(faker.nhs_number())
date_of_birth = DateOfBirth(faker.date_of_birth(minimum_age=66, maximum_age=74))

person_rows = person_rows_builder(nhs_number, date_of_birth=date_of_birth, cohorts=person_cohorts)
campaign_configs = [
rule_builder.CampaignConfigFactory.build(
target="RSV",
iterations=[
rule_builder.IterationFactory.build(
iteration_cohorts=[
rule_builder.IterationCohortFactory.build(cohort_label="cohort1"),
rule_builder.IterationCohortFactory.build(cohort_label="cohort2"),
],
iteration_rules=[rule_builder.PersonAgeSuppressionRuleFactory.build(cohort_label=cohort_label)],
)
],
)
]

calculator = EligibilityCalculator(person_rows, campaign_configs)

# When
actual = calculator.evaluate_eligibility()

# Then
assert_that(
actual,
is_eligibility_status().with_conditions(
has_items(is_condition().with_condition_name(ConditionName("RSV")).and_status(expected_status))
),
test_comment,
)
50 changes: 50 additions & 0 deletions tests/unit/services/calculators/test_rule_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections.abc import Collection, Mapping
from typing import Any

import pytest

from eligibility_signposting_api.model import rules
from eligibility_signposting_api.services.calculators.rule_calculator import RuleCalculator
from tests.fixtures.builders.model import rule as rule_builder

Row = Collection[Mapping[str, Any]]


@pytest.mark.parametrize(
("person_data", "rule", "expected"),
[
# PERSON attribute level
(
[{"ATTRIBUTE_TYPE": "PERSON", "POSTCODE": "SW19"}],
rule_builder.IterationRuleFactory.build(
attribute_level=rules.RuleAttributeLevel.PERSON, attribute_name="POSTCODE"
),
"SW19",
),
# TARGET attribute level
(
[{"ATTRIBUTE_TYPE": "RSV", "LAST_SUCCESSFUL_DATE": "20240101"}],
rule_builder.IterationRuleFactory.build(
attribute_level=rules.RuleAttributeLevel.TARGET,
attribute_name="LAST_SUCCESSFUL_DATE",
attribute_target="RSV",
),
"20240101",
),
# COHORT attribute level
(
[{"ATTRIBUTE_TYPE": "COHORTS", "COHORT_LABEL": ""}],
rule_builder.IterationRuleFactory.build(
attribute_level=rules.RuleAttributeLevel.COHORT, attribute_name="COHORT_LABEL"
),
"",
),
],
)
def test_get_attribute_value_for_all_attribute_levels(person_data: Row, rule: rules.IterationRule, expected: str):
# Given
calc = RuleCalculator(person_data=person_data, rule=rule)
# When
actual = calc.get_attribute_value()
# Then
assert actual == expected
Loading