diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 670b922eb..a03d846dd 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,20 +21,17 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ( # noqa: F401 - CycleError, - NaNCreationError, - SpiralError, -) +from openfisca_core.errors import CycleError, NaNCreationError, SpiralError -from .helpers import ( # noqa: F401 +from . import types +from .helpers import ( calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax, ) -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from .simulation import Simulation +from .simulation_builder import SimulationBuilder __all__ = [ "CycleError", @@ -46,4 +43,5 @@ "calculate_output_divide", "check_type", "transform_to_strict_syntax", + "types", ] diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py index f99c1d210..780dc9d49 100644 --- a/openfisca_core/simulations/_build_default_simulation.py +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -1,17 +1,21 @@ """This module contains the _BuildDefaultSimulation class.""" -from typing import Union from typing_extensions import Self import numpy from .simulation import Simulation -from .typing import Entity, Population, TaxBenefitSystem +from .types import Populations, TaxBenefitSystem class _BuildDefaultSimulation: """Build a default simulation. + Attributes: + count(int): The number of periods. + populations(Populations): The built populations. + simulation(Simulation): The built simulation. + Args: tax_benefit_system(TaxBenefitSystem): The tax-benefit system. count(int): The number of periods. @@ -47,7 +51,7 @@ class _BuildDefaultSimulation: count: int #: The built populations. - populations: dict[str, Union[Population[Entity]]] + populations: Populations #: The built simulation. simulation: Simulation @@ -61,7 +65,7 @@ def add_count(self) -> Self: """Add the number of Population to the simulation. Returns: - _BuildDefaultSimulation: The builder. + Self: The builder. Examples: >>> from openfisca_core import entities, taxbenefitsystems @@ -94,7 +98,7 @@ def add_ids(self) -> Self: """Add the populations ids to the simulation. Returns: - _BuildDefaultSimulation: The builder. + Self: The builder. Examples: >>> from openfisca_core import entities, taxbenefitsystems @@ -129,7 +133,7 @@ def add_members_entity_id(self) -> Self: Each SingleEntity has its own GroupEntity. Returns: - _BuildDefaultSimulation: The builder. + Self: The builder. Examples: >>> from openfisca_core import entities, taxbenefitsystems diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py index 60ff6148e..292a921d8 100644 --- a/openfisca_core/simulations/_build_from_variables.py +++ b/openfisca_core/simulations/_build_from_variables.py @@ -2,14 +2,15 @@ from __future__ import annotations +from collections.abc import Sized from typing_extensions import Self from openfisca_core import errors from ._build_default_simulation import _BuildDefaultSimulation -from ._type_guards import is_variable_dated +from ._guards import is_a_dated_value, is_a_pure_value from .simulation import Simulation -from .typing import Entity, Population, TaxBenefitSystem, Variables +from .types import Populations, TaxBenefitSystem, Variables class _BuildFromVariables: @@ -67,7 +68,7 @@ class _BuildFromVariables: default_period: str | None #: The built populations. - populations: dict[str, Population[Entity]] + populations: Populations #: The built simulation. simulation: Simulation @@ -99,7 +100,7 @@ def add_dated_values(self) -> Self: """Add the dated input values to the Simulation. Returns: - _BuildFromVariables: The builder. + Self: The builder. Examples: >>> from openfisca_core import entities, periods, taxbenefitsystems, variables @@ -141,7 +142,7 @@ def add_dated_values(self) -> Self: """ for variable, value in self.variables.items(): - if is_variable_dated(dated_variable := value): + if is_a_dated_value(dated_variable := value): for period, dated_value in dated_variable.items(): self.simulation.set_input(variable, period, dated_value) @@ -151,7 +152,7 @@ def add_undated_values(self) -> Self: """Add the undated input values to the Simulation. Returns: - _BuildFromVariables: The builder. + Self: The builder. Raises: SituationParsingError: If there is not a default period set. @@ -184,7 +185,7 @@ def add_undated_values(self) -> Self: >>> builder = _BuildFromVariables(tax_benefit_system, variables) >>> builder.add_undated_values() Traceback (most recent call last): - openfisca_core.errors.situation_parsing_error.SituationParsingError + openfisca_core.errors.situation_parsing_error.SituationParsingEr... >>> builder.default_period = period >>> builder.add_undated_values() <..._BuildFromVariables object at ...> @@ -199,7 +200,7 @@ def add_undated_values(self) -> Self: """ for variable, value in self.variables.items(): - if not is_variable_dated(undated_value := value): + if is_a_pure_value(undated_value := value): if (period := self.default_period) is None: message = ( "Can't deal with type: expected object. Input " @@ -218,7 +219,7 @@ def add_undated_values(self) -> Self: def _person_count(params: Variables) -> int: try: - first_value = next(iter(params.values())) + first_value: object = next(iter(params.values())) if isinstance(first_value, dict): first_value = next(iter(first_value.values())) @@ -226,7 +227,10 @@ def _person_count(params: Variables) -> int: if isinstance(first_value, str): return 1 - return len(first_value) + if isinstance(first_value, Sized): + return len(first_value) + + raise NotImplementedError except Exception: return 1 diff --git a/openfisca_core/simulations/_guards.py b/openfisca_core/simulations/_guards.py new file mode 100644 index 000000000..404d026d2 --- /dev/null +++ b/openfisca_core/simulations/_guards.py @@ -0,0 +1,613 @@ +"""Type guards to help type narrowing simulation parameters. + +Every calculation in a simulation requires an entity, a variable, a period, and +a value. However, the way users can specify these elements can vary. This +module provides type guards to help narrow down the type of simulation +parameters, to help both readability and maintainability. + +For example, the following is a perfectly valid, albeit complex, way to specify +a simulation's parameters:: + + .. code-block:: python + + params = { + "axes": [ + [ + { + "count": 2, + "max": 3000, + "min": 0, + "name": + "rent", + "period": "2018-11" + } + ] + ], + "households": { + "housea": { + "parents": ["Alicia", "Javier"] + }, + "houseb": { + "parents": ["Tom"] + }, + }, + "persons": { + "Alicia": { + "salary": { + "2018-11": 0 + } + }, + "Javier": {}, + "Tom": {} + }, + } + +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing_extensions import TypeGuard + +import pydantic + +from openfisca_core import periods + +from .types import ( + Axes, + Axis, + DatedValue, + FullySpecifiedEntities, + ImplicitGroupEntities, + Params, + ParamsWithAxes, + PureValue, + Roles, + Variables, +) + +#: Pydantic type adapter to extract information from axes. +adapter = pydantic.TypeAdapter(Axis) + +#: Field schema for axes. +axis_schema = adapter.core_schema + +#: Required fields. +axis_required = [ + key for key, value in axis_schema["fields"].items() if value["required"] +] + + +def is_a_pure_value( + value: object, +) -> TypeGuard[PureValue]: + """Check if an input value is undated. + + The most atomic elements of a simulation are pure values. They can be + either scalars or vectors. For example:: + + .. code-block:: python + + 1.5 + True + [1000, 2000] + + Args: + value(object): A value. + + Returns: + bool: True if the value is undated. + + Examples: + >>> value = 2000 + >>> is_a_pure_value(value) + True + + >>> value = [2000, 3000] + >>> is_a_pure_value(value) + True + + >>> value = {"2000": 2000} + >>> is_a_pure_value(value) + False + + >>> value = {"2018-W01": [2000, 3000]} + >>> is_a_pure_value(value) + False + + >>> value = {"123": 123} + >>> is_a_pure_value(value) + False + + """ + + return not isinstance(value, dict) + + +def is_a_dated_value( + value: object, +) -> TypeGuard[DatedValue]: + """Check if an input value is dated. + + Pure values are associated with the simulation's period behind the scenes. + However, some calculations require different values for variables for + different periods. In such a case, users can specify dated values:: + + .. code-block:: python + + {"2018-01": 2000} + {"2018-W01": [2000, 3000]} + {"2018-W01-1": 2000, "2018-W01-2": [3000, 4000]} + + Args: + value(object): A value. + + Returns: + bool: True if the value is dated. + + Examples: + >>> value = 2000 + >>> is_a_dated_value(value) + False + + >>> value = [2000, 3000] + >>> is_a_dated_value(value) + False + + >>> value = {"2000": 2000} + >>> is_a_dated_value(value) + True + + >>> value = {"2018-W01": [2000, 3000]} + >>> is_a_dated_value(value) + True + + >>> value = {"123": 123} + >>> is_a_dated_value(value) + False + + """ + + if not isinstance(value, dict): + return False + + try: + return all(periods.period(key) for key in value.keys()) + + except ValueError: + return False + + +def are_variables( + value: object, +) -> TypeGuard[Variables]: + """Check if an input value is a map of variables. + + In a simulation, every value has to be associated with a variable. As with + values, variables cannot be inferred from the context. Users have to + explicitly specify them. For example:: + + .. code-block:: python + + {"salary": 2000} + {"taxes": {"2018-W01-1": [123, 234]}} + {"taxes": {"2018-W01-1": [123, 234]}, "salary": 123} + + Args: + value(object): A value. + + Returns: + bool: True if the value is a map of variables. + + Examples: + >>> value = 2000 + >>> are_variables(value) + False + + >>> value = [2000, 3000] + >>> are_variables(value) + False + + >>> value = {"2000": 2000} + >>> are_variables(value) + False + + >>> value = {"2018-W01": [2000, 3000]} + >>> are_variables(value) + False + + >>> value = {"salary": 123} + >>> are_variables(value) + True + + >>> value = {"taxes": {"2018-W01-1": [123, 234]}} + >>> are_variables(value) + True + + >>> value = {"taxes": {"2018-W01-1": [123, 234]}, "salary": 123} + >>> are_variables(value) + True + + """ + + if is_a_pure_value(value): + return False + + if is_a_dated_value(value): + return False + + return True + + +def are_roles( + value: object, +) -> TypeGuard[Roles]: + """Check if an input value is a map of roles. + + In a simulation, there are cases where we need to calculate things for + group entities, for example, households. In such cases, some calculations + require that we specify certain roles. For example:: + + .. code-block:: python + + {"principal": "Alicia"} + {"parents": ["Alicia", "Javier"]} + + Args: + value(object): A value. + + Returns: + bool: True if the value is a map of roles. + + Examples: + >>> value = "parent" + >>> are_roles(value) + False + + >>> value = ["dad", "son"] + >>> are_roles(value) + False + + >>> value = {"2018-W01": [2000, 3000]} + >>> are_roles(value) + False + + >>> value = {"salary": 123} + >>> are_roles(value) + False + + >>> value = {"principal": "Alicia"} + >>> are_roles(value) + True + + >>> value = {"kids": ["Alicia", "Javier"]} + >>> are_roles(value) + True + + >>> value = {"principal": "Alicia", "kids": ["Tom"]} + >>> are_roles(value) + True + + """ + + if not isinstance(value, dict): + return False + + for role_key, role_id in value.items(): + if not isinstance(role_key, str): + return False + + if not isinstance(role_id, (Iterable, str)): + return False + + if isinstance(role_id, Iterable): + for role in role_id: + if not isinstance(role, str): + return False + + return True + + +def are_axes(value: object) -> TypeGuard[Axes]: + """Check if the given params are axes. + + Axis expansion is a feature that allows users to parametrise some + dimensions in order to create and to evaluate a range of values for others. + + Args: + value(object): Simulation parameters. + + Returns: + bool: True if the params are axes. + + Examples: + >>> value = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + >>> are_axes(value) + False + + >>> value = { + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + >>> are_axes(value) + False + + >>> value = [[{"a": 1, "b": 1, "c": 1}]] + >>> are_axes(value) + False + + >>> value = [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + >>> are_axes(value) + True + + """ + + if not isinstance(value, (list, tuple)): + return False + + (inner,) = value + + if not isinstance(inner, (list, tuple)): + return False + + return all(key in axis_required for key in inner[0].keys()) + + +def are_entities_specified( + params: Params, items: Iterable[str] +) -> TypeGuard[Variables]: + """Check if the params contains entities at all. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of variables. + + Returns: + bool: True if the params does not contain variables at the root level. + + Examples: + >>> variables = {"salary"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"salary": {"2016-10": [12000, 13000]}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": [12000, 13000]} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, variables) + False + + >>> params = {} + + >>> are_entities_specified(params, variables) + False + + """ + + if not params: + return False + + return not any(key in items for key in params.keys()) + + +def are_entities_short_form( + params: Params, items: Iterable[str] +) -> TypeGuard[ImplicitGroupEntities]: + """Check if the params contain short form entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in singular form. + + Returns: + bool: True if the params contain short form entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = { + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": "Javier"}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {} + + >>> are_entities_short_form(params, entities) + False + + """ + + return not not set(params).intersection(items) + + +def are_entities_fully_specified( + params: Params, items: Iterable[str] +) -> TypeGuard[FullySpecifiedEntities]: + """Check if the params contain fully specified entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the params contain fully specified entities. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {} + + >>> are_entities_fully_specified(params, entities) + False + + """ + + if not params: + return False + + return all(key in items for key in params.keys() if key != "axes") + + +def has_axes(value: object) -> TypeGuard[ParamsWithAxes]: + """Check if the params contains axes. + + Args: + value(object): Simulation parameters. + + Returns: + bool: True if the params contain axes. + + Examples: + >>> value = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + >>> has_axes(value) + True + + >>> value = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + >>> has_axes(value) + False + + """ + + if not isinstance(value, dict): + return False + + return value.get("axes", None) is not None diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py deleted file mode 100644 index c34361041..000000000 --- a/openfisca_core/simulations/_type_guards.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Type guards to help type narrowing simulation parameters.""" - -from __future__ import annotations - -from typing import Iterable -from typing_extensions import TypeGuard - -from .typing import ( - Axes, - DatedVariable, - FullySpecifiedEntities, - ImplicitGroupEntities, - Params, - UndatedVariable, - Variables, -) - - -def are_entities_fully_specified( - params: Params, items: Iterable[str] -) -> TypeGuard[FullySpecifiedEntities]: - """Check if the params contain fully specified entities. - - Args: - params(Params): Simulation parameters. - items(Iterable[str]): List of entities in plural form. - - Returns: - bool: True if the params contain fully specified entities. - - Examples: - >>> entities = {"persons", "households"} - - >>> params = { - ... "axes": [ - ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] - ... ], - ... "households": { - ... "housea": {"parents": ["Alicia", "Javier"]}, - ... "houseb": {"parents": ["Tom"]}, - ... }, - ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, - ... } - - >>> are_entities_fully_specified(params, entities) - True - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } - - >>> are_entities_fully_specified(params, entities) - True - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } - - >>> are_entities_fully_specified(params, entities) - True - - >>> params = {"household": {"parents": ["Javier"]}} - - >>> are_entities_fully_specified(params, entities) - False - - >>> params = {"salary": {"2016-10": 12000}} - - >>> are_entities_fully_specified(params, entities) - False - - >>> params = {"salary": 12000} - - >>> are_entities_fully_specified(params, entities) - False - - >>> params = {} - - >>> are_entities_fully_specified(params, entities) - False - - """ - - if not params: - return False - - return all(key in items for key in params.keys() if key != "axes") - - -def are_entities_short_form( - params: Params, items: Iterable[str] -) -> TypeGuard[ImplicitGroupEntities]: - """Check if the params contain short form entities. - - Args: - params(Params): Simulation parameters. - items(Iterable[str]): List of entities in singular form. - - Returns: - bool: True if the params contain short form entities. - - Examples: - >>> entities = {"person", "household"} - - >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, - ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] - ... } - - >>> are_entities_short_form(params, entities) - False - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } - - >>> are_entities_short_form(params, entities) - False - - >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, - ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] - ... } - - >>> are_entities_short_form(params, entities) - True - - >>> params = { - ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] - ... } - - >>> are_entities_short_form(params, entities) - True - - >>> params = {"household": {"parents": ["Javier"]}} - - >>> are_entities_short_form(params, entities) - True - - >>> params = {"household": {"parents": "Javier"}} - - >>> are_entities_short_form(params, entities) - True - - >>> params = {"salary": {"2016-10": 12000}} - - >>> are_entities_short_form(params, entities) - False - - >>> params = {"salary": 12000} - - >>> are_entities_short_form(params, entities) - False - - >>> params = {} - - >>> are_entities_short_form(params, entities) - False - - """ - - return not not set(params).intersection(items) - - -def are_entities_specified( - params: Params, items: Iterable[str] -) -> TypeGuard[Variables]: - """Check if the params contains entities at all. - - Args: - params(Params): Simulation parameters. - items(Iterable[str]): List of variables. - - Returns: - bool: True if the params does not contain variables at the root level. - - Examples: - >>> variables = {"salary"} - - >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, - ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] - ... } - - >>> are_entities_specified(params, variables) - True - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } - - >>> are_entities_specified(params, variables) - True - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } - - >>> are_entities_specified(params, variables) - True - - >>> params = {"household": {"parents": ["Javier"]}} - - >>> are_entities_specified(params, variables) - True - - >>> params = {"salary": {"2016-10": [12000, 13000]}} - - >>> are_entities_specified(params, variables) - False - - >>> params = {"salary": {"2016-10": 12000}} - - >>> are_entities_specified(params, variables) - False - - >>> params = {"salary": [12000, 13000]} - - >>> are_entities_specified(params, variables) - False - - >>> params = {"salary": 12000} - - >>> are_entities_specified(params, variables) - False - - >>> params = {} - - >>> are_entities_specified(params, variables) - False - - """ - - if not params: - return False - - return not any(key in items for key in params.keys()) - - -def has_axes(params: Params) -> TypeGuard[Axes]: - """Check if the params contains axes. - - Args: - params(Params): Simulation parameters. - - Returns: - bool: True if the params contain axes. - - Examples: - >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, - ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] - ... } - - >>> has_axes(params) - True - - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } - - >>> has_axes(params) - False - - """ - - return params.get("axes", None) is not None - - -def is_variable_dated( - variable: DatedVariable | UndatedVariable, -) -> TypeGuard[DatedVariable]: - """Check if the variable is dated. - - Args: - variable(DatedVariable | UndatedVariable): A variable. - - Returns: - bool: True if the variable is dated. - - Examples: - >>> variable = {"2018-11": [2000, 3000]} - - >>> is_variable_dated(variable) - True - - >>> variable = {"2018-11": 2000} - - >>> is_variable_dated(variable) - True - - >>> variable = 2000 - - >>> is_variable_dated(variable) - False - - """ - - return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index d5984d88b..edc3fa1ad 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -2,7 +2,7 @@ from openfisca_core import errors -from .typing import ParamsWithoutAxes +from .types import ParamsWithoutAxes def calculate_output_add(simulation, variable_name: str, period): @@ -58,7 +58,7 @@ def check_unexpected_entities( >>> check_unexpected_entities(params, entities) Traceback (most recent call last): - openfisca_core.errors.situation_parsing_error.SituationParsingError + openfisca_core.errors.situation_parsing_error.SituationParsingError... """ diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index c4525525d..7cd2fe4c6 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,8 +1,6 @@ from __future__ import annotations -from typing import Dict, Mapping, NamedTuple, Optional, Set - -from openfisca_core.types import SinglePopulation, TaxBenefitSystem, Variable +from typing import NamedTuple, Optional, Set import tempfile import warnings @@ -12,6 +10,16 @@ from openfisca_core import commons, errors, indexed_enums, periods, tracers from openfisca_core import warnings as core_warnings +from .types import ( + EntityPlural, + GroupEntity, + GroupPopulation, + Populations, + TaxBenefitSystem, + Variable, + VariableName, +) + class Simulation: """ @@ -19,13 +27,13 @@ class Simulation: """ tax_benefit_system: TaxBenefitSystem - populations: Dict[str, SinglePopulation] + populations: Populations invalidated_caches: Set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, - populations: Mapping[str, SinglePopulation], + populations: Populations, ): """ This constructor is reserved for internal use; see :any:`SimulationBuilder`, @@ -93,7 +101,7 @@ def data_storage_dir(self): # ----- Calculation methods ----- # - def calculate(self, variable_name: str, period): + def calculate(self, variable_name: VariableName, period): """Calculate ``variable_name`` for ``period``.""" if period is not None and not isinstance(period, periods.Period): @@ -110,7 +118,7 @@ def calculate(self, variable_name: str, period): self.tracer.record_calculation_end() self.purge_cache_of_invalid_values() - def _calculate(self, variable_name: str, period: periods.Period): + def _calculate(self, variable_name: VariableName, period: periods.Period): """ Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. @@ -162,7 +170,7 @@ def purge_cache_of_invalid_values(self): holder.delete_arrays(_period) self.invalidated_caches = set() - def calculate_add(self, variable_name: str, period): + def calculate_add(self, variable_name: VariableName, period): variable: Optional[Variable] variable = self.tax_benefit_system.get_variable( @@ -200,7 +208,7 @@ def calculate_add(self, variable_name: str, period): for sub_period in period.get_subperiods(variable.definition_period) ) - def calculate_divide(self, variable_name: str, period): + def calculate_divide(self, variable_name: VariableName, period): variable: Optional[Variable] variable = self.tax_benefit_system.get_variable( @@ -277,7 +285,7 @@ def calculate_divide(self, variable_name: str, period): return self.calculate(variable_name, calculation_period) / denominator - def calculate_output(self, variable_name: str, period): + def calculate_output(self, variable_name: VariableName, period): """ Calculate the value of a variable using the ``calculate_output`` attribute of the variable. """ @@ -427,7 +435,7 @@ def invalidate_spiral_variables(self, variable: str): # ----- Methods to access stored values ----- # - def get_array(self, variable_name: str, period): + def get_array(self, variable_name: VariableName, period): """ Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). @@ -437,7 +445,7 @@ def get_array(self, variable_name: str, period): period = periods.period(period) return self.get_holder(variable_name).get_array(period) - def get_holder(self, variable_name: str): + def get_holder(self, variable_name: VariableName): """Get the holder associated with the variable.""" return self.get_variable_population(variable_name).get_holder(variable_name) @@ -500,7 +508,7 @@ def get_known_periods(self, variable): """ return self.get_holder(variable).get_known_periods() - def set_input(self, variable_name: str, period, value): + def set_input(self, variable_name: VariableName, period, value): """ Set a variable's value for a given period @@ -531,7 +539,7 @@ def set_input(self, variable_name: str, period, value): return self.get_holder(variable_name).set_input(period, value) - def get_variable_population(self, variable_name: str) -> SinglePopulation: + def get_variable_population(self, variable_name: VariableName) -> GroupPopulation: variable: Optional[Variable] variable = self.tax_benefit_system.get_variable( @@ -543,9 +551,7 @@ def get_variable_population(self, variable_name: str) -> SinglePopulation: return self.populations[variable.entity.key] - def get_population( - self, plural: Optional[str] = None - ) -> Optional[SinglePopulation]: + def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulation]: return next( ( population @@ -557,10 +563,14 @@ def get_population( def get_entity( self, - plural: Optional[str] = None, - ) -> Optional[SinglePopulation]: - population = self.get_population(plural) - return population and population.entity + plural: EntityPlural | None = None, + ) -> GroupEntity | None: + population: GroupPopulation | None = self.get_population(plural) + + if population is None: + return None + + return population.entity def describe_entities(self): return { diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index c42d0e4f2..f1a2b71b8 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,45 +1,50 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from numpy.typing import NDArray as Array -from typing import Dict, List import copy import dpath.util import numpy -from openfisca_core import entities, errors, periods, populations, variables +from openfisca_core import errors, periods from . import helpers from ._build_default_simulation import _BuildDefaultSimulation from ._build_from_variables import _BuildFromVariables -from ._type_guards import ( +from ._guards import ( are_entities_fully_specified, are_entities_short_form, are_entities_specified, has_axes, ) from .simulation import Simulation -from .typing import ( +from .types import ( + Array, Axis, - Entity, + EntityCounts, + EntityIds, + EntityRoles, FullySpecifiedEntities, GroupEntities, GroupEntity, ImplicitGroupEntities, + InputBuffer, + Memberships, Params, ParamsWithoutAxes, - Population, + Populations, Role, SingleEntity, + SinglePopulation, TaxBenefitSystem, + VariableEntity, Variables, ) class SimulationBuilder: - def __init__(self): + def __init__(self) -> None: self.default_period = ( None # Simulation period used for variables when no period is defined ) @@ -48,26 +53,24 @@ def __init__(self): ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: Dict[ - variables.Variable.name, Dict[str(periods.period), numpy.array] - ] = {} - self.populations: Dict[entities.Entity.key, populations.Population] = {} + self.input_buffer: InputBuffer = {} + self.populations: Populations = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: Dict[entities.Entity.plural, int] = {} + self.entity_counts: EntityCounts = {} # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: Dict[entities.Entity.plural, List[int]] = {} + self.entity_ids: EntityIds = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: Dict[entities.Entity.plural, List[int]] = {} - self.roles: Dict[entities.Entity.plural, List[int]] = {} + self.memberships: Memberships = {} + self.roles: EntityRoles = {} - self.variable_entities: Dict[variables.Variable.name, entities.Entity] = {} + self.variable_entities: VariableEntity = {} self.axes = [[]] - self.axes_entity_counts: Dict[entities.Entity.plural, int] = {} - self.axes_entity_ids: Dict[entities.Entity.plural, List[int]] = {} - self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {} - self.axes_roles: Dict[entities.Entity.plural, List[int]] = {} + self.axes_entity_counts: EntityCounts = {} + self.axes_entity_ids: EntityIds = {} + self.axes_memberships: Memberships = {} + self.axes_roles: EntityRoles = {} def build_from_dict( self, @@ -748,7 +751,7 @@ def expand_axes(self) -> None: if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] - axis_count: int = first_axis["count"] + axis_count = first_axis["count"] axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes @@ -803,10 +806,10 @@ def expand_axes(self) -> None: ) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name: str) -> Entity: + def get_variable_entity(self, variable_name: str) -> SingleEntity: return self.variable_entities[variable_name] - def register_variable(self, variable_name: str, entity: Entity) -> None: + def register_variable(self, variable_name: str, entity: SingleEntity) -> None: self.variable_entities[variable_name] = entity def register_variables(self, simulation: Simulation) -> None: @@ -814,6 +817,6 @@ def register_variables(self, simulation: Simulation) -> None: variables: Iterable[str] = tax_benefit_system.variables.keys() for name in variables: - population: Population = simulation.get_variable_population(name) - entity: Entity = population.entity + population: SinglePopulation = simulation.get_variable_population(name) + entity: SingleEntity = population.entity self.register_variable(name, entity) diff --git a/openfisca_core/simulations/types.py b/openfisca_core/simulations/types.py new file mode 100644 index 000000000..68dd7dc96 --- /dev/null +++ b/openfisca_core/simulations/types.py @@ -0,0 +1,273 @@ +"""Type aliases of OpenFisca models to use in the context of simulations.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Sequence +from typing import Literal, NewType, Protocol, TypeVar, Union +from typing_extensions import NotRequired, TypeAlias, TypedDict + +import datetime + +from numpy import bool_ as Bool +from numpy import datetime64 as Date +from numpy import float32 as Float +from numpy import int16 as Enum +from numpy import int32 as Int +from numpy import str_ as String + +from openfisca_core import types as t + +# Generic type variables. +G = TypeVar("G", covariant=True) +T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) +U = TypeVar("U", bool, datetime.date, float, str) +V = TypeVar("V", covariant=True) + +# New types. + +#: Literally "axes". +AxesKey = Literal["axes"] + +#: For example "Juan". +EntityId = NewType("EntityId", int) + +#: For example "person". +EntityKey = NewType("EntityKey", str) + +#: For example "persons". +EntityPlural = NewType("EntityPlural", str) + +#: For example "2023-12". +PeriodStr = NewType("PeriodStr", str) + +#: For example "principal". +RoleKey = NewType("RoleKey", str) + +#: For example "parents". +RolePlural = NewType("RolePlural", str) + +#: For example "salary". +VariableName = NewType("VariableName", str) + +# Type aliases. + +#: Type alias for numpy arrays values. +Item: TypeAlias = Union[Bool, Date, Enum, Float, Int, String] + +#: Type Alias for a numpy Array. +Array: TypeAlias = t.Array + +#: Type alias for a role identifier. +RoleId: TypeAlias = EntityId + +# Entities + + +class CoreEntity(t.CoreEntity, Protocol): + key: EntityKey + plural: EntityPlural | None + + def get_variable( + self, + __variable_name: VariableName, + check_existence: bool = ..., + ) -> Variable[T] | None: + ... + + +class SingleEntity(t.SingleEntity, Protocol): + ... + + +class GroupEntity(t.GroupEntity, Protocol): + @property + def flattened_roles(self) -> Iterable[Role[G]]: + ... + + +class Role(t.Role, Protocol[G]): + ... + + +# Holders + + +class Holder(t.Holder, Protocol[V]): + @property + def variable(self) -> Variable[T]: + ... + + def get_array(self, __period: PeriodStr) -> t.Array[T] | None: + ... + + def set_input( + self, + __period: Period, + __array: t.Array[T] | Sequence[U], + ) -> t.Array[T] | None: + ... + + +# Periods + + +class Instant(t.Instant, Protocol): + ... + + +class Period(t.Period, Protocol): + ... + + +# Populations + + +class CorePopulation(t.CorePopulation, Protocol): + entity: CoreEntity + + def get_holder(self, __variable_name: VariableName) -> Holder[V]: + ... + + +class SinglePopulation(t.SinglePopulation, Protocol): + entity: SingleEntity + + +class GroupPopulation(t.GroupPopulation, Protocol): + entity: GroupEntity + members_entity_id: t.Array[String] + + def nb_persons(self, __role: Role[G] | None = ...) -> int: + ... + + +# Simulations + +#: Dictionary with axes parameters per variable. +InputBuffer: TypeAlias = dict[VariableName, dict[PeriodStr, Array]] + +#: Dictionary with entity/population key/pairs. +Populations: TypeAlias = dict[EntityKey, GroupPopulation] + +#: Dictionary with single entity count per group entity. +EntityCounts: TypeAlias = dict[EntityPlural, int] + +#: Dictionary with a list of single entities per group entity. +EntityIds: TypeAlias = dict[EntityPlural, Iterable[EntityId]] + +#: Dictionary with a list of members per group entity. +Memberships: TypeAlias = dict[EntityPlural, Iterable[int]] + +#: Dictionary with a list of roles per group entity. +EntityRoles: TypeAlias = dict[EntityPlural, Iterable[RoleKey]] + +#: Dictionary with a map between variables and entities. +VariableEntity: TypeAlias = dict[VariableName, CoreEntity] + +#: Type alias for a simulation dictionary with undated variable values. +PureValue: TypeAlias = Union[object, Sequence[object]] + +#: Type alias for a simulation dictionary with dated variable values. +DatedValue: TypeAlias = dict[PeriodStr, PureValue] + +#: Type alias for a simulation dictionary with abbreviated entities. +Variables: TypeAlias = dict[VariableName, Union[PureValue, DatedValue]] + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = Union[dict[RoleKey, RoleId], dict[RolePlural, Iterable[RoleId]]] + +#: Type alias for a simulation dictionary with axes parameters. +Axes: TypeAlias = Iterable[Iterable["Axis"]] + +#: Type alias for a simulation with fully specified single entities. +SingleEntities: TypeAlias = dict[str, dict[str, Variables]] + +#: Type alias for a simulation dictionary with implicit group entities. +ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] + +#: Type alias for a simulation dictionary with explicit group entities. +GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] + +#: Type alias for a simulation dictionary with fully specified entities. +FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] + +#: Type alias for a simulation dictionary without axes parameters. +ParamsWithoutAxes: TypeAlias = Union[ + Variables, ImplicitGroupEntities, FullySpecifiedEntities +] + +#: Type alias for a simulation dictionary with axes parameters. +ParamsWithAxes: TypeAlias = Union[dict[AxesKey, Axes], ParamsWithoutAxes] + +#: Type alias for a simulation dictionary with all the possible scenarios. +Params: TypeAlias = ParamsWithAxes + + +class Axis(TypedDict): + count: int + max: float + min: float + index: NotRequired[int] + name: EntityKey + period: NotRequired[Union[str, int]] + + +class Simulation(t.Simulation, Protocol): + ... + + +# Tax-Benefit systems + + +class TaxBenefitSystem(t.TaxBenefitSystem, Protocol): + @property + def person_entity(self) -> SingleEntity: + ... + + @person_entity.setter + def person_entity(self, person_entity: SingleEntity) -> None: + ... + + @property + def variables(self) -> dict[str, V]: + ... + + def entities_by_singular(self) -> dict[EntityKey, CoreEntity]: + ... + + def entities_plural(self) -> Iterable[EntityPlural]: + ... + + def get_variable( + self, + __variable_name: VariableName, + check_existence: bool = ..., + ) -> Variable[T] | None: + ... + + def instantiate_entities( + self, + ) -> Populations: + ... + + +# Variables + + +class Variable(t.Variable, Protocol[T]): + calculate_output: Callable[[Simulation, str, str], t.Array[T]] | None + definition_period: str + end: str + name: VariableName + + def default_array(self, __array_size: int) -> t.Array[T]: + ... + + def get_formula( + self, __period: Instant | Period | PeriodStr | Int + ) -> Formula | None: + ... + + +class Formula(t.Formula, Protocol): + ... diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py deleted file mode 100644 index 8603d0d81..000000000 --- a/openfisca_core/simulations/typing.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Type aliases of OpenFisca models to use in the context of simulations.""" - -from __future__ import annotations - -from collections.abc import Iterable, Sequence -from numpy.typing import NDArray as Array -from typing import Protocol, TypeVar, TypedDict, Union -from typing_extensions import NotRequired, Required, TypeAlias - -import datetime -from abc import abstractmethod - -from numpy import bool_ as Bool -from numpy import datetime64 as Date -from numpy import float32 as Float -from numpy import int16 as Enum -from numpy import int32 as Int -from numpy import str_ as String - -#: Generic type variables. -E = TypeVar("E") -G = TypeVar("G", covariant=True) -T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) -U = TypeVar("U", bool, datetime.date, float, str) -V = TypeVar("V", covariant=True) - - -#: Type alias for a simulation dictionary defining the roles. -Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] - -#: Type alias for a simulation dictionary with undated variables. -UndatedVariable: TypeAlias = dict[str, object] - -#: Type alias for a simulation dictionary with dated variables. -DatedVariable: TypeAlias = dict[str, UndatedVariable] - -#: Type alias for a simulation dictionary with abbreviated entities. -Variables: TypeAlias = dict[str, Union[UndatedVariable, DatedVariable]] - -#: Type alias for a simulation with fully specified single entities. -SingleEntities: TypeAlias = dict[str, dict[str, Variables]] - -#: Type alias for a simulation dictionary with implicit group entities. -ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] - -#: Type alias for a simulation dictionary with explicit group entities. -GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] - -#: Type alias for a simulation dictionary with fully specified entities. -FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] - -#: Type alias for a simulation dictionary with axes parameters. -Axes: TypeAlias = dict[str, Iterable[Iterable["Axis"]]] - -#: Type alias for a simulation dictionary without axes parameters. -ParamsWithoutAxes: TypeAlias = Union[ - Variables, ImplicitGroupEntities, FullySpecifiedEntities -] - -#: Type alias for a simulation dictionary with axes parameters. -ParamsWithAxes: TypeAlias = Union[Axes, ParamsWithoutAxes] - -#: Type alias for a simulation dictionary with all the possible scenarios. -Params: TypeAlias = ParamsWithAxes - - -class Axis(TypedDict, total=False): - """Interface representing an axis of a simulation.""" - - count: Required[int] - index: NotRequired[int] - max: Required[float] - min: Required[float] - name: Required[str] - period: NotRequired[str | int] - - -class Entity(Protocol): - """Interface representing an entity of a simulation.""" - - key: str - plural: str | None - - def get_variable( - self, - __variable_name: str, - __check_existence: bool = ..., - ) -> Variable[T] | None: - """Get a variable.""" - - -class SingleEntity(Entity, Protocol): - """Interface representing a single entity of a simulation.""" - - -class GroupEntity(Entity, Protocol): - """Interface representing a group entity of a simulation.""" - - @property - @abstractmethod - def flattened_roles(self) -> Iterable[Role[G]]: - """Get the flattened roles of the GroupEntity.""" - - -class Holder(Protocol[V]): - """Interface representing a holder of a simulation's computed values.""" - - @property - @abstractmethod - def variable(self) -> Variable[T]: - """Get the Variable of the Holder.""" - - def get_array(self, __period: str) -> Array[T] | None: - """Get the values of the Variable for a given Period.""" - - def set_input( - self, - __period: Period, - __array: Array[T] | Sequence[U], - ) -> Array[T] | None: - """Set values for a Variable for a given Period.""" - - -class Period(Protocol): - """Interface representing a period of a simulation.""" - - -class Population(Protocol[E]): - """Interface representing a data vector of an Entity.""" - - count: int - entity: E - ids: Array[String] - - def get_holder(self, __variable_name: str) -> Holder[V]: - """Get the holder of a Variable.""" - - -class SinglePopulation(Population[E], Protocol): - """Interface representing a data vector of a SingleEntity.""" - - -class GroupPopulation(Population[E], Protocol): - """Interface representing a data vector of a GroupEntity.""" - - members_entity_id: Array[String] - - def nb_persons(self, __role: Role[G] | None = ...) -> int: - """Get the number of persons for a given Role.""" - - -class Role(Protocol[G]): - """Interface representing a role of the group entities of a simulation.""" - - -class TaxBenefitSystem(Protocol): - """Interface representing a tax-benefit system.""" - - @property - @abstractmethod - def person_entity(self) -> SingleEntity: - """Get the person entity of the tax-benefit system.""" - - @person_entity.setter - @abstractmethod - def person_entity(self, person_entity: SingleEntity) -> None: - """Set the person entity of the tax-benefit system.""" - - @property - @abstractmethod - def variables(self) -> dict[str, V]: - """Get the variables of the tax-benefit system.""" - - def entities_by_singular(self) -> dict[str, E]: - """Get the singular form of the entities' keys.""" - - def entities_plural(self) -> Iterable[str]: - """Get the plural form of the entities' keys.""" - - def get_variable( - self, - __variable_name: str, - __check_existence: bool = ..., - ) -> V | None: - """Get a variable.""" - - def instantiate_entities( - self, - ) -> dict[str, Population[E]]: - """Instantiate the populations of each Entity.""" - - -class Variable(Protocol[T]): - """Interface representing a variable of a tax-benefit system.""" - - end: str - - def default_array(self, __array_size: int) -> Array[T]: - """Fill an array with the default value of the Variable.""" diff --git a/openfisca_core/types.py b/openfisca_core/types.py index fcb7f4898..617cd2fee 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -145,16 +145,18 @@ def unit(self) -> DateUnit: class CorePopulation(Protocol): - ... - - -class SinglePopulation(CorePopulation, Protocol): + count: int entity: Any + ids: Array[numpy.str_] def get_holder(self, variable_name: Any) -> Any: ... +class SinglePopulation(CorePopulation, Protocol): + ... + + class GroupPopulation(CorePopulation, Protocol): ... diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk index 4d1f3e097..128a3f51b 100644 --- a/openfisca_tasks/lint.mk +++ b/openfisca_tasks/lint.mk @@ -45,6 +45,7 @@ check-types: openfisca_core/indexed_enums \ openfisca_core/periods \ openfisca_core/projectors \ + openfisca_core/simulations \ openfisca_core/types.py @$(call print_pass,$@:) diff --git a/setup.cfg b/setup.cfg index cc850c06a..5c47d1cd8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,7 +75,7 @@ ignore_missing_imports = true implicit_reexport = false install_types = true non_interactive = true -plugins = numpy.typing.mypy_plugin +plugins = numpy.typing.mypy_plugin, pydantic.mypy pretty = true python_version = 3.9 strict = false diff --git a/setup.py b/setup.py index 9aa018fca..e44b38a96 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "numpy >=1.24.2, <1.25", "pendulum >=3.0.0, <4.0.0", "psutil >=5.9.4, <6.0", + "pydantic >=2.9.1, <3.0", "pytest >=7.2.2, <8.0", "sortedcontainers >=2.4.0, <3.0", "typing_extensions >=4.5.0, <5.0",