diff --git a/openfisca_core/commons/formulas.py b/openfisca_core/commons/formulas.py index 6a90622147..7249c16781 100644 --- a/openfisca_core/commons/formulas.py +++ b/openfisca_core/commons/formulas.py @@ -1,9 +1,10 @@ +from __future__ import annotations + from typing import Any, Dict, Sequence, TypeVar +from openfisca_core.typing import ArrayLike, ArrayType import numpy -from openfisca_core.types import ArrayLike, ArrayType - T = TypeVar("T") diff --git a/openfisca_core/commons/misc.py b/openfisca_core/commons/misc.py index dd05cea11b..2461b9afa2 100644 --- a/openfisca_core/commons/misc.py +++ b/openfisca_core/commons/misc.py @@ -1,6 +1,7 @@ -from typing import TypeVar +from __future__ import annotations -from openfisca_core.types import ArrayType +from typing import TypeVar +from openfisca_core.typing import ArrayType T = TypeVar("T") diff --git a/openfisca_core/commons/rates.py b/openfisca_core/commons/rates.py index d682824207..196f2072ea 100644 --- a/openfisca_core/commons/rates.py +++ b/openfisca_core/commons/rates.py @@ -1,9 +1,10 @@ +from __future__ import annotations + from typing import Optional +from openfisca_core.typing import ArrayLike, ArrayType import numpy -from openfisca_core.types import ArrayLike, ArrayType - def average_rate( target: ArrayType[float], diff --git a/openfisca_core/data_storage/on_disk_storage.py b/openfisca_core/data_storage/on_disk_storage.py index 10d4696b58..a8467ab7b1 100644 --- a/openfisca_core/data_storage/on_disk_storage.py +++ b/openfisca_core/data_storage/on_disk_storage.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +from typing import Any, AbstractSet, MutableMapping +from openfisca_core.typing import ArrayType, PeriodProtocol + import os import shutil import numpy from openfisca_core import periods -from openfisca_core.indexed_enums import EnumArray +from openfisca_core.indexed_enums import Enum, EnumArray class OnDiskStorage: @@ -12,9 +17,15 @@ class OnDiskStorage: Low-level class responsible for storing and retrieving calculated vectors on disk """ - def __init__(self, storage_dir, is_eternal = False, preserve_storage_dir = False): - self._files = {} - self._enums = {} + def __init__( + self, + storage_dir: str, + is_eternal: bool = False, + preserve_storage_dir: bool = False, + ) -> None: + + self._files: MutableMapping[PeriodProtocol, ArrayType[Any]] = {} + self._enums: MutableMapping[str, Enum] = {} self.is_eternal = is_eternal self.preserve_storage_dir = preserve_storage_dir self.storage_dir = storage_dir @@ -26,7 +37,7 @@ def _decode_file(self, file): else: return numpy.load(file) - def get(self, period): + def get(self, period: PeriodProtocol) -> ArrayType[Any]: if self.is_eternal: period = periods.period(periods.ETERNITY) period = periods.period(period) @@ -36,7 +47,7 @@ def get(self, period): return None return self._decode_file(values) - def put(self, value, period): + def put(self, value: ArrayType[Any], period: PeriodProtocol) -> None: if self.is_eternal: period = periods.period(periods.ETERNITY) period = periods.period(period) @@ -65,10 +76,11 @@ def delete(self, period = None): if not period.contains(period_item) } - def get_known_periods(self): + def get_known_periods(self) -> AbstractSet[PeriodProtocol]: return self._files.keys() - def restore(self): + def restore(self) -> None: + files: MutableMapping[PeriodProtocol, ArrayType[Any]] self._files = files = {} # Restore self._files from content of storage_dir. for filename in os.listdir(self.storage_dir): @@ -79,7 +91,7 @@ def restore(self): period = periods.period(filename_core) files[period] = path - def __del__(self): + def __del__(self) -> None: if self.preserve_storage_dir: return shutil.rmtree(self.storage_dir) # Remove the holder temporary files diff --git a/openfisca_core/errors/variable_not_found_error.py b/openfisca_core/errors/variable_not_found_error.py index f84ce06f95..4a79eb1566 100644 --- a/openfisca_core/errors/variable_not_found_error.py +++ b/openfisca_core/errors/variable_not_found_error.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from openfisca_core.typing import TaxBenefitSystemProtocol + import os @@ -6,7 +10,11 @@ class VariableNotFoundError(Exception): Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem. """ - def __init__(self, variable_name, tax_benefit_system): + def __init__( + self, + variable_name: str, + tax_benefit_system: TaxBenefitSystemProtocol, + ) -> None: """ :param variable_name: Name of the variable that was queried. :param tax_benefit_system: Tax benefits system that does not contain `variable_name` diff --git a/openfisca_core/holders/holder.py b/openfisca_core/holders/holder.py index 3d0379d22d..a4aaf47b6c 100644 --- a/openfisca_core/holders/holder.py +++ b/openfisca_core/holders/holder.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Any, Optional, Sequence +from openfisca_core.typing import ArrayType + import os import warnings @@ -48,7 +53,12 @@ def clone(self, population): return new - def create_disk_storage(self, directory = None, preserve = False): + def create_disk_storage( + self, + directory: Optional[str] = None, + preserve: bool = False, + ) -> OnDiskStorage: + if directory is None: directory = self.simulation.data_storage_dir storage_dir = os.path.join(directory, self.variable.name) @@ -71,12 +81,13 @@ def delete_arrays(self, period = None): if self._disk_storage: self._disk_storage.delete(period) - def get_array(self, period): + def get_array(self, period: periods.Period) -> Any: """ Get the value of the variable for the given period. If the value is not known, return ``None``. """ + if self.variable.is_neutralized: return self.default_array() value = self._memory_storage.get(period) @@ -122,7 +133,7 @@ def get_memory_usage(self): return usage - def get_known_periods(self): + def get_known_periods(self) -> Sequence[periods.Period]: """ Get the list of periods the variable value is known for. """ @@ -227,7 +238,12 @@ def _set(self, period, value): else: self._memory_storage.put(value, period) - def put_in_cache(self, value, period): + def put_in_cache( + self, + value: ArrayType[Any], + period: periods.Period, + ) -> None: + if self._do_not_store: return diff --git a/openfisca_core/populations/population.py b/openfisca_core/populations/population.py index 41cdbcd8c4..fa13ef3070 100644 --- a/openfisca_core/populations/population.py +++ b/openfisca_core/populations/population.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Optional +from openfisca_core.typing import ArrayLike + import traceback import numpy @@ -13,8 +18,8 @@ def __init__(self, entity): self.simulation = None self.entity = entity self._holders = {} - self.count = 0 - self.ids = [] + self.count: Optional[int] = 0 + self.ids: ArrayLike[str] = [] def clone(self, simulation): result = Population(self.entity) @@ -36,7 +41,7 @@ def __getattr__(self, attribute): raise AttributeError("You tried to use the '{}' of '{}' but that is not a known attribute.".format(attribute, self.entity.key)) return projector - def get_index(self, id): + def get_index(self, id: str) -> int: return self.ids.index(id) # Calculations diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 5dd2694292..c737260839 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,3 +1,14 @@ +from __future__ import annotations + +from typing import Any, Mapping, Optional, Set, Tuple +from openfisca_core.typing import ( + ArrayType, + HolderProtocol, + PeriodProtocol, + PopulationProtocol, + TaxBenefitSystemProtocol, + ) + import tempfile import warnings @@ -18,14 +29,15 @@ class Simulation: def __init__( self, - tax_benefit_system, - populations - ): + tax_benefit_system: TaxBenefitSystemProtocol, + populations: Mapping[str, PopulationProtocol] + ) -> None: """ This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ + self.tax_benefit_system = tax_benefit_system assert tax_benefit_system is not None @@ -34,7 +46,7 @@ def __init__( self.link_to_entities_instances() self.create_shortcuts() - self.invalidated_caches = set() + self.invalidated_caches: Set[Tuple[str, PeriodProtocol]] = set() self.debug = False self.trace = False @@ -83,7 +95,11 @@ def data_storage_dir(self): # ----- Calculation methods ----- # - def calculate(self, variable_name, period): + def calculate( + self, + variable_name: str, + period: Optional[Any], + ) -> ArrayType[Any]: """Calculate ``variable_name`` for ``period``.""" if period is not None and not isinstance(period, Period): @@ -291,10 +307,15 @@ def _check_for_cycle(self, variable: str, period): message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack) raise SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): + def invalidate_cache_entry( + self, + variable: str, + period: PeriodProtocol, + ) -> None: + self.invalidated_caches.add((variable, period)) - def invalidate_spiral_variables(self, variable: str): + def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them @@ -319,7 +340,7 @@ def get_array(self, variable_name, period): period = periods.period(period) return self.get_holder(variable_name).get_array(period) - def get_holder(self, variable_name): + def get_holder(self, variable_name: str) -> HolderProtocol: """ Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation """ @@ -414,7 +435,11 @@ def get_variable_population(self, variable_name): variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) return self.populations[variable.entity.key] - def get_population(self, plural = None): + def get_population( + self, + plural: Optional[str] = None, + ) -> Optional[PopulationProtocol]: + return next((population for population in self.populations.values() if population.entity.plural == plural), None) def get_entity(self, plural = None): diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 88553488db..552fa2b278 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,6 +1,15 @@ +from __future__ import annotations + +import typing +from typing import Any, Mapping, Optional, Sequence +from openfisca_core.typing import ( + ArrayType, + AxisSchema, + TaxBenefitSystemProtocol, + ) + import copy import dpath -import typing import numpy @@ -14,12 +23,14 @@ class SimulationBuilder: - def __init__(self): + default_period: Optional[str] + + def __init__(self) -> None: self.default_period = None # Simulation period used for variables when no period is defined self.persons_plural = None # Plural name for person entity in current tax and benefits system # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} + self.input_buffer: typing.Dict[Variable.name, typing.Dict[str, ArrayType]] = {} self.populations: typing.Dict[Entity.key, Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. self.entity_counts: typing.Dict[Entity.plural, int] = {} @@ -32,13 +43,17 @@ def __init__(self): self.variable_entities: typing.Dict[Variable.name, Entity] = {} - self.axes = [[]] + self.axes: Sequence[Sequence[AxisSchema]] = [[]] self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} - def build_from_dict(self, tax_benefit_system, input_dict): + def build_from_dict( + self, + tax_benefit_system: TaxBenefitSystemProtocol, + input_dict: Mapping[str, Any], + ) -> Simulation: """ Build a simulation from ``input_dict`` @@ -322,7 +337,7 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() - def set_default_period(self, period_str): + def set_default_period(self, period_str: Optional[str]) -> None: if period_str: self.default_period = str(periods.period(period_str)) diff --git a/openfisca_core/taxbenefitsystems/tax_benefit_system.py b/openfisca_core/taxbenefitsystems/tax_benefit_system.py index 26e37a7b81..4d8d0bd555 100644 --- a/openfisca_core/taxbenefitsystems/tax_benefit_system.py +++ b/openfisca_core/taxbenefitsystems/tax_benefit_system.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Mapping, Optional, overload +from typing_extensions import Literal + import copy import glob import importlib @@ -68,10 +73,10 @@ def base_tax_benefit_system(self): self._base_tax_benefit_system = base_tax_benefit_system = baseline.base_tax_benefit_system return base_tax_benefit_system - def instantiate_entities(self): + def instantiate_entities(self) -> Mapping[str, Population]: person = self.person_entity members = Population(person) - entities: typing.Dict[Entity.key, Entity] = {person.key: members} + entities: typing.Dict[Entity.key, Population] = {person.key: members} for entity in self.group_entities: entities[entity.key] = GroupPopulation(entity, members) @@ -284,7 +289,27 @@ def apply_reform(self, reform_path): return reform(self) - def get_variable(self, variable_name, check_existence = False): + @overload + def get_variable( + self, + variable_name: str, + check_existence: Literal[True] = ..., + ) -> Variable: + ... + + @overload + def get_variable( + self, + variable_name: str, + check_existence: bool = ..., + ) -> Optional[Variable]: + ... + + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> Optional[Variable]: """ Get a variable from the tax and benefit system. @@ -307,8 +332,19 @@ def neutralize_variable(self, variable_name): """ self.variables[variable_name] = variables.get_neutralized_variable(self.get_variable(variable_name)) - def annualize_variable(self, variable_name: str, period: typing.Optional[Period] = None): - self.variables[variable_name] = variables.get_annualized_variable(self.get_variable(variable_name, period)) + def annualize_variable( + self, + variable_name: str, + period: Optional[Period] = None, + ) -> None: + + variable: Variable + variable = self.get_variable(variable_name, check_existence = True) + + annualized: Variable + annualized = variables.get_annualized_variable(variable, period) + + self.variables[variable_name] = annualized def load_parameters(self, path_to_yaml_dir): """ @@ -355,7 +391,7 @@ def get_parameters_at_instant(self, instant): self._parameters_at_instant_cache[instant] = parameters_at_instant return parameters_at_instant - def get_package_metadata(self): + def get_package_metadata(self) -> Mapping[str, str]: """ Gets metatada relative to the country package the tax and benefit system is built from. @@ -384,19 +420,31 @@ def get_package_metadata(self): } module = inspect.getmodule(self) + + if module is None: + return fallback_metadata + if not module.__package__: return fallback_metadata + package_name = module.__package__.split('.')[0] + try: distribution = pkg_resources.get_distribution(package_name) + except pkg_resources.DistributionNotFound: return fallback_metadata - location = inspect.getsourcefile(module).split(package_name)[0].rstrip('/') + sourcefile = inspect.getsourcefile(module) + + if sourcefile is None: + return fallback_metadata + + location = sourcefile.split(package_name)[0].rstrip('/') home_page_metadatas = [ metadata.split(':', 1)[1].strip(' ') - for metadata in distribution._get_metadata(distribution.PKG_INFO) if 'Home-page' in metadata + for metadata in distribution._get_metadata(distribution.PKG_INFO) if 'Home-page' in metadata # type: ignore ] repository_url = home_page_metadatas[0] if home_page_metadatas else '' return { diff --git a/openfisca_core/tools/__init__.py b/openfisca_core/tools/__init__.py index 9b1dd2cc5d..6022bdd08e 100644 --- a/openfisca_core/tools/__init__.py +++ b/openfisca_core/tools/__init__.py @@ -1,85 +1,13 @@ -# -*- coding: utf-8 -*- - - -import os - -import numexpr - -from openfisca_core.indexed_enums import EnumArray - - -def assert_near(value, target_value, absolute_error_margin = None, message = '', relative_error_margin = None): - ''' - - :param value: Value returned by the test - :param target_value: Value that the test should return to pass - :param absolute_error_margin: Absolute error margin authorized - :param message: Error message to be displayed if the test fails - :param relative_error_margin: Relative error margin authorized - - Limit : This function cannot be used to assert near periods. - - ''' - - import numpy as np - - if absolute_error_margin is None and relative_error_margin is None: - absolute_error_margin = 0 - if not isinstance(value, np.ndarray): - value = np.array(value) - if isinstance(value, EnumArray): - return assert_enum_equals(value, target_value, message) - if np.issubdtype(value.dtype, np.datetime64): - target_value = np.array(target_value, dtype = value.dtype) - assert_datetime_equals(value, target_value, message) - if isinstance(target_value, str): - target_value = eval_expression(target_value) - - target_value = np.array(target_value).astype(np.float32) - - value = np.array(value).astype(np.float32) - diff = abs(target_value - value) - if absolute_error_margin is not None: - assert (diff <= absolute_error_margin).all(), \ - '{}{} differs from {} with an absolute margin {} > {}'.format(message, value, target_value, - diff, absolute_error_margin) - if relative_error_margin is not None: - assert (diff <= abs(relative_error_margin * target_value)).all(), \ - '{}{} differs from {} with a relative margin {} > {}'.format(message, value, target_value, - diff, abs(relative_error_margin * target_value)) - - -def assert_datetime_equals(value, target_value, message = ''): - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) - - -def assert_enum_equals(value, target_value, message = ''): - value = value.decode_to_str() - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) - - -def indent(text): - return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) - - -def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): - import json - import urllib - - scenario_json = scenario.to_json() - simulation_json = { - 'scenarios': [scenario_json], - 'variables': variables, - } - url = trace_tool_url + '?' + urllib.urlencode({ - 'simulation': json.dumps(simulation_json), - 'api_url': api_url, - }) - return url - - -def eval_expression(expression): - try: - return numexpr.evaluate(expression) - except (KeyError, TypeError): - return expression +"""bla""" + +from ._asserts import ( # noqa: F401 + assert_datetime_equals, + assert_enum_equals, + assert_near, + ) + +from ._misc import ( # noqa: F401 + eval_expression, + get_trace_tool_link, + indent, + ) diff --git a/openfisca_core/tools/_asserts.py b/openfisca_core/tools/_asserts.py new file mode 100644 index 0000000000..2f0cff79a7 --- /dev/null +++ b/openfisca_core/tools/_asserts.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Any, Optional +from openfisca_core.typing import ArrayType, ArrayLike + +import datetime + +import numpy + +from openfisca_core.indexed_enums import EnumArray + +from . import _misc + + +def assert_near( + value: ArrayType, + target_value: Any, + absolute_error_margin: Optional[float] = None, + message: str = '', + relative_error_margin: Optional[float] = None, + ) -> None: + ''' + + :param value: Value returned by the test + :param target_value: Value that the test should return to pass + :param absolute_error_margin: Absolute error margin authorized + :param message: Error message to be displayed if the test fails + :param relative_error_margin: Relative error margin authorized + + Limit : This function cannot be used to assert near periods. + + ''' + + if absolute_error_margin is None and relative_error_margin is None: + absolute_error_margin = 0 + if not isinstance(value, numpy.ndarray): + value = numpy.array(value) + if isinstance(value, EnumArray): + return assert_enum_equals(value, target_value, message) + if numpy.issubdtype(value.dtype, numpy.datetime64): + target_value = numpy.array(target_value, dtype = value.dtype) + assert_datetime_equals(value, target_value, message) + if isinstance(target_value, str): + target_value = _misc.eval_expression(target_value) + + target_value = numpy.array(target_value).astype(numpy.float32) + + value = numpy.array(value).astype(numpy.float32) + diff = abs(target_value - value) + if absolute_error_margin is not None: + assert (diff <= absolute_error_margin).all(), \ + '{}{} differs from {} with an absolute margin {} > {}'.format(message, value, target_value, + diff, absolute_error_margin) + if relative_error_margin is not None: + assert (diff <= abs(relative_error_margin * target_value)).all(), \ + '{}{} differs from {} with a relative margin {} > {}'.format(message, value, target_value, + diff, abs(relative_error_margin * target_value)) + + +def assert_datetime_equals( + value: ArrayType[datetime.date], + target_value: ArrayLike[datetime.date], + message: str = '', + ) -> None: + + assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) + + +def assert_enum_equals( + value: EnumArray, + target_value: str, + message: str = '', + ) -> None: + + value_ = value.decode_to_str() + + assert (value_ == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) diff --git a/openfisca_core/tools/_misc.py b/openfisca_core/tools/_misc.py new file mode 100644 index 0000000000..663aac691b --- /dev/null +++ b/openfisca_core/tools/_misc.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict, Sequence, cast +from openfisca_core.typing import TaxBenefitSystemProtocol + +import json +import os +from urllib import parse + +import numexpr + +_tax_benefit_system_cache: Dict[int, TaxBenefitSystemProtocol] = {} + + +def indent(text: str) -> str: + return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) + + +def eval_expression(expression: str) -> Any: + try: + return numexpr.evaluate(expression) + except (KeyError, TypeError): + return expression + + +def get_trace_tool_link( + scenario: Any, + variables: Any, + api_url: str, + trace_tool_url: str, + ) -> str: + + scenario_json = scenario.to_json() + simulation_json = { + 'scenarios': [scenario_json], + 'variables': variables, + } + url = trace_tool_url + '?' + parse.urlencode({ + 'simulation': json.dumps(simulation_json), + 'api_url': api_url, + }) + + return url + + +def _get_tax_benefit_system( + baseline: TaxBenefitSystemProtocol, + reforms: Sequence[str], + extensions: Sequence[str], + ) -> TaxBenefitSystemProtocol: + + if not isinstance(reforms, list): + reforms = cast(Sequence[str], [reforms]) + if not isinstance(extensions, list): + extensions = cast(Sequence[str], [extensions]) + + # keep reforms order in cache, ignore extensions order + key = hash((id(baseline), ':'.join(reforms), frozenset(extensions))) + if key in _tax_benefit_system_cache: + return _tax_benefit_system_cache[key] + + current_tax_benefit_system = baseline + + for reform_path in reforms: + current_tax_benefit_system = current_tax_benefit_system.apply_reform(reform_path) + + for extension in extensions: + current_tax_benefit_system = current_tax_benefit_system.clone() + current_tax_benefit_system.load_extension(extension) + + _tax_benefit_system_cache[key] = current_tax_benefit_system + + return current_tax_benefit_system diff --git a/openfisca_core/tools/_yaml.py b/openfisca_core/tools/_yaml.py new file mode 100644 index 0000000000..c3652ce4fc --- /dev/null +++ b/openfisca_core/tools/_yaml.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import os +import warnings +import yaml # noqa: F401 + +from openfisca_core.warnings import LibYAMLWarning + +try: + from yaml import CLoader as Loader + +except ImportError: + message = [ + "libyaml is not installed in your environment.", + "This can make OpenFisca slower to start,", + "and your test suite slower to run.", + "Once you have installed libyaml,", + "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", + "so that it is used by your Python environment.", + os.linesep, + ] + warnings.warn(" ".join(message), LibYAMLWarning) + + # see https://github.com/python/mypy/issues/1153#issuecomment-455802270 + from yaml import SafeLoader as Loader # type: ignore # noqa: F401 diff --git a/openfisca_core/tools/_yaml_file.py b/openfisca_core/tools/_yaml_file.py new file mode 100644 index 0000000000..68fcdde51b --- /dev/null +++ b/openfisca_core/tools/_yaml_file.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Sequence, Generator, Optional, cast +from openfisca_core.typing import ( + TaxBenefitSystemProtocol, + OptionsSchema, + TestSchema, + ) + +import os +import traceback + +from _pytest.python import Package +from py._path.local import LocalPath +from pytest import File + +from ._yaml import yaml, Loader +from ._yaml_item import YamlItem + + +class YamlFile(File): + + def __init__( + self, + path: LocalPath, + fspath: LocalPath, + parent: Package, + tax_benefit_system: TaxBenefitSystemProtocol, + options: OptionsSchema, + ) -> None: + + super(YamlFile, self).__init__(path, parent) + self.tax_benefit_system = tax_benefit_system + self.options = options + + def collect(self) -> Generator[YamlItem, None, None]: + tests: Sequence[TestSchema] + + try: + tests = yaml.load(self.fspath.open(), Loader = Loader) + except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): + message = os.linesep.join([ + traceback.format_exc(), + f"'{self.fspath}' is not a valid YAML file. Check the stack trace above for more details.", + ]) + raise ValueError(message) + + if not isinstance(tests, list): + tests = cast(Sequence[TestSchema], [tests]) + + for test in tests: + if not self.should_ignore(test): + yield YamlItem.from_parent(self, + name = '', + baseline_tax_benefit_system = self.tax_benefit_system, + test = test, options = self.options) + + def should_ignore(self, test: TestSchema) -> bool: + name_filter: Optional[str] = self.options.get('name_filter') + stem: str = os.path.splitext(self.fspath.basename)[0] + name: str = test.get('name', '') + kwds: Sequence[str] = test.get('keywords', []) + + return ( + name_filter is not None + and name_filter not in stem + and name_filter not in name + and name_filter not in kwds + ) diff --git a/openfisca_core/tools/_yaml_item.py b/openfisca_core/tools/_yaml_item.py new file mode 100644 index 0000000000..76983f5760 --- /dev/null +++ b/openfisca_core/tools/_yaml_item.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from typing import Any, Mapping, Optional, Sequence, Set, Union +from openfisca_core.typing import ( + TaxBenefitSystemProtocol, + OptionsSchema, + TestSchema, + ) + +import os +import sys +import textwrap + +from _pytest._code import code +from pytest import File, Item + +from openfisca_core.errors import SituationParsingError, VariableNotFoundError +from openfisca_core.simulations import Simulation, SimulationBuilder +from openfisca_core.tracers import FullTracer + +from . import _asserts +from . import _misc + +TEST_KEYWORDS: Set[str] +TEST_KEYWORDS = { + 'absolute_error_margin', + 'description', + 'extensions', + 'ignore_variables', + 'input', + 'keywords', + 'max_spiral_loops', + 'name', + 'only_variables', + 'output', + 'period', + 'reforms', + 'relative_error_margin', + } + + +class YamlItem(Item): + """ + Terminal nodes of the test collection tree. + """ + + baseline_tax_benefit_system: TaxBenefitSystemProtocol + name: str = "" + options: OptionsSchema + simulation: Simulation + tax_benefit_system: TaxBenefitSystemProtocol + test: TestSchema + + def __init__( + self, + name: str, + parent: File, + baseline_tax_benefit_system: TaxBenefitSystemProtocol, + test: TestSchema, + options: OptionsSchema, + ) -> None: + + super().__init__(name, parent) + self.baseline_tax_benefit_system = baseline_tax_benefit_system + self.options = options + self.test = test + + def runtest(self) -> None: + builder: SimulationBuilder + extensions: Sequence[str] = [] + input: Mapping[str, Any] = {} + max_spiral_loops: Optional[int] = None + performance_graph: bool = False + performance_tables: bool = False + period: Optional[str] = None + reforms: Sequence[str] = [] + unexpected_keys: Set[str] + verbose: bool = False + + if "name" in self.test: + self.name = self.test["name"] + + if "output" not in self.test: + raise ValueError("Missing key 'output' in test '{}' in file '{}'".format(self.name, self.fspath)) + + if not TEST_KEYWORDS.issuperset(self.test.keys()): + unexpected_keys = set(self.test.keys()).difference(TEST_KEYWORDS) + raise ValueError("Unexpected keys {} in test '{}' in file '{}'".format(unexpected_keys, self.name, self.fspath)) + + if "reforms" in self.test: + reforms = self.test["reforms"] + + if "extensions" in self.test: + extensions = self.test["extensions"] + + self.tax_benefit_system = _misc._get_tax_benefit_system(self.baseline_tax_benefit_system, reforms, extensions) + + builder = SimulationBuilder() + + if "input" in self.test: + input = self.test["input"] + + if "period" in self.test: + period = self.test["period"] + + if "max_spiral_loops" in self.test: + max_spiral_loops = self.test["max_spiral_loops"] + + if "verbose" in self.options: + verbose = self.options["verbose"] + + if "performance_graph" in self.options: + performance_graph = self.options["performance_graph"] + + if "performance_tables" in self.options: + performance_tables = self.options["performance_tables"] + + try: + builder.set_default_period(period) + self.simulation = builder.build_from_dict(self.tax_benefit_system, input) + except (VariableNotFoundError, SituationParsingError): + raise + except Exception as e: + error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) + raise ValueError(error_message).with_traceback(sys.exc_info()[2]) from e # Keep the stack trace from the root error + + if max_spiral_loops: + self.simulation.max_spiral_loops = max_spiral_loops + + try: + self.simulation.trace = verbose or performance_graph or performance_tables + self.check_output() + finally: + tracer = self.simulation.tracer + if verbose: + assert isinstance(tracer, FullTracer) + self.print_computation_log(tracer) + if performance_graph: + assert isinstance(tracer, FullTracer) + self.generate_performance_graph(tracer) + if performance_tables: + assert isinstance(tracer, FullTracer) + self.generate_performance_tables(tracer) + + def print_computation_log(self, tracer: FullTracer) -> None: + print("Computation log:") # noqa T001 + tracer.print_computation_log() + + def generate_performance_graph(self, tracer: FullTracer) -> None: + tracer.generate_performance_graph('.') + + def generate_performance_tables(self, tracer: FullTracer) -> None: + tracer.generate_performance_tables('.') + + def check_output(self) -> None: + output = self.test.get('output') + + if output is None: + return + for key, expected_value in output.items(): + if self.tax_benefit_system.get_variable(key): # If key is a variable + self.check_variable(key, expected_value, self.test.get('period')) + elif self.simulation.populations.get(key): # If key is an entity singular + for variable_name, value in expected_value.items(): + self.check_variable(variable_name, value, self.test.get('period')) + else: + population = self.simulation.get_population(plural = key) + if population is not None: # If key is an entity plural + for instance_id, instance_values in expected_value.items(): + for variable_name, value in instance_values.items(): + entity_index = population.get_index(instance_id) + self.check_variable(variable_name, value, self.test.get('period'), entity_index) + else: + raise VariableNotFoundError(key, self.tax_benefit_system) + + def check_variable( + self, + variable_name: str, + expected_value: Mapping[str, Any], + period: Optional[str], + entity_index: Optional[int] = None, + ) -> None: + + if self.should_ignore_variable(variable_name): + return + if isinstance(expected_value, dict): + for requested_period, expected_value_at_period in expected_value.items(): + self.check_variable(variable_name, expected_value_at_period, requested_period, entity_index) + return + + actual_value = self.simulation.calculate(variable_name, period) + + if entity_index is not None: + actual_value = actual_value[entity_index] + return _asserts.assert_near( + actual_value, + expected_value, + absolute_error_margin = self.test.get('absolute_error_margin'), + message = f"{variable_name}@{period}: ", + relative_error_margin = self.test.get('relative_error_margin'), + ) + + def should_ignore_variable(self, variable_name: str) -> bool: + only_variables = self.options.get('only_variables') + ignore_variables = self.options.get('ignore_variables') + variable_ignored = ignore_variables is not None and variable_name in ignore_variables + variable_not_tested = only_variables is not None and variable_name not in only_variables + + return variable_ignored or variable_not_tested + + def repr_failure( + self, + excinfo: code.ExceptionInfo[BaseException], + style: Optional[code._TracebackStyle] = None, + ) -> Union[str, code.TerminalRepr]: + + if not isinstance(excinfo.value, (AssertionError, VariableNotFoundError, SituationParsingError)): + return super(YamlItem, self).repr_failure(excinfo) + + message = excinfo.value.args[0] + if isinstance(excinfo.value, SituationParsingError): + message = f"Could not parse situation described: {message}" + + return os.linesep.join([ + f"{str(self.fspath)}:", + f" Test '{str(self.name)}':", + textwrap.indent(message, ' ') + ]) diff --git a/openfisca_core/tools/_yaml_plugin.py b/openfisca_core/tools/_yaml_plugin.py new file mode 100644 index 0000000000..33c22c4de7 --- /dev/null +++ b/openfisca_core/tools/_yaml_plugin.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import Optional +from openfisca_core.typing import TaxBenefitSystemProtocol, OptionsSchema + +from _pytest.main import Session +from py._path.local import LocalPath + +from ._yaml_file import YamlFile + + +class YamlPlugin: + + def __init__( + self, + tax_benefit_system: TaxBenefitSystemProtocol, + options: OptionsSchema, + ) -> None: + + self.tax_benefit_system = tax_benefit_system + self.options = options + + def pytest_collect_file( + self, + parent: Session, + path: LocalPath, + ) -> Optional[YamlFile]: + """ + Called by pytest for all plugins. + :return: The collector for test methods. + """ + + if path.ext in [".yaml", ".yml"]: + return YamlFile.from_parent( # type: ignore + parent, + path = path, + fspath = path, + tax_benefit_system = self.tax_benefit_system, + options = self.options, + ) + + return None diff --git a/openfisca_core/tools/simulation_dumper.py b/openfisca_core/tools/simulation_dumper.py index 4b5907c0ff..4546ced1e4 100644 --- a/openfisca_core/tools/simulation_dumper.py +++ b/openfisca_core/tools/simulation_dumper.py @@ -1,5 +1,11 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations +from typing import Any, Optional +from openfisca_core.typing import ( + HolderProtocol, + PopulationProtocol, + TaxBenefitSystemProtocol, + ) import os @@ -10,10 +16,14 @@ from openfisca_core.periods import ETERNITY -def dump_simulation(simulation, directory): +def dump_simulation( + simulation: Simulation, + directory: str + ) -> None: """ Write simulation data to directory, so that it can be restored later. """ + parent_directory = os.path.abspath(os.path.join(directory, os.pardir)) if not os.path.isdir(parent_directory): # To deal with reforms os.mkdir(parent_directory) @@ -35,10 +45,15 @@ def dump_simulation(simulation, directory): _dump_holder(holder, directory) -def restore_simulation(directory, tax_benefit_system, **kwargs): +def restore_simulation( + directory: str, + tax_benefit_system: TaxBenefitSystemProtocol, + **kwargs: Any, + ) -> Simulation: """ Restore simulation from directory """ + simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) entities_dump_dir = os.path.join(directory, "__entities__") @@ -60,14 +75,14 @@ def restore_simulation(directory, tax_benefit_system, **kwargs): return simulation -def _dump_holder(holder, directory): +def _dump_holder(holder: HolderProtocol, directory: str) -> None: disk_storage = holder.create_disk_storage(directory, preserve = True) for period in holder.get_known_periods(): value = holder.get_array(period) disk_storage.put(value, period) -def _dump_entity(population, directory): +def _dump_entity(population: PopulationProtocol, directory: str) -> None: path = os.path.join(directory, population.entity.key) os.mkdir(path) np.save(os.path.join(path, "id.npy"), population.ids) @@ -89,13 +104,17 @@ def _dump_entity(population, directory): np.save(os.path.join(path, "members_role.npy"), encoded_roles) -def _restore_entity(population, directory): +def _restore_entity( + population: PopulationProtocol, + directory: str, + ) -> Optional[int]: + path = os.path.join(directory, population.entity.key) population.ids = np.load(os.path.join(path, "id.npy")) if population.entity.is_person: - return + return None population.members_position = np.load(os.path.join(path, "members_position.npy")) population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy")) @@ -114,9 +133,19 @@ def _restore_entity(population, directory): return person_count -def _restore_holder(simulation, variable, directory): +def _restore_holder( + simulation: Simulation, + variable: str, + directory: str, + ) -> None: + + variable_ = simulation.tax_benefit_system.get_variable(variable) + + if variable_ is None: + return None + storage_dir = os.path.join(directory, variable) - is_variable_eternal = simulation.tax_benefit_system.get_variable(variable).definition_period == ETERNITY + is_variable_eternal = variable_.definition_period == ETERNITY disk_storage = OnDiskStorage( storage_dir, is_eternal = is_variable_eternal, diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index 1c37ea1469..10f99fc51a 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -1,44 +1,43 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -import warnings -import sys -import os -import traceback -import textwrap -from typing import Dict, List +from typing import Optional, Sequence, Union +from openfisca_core.typing import TaxBenefitSystemProtocol, OptionsSchema import pytest -from openfisca_core.tools import assert_near -from openfisca_core.simulation_builder import SimulationBuilder -from openfisca_core.errors import SituationParsingError, VariableNotFound -from openfisca_core.warnings import LibYAMLWarning +# For backwards compatibility. +from openfisca_core.simulation_builder import SimulationBuilder # noqa: F401 +# For backwards compatibility. +from openfisca_core.errors import SituationParsingError, VariableNotFound # noqa: F401 -def import_yaml(): - import yaml - try: - from yaml import CLoader as Loader - except ImportError: - message = [ - "libyaml is not installed in your environment.", - "This can make your test suite slower to run. Once you have installed libyaml, ", - "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", - "so that it is used in your Python environment." - ] - warnings.warn(" ".join(message), LibYAMLWarning) - from yaml import SafeLoader as Loader - return yaml, Loader +# For backwards compatibility. +from openfisca_core.warnings import LibYAMLWarning # noqa: F401 +# For backwards compatibility. +from ._asserts import assert_near # noqa: F401 -TEST_KEYWORDS = {'absolute_error_margin', 'description', 'extensions', 'ignore_variables', 'input', 'keywords', 'max_spiral_loops', 'name', 'only_variables', 'output', 'period', 'reforms', 'relative_error_margin'} +# For backwards compatibility. +from ._misc import _get_tax_benefit_system # noqa: F401 -yaml, Loader = import_yaml() +# For backwards compatibility. +from ._yaml import yaml, Loader # noqa: F401 -_tax_benefit_system_cache: Dict = {} +# For backwards compatibility. +from ._yaml_file import YamlFile # noqa: F401 +# For backwards compatibility. +from ._yaml_item import YamlItem # noqa: F401 -def run_tests(tax_benefit_system, paths, options = None): +# For backwards compatibility. +from ._yaml_plugin import YamlPlugin as OpenFiscaPlugin + + +def run_tests( + tax_benefit_system: TaxBenefitSystemProtocol, + paths: Sequence[str], + options: Optional[OptionsSchema] = None, + ) -> Union[int, pytest.ExitCode]: """ Runs all the YAML tests contained in a file or a directory. @@ -66,222 +65,16 @@ def run_tests(tax_benefit_system, paths, options = None): argv = [] - if options.get('pdb'): + if options is None: + options = {} + + if "pdb" in options and options["pdb"]: argv.append('--pdb') - if options.get('verbose'): + if "verbose" in options and options["verbose"]: argv.append('--verbose') if isinstance(paths, str): paths = [paths] return pytest.main([*argv, *paths] if True else paths, plugins = [OpenFiscaPlugin(tax_benefit_system, options)]) - - -class YamlFile(pytest.File): - - def __init__(self, path, fspath, parent, tax_benefit_system, options): - super(YamlFile, self).__init__(path, parent) - self.tax_benefit_system = tax_benefit_system - self.options = options - - def collect(self): - try: - tests = yaml.load(self.fspath.open(), Loader = Loader) - except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): - message = os.linesep.join([ - traceback.format_exc(), - f"'{self.fspath}' is not a valid YAML file. Check the stack trace above for more details.", - ]) - raise ValueError(message) - - if not isinstance(tests, list): - tests: List[Dict] = [tests] - - for test in tests: - if not self.should_ignore(test): - yield YamlItem.from_parent(self, - name = '', - baseline_tax_benefit_system = self.tax_benefit_system, - test = test, options = self.options) - - def should_ignore(self, test): - name_filter = self.options.get('name_filter') - return ( - name_filter is not None - and name_filter not in os.path.splitext(self.fspath.basename)[0] - and name_filter not in test.get('name', '') - and name_filter not in test.get('keywords', []) - ) - - -class YamlItem(pytest.Item): - """ - Terminal nodes of the test collection tree. - """ - - def __init__(self, name, parent, baseline_tax_benefit_system, test, options): - super(YamlItem, self).__init__(name, parent) - self.baseline_tax_benefit_system = baseline_tax_benefit_system - self.options = options - self.test = test - self.simulation = None - self.tax_benefit_system = None - - def runtest(self): - self.name = self.test.get('name', '') - if not self.test.get('output'): - raise ValueError("Missing key 'output' in test '{}' in file '{}'".format(self.name, self.fspath)) - - if not TEST_KEYWORDS.issuperset(self.test.keys()): - unexpected_keys = set(self.test.keys()).difference(TEST_KEYWORDS) - raise ValueError("Unexpected keys {} in test '{}' in file '{}'".format(unexpected_keys, self.name, self.fspath)) - - self.tax_benefit_system = _get_tax_benefit_system(self.baseline_tax_benefit_system, self.test.get('reforms', []), self.test.get('extensions', [])) - - builder = SimulationBuilder() - input = self.test.get('input', {}) - period = self.test.get('period') - max_spiral_loops = self.test.get('max_spiral_loops') - verbose = self.options.get('verbose') - performance_graph = self.options.get('performance_graph') - performance_tables = self.options.get('performance_tables') - - try: - builder.set_default_period(period) - self.simulation = builder.build_from_dict(self.tax_benefit_system, input) - except (VariableNotFound, SituationParsingError): - raise - except Exception as e: - error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) - raise ValueError(error_message).with_traceback(sys.exc_info()[2]) from e # Keep the stack trace from the root error - - if max_spiral_loops: - self.simulation.max_spiral_loops = max_spiral_loops - - try: - self.simulation.trace = verbose or performance_graph or performance_tables - self.check_output() - finally: - tracer = self.simulation.tracer - if verbose: - self.print_computation_log(tracer) - if performance_graph: - self.generate_performance_graph(tracer) - if performance_tables: - self.generate_performance_tables(tracer) - - def print_computation_log(self, tracer): - print("Computation log:") # noqa T001 - tracer.print_computation_log() - - def generate_performance_graph(self, tracer): - tracer.generate_performance_graph('.') - - def generate_performance_tables(self, tracer): - tracer.generate_performance_tables('.') - - def check_output(self): - output = self.test.get('output') - - if output is None: - return - for key, expected_value in output.items(): - if self.tax_benefit_system.get_variable(key): # If key is a variable - self.check_variable(key, expected_value, self.test.get('period')) - elif self.simulation.populations.get(key): # If key is an entity singular - for variable_name, value in expected_value.items(): - self.check_variable(variable_name, value, self.test.get('period')) - else: - population = self.simulation.get_population(plural = key) - if population is not None: # If key is an entity plural - for instance_id, instance_values in expected_value.items(): - for variable_name, value in instance_values.items(): - entity_index = population.get_index(instance_id) - self.check_variable(variable_name, value, self.test.get('period'), entity_index) - else: - raise VariableNotFound(key, self.tax_benefit_system) - - def check_variable(self, variable_name, expected_value, period, entity_index = None): - if self.should_ignore_variable(variable_name): - return - if isinstance(expected_value, dict): - for requested_period, expected_value_at_period in expected_value.items(): - self.check_variable(variable_name, expected_value_at_period, requested_period, entity_index) - return - - actual_value = self.simulation.calculate(variable_name, period) - - if entity_index is not None: - actual_value = actual_value[entity_index] - return assert_near( - actual_value, - expected_value, - absolute_error_margin = self.test.get('absolute_error_margin'), - message = f"{variable_name}@{period}: ", - relative_error_margin = self.test.get('relative_error_margin'), - ) - - def should_ignore_variable(self, variable_name): - only_variables = self.options.get('only_variables') - ignore_variables = self.options.get('ignore_variables') - variable_ignored = ignore_variables is not None and variable_name in ignore_variables - variable_not_tested = only_variables is not None and variable_name not in only_variables - - return variable_ignored or variable_not_tested - - def repr_failure(self, excinfo): - if not isinstance(excinfo.value, (AssertionError, VariableNotFound, SituationParsingError)): - return super(YamlItem, self).repr_failure(excinfo) - - message = excinfo.value.args[0] - if isinstance(excinfo.value, SituationParsingError): - message = f"Could not parse situation described: {message}" - - return os.linesep.join([ - f"{str(self.fspath)}:", - f" Test '{str(self.name)}':", - textwrap.indent(message, ' ') - ]) - - -class OpenFiscaPlugin(object): - - def __init__(self, tax_benefit_system, options): - self.tax_benefit_system = tax_benefit_system - self.options = options - - def pytest_collect_file(self, parent, path): - """ - Called by pytest for all plugins. - :return: The collector for test methods. - """ - if path.ext in [".yaml", ".yml"]: - return YamlFile.from_parent(parent, path = path, fspath = path, - tax_benefit_system = self.tax_benefit_system, - options = self.options) - - -def _get_tax_benefit_system(baseline, reforms, extensions): - if not isinstance(reforms, list): - reforms = [reforms] - if not isinstance(extensions, list): - extensions = [extensions] - - # keep reforms order in cache, ignore extensions order - key = hash((id(baseline), ':'.join(reforms), frozenset(extensions))) - if _tax_benefit_system_cache.get(key): - return _tax_benefit_system_cache.get(key) - - current_tax_benefit_system = baseline - - for reform_path in reforms: - current_tax_benefit_system = current_tax_benefit_system.apply_reform(reform_path) - - for extension in extensions: - current_tax_benefit_system = current_tax_benefit_system.clone() - current_tax_benefit_system.load_extension(extension) - - _tax_benefit_system_cache[key] = current_tax_benefit_system - - return current_tax_benefit_system diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 3fa46de5ab..d4a67f4379 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,23 +1,18 @@ from __future__ import annotations import time -import typing -from typing import Dict, Iterator, List, Optional, Union +from typing import Iterator, Optional, Sequence +from openfisca_core.typing import ArrayLike, FrameSchema, PeriodProtocol from .. import tracers -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period - - Stack = List[Dict[str, Union[str, Period]]] +Stack = Sequence[FrameSchema] class FullTracer: _simple_tracer: tracers.SimpleTracer - _trees: list + _trees: Sequence[tracers.TraceNode] _current_node: Optional[tracers.TraceNode] def __init__(self) -> None: @@ -28,8 +23,9 @@ def __init__(self) -> None: def record_calculation_start( self, variable: str, - period: Period, + period: PeriodProtocol, ) -> None: + self._simple_tracer.record_calculation_start(variable, period) self._enter_calculation(variable, period) self._record_start_time() @@ -37,8 +33,9 @@ def record_calculation_start( def _enter_calculation( self, variable: str, - period: Period, + period: PeriodProtocol, ) -> None: + new_node = tracers.TraceNode( name = variable, period = period, @@ -46,7 +43,7 @@ def _enter_calculation( ) if self._current_node is None: - self._trees.append(new_node) + self._trees = [*self.trees, new_node] else: self._current_node.append_child(new_node) @@ -56,15 +53,21 @@ def _enter_calculation( def record_parameter_access( self, parameter: str, - period: Period, + period: PeriodProtocol, value: ArrayLike, ) -> None: if self._current_node is not None: - self._current_node.parameters.append( - tracers.TraceNode(name = parameter, period = period, value = value), + new_node = tracers.TraceNode( + name = parameter, + period = period, + value = value, ) + parameters = self._current_node.parameters + + self._current_node.parameters = [*parameters, new_node] + def _record_start_time( self, time_in_s: Optional[float] = None, @@ -103,7 +106,7 @@ def stack(self) -> Stack: return self._simple_tracer.stack @property - def trees(self) -> List[tracers.TraceNode]: + def trees(self) -> Sequence[tracers.TraceNode]: return self._trees @property @@ -121,7 +124,7 @@ def flat_trace(self) -> tracers.FlatTrace: def _get_time_in_sec(self) -> float: return time.time_ns() / (10**9) - def print_computation_log(self, aggregate = False): + def print_computation_log(self, aggregate: bool = False) -> None: self.computation_log.print_log(aggregate) def generate_performance_graph(self, dir_path: str) -> None: diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 2fa98c6582..26e9559e0a 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,14 +1,9 @@ from __future__ import annotations -import typing -from typing import Dict, List, Union +from typing import Sequence +from openfisca_core.typing import ArrayLike, FrameSchema, PeriodProtocol -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period - - Stack = List[Dict[str, Union[str, Period]]] +Stack = Sequence[FrameSchema] class SimpleTracer: @@ -18,8 +13,15 @@ class SimpleTracer: def __init__(self) -> None: self._stack = [] - def record_calculation_start(self, variable: str, period: Period) -> None: - self.stack.append({'name': variable, 'period': period}) + def record_calculation_start( + self, + variable: str, + period: PeriodProtocol, + ) -> None: + + frame: FrameSchema + frame = {'name': variable, 'period': period} + self.stack = [*self.stack, frame] def record_calculation_result(self, value: ArrayLike) -> None: pass # ignore calculation result @@ -28,8 +30,12 @@ def record_parameter_access(self, parameter: str, period, value): pass def record_calculation_end(self) -> None: - self.stack.pop() + self.stack = self.stack[:-1] @property def stack(self) -> Stack: return self._stack + + @stack.setter + def stack(self, value: Stack) -> None: + self._stack = value diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index 93b630886c..d4ffb2b016 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -1,30 +1,27 @@ from __future__ import annotations -import dataclasses -import typing +from typing import Union, Optional, Sequence +from openfisca_core.typing import ArrayLike, PeriodProtocol -if typing.TYPE_CHECKING: - import numpy +import dataclasses - from openfisca_core.indexed_enums import EnumArray - from openfisca_core.periods import Period +from openfisca_core.indexed_enums import EnumArray - Array = typing.Union[EnumArray, numpy.typing.ArrayLike] - Time = typing.Union[float, int] +Array = Union[EnumArray, ArrayLike] @dataclasses.dataclass class TraceNode: name: str - period: Period - parent: typing.Optional[TraceNode] = None - children: typing.List[TraceNode] = dataclasses.field(default_factory = list) - parameters: typing.List[TraceNode] = dataclasses.field(default_factory = list) - value: typing.Optional[Array] = None + period: PeriodProtocol + parent: Optional[TraceNode] = None + children: Sequence[TraceNode] = dataclasses.field(default_factory = list) + parameters: Sequence[TraceNode] = dataclasses.field(default_factory = list) + value: Optional[Array] = None start: float = 0 end: float = 0 - def calculation_time(self, round_: bool = True) -> Time: + def calculation_time(self, round_: bool = True) -> float: result = self.end - self.start if round_: @@ -47,8 +44,8 @@ def formula_time(self) -> float: return self.round(result) def append_child(self, node: TraceNode) -> None: - self.children.append(node) + self.children = [*self.children, node] @staticmethod - def round(time: Time) -> float: + def round(time: float) -> float: return float(f'{time:.4g}') # Keep only 4 significant figures diff --git a/openfisca_core/types/__init__.py b/openfisca_core/types/__init__.py deleted file mode 100644 index e14cfea65d..0000000000 --- a/openfisca_core/types/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Data types and protocols used by OpenFisca Core. - -The type definitions included in this sub-package are intented for -contributors, to help them better understand and document contracts -and expected behaviours. - -Official Public API: - * ``ArrayLike`` - * :attr:`.ArrayType` - -Note: - How imports are being used today:: - - from openfisca_core.types import * # Bad - from openfisca_core.types.data_types.arrays import ArrayLike # Bad - - - The previous examples provoke cyclic dependency problems, that prevents us - from modularizing the different components of the library, so as to make - them easier to test and to maintain. - - How could them be used after the next major release:: - - from openfisca_core.types import ArrayLike - - ArrayLike # Good: import types as publicly exposed - - .. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. - - .. _PEP8#Imports: - https://www.python.org/dev/peps/pep-0008/#imports - - .. _OpenFisca's Styleguide: - https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md - -""" - -# Official Public API - -from .data_types import ( # noqa: F401 - ArrayLike, - ArrayType, - ) - -__all__ = ["ArrayLike", "ArrayType"] diff --git a/openfisca_core/types/data_types/__init__.py b/openfisca_core/types/data_types/__init__.py deleted file mode 100644 index 6dd38194e3..0000000000 --- a/openfisca_core/types/data_types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .arrays import ArrayLike, ArrayType # noqa: F401 diff --git a/openfisca_core/typing/__init__.py b/openfisca_core/typing/__init__.py new file mode 100644 index 0000000000..648a96ffc9 --- /dev/null +++ b/openfisca_core/typing/__init__.py @@ -0,0 +1,82 @@ +"""Data types and protocols used by OpenFisca Core. + +The type definitions included in this sub-package are intented for +contributors, to help them better understand and document contracts +and expected behaviours. + +Official Public API: + * :data:`.ArrayLike` + * :attr:`.ArrayType` + * :class:`.FormulaProtocol` + * :class:`.HolderProtocol` + * :class:`.PeriodProtocol` + * :class:`.PopulationProtocol` + * :class:`.TaxBenefitSystemProtocol` + * :class:`.AxisSchema` + * :class:`.FrameSchema` + * :class:`.OptionsSchema` + * :class:`.TestSchema` + + +Note: + How imports are being used today:: + + from openfisca_core.typing import * # Bad + from openfisca_core.typing.data_types.arrays import ArrayLike # Bad + + + The previous examples provoke cyclic dependency problems, that prevents us + from modularizing the different components of the library, so as to make + them easier to test and to maintain. + + How could them be used after the next major release:: + + from openfisca_core.typing import ArrayLike + + ArrayLike # Good: import types as publicly exposed + + .. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. + + .. _PEP8#Imports: + https://www.python.org/dev/peps/pep-0008/#imports + + .. _OpenFisca's Styleguide: + https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md + +""" + +# Official Public API + +from ._types import ( # noqa: F401 + ArrayLike, + ArrayType, + ) + +from ._protocols import ( # noqa: F401 + FormulaProtocol, + HolderProtocol, + PeriodProtocol, + PopulationProtocol, + TaxBenefitSystemProtocol, + ) + +from ._schemas import ( # noqa: F401 + AxisSchema, + FrameSchema, + OptionsSchema, + TestSchema, + ) + +__all__ = [ + "ArrayLike", + "ArrayType", + "FormulaProtocol", + "HolderProtocol", + "PeriodProtocol", + "PopulationProtocol", + "TaxBenefitSystemProtocol", + "AxisSchema", + "FrameSchema", + "OptionsSchema", + "TestSchema", + ] diff --git a/openfisca_core/typing/_protocols.py b/openfisca_core/typing/_protocols.py new file mode 100644 index 0000000000..a5a2c0405d --- /dev/null +++ b/openfisca_core/typing/_protocols.py @@ -0,0 +1,215 @@ +# pylint: disable=missing-function-docstring + +from __future__ import annotations + +from typing import Any, Mapping, Optional, Sequence, Set, overload +from typing_extensions import Literal, Protocol +from ._types import ArrayType + +import abc + + +class EntityProtocol(Protocol): + """Duck-type for entities. + + .. versionadded:: 35.8.0 + + """ + + key: str + plural: str + is_person: bool + flattened_roles: Sequence[RoleProtocol] + + +class FormulaProtocol(Protocol): + """Duck-type for formulas""" + + def __call__( + self, + __population: PopulationProtocol, + __period: PeriodProtocol, + __pararameters: ParametersProtocol, + ) -> ArrayType[Any]: + ... + + +class HolderProtocol(Protocol): + """Duck-type for holders. + + .. versionadded:: 35.8.0 + + """ + + @abc.abstractmethod + def create_disk_storage( + self, + directory: Optional[str] = ..., + preserve: bool = ..., + ) -> StorageProtocol: + ... + + @abc.abstractmethod + def put_in_cache( + self, + value: ArrayType[Any], + period: PeriodProtocol, + ) -> None: + ... + + @abc.abstractmethod + def get_array(self, period: PeriodProtocol) -> Any: + ... + + @abc.abstractmethod + def get_known_periods(self) -> Sequence[PeriodProtocol]: + ... + + +class InstantProtocol(Protocol): + """Duck-type for instants. + + .. versionadded:: 35.8.0 + + """ + + +class ParameterNodeAtInstantProtocol(Protocol): + """Duck-type for parameter nodes at instant. + + .. versionadded:: 35.8.0 + + """ + + +class ParametersProtocol(Protocol): + """Duck-type for parameters. + + .. versionadded:: 35.8.0 + + """ + + def __call__( + self, + instant: InstantProtocol, + ) -> ParameterNodeAtInstantProtocol: + ... + + +class PeriodProtocol(Protocol): + """Duck-type for periods. + + .. versionadded:: 35.8.0 + + """ + + +class PopulationProtocol(Protocol): + """Duck-type for populations. + + .. versionadded:: 35.8.0 + + """ + + _holders: Mapping[str, HolderProtocol] + count: Optional[int] + entity: EntityProtocol + ids: Sequence[str] + members_entity_id: ArrayType[int] + members_position: ArrayType[int] + members_role: ArrayType[RoleProtocol] + + @abc.abstractmethod + def get_index(self, id: str) -> int: + ... + + +class RoleProtocol(Protocol): + """Duck-type for roles. + + .. versionadded:: 35.8.0 + + """ + + key: str + + +class StorageProtocol(Protocol): + """Duck-type for storage mechanisms. + + .. versionadded:: 35.8.0 + + """ + + @abc.abstractmethod + def put(self, value: ArrayType[Any], period: PeriodProtocol) -> None: + ... + + +class TaxBenefitSystemProtocol(Protocol): + """Duck-type for tax-benefit systems. + + .. versionadded:: 35.8.0 + + """ + + person_entity: EntityProtocol + + @abc.abstractmethod + def apply_reform(self, reform_path: str) -> TaxBenefitSystemProtocol: + ... + + @abc.abstractmethod + def clone(self) -> TaxBenefitSystemProtocol: + ... + + @abc.abstractmethod + def entities_plural(self) -> Set[str]: + ... + + @abc.abstractmethod + def get_package_metadata(self) -> Mapping[str, str]: + ... + + @overload + def get_variable( + self, + variable_name: str, + check_existence: Literal[True] = ..., + ) -> VariableProtocol: + ... + + @overload + def get_variable( + self, + variable_name: str, + check_existence: bool = ..., + ) -> Optional[VariableProtocol]: + ... + + @abc.abstractmethod + def get_variable( + self, + variable_name: str, + check_existence: bool = ..., + ) -> Optional[VariableProtocol]: + ... + + @abc.abstractmethod + def instantiate_entities(self) -> Mapping[str, PopulationProtocol]: + ... + + @abc.abstractmethod + def load_extension(self, extension: str) -> None: + ... + + +class VariableProtocol(Protocol): + """Duck-type for variables. + + .. versionadded:: 35.8.0 + + """ + + definition_period: str + name: str diff --git a/openfisca_core/typing/_schemas.py b/openfisca_core/typing/_schemas.py new file mode 100644 index 0000000000..b9d7a086cf --- /dev/null +++ b/openfisca_core/typing/_schemas.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any, Mapping, Optional, Sequence +from typing_extensions import TypedDict + +from ._protocols import PeriodProtocol + + +class AxisSchema(TypedDict): + """Data-schema of axes.""" + + count: int + index: int + max: float + min: float + name: str + period: str + + +class FrameSchema(TypedDict): + """Data-schema of tracer stack frames.""" + + name: str + period: PeriodProtocol + + +class OptionsSchema(TypedDict, total = False): + """Data-schema of ``openfisca test`` options.""" + + ignore_variables: Optional[Sequence[str]] + name_filter: Optional[str] + only_variables: Optional[Sequence[str]] + pdb: bool + performance_graph: bool + performance_tables: bool + verbose: bool + + +class TestSchema(TypedDict, total = False): + """Data-schema of ``openfisca test`` tests.""" + + absolute_error_margin: float + extensions: Sequence[str] + input: Mapping[str, Mapping[str, Any]] + keywords: Sequence[str] + max_spiral_loops: int + name: str + output: Mapping[str, Mapping[str, Any]] + period: str + reforms: Sequence[str] + relative_error_margin: float diff --git a/openfisca_core/types/data_types/arrays.py b/openfisca_core/typing/_types.py similarity index 92% rename from openfisca_core/types/data_types/arrays.py rename to openfisca_core/typing/_types.py index 5cfef639c5..8441e2dd9a 100644 --- a/openfisca_core/types/data_types/arrays.py +++ b/openfisca_core/typing/_types.py @@ -34,10 +34,13 @@ Todo: * Refactor once numpy version >= 1.21 is used. -.. versionadded:: 35.5.0 +.. versionchanged:: 35.8.0 + Moved to :mod:`.openfisca_core.typing` .. versionchanged:: 35.6.0 - Moved to :mod:`.types` + Moved to ``openfisca_core.types`` + +.. versionadded:: 35.5.0 .. _mypy: https://mypy.readthedocs.io/en/stable/ diff --git a/openfisca_core/variables/__init__.py b/openfisca_core/variables/__init__.py index fb36963f7d..3decaf8f42 100644 --- a/openfisca_core/variables/__init__.py +++ b/openfisca_core/variables/__init__.py @@ -24,4 +24,3 @@ from .config import VALUE_TYPES, FORMULA_NAME_PREFIX # noqa: F401 from .helpers import get_annualized_variable, get_neutralized_variable # noqa: F401 from .variable import Variable # noqa: F401 -from .typing import Formula # noqa: F401 diff --git a/openfisca_core/variables/typing.py b/openfisca_core/variables/typing.py deleted file mode 100644 index 892ec0bf9f..0000000000 --- a/openfisca_core/variables/typing.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Callable, Union - -import numpy - -from openfisca_core.parameters import ParameterNodeAtInstant -from openfisca_core.periods import Instant, Period -from openfisca_core.populations import Population, GroupPopulation - -#: A collection of :obj:`.Entity` or :obj:`.GroupEntity`. -People = Union[Population, GroupPopulation] - -#: A callable to get the parameters for the given instant. -Params = Callable[[Instant], ParameterNodeAtInstant] - -#: A callable defining a calculation, or a rule, on a system. -Formula = Callable[[People, Period, Params], numpy.ndarray] diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index 61a5d9274f..9dcd73d5d4 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from typing import Optional +from openfisca_core.typing import FormulaProtocol + import datetime import inspect import re @@ -9,7 +14,7 @@ from openfisca_core import periods, tools from openfisca_core.entities import Entity from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import Period +from openfisca_core.periods import Instant, Period from . import config, helpers @@ -302,17 +307,23 @@ def get_introspection_data(cls, tax_benefit_system): return comments, source_file_path, source_code, start_line_number - def get_formula(self, period = None): + def get_formula( + self, + period: Optional[Period] = None, + ) -> Optional[FormulaProtocol]: """ Returns the formula used to compute the variable at the given period. If no period is given and the variable has several formula, return the oldest formula. :returns: Formula used to compute the variable - :rtype: .Formula + :rtype: callable """ + instant: Instant + to_str: str + if not self.formulas: return None @@ -330,9 +341,10 @@ def get_formula(self, period = None): if self.end and instant.date > self.end: return None - instant = str(instant) + to_str = str(instant) + for start_date in reversed(self.formulas): - if start_date <= instant: + if start_date <= to_str: return self.formulas[start_date] return None diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk index 115c6267bb..117da9e6c3 100644 --- a/openfisca_tasks/lint.mk +++ b/openfisca_tasks/lint.mk @@ -17,7 +17,7 @@ check-style: $(shell git ls-files "*.py") ## Run linters to check for syntax and style errors in the doc. lint-doc: \ lint-doc-commons \ - lint-doc-types \ + lint-doc-typing \ ; ## Run linters to check for syntax and style errors in the doc. @@ -42,7 +42,8 @@ check-types: ## Run static type checkers for type errors (strict). lint-typing-strict: \ lint-typing-strict-commons \ - lint-typing-strict-types \ + lint-typing-strict-tools \ + lint-typing-strict-typing \ ; ## Run static type checkers for type errors (strict). diff --git a/setup.cfg b/setup.cfg index bb3ff50fc5..ab06fe8f12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,9 +14,10 @@ extend-ignore = D hang-closing = true ignore = E128,E251,F403,F405,E501,RST301,W503,W504 in-place = true -include-in-doctest = openfisca_core/commons openfisca_core/types +include-in-doctest = openfisca_core/commons openfisca_core/typing +per-file-ignores = openfisca_core/typing/_protocols.py:D102 rst-directives = attribute, deprecated, seealso, versionadded, versionchanged -rst-roles = any, attr, class, exc, func, meth, obj +rst-roles = any, attr, class, data, exc, func, meth, obj strictness = short [pylint.message_control] @@ -41,7 +42,7 @@ skip_empty = true addopts = --doctest-modules --disable-pytest-warnings --showlocals doctest_optionflags = ELLIPSIS IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE python_files = **/*.py -testpaths = openfisca_core/commons openfisca_core/types tests +testpaths = openfisca_core/commons openfisca_core/typing tests [mypy] ignore_missing_imports = True diff --git a/setup.py b/setup.py index 36e30a751e..a99b8d6973 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ 'numexpr >= 2.7.0, <= 3.0', 'numpy >= 1.11, < 1.21', 'psutil >= 5.4.7, < 6.0.0', - 'pytest >= 4.4.1, < 6.0.0', # For openfisca test + 'pytest >= 4.4.1, < 7.0.0', # For openfisca test 'PyYAML >= 3.10', 'sortedcontainers == 2.2.2', 'typing-extensions == 3.10.0.2', diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index bd7aaccad7..ba3f85fa3e 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -58,15 +58,15 @@ def get_population(self, plural = None): class TestFile(YamlFile): - def __init__(self): + def __init__(self, parent, *, fspath): self.config = None self.session = None self._nodeid = 'testname' class TestItem(YamlItem): - def __init__(self, test): - super().__init__('', TestFile(), TaxBenefitSystem(), test, {}) + def __init__(self, parent, test): + super().__init__('', parent, TaxBenefitSystem(), test, {}) self.tax_benefit_system = self.baseline_tax_benefit_system self.simulation = Simulation() @@ -84,10 +84,15 @@ def __init__(self): self.dtype = np.float32 -def test_variable_not_found(): +@pytest.fixture +def test_file(): + return TestFile.from_parent(object(), fspath = "") + + +def test_variable_not_found(test_file): test = {"output": {"unknown_variable": 0}} with pytest.raises(VariableNotFound) as excinfo: - test_item = TestItem(test) + test_item = TestItem.from_parent(test_file, test = test) test_item.check_output() assert excinfo.value.variable_name == "unknown_variable" @@ -140,9 +145,9 @@ def test_extensions_order(): assert xy_tax_benefit_system == yx_tax_benefit_system # extensions order is ignored in cache -def test_performance_graph_option_output(): +def test_performance_graph_option_output(test_file): test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} - test_item = TestItem(test) + test_item = TestItem.from_parent(test_file, test = test) test_item.options = {'performance_graph': True} paths = ["./performance_graph.html"] @@ -158,9 +163,9 @@ def test_performance_graph_option_output(): clean_performance_files(paths) -def test_performance_tables_option_output(): +def test_performance_tables_option_output(test_file): test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} - test_item = TestItem(test) + test_item = TestItem.from_parent(test_file, test = test) test_item.options = {'performance_tables': True} paths = ["performance_table.csv", "aggregated_performance_table.csv"]