From 927263a6ec2cc17cdcd3b7b3122e9b568b5b6cad Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 12:59:40 +0200 Subject: [PATCH 01/38] Extract general requirements to file --- requirements/install.in | 70 +++++++++++++++++++++++++++++++++++++++++ setup.py | 66 +++++++++++++++----------------------- 2 files changed, 95 insertions(+), 41 deletions(-) create mode 100644 requirements/install.in diff --git a/requirements/install.in b/requirements/install.in new file mode 100644 index 0000000000..24eeab6f64 --- /dev/null +++ b/requirements/install.in @@ -0,0 +1,70 @@ +# Please make sure to cap all dependency versions, in order to avoid unwanted +# functional and integration breaks caused by external code updates. +# +# These are general dependencies, so: +# +# * the ranges should be broad, and +# * the number of them kept discreet. +# +# If a dependency seems trivial, please propose a pull request to remove it. +# +# General Rule +# ======== +# +# * Cap to latest major. +# +# NumPy Rule +# ======== +# +# * Cap to latest four minors (document why). +# +# Exceptions +# ======== +# +# * Avoiding a version (document why). +# * Pinning a version (document why). + +# For globing over dictionaries as if they were filesystems +# +# We do not support 2 because of a bug introducted to period parsing in +# OpenFisca's Web API. +dpath ~= 1.5 + +# For evaluating numerical expressions +# +# TODO: support for 2 should be extended. +numexpr ~= 2.7 + +# For vectorial support +# +# We support the latest four minors because NumPy is generally a transitive +# dependency that users rely on within the projects where OpenFisca is +# depended on by. +# +# TODO: support for < 1.17 should be dropped. +numpy ~= 1.11, < 1.21 + +# For caching +# +# We support psutil >= 5.4.7 because users have found problems running +# older versions on Windows (specifically 5.4.2). +# TODO: support should be extended to >= 5.4.3. +psutil >= 5.4.7, < 6 + +# For openfisca test +# +# TODO: support for 4 should be dropped. +# TODO: support for 6 requires fixing some tests. +# See: https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent +pytest >= 4.4.1, < 6 + +# For parameters, tests +# +# TODO: support for 3 should be dropped. +# TODO: support for 4 should be dropped. +PyYAML >= 3.10, < 6 + +# For sorting formulas by date +# +# TODO: support for 2 should be extended. +sortedcontainers == 2.2.2 diff --git a/setup.py b/setup.py index 36e30a751e..1d209f0d48 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,28 @@ #! /usr/bin/env python +from __future__ import annotations + +from typing import List + +import re from setuptools import setup, find_packages -# Please make sure to cap all dependency versions, in order to avoid unwanted -# functional and integration breaks caused by external code updates. -general_requirements = [ - 'dpath >= 1.5.0, < 2.0.0', - 'nptyping == 1.4.4', - 'numexpr >= 2.7.0, <= 3.0', - 'numpy >= 1.11, < 1.21', - 'psutil >= 5.4.7, < 6.0.0', - 'pytest >= 4.4.1, < 6.0.0', # For openfisca test - 'PyYAML >= 3.10', - 'sortedcontainers == 2.2.2', - 'typing-extensions == 3.10.0.2', - ] +def load_requirements_from_file(filename: str) -> List[str]: + """Allows for composable requirement files with the `-r filename` flag.""" + + reqs = open(f"requirements/{filename}").readlines() + pattern = re.compile(r"^\s*-r\s*(?P.*)$") + + for req in reqs: + match = pattern.match(req) -api_requirements = [ - 'flask == 1.1.2', - 'flask-cors == 3.0.10', - 'gunicorn >= 20.0.0, < 21.0.0', - 'werkzeug >= 1.0.0, < 2.0.0', - ] + if match: + reqs.remove(req) + reqs.extend(load_requirements_from_file(match.group("filename"))) + + return reqs -dev_requirements = [ - 'autopep8 >= 1.4.0, < 1.6.0', - 'coverage == 6.0.2', - 'darglint == 1.8.0', - 'flake8 >= 3.9.0, < 4.0.0', - 'flake8-bugbear >= 19.3.0, < 20.0.0', - 'flake8-docstrings == 1.6.0', - 'flake8-print >= 3.1.0, < 4.0.0', - 'flake8-rst-docstrings == 0.2.3', - 'mypy == 0.910', - 'openfisca-country-template >= 3.10.0, < 4.0.0', - 'openfisca-extension-template >= 1.2.0rc0, < 2.0.0', - 'pylint == 2.10.2', - ] + api_requirements setup( name = 'OpenFisca-Core', @@ -49,7 +34,6 @@ 'License :: OSI Approved :: GNU Affero General Public License v3', 'Operating System :: POSIX', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Topic :: Scientific/Engineering :: Information Analysis', ], @@ -57,7 +41,6 @@ keywords = 'benefit microsimulation social tax', license = 'https://www.fsf.org/licensing/licenses/agpl-3.0.html', url = 'https://github.com/openfisca/openfisca-core', - data_files = [ ( 'share/openfisca/openfisca-core', @@ -70,14 +53,15 @@ 'openfisca-run-test=openfisca_core.scripts.openfisca_command:main', ], }, + python_requires = ">= 3.7", + install_requires = load_requirements_from_file("install"), extras_require = { - 'web-api': api_requirements, - 'dev': dev_requirements, - 'tracker': [ - 'openfisca-tracker == 0.4.0', - ], + "coverage": load_requirements_from_file("coverage"), + "dev": load_requirements_from_file("dev"), + "publication": load_requirements_from_file("publication"), + "tracker": load_requirements_from_file("tracker"), + "web-api": load_requirements_from_file("web-api"), }, include_package_data = True, # Will read MANIFEST.in - install_requires = general_requirements, packages = find_packages(exclude=['tests*']), ) From 027256ab1e9ee42db3e3b5d9cc2725d4ee476c8f Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 13:27:07 +0200 Subject: [PATCH 02/38] Extract api requirements to file --- requirements/web-api.in | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 requirements/web-api.in diff --git a/requirements/web-api.in b/requirements/web-api.in new file mode 100644 index 0000000000..fc8970d994 --- /dev/null +++ b/requirements/web-api.in @@ -0,0 +1,16 @@ +# Please make sure to pin all dependency versions, in order to avoid unwanted +# functional and integration breaks caused by external code updates. +# +# These are web-api dependencies, so pin them. + +# For OpenFisca's Web API. +# +# TODO: pin. +flask == 1.1.2 +gunicorn >= 20.0.0, < 21.0.0 + +# For OpenFisca's Web API users requiring CORS. +# +# TODO: pin. +flask-cors == 3.0.10 +werkzeug >= 1.0.0, < 2.0.0 From 0076c6fdc9c9d07fd8fd1e0d7a425a30bbc15b2f Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 13:57:57 +0200 Subject: [PATCH 03/38] Extract ci requirements to file --- .circleci/config.yml | 5 +++-- openfisca_tasks/install.mk | 5 +++-- requirements/coverage.in | 5 +++++ requirements/publication.in | 8 ++++++++ requirements/tracking.in | 7 +++++++ 5 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 requirements/coverage.in create mode 100644 requirements/publication.in create mode 100644 requirements/tracking.in diff --git a/.circleci/config.yml b/.circleci/config.yml index 67c45002e0..416fb281c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -127,7 +127,7 @@ jobs: name: Submit coverage to Coveralls command: | source /tmp/venv/openfisca_core/bin/activate - pip install coveralls + pip install --editable .[coverage] coveralls - save_cache: @@ -153,9 +153,10 @@ jobs: command: if ! .circleci/has-functional-changes.sh ; then circleci step halt ; fi - run: - name: Upload a Python package to Pypi + name: Upload a Python package to PyPi command: | source /tmp/venv/openfisca_core/bin/activate + pip install --editable .[publication] .circleci/publish-python-package.sh - run: diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk index f37d17f26f..7918c76b14 100644 --- a/openfisca_tasks/install.mk +++ b/openfisca_tasks/install.mk @@ -1,8 +1,9 @@ ## Install project dependencies. install: @$(call print_help,$@:) - @pip install --upgrade pip twine wheel - @pip install --editable .[dev] --upgrade --use-deprecated=legacy-resolver + @pip install --upgrade pip setuptools + @pip install --requirement requirements/dev --upgrade + @pip install --editable . --upgrade --no-dependencies ## Uninstall project dependencies. uninstall: diff --git a/requirements/coverage.in b/requirements/coverage.in new file mode 100644 index 0000000000..9dcd97d83d --- /dev/null +++ b/requirements/coverage.in @@ -0,0 +1,5 @@ +# These are dependencies to upload test coverage statistics, so we always want +# the latest versions. + +# For sending test statistics to the Coveralls third-party service. +coveralls diff --git a/requirements/publication.in b/requirements/publication.in new file mode 100644 index 0000000000..85d4ccd2ea --- /dev/null +++ b/requirements/publication.in @@ -0,0 +1,8 @@ +# These are dependencies to publish the library, so we always want the latest +# versions. + +# For publishing on PyPI. +twine + +# For building the package. +wheel diff --git a/requirements/tracking.in b/requirements/tracking.in new file mode 100644 index 0000000000..929c660ee4 --- /dev/null +++ b/requirements/tracking.in @@ -0,0 +1,7 @@ +# Please make sure to pin all dependency versions, in order to avoid unwanted +# functional and integration breaks caused by external code updates. +# +# These are web-api tracking dependencies, so pin them. + +# For sending usage statistics to the Matomo third-party service. +openfisca-tracker == 0.4.0 From 98fe43a4d41fb65f67078366c726fd8e6fc51eed Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 14:58:42 +0200 Subject: [PATCH 04/38] Extract dev requirements to file --- requirements/debug.in | 5 +++++ requirements/dev.in | 28 ++++++++++++++++++++++++ requirements/{tracking.in => tracker.in} | 0 3 files changed, 33 insertions(+) create mode 100644 requirements/debug.in create mode 100644 requirements/dev.in rename requirements/{tracking.in => tracker.in} (100%) diff --git a/requirements/debug.in b/requirements/debug.in new file mode 100644 index 0000000000..59cf404cae --- /dev/null +++ b/requirements/debug.in @@ -0,0 +1,5 @@ +# These are dependencies to help with debug and profiling, so we always want +# the latest versions. + +# Interactive console on steroids (even makes your coffee)! +ipython diff --git a/requirements/dev.in b/requirements/dev.in new file mode 100644 index 0000000000..fcced7d54f --- /dev/null +++ b/requirements/dev.in @@ -0,0 +1,28 @@ +# Please make sure to pin all dependency versions, in order to avoid unwanted +# functional and integration breaks caused by external code updates. +# +# These are dev dependencies, so pin them. + +# For automatic style formatting. +# +# TODO: pin +autopep8 >= 1.4.0, < 1.6.0 + +# For style & code checking. +flake8 >= 3.9.0, < 4.0.0 +flake8-bugbear >= 19.3.0, < 20.0.0 +flake8-print >= 3.1.0, < 4.0.0 +flake8-rst-docstrings == 0.2.3 + +# For PyTest test coverage integration. +pytest-cov >= 2.6.1, < 3.0.0 + +# For optional duck & static type checking. +mypy >= 0.701, < 0.800 + +# For testing: parameters, variables, etc. +openfisca-country-template >= 3.10.0, < 4.0.0 +openfisca-extension-template >= 1.2.0rc0, < 2.0.0 + +# Include Web API dependencies for development +-r web-api.in diff --git a/requirements/tracking.in b/requirements/tracker.in similarity index 100% rename from requirements/tracking.in rename to requirements/tracker.in From cc2dccd5f80fef443fecddfd05678bbcc0d745a8 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 15:08:50 +0200 Subject: [PATCH 05/38] Pin dev requirements --- requirements/dev.in | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/requirements/dev.in b/requirements/dev.in index fcced7d54f..ffc8007a23 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -4,25 +4,23 @@ # These are dev dependencies, so pin them. # For automatic style formatting. -# -# TODO: pin -autopep8 >= 1.4.0, < 1.6.0 +autopep8 == 1.5.7 # For style & code checking. -flake8 >= 3.9.0, < 4.0.0 -flake8-bugbear >= 19.3.0, < 20.0.0 -flake8-print >= 3.1.0, < 4.0.0 +flake8 == 3.9.1 +flake8-bugbear == 21.4.3 +flake8-print == 4.0.0 flake8-rst-docstrings == 0.2.3 # For PyTest test coverage integration. -pytest-cov >= 2.6.1, < 3.0.0 +pytest-cov == 2.11.1 # For optional duck & static type checking. -mypy >= 0.701, < 0.800 +mypy == 0.812 # For testing: parameters, variables, etc. -openfisca-country-template >= 3.10.0, < 4.0.0 -openfisca-extension-template >= 1.2.0rc0, < 2.0.0 +openfisca-country-template == 3.12.5 +openfisca-extension-template == 1.3.6 # Include Web API dependencies for development -r web-api.in From 9fbb223a8272840242061e663bf7cbc3212c55d1 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 15:14:10 +0200 Subject: [PATCH 06/38] Pin api requirements --- requirements/web-api.in | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/requirements/web-api.in b/requirements/web-api.in index fc8970d994..614541cefe 100644 --- a/requirements/web-api.in +++ b/requirements/web-api.in @@ -4,13 +4,9 @@ # These are web-api dependencies, so pin them. # For OpenFisca's Web API. -# -# TODO: pin. flask == 1.1.2 -gunicorn >= 20.0.0, < 21.0.0 +gunicorn == 20.1.0 # For OpenFisca's Web API users requiring CORS. -# -# TODO: pin. flask-cors == 3.0.10 -werkzeug >= 1.0.0, < 2.0.0 +werkzeug == 1.0.1 From 65207fe25970b44dac0a1d41a02862b3c0964228 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 15:53:36 +0200 Subject: [PATCH 07/38] Relax sortedcontainers --- requirements/install.in | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index 24eeab6f64..90eb9b65e2 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -64,7 +64,5 @@ pytest >= 4.4.1, < 6 # TODO: support for 4 should be dropped. PyYAML >= 3.10, < 6 -# For sorting formulas by date -# -# TODO: support for 2 should be extended. -sortedcontainers == 2.2.2 +# For sorting formulas by period. +sortedcontainers >= 2, < 3 From 024400c1552274d5945c41e262fb717d9cbaa7a8 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 16:03:04 +0200 Subject: [PATCH 08/38] Relax psutil --- requirements/install.in | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index 90eb9b65e2..49934634f4 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -44,11 +44,10 @@ numexpr ~= 2.7 # TODO: support for < 1.17 should be dropped. numpy ~= 1.11, < 1.21 -# For caching +# Memory monitoring for caching. # -# We support psutil >= 5.4.7 because users have found problems running -# older versions on Windows (specifically 5.4.2). -# TODO: support should be extended to >= 5.4.3. +# We support psutil >= 5.4.7 because it is the first version compatible with +# Python 3.7. psutil >= 5.4.7, < 6 # For openfisca test From 83a53be3ad2cc0fd7e1b0dcf5997e1728ad07b32 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 16:21:04 +0200 Subject: [PATCH 09/38] Relax dpath --- requirements/install.in | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index 49934634f4..db4c3634c1 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -24,11 +24,14 @@ # * Avoiding a version (document why). # * Pinning a version (document why). -# For globing over dictionaries as if they were filesystems +# For globing over dictionaries as if they were filesystems. # -# We do not support 2 because of a bug introducted to period parsing in +# We support from 1.3.2 on because it is the first version published +# following the semantic versioning specification. +# +# We do not support 2 because of a bug introduced to period parsing in # OpenFisca's Web API. -dpath ~= 1.5 +dpath >= 1.3.2, < 2 # For evaluating numerical expressions # From e167093d29712297e80ca0adcec7e8cd3055ce58 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 16:41:06 +0200 Subject: [PATCH 10/38] Drop support for numexpr < 2.7.1 --- requirements/install.in | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index db4c3634c1..e2f4eb5dac 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -35,8 +35,9 @@ dpath >= 1.3.2, < 2 # For evaluating numerical expressions # -# TODO: support for 2 should be extended. -numexpr ~= 2.7 +# We support numexpr >= 2.7.1 because it is the first version compatible with +# Python 3.7. +numexpr >= 2.7.1, < 3 # For vectorial support # From caf1c8535149f73ec799ac626fa775d1caad439c Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 20:24:20 +0200 Subject: [PATCH 11/38] Drop support for pytest < 5.4.2 --- requirements/install.in | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index e2f4eb5dac..864964d6d6 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -29,7 +29,7 @@ # We support from 1.3.2 on because it is the first version published # following the semantic versioning specification. # -# We do not support 2 because of a bug introduced to period parsing in +# We do not support 2 yet because of a bug introduced to period parsing in # OpenFisca's Web API. dpath >= 1.3.2, < 2 @@ -56,10 +56,12 @@ psutil >= 5.4.7, < 6 # For openfisca test # -# TODO: support for 4 should be dropped. -# TODO: support for 6 requires fixing some tests. +# We support pytest >= 5.4.2 because `openfisca test` relies on the signature of +# `pytest.Item.from_module()` introduced since this 5.4.2. +# +# We do not support 6 yet because it requires fixing some tests before. # See: https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent -pytest >= 4.4.1, < 6 +pytest >= 5.4.2, < 6 # For parameters, tests # From b31aec295c4186b7a546ab667b089e021c080d5c Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 20:40:14 +0200 Subject: [PATCH 12/38] Drop support for PyYAML < 5.1 --- requirements/install.in | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index 864964d6d6..43613a9f4b 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -63,11 +63,8 @@ psutil >= 5.4.7, < 6 # See: https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent pytest >= 5.4.2, < 6 -# For parameters, tests -# -# TODO: support for 3 should be dropped. -# TODO: support for 4 should be dropped. -PyYAML >= 3.10, < 6 +# For parameters, tests. +PyYAML >= 5.1, < 6 # For sorting formulas by period. sortedcontainers >= 2, < 3 From 7edbe620ad6611f3c49ceb1de8811f65cec60b16 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 20:54:59 +0200 Subject: [PATCH 13/38] Drop support for NumPy < 1.17 --- requirements/install.in | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements/install.in b/requirements/install.in index 43613a9f4b..fcb8c75047 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -44,9 +44,7 @@ numexpr >= 2.7.1, < 3 # We support the latest four minors because NumPy is generally a transitive # dependency that users rely on within the projects where OpenFisca is # depended on by. -# -# TODO: support for < 1.17 should be dropped. -numpy ~= 1.11, < 1.21 +numpy >= 1.17, < 1.21 # Memory monitoring for caching. # From a38458be0ad2ff5c8e06d438f33802d76cb59d48 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 22:31:07 +0200 Subject: [PATCH 14/38] Add integrity test for NumPy --- .circleci/config.yml | 128 ++++++++++++++----- .circleci/get-numpy-version.py | 38 ------ README.md | 4 +- requirements/constraints | 6 + requirements/{coverage.in => coverage} | 0 requirements/{debug.in => debug} | 0 requirements/{dev.in => dev} | 2 +- requirements/{install.in => install} | 0 requirements/{publication.in => publication} | 0 requirements/{tracker.in => tracker} | 0 requirements/{web-api.in => web-api} | 0 setup.py | 2 +- 12 files changed, 108 insertions(+), 72 deletions(-) delete mode 100755 .circleci/get-numpy-version.py create mode 100644 requirements/constraints rename requirements/{coverage.in => coverage} (100%) rename requirements/{debug.in => debug} (100%) rename requirements/{dev.in => dev} (97%) rename requirements/{install.in => install} (100%) rename requirements/{publication.in => publication} (100%) rename requirements/{tracker.in => tracker} (100%) rename requirements/{web-api.in => web-api} (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 416fb281c1..50675206dd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,6 +1,19 @@ version: 2 jobs: - run_tests: + check_version: + docker: + - image: python:3.7 + + steps: + - checkout + + - run: + name: Check version number has been properly updated + command: | + git fetch + .circleci/is-version-number-acceptable.sh + + build: docker: - image: python:3.7 environment: @@ -17,13 +30,16 @@ jobs: command: | mkdir -p /tmp/venv/openfisca_core python -m venv /tmp/venv/openfisca_core - echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - run: name: Install dependencies command: | - make install - make clean + pip install --upgrade pip + pip install --requirement requirements/dev # pip install --editable git+https://github.com/openfisca/country-template.git@BRANCH_NAME#egg=OpenFisca-Country-Template # use a specific branch of OpenFisca-Country-Template # pip install --editable git+https://github.com/openfisca/extension-template.git@BRANCH_NAME#egg=OpenFisca-Extension-Template # use a specific branch of OpenFisca-Extension-Template @@ -32,6 +48,27 @@ jobs: paths: - /tmp/venv/openfisca_core + test: + docker: + - image: python:3.7 + + environment: + PYTEST_ADDOPTS: --exitfirst + + steps: + - checkout + + - restore_cache: + key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} + + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + + - run: + name: Install core + command: pip install --editable . --upgrade --no-dependencies + - run: name: Run linters command: make lint @@ -49,14 +86,41 @@ jobs: command: make test-extension pytest_args="--exitfirst" - run: - name: Check NumPy typing against latest 3 minor versions - command: for i in {1..3}; do VERSION=$(.circleci/get-numpy-version.py prev) && pip install numpy==$VERSION && make check-types; done + name: Run core tests + command: make test - persist_to_workspace: root: . paths: - .coverage + test_compatibility: + docker: + - image: python:3.7 + + environment: + PYTEST_ADDOPTS: --exitfirst + + steps: + - checkout + + - restore_cache: + key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} + + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + + - run: + name: Install core with a constrained Numpy version + command: | + pip install --requirement requirements/dev --upgrade --constraint requirements/compatibility + pip install --editable . --upgrade --no-dependencies + + - run: + name: Run core tests + command: make test + test_docs: docker: - image: python:3.7 @@ -96,20 +160,6 @@ jobs: name: Run doc tests command: make test-doc-build - - check_version: - docker: - - image: python:3.7 - - steps: - - checkout - - - run: - name: Check version number has been properly updated - command: | - git fetch - .circleci/is-version-number-acceptable.sh - submit_coverage: docker: - image: python:3.7 @@ -123,11 +173,16 @@ jobs: - restore_cache: key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + - run: name: Submit coverage to Coveralls command: | - source /tmp/venv/openfisca_core/bin/activate - pip install --editable .[coverage] + pip install --requirement requirements/coverage --upgrade + pip install --requirement requirements/dev --upgrade + pip install --editable . --upgrade --no-dependencies coveralls - save_cache: @@ -138,6 +193,7 @@ jobs: deploy: docker: - image: python:3.7 + environment: PYPI_USERNAME: openfisca-bot # PYPI_PASSWORD: this value is set in CircleCI's web interface; do not set it here, it is a secret! @@ -152,11 +208,15 @@ jobs: name: Check for functional changes command: if ! .circleci/has-functional-changes.sh ; then circleci step halt ; fi + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + - run: name: Upload a Python package to PyPi command: | - source /tmp/venv/openfisca_core/bin/activate - pip install --editable .[publication] + pip install --requirement requirements/publication --upgrade + pip install --editable . --upgrade --no-dependencies .circleci/publish-python-package.sh - run: @@ -172,17 +232,27 @@ workflows: version: 2 build_and_deploy: jobs: - - run_tests - - test_docs - check_version + - build + - test: + requires: + - build + - test_compatibility: + requires: + - build + - test_docs: + requires: + - build - submit_coverage: requires: - - run_tests + - test + - test_compatibility - deploy: requires: - - run_tests - - test_docs - check_version + - test + - test_compatibility + - test_docs filters: branches: only: master diff --git a/.circleci/get-numpy-version.py b/.circleci/get-numpy-version.py deleted file mode 100755 index 64cb68532e..0000000000 --- a/.circleci/get-numpy-version.py +++ /dev/null @@ -1,38 +0,0 @@ -#! /usr/bin/env python - -from __future__ import annotations - -import os -import sys -import typing -from packaging import version -from typing import NoReturn, Union - -import numpy - -if typing.TYPE_CHECKING: - from packaging.version import LegacyVersion, Version - - -def prev() -> NoReturn: - release = _installed().release - - if release is None: - sys.exit(os.EX_DATAERR) - - major, minor, _ = release - - if minor == 0: - sys.exit(os.EX_DATAERR) - - minor -= 1 - print(f"{major}.{minor}.0") # noqa: T001 - sys.exit(os.EX_OK) - - -def _installed() -> Union[LegacyVersion, Version]: - return version.parse(numpy.__version__) - - -if __name__ == "__main__": - globals()[sys.argv[1]]() diff --git a/README.md b/README.md index 7f253c9114..0d1dd703ca 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This package contains the core features of OpenFisca, which are meant to be used OpenFisca runs on Python 3.7. More recent versions should work, but are not tested. -OpenFisca also relies strongly on NumPy. Last four minor versions should work, but only latest/stable is tested. +OpenFisca also relies strongly on NumPy. Last four minor versions should work, but only upper and lower bound versions are tested. ## Installation @@ -58,8 +58,6 @@ pytest tests/core/test_parameters.py -k test_parameter_for_period This repository relies on MyPy for optional dynamic & static type checking. -As NumPy introduced the `typing` module in 1.20.0, to ensure type hints do not break the code at runtime, we run the checker against the last four minor NumPy versions. - Type checking is already run with `make test`. To run the type checker alone: ```sh diff --git a/requirements/constraints b/requirements/constraints new file mode 100644 index 0000000000..4deb711fae --- /dev/null +++ b/requirements/constraints @@ -0,0 +1,6 @@ +# These are constraint versions to ensure the integrity of distributions. +# +# Normally, we want to add here the pinned lower-bound supported versions of +# dependencies critical to OpenFisca's usability. + +numpy == 1.17 diff --git a/requirements/coverage.in b/requirements/coverage similarity index 100% rename from requirements/coverage.in rename to requirements/coverage diff --git a/requirements/debug.in b/requirements/debug similarity index 100% rename from requirements/debug.in rename to requirements/debug diff --git a/requirements/dev.in b/requirements/dev similarity index 97% rename from requirements/dev.in rename to requirements/dev index ffc8007a23..8c45496681 100644 --- a/requirements/dev.in +++ b/requirements/dev @@ -23,4 +23,4 @@ openfisca-country-template == 3.12.5 openfisca-extension-template == 1.3.6 # Include Web API dependencies for development --r web-api.in +-r web-api diff --git a/requirements/install.in b/requirements/install similarity index 100% rename from requirements/install.in rename to requirements/install diff --git a/requirements/publication.in b/requirements/publication similarity index 100% rename from requirements/publication.in rename to requirements/publication diff --git a/requirements/tracker.in b/requirements/tracker similarity index 100% rename from requirements/tracker.in rename to requirements/tracker diff --git a/requirements/web-api.in b/requirements/web-api similarity index 100% rename from requirements/web-api.in rename to requirements/web-api diff --git a/setup.py b/setup.py index 1d209f0d48..786a28d7d1 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def load_requirements_from_file(filename: str) -> List[str]: setup( name = 'OpenFisca-Core', - version = '35.7.1', + version = '36.0.0', author = 'OpenFisca Team', author_email = 'contact@openfisca.org', classifiers = [ From 62c0585f9a38c7d9c6fee895848e931d876e1404 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 7 Sep 2021 23:51:07 +0200 Subject: [PATCH 15/38] Improve phrasing in README.md Co-authored-by: Matti Schneider --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d1dd703ca..8474d72150 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This package contains the core features of OpenFisca, which are meant to be used OpenFisca runs on Python 3.7. More recent versions should work, but are not tested. -OpenFisca also relies strongly on NumPy. Last four minor versions should work, but only upper and lower bound versions are tested. +OpenFisca also relies strongly on NumPy. Only upper and lower bound versions are tested. ## Installation From 7b90b14bad3f055d6ff6fb5ff4947e22fff4f25e Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 7 Sep 2021 23:56:11 +0200 Subject: [PATCH 16/38] Remove debug requirements --- requirements/debug | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 requirements/debug diff --git a/requirements/debug b/requirements/debug deleted file mode 100644 index 59cf404cae..0000000000 --- a/requirements/debug +++ /dev/null @@ -1,5 +0,0 @@ -# These are dependencies to help with debug and profiling, so we always want -# the latest versions. - -# Interactive console on steroids (even makes your coffee)! -ipython From 8b1b84b46287bd681620ff9ee91cdc956c07d641 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 7 Sep 2021 23:59:02 +0200 Subject: [PATCH 17/38] Do not pin deps/tracker --- requirements/tracker | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/tracker b/requirements/tracker index 929c660ee4..6ab42528d2 100644 --- a/requirements/tracker +++ b/requirements/tracker @@ -4,4 +4,4 @@ # These are web-api tracking dependencies, so pin them. # For sending usage statistics to the Matomo third-party service. -openfisca-tracker == 0.4.0 +openfisca-tracker <= 0.4.0 From 73eaac86bd097646dd6adc08e85d73532894d79b Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 00:13:33 +0200 Subject: [PATCH 18/38] Clarify requirements/install Co-authored-by: Matti Schneider --- requirements/install | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements/install b/requirements/install index fcb8c75047..fd67bba9cf 100644 --- a/requirements/install +++ b/requirements/install @@ -1,12 +1,12 @@ # Please make sure to cap all dependency versions, in order to avoid unwanted # functional and integration breaks caused by external code updates. # -# These are general dependencies, so: +# These dependencies are always installed, so: # # * the ranges should be broad, and -# * the number of them kept discreet. +# * the number of them kept low. # -# If a dependency seems trivial, please propose a pull request to remove it. +# If a dependency seems redundant, please propose a pull request to remove it. # # General Rule # ======== @@ -24,7 +24,7 @@ # * Avoiding a version (document why). # * Pinning a version (document why). -# For globing over dictionaries as if they were filesystems. +# For globbing over dictionaries as if they were filesystems. # # We support from 1.3.2 on because it is the first version published # following the semantic versioning specification. @@ -61,7 +61,7 @@ psutil >= 5.4.7, < 6 # See: https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent pytest >= 5.4.2, < 6 -# For parameters, tests. +# For parameters and tests parsing. PyYAML >= 5.1, < 6 # For sorting formulas by period. From a2103a6eaa04d28a112c715ba38502c6748cc65d Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 00:25:52 +0200 Subject: [PATCH 19/38] Delete duplicated numpy rule --- requirements/install | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/requirements/install b/requirements/install index fd67bba9cf..e6ebb4e021 100644 --- a/requirements/install +++ b/requirements/install @@ -9,17 +9,12 @@ # If a dependency seems redundant, please propose a pull request to remove it. # # General Rule -# ======== +# ============ # # * Cap to latest major. # -# NumPy Rule -# ======== -# -# * Cap to latest four minors (document why). -# # Exceptions -# ======== +# ========== # # * Avoiding a version (document why). # * Pinning a version (document why). From d865888368d88f394be3a142f3a22fcdd368b8d1 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 00:27:53 +0200 Subject: [PATCH 20/38] Add link to issue in dpath --- requirements/install | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/install b/requirements/install index e6ebb4e021..ab9a60af9b 100644 --- a/requirements/install +++ b/requirements/install @@ -24,8 +24,8 @@ # We support from 1.3.2 on because it is the first version published # following the semantic versioning specification. # -# We do not support 2 yet because of a bug introduced to period parsing in -# OpenFisca's Web API. +# We do not support 2 yet because of a [bug introduced to period parsing in +# OpenFisca's Web API: https://github.com/openfisca/openfisca-core/pull/948 dpath >= 1.3.2, < 2 # For evaluating numerical expressions From 9bec4dec6fd46c08d08efaef1cb9c94514597ebb Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 00:31:33 +0200 Subject: [PATCH 21/38] Improve message in tracker deps --- requirements/tracker | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/requirements/tracker b/requirements/tracker index 6ab42528d2..7dd90ccee8 100644 --- a/requirements/tracker +++ b/requirements/tracker @@ -1,7 +1,5 @@ -# Please make sure to pin all dependency versions, in order to avoid unwanted -# functional and integration breaks caused by external code updates. -# -# These are web-api tracking dependencies, so pin them. +# Dependencies for tracking are optional, so we always want the latest +# versions. # For sending usage statistics to the Matomo third-party service. openfisca-tracker <= 0.4.0 From 65855e82513e9cadbf7b7c43da44cbb2f5454951 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 00:41:27 +0200 Subject: [PATCH 22/38] Improve wording in web-api deps --- requirements/web-api | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/requirements/web-api b/requirements/web-api index 614541cefe..ce7a2cc431 100644 --- a/requirements/web-api +++ b/requirements/web-api @@ -1,7 +1,10 @@ -# Please make sure to pin all dependency versions, in order to avoid unwanted -# functional and integration breaks caused by external code updates. +# These are dependencies to serve the Web-API, so we always want to support +# the latest versions. # -# These are web-api dependencies, so pin them. +# As a safety measure, compatibility could be smoke-tested automatically: +# https://github.com/openfisca/country-template/pull/113 +# +# In the meantime, we pin these dependencies. # For OpenFisca's Web API. flask == 1.1.2 From 4fbf18a9f2a5bb397291e345b0fc280acd486604 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 22:00:28 +0200 Subject: [PATCH 23/38] Remove legacy resolver --- requirements/dev | 15 ++++++++++++--- requirements/tracker | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/requirements/dev b/requirements/dev index 8c45496681..173777abcc 100644 --- a/requirements/dev +++ b/requirements/dev @@ -1,7 +1,16 @@ # Please make sure to pin all dependency versions, in order to avoid unwanted # functional and integration breaks caused by external code updates. # -# These are dev dependencies, so pin them. +# General Rule +# ============ +# +# * Pin them. +# +# Exceptions +# ========== +# +# * openfisca-country-template should not be constrained (circular dep). +# * openfisca-extension-template should not be constrained (circular dep). # For automatic style formatting. autopep8 == 1.5.7 @@ -19,8 +28,8 @@ pytest-cov == 2.11.1 mypy == 0.812 # For testing: parameters, variables, etc. -openfisca-country-template == 3.12.5 -openfisca-extension-template == 1.3.6 +openfisca-country-template +openfisca-extension-template # Include Web API dependencies for development -r web-api diff --git a/requirements/tracker b/requirements/tracker index 7dd90ccee8..8b9f13a850 100644 --- a/requirements/tracker +++ b/requirements/tracker @@ -2,4 +2,4 @@ # versions. # For sending usage statistics to the Matomo third-party service. -openfisca-tracker <= 0.4.0 +openfisca-tracker From d0bd5110b0e0bef6e7132cb36015f427573b606f Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 8 Sep 2021 22:14:42 +0200 Subject: [PATCH 24/38] Update README.md --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8474d72150..5eb4ce4f40 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ cd openfisca-core python3 -m venv .venv source .venv/bin/activate pip install -U pip -pip install --editable .[dev] --use-deprecated=legacy-resolver +pip install --requirement requirements/dev --upgrade +pip install --editable . --upgrade --no-dependencies ``` ## Testing @@ -194,9 +195,16 @@ The OpenFisca Web API comes with an [optional tracker](https://github.com/openfi The tracker is not installed by default. To install it, run: ```sh -pip install openfisca_core[tracker] --use-deprecated=legacy-resolver # Or `pip install --editable ".[tracker]"` for an editable installation +pip install openfisca_core[tracker] ``` +Or for an editable installation: + +``` +pip install --requirement requirements/tracker --upgrade +pip install --requirement requirements/dev --upgrade +pip install --editable . --upgrade --no-dependencies +``` #### Tracker configuration From 0818f5782d39e349dfbc06d873543c2a82f6db16 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Mon, 25 Oct 2021 23:19:30 +0200 Subject: [PATCH 25/38] Apply suggestions from code review Co-authored-by: Matti Schneider --- requirements/install | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/install b/requirements/install index ab9a60af9b..49aabadba6 100644 --- a/requirements/install +++ b/requirements/install @@ -24,7 +24,7 @@ # We support from 1.3.2 on because it is the first version published # following the semantic versioning specification. # -# We do not support 2 yet because of a [bug introduced to period parsing in +# We do not support 2 yet because of a bug introduced to period parsing in # OpenFisca's Web API: https://github.com/openfisca/openfisca-core/pull/948 dpath >= 1.3.2, < 2 @@ -50,7 +50,7 @@ psutil >= 5.4.7, < 6 # For openfisca test # # We support pytest >= 5.4.2 because `openfisca test` relies on the signature of -# `pytest.Item.from_module()` introduced since this 5.4.2. +# `pytest.Item.from_module()` introduced since 5.4.2. # # We do not support 6 yet because it requires fixing some tests before. # See: https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent From e346c09c85fcacec8a7f7463d2f5372a0e96834e Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 13:56:27 +0200 Subject: [PATCH 26/38] Rename constraints => compatibility --- requirements/{constraints => compatibility} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename requirements/{constraints => compatibility} (64%) diff --git a/requirements/constraints b/requirements/compatibility similarity index 64% rename from requirements/constraints rename to requirements/compatibility index 4deb711fae..3dcf417557 100644 --- a/requirements/constraints +++ b/requirements/compatibility @@ -1,4 +1,4 @@ -# These are constraint versions to ensure the integrity of distributions. +# These are constraint versions to ensure the compatibility of distributions. # # Normally, we want to add here the pinned lower-bound supported versions of # dependencies critical to OpenFisca's usability. From e5b803eabf521b6259bcf1748e6997c0ec500be0 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 14:21:14 +0200 Subject: [PATCH 27/38] Update dependencies --- requirements/dev | 13 ++++++++----- requirements/install | 6 ++++++ requirements/tracker | 4 +++- requirements/web-api | 17 ++++++++--------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/requirements/dev b/requirements/dev index 173777abcc..5c367e8ae0 100644 --- a/requirements/dev +++ b/requirements/dev @@ -13,19 +13,22 @@ # * openfisca-extension-template should not be constrained (circular dep). # For automatic style formatting. -autopep8 == 1.5.7 +autopep8 == 1.6.0 # For style & code checking. -flake8 == 3.9.1 -flake8-bugbear == 21.4.3 +darglint == 1.8.0 +flake8 == 4.0.1 +flake8-bugbear == 21.9.2 +flake8-docstrings == 1.6.0 flake8-print == 4.0.0 flake8-rst-docstrings == 0.2.3 +pylint == 2.11.1 # For PyTest test coverage integration. -pytest-cov == 2.11.1 +pytest-cov == 3.0.0 # For optional duck & static type checking. -mypy == 0.812 +mypy == 0.910 # For testing: parameters, variables, etc. openfisca-country-template diff --git a/requirements/install b/requirements/install index 49aabadba6..cd4a27a25a 100644 --- a/requirements/install +++ b/requirements/install @@ -28,6 +28,9 @@ # OpenFisca's Web API: https://github.com/openfisca/openfisca-core/pull/948 dpath >= 1.3.2, < 2 +# For Numpy type-hints. +nptyping >= 1, < 2 + # For evaluating numerical expressions # # We support numexpr >= 2.7.1 because it is the first version compatible with @@ -61,3 +64,6 @@ PyYAML >= 5.1, < 6 # For sorting formulas by period. sortedcontainers >= 2, < 3 + +# For typing backports. +typing-extensions >= 3, < 4 diff --git a/requirements/tracker b/requirements/tracker index 8b9f13a850..c0c8ce0599 100644 --- a/requirements/tracker +++ b/requirements/tracker @@ -2,4 +2,6 @@ # versions. # For sending usage statistics to the Matomo third-party service. -openfisca-tracker +# +# We start from the currently supported version forward. +openfisca-tracker >= 0.4.0 diff --git a/requirements/web-api b/requirements/web-api index ce7a2cc431..5034026bcd 100644 --- a/requirements/web-api +++ b/requirements/web-api @@ -1,15 +1,14 @@ # These are dependencies to serve the Web-API, so we always want to support # the latest versions. -# -# As a safety measure, compatibility could be smoke-tested automatically: -# https://github.com/openfisca/country-template/pull/113 -# -# In the meantime, we pin these dependencies. # For OpenFisca's Web API. -flask == 1.1.2 -gunicorn == 20.1.0 +# +# We start from the currently supported versions forward. +flask >= 1.1.2 +gunicorn >= 20.1.0 # For OpenFisca's Web API users requiring CORS. -flask-cors == 3.0.10 -werkzeug == 1.0.1 +# +# We start from the currently supported versions forward. +flask-cors >= 3.0.10 +werkzeug >= 1.0.1 From c0a654e0364b7e1bd3b5ad698398f2012e65b8ff Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 16:24:29 +0200 Subject: [PATCH 28/38] Use make as task-manager --- .circleci/config.yml | 93 ++++++++++++++--------------- .circleci/publish-python-package.sh | 4 -- MANIFEST.in | 3 +- openfisca_tasks/install.mk | 33 +++++++++- openfisca_tasks/publish.mk | 43 ++++++++++++- requirements/common | 8 +++ requirements/publication | 2 +- setup.py | 5 +- 8 files changed, 130 insertions(+), 61 deletions(-) delete mode 100755 .circleci/publish-python-package.sh create mode 100644 requirements/common diff --git a/.circleci/config.yml b/.circleci/config.yml index 50675206dd..e581c5555e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -38,8 +38,8 @@ jobs: - run: name: Install dependencies command: | - pip install --upgrade pip - pip install --requirement requirements/dev + make install-deps + make install-dev # pip install --editable git+https://github.com/openfisca/country-template.git@BRANCH_NAME#egg=OpenFisca-Country-Template # use a specific branch of OpenFisca-Country-Template # pip install --editable git+https://github.com/openfisca/extension-template.git@BRANCH_NAME#egg=OpenFisca-Extension-Template # use a specific branch of OpenFisca-Extension-Template @@ -51,9 +51,8 @@ jobs: test: docker: - image: python:3.7 - environment: - PYTEST_ADDOPTS: --exitfirst + TERM: xterm-256color # To colorize output of make tasks. steps: - checkout @@ -67,7 +66,7 @@ jobs: - run: name: Install core - command: pip install --editable . --upgrade --no-dependencies + command: make install-core - run: name: Run linters @@ -94,33 +93,6 @@ jobs: paths: - .coverage - test_compatibility: - docker: - - image: python:3.7 - - environment: - PYTEST_ADDOPTS: --exitfirst - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Install core with a constrained Numpy version - command: | - pip install --requirement requirements/dev --upgrade --constraint requirements/compatibility - pip install --editable . --upgrade --no-dependencies - - - run: - name: Run core tests - command: make test - test_docs: docker: - image: python:3.7 @@ -145,7 +117,10 @@ jobs: command: | mkdir -p /tmp/venv/openfisca_doc python -m venv /tmp/venv/openfisca_doc - echo "source /tmp/venv/openfisca_doc/bin/activate" >> $BASH_ENV + + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_doc/bin/activate" >> $BASH_ENV - run: name: Install dependencies @@ -160,6 +135,34 @@ jobs: name: Run doc tests command: make test-doc-build + + test_compatibility: + docker: + - image: python:3.7 + + environment: + PYTEST_ADDOPTS: --exitfirst + + steps: + - checkout + + - restore_cache: + key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} + + - run: + name: Activate virtualenv + command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV + + - run: + name: Install core with a constrained Numpy version + command: | + make install-core + make install-compat + + - run: + name: Run core tests + command: make test + submit_coverage: docker: - image: python:3.7 @@ -180,16 +183,10 @@ jobs: - run: name: Submit coverage to Coveralls command: | - pip install --requirement requirements/coverage --upgrade - pip install --requirement requirements/dev --upgrade - pip install --editable . --upgrade --no-dependencies + make install-core + make install-cov coveralls - - save_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - paths: - - /tmp/venv/openfisca_core - deploy: docker: - image: python:3.7 @@ -215,9 +212,8 @@ jobs: - run: name: Upload a Python package to PyPi command: | - pip install --requirement requirements/publication --upgrade - pip install --editable . --upgrade --no-dependencies - .circleci/publish-python-package.sh + make build + make publish - run: name: Publish a git tag @@ -237,22 +233,23 @@ workflows: - test: requires: - build - - test_compatibility: - requires: - - build - test_docs: requires: - build + - test_compatibility: + requires: + - test_docs - submit_coverage: requires: - test + - test_docs - test_compatibility - deploy: requires: - check_version - test - - test_compatibility - test_docs + - test_compatibility filters: branches: only: master diff --git a/.circleci/publish-python-package.sh b/.circleci/publish-python-package.sh deleted file mode 100755 index 8d331bd946..0000000000 --- a/.circleci/publish-python-package.sh +++ /dev/null @@ -1,4 +0,0 @@ -#! /usr/bin/env bash - -python setup.py bdist_wheel # build this package in the dist directory -twine upload dist/* --username $PYPI_USERNAME --password $PYPI_PASSWORD # publish diff --git a/MANIFEST.in b/MANIFEST.in index 166788d7fa..507d218461 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ -recursive-include openfisca_core/scripts * +graft requirements include openfisca_web_api/openAPI.yml +recursive-include openfisca_core/scripts * diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk index 7918c76b14..f7d1f8074b 100644 --- a/openfisca_tasks/install.mk +++ b/openfisca_tasks/install.mk @@ -1,9 +1,36 @@ ## Install project dependencies. install: + @${MAKE} install-deps + @${MAKE} install-dev + @${MAKE} install-core + @$(call print_pass,$@:) + +## Install common dependencies. +install-deps: + @$(call print_help,$@:) + @pip install --quiet --upgrade --constraint requirements/common pip setuptools + +## Install development dependencies. +install-dev: + @$(call print_help,$@:) + @pip install --quiet --upgrade --requirement requirements/install + @pip install --quiet --upgrade --requirement requirements/dev + +## Install package. +install-core: + @$(call print_help,$@:) + @pip uninstall --quiet --yes openfisca-core + @pip install --quiet --no-dependencies --editable . + +## Install lower-bound dependencies for compatibility check. +install-compat: + @$(call print_help,$@:) + @pip install --quiet --upgrade --constraint requirements/compatibility numpy + +## Install coverage dependencies. +install-cov: @$(call print_help,$@:) - @pip install --upgrade pip setuptools - @pip install --requirement requirements/dev --upgrade - @pip install --editable . --upgrade --no-dependencies + @pip install --quiet --upgrade --constraint requirements/coverage coveralls ## Uninstall project dependencies. uninstall: diff --git a/openfisca_tasks/publish.mk b/openfisca_tasks/publish.mk index 2bcd2c0ba7..09686a6274 100644 --- a/openfisca_tasks/publish.mk +++ b/openfisca_tasks/publish.mk @@ -1,8 +1,45 @@ -## Install openfisca-core for deployment and publishing. +.PHONY: build + +## Build openfisca-core for deployment and publishing. build: @## This allows us to be sure tests are run against the packaged version @## of openfisca-core, the same we put in the hands of users and reusers. @$(call print_help,$@:) - @python setup.py bdist_wheel - @find dist -name "*.whl" -exec pip install --force-reinstall {}[dev] \; + @${MAKE} install-deps + @${MAKE} build-deps + @${MAKE} build-build + @${MAKE} build-install @$(call print_pass,$@:) + +## Install building dependencies. +build-deps: + @$(call print_help,$@:) + @pip install --quiet --upgrade --constraint requirements/publication build + +## Build the package. +build-build: + @$(call print_help,$@:) + @python -m build + +## Install the built package. +build-install: + @$(call print_help,$@:) + @pip uninstall --quiet --yes openfisca-core + @find dist -name "*.whl" -exec pip install --quiet --no-dependencies {} \; + +## Publish package. +publish: + @$(call print_help,$@:) + @${MAKE} publish-deps + @${MAKE} publish-upload + @$(call print_pass,$@:) + +## Install required publishing dependencies. +publish-deps: + @$(call print_help,$@:) + @pip install --quiet --upgrade --constraint requirements/publication twine + +## Upload package to PyPi. +publish-upload: + @$(call print_help,$@:) + twine upload dist/* --username $${PYPI_USERNAME} --password $${PYPI_PASSWORD} diff --git a/requirements/common b/requirements/common new file mode 100644 index 0000000000..5599db953a --- /dev/null +++ b/requirements/common @@ -0,0 +1,8 @@ +# These are dependencies to build the library, so we always want the latest +# versions. + +# For managing dependencies. +pip + +# For packaging the package. +setuptools diff --git a/requirements/publication b/requirements/publication index 85d4ccd2ea..1711493b6d 100644 --- a/requirements/publication +++ b/requirements/publication @@ -5,4 +5,4 @@ twine # For building the package. -wheel +build diff --git a/setup.py b/setup.py index 786a28d7d1..5975877b72 100644 --- a/setup.py +++ b/setup.py @@ -5,13 +5,15 @@ from typing import List import re +from pathlib import Path from setuptools import setup, find_packages def load_requirements_from_file(filename: str) -> List[str]: """Allows for composable requirement files with the `-r filename` flag.""" - reqs = open(f"requirements/{filename}").readlines() + file = Path(f"./requirements/{filename}").resolve() + reqs = open(file).readlines() pattern = re.compile(r"^\s*-r\s*(?P.*)$") for req in reqs: @@ -56,6 +58,7 @@ def load_requirements_from_file(filename: str) -> List[str]: python_requires = ">= 3.7", install_requires = load_requirements_from_file("install"), extras_require = { + "common": load_requirements_from_file("common"), "coverage": load_requirements_from_file("coverage"), "dev": load_requirements_from_file("dev"), "publication": load_requirements_from_file("publication"), From 5801347244af2ade872b11974a508c9406bc6cca Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 16:27:50 +0200 Subject: [PATCH 29/38] Update README.md --- README.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5eb4ce4f40..5f060caf23 100644 --- a/README.md +++ b/README.md @@ -30,9 +30,7 @@ git clone https://github.com/openfisca/openfisca-core.git cd openfisca-core python3 -m venv .venv source .venv/bin/activate -pip install -U pip -pip install --requirement requirements/dev --upgrade -pip install --editable . --upgrade --no-dependencies +make install ``` ## Testing @@ -46,10 +44,10 @@ make test To run all the tests defined on a test file: ```sh -pytest tests/core/test_parameters.py +openfisca test tests/core/test_parameters.py ``` -To run a single test: +You can also use `pytest`, for example to run a single test: ```sh pytest tests/core/test_parameters.py -k test_parameter_for_period From 65c46213783c8d2531d2dee1274ff89a41dd742b Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 16:34:51 +0200 Subject: [PATCH 30/38] Expire deprecations --- openfisca_core/commons/__init__.py | 13 ++----------- openfisca_core/errors/__init__.py | 6 +++--- openfisca_core/formula_helpers.py | 9 --------- openfisca_core/memory_config.py | 9 --------- openfisca_core/parameters/__init__.py | 7 ++----- openfisca_core/rates.py | 9 --------- openfisca_core/simulation_builder.py | 16 ---------------- openfisca_core/simulations/__init__.py | 2 -- openfisca_core/taxbenefitsystems/__init__.py | 2 -- openfisca_core/taxscales/__init__.py | 2 -- openfisca_core/variables/__init__.py | 1 - openfisca_core/variables/variable.py | 2 +- 12 files changed, 8 insertions(+), 70 deletions(-) delete mode 100644 openfisca_core/formula_helpers.py delete mode 100644 openfisca_core/memory_config.py delete mode 100644 openfisca_core/rates.py delete mode 100644 openfisca_core/simulation_builder.py diff --git a/openfisca_core/commons/__init__.py b/openfisca_core/commons/__init__.py index b3b5d8cbb2..c2927dea22 100644 --- a/openfisca_core/commons/__init__.py +++ b/openfisca_core/commons/__init__.py @@ -12,12 +12,9 @@ * :func:`.stringify_array` * :func:`.switch` -Deprecated: - * :class:`.Dummy` - Note: - The ``deprecated`` imports are transitional, in order to ensure non-breaking - changes, and could be removed from the codebase in the next + The ``deprecated`` imports are transitional, in order to ensure + non-breaking changes, and could be removed from the codebase in the next major release. Note: @@ -59,9 +56,3 @@ __all__ = ["apply_thresholds", "concat", "switch"] __all__ = ["empty_clone", "stringify_array", *__all__] __all__ = ["average_rate", "marginal_rate", *__all__] - -# Deprecated - -from .dummy import Dummy # noqa: F401 - -__all__ = ["Dummy", *__all__] diff --git a/openfisca_core/errors/__init__.py b/openfisca_core/errors/__init__.py index ccd19af9b2..e5b9abbc78 100644 --- a/openfisca_core/errors/__init__.py +++ b/openfisca_core/errors/__init__.py @@ -24,10 +24,10 @@ from .cycle_error import CycleError # noqa: F401 from .empty_argument_error import EmptyArgumentError # noqa: F401 from .nan_creation_error import NaNCreationError # noqa: F401 -from .parameter_not_found_error import ParameterNotFoundError, ParameterNotFoundError as ParameterNotFound # noqa: F401 +from .parameter_not_found_error import ParameterNotFoundError # noqa: F401 from .parameter_parsing_error import ParameterParsingError # noqa: F401 from .period_mismatch_error import PeriodMismatchError # noqa: F401 from .situation_parsing_error import SituationParsingError # noqa: F401 from .spiral_error import SpiralError # noqa: F401 -from .variable_name_config_error import VariableNameConflictError, VariableNameConflictError as VariableNameConflict # noqa: F401 -from .variable_not_found_error import VariableNotFoundError, VariableNotFoundError as VariableNotFound # noqa: F401 +from .variable_name_config_error import VariableNameConflictError # noqa: F401 +from .variable_not_found_error import VariableNotFoundError # noqa: F401 diff --git a/openfisca_core/formula_helpers.py b/openfisca_core/formula_helpers.py deleted file mode 100644 index e0c755348e..0000000000 --- a/openfisca_core/formula_helpers.py +++ /dev/null @@ -1,9 +0,0 @@ -# The formula_helpers module has been deprecated since X.X.X, -# and will be removed in the future. -# -# The helpers have been moved to the commons module. -# -# The following are transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. - -from openfisca_core.commons import apply_thresholds, concat, switch # noqa: F401 diff --git a/openfisca_core/memory_config.py b/openfisca_core/memory_config.py deleted file mode 100644 index 18c4cebcdc..0000000000 --- a/openfisca_core/memory_config.py +++ /dev/null @@ -1,9 +0,0 @@ -# The memory config module has been deprecated since X.X.X, -# and will be removed in the future. -# -# Module's contents have been moved to the experimental module. -# -# The following are transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. - -from openfisca_core.experimental import MemoryConfig # noqa: F401 diff --git a/openfisca_core/parameters/__init__.py b/openfisca_core/parameters/__init__.py index 040ae47056..bbf5a0595f 100644 --- a/openfisca_core/parameters/__init__.py +++ b/openfisca_core/parameters/__init__.py @@ -21,9 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ParameterNotFound, ParameterParsingError # noqa: F401 - - from .config import ( # noqa: F401 ALLOWED_PARAM_TYPES, COMMON_KEYS, @@ -39,6 +36,6 @@ from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant # noqa: F401 from .parameter import Parameter # noqa: F401 from .parameter_node import ParameterNode # noqa: F401 -from .parameter_scale import ParameterScale, ParameterScale as Scale # noqa: F401 -from .parameter_scale_bracket import ParameterScaleBracket, ParameterScaleBracket as Bracket # noqa: F401 +from .parameter_scale import ParameterScale # noqa: F401 +from .parameter_scale_bracket import ParameterScaleBracket # noqa: F401 from .values_history import ValuesHistory # noqa: F401 diff --git a/openfisca_core/rates.py b/openfisca_core/rates.py deleted file mode 100644 index 9dfbbefcf0..0000000000 --- a/openfisca_core/rates.py +++ /dev/null @@ -1,9 +0,0 @@ -# The formula_helpers module has been deprecated since X.X.X, -# and will be removed in the future. -# -# The helpers have been moved to the commons module. -# -# The following are transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. - -from openfisca_core.commons import average_rate, marginal_rate # noqa: F401 diff --git a/openfisca_core/simulation_builder.py b/openfisca_core/simulation_builder.py deleted file mode 100644 index 57c7765ebe..0000000000 --- a/openfisca_core/simulation_builder.py +++ /dev/null @@ -1,16 +0,0 @@ -# The simulation builder module has been deprecated since X.X.X, -# and will be removed in the future. -# -# Module's contents have been moved to the simulation module. -# -# The following are transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. - -from openfisca_core.simulations import ( # noqa: F401 - Simulation, - SimulationBuilder, - calculate_output_add, - calculate_output_divide, - check_type, - transform_to_strict_syntax, - ) diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 5b02dc1a22..2f7a9c6d51 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,8 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import CycleError, NaNCreationError, SpiralError # noqa: F401 - from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 from .simulation import Simulation # noqa: F401 from .simulation_builder import SimulationBuilder # noqa: F401 diff --git a/openfisca_core/taxbenefitsystems/__init__.py b/openfisca_core/taxbenefitsystems/__init__.py index bf5f224c2c..05a2deb36b 100644 --- a/openfisca_core/taxbenefitsystems/__init__.py +++ b/openfisca_core/taxbenefitsystems/__init__.py @@ -21,6 +21,4 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import VariableNameConflict, VariableNotFound # noqa: F401 - from .tax_benefit_system import TaxBenefitSystem # noqa: F401 diff --git a/openfisca_core/taxscales/__init__.py b/openfisca_core/taxscales/__init__.py index 0364101d71..0e074b2e6e 100644 --- a/openfisca_core/taxscales/__init__.py +++ b/openfisca_core/taxscales/__init__.py @@ -21,8 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import EmptyArgumentError # noqa: F401 - from .helpers import combine_tax_scales # noqa: F401 from .tax_scale_like import TaxScaleLike # noqa: F401 from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 diff --git a/openfisca_core/variables/__init__.py b/openfisca_core/variables/__init__.py index fb36963f7d..3decaf8f42 100644 --- a/openfisca_core/variables/__init__.py +++ b/openfisca_core/variables/__init__.py @@ -24,4 +24,3 @@ from .config import VALUE_TYPES, FORMULA_NAME_PREFIX # noqa: F401 from .helpers import get_annualized_variable, get_neutralized_variable # noqa: F401 from .variable import Variable # noqa: F401 -from .typing import Formula # noqa: F401 diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index 61a5d9274f..acfeb9fe70 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -309,7 +309,7 @@ def get_formula(self, period = None): If no period is given and the variable has several formula, return the oldest formula. :returns: Formula used to compute the variable - :rtype: .Formula + :rtype: callable """ From 6872ae788b54b4ce2d231718c2d71b54f956a865 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 16:54:26 +0200 Subject: [PATCH 31/38] Fix failing tests after expire --- openfisca_core/model_api.py | 8 ++++---- openfisca_core/tools/test_runner.py | 10 +++++----- openfisca_web_api/handlers.py | 2 +- openfisca_web_api/loader/parameters.py | 4 ++-- .../parameter_validation/test_parameter_validation.py | 5 ++++- .../parameters_fancy_indexing/test_fancy_indexing.py | 7 ++++--- .../tax_scales/test_linear_average_rate_tax_scale.py | 5 +++-- .../core/tax_scales/test_marginal_amount_tax_scale.py | 2 +- tests/core/tax_scales/test_marginal_rate_tax_scale.py | 5 +++-- tests/core/tax_scales/test_single_amount_tax_scale.py | 6 +++--- tests/core/test_holders.py | 2 +- tests/core/test_parameters.py | 7 ++++--- tests/core/test_tracers.py | 3 ++- tests/core/tools/test_runner/test_yaml_runner.py | 4 ++-- tests/core/variables/test_variables.py | 2 +- tests/web_api/loader/test_parameters.py | 10 +++++----- 16 files changed, 45 insertions(+), 37 deletions(-) diff --git a/openfisca_core/model_api.py b/openfisca_core/model_api.py index 8ccf5c2763..3140c04d69 100644 --- a/openfisca_core/model_api.py +++ b/openfisca_core/model_api.py @@ -19,12 +19,12 @@ from openfisca_core.indexed_enums import Enum # noqa: F401 from openfisca_core.parameters import ( # noqa: F401 - load_parameter_file, - ParameterNode, - Scale, - Bracket, Parameter, + ParameterNode, + ParameterScale, + ParameterScaleBracket, ValuesHistory, + load_parameter_file, ) from openfisca_core.periods import DAY, MONTH, YEAR, ETERNITY, period # noqa: F401 diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index 1c37ea1469..286ff06991 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -10,8 +10,8 @@ import pytest from openfisca_core.tools import assert_near -from openfisca_core.simulation_builder import SimulationBuilder -from openfisca_core.errors import SituationParsingError, VariableNotFound +from openfisca_core.simulations import SimulationBuilder +from openfisca_core.errors import SituationParsingError, VariableNotFoundError from openfisca_core.warnings import LibYAMLWarning @@ -150,7 +150,7 @@ def runtest(self): try: builder.set_default_period(period) self.simulation = builder.build_from_dict(self.tax_benefit_system, input) - except (VariableNotFound, SituationParsingError): + except (VariableNotFoundError, SituationParsingError): raise except Exception as e: error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) @@ -200,7 +200,7 @@ def check_output(self): entity_index = population.get_index(instance_id) self.check_variable(variable_name, value, self.test.get('period'), entity_index) else: - raise VariableNotFound(key, self.tax_benefit_system) + raise VariableNotFoundError(key, self.tax_benefit_system) def check_variable(self, variable_name, expected_value, period, entity_index = None): if self.should_ignore_variable(variable_name): @@ -231,7 +231,7 @@ def should_ignore_variable(self, variable_name): return variable_ignored or variable_not_tested def repr_failure(self, excinfo): - if not isinstance(excinfo.value, (AssertionError, VariableNotFound, SituationParsingError)): + if not isinstance(excinfo.value, (AssertionError, VariableNotFoundError, SituationParsingError)): return super(YamlItem, self).repr_failure(excinfo) message = excinfo.value.args[0] diff --git a/openfisca_web_api/handlers.py b/openfisca_web_api/handlers.py index 9c8826772c..1a6ace07db 100644 --- a/openfisca_web_api/handlers.py +++ b/openfisca_web_api/handlers.py @@ -2,7 +2,7 @@ import dpath -from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.simulations import SimulationBuilder from openfisca_core.indexed_enums import Enum diff --git a/openfisca_web_api/loader/parameters.py b/openfisca_web_api/loader/parameters.py index 23a5f738b5..39534f972e 100644 --- a/openfisca_web_api/loader/parameters.py +++ b/openfisca_web_api/loader/parameters.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from openfisca_core.parameters import Parameter, ParameterNode, Scale +from openfisca_core.parameters import Parameter, ParameterNode, ParameterScale def build_api_values_history(values_history): @@ -77,7 +77,7 @@ def build_api_parameter(parameter, country_package_metadata): if parameter.documentation: api_parameter['documentation'] = parameter.documentation.strip() api_parameter['values'] = build_api_values_history(parameter) - elif isinstance(parameter, Scale): + elif isinstance(parameter, ParameterScale): if 'rate' in parameter.brackets[0].children: api_parameter['brackets'] = build_api_scale(parameter, 'rate') elif 'amount' in parameter.brackets[0].children: diff --git a/tests/core/parameter_validation/test_parameter_validation.py b/tests/core/parameter_validation/test_parameter_validation.py index 62b2b0c132..561fb28cb1 100644 --- a/tests/core/parameter_validation/test_parameter_validation.py +++ b/tests/core/parameter_validation/test_parameter_validation.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- import os + import pytest -from openfisca_core.parameters import load_parameter_file, ParameterNode, ParameterParsingError + +from openfisca_core.errors import ParameterParsingError +from openfisca_core.parameters import load_parameter_file, ParameterNode BASE_DIR = os.path.dirname(os.path.abspath(__file__)) year = 2016 diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index d34eb00773..41f42fad88 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -7,9 +7,10 @@ import pytest -from openfisca_core.tools import assert_near -from openfisca_core.parameters import ParameterNode, Parameter, ParameterNotFound +from openfisca_core.errors import ParameterNotFoundError from openfisca_core.model_api import * # noqa +from openfisca_core.parameters import ParameterNode, Parameter +from openfisca_core.tools import assert_near LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -59,7 +60,7 @@ def test_triple_fancy_indexing(): def test_wrong_key(): zone = np.asarray(['z1', 'z2', 'z2', 'toto']) - with pytest.raises(ParameterNotFound) as e: + with pytest.raises(ParameterNotFoundError) as e: P.single.owner[zone] assert "'rate.single.owner.toto' was not found" in get_message(e.value) diff --git a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py index 83153024c7..74b2762963 100644 --- a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py +++ b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py @@ -2,6 +2,7 @@ from openfisca_core import taxscales from openfisca_core import tools +from openfisca_core.errors import EmptyArgumentError import pytest @@ -49,7 +50,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - with pytest.raises(taxscales.EmptyArgumentError): + with pytest.raises(EmptyArgumentError): tax_scale.bracket_indices(tax_base) @@ -57,7 +58,7 @@ def test_bracket_indices_without_brackets(): tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() - with pytest.raises(taxscales.EmptyArgumentError): + with pytest.raises(EmptyArgumentError): tax_scale.bracket_indices(tax_base) diff --git a/tests/core/tax_scales/test_marginal_amount_tax_scale.py b/tests/core/tax_scales/test_marginal_amount_tax_scale.py index 7582d725b4..cdd7cc4f27 100644 --- a/tests/core/tax_scales/test_marginal_amount_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_amount_tax_scale.py @@ -35,7 +35,7 @@ def test_calc(): # TODO: move, as we're testing Scale, not MarginalAmountTaxScale def test_dispatch_scale_type_on_creation(data): - scale = parameters.Scale("amount_scale", data, "") + scale = parameters.ParameterScale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) result = scale.get_at_instant(first_jan) diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py index 1688e7e3cc..505d103348 100644 --- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py @@ -2,6 +2,7 @@ from openfisca_core import taxscales from openfisca_core import tools +from openfisca_core.errors import EmptyArgumentError import pytest @@ -49,7 +50,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - with pytest.raises(taxscales.EmptyArgumentError): + with pytest.raises(EmptyArgumentError): tax_scale.bracket_indices(tax_base) @@ -57,7 +58,7 @@ def test_bracket_indices_without_brackets(): tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() - with pytest.raises(taxscales.EmptyArgumentError): + with pytest.raises(EmptyArgumentError): tax_scale.bracket_indices(tax_base) diff --git a/tests/core/tax_scales/test_single_amount_tax_scale.py b/tests/core/tax_scales/test_single_amount_tax_scale.py index c5e6483a7d..0eb63c1f26 100644 --- a/tests/core/tax_scales/test_single_amount_tax_scale.py +++ b/tests/core/tax_scales/test_single_amount_tax_scale.py @@ -49,7 +49,7 @@ def test_to_dict(): # TODO: move, as we're testing Scale, not SingleAmountTaxScale def test_assign_thresholds_on_creation(data): - scale = parameters.Scale("amount_scale", data, "") + scale = parameters.ParameterScale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -60,7 +60,7 @@ def test_assign_thresholds_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale def test_assign_amounts_on_creation(data): - scale = parameters.Scale("amount_scale", data, "") + scale = parameters.ParameterScale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -71,7 +71,7 @@ def test_assign_amounts_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale def test_dispatch_scale_type_on_creation(data): - scale = parameters.Scale("amount_scale", data, "") + scale = parameters.ParameterScale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) result = scale.get_at_instant(first_jan) diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index cd26231037..d06ce34f04 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -7,7 +7,7 @@ from openfisca_core import holders, periods, tools from openfisca_core.errors import PeriodMismatchError -from openfisca_core.memory_config import MemoryConfig +from openfisca_core.experimental import MemoryConfig from openfisca_core.simulations import SimulationBuilder from openfisca_core.holders import Holder diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 40d8bb3fc9..4f74f9d907 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -2,7 +2,8 @@ import pytest -from openfisca_core.parameters import ParameterNotFound, ParameterNode, ParameterNodeAtInstant, load_parameter_file +from openfisca_core.errors import ParameterNotFoundError +from openfisca_core.parameters import ParameterNode, ParameterNodeAtInstant, load_parameter_file def test_get_at_instant(tax_benefit_system): @@ -27,7 +28,7 @@ def test_param_values(tax_benefit_system): def test_param_before_it_is_defined(tax_benefit_system): - with pytest.raises(ParameterNotFound): + with pytest.raises(ParameterNotFoundError): tax_benefit_system.get_parameters_at_instant('1997-12-31').taxes.income_tax_rate @@ -41,7 +42,7 @@ def test_stopped_parameter_before_end_value(tax_benefit_system): def test_stopped_parameter_after_end_value(tax_benefit_system): - with pytest.raises(ParameterNotFound): + with pytest.raises(ParameterNotFoundError): tax_benefit_system.get_parameters_at_instant('2016-12-01').benefits.housing_allowance diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index 2e3d8dbb56..383723d20b 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -6,7 +6,8 @@ import numpy as np from pytest import fixture, mark, raises, approx -from openfisca_core.simulations import Simulation, CycleError, SpiralError +from openfisca_core.errors import CycleError, SpiralError +from openfisca_core.simulations import Simulation from openfisca_core.tracers import SimpleTracer, FullTracer, TracingParameterNodeAtInstant, TraceNode from openfisca_country_template.variables.housing import HousingOccupancyStatus from .parameters_fancy_indexing.test_fancy_indexing import parameters diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index bd7aaccad7..a8cd55c154 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -5,7 +5,7 @@ import numpy as np from openfisca_core.tools.test_runner import _get_tax_benefit_system, YamlItem, YamlFile -from openfisca_core.errors import VariableNotFound +from openfisca_core.errors import VariableNotFoundError from openfisca_core.variables import Variable from openfisca_core.populations import Population from openfisca_core.entities import Entity @@ -86,7 +86,7 @@ def __init__(self): def test_variable_not_found(): test = {"output": {"unknown_variable": 0}} - with pytest.raises(VariableNotFound) as excinfo: + with pytest.raises(VariableNotFoundError) as excinfo: test_item = TestItem(test) test_item.check_output() assert excinfo.value.variable_name == "unknown_variable" diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index 876145bde1..f01ce7c480 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -4,7 +4,7 @@ from openfisca_core.model_api import Variable from openfisca_core.periods import MONTH, ETERNITY -from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.simulations import SimulationBuilder from openfisca_core.tools import assert_near import openfisca_country_template as country_template diff --git a/tests/web_api/loader/test_parameters.py b/tests/web_api/loader/test_parameters.py index e17472a9d6..232bd24c26 100644 --- a/tests/web_api/loader/test_parameters.py +++ b/tests/web_api/loader/test_parameters.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from openfisca_core.parameters import Scale +from openfisca_core.parameters import ParameterScale from openfisca_web_api.loader.parameters import build_api_scale, build_api_parameter @@ -8,21 +8,21 @@ def test_build_rate_scale(): '''Extracts a 'rate' children from a bracket collection''' data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = Scale('this rate', data, None) + rate = ParameterScale('this rate', data, None) assert build_api_scale(rate, 'rate') == {'2014-01-01': {1: 0.5}} def test_build_amount_scale(): '''Extracts an 'amount' children from a bracket collection''' data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = Scale('that amount', data, None) + rate = ParameterScale('that amount', data, None) assert build_api_scale(rate, 'amount') == {'2014-01-01': {1: 0}} def test_full_rate_scale(): '''Serializes a 'rate' scale parameter''' data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = Scale('rate', data, None) + scale = ParameterScale('rate', data, None) api_scale = build_api_parameter(scale, {}) assert api_scale == {'description': None, 'id': 'rate', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0.5}}} @@ -30,6 +30,6 @@ def test_full_rate_scale(): def test_walk_node_amount_scale(): '''Serializes an 'amount' scale parameter ''' data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = Scale('amount', data, None) + scale = ParameterScale('amount', data, None) api_scale = build_api_parameter(scale, {}) assert api_scale == {'description': None, 'id': 'amount', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0}}} From 64007a237a125943c2d77328f097bd3ae7e8bab2 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 16:54:51 +0200 Subject: [PATCH 32/38] Remove outdated instruction from circle --- .circleci/config.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e581c5555e..f1ff914272 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -40,8 +40,6 @@ jobs: command: | make install-deps make install-dev - # pip install --editable git+https://github.com/openfisca/country-template.git@BRANCH_NAME#egg=OpenFisca-Country-Template # use a specific branch of OpenFisca-Country-Template - # pip install --editable git+https://github.com/openfisca/extension-template.git@BRANCH_NAME#egg=OpenFisca-Extension-Template # use a specific branch of OpenFisca-Extension-Template - save_cache: key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} From 6d2c415bb3ae7e3a01404bc002a7fdec7d1bf9fe Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 23:08:16 +0200 Subject: [PATCH 33/38] Update to major version 36.0.0 --- CHANGELOG.md | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27d95ddaa3..6d20703016 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,59 @@ # Changelog +# 36.0.0 [#1015](https://github.com/openfisca/openfisca-core/pull/1015) + +#### Technical changes + +- Extract requirements to separate files for easier contract enforcement. +- Add explicit contract regarding supported dependencies. +- Add constraint file to test against lower-bound NumPy. +- Add extra dependencies. + - Add coveralls (latest) to extra requirements. + - Add twine (latest) to extra requirements. + - Add wheel (latest) to extra requirements. +- Pin non-distribution dependencies. + - Pin autopep8 at latest. + - Pin flake8 at latest. + - Pin flake8-bugbear at latest. + - Pin flake8-print at latest. + - Pin pytest-cov at latest. + - Pin mypy at latest. + - Pin flask at 1.1.2. + - Pin gunicorn at 20.1.0. + - Pin flask-cors at 3.0.10. + - Pin werkzeug at 1.0.1. +- Relax distrubution dependencies. + - Set dpath at >= 1.3.2, < 2. + - Set psutil at >= 5.4.7, < 6. + - Set sortedcontainers at >= 2, < 3. +- Relax circular dependencies. + - Relax openfisca-country-template. + - Relax openfisca-extension-template. + +#### Breaking changes + +- Drop support for Python < 3.7. + - Python 3.7 [introduces backwards incompatible syntax changes](https://docs.python.org/3/whatsnew/3.7.html) that might be used in your country models. +- Drop support for numexpr < 2.7.1. + - numexpr 2.7.1 [introduces no breaking changes](https://numexpr.readthedocs.io/projects/NumExpr3/en/latest/release_notes.html#changes-from-2-7-0-to-2-7-1). +- Drop support for NumPy < 1.17 + - NumPy 1.12 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.12.0-notes.html#compatibility-notes) that might be used in your country models. + - NumPy 1.13 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.13.0-notes.html#compatibility-notes) that might be used in your country models. + - NumPy 1.14 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.14.0-notes.html#compatibility-notes) that might be used in your country models. + - NumPy 1.15 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.15.0-notes.html#compatibility-notes) that might be used in your country models. + - NumPy 1.16 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.16.0-notes.html#expired-deprecations) that might be used in your country models. + - NumPy 1.17 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.17.0-notes.html#compatibility-notes) that might be used in your country models. +- Drop support for pytest < 5.4.2. + - pytest 5 [introduces a list of removals and deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-0-0-2019-06-28) that might be used in your country models. + - pytest 5.1 [introduces a list of removals](https://docs.pytest.org/en/stable/changelog.html#pytest-5-1-0-2019-08-15) that might be used in your country models. + - pytest 5.2 [introduces a list of deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-2-0-2019-09-28) that might be used in your country models. + - pytest 5.3 [introduces a list of deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-3-0-2019-11-19) that might be used in your country models. + - pytest 5.4 [introduces a list of breaking changes and deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-3-0-2019-11-19) that might be used in your country models. + - pytest 5.4.1 [introduces no breaking changes](https://docs.pytest.org/en/stable/changelog.html#pytest-5-4-1-2020-03-13). + - pytest 5.4.2 [introduces no breaking changes](https://docs.pytest.org/en/stable/changelog.html#pytest-5-4-2-2020-05-08). +- Drop support for PyYAML < 5.1. + - PyYAML 5.1 [introduces some breaking changes](https://github.com/yaml/pyyaml/blob/ee37f4653c08fc07aecff69cfd92848e6b1a540e/CHANGES#L66-L97) that might be used in your country models. + ### 35.7.1 [#1075](https://github.com/openfisca/openfisca-core/pull/1075) #### Bug fix From 6f73e98a9958cb05a7d81e5c8f28291f3c929c07 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 17:17:14 +0200 Subject: [PATCH 34/38] Fix circleci config --- .circleci/config.yml | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f1ff914272..b107a97b1b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -51,6 +51,7 @@ jobs: - image: python:3.7 environment: TERM: xterm-256color # To colorize output of make tasks. + PYTEST_ADDOPTS: --exitfirst steps: - checkout @@ -72,19 +73,15 @@ jobs: - run: name: Run openfisca-core tests - command: make test-core pytest_args="--exitfirst" + command: make test-core - run: name: Run country-template tests - command: make test-country pytest_args="--exitfirst" + command: make test-country - run: name: Run extension-template tests - command: make test-extension pytest_args="--exitfirst" - - - run: - name: Run core tests - command: make test + command: make test-extension - persist_to_workspace: root: . @@ -137,8 +134,8 @@ jobs: test_compatibility: docker: - image: python:3.7 - environment: + TERM: xterm-256color # To colorize output of make tasks. PYTEST_ADDOPTS: --exitfirst steps: @@ -158,8 +155,20 @@ jobs: make install-compat - run: - name: Run core tests - command: make test + name: Run linters + command: make lint + + - run: + name: Run openfisca-core tests + command: make test-core + + - run: + name: Run country-template tests + command: make test-country + + - run: + name: Run extension-template tests + command: make test-extension submit_coverage: docker: From 23f4aff36c27f848288003aacf20af6ee52f0611 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 17:32:46 +0200 Subject: [PATCH 35/38] Remove deprecation leftovers --- openfisca_core/commons/dummy.py | 23 ---------------------- openfisca_core/commons/tests/test_dummy.py | 10 ---------- openfisca_core/variables/typing.py | 16 --------------- 3 files changed, 49 deletions(-) delete mode 100644 openfisca_core/commons/dummy.py delete mode 100644 openfisca_core/commons/tests/test_dummy.py delete mode 100644 openfisca_core/variables/typing.py diff --git a/openfisca_core/commons/dummy.py b/openfisca_core/commons/dummy.py deleted file mode 100644 index 5f1b0be330..0000000000 --- a/openfisca_core/commons/dummy.py +++ /dev/null @@ -1,23 +0,0 @@ -import warnings - - -class Dummy: - """A class that did nothing. - - Examples: - >>> Dummy() - None: - message = [ - "The 'Dummy' class has been deprecated since version 34.7.0,", - "and will be removed in the future.", - ] - warnings.warn(" ".join(message), DeprecationWarning) - pass diff --git a/openfisca_core/commons/tests/test_dummy.py b/openfisca_core/commons/tests/test_dummy.py deleted file mode 100644 index d4ecec3842..0000000000 --- a/openfisca_core/commons/tests/test_dummy.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest - -from openfisca_core.commons import Dummy - - -def test_dummy_deprecation(): - """Dummy throws a deprecation warning when instantiated.""" - - with pytest.warns(DeprecationWarning): - assert Dummy() diff --git a/openfisca_core/variables/typing.py b/openfisca_core/variables/typing.py deleted file mode 100644 index 892ec0bf9f..0000000000 --- a/openfisca_core/variables/typing.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Callable, Union - -import numpy - -from openfisca_core.parameters import ParameterNodeAtInstant -from openfisca_core.periods import Instant, Period -from openfisca_core.populations import Population, GroupPopulation - -#: A collection of :obj:`.Entity` or :obj:`.GroupEntity`. -People = Union[Population, GroupPopulation] - -#: A callable to get the parameters for the given instant. -Params = Callable[[Instant], ParameterNodeAtInstant] - -#: A callable defining a calculation, or a rule, on a system. -Formula = Callable[[People, Period, Params], numpy.ndarray] From 2b322a38d6e56d0d7b0f522338c1a2f4ba7ea9d0 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Tue, 26 Oct 2021 17:55:25 +0200 Subject: [PATCH 36/38] Update tracker installation instructions --- README.md | 5 ++--- openfisca_tasks/install.mk | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5f060caf23..c38a614754 100644 --- a/README.md +++ b/README.md @@ -199,9 +199,8 @@ pip install openfisca_core[tracker] Or for an editable installation: ``` -pip install --requirement requirements/tracker --upgrade -pip install --requirement requirements/dev --upgrade -pip install --editable . --upgrade --no-dependencies +make install +make install-tracker ``` #### Tracker configuration diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk index f7d1f8074b..ac6bf1b34a 100644 --- a/openfisca_tasks/install.mk +++ b/openfisca_tasks/install.mk @@ -22,6 +22,11 @@ install-core: @pip uninstall --quiet --yes openfisca-core @pip install --quiet --no-dependencies --editable . +## Install the WebAPI tracker. +install-tracker: + @$(call print_help,$@:) + @pip install --quiet --upgrade --constraint requirements/tracker openfisca-tracker + ## Install lower-bound dependencies for compatibility check. install-compat: @$(call print_help,$@:) From 24522416fdc1c6eaba2fbe5390e58555ecdbdbe9 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Sat, 1 May 2021 23:08:16 +0200 Subject: [PATCH 37/38] Update to major version 36.0.0 --- CHANGELOG.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d20703016..502ca99d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,27 @@ - Drop support for PyYAML < 5.1. - PyYAML 5.1 [introduces some breaking changes](https://github.com/yaml/pyyaml/blob/ee37f4653c08fc07aecff69cfd92848e6b1a540e/CHANGES#L66-L97) that might be used in your country models. +#### Expired deprecations + +- `openfisca_core.commons.Dummy` => `openfisca_core.commons.empty_clone` +- `openfisca_core.errors.ParameterNotFound` => `openfisca_core.errors.ParameterNotFoundError` +- `openfisca_core.errors.VariableNameConflict` => `openfisca_core.errors.VariableNameConflictError` +- `openfisca_core.errors.VariableNotFound` => `openfisca_core.errors.VariableNotFoundError` +- `openfisca_core.formula_helpers.py` => `openfisca_core.commons` +- `openfisca_core.memory_config.py` => `openfisca_core.experimental` +- `openfisca_core.parameters.Bracket` => `openfisca_core.errors.ParameterScaleBracket` +- `openfisca_core.parameters.ParameterNotFound` => `openfisca_core.errors.ParameterNotFoundError` +- `openfisca_core.parameters.ParameterParsingError` => `openfisca_core.errors.ParameterParsingError` +- `openfisca_core.parameters.Scale` => `openfisca_core.errors.ParameterScale` +- `openfisca_core.rates` => `openfisca_core.commons` +- `openfisca_core.simulation_builder` => `openfisca_core.simulations` +- `openfisca_core.simulations.CycleError` => `openfisca_core.errors.CycleError` +- `openfisca_core.simulations.NaNCreationError` => `openfisca_core.errors.NaNCreationError` +- `openfisca_core.simulations.SpiralError` => `openfisca_core.errors.SpiralError` +- `openfisca_core.taxbenefitsystems.VariableNameConflict` => `openfisca_core.errors.VariableNameConflictError` +- `openfisca_core.taxbenefitsystems.VariableNotFound` => `openfisca_core.errors.VariableNotFoundError` +- `openfisca_core.taxscales.EmptyArgumentError` => `openfisca_core.errors.EmptyArgumentError` + ### 35.7.1 [#1075](https://github.com/openfisca/openfisca-core/pull/1075) #### Bug fix From d088e40519095127985e7eedb9a6553626aac889 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Fri, 25 Oct 2024 21:49:03 +0200 Subject: [PATCH 38/38] chore: backport changes --- .circleci/config.yml | 262 ---- .circleci/publish-git-tag.sh | 4 - .conda/README.md | 37 + .conda/openfisca-core/conda_build_config.yaml | 9 + .conda/openfisca-core/meta.yaml | 89 ++ .conda/openfisca-country-template/recipe.yaml | 38 + .../openfisca-country-template/variants.yaml | 7 + .../openfisca-extension-template/recipe.yaml | 39 + .../variants.yaml | 7 + .conda/pylint-per-file-ignores/recipe.yaml | 41 + .conda/pylint-per-file-ignores/variants.yaml | 4 + .github/dependabot.yml | 2 +- .github/get_pypi_info.py | 83 ++ .../has-functional-changes.sh | 2 +- .../is-version-number-acceptable.sh | 2 +- .github/workflows/_before-conda.yaml | 109 ++ .github/workflows/_before-pip.yaml | 103 ++ .github/workflows/_lint-pip.yaml | 57 + .github/workflows/_test-conda.yaml | 76 ++ .github/workflows/_test-pip.yaml | 71 + .github/workflows/_version.yaml | 38 + .github/workflows/merge.yaml | 250 ++++ .github/workflows/push.yaml | 89 ++ .gitignore | 3 +- CHANGELOG.md | 799 +++++++++-- CONTRIBUTING.md | 2 +- MANIFEST.in | 3 +- Makefile | 3 +- README.md | 152 +-- STYLEGUIDE.md | 2 +- conftest.py | 3 +- openfisca_core/commons/__init__.py | 76 +- openfisca_core/commons/formulas.py | 89 +- openfisca_core/commons/misc.py | 61 +- .../commons/py.typed | 0 openfisca_core/commons/rates.py | 73 +- openfisca_core/commons/tests/test_formulas.py | 38 +- openfisca_core/commons/tests/test_rates.py | 16 +- openfisca_core/commons/types.py | 3 + openfisca_core/data_storage/__init__.py | 30 +- .../data_storage/in_memory_storage.py | 178 ++- .../data_storage/on_disk_storage.py | 270 +++- openfisca_core/data_storage/types.py | 14 + openfisca_core/entities/__init__.py | 48 +- openfisca_core/entities/_core_entity.py | 218 +++ openfisca_core/entities/_description.py | 55 + openfisca_core/entities/entity.py | 85 +- openfisca_core/entities/group_entity.py | 127 +- openfisca_core/entities/helpers.py | 167 ++- .../entities/py.typed | 0 openfisca_core/entities/role.py | 98 +- openfisca_core/entities/tests/__init__.py | 0 openfisca_core/entities/tests/test_entity.py | 10 + .../entities/tests/test_group_entity.py | 70 + openfisca_core/entities/tests/test_role.py | 11 + openfisca_core/entities/types.py | 42 + openfisca_core/errors/__init__.py | 33 +- openfisca_core/errors/cycle_error.py | 2 - openfisca_core/errors/empty_argument_error.py | 17 +- openfisca_core/errors/nan_creation_error.py | 2 - .../errors/parameter_not_found_error.py | 19 +- .../errors/parameter_parsing_error.py | 18 +- .../errors/period_mismatch_error.py | 6 +- .../errors/situation_parsing_error.py | 23 +- openfisca_core/errors/spiral_error.py | 2 - .../errors/variable_name_config_error.py | 6 +- .../errors/variable_not_found_error.py | 31 +- openfisca_core/experimental/__init__.py | 31 +- openfisca_core/experimental/_errors.py | 5 + openfisca_core/experimental/_memory_config.py | 42 + openfisca_core/experimental/memory_config.py | 24 - openfisca_core/holders/__init__.py | 12 +- openfisca_core/holders/helpers.py | 49 +- openfisca_core/holders/holder.py | 319 +++-- openfisca_core/holders/tests/__init__.py | 0 openfisca_core/holders/tests/test_helpers.py | 134 ++ openfisca_core/holders/types.py | 3 + openfisca_core/indexed_enums/__init__.py | 42 +- openfisca_core/indexed_enums/_enum_type.py | 70 + openfisca_core/indexed_enums/_errors.py | 35 + openfisca_core/indexed_enums/_guards.py | 209 +++ openfisca_core/indexed_enums/_utils.py | 187 +++ openfisca_core/indexed_enums/config.py | 3 + openfisca_core/indexed_enums/enum.py | 290 ++-- openfisca_core/indexed_enums/enum_array.py | 335 ++++- openfisca_core/indexed_enums/py.typed | 0 .../indexed_enums/tests/__init__.py | 0 .../indexed_enums/tests/test_enum.py | 135 ++ .../indexed_enums/tests/test_enum_array.py | 30 + openfisca_core/indexed_enums/types.py | 41 + openfisca_core/model_api.py | 74 +- openfisca_core/parameters/__init__.py | 55 +- openfisca_core/parameters/at_instant_like.py | 7 +- openfisca_core/parameters/config.py | 41 +- openfisca_core/parameters/helpers.py | 89 +- openfisca_core/parameters/parameter.py | 149 +- .../parameters/parameter_at_instant.py | 57 +- openfisca_core/parameters/parameter_node.py | 120 +- .../parameters/parameter_node_at_instant.py | 39 +- openfisca_core/parameters/parameter_scale.py | 107 +- .../parameters/parameter_scale_bracket.py | 6 +- openfisca_core/parameters/values_history.py | 8 +- ...ial_asof_date_parameter_node_at_instant.py | 81 ++ .../vectorial_parameter_node_at_instant.py | 163 ++- openfisca_core/periods/__init__.py | 61 +- openfisca_core/periods/_errors.py | 28 + openfisca_core/periods/_parsers.py | 121 ++ openfisca_core/periods/config.py | 23 +- openfisca_core/periods/date_unit.py | 110 ++ openfisca_core/periods/helpers.py | 448 +++--- openfisca_core/periods/instant_.py | 423 +++--- openfisca_core/periods/period_.py | 1209 +++++++++++------ openfisca_core/periods/py.typed | 0 openfisca_core/periods/tests/__init__.py | 0 .../periods/tests/helpers/__init__.py | 0 .../periods/tests/helpers/test_helpers.py | 65 + .../periods/tests/helpers/test_instant.py | 73 + .../periods/tests/helpers/test_period.py | 134 ++ openfisca_core/periods/tests/test_instant.py | 32 + openfisca_core/periods/tests/test_parsers.py | 129 ++ openfisca_core/periods/tests/test_period.py | 283 ++++ openfisca_core/periods/types.py | 183 +++ openfisca_core/populations/__init__.py | 46 +- .../populations/_core_population.py | 455 +++++++ openfisca_core/populations/_errors.py | 65 + openfisca_core/populations/config.py | 2 - .../populations/group_population.py | 242 ++-- openfisca_core/populations/population.py | 193 +-- openfisca_core/populations/types.py | 103 ++ openfisca_core/projectors/__init__.py | 21 +- .../projectors/entity_to_person_projector.py | 4 +- .../first_person_to_entity_projector.py | 4 +- openfisca_core/projectors/helpers.py | 145 +- openfisca_core/projectors/projector.py | 11 +- openfisca_core/projectors/typing.py | 27 + .../unique_role_to_entity_projector.py | 6 +- openfisca_core/reforms/reform.py | 51 +- openfisca_core/scripts/__init__.py | 72 +- openfisca_core/scripts/find_placeholders.py | 36 +- .../measure_numpy_condition_notations.py | 94 +- .../scripts/measure_performances.py | 226 +-- .../measure_performances_fancy_indexing.py | 108 +- .../xml_to_yaml_country_template.py | 32 +- .../xml_to_yaml_extension_template.py | 25 +- .../scripts/migrations/v24_to_25.py | 91 +- openfisca_core/scripts/openfisca_command.py | 146 +- openfisca_core/scripts/remove_fuzzy.py | 101 +- openfisca_core/scripts/run_test.py | 37 +- .../scripts/simulation_generator.py | 82 +- openfisca_core/simulations/__init__.py | 25 +- .../simulations/_build_default_simulation.py | 159 +++ .../simulations/_build_from_variables.py | 230 ++++ openfisca_core/simulations/_type_guards.py | 298 ++++ openfisca_core/simulations/helpers.py | 106 +- openfisca_core/simulations/simulation.py | 518 ++++--- .../simulations/simulation_builder.py | 774 +++++++---- openfisca_core/simulations/typing.py | 203 +++ openfisca_core/taxbenefitsystems/__init__.py | 2 + .../taxbenefitsystems/tax_benefit_system.py | 498 ++++--- openfisca_core/taxscales/__init__.py | 14 +- .../taxscales/abstract_rate_tax_scale.py | 35 +- .../taxscales/abstract_tax_scale.py | 49 +- .../taxscales/amount_tax_scale_like.py | 43 +- openfisca_core/taxscales/helpers.py | 17 +- .../linear_average_rate_tax_scale.py | 44 +- .../taxscales/marginal_amount_tax_scale.py | 26 +- .../taxscales/marginal_rate_tax_scale.py | 166 ++- .../taxscales/rate_tax_scale_like.py | 140 +- .../taxscales/single_amount_tax_scale.py | 32 +- openfisca_core/taxscales/tax_scale_like.py | 53 +- openfisca_core/tools/__init__.py | 103 +- openfisca_core/tools/simulation_dumper.py | 86 +- openfisca_core/tools/test_runner.py | 357 +++-- openfisca_core/tracers/__init__.py | 26 +- openfisca_core/tracers/computation_log.py | 145 +- openfisca_core/tracers/flat_trace.py | 116 +- openfisca_core/tracers/full_tracer.py | 227 ++-- openfisca_core/tracers/performance_log.py | 123 +- openfisca_core/tracers/simple_tracer.py | 67 +- openfisca_core/tracers/trace_node.py | 127 +- .../tracing_parameter_node_at_instant.py | 63 +- openfisca_core/tracers/types.py | 108 ++ openfisca_core/types.py | 299 ++++ openfisca_core/types/__init__.py | 45 - openfisca_core/types/data_types/__init__.py | 1 - openfisca_core/types/data_types/arrays.py | 51 - openfisca_core/variables/__init__.py | 2 +- openfisca_core/variables/config.py | 75 +- openfisca_core/variables/helpers.py | 40 +- openfisca_core/variables/tests/__init__.py | 0 .../variables/tests/test_definition_period.py | 43 + openfisca_core/variables/variable.py | 327 +++-- openfisca_core/warnings/__init__.py | 1 - openfisca_core/warnings/libyaml_warning.py | 5 +- openfisca_core/warnings/memory_warning.py | 5 - openfisca_core/warnings/tempfile_warning.py | 5 +- openfisca_tasks/install.mk | 55 +- openfisca_tasks/lint.mk | 51 +- openfisca_tasks/publish.mk | 46 +- openfisca_tasks/test_code.mk | 43 +- openfisca_tasks/test_doc.mk | 78 -- openfisca_web_api/app.py | 193 +-- openfisca_web_api/errors.py | 11 +- openfisca_web_api/handlers.py | 75 +- openfisca_web_api/loader/__init__.py | 24 +- openfisca_web_api/loader/entities.py | 33 +- openfisca_web_api/loader/parameters.py | 115 +- openfisca_web_api/loader/spec.py | 199 +-- .../loader/tax_benefit_system.py | 20 +- openfisca_web_api/loader/variables.py | 93 +- openfisca_web_api/openAPI.yml | 542 ++++---- openfisca_web_api/scripts/serve.py | 66 +- pyproject.toml | 14 + setup.cfg | 135 +- setup.py | 160 ++- stubs/numexpr/__init__.pyi | 10 + .../test_parameter_clone.py | 21 +- .../test_parameter_validation.py | 94 +- .../core/parameters_date_indexing/__init__.py | 0 .../full_rate_age.yaml | 121 ++ .../full_rate_required_duration.yml | 162 +++ .../test_date_indexing.py | 48 + .../coefficient_de_minoration.yaml | 135 ++ .../test_fancy_indexing.py | 180 +-- .../test_abstract_rate_tax_scale.py | 8 +- .../tax_scales/test_abstract_tax_scale.py | 8 +- .../test_linear_average_rate_tax_scale.py | 35 +- .../test_marginal_amount_tax_scale.py | 24 +- .../test_marginal_rate_tax_scale.py | 112 +- .../tax_scales/test_rate_tax_scale_like.py | 17 + .../test_single_amount_tax_scale.py | 36 +- .../tax_scales/test_tax_scales_commons.py | 20 +- tests/core/test_axes.py | 412 ++++-- tests/core/test_calculate_output.py | 48 +- tests/core/test_countries.py | 82 +- tests/core/test_cycles.py | 89 +- tests/core/test_dump_restore.py | 30 +- tests/core/test_entities.py | 369 ++--- tests/core/test_extensions.py | 24 +- tests/core/test_formulas.py | 161 +-- tests/core/test_holders.py | 233 ++-- tests/core/test_opt_out_cache.py | 59 +- tests/core/test_parameters.py | 130 +- tests/core/test_periods.py | 203 --- tests/core/test_projectors.py | 315 +++-- tests/core/test_reforms.py | 458 ++++--- tests/core/test_simulation_builder.py | 636 ++++++--- tests/core/test_simulations.py | 77 +- tests/core/test_tracers.py | 453 +++--- tests/core/test_yaml.py | 142 +- tests/core/tools/test_assert_near.py | 26 +- .../tools/test_runner/test_yaml_runner.py | 123 +- tests/core/variables/test_annualize.py | 56 +- .../core/variables/test_definition_period.py | 43 + tests/core/variables/test_variables.py | 430 +++--- tests/fixtures/appclient.py | 8 +- tests/fixtures/entities.py | 28 +- tests/fixtures/extensions.py | 18 + tests/fixtures/simulations.py | 8 +- tests/fixtures/taxbenefitsystems.py | 2 +- tests/fixtures/variables.py | 6 +- .../failing_test_absolute_error_margin.yaml | 11 + .../failing_test_relative_error_margin.yaml | 11 + .../test_absolute_error_margin.yaml | 11 + .../fixtures/yaml_tests/test_name_filter.yaml | 4 +- .../test_relative_error_margin.yaml | 11 + tests/web_api/__init__.py | 4 - tests/web_api/basic_case/__init__.py | 9 - .../case_with_extension/test_extensions.py | 41 +- .../web_api/case_with_reform/test_reforms.py | 39 +- tests/web_api/loader/test_parameters.py | 85 +- tests/web_api/test_calculate.py | 611 +++++---- tests/web_api/test_entities.py | 47 +- tests/web_api/test_headers.py | 19 +- tests/web_api/test_helpers.py | 62 +- tests/web_api/test_parameters.py | 184 ++- tests/web_api/test_spec.py | 88 +- tests/web_api/test_trace.py | 133 +- tests/web_api/test_variables.py | 209 +-- 279 files changed, 19559 insertions(+), 8615 deletions(-) delete mode 100644 .circleci/config.yml delete mode 100755 .circleci/publish-git-tag.sh create mode 100644 .conda/README.md create mode 100644 .conda/openfisca-core/conda_build_config.yaml create mode 100644 .conda/openfisca-core/meta.yaml create mode 100644 .conda/openfisca-country-template/recipe.yaml create mode 100644 .conda/openfisca-country-template/variants.yaml create mode 100644 .conda/openfisca-extension-template/recipe.yaml create mode 100644 .conda/openfisca-extension-template/variants.yaml create mode 100644 .conda/pylint-per-file-ignores/recipe.yaml create mode 100644 .conda/pylint-per-file-ignores/variants.yaml create mode 100644 .github/get_pypi_info.py rename {.circleci => .github}/has-functional-changes.sh (89%) rename {.circleci => .github}/is-version-number-acceptable.sh (95%) create mode 100644 .github/workflows/_before-conda.yaml create mode 100644 .github/workflows/_before-pip.yaml create mode 100644 .github/workflows/_lint-pip.yaml create mode 100644 .github/workflows/_test-conda.yaml create mode 100644 .github/workflows/_test-pip.yaml create mode 100644 .github/workflows/_version.yaml create mode 100644 .github/workflows/merge.yaml create mode 100644 .github/workflows/push.yaml rename tests/web_api/case_with_extension/__init__.py => openfisca_core/commons/py.typed (100%) create mode 100644 openfisca_core/commons/types.py create mode 100644 openfisca_core/data_storage/types.py create mode 100644 openfisca_core/entities/_core_entity.py create mode 100644 openfisca_core/entities/_description.py rename tests/web_api/loader/__init__.py => openfisca_core/entities/py.typed (100%) create mode 100644 openfisca_core/entities/tests/__init__.py create mode 100644 openfisca_core/entities/tests/test_entity.py create mode 100644 openfisca_core/entities/tests/test_group_entity.py create mode 100644 openfisca_core/entities/tests/test_role.py create mode 100644 openfisca_core/entities/types.py create mode 100644 openfisca_core/experimental/_errors.py create mode 100644 openfisca_core/experimental/_memory_config.py delete mode 100644 openfisca_core/experimental/memory_config.py create mode 100644 openfisca_core/holders/tests/__init__.py create mode 100644 openfisca_core/holders/tests/test_helpers.py create mode 100644 openfisca_core/holders/types.py create mode 100644 openfisca_core/indexed_enums/_enum_type.py create mode 100644 openfisca_core/indexed_enums/_errors.py create mode 100644 openfisca_core/indexed_enums/_guards.py create mode 100644 openfisca_core/indexed_enums/_utils.py create mode 100644 openfisca_core/indexed_enums/py.typed create mode 100644 openfisca_core/indexed_enums/tests/__init__.py create mode 100644 openfisca_core/indexed_enums/tests/test_enum.py create mode 100644 openfisca_core/indexed_enums/tests/test_enum_array.py create mode 100644 openfisca_core/indexed_enums/types.py create mode 100644 openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py create mode 100644 openfisca_core/periods/_errors.py create mode 100644 openfisca_core/periods/_parsers.py create mode 100644 openfisca_core/periods/date_unit.py create mode 100644 openfisca_core/periods/py.typed create mode 100644 openfisca_core/periods/tests/__init__.py create mode 100644 openfisca_core/periods/tests/helpers/__init__.py create mode 100644 openfisca_core/periods/tests/helpers/test_helpers.py create mode 100644 openfisca_core/periods/tests/helpers/test_instant.py create mode 100644 openfisca_core/periods/tests/helpers/test_period.py create mode 100644 openfisca_core/periods/tests/test_instant.py create mode 100644 openfisca_core/periods/tests/test_parsers.py create mode 100644 openfisca_core/periods/tests/test_period.py create mode 100644 openfisca_core/periods/types.py create mode 100644 openfisca_core/populations/_core_population.py create mode 100644 openfisca_core/populations/_errors.py delete mode 100644 openfisca_core/populations/config.py create mode 100644 openfisca_core/populations/types.py create mode 100644 openfisca_core/projectors/typing.py create mode 100644 openfisca_core/simulations/_build_default_simulation.py create mode 100644 openfisca_core/simulations/_build_from_variables.py create mode 100644 openfisca_core/simulations/_type_guards.py create mode 100644 openfisca_core/simulations/typing.py create mode 100644 openfisca_core/tracers/types.py create mode 100644 openfisca_core/types.py delete mode 100644 openfisca_core/types/__init__.py delete mode 100644 openfisca_core/types/data_types/__init__.py delete mode 100644 openfisca_core/types/data_types/arrays.py create mode 100644 openfisca_core/variables/tests/__init__.py create mode 100644 openfisca_core/variables/tests/test_definition_period.py delete mode 100644 openfisca_core/warnings/memory_warning.py delete mode 100644 openfisca_tasks/test_doc.mk create mode 100644 pyproject.toml create mode 100644 stubs/numexpr/__init__.pyi create mode 100644 tests/core/parameters_date_indexing/__init__.py create mode 100644 tests/core/parameters_date_indexing/full_rate_age.yaml create mode 100644 tests/core/parameters_date_indexing/full_rate_required_duration.yml create mode 100644 tests/core/parameters_date_indexing/test_date_indexing.py create mode 100644 tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml create mode 100644 tests/core/tax_scales/test_rate_tax_scale_like.py delete mode 100644 tests/core/test_periods.py create mode 100644 tests/core/variables/test_definition_period.py create mode 100644 tests/fixtures/extensions.py delete mode 100644 tests/web_api/basic_case/__init__.py diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index b107a97b1b..0000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,262 +0,0 @@ -version: 2 -jobs: - check_version: - docker: - - image: python:3.7 - - steps: - - checkout - - - run: - name: Check version number has been properly updated - command: | - git fetch - .circleci/is-version-number-acceptable.sh - - build: - docker: - - image: python:3.7 - environment: - TERM: xterm-256color # To colorize output of make tasks. - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Create a virtualenv - command: | - mkdir -p /tmp/venv/openfisca_core - python -m venv /tmp/venv/openfisca_core - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Install dependencies - command: | - make install-deps - make install-dev - - - save_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - paths: - - /tmp/venv/openfisca_core - - test: - docker: - - image: python:3.7 - environment: - TERM: xterm-256color # To colorize output of make tasks. - PYTEST_ADDOPTS: --exitfirst - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Install core - command: make install-core - - - run: - name: Run linters - command: make lint - - - run: - name: Run openfisca-core tests - command: make test-core - - - run: - name: Run country-template tests - command: make test-country - - - run: - name: Run extension-template tests - command: make test-extension - - - persist_to_workspace: - root: . - paths: - - .coverage - - test_docs: - docker: - - image: python:3.7 - environment: - TERM: xterm-256color # To colorize output of make tasks. - - steps: - - checkout - - - run: - name: Checkout docs - command: make test-doc-checkout branch=$CIRCLE_BRANCH - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - restore_cache: - key: v1-py3-docs-{{ .Branch }}-{{ checksum "doc/requirements.txt" }} - - - run: - name: Create a virtualenv - command: | - mkdir -p /tmp/venv/openfisca_doc - python -m venv /tmp/venv/openfisca_doc - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_doc/bin/activate" >> $BASH_ENV - - - run: - name: Install dependencies - command: make test-doc-install - - - save_cache: - key: v1-py3-docs-{{ .Branch }}-{{ checksum "doc/requirements.txt" }} - paths: - - /tmp/venv/openfisca_doc - - - run: - name: Run doc tests - command: make test-doc-build - - - test_compatibility: - docker: - - image: python:3.7 - environment: - TERM: xterm-256color # To colorize output of make tasks. - PYTEST_ADDOPTS: --exitfirst - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Install core with a constrained Numpy version - command: | - make install-core - make install-compat - - - run: - name: Run linters - command: make lint - - - run: - name: Run openfisca-core tests - command: make test-core - - - run: - name: Run country-template tests - command: make test-country - - - run: - name: Run extension-template tests - command: make test-extension - - submit_coverage: - docker: - - image: python:3.7 - - steps: - - checkout - - - attach_workspace: - at: . - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Submit coverage to Coveralls - command: | - make install-core - make install-cov - coveralls - - deploy: - docker: - - image: python:3.7 - - environment: - PYPI_USERNAME: openfisca-bot - # PYPI_PASSWORD: this value is set in CircleCI's web interface; do not set it here, it is a secret! - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Check for functional changes - command: if ! .circleci/has-functional-changes.sh ; then circleci step halt ; fi - - - run: - name: Activate virtualenv - command: echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Upload a Python package to PyPi - command: | - make build - make publish - - - run: - name: Publish a git tag - command: .circleci/publish-git-tag.sh - - - run: - name: Update doc - command: | - curl -X POST --header "Content-Type: application/json" -d '{"branch":"master"}' https://circleci.com/api/v1.1/project/github/openfisca/openfisca-doc/build?circle-token=$CIRCLE_TOKEN - -workflows: - version: 2 - build_and_deploy: - jobs: - - check_version - - build - - test: - requires: - - build - - test_docs: - requires: - - build - - test_compatibility: - requires: - - test_docs - - submit_coverage: - requires: - - test - - test_docs - - test_compatibility - - deploy: - requires: - - check_version - - test - - test_docs - - test_compatibility - filters: - branches: - only: master diff --git a/.circleci/publish-git-tag.sh b/.circleci/publish-git-tag.sh deleted file mode 100755 index 4450357cbc..0000000000 --- a/.circleci/publish-git-tag.sh +++ /dev/null @@ -1,4 +0,0 @@ -#! /usr/bin/env bash - -git tag `python setup.py --version` -git push --tags # update the repository version diff --git a/.conda/README.md b/.conda/README.md new file mode 100644 index 0000000000..ac0d2c2be5 --- /dev/null +++ b/.conda/README.md @@ -0,0 +1,37 @@ +# Publish OpenFisca-Core to conda + +We use two systems to publish to conda: +- A fully automatic in OpenFisca-Core CI that publishes to an `openfisca` channel. See below for more information. +- A more complex in Conda-Forge CI, that publishes to [Conda-Forge](https://conda-forge.org). See this [YouTube video](https://www.youtube.com/watch?v=N2XwK9BkJpA) as an introduction to Conda-Forge, and [openfisca-core-feedstock repository](https://github.com/openfisca/openfisca-core-feedstock) for the project publishing process on Conda-Forge. + +We use both channels. With conda-forge users get an easier way to install and use openfisca-core: conda-forge is the default channel in Anaconda and it allows for publishing packages that depend on openfisca-core to conda-forge. + + +## Automatic upload + +The CI automatically uploads the PyPi package; see the `.github/workflow.yml`, step `publish-to-conda`. + +## Manual actions for first time publishing + +- Create an account on https://anaconda.org. +- Create a token on https://anaconda.org/openfisca/settings/access with `Allow write access to the API site`. Warning, it expires on 2023/01/13. + +- Put the token in a CI environment variable named `ANACONDA_TOKEN`. + + +## Manual actions to test before CI + +Everything is done by the CI but if you want to test it locally, here is how to do it. + +Do the following in the project root folder: + +- Auto-update `.conda/meta.yaml` with last infos from pypi by running: + - `python .github/get_pypi_info.py -p OpenFisca-Core` + +- Build package: + - `conda install -c anaconda conda-build anaconda-client` (`conda-build` to build the package and [anaconda-client](https://github.com/Anaconda-Platform/anaconda-client) to push the package to anaconda.org) + - `conda build -c conda-forge .conda` + + - Upload the package to Anaconda.org, but DON'T do it if you don't want to publish your locally built package as official openfisca-core library: + - `anaconda login` + - `anaconda upload openfisca-core--py_0.tar.bz2` diff --git a/.conda/openfisca-core/conda_build_config.yaml b/.conda/openfisca-core/conda_build_config.yaml new file mode 100644 index 0000000000..02754f3894 --- /dev/null +++ b/.conda/openfisca-core/conda_build_config.yaml @@ -0,0 +1,9 @@ +numpy: +- 1.24 +- 1.25 +- 1.26 + +python: +- 3.9 +- 3.10 +- 3.11 diff --git a/.conda/openfisca-core/meta.yaml b/.conda/openfisca-core/meta.yaml new file mode 100644 index 0000000000..1c90e6191e --- /dev/null +++ b/.conda/openfisca-core/meta.yaml @@ -0,0 +1,89 @@ +############################################################################### +## File for Anaconda.org +## It use Jinja2 templating code to retreive information from setup.py +############################################################################### + +{% set name = "OpenFisca-Core" %} +{% set data = load_setup_py_data() %} +{% set version = data.get('version') %} + +package: + name: {{ name|lower }} + version: {{ version }} + +source: + path: ../.. + +build: + noarch: python + number: 0 + script: "{{ PYTHON }} -m pip install . -vv" + entry_points: + - openfisca = openfisca_core.scripts.openfisca_command:main + - openfisca-run-test = openfisca_core.scripts.openfisca_command:main + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + {% for req in data['install_requires'] %} + {% if not req.startswith('numpy') %} + - {{ req }} + {% endif %} + {% endfor %} + +test: + imports: + - openfisca_core + - openfisca_core.commons + +outputs: + - name: openfisca-core + type: conda_v2 + + - name: openfisca-core-api + type: conda_v2 + build: + noarch: python + requirements: + host: + - numpy + - python + run: + - numpy + - python + {% for req in data['extras_require']['web-api'] %} + - {{ req }} + {% endfor %} + - {{ pin_subpackage('openfisca-core', exact=True) }} + + - name: openfisca-core-dev + type: conda_v2 + build: + noarch: python + requirements: + host: + - numpy + - python + run: + - numpy + - python + {% for req in data['extras_require']['dev'] %} + - {{ req }} + {% endfor %} + - {{ pin_subpackage('openfisca-core-api', exact=True) }} + +about: + home: https://openfisca.org + license_family: AGPL + license: AGPL-3.0-only + license_file: LICENSE + summary: "A versatile microsimulation free software" + doc_url: https://openfisca.org + dev_url: https://github.com/openfisca/openfisca-core/ + description: This package contains the core features of OpenFisca, which are meant to be used by country packages such as OpenFisca-Country-Template. diff --git a/.conda/openfisca-country-template/recipe.yaml b/.conda/openfisca-country-template/recipe.yaml new file mode 100644 index 0000000000..871e591708 --- /dev/null +++ b/.conda/openfisca-country-template/recipe.yaml @@ -0,0 +1,38 @@ +schema_version: 1 + +context: + name: openfisca-country-template + version: 7.1.5 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_country_template-${{ version }}.tar.gz + sha256: b2f2ac9945d9ccad467aed0925bd82f7f4d5ce4e96b212324cd071b8bee46914 + +build: + number: 2 + noarch: python + script: pip install . -v --no-deps + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-core >=42 + +about: + summary: OpenFisca Rules as Code model for Country-Template. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-country-template/variants.yaml b/.conda/openfisca-country-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-country-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/openfisca-extension-template/recipe.yaml b/.conda/openfisca-extension-template/recipe.yaml new file mode 100644 index 0000000000..c30e28cde7 --- /dev/null +++ b/.conda/openfisca-extension-template/recipe.yaml @@ -0,0 +1,39 @@ +schema_version: 1 + +context: + name: openfisca-extension-template + version: 1.3.15 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_extension_template-${{ version }}.tar.gz + sha256: e16ee9cbefdd5e9ddc1c2c0e12bcd74307c8cb1be55353b3b2788d64a90a5df9 + +build: + number: 2 + noarch: python + script: pip install . -v --no-deps + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-country-template >=7.1.5 + +about: + summary: An OpenFisca extension that adds some variables to an already-existing + tax and benefit system. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-extension-template/variants.yaml b/.conda/openfisca-extension-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-extension-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/pylint-per-file-ignores/recipe.yaml b/.conda/pylint-per-file-ignores/recipe.yaml new file mode 100644 index 0000000000..4a573982f8 --- /dev/null +++ b/.conda/pylint-per-file-ignores/recipe.yaml @@ -0,0 +1,41 @@ +schema_version: 1 + +context: + name: pylint-per-file-ignores + version: 1.3.2 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/pylint_per_file_ignores-${{ version }}.tar.gz + sha256: 3c641f69c316770749a8a353556504dae7469541cdaef38e195fe2228841451e + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - python + - poetry-core >=1.0.0 + - pip + run: + - pylint >=3.3.1,<4.0 + - python + - tomli >=2.0.1,<3.0.0 + +tests: +- python: + imports: + - pylint_per_file_ignores + +about: + summary: A pylint plugin to ignore error codes per file. + license: MIT + homepage: https://github.com/christopherpickering/pylint-per-file-ignores.git + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/pylint-per-file-ignores/variants.yaml b/.conda/pylint-per-file-ignores/variants.yaml new file mode 100644 index 0000000000..ab419e422e --- /dev/null +++ b/.conda/pylint-per-file-ignores/variants.yaml @@ -0,0 +1,4 @@ +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.github/dependabot.yml b/.github/dependabot.yml index fcb2acc162..71eaf02d67 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,7 +1,7 @@ version: 2 updates: - package-ecosystem: pip - directory: "/" + directory: / schedule: interval: monthly labels: diff --git a/.github/get_pypi_info.py b/.github/get_pypi_info.py new file mode 100644 index 0000000000..70013fbe98 --- /dev/null +++ b/.github/get_pypi_info.py @@ -0,0 +1,83 @@ +"""Script to get information needed by .conda/meta.yaml from PyPi JSON API. + +This script use get_info to get the info (yes !) and replace_in_file to +write them into .conda/meta.yaml. +Sample call: +python3 .github/get_pypi_info.py -p OpenFisca-Core +""" + +import argparse + +import requests + + +def get_info(package_name: str = "") -> dict: + """Get minimal information needed by .conda/meta.yaml from PyPi JSON API. + + ::package_name:: Name of package to get infos from. + ::return:: A dict with last_version, url and sha256 + """ + if package_name == "": + msg = "Package name not provided." + raise ValueError(msg) + url = f"https://pypi.org/pypi/{package_name}/json" + print(f"Calling {url}") # noqa: T201 + resp = requests.get(url) + if resp.status_code != 200: + msg = f"ERROR calling PyPI ({url}) : {resp}" + raise Exception(msg) + resp = resp.json() + version = resp["info"]["version"] + + for v in resp["releases"][version]: + # Find packagetype=="sdist" to get source code in .tar.gz + if v["packagetype"] == "sdist": + return { + "last_version": version, + "url": v["url"], + "sha256": v["digests"]["sha256"], + } + return {} + + +def replace_in_file(filepath: str, info: dict) -> None: + """Replace placeholder in meta.yaml by their values. + + ::filepath:: Path to meta.yaml, with filename. + ::info:: Dict with information to populate. + """ + with open(filepath, encoding="utf-8") as fin: + meta = fin.read() + # Replace with info from PyPi + meta = meta.replace("PYPI_VERSION", info["last_version"]) + meta = meta.replace("PYPI_URL", info["url"]) + meta = meta.replace("PYPI_SHA256", info["sha256"]) + with open(filepath, "w", encoding="utf-8") as fout: + fout.write(meta) + print(f"File {filepath} has been updated with info from PyPi.") # noqa: T201 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", + "--package", + type=str, + default="", + required=True, + help="The name of the package", + ) + parser.add_argument( + "-f", + "--filename", + type=str, + default=".conda/openfisca-core/meta.yaml", + help="Path to meta.yaml, with filename", + ) + args = parser.parse_args() + info = get_info(args.package) + print( # noqa: T201 + "Information of the last published PyPi package :", + info["last_version"], + ) + replace_in_file(args.filename, info) diff --git a/.circleci/has-functional-changes.sh b/.github/has-functional-changes.sh similarity index 89% rename from .circleci/has-functional-changes.sh rename to .github/has-functional-changes.sh index b591716932..bf1270989a 100755 --- a/.circleci/has-functional-changes.sh +++ b/.github/has-functional-changes.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -IGNORE_DIFF_ON="README.md CONTRIBUTING.md Makefile .gitignore LICENSE* .circleci/* .github/* openfisca_tasks/*.mk tasks/*.mk tests/*" +IGNORE_DIFF_ON="README.md CONTRIBUTING.md Makefile .gitignore LICENSE* .github/* tests/* openfisca_tasks/*.mk tasks/*.mk" last_tagged_commit=`git describe --tags --abbrev=0 --first-parent` # --first-parent ensures we don't follow tags not published in master through an unlikely intermediary merge commit diff --git a/.circleci/is-version-number-acceptable.sh b/.github/is-version-number-acceptable.sh similarity index 95% rename from .circleci/is-version-number-acceptable.sh rename to .github/is-version-number-acceptable.sh index ae370e2a17..0f704a93fe 100755 --- a/.circleci/is-version-number-acceptable.sh +++ b/.github/is-version-number-acceptable.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $CIRCLE_BRANCH == master ]] +if [[ ${GITHUB_REF#refs/heads/} == master ]] then echo "No need for a version check on master." exit 0 diff --git a/.github/workflows/_before-conda.yaml b/.github/workflows/_before-conda.yaml new file mode 100644 index 0000000000..06d0067eff --- /dev/null +++ b/.github/workflows/_before-conda.yaml @@ -0,0 +1,109 @@ +name: Setup conda + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + +defaults: + run: + shell: bash -l {0} + +jobs: + setup: + runs-on: ${{ inputs.os }} + name: conda-setup-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + id: cache-env + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + id: cache-deps + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Setup conda + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + miniforge-version: latest + python-version: ${{ inputs.python }} + use-mamba: true + if: steps.cache-env.outputs.cache-hit != 'true' + + - name: Install dependencies + run: mamba install boa rattler-build + if: steps.cache-env.outputs.cache-hit != 'true' + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + if: steps.cache-env.outputs.cache-hit == 'true' + + - name: Build pylint plugin package + run: | + rattler-build build \ + --recipe .conda/pylint-per-file-ignores \ + --output-dir ~/conda-rel + + - name: Build core package + run: | + conda mambabuild .conda/openfisca-core \ + --use-local \ + --no-anaconda-upload \ + --output-folder ~/conda-rel \ + --numpy ${{ inputs.numpy }} \ + --python ${{ inputs.python }} + + - name: Build country template package + run: | + rattler-build build \ + --recipe .conda/openfisca-country-template \ + --output-dir ~/conda-rel \ + + - name: Build extension template package + run: | + rattler-build build \ + --recipe .conda/openfisca-extension-template \ + --output-dir ~/conda-rel + + - name: Export env + run: mamba env export --name openfisca > .env.yaml diff --git a/.github/workflows/_before-pip.yaml b/.github/workflows/_before-pip.yaml new file mode 100644 index 0000000000..02554419c8 --- /dev/null +++ b/.github/workflows/_before-pip.yaml @@ -0,0 +1,103 @@ +name: Setup package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + deps: + runs-on: ${{ inputs.os }} + name: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + + - name: Install dependencies + run: | + python -m venv venv + ${{ inputs.activate_command }} + make install-deps install-dist + pip install numpy==${{ inputs.numpy }} + + build: + runs-on: ${{ inputs.os }} + needs: [deps] + name: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[Oo]pen[Ff]isca* + key: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + restore-keys: | + pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}- + pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + + - name: Cache release + uses: actions/cache@v4 + with: + path: dist + key: pip-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Build package + run: | + ${{ inputs.activate_command }} + make install-test clean build diff --git a/.github/workflows/_lint-pip.yaml b/.github/workflows/_lint-pip.yaml new file mode 100644 index 0000000000..e994f473e3 --- /dev/null +++ b/.github/workflows/_lint-pip.yaml @@ -0,0 +1,57 @@ +name: Lint package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + lint: + runs-on: ${{ inputs.os }} + name: pip-lint-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Lint doc + run: | + ${{ inputs.activate_command }} + make clean check-syntax-errors lint-doc + + - name: Lint styles + run: | + ${{ inputs.activate_command }} + make clean check-syntax-errors check-style diff --git a/.github/workflows/_test-conda.yaml b/.github/workflows/_test-conda.yaml new file mode 100644 index 0000000000..fab88ac1df --- /dev/null +++ b/.github/workflows/_test-conda.yaml @@ -0,0 +1,76 @@ +name: Test conda package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + +defaults: + run: + shell: bash -l {0} + +jobs: + test: + runs-on: ${{ inputs.os }} + name: conda-test-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + + - name: Install packages + run: | + mamba install --channel file:///home/runner/conda-rel \ + openfisca-core-dev \ + openfisca-country-template \ + openfisca-extension-template + + - name: Run core tests + run: make test-core + + - name: Run country tests + run: make test-country + + - name: Run extension tests + run: make test-extension diff --git a/.github/workflows/_test-pip.yaml b/.github/workflows/_test-pip.yaml new file mode 100644 index 0000000000..e2db77ac3d --- /dev/null +++ b/.github/workflows/_test-pip.yaml @@ -0,0 +1,71 @@ +name: Test package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + test: + runs-on: ${{ inputs.os }} + name: pip-test-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[Oo]pen[Ff]isca* + key: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Run Openfisca Core tests + run: | + ${{ inputs.activate_command }} + make test-core + + - name: Run Country Template tests + if: ${{ startsWith(inputs.os, 'ubuntu') }} + run: | + ${{ inputs.activate_command }} + make test-country + + - name: Run Extension Template tests + if: ${{ startsWith(inputs.os, 'ubuntu') }} + run: | + ${{ inputs.activate_command }} + make test-extension diff --git a/.github/workflows/_version.yaml b/.github/workflows/_version.yaml new file mode 100644 index 0000000000..27c4737a4f --- /dev/null +++ b/.github/workflows/_version.yaml @@ -0,0 +1,38 @@ +name: Check version + +on: + workflow_call: + inputs: + os: + required: true + type: string + + python: + required: true + type: string + +jobs: + # The idea behind these dependencies is that we want to give feedback to + # contributors on the version number only after they have passed all tests, + # so they don't have to do it twice after changes happened to the main branch + # during the time they took to fix the tests. + check-version: + runs-on: ${{ inputs.os }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + # Fetch all the tags + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Check version number has been properly updated + run: ${GITHUB_WORKSPACE}/.github/is-version-number-acceptable.sh diff --git a/.github/workflows/merge.yaml b/.github/workflows/merge.yaml new file mode 100644 index 0000000000..31e863a96b --- /dev/null +++ b/.github/workflows/merge.yaml @@ -0,0 +1,250 @@ +name: OpenFisca-Core / Deploy package to PyPi & Conda + +on: + push: + branches: [master] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + setup-pip: + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + # Patch version must be specified to avoid any cache confusion, since + # the cache key depends on the full Python version. If left unspecified, + # different patch versions could be allocated between jobs, and any + # such difference would lead to a cache not found error. + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_before-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + setup-conda: + uses: ./.github/workflows/_before-conda.yaml + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + test-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_test-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + test-conda: + uses: ./.github/workflows/_test-conda.yaml + needs: [setup-conda] + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + lint-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + numpy: [1.24.2] + python: [3.11.9, 3.9.13] + uses: ./.github/workflows/_lint-pip.yaml + with: + os: ubuntu-22.04 + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: source venv/bin/activate + + check-version: + needs: [test-pip, test-conda, lint-pip] + uses: ./.github/workflows/_version.yaml + with: + os: ubuntu-22.04 + python: 3.9.13 + + # GitHub Actions does not have a halt job option, to stop from deploying if + # no functional changes were found. We build a separate job to substitute the + # halt option. The `deploy` job is dependent on the output of the + # `check-for-functional-changes`job. + check-for-functional-changes: + runs-on: ubuntu-22.04 + # Last job to run + needs: [check-version] + outputs: + status: ${{ steps.stop-early.outputs.status }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9.13 + + - id: stop-early + # The `check-for-functional-changes` job should always succeed regardless + # of the `has-functional-changes` script's exit code. Consequently, we do + # not use that exit code to trigger deploy, but rather a dedicated output + # variable `status`, to avoid a job failure if the exit code is different + # from 0. Conversely, if the job fails the entire workflow would be + # marked as `failed` which is disturbing for contributors. + run: if "${GITHUB_WORKSPACE}/.github/has-functional-changes.sh" ; then echo + "::set-output name=status::success" ; fi + + publish-to-pypi: + runs-on: ubuntu-22.04 + needs: [check-for-functional-changes] + if: needs.check-for-functional-changes.outputs.status == 'success' + env: + PYPI_USERNAME: __token__ + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_OPENFISCA_BOT }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9.13 + + - name: Cache deps + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[oO]pen[fF]isca* + key: pip-build-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: dist + key: pip-release-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Upload package to PyPi + run: | + source venv/bin/activate + twine upload dist/* --username $PYPI_USERNAME --password $PYPI_TOKEN + + - name: Update version + run: | + source venv/bin/activate + git tag `python setup.py --version` + git push --tags # update the repository version + + - name: Update doc + run: | + curl -L \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.OPENFISCADOC_BOT_ACCESS_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/openfisca/openfisca-doc/actions/workflows/deploy.yaml/dispatches \ + -d '{"ref":"main"}' + + publish-to-conda: + runs-on: ubuntu-22.04 + needs: [publish-to-pypi] + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }} + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + + - name: Publish to conda + run: | + rattler-build upload anaconda ~/conda-rel/noarch/*.conda \ + --force \ + --owner openfisca \ + --api-key ${{ secrets.ANACONDA_TOKEN }} + + test-on-windows: + runs-on: windows-2019 + needs: [publish-to-conda] + defaults: + run: + shell: bash -l {0} + + steps: + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + # See GHA Windows + # https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json + python-version: 3.10.6 + channels: conda-forge + activate-environment: true + + - name: Checkout + uses: actions/checkout@v4 + + - name: Install with conda + run: conda install -c openfisca openfisca-core + + - name: Test openfisca + run: openfisca --help diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml new file mode 100644 index 0000000000..7bee48c81c --- /dev/null +++ b/.github/workflows/push.yaml @@ -0,0 +1,89 @@ +name: OpenFisca-Core / Pull request review + +on: + pull_request: + types: [assigned, opened, reopened, synchronize, ready_for_review] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + setup-pip: + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + # Patch version must be specified to avoid any cache confusion, since + # the cache key depends on the full Python version. If left unspecified, + # different patch versions could be allocated between jobs, and any + # such difference would lead to a cache not found error. + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_before-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + setup-conda: + uses: ./.github/workflows/_before-conda.yaml + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + test-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_test-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + test-conda: + uses: ./.github/workflows/_test-conda.yaml + needs: [setup-conda] + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + lint-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + numpy: [1.24.2] + python: [3.11.9, 3.9.13] + uses: ./.github/workflows/_lint-pip.yaml + with: + os: ubuntu-22.04 + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: source venv/bin/activate + + check-version: + needs: [test-pip, test-conda, lint-pip] + uses: ./.github/workflows/_version.yaml + with: + os: ubuntu-22.04 + python: 3.9.13 diff --git a/.gitignore b/.gitignore index c66d2bd194..cdf7204715 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,8 @@ *.mo *.pyc *~ -.coverage .mypy_cache +.mypy_cache* .noseids .project .pydevproject @@ -13,7 +13,6 @@ .tags* .venv .vscode -.vscode build cover dist diff --git a/CHANGELOG.md b/CHANGELOG.md index 502ca99d51..e746ad57a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,60 +1,14 @@ # Changelog -# 36.0.0 [#1015](https://github.com/openfisca/openfisca-core/pull/1015) - -#### Technical changes - -- Extract requirements to separate files for easier contract enforcement. -- Add explicit contract regarding supported dependencies. -- Add constraint file to test against lower-bound NumPy. -- Add extra dependencies. - - Add coveralls (latest) to extra requirements. - - Add twine (latest) to extra requirements. - - Add wheel (latest) to extra requirements. -- Pin non-distribution dependencies. - - Pin autopep8 at latest. - - Pin flake8 at latest. - - Pin flake8-bugbear at latest. - - Pin flake8-print at latest. - - Pin pytest-cov at latest. - - Pin mypy at latest. - - Pin flask at 1.1.2. - - Pin gunicorn at 20.1.0. - - Pin flask-cors at 3.0.10. - - Pin werkzeug at 1.0.1. -- Relax distrubution dependencies. - - Set dpath at >= 1.3.2, < 2. - - Set psutil at >= 5.4.7, < 6. - - Set sortedcontainers at >= 2, < 3. -- Relax circular dependencies. - - Relax openfisca-country-template. - - Relax openfisca-extension-template. +# 44.0.0 [#1015](https://github.com/openfisca/openfisca-core/pull/1015) -#### Breaking changes +#### Technical changes + +- Add `pyproject.toml` to the repository. +- Add `poetry` to manage dependencies. +- Add `tox` to manage isolated tests. -- Drop support for Python < 3.7. - - Python 3.7 [introduces backwards incompatible syntax changes](https://docs.python.org/3/whatsnew/3.7.html) that might be used in your country models. -- Drop support for numexpr < 2.7.1. - - numexpr 2.7.1 [introduces no breaking changes](https://numexpr.readthedocs.io/projects/NumExpr3/en/latest/release_notes.html#changes-from-2-7-0-to-2-7-1). -- Drop support for NumPy < 1.17 - - NumPy 1.12 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.12.0-notes.html#compatibility-notes) that might be used in your country models. - - NumPy 1.13 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.13.0-notes.html#compatibility-notes) that might be used in your country models. - - NumPy 1.14 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.14.0-notes.html#compatibility-notes) that might be used in your country models. - - NumPy 1.15 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.15.0-notes.html#compatibility-notes) that might be used in your country models. - - NumPy 1.16 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.16.0-notes.html#expired-deprecations) that might be used in your country models. - - NumPy 1.17 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.17.0-notes.html#compatibility-notes) that might be used in your country models. -- Drop support for pytest < 5.4.2. - - pytest 5 [introduces a list of removals and deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-0-0-2019-06-28) that might be used in your country models. - - pytest 5.1 [introduces a list of removals](https://docs.pytest.org/en/stable/changelog.html#pytest-5-1-0-2019-08-15) that might be used in your country models. - - pytest 5.2 [introduces a list of deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-2-0-2019-09-28) that might be used in your country models. - - pytest 5.3 [introduces a list of deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-3-0-2019-11-19) that might be used in your country models. - - pytest 5.4 [introduces a list of breaking changes and deprecations](https://docs.pytest.org/en/stable/changelog.html#pytest-5-3-0-2019-11-19) that might be used in your country models. - - pytest 5.4.1 [introduces no breaking changes](https://docs.pytest.org/en/stable/changelog.html#pytest-5-4-1-2020-03-13). - - pytest 5.4.2 [introduces no breaking changes](https://docs.pytest.org/en/stable/changelog.html#pytest-5-4-2-2020-05-08). -- Drop support for PyYAML < 5.1. - - PyYAML 5.1 [introduces some breaking changes](https://github.com/yaml/pyyaml/blob/ee37f4653c08fc07aecff69cfd92848e6b1a540e/CHANGES#L66-L97) that might be used in your country models. - -#### Expired deprecations +#### Deprecations - `openfisca_core.commons.Dummy` => `openfisca_core.commons.empty_clone` - `openfisca_core.errors.ParameterNotFound` => `openfisca_core.errors.ParameterNotFoundError` @@ -75,6 +29,653 @@ - `openfisca_core.taxbenefitsystems.VariableNotFound` => `openfisca_core.errors.VariableNotFoundError` - `openfisca_core.taxscales.EmptyArgumentError` => `openfisca_core.errors.EmptyArgumentError` +### 43.2.2 [#1280](https://github.com/openfisca/openfisca-core/pull/1280) + +#### Documentation + +- Add types to common tracers (`SimpleTracer`, `FlatTracer`, etc.) + +### 43.2.1 [#1283](https://github.com/openfisca/openfisca-core/pull/1283) + +#### Technical changes + +- Remove `coveralls` + +## 43.2.0 [#1279](https://github.com/openfisca/openfisca-core/pull/1279) + +#### New features + +- Introduce `populations.CorePopulation` + - Allows for testing and better subclassing custom populations + +### 43.1.2 [#1274](https://github.com/openfisca/openfisca-core/pull/1275) + +#### Documentation + +- Add docs to experimental + +### 43.1.1 [#1282](https://github.com/openfisca/openfisca-core/pull/1282) + +#### Technical changes + +- Add check to spot common spelling mistakes + +## 43.1.0 [#1255](https://github.com/openfisca/openfisca-core/pull/1255) + +- Make `CoreEntity` public + - Allows for more easily creating customised entities. + +#### Technical changes + +- Add missing doctests. + +# 43.0.0 [#1224](https://github.com/openfisca/openfisca-core/pull/1224) + +#### Technical changes + +- Add documentation to the `indexed_enums` module +- Fix type definitions in the enums module +- Fix doctests +- Fix bug in `Enum.encode` when passing a scalar +- Fix bug in `Enum.encode` when encoding values not present in the enum + +#### New features + +- Introduce `indexed_enums.EnumType` + - Allows for actually fancy indexing `indexed_enums.Enum` + +#### Note + +This changeset has not breaking changes to the `indexed_enums` public API. +However, as a conservative measure concerning data preparation for large +population simulations, it has been marked as a major release. + +##### Before + +```python +from openfisca_core import indexed_enums as enum + +class TestEnum(enum.Enum): + ONE = "one" + TWO = "two" + +TestEnum.encode([2]) +# EnumArray([0]) +``` + +##### After + +```python +from openfisca_core import indexed_enums as enum + +class TestEnum(enum.Enum): + ONE = "one" + TWO = "two" + +TestEnum.encode([2]) +# EnumArray([]) + +TestEnum.encode([0,1,2,5]) +# EnumArray([ ]) +``` + +### 42.0.7 [#1264](https://github.com/openfisca/openfisca-core/pull/1264) + +#### Technical changes + +- Add typing to `data_storage` module + +### 42.0.6 [#1263](https://github.com/openfisca/openfisca-core/pull/1263) + +#### Documentation + +- Fix docs of the `data_storage` module + +### 42.0.5 [#1261](https://github.com/openfisca/openfisca-core/pull/1261) + +#### Technical changes + +- Fix doctests of `data_storage` module + +### 42.0.4 [#1257](https://github.com/openfisca/openfisca-core/pull/1257) + +#### Technical changes + +- Fix conda test and publish +- Add matrix testing to CI + - Now it tests lower and upper bounds of python and numpy versions + +### 42.0.3 [#1234](https://github.com/openfisca/openfisca-core/pull/1234) + +#### Technical changes + +- Add matrix testing to CI + - Now it tests lower and upper bounds of python and numpy versions + +> Note: Version `42.0.3` has been unpublished as was deployed by mistake. +> Please use versions `42.0.4` and subsequents. + +### 42.0.2 [#1256](https://github.com/openfisca/openfisca-core/pull/1256) + +#### Documentation + +- Fix bad indent + +### 42.0.1 [#1253](https://github.com/openfisca/openfisca-core/pull/1253) + +#### Documentation + +- Fix documentation of `entities` + +# 42.0.0 [#1223](https://github.com/openfisca/openfisca-core/pull/1223) + +#### Breaking changes + +- Changes to `eternity` instants and periods + - Eternity instants are now `` instead of + `` + - Eternity periods are now `, -1))>` + instead of `, inf))>` + - The reason is to avoid mixing data types: `inf` is a float, periods and + instants are integers. Mixed data types make memory optimisations impossible. + - Migration should be straightforward. If you have a test that checks for + `inf`, you should update it to check for `-1` or use the `is_eternal` method. +- `periods.instant` no longer returns `None` + - Now, it raises `periods.InstantError` + +#### New features + +- Introduce `Instant.eternity()` + - This behaviour was duplicated across + - Now it is encapsulated in a single method +- Introduce `Instant.is_eternal` and `Period.is_eternal` + - These methods check if the instant or period are eternity (`bool`). +- Now `periods.instant` parses also ISO calendar strings (weeks) + - For instance, `2022-W01` is now a valid input + +#### Technical changes + +- Update `pendulum` +- Reduce code complexity +- Remove run-time type-checks +- Add typing to the periods module + +### 41.5.7 [#1225](https://github.com/openfisca/openfisca-core/pull/1225) + +#### Technical changes + +- Refactor & test `eval_expression` + +### 41.5.6 [#1185](https://github.com/openfisca/openfisca-core/pull/1185) + +#### Technical changes + +- Remove pre Python 3.9 syntax. + +### 41.5.5 [#1220](https://github.com/openfisca/openfisca-core/pull/1220) + +#### Technical changes + +- Fix doc & type definitions in the entities module + +### 41.5.4 [#1219](https://github.com/openfisca/openfisca-core/pull/1219) + +#### Technical changes + +- Fix doc & type definitions in the commons module + +### 41.5.3 [#1218](https://github.com/openfisca/openfisca-core/pull/1218) + +#### Technical changes + +- Fix `flake8` doc linting: + - Add format "google" + - Fix per-file skips +- Fix failing lints + +### 41.5.2 [#1217](https://github.com/openfisca/openfisca-core/pull/1217) + +#### Technical changes + +- Fix styles by applying `isort`. +- Add a `isort` dry-run check to `make lint` + +### 41.5.1 [#1216](https://github.com/openfisca/openfisca-core/pull/1216) + +#### Technical changes + +- Fix styles by applying `black`. +- Add a `black` dry-run check to `make lint` + +## 41.5.0 [#1212](https://github.com/openfisca/openfisca-core/pull/1212) + +#### New features + +- Introduce `VectorialAsofDateParameterNodeAtInstant` + - It is a parameter node of the legislation at a given instant which has been vectorized along some date. + - Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a `numpy.datetime64` vector + +### 41.4.7 [#1211](https://github.com/openfisca/openfisca-core/pull/1211) + +#### Technical changes + +- Update documentation continuous deployment method to reflect OpenFisca-Doc [process updates](https://github.com/openfisca/openfisca-doc/pull/308) + +### 41.4.6 [#1210](https://github.com/openfisca/openfisca-core/pull/1210) + +#### Technical changes + +- Abide by OpenAPI v3.0.0 instead of v3.1.0 + - Drop support for `propertyNames` in `Values` definition + +### 41.4.5 [#1209](https://github.com/openfisca/openfisca-core/pull/1209) + +#### Technical changes + +- Support loading metadata from both `setup.py` and `pyproject.toml` package description files. + +### ~41.4.4~ [#1208](https://github.com/openfisca/openfisca-core/pull/1208) + +_Unpublished due to introduced backwards incompatibilities._ + +#### Technical changes + +- Adapt testing pipeline to Country Template [v7](https://github.com/openfisca/country-template/pull/139). + +### 41.4.3 [#1206](https://github.com/openfisca/openfisca-core/pull/1206) + +#### Technical changes + +- Increase spiral and cycle tests robustness. + - The current test is ambiguous, as it hides a failure at the first spiral + occurrence (from 2017 to 2016). + +### 41.4.2 [#1203](https://github.com/openfisca/openfisca-core/pull/1203) + +#### Technical changes + +- Changes the Pypi's deployment authentication way to use token API following Pypi's 2FA enforcement starting 2024/01/01. + +### 41.4.1 [#1202](https://github.com/openfisca/openfisca-core/pull/1202) + +#### Technical changes + +- Check that entities are fully specified when expanding over axes. + +## 41.4.0 [#1197](https://github.com/openfisca/openfisca-core/pull/1197) + +#### New features + +- Add `entities.find_role()` to find roles by key and `max`. + +#### Technical changes + +- Document `projectors.get_projector_from_shortcut()`. + +## 41.3.0 [#1200](https://github.com/openfisca/openfisca-core/pull/1200) + +> As `TracingParameterNodeAtInstant` is a wrapper for `ParameterNodeAtInstant` +> which allows iteration and the use of `contains`, it was not possible +> to use those on a `TracingParameterNodeAtInstant` + +#### New features + +- Allows iterations on `TracingParameterNodeAtInstant` +- Allows keyword `contains` on `TracingParameterNodeAtInstant` + +## 41.2.0 [#1199](https://github.com/openfisca/openfisca-core/pull/1199) + +#### Technical changes + +- Fix `openfisca-core` Web API error triggered by `Gunicorn` < 22.0. + - Bump `Gunicorn` major revision to fix error on Web API. + Source: https://github.com/benoitc/gunicorn/issues/2564 + +### 41.1.2 [#1192](https://github.com/openfisca/openfisca-core/pull/1192) + +#### Technical changes + +- Add tests to `entities`. + +### 41.1.1 [#1186](https://github.com/openfisca/openfisca-core/pull/1186) + +#### Technical changes + +- Skip type-checking tasks + - Before their definition was commented out but still run with `make test` + - Now they're skipped but not commented, which is needed to fix the + underlying issues + +## 41.1.0 [#1195](https://github.com/openfisca/openfisca-core/pull/1195) + +#### Technical changes + +- Make `Role` explicitly hashable. +- Details: + - By introducing `__eq__`, naturally `Role` became unhashable, because + equality was calculated based on a property of `Role` + (`role.key == another_role.key`), and no longer structurally + (`"1" == "1"`). + - This changeset removes `__eq__`, as `Role` is being used downstream as a + hashable object, and adds a test to ensure `Role`'s hashability. + +### 41.0.2 [#1194](https://github.com/openfisca/openfisca-core/pull/1194) + +#### Technical changes + +- Add `__hash__` method to `Role`. + +### 41.0.1 [#1187](https://github.com/openfisca/openfisca-core/pull/1187) + +#### Technical changes + +- Document `Role`. + +# 41.0.0 [#1189](https://github.com/openfisca/openfisca-core/pull/1189) + +#### Breaking changes + +- `Variable.get_introspection_data` no longer has parameters nor calling functions + +The Web API was very prone to crashing, timeouting at startup because of the time consuming python file parsing to generate documentation displayed for instance in the Legislation Explorer. + +## 40.1.0 [#1174](https://github.com/openfisca/openfisca-core/pull/1174) + +#### New Features + +- Allows for dispatching and dividing inputs over a broader range. + - For example, divide a monthly variable by week. + +### 40.0.1 [#1184](https://github.com/openfisca/openfisca-core/pull/1184) + +#### Technical changes + +- Require numpy < 1.25 because of memory leak detected in OpenFisca-France. + +# 40.0.0 [#1181](https://github.com/openfisca/openfisca-core/pull/1181) + +#### Breaking changes + +- Upgrade every dependencies to its latest version. +- Upgrade to Python >= 3.9 + +Note: Checks on mypy typings are disabled, because they cause generate of errors that we were not able to fix easily. + +# 39.0.0 [#1181](https://github.com/openfisca/openfisca-core/pull/1181) + +#### Breaking changes + +- Upgrade every dependencies to their latest versions. +- Upgrade to Python >= 3.9 + +Main changes, that may require some code changes in country packages: +- numpy +- pytest +- Flask + +### 38.0.4 [#1182](https://github.com/openfisca/openfisca-core/pull/1182) + +#### Technical changes + +- Method `_get_tax_benefit_system()` of class `YamlItem` in file `openfisca_core/tools/test_runner.py` will now clone the TBS when applying reforms to avoid running tests with previously reformed TBS. + +### 38.0.3 [#1179](https://github.com/openfisca/openfisca-core/pull/1179) + +#### Bug fix + +- Do not install dependencies outside the `setup.py` + - Dependencies installed outside the `setup.py` are not taken into account by + `pip`'s dependency resolver. + - In case of conflicting transient dependencies, the last library installed + will "impose" its dependency version. + - This makes the installation and build of the library non-deterministic and + prone to unforeseen bugs caused by external changes in dependencies' + versions. + +#### Note + +A definite way to solve this issue is to clearly separate library dependencies +(with a `virtualenv`) and a universal dependency installer for CI requirements +(like `pipx`), taking care of: + +- Always running tests inside the `virtualenv` (for example with `nox`). +- Always building outside of the `virtualenv` (for example with `poetry` + installed by `pipx`). + +Moreover, it is indeed even better to have a lock file for dependencies, +using `pip freeze`) or with tools providing such features (`pipenv`, etc.). + +### 38.0.2 [#1178](https://github.com/openfisca/openfisca-core/pull/1178) + +#### Technical changes + +- Remove use of `importlib_metadata`. + +### 38.0.1 - + +> Note: Version `38.0.1` has been unpublished as was deployed by mistake. +> Please use versions `38.0.2` and subsequents. + + +# 38.0.0 [#989](https://github.com/openfisca/openfisca-core/pull/989) + +> Note: Version `38.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### New Features + +- Upgrade OpenAPI specification of the API to v3 from Swagger v2. +- Continuously validate OpenAPI specification. + +#### Breaking changes + +- Drop support for OpenAPI specification v2 and prior. + - Users relying on OpenAPI v2 can use [Swagger Converter](https://converter.swagger.io/api/convert?url=OAS2_YAML_OR_JSON_URL) to migrate ([example](https://web.archive.org/web/20221103230822/https://converter.swagger.io/api/convert?url=https://api.demo.openfisca.org/latest/spec)). + +### 37.0.2 [#1170](https://github.com/openfisca/openfisca-core/pull/1170) + +> Note: Version `37.0.2` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Always import numpy + +### 37.0.1 [#1169](https://github.com/openfisca/openfisca-core/pull/1169) + +> Note: Version `37.0.1` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Unify casing of NumPy. + +# 37.0.0 [#1142](https://github.com/openfisca/openfisca-core/pull/1142) + +> Note: Version `37.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Deprecations + +- In _periods.Instant_: + - Remove `period`, method used to build a `Period` from an `Instant`. + - This method created an upward circular dependency between `Instant` and `Period` causing lots of trouble. + - The functionality is still provided by `periods.period` and the `Period` constructor. + +#### Migration details + +- Replace `some_period.start.period` and similar methods with `Period((unit, some_period.start, 1))`. + +# 36.0.0 [#1149](https://github.com/openfisca/openfisca-core/pull/1162) + +> Note: Version `36.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Breaking changes + +- In `ParameterScaleBracket`: + - Remove the `base` attribute + - The attribute's usage was unclear and it was only being used by some French social security variables + +## 35.12.0 [#1160](https://github.com/openfisca/openfisca-core/pull/1160) + +> Note: Version `35.12.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### New Features + +- Lighter install by removing test packages from systematic install. + +### 35.11.2 [#1166](https://github.com/openfisca/openfisca-core/pull/1166) + +> Note: Version `35.11.2` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Fix Holder's doctests. + +### 35.11.1 [#1165](https://github.com/openfisca/openfisca-core/pull/1165) + +> Note: Version `35.11.1` has been unpublished as it introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Bug fix + +- Fix documentation + - Suppression of some modules broke the documentation build + +## 35.11.0 [#1149](https://github.com/openfisca/openfisca-core/pull/1149) + +#### New Features + +- Introduce variable dependent error margins in YAML tests. + +### 35.10.1 [#1143](https://github.com/openfisca/openfisca-core/pull/1143) + +#### Bug fix + +- Reintroduce support for the ``day`` date unit in `holders.set_input_dispatch_by_period` and `holders. + set_input_divide_by_period` + - Allows for dispatching values per day, for example, to provide a daily (week, fortnight) to an yearly variable. + - Inversely, allows for calculating the daily (week, fortnight) value of a yearly input. + +## 35.10.0 [#1151](https://github.com/openfisca/openfisca-core/pull/1151) + +#### New features + +- Add type hints for all instances of `variable_name` in function declarations. +- Add type hints for some `Simulation` and `Population` properties. + +## 35.9.0 [#1150](https://github.com/openfisca/openfisca-core/pull/1150) + +#### New Features + +- Introduce a maximal depth for computation logs + - Allows for limiting the depth of the computation log chain + +### 35.8.6 [#1145](https://github.com/openfisca/openfisca-core/pull/1145) + +#### Technical changes + +- Removes the automatic documentation build check + - It has been proven difficult to maintain, specifically due _dependency hell_ and a very contrived build workflow. + +### 35.8.5 [#1137](https://github.com/openfisca/openfisca-core/pull/1137) + +#### Technical changes + +- Fix pylint dependency in fresh editable installations + - Ignore pytest requirement, used to collect test cases, if it is not yet installed. + +### 35.8.4 [#1131](https://github.com/openfisca/openfisca-core/pull/1131) + +#### Technical changes + +- Correct some type hints and docstrings. + +### 35.8.3 [#1127](https://github.com/openfisca/openfisca-core/pull/1127) + +#### Technical changes + +- Fix the build for Anaconda in CI. The conda build failed on master because of a replacement in a comment string. + - The _ were removed in the comment to avoid a replace. + +### 35.8.2 [#1128](https://github.com/openfisca/openfisca-core/pull/1128) + +#### Technical changes + +- Remove ambiguous links in docstrings. + +### 35.8.1 [#1105](https://github.com/openfisca/openfisca-core/pull/1105) + +#### Technical changes + +- Add publish to Anaconda in CI. See file .conda/README.md. + +## 35.8.0 [#1114](https://github.com/openfisca/openfisca-core/pull/1114) + +#### New Features + +- Introduce `rate_from_bracket_indice` method on `RateTaxScaleLike` class + - Allows for the determination of the tax rate based on the tax bracket indice + +- Introduce `rate_from_tax_base` method on `RateTaxScaleLike` class + - Allows for the determination of the tax rate based on the tax base + +- Introduce `threshold_from_tax_base` method on `RateTaxScaleLike` class + - Allows for the determination of the lower threshold based on the tax base + +- Add publish openfisca-core library to Anaconda in CI. See file .conda/README.md. + +### 35.7.8 [#1086](https://github.com/openfisca/openfisca-core/pull/1086) + +#### Technical changes + +### 35.7.7 [#1109](https://github.com/openfisca/openfisca-core/pull/1109) + +#### Technical changes + +- Fix `openfisca-core` Web API error triggered by `Flask` dependencies updates + - Bump `Flask` patch revision to fix `cannot import name 'json' from 'itsdangerous'` on Web API. + - Then, fix `MarkupSafe` revision to avoid `cannot import name 'soft_unicode' from 'markupsafe'` error on Web API. + +### 35.7.6 [#1065](https://github.com/openfisca/openfisca-core/pull/1065) + +#### Technical changes + +- Made code compatible with dpath versions >=1.5.0,<3.0.0, instead of >=1.5.0,<2.0.0 + +### 35.7.5 [#1090](https://github.com/openfisca/openfisca-core/pull/1090) + +#### Technical changes + +- Remove calls to deprecated imp module + +### 35.7.4 [#1083](https://github.com/openfisca/openfisca-core/pull/1083) + +#### Technical changes + +- Add GitHub `pull-request` event as a trigger to GitHub Actions workflow + +### 35.7.3 [#1081](https://github.com/openfisca/openfisca-core/pull/1081) + +- Correct error message in case of mis-sized population + +### 35.7.2 [#1057](https://github.com/openfisca/openfisca-core/pull/1057) + +#### Technical changes + +- Switch CI provider from CircleCI to GitHub Actions + ### 35.7.1 [#1075](https://github.com/openfisca/openfisca-core/pull/1075) #### Bug fix @@ -206,7 +807,7 @@ - When libraries do not implement their own types, MyPy provides stubs, or type sheds - Thanks to `__future__.annotations`, those stubs or type sheds are casted to `typing.Any` - Since 1.20.x, NumPy now provides their own type definitions - - The introduction of NumPy 1.20.x in #990 caused one major problem: + - The introduction of NumPy 1.20.x in #990 caused one major problem: - It is general practice to do not import at runtime modules only used for typing purposes, thanks to the `typing.TYPE_CHEKING` variable - The new `numpy.typing` module was being imported at runtime, rendering OpenFisca unusable to all users depending on previous versions of NumPy (1.20.x-) - These changes revert #990 and solve #1009 and #1012 @@ -248,7 +849,7 @@ _Note: this version has been unpublished due to an issue introduced by NumPy upg #### Bug fix - Repair expansion of axes on a variable given as input - - When expanding axes, the expected behavour is to override any input value for the requested variable and period + - When expanding axes, the expected behaviour is to override any input value for the requested variable and period - As longs as we passed some input for a variable on a period, it was not being overrode, creating a NumPy's error (boradcasting) - By additionally checking that an input was given, now we make that the array has the correct shape by constructing it with NumPy's tile with a shape equal to the number of the axis expansion count requested. @@ -317,7 +918,7 @@ _Note: this version has been unpublished due to an issue introduced by NumPy upg #### Technical changes -- Improve error message when laoding parameters file to detect the problematic file +- Improve error message when loading parameters file to detect the problematic file ### 35.0.3 [#961](https://github.com/openfisca/openfisca-core/pull/961) @@ -331,7 +932,7 @@ _Note: this version has been unpublished due to an issue introduced by NumPy upg #### Technical changes -- Update dependency: `flask-cors` (`Flask` extension for Cross Origin Resouce Sharing) +- Update dependency: `flask-cors` (`Flask` extension for Cross Origin Resource Sharing) ### 35.0.1 [#968](https://github.com/openfisca/openfisca-core/pull/968) @@ -344,22 +945,22 @@ _Note: this version has been unpublished due to an issue introduced by NumPy upg #### Breaking changes -- Update Numpy version's upper bound to 1.18 - - Numpy 1.18 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) that might be used in openfisca country models. +- Update NumPy version's upper bound to 1.18 + - NumPy 1.18 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) that might be used in openfisca country models. #### Migration details -You might need to change your code if any of the [Numpy expired deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) is used in your model formulas. +You might need to change your code if any of the [NumPy expired deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) is used in your model formulas. Here is a subset of the deprecations that you might find in your model with some checks and migration steps (where `np` stands for `numpy`): -* `Removed deprecated support for boolean and empty condition lists in np.select.` - * Before `np.select([], [])` result was `0` (for a `default` argument value set to `0`). +* `Removed deprecated support for boolean and empty condition lists in numpy.select.` + * Before `numpy.select([], [])` result was `0` (for a `default` argument value set to `0`). * Now, we have to check for empty conditions and, return `0` or the defined default argument value when we want to keep the same behavior. * Before, integer conditions where transformed to booleans. - * For example, `np.select([0, 1, 0], ['a', 'b', 'c'])` result was `array('b', dtype=' ``` > And two parameters `parameters.city_tax.z1` and `parameters.city_tax.z2`, they can be dynamically accessed through: > ```py -> zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) +> zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) > zone_value = parameters.rate._get_at_instant('2015-01-01').single.owner[zone] > ``` > returns @@ -2154,16 +2755,16 @@ class housing_occupancy_status(Variable): - When using the Python API (`set_input`), the three following inputs are accepted: - The enum item (e.g. HousingOccupancyStatus.tenant) - The enum string identifier (e.g. "tenant") - - The enum item index, though this is not recommanded. + - The enum item index, though this is not recommended. - If you rely on index, make sure to specify an `__order__` attribute to all your enums to make sure each intem has the right index. See the enum34 [doc](https://pypi.python.org/pypi/enum34/1.1.1). > Example: ```py holder = simulation.household.get_holder('housing_occupancy_status') # Three possibilities -holder.set_input(period, np.asarray([HousingOccupancyStatus.owner])) -holder.set_input(period, np.asarray(['owner'])) -holder.set_input(period, np.asarray([0])) # Highly not recommanded +holder.set_input(period, numpy.asarray([HousingOccupancyStatus.owner])) +holder.set_input(period, numpy.asarray(['owner'])) +holder.set_input(period, numpy.asarray([0])) # Highly not recommended ``` - When calculating an Enum variable, the output will be an [EnumArray](https://openfisca.org/doc/openfisca-python-api/enum_array.html#module-openfisca_core.indexed_enums). @@ -2326,7 +2927,7 @@ column = make_column_from_variable(variable) - In `Variable`: * Remove `to_column` - * Variables can now directly be instanciated: + * Variables can now directly be instantiated: ```py class salary(Variable): @@ -2377,7 +2978,7 @@ tax_benefit_system.parameters.benefits.basic_income - Be more flexible about parameters definitions -The two following expressions are for instance striclty equivalent: +The two following expressions are for instance strictly equivalent: ``` Parameter("taxes.rate", {"2015-01-01": 2000}) @@ -2598,7 +3199,7 @@ For more information, check the [documentation](https://openfisca.org/doc/coding reference_parameters.add_child('plf2016_conterfactual', reform_parameters_subtree) ``` - - Note that this way of creating parameters is only recommanded when using dynamically computed values (for instance `round(1135 * (1 + inflation))` in the previous example). If the values are static, the new parameters can be directly built from YAML (See New features section). + - Note that this way of creating parameters is only recommended when using dynamically computed values (for instance `round(1135 * (1 + inflation))` in the previous example). If the values are static, the new parameters can be directly built from YAML (See New features section). ##### TaxBenefitSystem @@ -2658,8 +3259,8 @@ For more information, check the [documentation](https://openfisca.org/doc/coding #### Technical changes * Refactor the internal representation and the interface of legislation parameters - - The parameters of a legislation are wraped into the classes `Node`, `Parameter`, `Scale`, `Bracket`, `ValuesHistory`, `ValueAtInstant` instead of bare python dict. - - The parameters of a legislation at a given instant are wraped into the classes `NodeAtInstant`, `ValueAtInstant` and tax scales instead of bare python objects. + - The parameters of a legislation are wrapped into the classes `Node`, `Parameter`, `Scale`, `Bracket`, `ValuesHistory`, `ValueAtInstant` instead of bare python dict. + - The parameters of a legislation at a given instant are wrapped into the classes `NodeAtInstant`, `ValueAtInstant` and tax scales instead of bare python objects. - The file `parameters.py` and the classes defined inside are responsible both for loading and accessing the parameters. Before the loading was implemented in `legislationsxml.py` and the other processings were implemented in `legislations.py` - The validation of the XML files was performed against a XML schema defined in `legislation.xsd`. Now the YAML files are loaded with the library `yaml` and then validated in basic Python. @@ -2670,7 +3271,7 @@ For more information, check the [documentation](https://openfisca.org/doc/coding - `Simulation.get_compact_legislation()` -> `Simulation._get_parameters_at_instant()` - `Simulation.get_baseline_compact_legislation()` -> `Simulation._get_baseline_parameters_at_instant()` -* The optionnal parameter `traced_simulation` is removed in function `TaxBenefitSystem.get_compact_legislation()` (now `TaxBenefitSystem.get_parameters_at_instant()`). This parameter had no effect. +* The optional parameter `traced_simulation` is removed in function `TaxBenefitSystem.get_compact_legislation()` (now `TaxBenefitSystem.get_parameters_at_instant()`). This parameter had no effect. * The optional parameter `with_source_file_infos` is removed in functions `TaxBenefitSystem.compute_legislation()` (now `TaxBenefitSystem._compute_parameters()`) and `TaxBenefitSystem.get_legislation()`. This parameter had no effect. @@ -2688,7 +3289,7 @@ For more information, check the [documentation](https://openfisca.org/doc/coding In the preview web API, for variables of type `Enum`: * Accept and recommend to use strings as simulation inputs, instead of the enum indices. - - For instance, `{"housing_occupancy_status": {"2017-01": "Tenant"}}` is now accepted and prefered to `{"housing_occupancy_status": {"2017-01": 0}}`). + - For instance, `{"housing_occupancy_status": {"2017-01": "Tenant"}}` is now accepted and preferred to `{"housing_occupancy_status": {"2017-01": 0}}`). - Using the enum indices as inputs is _still accepted_ for backward compatibility, but _should not_ be encouraged. * Return strings instead of enum indices. - For instance, is `housing_occupancy_status` is calculated for `2017-01`, `{"housing_occupancy_status": {"2017-01": "Tenant"}}` is now returned, instead of `{"housing_occupancy_status": {"2017-01": 0}}`. @@ -2734,7 +3335,7 @@ In the preview web API, for variables of type `Enum`: - This attribute is the legislative reference of a variable. - As previously, this attribute can be a string, or a list of strings. * Rename `Variable` attribute `reference` to `baseline_variable` - - This attibute is, for a variable defined in a reform, the baseline variable the reform variable is replacing. + - This attribute is, for a variable defined in a reform, the baseline variable the reform variable is replacing. * Remove variable attribute `law_reference` * Rename `TaxBenefitSystem.reference` to `TaxBenefitSystem.baseline` * Rename `TaxBenefitSystem.get_reference_compact_legislation` to `TaxBenefitSystem.get_baseline_compact_legislation` @@ -2803,7 +3404,7 @@ In the preview web API, for variables of type `Enum`: - These functionalities are now provided by `entity.get_holder(name)` - Deprecate constructor `Holder(simulation, column)` - - A `Holder` should now be instanciated with `Holder(entity = entity, column = column)` + - A `Holder` should now be instantiated with `Holder(entity = entity, column = column)` ### 14.0.1 - [#527](https://github.com/openfisca/openfisca-core/pull/527) @@ -3012,7 +3613,7 @@ These breaking changes only concern variable and tax and benefit system **metada # 9.0.0 -* Make sure identic periods are stringified the same way +* Make sure identical periods are stringified the same way * _Breaking changes_: - Change `periods.period` signature. - It now only accepts strings. @@ -3049,7 +3650,7 @@ These breaking changes only concern variable and tax and benefit system **metada ## 6.1.0 * Move `base.py` content (file usually located in country packages) to core module `formula_toolbox` so that it can be reused by all countries -* Use `AbstractScenario` if no custom scenario is defined for a tax and benefit sytem +* Use `AbstractScenario` if no custom scenario is defined for a tax and benefit system # 6.0.0 @@ -3091,7 +3692,7 @@ These breaking changes only concern variable and tax and benefit system **metada * Improve `openfisca-run-test` script - Make country package detection more robust (it only worked for packages installed in editable mode) - Use spaces instead of commas as separator in the script arguments when loading several extensions or reforms (this is more standard) -* Refactor the `scripts` module to seperate the logic specific to yaml test running from the one that can be re-used by any script which needs to build a tax and benefit system. +* Refactor the `scripts` module to separate the logic specific to yaml test running from the one that can be re-used by any script which needs to build a tax and benefit system. # 5.0.0 @@ -3109,7 +3710,7 @@ These breaking changes only concern variable and tax and benefit system **metada ### 4.3.4 -* Fix occasionnal `NaN` creation in `MarginalRateTaxScale.calc` resulting from `0 * np.inf` +* Fix occasionnal `NaN` creation in `MarginalRateTaxScale.calc` resulting from `0 * numpy.inf` ### 4.3.3 @@ -3168,8 +3769,8 @@ Unlike simple formulas, a `DatedVariable` have several functions. We thus need t ### 4.1.2-Beta * Enable simulation initialization with only legacy roles - * New roles are in this case automatically infered - * Positions are always infered from persons entity id + * New roles are in this case automatically inferred + * Positions are always inferred from persons entity id ### 4.1.1-Beta @@ -3269,7 +3870,7 @@ Unlike simple formulas, a `DatedVariable` have several functions. We thus need t # 2.0.0 – [diff](https://github.com/openfisca/openfisca-core/compare/1.1.0...2.0.0) -* Variables are not added to the TaxBenefitSystem when the entities class are imported, but explicitely when the TaxBenefitSystem is instanciated. +* Variables are not added to the TaxBenefitSystem when the entities class are imported, but explicitly when the TaxBenefitSystem is instantiated. * Metaclasses are not used anymore. * New API for TaxBenefitSystem * Columns are now stored in the TaxBenefitSystem, not in entities. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8ba10fd606..7d20ce44ce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -182,7 +182,7 @@ def get(self, key: str) -> Any: * Return None when key is not found. .. versionadded:: 1.2.3 - This will help people to undestand the code evolution. + This will help people to understand the code evolution. .. deprecated:: 2.3.4 This, to have time to adapt their own codebases before the code is diff --git a/MANIFEST.in b/MANIFEST.in index 507d218461..166788d7fa 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,2 @@ -graft requirements -include openfisca_web_api/openAPI.yml recursive-include openfisca_core/scripts * +include openfisca_web_api/openAPI.yml diff --git a/Makefile b/Makefile index b5c73a5ff8..9271f9431a 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,6 @@ include openfisca_tasks/lint.mk include openfisca_tasks/publish.mk include openfisca_tasks/serve.mk include openfisca_tasks/test_code.mk -include openfisca_tasks/test_doc.mk ## To share info with the user, but no action is needed. print_info = $$(tput setaf 6)[i]$$(tput sgr0) @@ -14,7 +13,7 @@ print_warn = $$(tput setaf 3)[!]$$(tput sgr0) ## To let the user know where we are in the task pipeline. print_work = $$(tput setaf 5)[⚙]$$(tput sgr0) -## To let the user know the task in progress succeded. +## To let the user know the task in progress succeeded. ## The `$1` is a function argument, passed from a task (usually the task name). print_pass = echo $$(tput setaf 2)[✓]$$(tput sgr0) $$(tput setaf 8)$1$$(tput sgr0)$$(tput setaf 2)passed$$(tput sgr0) $$(tput setaf 1)❤$$(tput sgr0) diff --git a/README.md b/README.md index c38a614754..6911c73318 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,16 @@ # OpenFisca Core -[![Newsletter](https://img.shields.io/badge/newsletter-subscribe!-informational.svg?style=flat)](mailto:contact%40openfisca.org?subject=Subscribe%20to%20your%20newsletter%20%7C%20S'inscrire%20%C3%A0%20votre%20newsletter&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0AEnvoyez-nous%20cet%20email%20pour%20que%20l'on%20puisse%20vous%20inscrire%20%C3%A0%20la%20newsletter.%20%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20send%20us%20this%20email%2C%20so%20we%20can%20subscribe%20you%20to%20the%20newsletter.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) -[![Twitter](https://img.shields.io/badge/twitter-follow%20us!-9cf.svg?style=flat)](https://twitter.com/intent/follow?screen_name=openfisca) -[![Slack](https://img.shields.io/badge/slack-join%20us!-blueviolet.svg?style=flat)](mailto:contact%40openfisca.org?subject=Join%20you%20on%20Slack%20%7C%20Nous%20rejoindre%20sur%20Slack&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0ARacontez-nous%20un%20peu%20de%20vous%2C%20et%20du%20pourquoi%20de%20votre%20int%C3%A9r%C3%AAt%20de%20rejoindre%20la%20communaut%C3%A9%20OpenFisca%20sur%20Slack.%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AN%E2%80%99oubliez%20pas%20de%20nous%20envoyer%20cet%20email%C2%A0!%20Sinon%2C%20on%20ne%20pourra%20pas%20vous%20contacter%20ni%20vous%20inviter%20sur%20Slack.%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20tell%20us%20a%20bit%20about%20you%20and%20why%20you%20want%20to%20join%20the%20OpenFisca%20community%20on%20Slack.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2.%0A%0ADon't%20forget%20to%20send%20us%20this%20email!%20Otherwise%20we%20won't%20be%20able%20to%20contact%20you%20back%2C%20nor%20invite%20you%20on%20Slack.%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) +[![PyPi Downloads](https://img.shields.io/pypi/dm/openfisca-core?label=pypi%2Fdownloads&style=for-the-badge)](https://pepy.tech/project/openfisca-core) +[![PyPi Version](https://img.shields.io/pypi/v/openfisca-core.svg?label=pypi%2Fversion&style=for-the-badge)](https://pypi.python.org/pypi/openfisca-core) +[![Conda Downloads](https://img.shields.io/conda/dn/conda-forge/openfisca-core?label=conda%2Fdownloads&style=for-the-badge)](https://anaconda.org/conda-forge/openfisca-core) +[![Conda Version](https://img.shields.io/conda/vn/conda-forge/openfisca-core.svg?label=conda%2Fversion&style=for-the-badge)](https://anaconda.org/conda-forge/openfisca-core) -[![CircleCI](https://img.shields.io/circleci/project/github/openfisca/openfisca-core/master.svg?style=flat)](https://circleci.com/gh/openfisca/openfisca-core) -[![Coveralls](https://img.shields.io/coveralls/github/openfisca/openfisca-core/master.svg?style=flat)](https://coveralls.io/github/openfisca/openfisca-core?branch=master) -[![Python](https://img.shields.io/pypi/pyversions/openfisca-core.svg)](https://pypi.python.org/pypi/openfisca-core) -[![PyPi](https://img.shields.io/pypi/v/openfisca-core.svg?style=flat)](https://pypi.python.org/pypi/openfisca-core) +[![Python](https://img.shields.io/pypi/pyversions/openfisca-core.svg?label=python&style=for-the-badge)](https://pypi.python.org/pypi/openfisca-core) +[![Contributors](https://img.shields.io/github/contributors/openfisca/openfisca-core.svg?style=for-the-badge)](https://github.com/openfisca/openfisca-core/graphs/contributors) + +[![Newsletter](https://img.shields.io/badge/newsletter-subscribe!-informational.svg?style=for-the-badge)](mailto:contact%40openfisca.org?subject=Subscribe%20to%20your%20newsletter%20%7C%20S'inscrire%20%C3%A0%20votre%20newsletter&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0AEnvoyez-nous%20cet%20email%20pour%20que%20l'on%20puisse%20vous%20inscrire%20%C3%A0%20la%20newsletter.%20%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20send%20us%20this%20email%2C%20so%20we%20can%20subscribe%20you%20to%20the%20newsletter.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) +[![Twitter](https://img.shields.io/badge/twitter-follow%20us!-9cf.svg?style=for-the-badge)](https://twitter.com/intent/follow?screen_name=openfisca) +[![Slack](https://img.shields.io/badge/slack-join%20us!-blueviolet.svg?style=for-the-badge)](mailto:contact%40openfisca.org?subject=Join%20you%20on%20Slack%20%7C%20Nous%20rejoindre%20sur%20Slack&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0ARacontez-nous%20un%20peu%20de%20vous%2C%20et%20du%20pourquoi%20de%20votre%20int%C3%A9r%C3%AAt%20de%20rejoindre%20la%20communaut%C3%A9%20OpenFisca%20sur%20Slack.%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AN%E2%80%99oubliez%20pas%20de%20nous%20envoyer%20cet%20email%C2%A0!%20Sinon%2C%20on%20ne%20pourra%20pas%20vous%20contacter%20ni%20vous%20inviter%20sur%20Slack.%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20tell%20us%20a%20bit%20about%20you%20and%20why%20you%20want%20to%20join%20the%20OpenFisca%20community%20on%20Slack.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2.%0A%0ADon't%20forget%20to%20send%20us%20this%20email!%20Otherwise%20we%20won't%20be%20able%20to%20contact%20you%20back%2C%20nor%20invite%20you%20on%20Slack.%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) [OpenFisca](https://openfisca.org/doc/) is a versatile microsimulation free software. Check the [online documentation](https://openfisca.org/doc/) for more details. @@ -15,26 +18,69 @@ This package contains the core features of OpenFisca, which are meant to be used ## Environment -OpenFisca runs on Python 3.7. More recent versions should work, but are not tested. +OpenFisca runs on Python 3.7. More recent versions should work but are not tested. -OpenFisca also relies strongly on NumPy. Only upper and lower bound versions are tested. +OpenFisca also relies strongly on NumPy. The last four minor versions should work, but only the latest/stable is tested. ## Installation -If you're developing your own country package, you don't need to explicitly install OpenFisca-Core. It just needs to appear [in your package dependencies](https://github.com/openfisca/openfisca-france/blob/18.2.1/setup.py#L53). +If you're developing your own country package, you don't need to explicitly install OpenFisca-Core. It just needs to appear [in your package dependencies](https://github.com/openfisca/openfisca-france/blob/100.0.0/setup.py#L60). +If you want to contribute to OpenFisca-Core itself, welcome! +To install it locally you can use one of these two options: +* [conda](https://docs.conda.io/en/latest/) package manager that we recommend for Windows operating system users, +* or standard Python [pip](https://packaging.python.org/en/latest/key_projects/#pip) package manager. + +### Installing `openfisca-core` with `pip` -If you want to contribute to OpenFisca-Core itself, welcome! To install it locally in development mode run the following commands: +This installation requires [Python](https://www.python.org/downloads/) 3.7+ and [GIT](https://git-scm.com) installations. + +To install `openfisca-core` locally in development mode run the following commands in a shell terminal: ```bash git clone https://github.com/openfisca/openfisca-core.git cd openfisca-core python3 -m venv .venv source .venv/bin/activate -make install +make install-deps install-edit ``` +### Installing `openfisca-core` with `conda` + +Since `openfisca-core` version [35.7.7](https://anaconda.org/conda-forge/openfisca-core), you could use `conda` to install OpenFisca-Core. + +Conda is the easiest way to use OpenFisca under Windows as by installing Anaconda you will get: +- Python +- The package manager [Anaconda.org](https://docs.anaconda.com/anacondaorg/user-guide/) +- A virtual environment manager : [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) +- A GUI [Anaconda Navigator](https://docs.anaconda.com/anaconda/navigator/index.html) if you choose to install the full [Anaconda](https://www.anaconda.com/products/individual) + +If you are familiar with the command line you could use [Miniconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/windows.html), which needs very much less disk space than Anaconda. + +After installing conda, run these commands in an `Anaconda Powershell Prompt`: +- `conda create --name openfisca python=3.7` to create an `openfisca` environment. +- `conda activate openfisca` to use your new environment. + +Then, choose one of the following options according to your use case: +- `conda install -c conda-forge openfisca-core` for default dependencies, +- or `conda install -c conda-forge openfisca-core-api` if you want the Web API part, +- or `conda install -c conda-forge -c openfisca openfisca-core-dev` if you want all the dependencies needed to contribute to the project. + +For information on how we publish to conda-forge, see [openfisca-core-feedstock](https://github.com/openfisca/openfisca-core-feedstock/blob/master/recipe/README.md). + ## Testing +Install the test dependencies: + +``` +make install-deps install-edit install-test +``` + +> For integration testing purposes, `openfisca-core` relies on +> [country-template](https://github.com/openfisca/country-template.git) and +> [extension-template](https://github.com/openfisca/extension-template.git). +> Because these packages rely at the same time on `openfisca-core`, they need +> to be installed separately. + To run the entire test suite: ```sh @@ -44,10 +90,10 @@ make test To run all the tests defined on a test file: ```sh -openfisca test tests/core/test_parameters.py +pytest tests/core/test_parameters.py ``` -You can also use `pytest`, for example to run a single test: +To run a single test: ```sh pytest tests/core/test_parameters.py -k test_parameter_for_period @@ -57,6 +103,8 @@ pytest tests/core/test_parameters.py -k test_parameter_for_period This repository relies on MyPy for optional dynamic & static type checking. +As NumPy introduced the `typing` module in 1.20.0, to ensure type hints do not break the code at runtime, we run the checker against the last four minor NumPy versions. + Type checking is already run with `make test`. To run the type checker alone: ```sh @@ -95,74 +143,10 @@ END ## Documentation -Yet however OpenFisca does not follow a common convention for docstrings, our current toolchain allows to check whether documentation builds correctly and to update it automatically with each contribution to this repository. +OpenFisca’s toolchain checks whether documentation builds correctly and updates it automatically with each contribution to this repository. In the meantime, please take a look at our [contributing guidelines](CONTRIBUTING.md) for some general tips on how to document your contributions, and at our official documentation's [repository](https://github.com/openfisca/openfisca-doc/blob/master/README.md) to in case you want to know how to build it by yourself —and improve it! -### To verify that the documentation still builds correctly - -You can run: - -```sh -make test-doc -``` - -### If it doesn't, or if the doc is already broken. - -Here's how you can fix it: - -1. Clone the documentation, if not yet done: - -``` -make test-doc-checkout -``` - -2. Install the documentation's dependencies, if not yet done: - -``` -make test-doc-install -``` - -3. Create a branch, both in core and in the doc, to correct the problems: - -``` -git checkout -b fix-doc -sh -c "cd doc && git checkout -b `git branch --show-current`" -``` - -4. Fix the offending problems —they could be in core, in the doc, or in both. - -You can test-drive your fixes by checking that each change works as expected: - -``` -make test-doc-build branch=`git branch --show-current` -``` - -5. Commit at each step, so you don't accidentally lose your progress: - -``` -git add -A && git commit -m "Fix outdated argument for Entity" -sh -c "cd doc && git add -A && git commit -m \"Fix outdated argument for Entity\"" -``` - -6. Once you're done, push your changes and cleanup: - -``` -git push origin `git branch --show-current` -sh -c "cd doc && git push origin `git branch --show-current`" -rm -rf doc -``` - -7. Finally, open a pull request both in [core](https://github.com/openfisca/openfisca-core/compare/master...fix-doc) and in the [doc](https://github.com/openfisca/openfisca-doc/compare/master...fix-doc). - -[CircleCI](.circleci/config.yml) will automatically try to build the documentation from the same branch in both core and the doc (in our example "fix-doc") so we can integrate first our changes to core, and then our changes to the doc. - -If no changes were needed to the doc, then your changes to core will be verified against the production version of the doc. - -If your changes concern only the doc, please take a look at the doc's [README](https://github.com/openfisca/openfisca-doc/blob/master/README.md). - -That's it! 🙌 - ## Serving the API OpenFisca-Core provides a Web-API. It is by default served on the `5000` port. @@ -193,15 +177,9 @@ The OpenFisca Web API comes with an [optional tracker](https://github.com/openfi The tracker is not installed by default. To install it, run: ```sh -pip install openfisca_core[tracker] +pip install openfisca_core[tracker] --use-deprecated=legacy-resolver # Or `pip install --editable ".[tracker]"` for an editable installation ``` -Or for an editable installation: - -``` -make install -make install-tracker -``` #### Tracker configuration @@ -209,7 +187,7 @@ The tracker is activated when these two options are set: * `--tracker-url`: An URL ending with `piwik.php`. It defines the Piwik instance that will receive the tracking information. To use the main OpenFisca Piwik instance, use `https://stats.data.gouv.fr/piwik.php`. * `--tracker-idsite`: An integer. It defines the identifier of the tracked site on your Piwik instance. To use the main OpenFisca piwik instance, use `4`. -* `--tracker-token`: A string. It defines the Piwik API Authentification token to differentiate API calls based on the user IP. Otherwise, all API calls will seem to come from your server. The Piwik API Authentification token can be found in your Piwik interface, when you are logged. +* `--tracker-token`: A string. It defines the Piwik API Authentication token to differentiate API calls based on the user IP. Otherwise, all API calls will seem to come from your server. The Piwik API Authentication token can be found in your Piwik interface when you are logged in. For instance, to run the Web API with the mock country package `openfisca_country_template` and the tracker activated, run: diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md index 97b0461aa3..cffca2cfc3 100644 --- a/STYLEGUIDE.md +++ b/STYLEGUIDE.md @@ -1,6 +1,6 @@ # OpenFisca's Python Style Guide -Arguments over code style and formatting are the bread and butter of most open-source projets out there, including OpenFisca. +Arguments over code style and formatting are the bread and butter of most open-source projects out there, including OpenFisca. To avoid this, we have a style guide, that is a set or arbitrary but consistent conventions about how code should be written, for contributors and maintainers alike. diff --git a/conftest.py b/conftest.py index 3adc794111..fbe03e7d37 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ pytest_plugins = [ "tests.fixtures.appclient", "tests.fixtures.entities", + "tests.fixtures.extensions", "tests.fixtures.simulations", "tests.fixtures.taxbenefitsystems", - ] +] diff --git a/openfisca_core/commons/__init__.py b/openfisca_core/commons/__init__.py index c2927dea22..d7bdff5f71 100644 --- a/openfisca_core/commons/__init__.py +++ b/openfisca_core/commons/__init__.py @@ -1,58 +1,18 @@ -"""Common tools for contributors and users. - -The tools in this sub-package are intended, to help both contributors -to OpenFisca Core and to country packages. - -Official Public API: - * :func:`.apply_thresholds` - * :func:`.average_rate` - * :func:`.concat` - * :func:`.empty_clone` - * :func:`.marginal_rate` - * :func:`.stringify_array` - * :func:`.switch` - -Note: - The ``deprecated`` imports are transitional, in order to ensure - non-breaking changes, and could be removed from the codebase in the next - major release. - -Note: - How imports are being used today:: - - from openfisca_core.commons import * # Bad - from openfisca_core.commons.formulas import switch # Bad - from openfisca_core.commons.decorators import deprecated # Bad - - - The previous examples provoke cyclic dependency problems, that prevent us - from modularizing the different components of the library, which would make - them easier to test and to maintain. - - How they could be used in a future release: - - from openfisca_core import commons - from openfisca_core.commons import deprecated - - deprecated() # Good: import classes as publicly exposed - commons.switch() # Good: use functions as publicly exposed - - .. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. - - .. _PEP8#Imports: - https://www.python.org/dev/peps/pep-0008/#imports - - .. _OpenFisca's Styleguide: - https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md - -""" - -# Official Public API - -from .formulas import apply_thresholds, concat, switch # noqa: F401 -from .misc import empty_clone, stringify_array # noqa: F401 -from .rates import average_rate, marginal_rate # noqa: F401 - -__all__ = ["apply_thresholds", "concat", "switch"] -__all__ = ["empty_clone", "stringify_array", *__all__] -__all__ = ["average_rate", "marginal_rate", *__all__] +"""Common tools for contributors and users.""" + +from . import types +from .formulas import apply_thresholds, concat, switch +from .misc import empty_clone, eval_expression, stringify_array +from .rates import average_rate, marginal_rate + +__all__ = [ + "apply_thresholds", + "average_rate", + "concat", + "empty_clone", + "eval_expression", + "marginal_rate", + "stringify_array", + "switch", + "types", +] diff --git a/openfisca_core/commons/formulas.py b/openfisca_core/commons/formulas.py index 6a90622147..1df8039410 100644 --- a/openfisca_core/commons/formulas.py +++ b/openfisca_core/commons/formulas.py @@ -1,17 +1,17 @@ -from typing import Any, Dict, Sequence, TypeVar +from __future__ import annotations -import numpy +from collections.abc import Mapping -from openfisca_core.types import ArrayLike, ArrayType +import numpy -T = TypeVar("T") +from . import types as t def apply_thresholds( - input: ArrayType[float], - thresholds: ArrayLike[float], - choices: ArrayLike[float], - ) -> ArrayType[float]: + input: t.Array[numpy.float32], + thresholds: t.ArrayLike[float], + choices: t.ArrayLike[float], +) -> t.Array[numpy.float32]: """Makes a choice based on an input and thresholds. From a list of ``choices``, this function selects one of these values @@ -24,12 +24,7 @@ def apply_thresholds( choices: A list of the possible values to choose from. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - A list of the values chosen. - - Raises: - :exc:`AssertionError`: When the number of ``thresholds`` (t) and the - number of choices (c) are not either t == c or t == c - 1. + ndarray[float32]: A list of the values chosen. Examples: >>> input = numpy.array([4, 5, 6, 7, 8]) @@ -39,8 +34,7 @@ def apply_thresholds( array([10, 10, 15, 15, 20]) """ - - condlist: Sequence[ArrayType[bool]] + condlist: list[t.Array[numpy.bool_] | bool] condlist = [input <= threshold for threshold in thresholds] if len(condlist) == len(choices) - 1: @@ -48,25 +42,27 @@ def apply_thresholds( # must be true to return it. condlist += [True] - assert len(condlist) == len(choices), \ - " ".join([ - "'apply_thresholds' must be called with the same number of", - "thresholds than choices, or one more choice.", - ]) + msg = ( + "'apply_thresholds' must be called with the same number of thresholds " + "than choices, or one more choice." + ) + assert len(condlist) == len(choices), msg return numpy.select(condlist, choices) -def concat(this: ArrayLike[str], that: ArrayLike[str]) -> ArrayType[str]: - """Concatenates the values of two arrays. +def concat( + this: t.Array[numpy.str_] | t.ArrayLike[object], + that: t.Array[numpy.str_] | t.ArrayLike[object], +) -> t.Array[numpy.str_]: + """Concatenate the values of two arrays. Args: this: An array to concatenate. that: Another array to concatenate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - An array with the concatenated values. + ndarray[str_]: An array with the concatenated values. Examples: >>> this = ["this", "that"] @@ -75,25 +71,26 @@ def concat(this: ArrayLike[str], that: ArrayLike[str]) -> ArrayType[str]: array(['this1.0', 'that2.5']...) """ + if not isinstance(this, numpy.ndarray): + this = numpy.array(this) - if isinstance(this, numpy.ndarray) and \ - not numpy.issubdtype(this.dtype, numpy.str_): + if not numpy.issubdtype(this.dtype, numpy.str_): + this = this.astype("str") - this = this.astype('str') + if not isinstance(that, numpy.ndarray): + that = numpy.array(that) - if isinstance(that, numpy.ndarray) and \ - not numpy.issubdtype(that.dtype, numpy.str_): - - that = that.astype('str') + if not numpy.issubdtype(that.dtype, numpy.str_): + that = that.astype("str") return numpy.char.add(this, that) def switch( - conditions: ArrayType[Any], - value_by_condition: Dict[float, T], - ) -> ArrayType[T]: - """Mimicks a switch statement. + conditions: t.Array[numpy.float32] | t.ArrayLike[float], + value_by_condition: Mapping[float, float], +) -> t.Array[numpy.float32]: + """Mimic a switch statement. Given an array of conditions, returns an array of the same size, replacing each condition item with the matching given value. @@ -103,11 +100,7 @@ def switch( value_by_condition: Values to replace for each condition. Returns: - :obj:`numpy.ndarray`: - An array with the replaced values. - - Raises: - :exc:`AssertionError`: When ``value_by_condition`` is empty. + ndarray[float32]: An array with the replaced values. Examples: >>> conditions = numpy.array([1, 1, 1, 2]) @@ -116,13 +109,13 @@ def switch( array([80, 80, 80, 90]) """ + assert ( + len(value_by_condition) > 0 + ), "'switch' must be called with at least one value." + + condlist = [conditions == condition for condition in value_by_condition] - assert len(value_by_condition) > 0, \ - "'switch' must be called with at least one value." + return numpy.select(condlist, tuple(value_by_condition.values())) - condlist = [ - conditions == condition - for condition in value_by_condition.keys() - ] - return numpy.select(condlist, value_by_condition.values()) +__all__ = ["apply_thresholds", "concat", "switch"] diff --git a/openfisca_core/commons/misc.py b/openfisca_core/commons/misc.py index dd05cea11b..e3e55948d5 100644 --- a/openfisca_core/commons/misc.py +++ b/openfisca_core/commons/misc.py @@ -1,18 +1,19 @@ -from typing import TypeVar +from __future__ import annotations -from openfisca_core.types import ArrayType +import numexpr +import numpy -T = TypeVar("T") +from openfisca_core import types as t -def empty_clone(original: T) -> T: - """Creates an empty instance of the same class of the original object. +def empty_clone(original: object) -> object: + """Create an empty instance of the same class of the original object. Args: original: An object to clone. Returns: - The cloned, empty, object. + object: The cloned, empty, object. Examples: >>> Foo = type("Foo", (list,), {}) @@ -29,36 +30,35 @@ def empty_clone(original: T) -> T: """ - Dummy: object - new: T + def __init__(_: object) -> None: ... Dummy = type( "Dummy", (original.__class__,), - {"__init__": lambda self: None}, - ) + {"__init__": __init__}, + ) new = Dummy() new.__class__ = original.__class__ return new -def stringify_array(array: ArrayType) -> str: - """Generates a clean string representation of a numpy array. +def stringify_array(array: None | t.Array[numpy.generic]) -> str: + """Generate a clean string representation of a numpy array. Args: array: An array. Returns: - :obj:`str`: - "None" if the ``array`` is None, the stringified ``array`` otherwise. + str: ``"None"`` if the ``array`` is ``None``. + str: The stringified ``array`` otherwise. Examples: >>> import numpy >>> stringify_array(None) 'None' - >>> array = numpy.array([10, 20.]) + >>> array = numpy.array([10, 20.0]) >>> stringify_array(array) '[10.0, 20.0]' @@ -71,8 +71,37 @@ def stringify_array(array: ArrayType) -> str: "[, {}, str | t.Array[numpy.bool_] | t.Array[numpy.int32] | t.Array[numpy.float32]: + """Evaluate a string expression to a numpy array. + + Args: + expression: An expression to evaluate. + + Returns: + ndarray: The result of the evaluation. + str: The expression if it couldn't be evaluated. + + Examples: + >>> eval_expression("1 + 2") + array(3, dtype=int32) + + >>> eval_expression("salary") + 'salary' + + """ + try: + return numexpr.evaluate(expression) + + except (KeyError, TypeError): + return expression + + +__all__ = ["empty_clone", "eval_expression", "stringify_array"] diff --git a/tests/web_api/case_with_extension/__init__.py b/openfisca_core/commons/py.typed similarity index 100% rename from tests/web_api/case_with_extension/__init__.py rename to openfisca_core/commons/py.typed diff --git a/openfisca_core/commons/rates.py b/openfisca_core/commons/rates.py index d682824207..1d17c77352 100644 --- a/openfisca_core/commons/rates.py +++ b/openfisca_core/commons/rates.py @@ -1,16 +1,16 @@ -from typing import Optional +from __future__ import annotations import numpy -from openfisca_core.types import ArrayLike, ArrayType +from . import types as t def average_rate( - target: ArrayType[float], - varying: ArrayLike[float], - trim: Optional[ArrayLike[float]] = None, - ) -> ArrayType[float]: - """Computes the average rate of a target net income. + target: t.Array[numpy.float32], + varying: t.Array[numpy.float32] | t.ArrayLike[float], + trim: None | t.ArrayLike[float] = None, +) -> t.Array[numpy.float32]: + """Compute the average rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross income. Optionally, a ``trim`` can be applied consisting of the lower and @@ -25,49 +25,45 @@ def average_rate( trim: The lower and upper bounds of the average rate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - - The average rate for each target. - - When ``trim`` is provided, values that are out of the provided bounds - are replaced by :obj:`numpy.nan`. + ndarray[float32]: The average rate for each target. When ``trim`` + is provided, values that are out of the provided bounds are + replaced by :obj:`numpy.nan`. Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = [2, 2, 2] - >>> trim = [-1, .25] + >>> trim = [-1, 0.25] >>> average_rate(target, varying, trim) array([ nan, 0. , -0.5]) """ - - average_rate: ArrayType[float] + if not isinstance(varying, numpy.ndarray): + varying = numpy.array(varying, dtype=numpy.float32) average_rate = 1 - target / varying if trim is not None: - average_rate = numpy.where( average_rate <= max(trim), average_rate, numpy.nan, - ) + ) average_rate = numpy.where( average_rate >= min(trim), average_rate, numpy.nan, - ) + ) return average_rate def marginal_rate( - target: ArrayType[float], - varying: ArrayType[float], - trim: Optional[ArrayLike[float]] = None, - ) -> ArrayType[float]: - """Computes the marginal rate of a target net income. + target: t.Array[numpy.float32], + varying: t.Array[numpy.float32] | t.ArrayLike[float], + trim: None | t.ArrayLike[float] = None, +) -> t.Array[numpy.float32]: + """Compute the marginal rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross income. Optionally, a ``trim`` can be applied consisting of the lower and @@ -82,42 +78,37 @@ def marginal_rate( trim: The lower and upper bounds of the marginal rate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - - The marginal rate for each target. - - When ``trim`` is provided, values that are out of the provided bounds - are replaced by :obj:`numpy.nan`. + ndarray[float32]: The marginal rate for each target. When ``trim`` + is provided, values that are out of the provided bounds are + replaced by :class:`numpy.nan`. Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = numpy.array([1, 2, 4]) - >>> trim = [.25, .75] + >>> trim = [0.25, 0.75] >>> marginal_rate(target, varying, trim) array([nan, 0.5]) """ + if not isinstance(varying, numpy.ndarray): + varying = numpy.array(varying, dtype=numpy.float32) - marginal_rate: ArrayType[float] - - marginal_rate = ( - + 1 - - (target[:-1] - target[1:]) - / (varying[:-1] - varying[1:]) - ) + marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:]) if trim is not None: - marginal_rate = numpy.where( marginal_rate <= max(trim), marginal_rate, numpy.nan, - ) + ) marginal_rate = numpy.where( marginal_rate >= min(trim), marginal_rate, numpy.nan, - ) + ) return marginal_rate + + +__all__ = ["average_rate", "marginal_rate"] diff --git a/openfisca_core/commons/tests/test_formulas.py b/openfisca_core/commons/tests/test_formulas.py index f05725cb80..6fa98a7c20 100644 --- a/openfisca_core/commons/tests/test_formulas.py +++ b/openfisca_core/commons/tests/test_formulas.py @@ -5,9 +5,8 @@ from openfisca_core import commons -def test_apply_thresholds_when_several_inputs(): - """Makes a choice for any given input.""" - +def test_apply_thresholds_when_several_inputs() -> None: + """Make a choice for any given input.""" input_ = numpy.array([4, 5, 6, 7, 8, 9, 10]) thresholds = [5, 7, 9] choices = [10, 15, 20, 25] @@ -17,9 +16,8 @@ def test_apply_thresholds_when_several_inputs(): assert_array_equal(result, [10, 10, 15, 15, 20, 20, 25]) -def test_apply_thresholds_when_too_many_thresholds(): - """Raises an AssertionError when thresholds > choices.""" - +def test_apply_thresholds_when_too_many_thresholds() -> None: + """Raise an AssertionError when thresholds > choices.""" input_ = numpy.array([6]) thresholds = [5, 7, 9, 11] choices = [10, 15, 20] @@ -28,9 +26,8 @@ def test_apply_thresholds_when_too_many_thresholds(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_apply_thresholds_when_too_many_choices(): - """Raises an AssertionError when thresholds < choices - 1.""" - +def test_apply_thresholds_when_too_many_choices() -> None: + """Raise an AssertionError when thresholds < choices - 1.""" input_ = numpy.array([6]) thresholds = [5, 7] choices = [10, 15, 20, 25] @@ -39,9 +36,8 @@ def test_apply_thresholds_when_too_many_choices(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_concat_when_this_is_array_not_str(): - """Casts ``this`` to ``str`` when it is a numpy array other than string.""" - +def test_concat_when_this_is_array_not_str() -> None: + """Cast ``this`` to ``str`` when it is a NumPy array other than string.""" this = numpy.array([1, 2]) that = numpy.array(["la", "o"]) @@ -50,9 +46,8 @@ def test_concat_when_this_is_array_not_str(): assert_array_equal(result, ["1la", "2o"]) -def test_concat_when_that_is_array_not_str(): - """Casts ``that`` to ``str`` when it is a numpy array other than string.""" - +def test_concat_when_that_is_array_not_str() -> None: + """Cast ``that`` to ``str`` when it is a NumPy array other than string.""" this = numpy.array(["ho", "cha"]) that = numpy.array([1, 2]) @@ -61,19 +56,18 @@ def test_concat_when_that_is_array_not_str(): assert_array_equal(result, ["ho1", "cha2"]) -def test_concat_when_args_not_str_array_like(): - """Raises a TypeError when args are not a string array-like object.""" - +def test_concat_when_args_not_str_array_like() -> None: + """Cast ``this`` and ``that`` to a NumPy array or strings.""" this = (1, 2) that = (3, 4) - with pytest.raises(TypeError): - commons.concat(this, that) + result = commons.concat(this, that) + assert_array_equal(result, ["13", "24"]) -def test_switch_when_values_are_empty(): - """Raises an AssertionError when the values are empty.""" +def test_switch_when_values_are_empty() -> None: + """Raise an AssertionError when the values are empty.""" conditions = [1, 1, 1, 2] value_by_condition = {} diff --git a/openfisca_core/commons/tests/test_rates.py b/openfisca_core/commons/tests/test_rates.py index e603a05241..fbee4cc83c 100644 --- a/openfisca_core/commons/tests/test_rates.py +++ b/openfisca_core/commons/tests/test_rates.py @@ -1,26 +1,26 @@ +import math + import numpy from numpy.testing import assert_array_equal from openfisca_core import commons -def test_average_rate_when_varying_is_zero(): - """Yields infinity when the varying gross income crosses zero.""" - +def test_average_rate_when_varying_is_zero() -> None: + """Yield infinity when the varying gross income crosses zero.""" target = numpy.array([1, 2, 3]) varying = [0, 0, 0] result = commons.average_rate(target, varying) - assert_array_equal(result, [- numpy.inf, - numpy.inf, - numpy.inf]) - + assert_array_equal(result, numpy.array([-math.inf, -math.inf, -math.inf])) -def test_marginal_rate_when_varying_is_zero(): - """Yields infinity when the varying gross income crosses zero.""" +def test_marginal_rate_when_varying_is_zero() -> None: + """Yield infinity when the varying gross income crosses zero.""" target = numpy.array([1, 2, 3]) varying = numpy.array([0, 0, 0]) result = commons.marginal_rate(target, varying) - assert_array_equal(result, [numpy.inf, numpy.inf]) + assert_array_equal(result, numpy.array([math.inf, math.inf])) diff --git a/openfisca_core/commons/types.py b/openfisca_core/commons/types.py new file mode 100644 index 0000000000..39c067f455 --- /dev/null +++ b/openfisca_core/commons/types.py @@ -0,0 +1,3 @@ +from openfisca_core.types import Array, ArrayLike + +__all__ = ["Array", "ArrayLike"] diff --git a/openfisca_core/data_storage/__init__.py b/openfisca_core/data_storage/__init__.py index e2b4d8911d..4dbbb89543 100644 --- a/openfisca_core/data_storage/__init__.py +++ b/openfisca_core/data_storage/__init__.py @@ -1,25 +1,7 @@ -# Transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. -# -# How imports are being used today: -# -# >>> from openfisca_core.module import symbol -# -# The previous example provokes cyclic dependency problems -# that prevent us from modularizing the different components -# of the library so to make them easier to test and to maintain. -# -# How could them be used after the next major release: -# -# >>> from openfisca_core import module -# >>> module.symbol() -# -# And for classes: -# -# >>> from openfisca_core.module import Symbol -# >>> Symbol() -# -# See: https://www.python.org/dev/peps/pep-0008/#imports +"""Different storage backends for the data of a simulation.""" -from .in_memory_storage import InMemoryStorage # noqa: F401 -from .on_disk_storage import OnDiskStorage # noqa: F401 +from . import types +from .in_memory_storage import InMemoryStorage +from .on_disk_storage import OnDiskStorage + +__all__ = ["InMemoryStorage", "OnDiskStorage", "types"] diff --git a/openfisca_core/data_storage/in_memory_storage.py b/openfisca_core/data_storage/in_memory_storage.py index bd40460a56..d4d5240c92 100644 --- a/openfisca_core/data_storage/in_memory_storage.py +++ b/openfisca_core/data_storage/in_memory_storage.py @@ -1,20 +1,62 @@ +from __future__ import annotations + +from collections.abc import KeysView, MutableMapping + import numpy from openfisca_core import periods +from openfisca_core.periods import DateUnit + +from . import types as t class InMemoryStorage: + """Storing and retrieving calculated vectors in memory. + + Args: + is_eternal: Whether the storage is eternal. + """ - Low-level class responsible for storing and retrieving calculated vectors in memory - """ - def __init__(self, is_eternal = False): + #: Whether the storage is eternal. + is_eternal: bool + + #: A dictionary containing data that has been stored in memory. + _arrays: MutableMapping[t.Period, t.Array[t.DTypeGeneric]] + + def __init__(self, is_eternal: bool = False) -> None: self._arrays = {} self.is_eternal = is_eternal - def get(self, period): + def get(self, period: None | t.Period = None) -> None | t.Array[t.DTypeGeneric]: + """Retrieve the data for the specified :obj:`.Period` from memory. + + Args: + period: The :obj:`.Period` for which data should be retrieved. + + Returns: + None: If no data is available. + EnumArray: The data for the specified :obj:`.Period`. + ndarray[generic]: The data for the specified :obj:`.Period`. + + Examples: + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> storage = data_storage.InMemoryStorage() + >>> value = numpy.array([1, 2, 3]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> storage.put(value, period) + + >>> storage.get(period) + array([1, 2, 3]) + + """ if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) values = self._arrays.get(period) @@ -22,43 +64,135 @@ def get(self, period): return None return values - def put(self, value, period): + def put(self, value: t.Array[t.DTypeGeneric], period: None | t.Period) -> None: + """Store the specified data in memory for the specified :obj:`.Period`. + + Args: + value: The data to store + period: The :obj:`.Period` for which the data should be stored. + + Examples: + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> storage = data_storage.InMemoryStorage() + >>> value = numpy.array([1, "2", "salary"]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> storage.put(value, period) + + >>> storage.get(period) + array(['1', '2', 'salary'], ...) + + """ if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) self._arrays[period] = value - def delete(self, period = None): + def delete(self, period: None | t.Period = None) -> None: + """Delete the data for the specified :obj:`.Period` from memory. + + Args: + period: The :obj:`.Period` for which data should be deleted. + + Note: + If ``period`` is specified, all data will be deleted. + + Examples: + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> storage = data_storage.InMemoryStorage() + >>> value = numpy.array([1, 2, 3]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> storage.put(value, period) + + >>> storage.get(period) + array([1, 2, 3]) + + >>> storage.delete(period) + + >>> storage.get(period) + + >>> storage.put(value, period) + + >>> storage.delete() + + >>> storage.get(period) + + """ if period is None: self._arrays = {} return if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) self._arrays = { period_item: value for period_item, value in self._arrays.items() if not period.contains(period_item) - } + } + + def get_known_periods(self) -> KeysView[t.Period]: + """List of storage's known periods. - def get_known_periods(self): + Returns: + KeysView[Period]: A sequence containing the storage's known periods. + + Examples: + >>> from openfisca_core import data_storage, periods + + >>> storage = data_storage.InMemoryStorage() + >>> storage.get_known_periods() + dict_keys([]) + + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + >>> storage.put([], period) + + >>> storage.get_known_periods() + dict_keys([Period(('year', Instant((2017, 1, 1)), 1))]) + + """ return self._arrays.keys() - def get_memory_usage(self): + def get_memory_usage(self) -> t.MemoryUsage: + """Memory usage of the storage. + + Returns: + MemoryUsage: A dictionary representing the storage's memory usage. + + Examples: + >>> from openfisca_core import data_storage + + >>> storage = data_storage.InMemoryStorage() + >>> storage.get_memory_usage() + {'nb_arrays': 0, 'total_nb_bytes': 0, 'cell_size': nan} + + """ if not self._arrays: - return dict( - nb_arrays = 0, - total_nb_bytes = 0, - cell_size = numpy.nan, - ) + return { + "nb_arrays": 0, + "total_nb_bytes": 0, + "cell_size": numpy.nan, + } nb_arrays = len(self._arrays) array = next(iter(self._arrays.values())) - return dict( - nb_arrays = nb_arrays, - total_nb_bytes = array.nbytes * nb_arrays, - cell_size = array.itemsize, - ) + return { + "nb_arrays": nb_arrays, + "total_nb_bytes": array.nbytes * nb_arrays, + "cell_size": array.itemsize, + } + + +__all__ = ["InMemoryStorage"] diff --git a/openfisca_core/data_storage/on_disk_storage.py b/openfisca_core/data_storage/on_disk_storage.py index 10d4696b58..a13ce37fc6 100644 --- a/openfisca_core/data_storage/on_disk_storage.py +++ b/openfisca_core/data_storage/on_disk_storage.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from collections.abc import KeysView, MutableMapping + import os import shutil @@ -5,30 +9,127 @@ from openfisca_core import periods from openfisca_core.indexed_enums import EnumArray +from openfisca_core.periods import DateUnit + +from . import types as t class OnDiskStorage: + """Storing and retrieving calculated vectors on disk. + + Args: + storage_dir: Path to store calculated vectors. + is_eternal: Whether the storage is eternal. + preserve_storage_dir: Whether to preserve the storage directory. + """ - Low-level class responsible for storing and retrieving calculated vectors on disk - """ - def __init__(self, storage_dir, is_eternal = False, preserve_storage_dir = False): + #: A dictionary containing data that has been stored on disk. + storage_dir: str + + #: Whether the storage is eternal. + is_eternal: bool + + #: Whether to preserve the storage directory. + preserve_storage_dir: bool + + #: Mapping of file paths to possible :class:`.Enum` values. + _enums: MutableMapping[str, type[t.Enum]] + + #: Mapping of periods to file paths. + _files: MutableMapping[t.Period, str] + + def __init__( + self, + storage_dir: str, + is_eternal: bool = False, + preserve_storage_dir: bool = False, + ) -> None: self._files = {} self._enums = {} self.is_eternal = is_eternal self.preserve_storage_dir = preserve_storage_dir self.storage_dir = storage_dir - def _decode_file(self, file): + def _decode_file(self, file: str) -> t.Array[t.DTypeGeneric]: + """Decode a file by loading its contents as a :mod:`numpy` array. + + Args: + file: Path to the file to be decoded. + + Returns: + EnumArray: Representing the data in the file. + ndarray[generic]: Representing the data in the file. + + Note: + If the file is associated with :class:`~indexed_enums.Enum` values, the + array is converted back to an :obj:`~indexed_enums.EnumArray` object. + + Examples: + >>> import tempfile + + >>> import numpy + + >>> from openfisca_core import data_storage, indexed_enums, periods + + >>> class Housing(indexed_enums.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + ... FREE_LODGER = "Free lodger" + ... HOMELESS = "Homeless" + + >>> array = numpy.array([1]) + >>> value = indexed_enums.EnumArray(array, Housing) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage._decode_file(storage._files[period]) + EnumArray([Housing.TENANT]) + + """ enum = self._enums.get(file) + if enum is not None: return EnumArray(numpy.load(file), enum) - else: - return numpy.load(file) - def get(self, period): + array: t.Array[t.DTypeGeneric] = numpy.load(file) + + return array + + def get(self, period: None | t.Period = None) -> None | t.Array[t.DTypeGeneric]: + """Retrieve the data for the specified period from disk. + + Args: + period: The period for which data should be retrieved. + + Returns: + None: If no data is available. + EnumArray: Representing the data for the specified period. + ndarray[generic]: Representing the data for the specified period. + + Examples: + >>> import tempfile + + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> value = numpy.array([1, 2, 3]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage.get(period) + array([1, 2, 3]) + + """ if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) values = self._files.get(period) @@ -36,50 +137,166 @@ def get(self, period): return None return self._decode_file(values) - def put(self, value, period): + def put(self, value: t.Array[t.DTypeGeneric], period: None | t.Period) -> None: + """Store the specified data on disk for the specified period. + + Args: + value: The data to store + period: The period for which the data should be stored. + + Examples: + >>> import tempfile + + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> value = numpy.array([1, "2", "salary"]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage.get(period) + array(['1', '2', 'salary'], ...) + + """ if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) filename = str(period) - path = os.path.join(self.storage_dir, filename) + '.npy' - if isinstance(value, EnumArray): + path = os.path.join(self.storage_dir, filename) + ".npy" + if isinstance(value, EnumArray) and value.possible_values is not None: self._enums[path] = value.possible_values value = value.view(numpy.ndarray) numpy.save(path, value) self._files[period] = path - def delete(self, period = None): + def delete(self, period: None | t.Period = None) -> None: + """Delete the data for the specified period from disk. + + Args: + period: The period for which data should be deleted. If not + specified, all data will be deleted. + + Examples: + >>> import tempfile + + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> value = numpy.array([1, 2, 3]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage.get(period) + array([1, 2, 3]) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage.delete(period) + ... storage.get(period) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put(value, period) + ... storage.delete() + ... storage.get(period) + + """ if period is None: self._files = {} return if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) - if period is not None: - self._files = { - period_item: value - for period_item, value in self._files.items() - if not period.contains(period_item) - } + self._files = { + period_item: value + for period_item, value in self._files.items() + if not period.contains(period_item) + } + + def get_known_periods(self) -> KeysView[t.Period]: + """List of storage's known periods. + + Returns: + KeysView[Period]: A sequence containing the storage's known periods. + + Examples: + >>> import tempfile - def get_known_periods(self): + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.get_known_periods() + dict_keys([]) + + >>> with tempfile.TemporaryDirectory() as directory: + ... storage = data_storage.OnDiskStorage(directory) + ... storage.put([], period) + ... storage.get_known_periods() + dict_keys([Period(('year', Instant((2017, 1, 1)), 1))]) + + """ return self._files.keys() - def restore(self): + def restore(self) -> None: + """Restore the storage from disk. + + Examples: + >>> import tempfile + + >>> import numpy + + >>> from openfisca_core import data_storage, periods + + >>> value = numpy.array([1, 2, 3]) + >>> instant = periods.Instant((2017, 1, 1)) + >>> period = periods.Period(("year", instant, 1)) + >>> directory = tempfile.TemporaryDirectory() + + >>> storage1 = data_storage.OnDiskStorage(directory.name) + >>> storage1.put(value, period) + >>> storage1._files + {Period(('year', Instant((2017, 1, 1)), 1)): '...2017.npy'} + + >>> storage2 = data_storage.OnDiskStorage(directory.name) + >>> storage2._files + {} + + >>> storage2.restore() + >>> storage2._files + {Period((, Instant((2017, 1, 1...2017.npy'} + + >>> directory.cleanup() + + """ self._files = files = {} # Restore self._files from content of storage_dir. for filename in os.listdir(self.storage_dir): - if not filename.endswith('.npy'): + if not filename.endswith(".npy"): continue path = os.path.join(self.storage_dir, filename) - filename_core = filename.rsplit('.', 1)[0] + filename_core = filename.rsplit(".", 1)[0] period = periods.period(filename_core) files[period] = path - def __del__(self): + def __del__(self) -> None: if self.preserve_storage_dir: return shutil.rmtree(self.storage_dir) # Remove the holder temporary files @@ -87,3 +304,6 @@ def __del__(self): parent_dir = os.path.abspath(os.path.join(self.storage_dir, os.pardir)) if not os.listdir(parent_dir): shutil.rmtree(parent_dir) + + +__all__ = ["OnDiskStorage"] diff --git a/openfisca_core/data_storage/types.py b/openfisca_core/data_storage/types.py new file mode 100644 index 0000000000..db71abbf57 --- /dev/null +++ b/openfisca_core/data_storage/types.py @@ -0,0 +1,14 @@ +from typing_extensions import TypedDict + +from openfisca_core.types import Array, DTypeGeneric, Enum, Period + + +class MemoryUsage(TypedDict, total=True): + """Memory usage information.""" + + cell_size: float + nb_arrays: int + total_nb_bytes: int + + +__all__ = ["Array", "DTypeGeneric", "Enum", "Period"] diff --git a/openfisca_core/entities/__init__.py b/openfisca_core/entities/__init__.py index 15b38e2a5c..1811e3fe94 100644 --- a/openfisca_core/entities/__init__.py +++ b/openfisca_core/entities/__init__.py @@ -1,27 +1,23 @@ -# Transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. -# -# How imports are being used today: -# -# >>> from openfisca_core.module import symbol -# -# The previous example provokes cyclic dependency problems -# that prevent us from modularizing the different components -# of the library so to make them easier to test and to maintain. -# -# How could them be used after the next major release: -# -# >>> from openfisca_core import module -# >>> module.symbol() -# -# And for classes: -# -# >>> from openfisca_core import module -# >>> module.Symbol() -# -# See: https://www.python.org/dev/peps/pep-0008/#imports +"""Provide a way of representing the entities of a rule system.""" -from .helpers import build_entity # noqa: F401 -from .role import Role # noqa: F401 -from .entity import Entity # noqa: F401 -from .group_entity import GroupEntity # noqa: F401 +from . import types +from ._core_entity import CoreEntity +from .entity import Entity +from .group_entity import GroupEntity +from .helpers import build_entity, find_role +from .role import Role + +SingleEntity = Entity +check_role_validity = CoreEntity.check_role_validity + +__all__ = [ + "CoreEntity", + "Entity", + "GroupEntity", + "Role", + "SingleEntity", + "build_entity", + "check_role_validity", + "find_role", + "types", +] diff --git a/openfisca_core/entities/_core_entity.py b/openfisca_core/entities/_core_entity.py new file mode 100644 index 0000000000..33002e9af5 --- /dev/null +++ b/openfisca_core/entities/_core_entity.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import ClassVar + +import abc +import os + +from . import types as t +from .role import Role + + +class CoreEntity: + """Base class to build entities from. + + Args: + *__args: Any arguments. + **__kwargs: Any keyword arguments. + + Examples: + >>> from openfisca_core import entities + >>> from openfisca_core.entities import types as t + + >>> class Entity(entities.CoreEntity): + ... def __init__(self, key): + ... self.key = t.EntityKey(key) + + >>> Entity("individual") + Entity(individual) + + """ + + #: A key to identify the ``CoreEntity``. + key: t.EntityKey + + #: The ``key`` pluralised. + plural: t.EntityPlural + + #: A summary description. + label: str + + #: A full description. + doc: str + + #: Whether the ``CoreEntity`` is a person or not. + is_person: ClassVar[bool] + + #: A ``TaxBenefitSystem`` instance. + _tax_benefit_system: None | t.TaxBenefitSystem = None + + @abc.abstractmethod + def __init__(self, *__args: object, **__kwargs: object) -> None: ... + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.key})" + + def set_tax_benefit_system(self, tax_benefit_system: t.TaxBenefitSystem) -> None: + """A ``CoreEntity`` belongs to a ``TaxBenefitSystem``.""" + self._tax_benefit_system = tax_benefit_system + + def get_variable( + self, + variable_name: t.VariableName, + check_existence: bool = False, + ) -> t.Variable | None: + """Get ``variable_name`` from ``variables``. + + Args: + variable_name: The ``Variable`` to be found. + check_existence: Was the ``Variable`` found? + + Returns: + Variable: When the ``Variable`` exists. + None: When the ``Variable`` doesn't exist. + + Raises: + ValueError: When the :attr:`_tax_benefit_system` is not set yet. + ValueError: When ``check_existence`` is ``True`` and + the ``Variable`` doesn't exist. + + Examples: + >>> from openfisca_core import ( + ... entities, + ... periods, + ... taxbenefitsystems, + ... variables, + ... ) + + >>> this = entities.SingleEntity("this", "", "", "") + >>> that = entities.SingleEntity("that", "", "", "") + + >>> this.get_variable("tax") + Traceback (most recent call last): + ValueError: You must set 'tax_benefit_system' before calling this... + + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem([this]) + >>> this.set_tax_benefit_system(tax_benefit_system) + + >>> this.get_variable("tax") + + >>> this.get_variable("tax", check_existence=True) + Traceback (most recent call last): + VariableNotFoundError: You tried to calculate or to set a value... + + >>> class tax(variables.Variable): + ... definition_period = periods.MONTH + ... value_type = float + ... entity = that + + >>> this._tax_benefit_system.add_variable(tax) + + + >>> this.get_variable("tax") + + + """ + if self._tax_benefit_system is None: + msg = "You must set 'tax_benefit_system' before calling this method." + raise ValueError( + msg, + ) + return self._tax_benefit_system.get_variable(variable_name, check_existence) + + def check_variable_defined_for_entity(self, variable_name: t.VariableName) -> None: + """Check if ``variable_name`` is defined for ``self``. + + Args: + variable_name: The ``Variable`` to be found. + + Raises: + ValueError: When the ``Variable`` exists but is defined + for another ``Entity``. + + Examples: + >>> from openfisca_core import ( + ... entities, + ... periods, + ... taxbenefitsystems, + ... variables, + ... ) + + >>> this = entities.SingleEntity("this", "", "", "") + >>> that = entities.SingleEntity("that", "", "", "") + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem([that]) + >>> this.set_tax_benefit_system(tax_benefit_system) + + >>> this.check_variable_defined_for_entity("tax") + Traceback (most recent call last): + VariableNotFoundError: You tried to calculate or to set a value... + + >>> class tax(variables.Variable): + ... definition_period = periods.WEEK + ... value_type = int + ... entity = that + + >>> this._tax_benefit_system.add_variable(tax) + + + >>> this.check_variable_defined_for_entity("tax") + Traceback (most recent call last): + ValueError: You tried to compute the variable 'tax' for the enti... + + >>> tax.entity = this + + >>> this._tax_benefit_system.update_variable(tax) + + + >>> this.check_variable_defined_for_entity("tax") + + """ + entity: None | t.CoreEntity = None + variable: None | t.Variable = self.get_variable( + variable_name, + check_existence=True, + ) + + if variable is not None: + entity = variable.entity + + if entity is None: + return + + if entity.key != self.key: + message = ( + f"You tried to compute the variable '{variable_name}' for", + f"the entity '{self.plural}'; however the variable", + f"'{variable_name}' is defined for '{entity.plural}'.", + "Learn more about entities in our documentation:", + ".", + ) + raise ValueError(os.linesep.join(message)) + + @staticmethod + def check_role_validity(role: object) -> None: + """Check if ``role`` is an instance of ``Role``. + + Args: + role: Any object. + + Raises: + ValueError: When ``role`` is not a ``Role``. + + Examples: + >>> from openfisca_core import entities + + >>> role = entities.Role({"key": "key"}, object()) + >>> entities.check_role_validity(role) + + >>> entities.check_role_validity("hey!") + Traceback (most recent call last): + ValueError: hey! is not a valid role + + """ + if role is not None and not isinstance(role, Role): + msg = f"{role} is not a valid role" + raise ValueError(msg) + + +__all__ = ["CoreEntity"] diff --git a/openfisca_core/entities/_description.py b/openfisca_core/entities/_description.py new file mode 100644 index 0000000000..6e2d68af1b --- /dev/null +++ b/openfisca_core/entities/_description.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import dataclasses +import textwrap + + +@dataclasses.dataclass(frozen=True) +class _Description: + r"""A ``Role``'s description. + + Examples: + >>> data = { + ... "key": "parent", + ... "label": "Parents", + ... "plural": "parents", + ... "doc": "\t\t\tThe one/two adults in charge of the household.", + ... } + + >>> description = _Description(**data) + + >>> repr(_Description) + "" + + >>> repr(description) + "_Description(key='parent', plural='parents', label='Parents', ...)" + + >>> str(description) + "_Description(key='parent', plural='parents', label='Parents', ...)" + + >>> {description} + {_Description(key='parent', plural='parents', label='Parents', doc=...} + + >>> description.key + 'parent' + + """ + + #: A key to identify the ``Role``. + key: str + + #: The ``key`` pluralised. + plural: None | str = None + + #: A summary description. + label: None | str = None + + #: A full description, non-indented. + doc: None | str = None + + def __post_init__(self) -> None: + if self.doc is not None: + object.__setattr__(self, "doc", textwrap.dedent(self.doc)) + + +__all__ = ["_Description"] diff --git a/openfisca_core/entities/entity.py b/openfisca_core/entities/entity.py index c0da47f9bc..673aae48b7 100644 --- a/openfisca_core/entities/entity.py +++ b/openfisca_core/entities/entity.py @@ -1,40 +1,61 @@ -import os +from typing import ClassVar + import textwrap -from openfisca_core.entities import Role +from . import types as t +from ._core_entity import CoreEntity -class Entity: - """ - Represents an entity (e.g. a person, a household, etc.) on which calculations can be run. +class Entity(CoreEntity): + r"""An entity (e.g. a person, a household) on which calculations can be run. + + Args: + key: A key to identify the ``Entity``. + plural: The ``key`` pluralised. + label: A summary description. + doc: A full description. + + Examples: + >>> from openfisca_core import entities + + >>> entity = entities.SingleEntity( + ... "individual", + ... "individuals", + ... "An individual", + ... "\t\t\tThe minimal legal entity on which a rule might be a...", + ... ) + + >>> repr(entities.SingleEntity) + "" + + >>> repr(entity) + 'Entity(individual)' + + >>> str(entity) + 'Entity(individual)' + """ - def __init__(self, key, plural, label, doc): - self.key = key + #: A key to identify the ``Entity``. + key: t.EntityKey + + #: The ``key`` pluralised. + plural: t.EntityPlural + + #: A summary description. + label: str + + #: A full description. + doc: str + + #: Whether the ``Entity`` is a person or not. + is_person: ClassVar[bool] = True + + def __init__(self, key: str, plural: str, label: str, doc: str) -> None: + self.key = t.EntityKey(key) + self.plural = t.EntityPlural(plural) self.label = label - self.plural = plural self.doc = textwrap.dedent(doc) - self.is_person = True - self._tax_benefit_system = None - - def set_tax_benefit_system(self, tax_benefit_system): - self._tax_benefit_system = tax_benefit_system - - def check_role_validity(self, role): - if role is not None and not type(role) == Role: - raise ValueError("{} is not a valid role".format(role)) - - def get_variable(self, variable_name, check_existence = False): - return self._tax_benefit_system.get_variable(variable_name, check_existence) - - def check_variable_defined_for_entity(self, variable_name): - variable_entity = self.get_variable(variable_name, check_existence = True).entity - # Should be this: - # if variable_entity is not self: - if variable_entity.key != self.key: - message = os.linesep.join([ - "You tried to compute the variable '{0}' for the entity '{1}';".format(variable_name, self.plural), - "however the variable '{0}' is defined for '{1}'.".format(variable_name, variable_entity.plural), - "Learn more about entities in our documentation:", - "."]) - raise ValueError(message) + + +__all__ = ["Entity"] diff --git a/openfisca_core/entities/group_entity.py b/openfisca_core/entities/group_entity.py index 0d58acc6ba..796da105ee 100644 --- a/openfisca_core/entities/group_entity.py +++ b/openfisca_core/entities/group_entity.py @@ -1,43 +1,124 @@ -from openfisca_core.entities import Entity, Role +from __future__ import annotations +from collections.abc import Iterable, Sequence +from typing import ClassVar -class GroupEntity(Entity): - """Represents an entity containing several others with different roles. +import textwrap +from itertools import chain - A :class:`.GroupEntity` represents an :class:`.Entity` containing - several other :class:`.Entity` with different :class:`.Role`, and on - which calculations can be run. +from . import types as t +from ._core_entity import CoreEntity +from .role import Role + + +class GroupEntity(CoreEntity): + r"""Represents an entity containing several others with different roles. + + A ``GroupEntity`` represents an ``Entity`` containing several other entities, + with different roles, and on which calculations can be run. Args: - key: A key to identify the group entity. - plural: The ``key``, pluralised. + key: A key to identify the ``GroupEntity``. + plural: The ``key`` pluralised. label: A summary description. doc: A full description. - roles: The list of :class:`.Role` of the group entity. + roles: The list of roles of the group entity. containing_entities: The list of keys of group entities whose members are guaranteed to be a superset of this group's entities. - .. versionchanged:: 35.7.0 - Added ``containing_entities``, that allows the defining of group - entities which entirely contain other group entities. + Examples: + >>> from openfisca_core import entities + + >>> family_roles = [ + ... { + ... "key": "parent", + ... "subroles": ["first_parent", "second_parent"], + ... } + ... ] + + >>> family = entities.GroupEntity( + ... "family", + ... "families", + ... "A family", + ... "\t\t\tAll the people somehow related living together.", + ... family_roles, + ... ) + + >>> household_roles = [ + ... { + ... "key": "partners", + ... "subroles": ["first_partner", "second_partner"], + ... } + ... ] + + >>> household = entities.GroupEntity( + ... "household", + ... "households", + ... "A household", + ... "\t\t\tAll the people who live together in the same place.", + ... household_roles, + ... (family.key,), + ... ) + + >>> repr(entities.GroupEntity) + "" + + >>> repr(household) + 'GroupEntity(household)' + + >>> str(household) + 'GroupEntity(household)' """ - def __init__(self, key, plural, label, doc, roles, containing_entities = ()): - super().__init__(key, plural, label, doc) + #: A key to identify the ``Entity``. + key: t.EntityKey + + #: The ``key`` pluralised. + plural: t.EntityPlural + + #: A summary description. + label: str + + #: A full description. + doc: str + + #: The list of roles of the ``GroupEntity``. + roles: Iterable[Role] + + #: Whether the entity is a person or not. + is_person: ClassVar[bool] = False + + def __init__( + self, + key: str, + plural: str, + label: str, + doc: str, + roles: Sequence[t.RoleParams], + containing_entities: Iterable[str] = (), + ) -> None: + self.key = t.EntityKey(key) + self.plural = t.EntityPlural(plural) + self.label = label + self.doc = textwrap.dedent(doc) self.roles_description = roles - self.roles = [] + self.roles: Iterable[Role] = () for role_description in roles: role = Role(role_description, self) setattr(self, role.key.upper(), role) - self.roles.append(role) - if role_description.get('subroles'): - role.subroles = [] - for subrole_key in role_description['subroles']: - subrole = Role({'key': subrole_key, 'max': 1}, self) + self.roles = (*self.roles, role) + if subroles := role_description.get("subroles"): + role.subroles = () + for subrole_key in subroles: + subrole = Role({"key": subrole_key, "max": 1}, self) setattr(self, subrole.key.upper(), subrole) - role.subroles.append(subrole) + role.subroles = (*role.subroles, subrole) role.max = len(role.subroles) - self.flattened_roles = sum([role2.subroles or [role2] for role2 in self.roles], []) - self.is_person = False + self.flattened_roles = tuple( + chain.from_iterable(role.subroles or [role] for role in self.roles), + ) self.containing_entities = containing_entities + + +__all__ = ["GroupEntity"] diff --git a/openfisca_core/entities/helpers.py b/openfisca_core/entities/helpers.py index 86d7bb6a6b..1dcdad88a3 100644 --- a/openfisca_core/entities/helpers.py +++ b/openfisca_core/entities/helpers.py @@ -1,8 +1,165 @@ -from openfisca_core import entities +from __future__ import annotations +from collections.abc import Iterable, Sequence -def build_entity(key, plural, label, doc = "", roles = None, is_person = False, class_override = None, containing_entities = ()): +from . import types as t +from .entity import Entity as SingleEntity +from .group_entity import GroupEntity + + +def build_entity( + key: str, + plural: str, + label: str, + doc: str = "", + roles: None | Sequence[t.RoleParams] = None, + is_person: bool = False, + *, + class_override: object = None, + containing_entities: Sequence[str] = (), +) -> t.SingleEntity | t.GroupEntity: + """Build an ``Entity`` or a ``GroupEntity``. + + Args: + key: Key to identify the ``Entity`` or ``GroupEntity``. + plural: The ``key`` pluralised. + label: A summary description. + doc: A full description. + roles: A list of roles —if it's a ``GroupEntity``. + is_person: If is an individual, or not. + class_override: ? + containing_entities: Keys of contained entities. + + Returns: + Entity: When ``is_person`` is ``True``. + GroupEntity: When ``is_person`` is ``False``. + + Raises: + NotImplementedError: If ``roles`` is ``None``. + + Examples: + >>> from openfisca_core import entities + + >>> entity = entities.build_entity( + ... "syndicate", + ... "syndicates", + ... "Banks loaning jointly.", + ... roles=[], + ... containing_entities=(), + ... ) + >>> entity + GroupEntity(syndicate) + + >>> entities.build_entity( + ... "company", + ... "companies", + ... "A small or medium company.", + ... is_person=True, + ... ) + Entity(company) + + >>> role = entities.Role({"key": "key"}, entity) + + >>> entities.build_entity( + ... "syndicate", + ... "syndicates", + ... "Banks loaning jointly.", + ... roles=[role], + ... ) + Traceback (most recent call last): + TypeError: 'Role' object is not subscriptable + + """ if is_person: - return entities.Entity(key, plural, label, doc) - else: - return entities.GroupEntity(key, plural, label, doc, roles, containing_entities = containing_entities) + return SingleEntity(key, plural, label, doc) + + if roles is not None: + return GroupEntity( + key, + plural, + label, + doc, + roles, + containing_entities=containing_entities, + ) + + raise NotImplementedError + + +def find_role( + roles: Iterable[t.Role], + key: t.RoleKey, + *, + total: None | int = None, +) -> None | t.Role: + """Find a ``Role`` in a ``GroupEntity``. + + Args: + roles: The roles to search. + key: The key of the role to find. + total: The ``max`` attribute of the role to find. + + Returns: + Role: The role if found + None: Else ``None``. + + Examples: + >>> from openfisca_core import entities + >>> from openfisca_core.entities import types as t + + >>> principal = t.RoleParams( + ... key="principal", + ... label="Principal", + ... doc="Person focus of a calculation in a family context.", + ... max=1, + ... ) + + >>> partner = t.RoleParams( + ... key="partner", + ... plural="partners", + ... label="Partners", + ... doc="Persons partners of the principal.", + ... ) + + >>> parent = t.RoleParams( + ... key="parent", + ... plural="parents", + ... label="Parents", + ... doc="Persons parents of children of the principal", + ... subroles=["first_parent", "second_parent"], + ... ) + + >>> group_entity = entities.build_entity( + ... key="family", + ... plural="families", + ... label="Family", + ... doc="A Family represents a collection of related persons.", + ... roles=[principal, partner, parent], + ... ) + + >>> entities.find_role(group_entity.roles, "principal", total=1) + Role(principal) + + >>> entities.find_role(group_entity.roles, "partner") + Role(partner) + + >>> entities.find_role(group_entity.roles, "parent", total=2) + Role(parent) + + >>> entities.find_role(group_entity.roles, "first_parent", total=1) + Role(first_parent) + + """ + for role in roles: + if role.subroles: + for subrole in role.subroles: + if (subrole.max == total) and (subrole.key == key): + return subrole + + if (role.max == total) and (role.key == key): + return role + + return None + + +__all__ = ["build_entity", "find_role"] diff --git a/tests/web_api/loader/__init__.py b/openfisca_core/entities/py.typed similarity index 100% rename from tests/web_api/loader/__init__.py rename to openfisca_core/entities/py.typed diff --git a/openfisca_core/entities/role.py b/openfisca_core/entities/role.py index ea815ed513..39bd5090ed 100644 --- a/openfisca_core/entities/role.py +++ b/openfisca_core/entities/role.py @@ -1,16 +1,92 @@ -import textwrap +from __future__ import annotations + +from collections.abc import Iterable + +from . import types as t +from ._description import _Description class Role: + """The role of an ``Entity`` within a ``GroupEntity``. + + Each ``Entity`` related to a ``GroupEntity`` has a ``Role``. For example, + if you have a family, its roles could include a parent, a child, and so on. + Or if you have a tax household, its roles could include the taxpayer, a + spouse, several dependents, and the like. + + Args: + description: A description of the Role. + entity: The Entity to which the Role belongs. + + Examples: + >>> from openfisca_core import entities + + >>> entity = entities.GroupEntity("key", "plural", "label", "doc", []) + >>> role = entities.Role({"key": "parent"}, entity) + + >>> repr(entities.Role) + "" + + >>> repr(role) + 'Role(parent)' + + >>> str(role) + 'Role(parent)' + + >>> {role} + {Role(parent)} + + >>> role.key + 'parent' + + """ - def __init__(self, description, entity): + #: The ``GroupEntity`` the Role belongs to. + entity: t.GroupEntity + + #: A description of the ``Role``. + description: _Description + + #: Max number of members. + max: None | int = None + + #: A list of subroles. + subroles: None | Iterable[Role] = None + + @property + def key(self) -> t.RoleKey: + """A key to identify the ``Role``.""" + return t.RoleKey(self.description.key) + + @property + def plural(self) -> None | t.RolePlural: + """The ``key`` pluralised.""" + if (plural := self.description.plural) is None: + return None + return t.RolePlural(plural) + + @property + def label(self) -> None | str: + """A summary description.""" + return self.description.label + + @property + def doc(self) -> None | str: + """A full description, non-indented.""" + return self.description.doc + + def __init__(self, description: t.RoleParams, entity: t.GroupEntity) -> None: + self.description = _Description( + key=description["key"], + plural=description.get("plural"), + label=description.get("label"), + doc=description.get("doc"), + ) self.entity = entity - self.key = description['key'] - self.label = description.get('label') - self.plural = description.get('plural') - self.doc = textwrap.dedent(description.get('doc', "")) - self.max = description.get('max') - self.subroles = None - - def __repr__(self): - return "Role({})".format(self.key) + self.max = description.get("max") + + def __repr__(self) -> str: + return f"Role({self.key})" + + +__all__ = ["Role"] diff --git a/openfisca_core/entities/tests/__init__.py b/openfisca_core/entities/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/entities/tests/test_entity.py b/openfisca_core/entities/tests/test_entity.py new file mode 100644 index 0000000000..b3cb813ddc --- /dev/null +++ b/openfisca_core/entities/tests/test_entity.py @@ -0,0 +1,10 @@ +from openfisca_core import entities + + +def test_init_when_doc_indented() -> None: + """De-indent the ``doc`` attribute if it is passed at initialisation.""" + key = "\tkey" + doc = "\tdoc" + entity = entities.Entity(key, "label", "plural", doc) + assert entity.key == key + assert entity.doc == doc.lstrip() diff --git a/openfisca_core/entities/tests/test_group_entity.py b/openfisca_core/entities/tests/test_group_entity.py new file mode 100644 index 0000000000..092c9d3575 --- /dev/null +++ b/openfisca_core/entities/tests/test_group_entity.py @@ -0,0 +1,70 @@ +from collections.abc import Mapping +from typing import Any + +import pytest + +from openfisca_core import entities + + +@pytest.fixture +def parent() -> str: + return "parent" + + +@pytest.fixture +def uncle() -> str: + return "uncle" + + +@pytest.fixture +def first_parent() -> str: + return "first_parent" + + +@pytest.fixture +def second_parent() -> str: + return "second_parent" + + +@pytest.fixture +def third_parent() -> str: + return "third_parent" + + +@pytest.fixture +def role(parent: str, first_parent: str, third_parent: str) -> Mapping[str, Any]: + return {"key": parent, "subroles": {first_parent, third_parent}} + + +@pytest.fixture +def group_entity(role: Mapping[str, Any]) -> entities.GroupEntity: + return entities.GroupEntity("key", "label", "plural", "doc", (role,)) + + +def test_init_when_doc_indented() -> None: + """De-indent the ``doc`` attribute if it is passed at initialisation.""" + key = "\tkey" + doc = "\tdoc" + group_entity = entities.GroupEntity(key, "label", "plural", doc, ()) + assert group_entity.key == key + assert group_entity.doc == doc.lstrip() + + +def test_group_entity_with_roles( + group_entity: entities.GroupEntity, + parent: str, + uncle: str, +) -> None: + """Assign a Role for each role-like passed as argument.""" + assert hasattr(group_entity, parent.upper()) + assert not hasattr(group_entity, uncle.upper()) + + +def test_group_entity_with_subroles( + group_entity: entities.GroupEntity, + first_parent: str, + second_parent: str, +) -> None: + """Assign a Role for each subrole-like passed as argument.""" + assert hasattr(group_entity, first_parent.upper()) + assert not hasattr(group_entity, second_parent.upper()) diff --git a/openfisca_core/entities/tests/test_role.py b/openfisca_core/entities/tests/test_role.py new file mode 100644 index 0000000000..454d862c70 --- /dev/null +++ b/openfisca_core/entities/tests/test_role.py @@ -0,0 +1,11 @@ +from openfisca_core import entities + + +def test_init_when_doc_indented() -> None: + """De-indent the ``doc`` attribute if it is passed at initialisation.""" + key = "\tkey" + doc = "\tdoc" + entity = entities.GroupEntity("key", "plural", "label", "doc", []) + role = entities.Role({"key": key, "doc": doc}, entity) + assert role.key == key + assert role.doc == doc.lstrip() diff --git a/openfisca_core/entities/types.py b/openfisca_core/entities/types.py new file mode 100644 index 0000000000..ef6af9024f --- /dev/null +++ b/openfisca_core/entities/types.py @@ -0,0 +1,42 @@ +from typing_extensions import Required, TypedDict + +from openfisca_core.types import ( + CoreEntity, + EntityKey, + EntityPlural, + GroupEntity, + Role, + RoleKey, + RolePlural, + SingleEntity, + TaxBenefitSystem, + Variable, + VariableName, +) + +# Entities + + +class RoleParams(TypedDict, total=False): + key: Required[str] + plural: str + label: str + doc: str + max: int + subroles: list[str] + + +__all__ = [ + "CoreEntity", + "EntityKey", + "EntityPlural", + "GroupEntity", + "Role", + "RoleKey", + "RoleParams", + "RolePlural", + "SingleEntity", + "TaxBenefitSystem", + "Variable", + "VariableName", +] diff --git a/openfisca_core/errors/__init__.py b/openfisca_core/errors/__init__.py index e5b9abbc78..afe88980d9 100644 --- a/openfisca_core/errors/__init__.py +++ b/openfisca_core/errors/__init__.py @@ -21,13 +21,26 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .cycle_error import CycleError # noqa: F401 -from .empty_argument_error import EmptyArgumentError # noqa: F401 -from .nan_creation_error import NaNCreationError # noqa: F401 -from .parameter_not_found_error import ParameterNotFoundError # noqa: F401 -from .parameter_parsing_error import ParameterParsingError # noqa: F401 -from .period_mismatch_error import PeriodMismatchError # noqa: F401 -from .situation_parsing_error import SituationParsingError # noqa: F401 -from .spiral_error import SpiralError # noqa: F401 -from .variable_name_config_error import VariableNameConflictError # noqa: F401 -from .variable_not_found_error import VariableNotFoundError # noqa: F401 +from .cycle_error import CycleError +from .empty_argument_error import EmptyArgumentError +from .nan_creation_error import NaNCreationError +from .parameter_not_found_error import ParameterNotFoundError +from .parameter_parsing_error import ParameterParsingError +from .period_mismatch_error import PeriodMismatchError +from .situation_parsing_error import SituationParsingError +from .spiral_error import SpiralError +from .variable_name_config_error import VariableNameConflictError +from .variable_not_found_error import VariableNotFoundError + +__all__ = [ + "CycleError", + "EmptyArgumentError", + "NaNCreationError", + "ParameterNotFoundError", + "ParameterParsingError", + "PeriodMismatchError", + "SituationParsingError", + "SpiralError", + "VariableNameConflictError", + "VariableNotFoundError", +] diff --git a/openfisca_core/errors/cycle_error.py b/openfisca_core/errors/cycle_error.py index b4d44b5993..b81cc7b3f9 100644 --- a/openfisca_core/errors/cycle_error.py +++ b/openfisca_core/errors/cycle_error.py @@ -1,4 +1,2 @@ class CycleError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/empty_argument_error.py b/openfisca_core/errors/empty_argument_error.py index d3bcddbf9a..960d8d28c2 100644 --- a/openfisca_core/errors/empty_argument_error.py +++ b/openfisca_core/errors/empty_argument_error.py @@ -1,6 +1,7 @@ +import typing + import os import traceback -import typing import numpy @@ -11,12 +12,12 @@ class EmptyArgumentError(IndexError): message: str def __init__( - self, - class_name: str, - method_name: str, - arg_name: str, - arg_value: typing.Union[typing.List, numpy.ndarray] - ) -> None: + self, + class_name: str, + method_name: str, + arg_name: str, + arg_value: typing.Union[list, numpy.ndarray], + ) -> None: message = [ f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n", f">>> {arg_name}", @@ -31,7 +32,7 @@ def __init__( "- Mention us via https://twitter.com/openfisca", "- Drop us a line to contact@openfisca.org\n", "😃", - ] + ] stacktrace = os.linesep.join(traceback.format_stack()) self.message = os.linesep.join([f" {line}" for line in message]) self.message = os.linesep.join([stacktrace, self.message]) diff --git a/openfisca_core/errors/nan_creation_error.py b/openfisca_core/errors/nan_creation_error.py index dfd1b7af7e..373e391517 100644 --- a/openfisca_core/errors/nan_creation_error.py +++ b/openfisca_core/errors/nan_creation_error.py @@ -1,4 +1,2 @@ class NaNCreationError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/parameter_not_found_error.py b/openfisca_core/errors/parameter_not_found_error.py index 624ef490e6..bad33c89f4 100644 --- a/openfisca_core/errors/parameter_not_found_error.py +++ b/openfisca_core/errors/parameter_not_found_error.py @@ -1,21 +1,16 @@ class ParameterNotFoundError(AttributeError): - """ - Exception raised when a parameter is not found in the parameters. - """ + """Exception raised when a parameter is not found in the parameters.""" - def __init__(self, name, instant_str, variable_name = None): - """ - :param name: Name of the parameter + def __init__(self, name, instant_str, variable_name=None) -> None: + """:param name: Name of the parameter :param instant_str: Instant where the parameter does not exist, in the format `YYYY-MM-DD`. :param variable_name: If the parameter was queried during the computation of a variable, name of that variable. """ self.name = name self.instant_str = instant_str self.variable_name = variable_name - message = "The parameter '{}'".format(name) + message = f"The parameter '{name}'" if variable_name is not None: - message += " requested by variable '{}'".format(variable_name) - message += ( - " was not found in the {} tax and benefit system." - ).format(instant_str) - super(ParameterNotFoundError, self).__init__(message) + message += f" requested by variable '{variable_name}'" + message += f" was not found in the {instant_str} tax and benefit system." + super().__init__(message) diff --git a/openfisca_core/errors/parameter_parsing_error.py b/openfisca_core/errors/parameter_parsing_error.py index aa92124290..7628e42d86 100644 --- a/openfisca_core/errors/parameter_parsing_error.py +++ b/openfisca_core/errors/parameter_parsing_error.py @@ -2,21 +2,17 @@ class ParameterParsingError(Exception): - """ - Exception raised when a parameter cannot be parsed. - """ + """Exception raised when a parameter cannot be parsed.""" - def __init__(self, message, file = None, traceback = None): - """ - :param message: Error message + def __init__(self, message, file=None, traceback=None) -> None: + """:param message: Error message :param file: Parameter file which caused the error (optional) :param traceback: Traceback (optional) """ if file is not None: - message = os.linesep.join([ - "Error parsing parameter file '{}':".format(file), - message - ]) + message = os.linesep.join( + [f"Error parsing parameter file '{file}':", message], + ) if traceback is not None: message = os.linesep.join([traceback, message]) - super(ParameterParsingError, self).__init__(message) + super().__init__(message) diff --git a/openfisca_core/errors/period_mismatch_error.py b/openfisca_core/errors/period_mismatch_error.py index 0ba01abcd0..fcece9474d 100644 --- a/openfisca_core/errors/period_mismatch_error.py +++ b/openfisca_core/errors/period_mismatch_error.py @@ -1,9 +1,7 @@ class PeriodMismatchError(ValueError): - """ - Exception raised when one tries to set a variable value for a period that doesn't match its definition period - """ + """Exception raised when one tries to set a variable value for a period that doesn't match its definition period.""" - def __init__(self, variable_name, period, definition_period, message): + def __init__(self, variable_name: str, period, definition_period, message) -> None: self.variable_name = variable_name self.period = period self.definition_period = definition_period diff --git a/openfisca_core/errors/situation_parsing_error.py b/openfisca_core/errors/situation_parsing_error.py index f5c11e65cb..a5d7ee88d3 100644 --- a/openfisca_core/errors/situation_parsing_error.py +++ b/openfisca_core/errors/situation_parsing_error.py @@ -1,20 +1,27 @@ +from __future__ import annotations + +from collections.abc import Iterable + import os -import dpath +import dpath.util class SituationParsingError(Exception): - """ - Exception raised when the situation provided as an input for a simulation cannot be parsed - """ + """Exception raised when the situation provided as an input for a simulation cannot be parsed.""" - def __init__(self, path, message, code = None): + def __init__( + self, + path: Iterable[str], + message: str, + code: int | None = None, + ) -> None: self.error = {} - dpath_path = '/'.join([str(item) for item in path]) - message = str(message).strip(os.linesep).replace(os.linesep, ' ') + dpath_path = "/".join([str(item) for item in path]) + message = str(message).strip(os.linesep).replace(os.linesep, " ") dpath.util.new(self.error, dpath_path, message) self.code = code Exception.__init__(self, str(self.error)) - def __str__(self): + def __str__(self) -> str: return str(self.error) diff --git a/openfisca_core/errors/spiral_error.py b/openfisca_core/errors/spiral_error.py index 0495439b68..ffa7fe2850 100644 --- a/openfisca_core/errors/spiral_error.py +++ b/openfisca_core/errors/spiral_error.py @@ -1,4 +1,2 @@ class SpiralError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/variable_name_config_error.py b/openfisca_core/errors/variable_name_config_error.py index 7a87d7f5c8..fec1c45864 100644 --- a/openfisca_core/errors/variable_name_config_error.py +++ b/openfisca_core/errors/variable_name_config_error.py @@ -1,6 +1,2 @@ class VariableNameConflictError(Exception): - """ - Exception raised when two variables with the same name are added to a tax and benefit system. - """ - - pass + """Exception raised when two variables with the same name are added to a tax and benefit system.""" diff --git a/openfisca_core/errors/variable_not_found_error.py b/openfisca_core/errors/variable_not_found_error.py index f84ce06f95..46ece4b13c 100644 --- a/openfisca_core/errors/variable_not_found_error.py +++ b/openfisca_core/errors/variable_not_found_error.py @@ -2,29 +2,28 @@ class VariableNotFoundError(Exception): - """ - Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem. - """ + """Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem.""" - def __init__(self, variable_name, tax_benefit_system): - """ - :param variable_name: Name of the variable that was queried. + def __init__(self, variable_name: str, tax_benefit_system) -> None: + """:param variable_name: Name of the variable that was queried. :param tax_benefit_system: Tax benefits system that does not contain `variable_name` """ country_package_metadata = tax_benefit_system.get_package_metadata() - country_package_name = country_package_metadata['name'] - country_package_version = country_package_metadata['version'] + country_package_name = country_package_metadata["name"] + country_package_version = country_package_metadata["version"] if country_package_version: - country_package_id = '{}@{}'.format(country_package_name, country_package_version) + country_package_id = f"{country_package_name}@{country_package_version}" else: country_package_id = country_package_name - message = os.linesep.join([ - "You tried to calculate or to set a value for variable '{0}', but it was not found in the loaded tax and benefit system ({1}).".format(variable_name, country_package_id), - "Are you sure you spelled '{0}' correctly?".format(variable_name), - "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", - "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", - "it is probably available on .".format(country_package_name) - ]) + message = os.linesep.join( + [ + f"You tried to calculate or to set a value for variable '{variable_name}', but it was not found in the loaded tax and benefit system ({country_package_id}).", + f"Are you sure you spelled '{variable_name}' correctly?", + "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", + "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", + f"it is probably available on .", + ], + ) self.message = message self.variable_name = variable_name Exception.__init__(self, self.message) diff --git a/openfisca_core/experimental/__init__.py b/openfisca_core/experimental/__init__.py index 83faabe2bb..07114cdd27 100644 --- a/openfisca_core/experimental/__init__.py +++ b/openfisca_core/experimental/__init__.py @@ -1,24 +1,9 @@ -# Transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. -# -# How imports are being used today: -# -# >>> from openfisca_core.module import symbol -# -# The previous example provokes cyclic dependency problems -# that prevent us from modularizing the different components -# of the library so to make them easier to test and to maintain. -# -# How could them be used after the next major release: -# -# >>> from openfisca_core import module -# >>> module.symbol() -# -# And for classes: -# -# >>> from openfisca_core.module import Symbol -# >>> Symbol() -# -# See: https://www.python.org/dev/peps/pep-0008/#imports +"""Experimental features of OpenFisca-Core.""" -from .memory_config import MemoryConfig # noqa: F401 +from ._errors import MemoryConfigWarning +from ._memory_config import MemoryConfig + +__all__ = [ + "MemoryConfig", + "MemoryConfigWarning", +] diff --git a/openfisca_core/experimental/_errors.py b/openfisca_core/experimental/_errors.py new file mode 100644 index 0000000000..6957e36c26 --- /dev/null +++ b/openfisca_core/experimental/_errors.py @@ -0,0 +1,5 @@ +class MemoryConfigWarning(UserWarning): + """Custom warning for MemoryConfig.""" + + +__all__ = ["MemoryConfigWarning"] diff --git a/openfisca_core/experimental/_memory_config.py b/openfisca_core/experimental/_memory_config.py new file mode 100644 index 0000000000..6fba790e90 --- /dev/null +++ b/openfisca_core/experimental/_memory_config.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from collections.abc import Iterable + +import warnings + +from ._errors import MemoryConfigWarning + + +class MemoryConfig: + """Experimental memory configuration.""" + + #: Maximum memory occupation allowed. + max_memory_occupation: float + + #: Priority variables. + priority_variables: frozenset[str] + + #: Variables to drop. + variables_to_drop: frozenset[str] + + def __init__( + self, + max_memory_occupation: str | float, + priority_variables: Iterable[str] = frozenset(), + variables_to_drop: Iterable[str] = frozenset(), + ) -> None: + message = [ + "Memory configuration is a feature that is still currently under " + "experimentation. You are very welcome to use it and send us " + "precious feedback, but keep in mind that the way it is used might " + "change without any major version bump.", + ] + warnings.warn(" ".join(message), MemoryConfigWarning, stacklevel=2) + + self.max_memory_occupation = float(max_memory_occupation) + if self.max_memory_occupation > 1: + msg = "max_memory_occupation must be <= 1" + raise ValueError(msg) + self.max_memory_occupation_pc = self.max_memory_occupation * 100 + self.priority_variables = frozenset(priority_variables) + self.variables_to_drop = frozenset(variables_to_drop) diff --git a/openfisca_core/experimental/memory_config.py b/openfisca_core/experimental/memory_config.py deleted file mode 100644 index 5f3b4a1126..0000000000 --- a/openfisca_core/experimental/memory_config.py +++ /dev/null @@ -1,24 +0,0 @@ -import warnings - -from openfisca_core.warnings import MemoryConfigWarning - - -class MemoryConfig: - - def __init__(self, - max_memory_occupation, - priority_variables = None, - variables_to_drop = None): - message = [ - "Memory configuration is a feature that is still currently under experimentation.", - "You are very welcome to use it and send us precious feedback,", - "but keep in mind that the way it is used might change without any major version bump." - ] - warnings.warn(" ".join(message), MemoryConfigWarning) - - self.max_memory_occupation = float(max_memory_occupation) - if self.max_memory_occupation > 1: - raise ValueError("max_memory_occupation must be <= 1") - self.max_memory_occupation_pc = self.max_memory_occupation * 100 - self.priority_variables = set(priority_variables) if priority_variables else set() - self.variables_to_drop = set(variables_to_drop) if variables_to_drop else set() diff --git a/openfisca_core/holders/__init__.py b/openfisca_core/holders/__init__.py index a7d46e38a6..a120a671b9 100644 --- a/openfisca_core/holders/__init__.py +++ b/openfisca_core/holders/__init__.py @@ -21,5 +21,13 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import set_input_dispatch_by_period, set_input_divide_by_period # noqa: F401 -from .holder import Holder # noqa: F401 +from . import types +from .helpers import set_input_dispatch_by_period, set_input_divide_by_period +from .holder import Holder + +__all__ = [ + "Holder", + "set_input_dispatch_by_period", + "set_input_divide_by_period", + "types", +] diff --git a/openfisca_core/holders/helpers.py b/openfisca_core/holders/helpers.py index efe16388e0..fcc6563c79 100644 --- a/openfisca_core/holders/helpers.py +++ b/openfisca_core/holders/helpers.py @@ -7,9 +7,8 @@ log = logging.getLogger(__name__) -def set_input_dispatch_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_dispatch_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be applied to all its subperiods. @@ -20,17 +19,19 @@ def set_input_dispatch_by_period(holder, period, array): period_size = period.size period_unit = period.unit - if holder.variable.definition_period == periods.MONTH: - cached_period_unit = periods.MONTH - elif holder.variable.definition_period == periods.YEAR: - cached_period_unit = periods.YEAR - else: - raise ValueError('set_input_dispatch_by_period can be used only for yearly or monthly variables.') + if holder.variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = "set_input_dispatch_by_period can't be used for eternal variables." + raise ValueError( + msg, + ) + cached_period_unit = holder.variable.definition_period after_instant = period.start.offset(period_size, period_unit) # Cache the input data, skipping the existing cached months - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) while sub_period.start < after_instant: existing_array = holder.get_array(sub_period) if existing_array is None: @@ -42,9 +43,8 @@ def set_input_dispatch_by_period(holder, period, array): sub_period = sub_period.offset(1) -def set_input_divide_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_divide_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be divided between its subperiods. @@ -55,18 +55,20 @@ def set_input_divide_by_period(holder, period, array): period_size = period.size period_unit = period.unit - if holder.variable.definition_period == periods.MONTH: - cached_period_unit = periods.MONTH - elif holder.variable.definition_period == periods.YEAR: - cached_period_unit = periods.YEAR - else: - raise ValueError('set_input_divide_by_period can be used only for yearly or monthly variables.') + if holder.variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = "set_input_divide_by_period can't be used for eternal variables." + raise ValueError( + msg, + ) + cached_period_unit = holder.variable.definition_period after_instant = period.start.offset(period_size, period_unit) # Count the number of elementary periods to change, and the difference with what is already known. remaining_array = array.copy() - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) sub_periods_count = 0 while sub_period.start < after_instant: existing_array = holder.get_array(sub_period) @@ -79,10 +81,13 @@ def set_input_divide_by_period(holder, period, array): # Cache the input data if sub_periods_count > 0: divided_array = remaining_array / sub_periods_count - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) while sub_period.start < after_instant: if holder.get_array(sub_period) is None: holder._set(sub_period, divided_array) sub_period = sub_period.offset(1) elif not (remaining_array == 0).all(): - raise ValueError("Inconsistent input: variable {0} has already been set for all months contained in period {1}, and value {2} provided for {1} doesn't match the total ({3}). This error may also be thrown if you try to call set_input twice for the same variable and period.".format(holder.variable.name, period, array, array - remaining_array)) + msg = f"Inconsistent input: variable {holder.variable.name} has already been set for all months contained in period {period}, and value {array} provided for {period} doesn't match the total ({array - remaining_array}). This error may also be thrown if you try to call set_input twice for the same variable and period." + raise ValueError( + msg, + ) diff --git a/openfisca_core/holders/holder.py b/openfisca_core/holders/holder.py index 3d0379d22d..f60d92f70b 100644 --- a/openfisca_core/holders/holder.py +++ b/openfisca_core/holders/holder.py @@ -1,79 +1,87 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + import os import warnings import numpy import psutil -from openfisca_core import commons, periods, tools -from openfisca_core.errors import PeriodMismatchError -from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage -from openfisca_core.indexed_enums import Enum +from openfisca_core import ( + commons, + data_storage as storage, + errors, + indexed_enums as enums, + periods, + types, +) + +from . import types as t class Holder: - """ - A holder keeps tracks of a variable values after they have been calculated, or set as an input. - """ + """A holder keeps tracks of a variable values after they have been calculated, or set as an input.""" - def __init__(self, variable, population): + def __init__(self, variable, population) -> None: self.population = population self.variable = variable self.simulation = population.simulation - self._memory_storage = InMemoryStorage(is_eternal = (self.variable.definition_period == periods.ETERNITY)) + self._eternal = self.variable.definition_period == periods.DateUnit.ETERNITY + self._memory_storage = storage.InMemoryStorage(is_eternal=self._eternal) # By default, do not activate on-disk storage, or variable dropping self._disk_storage = None self._on_disk_storable = False self._do_not_store = False if self.simulation and self.simulation.memory_config: - if self.variable.name not in self.simulation.memory_config.priority_variables: + if ( + self.variable.name + not in self.simulation.memory_config.priority_variables + ): self._disk_storage = self.create_disk_storage() self._on_disk_storable = True if self.variable.name in self.simulation.memory_config.variables_to_drop: self._do_not_store = True - def clone(self, population): - """ - Copy the holder just enough to be able to run a new simulation without modifying the original simulation. - """ + def clone(self, population: t.CorePopulation) -> t.Holder: + """Copy the holder just enough to be able to run a new simulation without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('population', 'formula', 'simulation'): + if key not in ("population", "formula", "simulation"): new_dict[key] = value - new_dict['population'] = population - new_dict['simulation'] = population.simulation + new_dict["population"] = population + new_dict["simulation"] = population.simulation return new - def create_disk_storage(self, directory = None, preserve = False): + def create_disk_storage(self, directory=None, preserve=False): if directory is None: directory = self.simulation.data_storage_dir storage_dir = os.path.join(directory, self.variable.name) if not os.path.isdir(storage_dir): os.mkdir(storage_dir) - return OnDiskStorage( + return storage.OnDiskStorage( storage_dir, - is_eternal = (self.variable.definition_period == periods.ETERNITY), - preserve_storage_dir = preserve - ) + self._eternal, + preserve_storage_dir=preserve, + ) - def delete_arrays(self, period = None): - """ - If ``period`` is ``None``, remove all known values of the variable. + def delete_arrays(self, period=None) -> None: + """If ``period`` is ``None``, remove all known values of the variable. If ``period`` is not ``None``, only remove all values for any period included in period (e.g. if period is "2017", values for "2017-01", "2017-07", etc. would be removed) """ - self._memory_storage.delete(period) if self._disk_storage: self._disk_storage.delete(period) def get_array(self, period): - """ - Get the value of the variable for the given period. + """Get the value of the variable for the given period. If the value is not known, return ``None``. """ @@ -84,92 +92,149 @@ def get_array(self, period): return value if self._disk_storage: return self._disk_storage.get(period) + return None - def get_memory_usage(self): - """ - Get data about the virtual memory usage of the holder. - - :returns: Memory usage data - :rtype: dict - - Example: - - >>> holder.get_memory_usage() - >>> { - >>> 'nb_arrays': 12, # The holder contains the variable values for 12 different periods - >>> 'nb_cells_by_array': 100, # There are 100 entities (e.g. persons) in our simulation - >>> 'cell_size': 8, # Each value takes 8B of memory - >>> 'dtype': dtype('float64') # Each value is a float 64 - >>> 'total_nb_bytes': 10400 # The holder uses 10.4kB of virtual memory - >>> 'nb_requests': 24 # The variable has been computed 24 times - >>> 'nb_requests_by_array': 2 # Each array stored has been on average requested twice - >>> } - """ + def get_memory_usage(self) -> t.MemoryUsage: + """Get data about the virtual memory usage of the Holder. - usage = dict( - nb_cells_by_array = self.population.count, - dtype = self.variable.dtype, - ) + Returns: + Memory usage data. + + Examples: + >>> from pprint import pprint + + >>> from openfisca_core import ( + ... entities, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... variables, + ... ) + + >>> entity = entities.Entity("", "", "", "") + + >>> class MyVariable(variables.Variable): + ... definition_period = periods.DateUnit.YEAR + ... entity = entity + ... value_type = int + + >>> population = populations.Population(entity) + >>> variable = MyVariable() + >>> holder = Holder(variable, population) + + >>> tbs = taxbenefitsystems.TaxBenefitSystem([entity]) + >>> entities = {entity.key: population} + >>> simulation = simulations.Simulation(tbs, entities) + >>> holder.simulation = simulation + + >>> pprint(holder.get_memory_usage(), indent=3) + { 'cell_size': nan, + 'dtype': , + 'nb_arrays': 0, + 'nb_cells_by_array': 0, + 'total_nb_bytes': 0... + + """ + usage = t.MemoryUsage( + nb_cells_by_array=self.population.count, + dtype=self.variable.dtype, + ) usage.update(self._memory_storage.get_memory_usage()) if self.simulation.trace: nb_requests = self.simulation.tracer.get_nb_requests(self.variable.name) - usage.update(dict( - nb_requests = nb_requests, - nb_requests_by_array = nb_requests / float(usage['nb_arrays']) if usage['nb_arrays'] > 0 else numpy.nan - )) + usage.update( + { + "nb_requests": nb_requests, + "nb_requests_by_array": ( + nb_requests / float(usage["nb_arrays"]) + if usage["nb_arrays"] > 0 + else numpy.nan + ), + }, + ) return usage def get_known_periods(self): - """ - Get the list of periods the variable value is known for. - """ + """Get the list of periods the variable value is known for.""" + return list(self._memory_storage.get_known_periods()) + list( + self._disk_storage.get_known_periods() if self._disk_storage else [], + ) - return list(self._memory_storage.get_known_periods()) + list(( - self._disk_storage.get_known_periods() if self._disk_storage else [])) + def set_input( + self, + period: types.Period, + array: numpy.ndarray | Sequence[Any], + ) -> numpy.ndarray | None: + """Set a Variable's array of values of a given Period. - def set_input(self, period, array): - """ - Set a variable's value (``array``) for a given period (``period``) + Args: + period: The period at which the value is set. + array: The input value for the variable. - :param array: the input value for the variable - :param period: the period at which the value is setted + Returns: + The set input array. - Example : + Note: + If a ``set_input`` property has been set for the variable, this + method may accept inputs for periods not matching the + ``definition_period`` of the Variable. To read + more about this, check the `documentation`_. - >>> holder.set_input([12, 14], '2018-04') - >>> holder.get_array('2018-04') - >>> [12, 14] + Examples: + >>> from openfisca_core import entities, populations, variables + >>> entity = entities.Entity("", "", "", "") - If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. - """ + >>> class MyVariable(variables.Variable): + ... definition_period = periods.DateUnit.YEAR + ... entity = entity + ... value_type = float + + >>> variable = MyVariable() + >>> population = populations.Population(entity) + >>> population.count = 2 + + >>> holder = Holder(variable, population) + >>> holder.set_input("2018", numpy.array([12.5, 14])) + >>> holder.get_array("2018") + array([12.5, 14. ], dtype=float32) + + >>> holder.set_input("2018", [12.5, 14]) + >>> holder.get_array("2018") + array([12.5, 14. ], dtype=float32) + + .. _documentation: + https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period + + """ period = periods.period(period) - if period.unit == periods.ETERNITY and self.variable.definition_period != periods.ETERNITY: - error_message = os.linesep.join([ - 'Unable to set a value for variable {0} for periods.ETERNITY.', - '{0} is only defined for {1}s. Please adapt your input.', - ]).format( - self.variable.name, - self.variable.definition_period - ) - raise PeriodMismatchError( + + if period.unit == periods.DateUnit.ETERNITY and not self._eternal: + error_message = os.linesep.join( + [ + "Unable to set a value for variable {1} for {0}.", + "{1} is only defined for {2}s. Please adapt your input.", + ], + ).format( + periods.DateUnit.ETERNITY.upper(), + self.variable.name, + self.variable.definition_period, + ) + raise errors.PeriodMismatchError( self.variable.name, period, self.variable.definition_period, - error_message - ) + error_message, + ) if self.variable.is_neutralized: - warning_message = "You cannot set a value for the variable {}, as it has been neutralized. The value you provided ({}) will be ignored.".format(self.variable.name, array) - return warnings.warn( - warning_message, - Warning - ) + warning_message = f"You cannot set a value for the variable {self.variable.name}, as it has been neutralized. The value you provided ({array}) will be ignored." + return warnings.warn(warning_message, Warning, stacklevel=2) if self.variable.value_type in (float, int) and isinstance(array, str): - array = tools.eval_expression(array) + array = commons.eval_expression(array) if self.variable.set_input: return self.variable.set_input(self, period, array) return self._set(period, array) @@ -181,66 +246,80 @@ def _to_array(self, value): # 0-dim arrays are casted to scalar when they interact with float. We don't want that. value = value.reshape(1) if len(value) != self.population.count: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as its length is {len(value)} while there are {self.population.count} {self.population.entity.plural} in the simulation.' raise ValueError( - 'Unable to set value "{}" for variable "{}", as its length is {} while there are {} {} in the simulation.' - .format(value, self.variable.name, len(value), self.population.count, self.population.entity.plural)) - if self.variable.value_type == Enum: + msg, + ) + if self.variable.value_type == enums.Enum: value = self.variable.possible_values.encode(value) if value.dtype != self.variable.dtype: try: value = value.astype(self.variable.dtype) except ValueError: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as the variable dtype "{self.variable.dtype}" does not match the value dtype "{value.dtype}".' raise ValueError( - 'Unable to set value "{}" for variable "{}", as the variable dtype "{}" does not match the value dtype "{}".' - .format(value, self.variable.name, self.variable.dtype, value.dtype)) + msg, + ) return value - def _set(self, period, value): + def _set(self, period, value) -> None: value = self._to_array(value) - if self.variable.definition_period != periods.ETERNITY: + if not self._eternal: if period is None: - raise ValueError('A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition.') - if (self.variable.definition_period != period.unit or period.size > 1): + msg = ( + f"A period must be specified to set values, except for variables with " + f"{periods.DateUnit.ETERNITY.upper()} as as period_definition." + ) + raise ValueError( + msg, + ) + if self.variable.definition_period != period.unit or period.size > 1: name = self.variable.name - period_size_adj = f'{period.unit}' if (period.size == 1) else f'{period.size}-{period.unit}s' - error_message = os.linesep.join([ - f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', - f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', - f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.' - ]) - - raise PeriodMismatchError( + period_size_adj = ( + f"{period.unit}" + if (period.size == 1) + else f"{period.size}-{period.unit}s" + ) + error_message = os.linesep.join( + [ + f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', + f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', + f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.', + ], + ) + + raise errors.PeriodMismatchError( self.variable.name, period, self.variable.definition_period, - error_message - ) + error_message, + ) should_store_on_disk = ( - self._on_disk_storable and - self._memory_storage.get(period) is None and # If there is already a value in memory, replace it and don't put a new value in the disk storage - psutil.virtual_memory().percent >= self.simulation.memory_config.max_memory_occupation_pc - ) + self._on_disk_storable + and self._memory_storage.get(period) is None + and psutil.virtual_memory().percent # If there is already a value in memory, replace it and don't put a new value in the disk storage + >= self.simulation.memory_config.max_memory_occupation_pc + ) if should_store_on_disk: self._disk_storage.put(value, period) else: self._memory_storage.put(value, period) - def put_in_cache(self, value, period): + def put_in_cache(self, value, period) -> None: if self._do_not_store: return - if (self.simulation.opt_out_cache and - self.simulation.tax_benefit_system.cache_blacklist and - self.variable.name in self.simulation.tax_benefit_system.cache_blacklist): + if ( + self.simulation.opt_out_cache + and self.simulation.tax_benefit_system.cache_blacklist + and self.variable.name in self.simulation.tax_benefit_system.cache_blacklist + ): return self._set(period, value) def default_array(self): - """ - Return a new array of the appropriate length for the entity, filled with the variable default values. - """ - + """Return a new array of the appropriate length for the entity, filled with the variable default values.""" return self.variable.default_array(self.population.count) diff --git a/openfisca_core/holders/tests/__init__.py b/openfisca_core/holders/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/holders/tests/test_helpers.py b/openfisca_core/holders/tests/test_helpers.py new file mode 100644 index 0000000000..948f25288f --- /dev/null +++ b/openfisca_core/holders/tests/test_helpers.py @@ -0,0 +1,134 @@ +import pytest + +from openfisca_core import holders, tools +from openfisca_core.entities import Entity +from openfisca_core.holders import Holder +from openfisca_core.periods import DateUnit, Instant, Period +from openfisca_core.populations import Population +from openfisca_core.variables import Variable + + +@pytest.fixture +def people(): + return Entity( + key="person", + plural="people", + label="An individual member of a larger group.", + doc="People have the particularity of not being someone else.", + ) + + +@pytest.fixture +def Income(people): + return type( + "Income", + (Variable,), + {"value_type": float, "entity": people}, + ) + + +@pytest.fixture +def population(people): + population = Population(people) + population.count = 1 + return population + + +@pytest.mark.parametrize( + ("dispatch_unit", "definition_unit", "values", "expected"), + [ + (DateUnit.YEAR, DateUnit.YEAR, [1.0], [3.0]), + (DateUnit.YEAR, DateUnit.MONTH, [1.0], [36.0]), + (DateUnit.YEAR, DateUnit.DAY, [1.0], [1096.0]), + (DateUnit.YEAR, DateUnit.WEEK, [1.0], [157.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1.0], [1096.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [1.0], [3.0]), + (DateUnit.MONTH, DateUnit.DAY, [1.0], [90.0]), + (DateUnit.MONTH, DateUnit.WEEK, [1.0], [13.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [1.0], [90.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [1.0], [21.0]), + (DateUnit.WEEK, DateUnit.WEEK, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [1.0], [21.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [1.0], [3.0]), + ], +) +def test_set_input_dispatch_by_period( + Income, + population, + dispatch_unit, + definition_unit, + values, + expected, +) -> None: + Income.definition_period = definition_unit + income = Income() + holder = Holder(income, population) + instant = Instant((2022, 1, 1)) + dispatch_period = Period((dispatch_unit, instant, 3)) + + holders.set_input_dispatch_by_period(holder, dispatch_period, values) + total = sum(map(holder.get_array, holder.get_known_periods())) + + tools.assert_near(total, expected, absolute_error_margin=0.001) + + +@pytest.mark.parametrize( + ("divide_unit", "definition_unit", "values", "expected"), + [ + (DateUnit.YEAR, DateUnit.YEAR, [3.0], [1.0]), + (DateUnit.YEAR, DateUnit.MONTH, [36.0], [1.0]), + (DateUnit.YEAR, DateUnit.DAY, [1095.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEK, [157.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1095.0], [1.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [3.0], [1.0]), + (DateUnit.MONTH, DateUnit.DAY, [90.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEK, [13.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [90.0], [1.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [21.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEK, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [21.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [3.0], [1.0]), + ], +) +def test_set_input_divide_by_period( + Income, + population, + divide_unit, + definition_unit, + values, + expected, +) -> None: + Income.definition_period = definition_unit + income = Income() + holder = Holder(income, population) + instant = Instant((2022, 1, 1)) + divide_period = Period((divide_unit, instant, 3)) + + holders.set_input_divide_by_period(holder, divide_period, values) + last = holder.get_array(holder.get_known_periods()[-1]) + + tools.assert_near(last, expected, absolute_error_margin=0.001) diff --git a/openfisca_core/holders/types.py b/openfisca_core/holders/types.py new file mode 100644 index 0000000000..7137b86483 --- /dev/null +++ b/openfisca_core/holders/types.py @@ -0,0 +1,3 @@ +from openfisca_core.types import CorePopulation, Holder, MemoryUsage + +__all__ = ["CorePopulation", "Holder", "MemoryUsage"] diff --git a/openfisca_core/indexed_enums/__init__.py b/openfisca_core/indexed_enums/__init__.py index 6a18aa4809..494601fc8d 100644 --- a/openfisca_core/indexed_enums/__init__.py +++ b/openfisca_core/indexed_enums/__init__.py @@ -1,26 +1,18 @@ -# Transitional imports to ensure non-breaking changes. -# Could be deprecated in the next major release. -# -# How imports are being used today: -# -# >>> from openfisca_core.module import symbol -# -# The previous example provokes cyclic dependency problems -# that prevent us from modularizing the different components -# of the library so to make them easier to test and to maintain. -# -# How could them be used after the next major release: -# -# >>> from openfisca_core import module -# >>> module.symbol() -# -# And for classes: -# -# >>> from openfisca_core.module import Symbol -# >>> Symbol() -# -# See: https://www.python.org/dev/peps/pep-0008/#imports +"""Enumerations for variables with a limited set of possible values.""" -from .config import ENUM_ARRAY_DTYPE # noqa: F401 -from .enum_array import EnumArray # noqa: F401 -from .enum import Enum # noqa: F401 +from . import types +from ._enum_type import EnumType +from ._errors import EnumEncodingError, EnumMemberNotFoundError +from .config import ENUM_ARRAY_DTYPE +from .enum import Enum +from .enum_array import EnumArray + +__all__ = [ + "ENUM_ARRAY_DTYPE", + "Enum", + "EnumArray", + "EnumEncodingError", + "EnumMemberNotFoundError", + "EnumType", + "types", +] diff --git a/openfisca_core/indexed_enums/_enum_type.py b/openfisca_core/indexed_enums/_enum_type.py new file mode 100644 index 0000000000..8083a6d49f --- /dev/null +++ b/openfisca_core/indexed_enums/_enum_type.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import final + +import numpy + +from . import types as t + + +@final +class EnumType(t.EnumType): + """Meta class for creating an indexed :class:`.Enum`. + + Examples: + >>> from openfisca_core import indexed_enums as enum + + >>> class Enum(enum.Enum, metaclass=enum.EnumType): + ... pass + + >>> Enum.items + Traceback (most recent call last): + AttributeError: ... + + >>> class Housing(Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> Housing.indices + array([0, 1], dtype=uint8) + + >>> Housing.names + array(['OWNER', 'TENANT'], dtype='>> Housing.enums + array([Housing.OWNER, Housing.TENANT], dtype=object) + + """ + + def __new__( + metacls, + name: str, + bases: tuple[type, ...], + classdict: t.EnumDict, + **kwds: object, + ) -> t.EnumType: + """Create a new indexed enum class.""" + # Create the enum class. + cls = super().__new__(metacls, name, bases, classdict, **kwds) + + # If the enum class has no members, return it as is. + if not cls.__members__: + return cls + + # Add the indices attribute to the enum class. + cls.indices = numpy.arange(len(cls), dtype=t.EnumDType) + + # Add the names attribute to the enum class. + cls.names = numpy.array(cls._member_names_, dtype=t.StrDType) + + # Add the enums attribute to the enum class. + cls.enums = numpy.array(cls, dtype=t.ObjDType) + + # Return the modified enum class. + return cls + + def __dir__(cls) -> list[str]: + return sorted({"indices", "names", "enums", *super().__dir__()}) + + +__all__ = ["EnumType"] diff --git a/openfisca_core/indexed_enums/_errors.py b/openfisca_core/indexed_enums/_errors.py new file mode 100644 index 0000000000..e9b543fc73 --- /dev/null +++ b/openfisca_core/indexed_enums/_errors.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from . import types as t + + +class EnumEncodingError(TypeError): + """Raised when an enum is encoded with an unsupported type.""" + + def __init__( + self, enum_class: type[t.Enum], value: t.VarArray | t.ArrayLike[object] + ) -> None: + msg = ( + f"Failed to encode \"{value}\" of type '{value[0].__class__.__name__}', " + "as it is not supported. Please, try again with an array of " + f"'{int.__name__}', '{str.__name__}', or '{enum_class.__name__}'." + ) + super().__init__(msg) + + +class EnumMemberNotFoundError(IndexError): + """Raised when a member is not found in an enum.""" + + def __init__(self, enum_class: type[t.Enum]) -> None: + index = [str(enum.index) for enum in enum_class] + names = [enum.name for enum in enum_class] + msg = ( + f"Some members were not found in enum '{enum_class.__name__}'. " + f"Possible values are: {', '.join(names[:-1])}, and {names[-1]!s}; " + f"or their corresponding indices: {', '.join(index[:-1])}, and " + f"{index[-1]}." + ) + super().__init__(msg) + + +__all__ = ["EnumEncodingError", "EnumMemberNotFoundError"] diff --git a/openfisca_core/indexed_enums/_guards.py b/openfisca_core/indexed_enums/_guards.py new file mode 100644 index 0000000000..6c47471b3e --- /dev/null +++ b/openfisca_core/indexed_enums/_guards.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from typing import Final +from typing_extensions import TypeIs + +import numpy + +from . import types as t + +#: Types for int arrays. +ints: Final = { + numpy.uint8, + numpy.uint16, + numpy.uint32, + numpy.uint64, + numpy.int8, + numpy.int16, + numpy.int32, + numpy.int64, +} + +#: Types for object arrays. +objs: Final = {numpy.object_} + +#: Types for str arrays. +strs: Final = {numpy.str_} + + +def _is_enum_array(array: t.VarArray) -> TypeIs[t.ObjArray]: + """Narrow the type of a given array to an array of :obj:`numpy.object_`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array of :obj:`numpy.object_`, False otherwise. + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> Enum = enum.Enum("Enum", ["A", "B"]) + >>> array = numpy.array([Enum.A], dtype=numpy.object_) + >>> _is_enum_array(array) + True + + >>> array = numpy.array([1.0]) + >>> _is_enum_array(array) + False + + """ + return array.dtype.type in objs + + +def _is_enum_array_like( + array: t.VarArray | t.ArrayLike[object], +) -> TypeIs[t.ArrayLike[t.Enum]]: + """Narrow the type of a given array-like to an sequence of :class:`.Enum`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array-like of :class:`.Enum`, False otherwise. + + Examples: + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "owner" + ... TENANT = "tenant" + + >>> array = [Housing.OWNER] + >>> _is_enum_array_like(array) + True + + >>> array = ["owner"] + >>> _is_enum_array_like(array) + False + + """ + return all(isinstance(item, t.Enum) for item in array) + + +def _is_int_array(array: t.VarArray) -> TypeIs[t.IndexArray]: + """Narrow the type of a given array to an array of :obj:`numpy.integer`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array of :obj:`numpy.integer`, False otherwise. + + Examples: + >>> import numpy + + >>> array = numpy.array([1], dtype=numpy.int16) + >>> _is_int_array(array) + True + + >>> array = numpy.array([1], dtype=numpy.int32) + >>> _is_int_array(array) + True + + >>> array = numpy.array([1.0]) + >>> _is_int_array(array) + False + + """ + return array.dtype.type in ints + + +def _is_int_array_like( + array: t.VarArray | t.ArrayLike[object], +) -> TypeIs[t.ArrayLike[int]]: + """Narrow the type of a given array-like to a sequence of :obj:`int`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array-like of :obj:`int`, False otherwise. + + Examples: + >>> array = [1] + >>> _is_int_array_like(array) + True + + >>> array = (1, 2) + >>> _is_int_array_like(array) + True + + >>> array = [1.0] + >>> _is_int_array_like(array) + False + + """ + return all(isinstance(item, int) for item in array) + + +def _is_str_array(array: t.VarArray) -> TypeIs[t.StrArray]: + """Narrow the type of a given array to an array of :obj:`numpy.str_`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array of :obj:`numpy.str_`, False otherwise. + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "owner" + ... TENANT = "tenant" + + >>> array = numpy.array([Housing.OWNER]) + >>> _is_str_array(array) + False + + >>> array = numpy.array(["owner"]) + >>> _is_str_array(array) + True + + """ + return array.dtype.type in strs + + +def _is_str_array_like( + array: t.VarArray | t.ArrayLike[object], +) -> TypeIs[t.ArrayLike[str]]: + """Narrow the type of a given array-like to an sequence of :obj:`str`. + + Args: + array: Array to check. + + Returns: + bool: True if ``array`` is an array-like of :obj:`str`, False otherwise. + + Examples: + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "owner" + ... TENANT = "tenant" + + >>> array = [Housing.OWNER] + >>> _is_str_array_like(array) + False + + >>> array = ["owner"] + >>> _is_str_array_like(array) + True + + """ + return all(isinstance(item, str) for item in array) + + +__all__ = [ + "_is_enum_array", + "_is_enum_array_like", + "_is_int_array", + "_is_int_array_like", + "_is_str_array", + "_is_str_array_like", +] diff --git a/openfisca_core/indexed_enums/_utils.py b/openfisca_core/indexed_enums/_utils.py new file mode 100644 index 0000000000..aa676b92f7 --- /dev/null +++ b/openfisca_core/indexed_enums/_utils.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import numpy + +from . import types as t + + +def _enum_to_index(value: t.ObjArray | t.ArrayLike[t.Enum]) -> t.IndexArray: + """Transform an array of enum members into an index array. + + Args: + value: The enum members array to encode. + + Returns: + The index array. + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Road(enum.Enum): + ... STREET = ( + ... "A public road that connects two points, but also has " + ... "buildings on both sides of it; these typically run " + ... "perpendicular to avenues." + ... ) + ... AVENUE = ( + ... "A public way that also has buildings and/or trees on both " + ... "sides; these run perpendicular to streets and are " + ... "traditionally wider." + ... ) + + >>> class Rogue(enum.Enum): + ... BOULEVARD = "More like a shady impasse, to be honest." + + >>> _enum_to_index(Road.AVENUE) + Traceback (most recent call last): + TypeError: 'Road' object is not iterable + + >>> _enum_to_index([Road.AVENUE]) + array([1], dtype=uint8) + + >>> _enum_to_index(numpy.array(Road.AVENUE)) + Traceback (most recent call last): + TypeError: iteration over a 0-d array + + >>> _enum_to_index(numpy.array([Road.AVENUE])) + array([1], dtype=uint8) + + >>> value = numpy.array([Road.STREET, Road.AVENUE, Road.STREET]) + >>> _enum_to_index(value) + array([0, 1, 0], dtype=uint8) + + >>> value = numpy.array([Road.AVENUE, Road.AVENUE, Rogue.BOULEVARD]) + >>> _enum_to_index(value) + array([1, 1, 0], dtype=uint8) + + """ + return numpy.array([enum.index for enum in value], t.EnumDType) + + +def _int_to_index( + enum_class: type[t.Enum], value: t.IndexArray | t.ArrayLike[int] +) -> t.IndexArray: + """Transform an integer array into an index array. + + Args: + enum_class: The enum class to encode the integer array. + value: The integer array to encode. + + Returns: + The index array. + + Examples: + >>> from array import array + + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Road(enum.Enum): + ... STREET = ( + ... "A public road that connects two points, but also has " + ... "buildings on both sides of it; these typically run " + ... "perpendicular to avenues." + ... ) + ... AVENUE = ( + ... "A public way that also has buildings and/or trees on both " + ... "sides; these run perpendicular to streets and are " + ... "traditionally wider." + ... ) + + >>> _int_to_index(Road, 1) + Traceback (most recent call last): + TypeError: 'int' object is not iterable + + >>> _int_to_index(Road, [1]) + array([1], dtype=uint8) + + >>> _int_to_index(Road, array("B", [1])) + array([1], dtype=uint8) + + >>> _int_to_index(Road, memoryview(array("B", [1]))) + array([1], dtype=uint8) + + >>> _int_to_index(Road, numpy.array(1)) + Traceback (most recent call last): + TypeError: iteration over a 0-d array + + >>> _int_to_index(Road, numpy.array([1])) + array([1], dtype=uint8) + + >>> _int_to_index(Road, numpy.array([0, 1, 0])) + array([0, 1, 0], dtype=uint8) + + >>> _int_to_index(Road, numpy.array([1, 1, 2])) + array([1, 1], dtype=uint8) + + """ + return numpy.array( + [index for index in value if index < len(enum_class.__members__)], t.EnumDType + ) + + +def _str_to_index( + enum_class: type[t.Enum], value: t.StrArray | t.ArrayLike[str] +) -> t.IndexArray: + """Transform a string array into an index array. + + Args: + enum_class: The enum class to encode the string array. + value: The string array to encode. + + Returns: + The index array. + + Examples: + >>> from array import array + + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Road(enum.Enum): + ... STREET = ( + ... "A public road that connects two points, but also has " + ... "buildings on both sides of it; these typically run " + ... "perpendicular to avenues." + ... ) + ... AVENUE = ( + ... "A public way that also has buildings and/or trees on both " + ... "sides; these run perpendicular to streets and are " + ... "traditionally wider." + ... ) + + >>> _str_to_index(Road, "AVENUE") + array([], dtype=uint8) + + >>> _str_to_index(Road, ["AVENUE"]) + array([1], dtype=uint8) + + >>> _str_to_index(Road, numpy.array("AVENUE")) + Traceback (most recent call last): + TypeError: iteration over a 0-d array + + >>> _str_to_index(Road, numpy.array(["AVENUE"])) + array([1], dtype=uint8) + + >>> _str_to_index(Road, numpy.array(["STREET", "AVENUE", "STREET"])) + array([0, 1, 0], dtype=uint8) + + >>> _str_to_index(Road, numpy.array(["AVENUE", "AVENUE", "BOULEVARD"])) + array([1, 1], dtype=uint8) + + """ + return numpy.array( + [ + enum_class.__members__[name].index + for name in value + if name in enum_class._member_names_ + ], + t.EnumDType, + ) + + +__all__ = ["_enum_to_index", "_int_to_index", "_str_to_index"] diff --git a/openfisca_core/indexed_enums/config.py b/openfisca_core/indexed_enums/config.py index f7da69b847..abb8817de3 100644 --- a/openfisca_core/indexed_enums/config.py +++ b/openfisca_core/indexed_enums/config.py @@ -1,3 +1,6 @@ import numpy ENUM_ARRAY_DTYPE = numpy.int16 + + +__all__ = ["ENUM_ARRAY_DTYPE"] diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index 3d9fc08447..d116a56ba4 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -1,99 +1,237 @@ from __future__ import annotations -import enum -from typing import Union +from collections.abc import Sequence import numpy -from . import ENUM_ARRAY_DTYPE, EnumArray +from . import types as t +from ._enum_type import EnumType +from ._errors import EnumEncodingError, EnumMemberNotFoundError +from ._guards import ( + _is_enum_array, + _is_enum_array_like, + _is_int_array, + _is_int_array_like, + _is_str_array, + _is_str_array_like, +) +from ._utils import _enum_to_index, _int_to_index, _str_to_index +from .enum_array import EnumArray -class Enum(enum.Enum): - """ - Enum based on `enum34 `_, whose items - have an index. +class Enum(t.Enum, metaclass=EnumType): + """Enum based on `enum34 `_. + + Its items have an :class:`int` index, useful and performant when running + :mod:`~openfisca_core.simulations` on large :mod:`~openfisca_core.populations`. + + Examples: + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + ... FREE_LODGER = "Free lodger" + ... HOMELESS = "Homeless" + + >>> repr(Housing) + "" + + >>> repr(Housing.TENANT) + 'Housing.TENANT' + + >>> str(Housing.TENANT) + 'Housing.TENANT' + + >>> dict([(Housing.TENANT, Housing.TENANT.value)]) + {Housing.TENANT: 'Tenant'} + + >>> list(Housing) + [Housing.OWNER, Housing.TENANT, Housing.FREE_LODGER, Housing.HOMELESS] + + >>> Housing["TENANT"] + Housing.TENANT + + >>> Housing("Tenant") + Housing.TENANT + + >>> Housing.TENANT in Housing + True + + >>> len(Housing) + 4 + + >>> Housing.TENANT == Housing.TENANT + True + + >>> Housing.TENANT != Housing.TENANT + False + + >>> Housing.TENANT.index + 1 + + >>> Housing.TENANT.name + 'TENANT' + + >>> Housing.TENANT.value + 'Tenant' + """ - # Tweak enums to add an index attribute to each enum item - def __init__(self, name: str) -> None: - # When the enum item is initialized, self._member_names_ contains the - # names of the previously initialized items, so its length is the index - # of this item. + #: The :attr:`index` of the :class:`.Enum` member. + index: int + + def __init__(self, *__args: object, **__kwargs: object) -> None: + """Tweak :class:`enum.Enum` to add an :attr:`.index` to each enum item. + + When the enum is initialised, ``_member_names_`` contains the names of + the already initialized items, so its length is the index of this item. + + Args: + *__args: Positional arguments. + **__kwargs: Keyword arguments. + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> Housing = enum.Enum("Housing", "owner tenant") + >>> Housing.tenant.index + 1 + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> Housing.TENANT.index + 1 + + >>> array = numpy.array([[1, 2], [3, 4]]) + >>> array[Housing.TENANT.index] + array([3, 4]) + + Note: + ``_member_names_`` is undocumented in upstream :class:`enum.Enum`. + + """ self.index = len(self._member_names_) - # Bypass the slow Enum.__eq__ - __eq__ = object.__eq__ + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" - # In Python 3, __hash__ must be defined if __eq__ is defined to stay - # hashable. - __hash__ = object.__hash__ + def __hash__(self) -> int: + return object.__hash__(self.__class__.__name__ + self.name) + + def __eq__(self, other: object) -> bool: + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index == other.index + return NotImplemented + + def __ne__(self, other: object) -> bool: + if ( + isinstance(other, Enum) + and self.__class__.__name__ == other.__class__.__name__ + ): + return self.index != other.index + return NotImplemented @classmethod - def encode( - cls, - array: Union[ - EnumArray, - numpy.int_, - numpy.float_, - numpy.object_, - ], - ) -> EnumArray: - """ - Encode a string numpy array, an enum item numpy array, or an int numpy - array into an :any:`EnumArray`. See :any:`EnumArray.decode` for - decoding. + def encode(cls, array: t.VarArray | t.ArrayLike[object]) -> t.EnumArray: + """Encode an encodable array into an :class:`.EnumArray`. + + Args: + array: :class:`~numpy.ndarray` to encode. + + Returns: + EnumArray: An :class:`.EnumArray` with the encoded input values. + + Examples: + >>> import numpy - :param numpy.ndarray array: Array of string identifiers, or of enum - items, to encode. + >>> from openfisca_core import indexed_enums as enum - :returns: An :any:`EnumArray` encoding the input array values. - :rtype: :any:`EnumArray` + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" - For instance: + # EnumArray - >>> string_identifier_array = asarray(['free_lodger', 'owner']) - >>> encoded_array = HousingOccupancyStatus.encode(string_identifier_array) - >>> encoded_array[0] - 2 # Encoded value + >>> array = numpy.array([1]) + >>> enum_array = enum.EnumArray(array, Housing) + >>> Housing.encode(enum_array) + EnumArray([Housing.TENANT]) + + # Array of Enum + + >>> array = numpy.array([Housing.TENANT]) + >>> enum_array = Housing.encode(array) + >>> enum_array == Housing.TENANT + array([ True]) + + # Array of integers + + >>> array = numpy.array([1]) + >>> enum_array = Housing.encode(array) + >>> enum_array == Housing.TENANT + array([ True]) + + # Array of strings + + >>> array = numpy.array(["TENANT"]) + >>> enum_array = Housing.encode(array) + >>> enum_array == Housing.TENANT + array([ True]) + + # Array of bytes + + >>> array = numpy.array([b"TENANT"]) + >>> enum_array = Housing.encode(array) + Traceback (most recent call last): + EnumEncodingError: Failed to encode "[b'TENANT']" of type 'bytes... + + .. seealso:: + :meth:`.EnumArray.decode` for decoding. - >>> free_lodger = HousingOccupancyStatus.free_lodger - >>> owner = HousingOccupancyStatus.owner - >>> enum_item_array = asarray([free_lodger, owner]) - >>> encoded_array = HousingOccupancyStatus.encode(enum_item_array) - >>> encoded_array[0] - 2 # Encoded value """ if isinstance(array, EnumArray): return array + if len(array) == 0: + return EnumArray(numpy.asarray(array, t.EnumDType), cls) + if isinstance(array, Sequence): + return cls._encode_array_like(array) + return cls._encode_array(array) - # String array - if isinstance(array, numpy.ndarray) and \ - array.dtype.kind in {'U', 'S'}: - array = numpy.select( - [array == item.name for item in cls], - [item.index for item in cls], - ).astype(ENUM_ARRAY_DTYPE) - - # Enum items arrays - elif isinstance(array, numpy.ndarray) and \ - array.dtype.kind == 'O': - # Ensure we are comparing the comparable. The problem this fixes: - # On entering this method "cls" will generally come from - # variable.possible_values, while the array values may come from - # directly importing a module containing an Enum class. However, - # variables (and hence their possible_values) are loaded by a call - # to load_module, which gives them a different identity from the - # ones imported in the usual way. - # - # So, instead of relying on the "cls" passed in, we use only its - # name to check that the values in the array, if non-empty, are of - # the right type. - if len(array) > 0 and cls.__name__ is array[0].__class__.__name__: - cls = array[0].__class__ - - array = numpy.select( - [array == item for item in cls], - [item.index for item in cls], - ).astype(ENUM_ARRAY_DTYPE) - - return EnumArray(array, cls) + @classmethod + def _encode_array(cls, value: t.VarArray) -> t.EnumArray: + if _is_int_array(value): + indices = _int_to_index(cls, value) + elif _is_str_array(value): # type: ignore[unreachable] + indices = _str_to_index(cls, value) + elif _is_enum_array(value) and cls.__name__ is value[0].__class__.__name__: + indices = _enum_to_index(value) + else: + raise EnumEncodingError(cls, value) + if indices.size != len(value): + raise EnumMemberNotFoundError(cls) + return EnumArray(indices, cls) + + @classmethod + def _encode_array_like(cls, value: t.ArrayLike[object]) -> t.EnumArray: + if _is_int_array_like(value): + indices = _int_to_index(cls, value) + elif _is_str_array_like(value): # type: ignore[unreachable] + indices = _str_to_index(cls, value) + elif _is_enum_array_like(value): + indices = _enum_to_index(value) + else: + raise EnumEncodingError(cls, value) + if indices.size != len(value): + raise EnumMemberNotFoundError(cls) + return EnumArray(indices, cls) + + +__all__ = ["Enum"] diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 6a77be57a7..98f9b4c6aa 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -1,58 +1,240 @@ from __future__ import annotations -import typing -from typing import Any, NoReturn, Optional, Type +from typing import NoReturn +from typing_extensions import Self import numpy -if typing.TYPE_CHECKING: - from openfisca_core.indexed_enums import Enum +from . import types as t -class EnumArray(numpy.ndarray): - """ - Numpy array subclass representing an array of enum items. +class EnumArray(t.EnumArray): + """A subclass of :class:`~numpy.ndarray` of :class:`.Enum`. + + :class:`.Enum` arrays are encoded as :class:`int` to improve performance. + + Note: + Subclassing :class:`~numpy.ndarray` is a little tricky™. To read more + about the :meth:`.__new__` and :meth:`.__array_finalize__` methods + below, see `Subclassing ndarray`_. + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum, variables + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + ... FREE_LODGER = "Free lodger" + ... HOMELESS = "Homeless" + + >>> array = numpy.array([1], dtype=numpy.int16) + >>> enum_array = enum.EnumArray(array, Housing) + + >>> repr(enum.EnumArray) + "" + + >>> repr(enum_array) + 'EnumArray([Housing.TENANT])' + + >>> str(enum_array) + "['TENANT']" + + >>> list(map(int, enum_array)) + [1] + + >>> int(enum_array[0]) + 1 + + >>> enum_array[0] in enum_array + True + + >>> len(enum_array) + 1 + + >>> enum_array = enum.EnumArray(list(Housing), Housing) + Traceback (most recent call last): + AttributeError: 'list' object has no attribute 'view' + + >>> class OccupancyStatus(variables.Variable): + ... value_type = enum.Enum + ... possible_values = Housing + + >>> enum.EnumArray(array, OccupancyStatus.possible_values) + EnumArray([Housing.TENANT]) + + .. _Subclassing ndarray: + https://numpy.org/doc/stable/user/basics.subclassing.html - EnumArrays are encoded as ``int`` arrays to improve performance """ - # Subclassing ndarray is a little tricky. - # To read more about the two following methods, see: - # https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array. + #: Enum type of the array items. + possible_values: None | type[t.Enum] + def __new__( - cls, - input_array: numpy.int_, - possible_values: Optional[Type[Enum]] = None, - ) -> EnumArray: - obj = numpy.asarray(input_array).view(cls) + cls, + input_array: t.IndexArray, + possible_values: type[t.Enum], + ) -> Self: + """See comment above.""" + obj = input_array.view(cls) obj.possible_values = possible_values return obj - # See previous comment - def __array_finalize__(self, obj: Optional[numpy.int_]) -> None: + def __array_finalize__(self, obj: None | t.EnumArray | t.VarArray) -> None: + """See comment above.""" if obj is None: return - self.possible_values = getattr(obj, "possible_values", None) - def __eq__(self, other: Any) -> bool: - # When comparing to an item of self.possible_values, use the item index - # to speed up the comparison. - if other.__class__.__name__ is self.possible_values.__name__: - # Use view(ndarray) so that the result is a classic ndarray, not an - # EnumArray. - return self.view(numpy.ndarray) == other.index + def __eq__(self, other: object) -> t.BoolArray: # type: ignore[override] + """Compare equality with the item's :attr:`~.Enum.index`. + + When comparing to an item of :attr:`.possible_values`, use the + item's :attr:`~.Enum.index`. to speed up the comparison. + + Whenever possible, use :any:`numpy.ndarray.view` so that the result is + a classic :class:`~numpy.ndarray`, not an :obj:`.EnumArray`. + + Args: + other: Another :class:`object` to compare to. + + Returns: + bool: When ??? + ndarray[bool_]: When ??? + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> array = numpy.array([1]) + >>> enum_array = enum.EnumArray(array, Housing) + + >>> enum_array == Housing + array([False, True]) + + >>> enum_array == Housing.TENANT + array([ True]) + + >>> enum_array == 1 + array([ True]) + + >>> enum_array == [1] + array([ True]) + + >>> enum_array == [2] + array([False]) + + >>> enum_array == "1" + array([False]) + + >>> enum_array is None + False + + >>> enum_array == enum.EnumArray(numpy.array([1]), Housing) + array([ True]) - return self.view(numpy.ndarray) == other + Note: + This breaks the `Liskov substitution principle`_. - def __ne__(self, other: Any) -> bool: + .. _Liskov substitution principle: + https://en.wikipedia.org/wiki/Liskov_substitution_principle + + """ + result: t.BoolArray + + if self.possible_values is None: + return NotImplemented + if other is None: + return NotImplemented + if ( + isinstance(other, type(t.Enum)) + and other.__name__ is self.possible_values.__name__ + ): + result = ( + self.view(numpy.ndarray) + == self.possible_values.indices[ + self.possible_values.indices <= max(self) + ] + ) + return result + if ( + isinstance(other, t.Enum) + and other.__class__.__name__ is self.possible_values.__name__ + ): + result = self.view(numpy.ndarray) == other.index + return result + # For NumPy >=1.26.x. + if isinstance(is_equal := self.view(numpy.ndarray) == other, numpy.ndarray): + return is_equal + # For NumPy <1.26.x. + return numpy.array([is_equal], dtype=t.BoolDType) + + def __ne__(self, other: object) -> t.BoolArray: # type: ignore[override] + """Inequality. + + Args: + other: Another :class:`object` to compare to. + + Returns: + bool: When ??? + ndarray[bool_]: When ??? + + Examples: + >>> import numpy + + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> array = numpy.array([1]) + >>> enum_array = enum.EnumArray(array, Housing) + + >>> enum_array != Housing + array([ True, False]) + + >>> enum_array != Housing.TENANT + array([False]) + + >>> enum_array != 1 + array([False]) + + >>> enum_array != [1] + array([False]) + + >>> enum_array != [2] + array([ True]) + + >>> enum_array != "1" + array([ True]) + + >>> enum_array is not None + True + + Note: + This breaks the `Liskov substitution principle`_. + + .. _Liskov substitution principle: + https://en.wikipedia.org/wiki/Liskov_substitution_principle + + """ return numpy.logical_not(self == other) - def _forbidden_operation(self, other: Any) -> NoReturn: - raise TypeError( + @staticmethod + def _forbidden_operation(*__args: object, **__kwds: object) -> NoReturn: + msg = ( "Forbidden operation. The only operations allowed on EnumArrays " - "are '==' and '!='.", - ) + "are '==' and '!='." + ) + raise TypeError(msg) __add__ = _forbidden_operation __mul__ = _forbidden_operation @@ -63,44 +245,81 @@ def _forbidden_operation(self, other: Any) -> NoReturn: __and__ = _forbidden_operation __or__ = _forbidden_operation - def decode(self) -> numpy.object_: - """ - Return the array of enum items corresponding to self. + def decode(self) -> t.ObjArray: + """Decode itself to a normal array. + + Returns: + ndarray[Enum]: The items of the :obj:`.EnumArray`. + + Raises: + TypeError: When the :attr:`.possible_values` is not defined. + + Examples: + >>> import numpy - For instance: + >>> from openfisca_core import indexed_enums as enum - >>> enum_array = household('housing_occupancy_status', period) - >>> enum_array[0] - >>> 2 # Encoded value - >>> enum_array.decode()[0] - + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> array = numpy.array([1]) + >>> enum_array = enum.EnumArray(array, Housing) + >>> enum_array.decode() + array([Housing.TENANT], dtype=object) - Decoded value: enum item """ - return numpy.select( - [self == item.index for item in self.possible_values], - list(self.possible_values), + result: t.ObjArray + if self.possible_values is None: + msg = ( + f"The possible values of the {self.__class__.__name__} are " + f"not defined." ) + raise TypeError(msg) + array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self + result = self.possible_values.enums[array] + return result - def decode_to_str(self) -> numpy.str_: - """ - Return the array of string identifiers corresponding to self. + def decode_to_str(self) -> t.StrArray: + """Decode itself to an array of strings. + + Returns: + ndarray[str_]: The string values of the :obj:`.EnumArray`. + + Raises: + TypeError: When the :attr:`.possible_values` is not defined. + + Examples: + >>> import numpy - For instance: + >>> from openfisca_core import indexed_enums as enum + + >>> class Housing(enum.Enum): + ... OWNER = "Owner" + ... TENANT = "Tenant" + + >>> array = numpy.array([1]) + >>> enum_array = enum.EnumArray(array, Housing) + >>> enum_array.decode_to_str() + array(['TENANT'], dtype='>> enum_array = household('housing_occupancy_status', period) - >>> enum_array[0] - >>> 2 # Encoded value - >>> enum_array.decode_to_str()[0] - 'free_lodger' # String identifier """ - return numpy.select( - [self == item.index for item in self.possible_values], - [item.name for item in self.possible_values], + result: t.StrArray + if self.possible_values is None: + msg = ( + f"The possible values of the {self.__class__.__name__} are " + f"not defined." ) + raise TypeError(msg) + array = self.reshape(1).astype(t.EnumDType) if self.ndim == 0 else self + result = self.possible_values.names[array] + return result def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self.decode())})" + return f"{self.__class__.__name__}({self.decode()!s})" def __str__(self) -> str: return str(self.decode_to_str()) + + +__all__ = ["EnumArray"] diff --git a/openfisca_core/indexed_enums/py.typed b/openfisca_core/indexed_enums/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/indexed_enums/tests/__init__.py b/openfisca_core/indexed_enums/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/indexed_enums/tests/test_enum.py b/openfisca_core/indexed_enums/tests/test_enum.py new file mode 100644 index 0000000000..2e49c17427 --- /dev/null +++ b/openfisca_core/indexed_enums/tests/test_enum.py @@ -0,0 +1,135 @@ +import numpy +import pytest +from numpy.testing import assert_array_equal + +from openfisca_core import indexed_enums as enum + + +class Animal(enum.Enum): + CAT = b"Cat" + DOG = b"Dog" + + +class Colour(enum.Enum): + INCARNADINE = "incarnadine" + TURQUOISE = "turquoise" + AMARANTH = "amaranth" + + +# Arrays of Enum + + +def test_enum_encode_with_array_of_enum(): + """Does encode when called with an array of enums.""" + array = numpy.array([Animal.DOG, Animal.DOG, Animal.CAT]) + enum_array = Animal.encode(array) + assert_array_equal(enum_array, numpy.array([1, 1, 0])) + + +def test_enum_encode_with_enum_sequence(): + """Does encode when called with an enum sequence.""" + sequence = list(Animal) + enum_array = Animal.encode(sequence) + assert Animal.DOG in enum_array + + +def test_enum_encode_with_enum_scalar_array(): + """Does not encode when called with an enum scalar array.""" + array = numpy.array(Animal.DOG) + with pytest.raises(TypeError): + Animal.encode(array) + + +def test_enum_encode_with_enum_with_bad_value(): + """Does not encode when called with a value not in an Enum.""" + array = numpy.array([Colour.AMARANTH]) + with pytest.raises(TypeError): + Animal.encode(array) + + +# Arrays of int + + +def test_enum_encode_with_array_of_int(): + """Does encode when called with an array of int.""" + array = numpy.array([1, 1, 0]) + enum_array = Animal.encode(array) + assert_array_equal(enum_array, numpy.array([1, 1, 0])) + + +def test_enum_encode_with_int_sequence(): + """Does encode when called with an int sequence.""" + sequence = (0, 1) + enum_array = Animal.encode(sequence) + assert Animal.DOG in enum_array + + +def test_enum_encode_with_int_scalar_array(): + """Does not encode when called with an int scalar array.""" + array = numpy.array(1) + with pytest.raises(TypeError): + Animal.encode(array) + + +def test_enum_encode_with_int_with_bad_value(): + """Does not encode when called with a value not in an Enum.""" + array = numpy.array([2]) + with pytest.raises(IndexError): + Animal.encode(array) + + +# Arrays of strings + + +def test_enum_encode_with_array_of_string(): + """Does encode when called with an array of string.""" + array = numpy.array(["DOG", "DOG", "CAT"]) + enum_array = Animal.encode(array) + assert_array_equal(enum_array, numpy.array([1, 1, 0])) + + +def test_enum_encode_with_str_sequence(): + """Does encode when called with a str sequence.""" + sequence = ("DOG", "CAT") + enum_array = Animal.encode(sequence) + assert Animal.DOG in enum_array + + +def test_enum_encode_with_str_scalar_array(): + """Does not encode when called with a str scalar array.""" + array = numpy.array("DOG") + with pytest.raises(TypeError): + Animal.encode(array) + + +def test_enum_encode_with_str_with_bad_value(): + """Encode encode when called with a value not in an Enum.""" + array = numpy.array(["JAIBA"]) + with pytest.raises(IndexError): + Animal.encode(array) + + +# Unsupported encodings + + +def test_enum_encode_with_any_array(): + """Does not encode when called with unsupported types.""" + value = {"animal": "dog"} + array = numpy.array([value]) + with pytest.raises(TypeError): + Animal.encode(array) + + +def test_enum_encode_with_any_scalar_array(): + """Does not encode when called with unsupported types.""" + value = 1.5 + array = numpy.array(value) + with pytest.raises(TypeError): + Animal.encode(array) + + +def test_enum_encode_with_any_sequence(): + """Does not encode when called with unsupported types.""" + sequence = memoryview(b"DOG") + with pytest.raises(IndexError): + Animal.encode(sequence) diff --git a/openfisca_core/indexed_enums/tests/test_enum_array.py b/openfisca_core/indexed_enums/tests/test_enum_array.py new file mode 100644 index 0000000000..1ab2474688 --- /dev/null +++ b/openfisca_core/indexed_enums/tests/test_enum_array.py @@ -0,0 +1,30 @@ +import numpy +import pytest + +from openfisca_core import indexed_enums as enum + + +class Fruit(enum.Enum): + APPLE = b"apple" + BERRY = b"berry" + + +@pytest.fixture +def enum_array(): + return enum.EnumArray(numpy.array([1]), Fruit) + + +def test_enum_array_eq_operation(enum_array): + """The equality operation is permitted.""" + assert enum_array == enum.EnumArray(numpy.array([1]), Fruit) + + +def test_enum_array_ne_operation(enum_array): + """The non-equality operation is permitted.""" + assert enum_array != enum.EnumArray(numpy.array([0]), Fruit) + + +def test_enum_array_any_other_operation(enum_array): + """Only equality and non-equality operations are permitted.""" + with pytest.raises(TypeError, match="Forbidden operation."): + enum_array * 1 diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py new file mode 100644 index 0000000000..e0a71b3221 --- /dev/null +++ b/openfisca_core/indexed_enums/types.py @@ -0,0 +1,41 @@ +from typing_extensions import TypeAlias + +from openfisca_core.types import Array, ArrayLike, DTypeLike, Enum, EnumArray, EnumType + +from enum import _EnumDict as EnumDict # noqa: PLC2701 + +from numpy import ( + bool_ as BoolDType, + generic as VarDType, + int32 as IntDType, + object_ as ObjDType, + str_ as StrDType, + uint8 as EnumDType, +) + +#: Type for enum indices arrays. +IndexArray: TypeAlias = Array[EnumDType] + +#: Type for boolean arrays. +BoolArray: TypeAlias = Array[BoolDType] + +#: Type for int arrays. +IntArray: TypeAlias = Array[IntDType] + +#: Type for str arrays. +StrArray: TypeAlias = Array[StrDType] + +#: Type for object arrays. +ObjArray: TypeAlias = Array[ObjDType] + +#: Type for generic arrays. +VarArray: TypeAlias = Array[VarDType] + +__all__ = [ + "ArrayLike", + "DTypeLike", + "Enum", + "EnumArray", + "EnumDict", + "EnumType", +] diff --git a/openfisca_core/model_api.py b/openfisca_core/model_api.py index 3140c04d69..e36e0d5f76 100644 --- a/openfisca_core/model_api.py +++ b/openfisca_core/model_api.py @@ -1,39 +1,63 @@ -from datetime import date # noqa: F401 +from datetime import date -from numpy import ( # noqa: F401 +from numpy import ( logical_not as not_, maximum as max_, minimum as min_, round as round_, select, where, - ) +) -from openfisca_core.commons import apply_thresholds, concat, switch # noqa: F401 - -from openfisca_core.holders import ( # noqa: F401 +from openfisca_core.commons import apply_thresholds, concat, switch +from openfisca_core.holders import ( set_input_dispatch_by_period, set_input_divide_by_period, - ) - -from openfisca_core.indexed_enums import Enum # noqa: F401 - -from openfisca_core.parameters import ( # noqa: F401 +) +from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import ( + Bracket, Parameter, ParameterNode, - ParameterScale, - ParameterScaleBracket, + Scale, ValuesHistory, load_parameter_file, - ) - -from openfisca_core.periods import DAY, MONTH, YEAR, ETERNITY, period # noqa: F401 -from openfisca_core.populations import ADD, DIVIDE # noqa: F401 -from openfisca_core.reforms import Reform # noqa: F401 - -from openfisca_core.simulations import ( # noqa: F401 - calculate_output_add, - calculate_output_divide, - ) - -from openfisca_core.variables import Variable # noqa: F401 +) +from openfisca_core.periods import DAY, ETERNITY, MONTH, YEAR, period +from openfisca_core.populations import ADD, DIVIDE +from openfisca_core.reforms import Reform +from openfisca_core.simulations import calculate_output_add, calculate_output_divide +from openfisca_core.variables import Variable + +__all__ = [ + "date", + "not_", + "max_", + "min_", + "round_", + "select", + "where", + "apply_thresholds", + "concat", + "switch", + "set_input_dispatch_by_period", + "set_input_divide_by_period", + "Enum", + "Bracket", + "Parameter", + "ParameterNode", + "Scale", + "ValuesHistory", + "load_parameter_file", + "DAY", + "ETERNITY", + "MONTH", + "YEAR", + "period", + "ADD", + "DIVIDE", + "Reform", + "calculate_output_add", + "calculate_output_divide", + "Variable", +] diff --git a/openfisca_core/parameters/__init__.py b/openfisca_core/parameters/__init__.py index bbf5a0595f..5d742d4611 100644 --- a/openfisca_core/parameters/__init__.py +++ b/openfisca_core/parameters/__init__.py @@ -21,21 +21,52 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .config import ( # noqa: F401 +from openfisca_core.errors import ParameterNotFound, ParameterParsingError + +from .at_instant_like import AtInstantLike +from .config import ( ALLOWED_PARAM_TYPES, COMMON_KEYS, FILE_EXTENSIONS, date_constructor, dict_no_duplicate_constructor, - ) +) +from .helpers import contains_nan, load_parameter_file +from .parameter import Parameter +from .parameter_at_instant import ParameterAtInstant +from .parameter_node import ParameterNode +from .parameter_node_at_instant import ParameterNodeAtInstant +from .parameter_scale import ParameterScale, ParameterScale as Scale +from .parameter_scale_bracket import ( + ParameterScaleBracket, + ParameterScaleBracket as Bracket, +) +from .values_history import ValuesHistory +from .vectorial_asof_date_parameter_node_at_instant import ( + VectorialAsofDateParameterNodeAtInstant, +) +from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant -from .at_instant_like import AtInstantLike # noqa: F401 -from .helpers import contains_nan, load_parameter_file # noqa: F401 -from .parameter_at_instant import ParameterAtInstant # noqa: F401 -from .parameter_node_at_instant import ParameterNodeAtInstant # noqa: F401 -from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant # noqa: F401 -from .parameter import Parameter # noqa: F401 -from .parameter_node import ParameterNode # noqa: F401 -from .parameter_scale import ParameterScale # noqa: F401 -from .parameter_scale_bracket import ParameterScaleBracket # noqa: F401 -from .values_history import ValuesHistory # noqa: F401 +__all__ = [ + "ParameterNotFound", + "ParameterParsingError", + "AtInstantLike", + "ALLOWED_PARAM_TYPES", + "COMMON_KEYS", + "FILE_EXTENSIONS", + "date_constructor", + "dict_no_duplicate_constructor", + "contains_nan", + "load_parameter_file", + "Parameter", + "ParameterAtInstant", + "ParameterNode", + "ParameterNodeAtInstant", + "ParameterScale", + "Scale", + "ParameterScaleBracket", + "Bracket", + "ValuesHistory", + "VectorialAsofDateParameterNodeAtInstant", + "VectorialParameterNodeAtInstant", +] diff --git a/openfisca_core/parameters/at_instant_like.py b/openfisca_core/parameters/at_instant_like.py index 1a1db34beb..19c28e98c2 100644 --- a/openfisca_core/parameters/at_instant_like.py +++ b/openfisca_core/parameters/at_instant_like.py @@ -4,9 +4,7 @@ class AtInstantLike(abc.ABC): - """ - Base class for various types of parameters implementing the at instant protocol. - """ + """Base class for various types of parameters implementing the at instant protocol.""" def __call__(self, instant): return self.get_at_instant(instant) @@ -16,5 +14,4 @@ def get_at_instant(self, instant): return self._get_at_instant(instant) @abc.abstractmethod - def _get_at_instant(self, instant): - ... + def _get_at_instant(self, instant): ... diff --git a/openfisca_core/parameters/config.py b/openfisca_core/parameters/config.py index e9a3041ae8..5fb1198bea 100644 --- a/openfisca_core/parameters/config.py +++ b/openfisca_core/parameters/config.py @@ -1,9 +1,9 @@ -import warnings import os +import warnings + import yaml -import typing -from openfisca_core.warnings import LibYAMLWarning +from openfisca_core.warnings import LibYAMLWarning try: from yaml import CLoader as Loader @@ -12,33 +12,44 @@ "libyaml is not installed in your environment.", "This can make OpenFisca slower to start.", "Once you have installed libyaml, run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", - "so that it is used in your Python environment." + os.linesep - ] - warnings.warn(" ".join(message), LibYAMLWarning) - from yaml import Loader # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + "so that it is used in your Python environment." + os.linesep, + ] + warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) + from yaml import ( # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + Loader, + ) # 'unit' and 'reference' are only listed here for backward compatibility. # It is now recommended to include them in metadata, until a common consensus emerges. -ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) -COMMON_KEYS = {'description', 'metadata', 'unit', 'reference', 'documentation'} -FILE_EXTENSIONS = {'.yaml', '.yml'} +ALLOWED_PARAM_TYPES = (float, int, bool, type(None), list) +COMMON_KEYS = {"description", "metadata", "unit", "reference", "documentation"} +FILE_EXTENSIONS = {".yaml", ".yml"} def date_constructor(_loader, node): return node.value -yaml.add_constructor('tag:yaml.org,2002:timestamp', date_constructor, Loader = Loader) +yaml.add_constructor("tag:yaml.org,2002:timestamp", date_constructor, Loader=Loader) -def dict_no_duplicate_constructor(loader, node, deep = False): +def dict_no_duplicate_constructor(loader, node, deep=False): keys = [key.value for key, value in node.value] if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) - raise yaml.parser.ParserError('', node.start_mark, f"Found duplicate key '{duplicate}'") + duplicate = next(key for key in keys if keys.count(key) > 1) + msg = "" + raise yaml.parser.ParserError( + msg, + node.start_mark, + f"Found duplicate key '{duplicate}'", + ) return loader.construct_mapping(node, deep) -yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, dict_no_duplicate_constructor, Loader = Loader) +yaml.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, + dict_no_duplicate_constructor, + Loader=Loader, +) diff --git a/openfisca_core/parameters/helpers.py b/openfisca_core/parameters/helpers.py index 75d5a18b73..09925bbcdb 100644 --- a/openfisca_core/parameters/helpers.py +++ b/openfisca_core/parameters/helpers.py @@ -9,91 +9,98 @@ def contains_nan(vector): - if numpy.issubdtype(vector.dtype, numpy.record): - return any([contains_nan(vector[name]) for name in vector.dtype.names]) - else: - return numpy.isnan(vector).any() + if numpy.issubdtype(vector.dtype, numpy.record) or numpy.issubdtype( + vector.dtype, + numpy.void, + ): + return any(contains_nan(vector[name]) for name in vector.dtype.names) + return numpy.isnan(vector).any() -def load_parameter_file(file_path, name = ''): - """ - Load parameters from a YAML file (or a directory containing YAML files). +def load_parameter_file(file_path, name=""): + """Load parameters from a YAML file (or a directory containing YAML files). :returns: An instance of :class:`.ParameterNode` or :class:`.ParameterScale` or :class:`.Parameter`. """ if not os.path.exists(file_path): - raise ValueError("{} does not exist".format(file_path)) + msg = f"{file_path} does not exist" + raise ValueError(msg) if os.path.isdir(file_path): - return parameters.ParameterNode(name, directory_path = file_path) + return parameters.ParameterNode(name, directory_path=file_path) data = _load_yaml_file(file_path) return _parse_child(name, data, file_path) -def _compose_name(path, child_name = None, item_name = None): +def _compose_name(path, child_name=None, item_name=None): if not path: return child_name if child_name is not None: - return '{}.{}'.format(path, child_name) + return f"{path}.{child_name}" if item_name is not None: - return '{}[{}]'.format(path, item_name) + return f"{path}[{item_name}]" + return None def _load_yaml_file(file_path): - with open(file_path, 'r') as f: + with open(file_path) as f: try: - return config.yaml.load(f, Loader = config.Loader) + return config.yaml.load(f, Loader=config.Loader) except (config.yaml.scanner.ScannerError, config.yaml.parser.ParserError): stack_trace = traceback.format_exc() + msg = "Invalid YAML. Check the traceback above for more details." raise ParameterParsingError( - "Invalid YAML. Check the traceback above for more details.", + msg, file_path, - stack_trace - ) + stack_trace, + ) except Exception: stack_trace = traceback.format_exc() + msg = "Invalid parameter file content. Check the traceback above for more details." raise ParameterParsingError( - "Invalid parameter file content. Check the traceback above for more details.", + msg, file_path, - stack_trace - ) + stack_trace, + ) def _parse_child(child_name, child, child_path): - if 'values' in child: + if "values" in child: return parameters.Parameter(child_name, child, child_path) - elif 'brackets' in child: + if "brackets" in child: return parameters.ParameterScale(child_name, child, child_path) - elif isinstance(child, dict) and all([periods.INSTANT_PATTERN.match(str(key)) for key in child.keys()]): + if isinstance(child, dict) and all( + periods.INSTANT_PATTERN.match(str(key)) for key in child + ): return parameters.Parameter(child_name, child, child_path) - else: - return parameters.ParameterNode(child_name, data = child, file_path = child_path) + return parameters.ParameterNode(child_name, data=child, file_path=child_path) -def _set_backward_compatibility_metadata(parameter, data): - if data.get('unit') is not None: - parameter.metadata['unit'] = data['unit'] - if data.get('reference') is not None: - parameter.metadata['reference'] = data['reference'] +def _set_backward_compatibility_metadata(parameter, data) -> None: + if data.get("unit") is not None: + parameter.metadata["unit"] = data["unit"] + if data.get("reference") is not None: + parameter.metadata["reference"] = data["reference"] -def _validate_parameter(parameter, data, data_type = None, allowed_keys = None): +def _validate_parameter(parameter, data, data_type=None, allowed_keys=None) -> None: type_map = { - dict: 'object', - list: 'array', - } + dict: "object", + list: "array", + } if data_type is not None and not isinstance(data, data_type): + msg = f"'{parameter.name}' must be of type {type_map[data_type]}." raise ParameterParsingError( - "'{}' must be of type {}.".format(parameter.name, type_map[data_type]), - parameter.file_path - ) + msg, + parameter.file_path, + ) if allowed_keys is not None: keys = data.keys() for key in keys: if key not in allowed_keys: + msg = f"Unexpected property '{key}' in '{parameter.name}'. Allowed properties are {list(allowed_keys)}." raise ParameterParsingError( - "Unexpected property '{}' in '{}'. Allowed properties are {}." - .format(key, parameter.name, list(allowed_keys)), - parameter.file_path - ) + msg, + parameter.file_path, + ) diff --git a/openfisca_core/parameters/parameter.py b/openfisca_core/parameters/parameter.py index 62fd3f6766..528f54cccd 100644 --- a/openfisca_core/parameters/parameter.py +++ b/openfisca_core/parameters/parameter.py @@ -1,22 +1,30 @@ +from __future__ import annotations + import copy import os -import typing from openfisca_core import commons, periods from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike, ParameterAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter_at_instant import ParameterAtInstant class Parameter(AtInstantLike): - """ - A parameter of the legislation. Parameters can change over time. + """A parameter of the legislation. + + Parameters can change over time. - :param string name: Name of the parameter, e.g. "taxes.some_tax.some_param" - :param dict data: Data loaded from a YAML file. - :param string file_path: File the parameter was loaded from. - :param string documentation: Documentation describing parameter usage and context. + Attributes: + values_list: List of the values, in reverse chronological order. + Args: + name: Name of the parameter, e.g. "taxes.some_tax.some_param". + data: Data loaded from a YAML file. + file_path: File the parameter was loaded from. + Instantiate a parameter without metadata: >>> Parameter('rate', data = { @@ -34,63 +42,84 @@ class Parameter(AtInstantLike): } }) - .. attribute:: values_list - - List of the values, in reverse chronological order """ - def __init__(self, name, data, file_path = None): + def __init__(self, name: str, data: dict, file_path: str | None = None) -> None: self.name: str = name - self.file_path: str = file_path - helpers._validate_parameter(self, data, data_type = dict) - self.description: str = None - self.metadata: typing.Dict = {} - self.documentation: str = None + self.file_path: str | None = file_path + helpers._validate_parameter(self, data, data_type=dict) + self.description: str | None = None + self.metadata: dict = {} + self.documentation: str | None = None self.values_history = self # Only for backward compatibility # Normal parameter declaration: the values are declared under the 'values' key: parse the description and metadata. - if data.get('values'): + if data.get("values"): # 'unit' and 'reference' are only listed here for backward compatibility - helpers._validate_parameter(self, data, allowed_keys = config.COMMON_KEYS.union({'values'})) - self.description = data.get('description') + helpers._validate_parameter( + self, + data, + allowed_keys=config.COMMON_KEYS.union({"values"}), + ) + self.description = data.get("description") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - helpers._validate_parameter(self, data['values'], data_type = dict) - values = data['values'] + helpers._validate_parameter(self, data["values"], data_type=dict) + values = data["values"] - self.documentation = data.get('documentation') + self.documentation = data.get("documentation") else: # Simplified parameter declaration: only values are provided values = data - instants = sorted(values.keys(), reverse = True) # sort in reverse chronological order + instants = sorted( + values.keys(), + reverse=True, + ) # sort in reverse chronological order values_list = [] for instant_str in instants: if not periods.INSTANT_PATTERN.match(instant_str): + msg = f"Invalid property '{instant_str}' in '{self.name}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." raise ParameterParsingError( - "Invalid property '{}' in '{}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." - .format(instant_str, self.name), - file_path) + msg, + file_path, + ) instant_info = values[instant_str] # Ignore expected values, as they are just metadata - if instant_info == "expected" or isinstance(instant_info, dict) and instant_info.get("expected"): + if ( + instant_info == "expected" + or isinstance(instant_info, dict) + and instant_info.get("expected") + ): continue - value_name = helpers._compose_name(name, item_name = instant_str) - value_at_instant = ParameterAtInstant(value_name, instant_str, data = instant_info, file_path = self.file_path, metadata = self.metadata) + value_name = helpers._compose_name(name, item_name=instant_str) + value_at_instant = ParameterAtInstant( + value_name, + instant_str, + data=instant_info, + file_path=self.file_path, + metadata=self.metadata, + ) values_list.append(value_at_instant) - self.values_list: typing.List[ParameterAtInstant] = values_list + self.values_list: list[ParameterAtInstant] = values_list - def __repr__(self): - return os.linesep.join([ - '{}: {}'.format(value.instant_str, value.value if value.value is not None else 'null') for value in self.values_list - ]) + def __repr__(self) -> str: + return os.linesep.join( + [ + "{}: {}".format( + value.instant_str, + value.value if value.value is not None else "null", + ) + for value in self.values_list + ], + ) def __eq__(self, other): return (self.name == other.name) and (self.values_list == other.values_list) @@ -100,12 +129,13 @@ def clone(self): clone.__dict__ = self.__dict__.copy() clone.metadata = copy.deepcopy(self.metadata) - clone.values_list = [parameter_at_instant.clone() for parameter_at_instant in self.values_list] + clone.values_list = [ + parameter_at_instant.clone() for parameter_at_instant in self.values_list + ] return clone - def update(self, period = None, start = None, stop = None, value = None): - """ - Change the value for a given period. + def update(self, period=None, start=None, stop=None, value=None) -> None: + """Change the value for a given period. :param period: Period where the value is modified. If set, `start` and `stop` should be `None`. :param start: Start of the period. Instance of `openfisca_core.periods.Instant`. If set, `period` should be `None`. @@ -114,15 +144,19 @@ def update(self, period = None, start = None, stop = None, value = None): """ if period is not None: if start is not None or stop is not None: - raise TypeError("Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'.") + msg = "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." + raise TypeError( + msg, + ) if isinstance(period, str): period = periods.period(period) start = period.start stop = period.stop if start is None: - raise ValueError("You must provide either a start or a period") + msg = "You must provide either a start or a period" + raise ValueError(msg) start_str = str(start) - stop_str = str(stop.offset(1, 'day')) if stop else None + stop_str = str(stop.offset(1, "day")) if stop else None old_values = self.values_list new_values = [] @@ -139,20 +173,27 @@ def update(self, period = None, start = None, stop = None, value = None): if stop_str: if new_values and (stop_str == new_values[-1].instant_str): pass # such interval is empty + elif i < n: + overlapped_value = old_values[i].value + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": overlapped_value}, + ) + new_values.append(new_interval) else: - if i < n: - overlapped_value = old_values[i].value - value_name = helpers._compose_name(self.name, item_name = stop_str) - new_interval = ParameterAtInstant(value_name, stop_str, data = {'value': overlapped_value}) - new_values.append(new_interval) - else: - value_name = helpers._compose_name(self.name, item_name = stop_str) - new_interval = ParameterAtInstant(value_name, stop_str, data = {'value': None}) - new_values.append(new_interval) + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": None}, + ) + new_values.append(new_interval) # Insert new interval - value_name = helpers._compose_name(self.name, item_name = start_str) - new_interval = ParameterAtInstant(value_name, start_str, data = {'value': value}) + value_name = helpers._compose_name(self.name, item_name=start_str) + new_interval = ParameterAtInstant(value_name, start_str, data={"value": value}) new_values.append(new_interval) # Remove covered intervals diff --git a/openfisca_core/parameters/parameter_at_instant.py b/openfisca_core/parameters/parameter_at_instant.py index ea91d25421..ae525cf829 100644 --- a/openfisca_core/parameters/parameter_at_instant.py +++ b/openfisca_core/parameters/parameter_at_instant.py @@ -1,5 +1,4 @@ import copy -import typing from openfisca_core import commons from openfisca_core.errors import ParameterParsingError @@ -7,23 +6,22 @@ class ParameterAtInstant: - """ - A value of a parameter at a given instant. - """ + """A value of a parameter at a given instant.""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = set(['value', 'metadata', 'unit', 'reference']) + _allowed_keys = {"value", "metadata", "unit", "reference"} - def __init__(self, name, instant_str, data = None, file_path = None, metadata = None): - """ - :param string name: name of the parameter, e.g. "taxes.some_tax.some_param" - :param string instant_str: Date of the value in the format `YYYY-MM-DD`. + def __init__( + self, name, instant_str, data=None, file_path=None, metadata=None + ) -> None: + """:param str name: name of the parameter, e.g. "taxes.some_tax.some_param" + :param str instant_str: Date of the value in the format `YYYY-MM-DD`. :param dict data: Data, usually loaded from a YAML file. """ self.name: str = name self.instant_str: str = instant_str self.file_path: str = file_path - self.metadata: typing.Dict = {} + self.metadata: dict = {} # Accept { 2015-01-01: 4000 } if not isinstance(data, dict) and isinstance(data, config.ALLOWED_PARAM_TYPES): @@ -31,33 +29,44 @@ def __init__(self, name, instant_str, data = None, file_path = None, metadata = return self.validate(data) - self.value: float = data['value'] + self.value: float = data["value"] if metadata is not None: self.metadata.update(metadata) # Inherit metadata from Parameter helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - def validate(self, data): - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) + def validate(self, data) -> None: + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) try: - value = data['value'] + value = data["value"] except KeyError: + msg = f"Missing 'value' property for {self.name}" raise ParameterParsingError( - "Missing 'value' property for {}".format(self.name), - self.file_path - ) + msg, + self.file_path, + ) if not isinstance(value, config.ALLOWED_PARAM_TYPES): + msg = f"Value in {self.name} has type {type(value)}, which is not one of the allowed types ({config.ALLOWED_PARAM_TYPES}): {value}" raise ParameterParsingError( - "Value in {} has type {}, which is not one of the allowed types ({}): {}".format(self.name, type(value), config.ALLOWED_PARAM_TYPES, value), - self.file_path - ) + msg, + self.file_path, + ) def __eq__(self, other): - return (self.name == other.name) and (self.instant_str == other.instant_str) and (self.value == other.value) + return ( + (self.name == other.name) + and (self.instant_str == other.instant_str) + and (self.value == other.value) + ) - def __repr__(self): - return "ParameterAtInstant({})".format({self.instant_str: self.value}) + def __repr__(self) -> str: + return "ParameterAtInstant({self.instant_str: self.value})" def clone(self): clone = commons.empty_clone(self) diff --git a/openfisca_core/parameters/parameter_node.py b/openfisca_core/parameters/parameter_node.py index 1dae81dfb4..6f43379b36 100644 --- a/openfisca_core/parameters/parameter_node.py +++ b/openfisca_core/parameters/parameter_node.py @@ -1,28 +1,30 @@ from __future__ import annotations +from collections.abc import Iterable + import copy import os -import typing from openfisca_core import commons, parameters, tools -from . import config, helpers, AtInstantLike, Parameter, ParameterNodeAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter import Parameter +from .parameter_node_at_instant import ParameterNodeAtInstant class ParameterNode(AtInstantLike): - """ - A node in the legislation `parameter tree `_. - """ + """A node in the legislation `parameter tree `_.""" - _allowed_keys: typing.Optional[typing.Iterable[str]] = None # By default, no restriction on the keys + _allowed_keys: None | Iterable[str] = None # By default, no restriction on the keys - def __init__(self, name = "", directory_path = None, data = None, file_path = None): - """ - Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). + def __init__(self, name="", directory_path=None, data=None, file_path=None) -> None: + """Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). - :param string name: Name of the node, eg "taxes.some_tax". - :param string directory_path: Directory containing YAML files describing the node. + :param str name: Name of the node, eg "taxes.some_tax". + :param str directory_path: Directory containing YAML files describing the node. :param dict data: Object representing the parameter node. It usually has been extracted from a YAML file. - :param string file_path: YAML file from which the `data` has been extracted from. + :param str file_path: YAML file from which the `data` has been extracted from. Instantiate a ParameterNode from a dict: @@ -44,14 +46,20 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No Instantiate a ParameterNode from a directory containing YAML parameter files: - >>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits') + >>> node = ParameterNode( + ... "benefits", + ... directory_path="/path/to/country_package/parameters/benefits", + ... ) """ self.name: str = name - self.children: typing.Dict[str, typing.Union[ParameterNode, Parameter, parameters.ParameterScale]] = {} + self.children: dict[ + str, + ParameterNode | Parameter | parameters.ParameterScale, + ] = {} self.description: str = None self.documentation: str = None self.file_path: str = None - self.metadata: typing.Dict = {} + self.metadata: dict = {} if directory_path: self.file_path = directory_path @@ -64,31 +72,46 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No if ext not in config.FILE_EXTENSIONS: continue - if child_name == 'index': + if child_name == "index": data = helpers._load_yaml_file(child_path) or {} - helpers._validate_parameter(self, data, allowed_keys = config.COMMON_KEYS) - self.description = data.get('description') - self.documentation = data.get('documentation') + helpers._validate_parameter( + self, + data, + allowed_keys=config.COMMON_KEYS, + ) + self.description = data.get("description") + self.documentation = data.get("documentation") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) else: child_name_expanded = helpers._compose_name(name, child_name) - child = helpers.load_parameter_file(child_path, child_name_expanded) + child = helpers.load_parameter_file( + child_path, + child_name_expanded, + ) self.add_child(child_name, child) elif os.path.isdir(child_path): child_name = os.path.basename(child_path) child_name_expanded = helpers._compose_name(name, child_name) - child = ParameterNode(child_name_expanded, directory_path = child_path) + child = ParameterNode( + child_name_expanded, + directory_path=child_path, + ) self.add_child(child_name, child) else: self.file_path = file_path - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) - self.description = data.get('description') - self.documentation = data.get('documentation') + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) + self.description = data.get("description") + self.documentation = data.get("documentation") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) for child_name, child in data.items(): if child_name in config.COMMON_KEYS: continue # do not treat reserved keys as subparameters. @@ -98,41 +121,43 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No child = helpers._parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) - def merge(self, other): - """ - Merges another ParameterNode into the current node. + def merge(self, other) -> None: + """Merges another ParameterNode into the current node. In case of child name conflict, the other node child will replace the current node child. """ for child_name, child in other.children.items(): self.add_child(child_name, child) - def add_child(self, name, child): - """ - Add a new child to the node. + def add_child(self, name, child) -> None: + """Add a new child to the node. :param name: Name of the child that must be used to access that child. Should not contain anything that could interfere with the operator `.` (dot). :param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`. """ if name in self.children: - raise ValueError("{} has already a child named {}".format(self.name, name)) - if not (isinstance(child, ParameterNode) or isinstance(child, Parameter) or isinstance(child, parameters.ParameterScale)): - raise TypeError("child must be of type ParameterNode, Parameter, or Scale. Instead got {}".format(type(child))) + msg = f"{self.name} has already a child named {name}" + raise ValueError(msg) + if not ( + isinstance(child, (ParameterNode, Parameter, parameters.ParameterScale)) + ): + msg = f"child must be of type ParameterNode, Parameter, or Scale. Instead got {type(child)}" + raise TypeError( + msg, + ) self.children[name] = child setattr(self, name, child) - def __repr__(self): - result = os.linesep.join( - [os.linesep.join( - ["{}:", "{}"]).format(name, tools.indent(repr(value))) - for name, value in sorted(self.children.items())] - ) - return result + def __repr__(self) -> str: + return os.linesep.join( + [ + os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) + for name, value in sorted(self.children.items()) + ], + ) def get_descendants(self): - """ - Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode` - """ + """Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode`.""" for child in self.children.values(): yield child yield from child.get_descendants() @@ -142,10 +167,7 @@ def clone(self): clone.__dict__ = self.__dict__.copy() clone.metadata = copy.deepcopy(self.metadata) - clone.children = { - key: child.clone() - for key, child in self.children.items() - } + clone.children = {key: child.clone() for key, child in self.children.items()} for child_key, child in clone.children.items(): setattr(clone, child_key, child) diff --git a/openfisca_core/parameters/parameter_node_at_instant.py b/openfisca_core/parameters/parameter_node_at_instant.py index 49a7704c35..a8420a2f9f 100644 --- a/openfisca_core/parameters/parameter_node_at_instant.py +++ b/openfisca_core/parameters/parameter_node_at_instant.py @@ -1,5 +1,4 @@ import os -import sys import numpy @@ -9,17 +8,13 @@ class ParameterNodeAtInstant: - """ - Parameter node of the legislation, at a given instant. - """ + """Parameter node of the legislation, at a given instant.""" - def __init__(self, name, node, instant_str): - """ - :param name: Name of the node. + def __init__(self, name, node, instant_str) -> None: + """:param name: Name of the node. :param node: Original :any:`ParameterNode` instance. :param instant_str: A date in the format `YYYY-MM-DD`. """ - # The "technical" attributes are hidden, so that the node children can be easily browsed with auto-completion without pollution self._name = name self._instant_str = instant_str @@ -30,29 +25,35 @@ def __init__(self, name, node, instant_str): if child_at_instant is not None: self.add_child(child_name, child_at_instant) - def add_child(self, child_name, child_at_instant): + def add_child(self, child_name, child_at_instant) -> None: self._children[child_name] = child_at_instant setattr(self, child_name, child_at_instant) def __getattr__(self, key): - param_name = helpers._compose_name(self._name, item_name = key) + param_name = helpers._compose_name(self._name, item_name=key) raise ParameterNotFoundError(param_name, self._instant_str) def __getitem__(self, key): # If fancy indexing is used, cast to a vectorial node if isinstance(key, numpy.ndarray): + # If fancy indexing is used with a datetime64, cast to a vectorial node supporting datetime64 + if numpy.issubdtype(key.dtype, numpy.datetime64): + return ( + parameters.VectorialAsofDateParameterNodeAtInstant.build_from_node( + self, + )[key] + ) + return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key] return self._children[key] def __iter__(self): return iter(self._children) - def __repr__(self): - result = os.linesep.join( - [os.linesep.join( - ["{}:", "{}"]).format(name, tools.indent(repr(value))) - for name, value in self._children.items()] - ) - if sys.version_info < (3, 0): - return result - return result + def __repr__(self) -> str: + return os.linesep.join( + [ + os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) + for name, value in self._children.items() + ], + ) diff --git a/openfisca_core/parameters/parameter_scale.py b/openfisca_core/parameters/parameter_scale.py index d1cfc26379..b01b6a372a 100644 --- a/openfisca_core/parameters/parameter_scale.py +++ b/openfisca_core/parameters/parameter_scale.py @@ -1,65 +1,72 @@ import copy import os -import typing from openfisca_core import commons, parameters, tools from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike +from openfisca_core.parameters import AtInstantLike, config, helpers from openfisca_core.taxscales import ( LinearAverageRateTaxScale, MarginalAmountTaxScale, MarginalRateTaxScale, SingleAmountTaxScale, - ) +) class ParameterScale(AtInstantLike): - """ - A parameter scale (for instance a marginal scale). - """ + """A parameter scale (for instance a marginal scale).""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = config.COMMON_KEYS.union({'brackets'}) + _allowed_keys = config.COMMON_KEYS.union({"brackets"}) - def __init__(self, name, data, file_path): - """ - :param name: name of the scale, eg "taxes.some_scale" + def __init__(self, name, data, file_path) -> None: + """:param name: name of the scale, eg "taxes.some_scale" :param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically. :param file_path: File the parameter was loaded from. """ self.name: str = name self.file_path: str = file_path - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) - self.description: str = data.get('description') - self.metadata: typing.Dict = {} + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) + self.description: str = data.get("description") + self.metadata: dict = {} helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - if not isinstance(data.get('brackets', []), list): + if not isinstance(data.get("brackets", []), list): + msg = f"Property 'brackets' of scale '{self.name}' must be of type array." raise ParameterParsingError( - "Property 'brackets' of scale '{}' must be of type array." - .format(self.name), - self.file_path - ) + msg, + self.file_path, + ) brackets = [] - for i, bracket_data in enumerate(data.get('brackets', [])): - bracket_name = helpers._compose_name(name, item_name = i) - bracket = parameters.ParameterScaleBracket(name = bracket_name, data = bracket_data, file_path = file_path) + for i, bracket_data in enumerate(data.get("brackets", [])): + bracket_name = helpers._compose_name(name, item_name=i) + bracket = parameters.ParameterScaleBracket( + name=bracket_name, + data=bracket_data, + file_path=file_path, + ) brackets.append(bracket) - self.brackets: typing.List[parameters.ParameterScaleBracket] = brackets + self.brackets: list[parameters.ParameterScaleBracket] = brackets def __getitem__(self, key): if isinstance(key, int) and key < len(self.brackets): return self.brackets[key] - else: - raise KeyError(key) + raise KeyError(key) - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( - ['brackets:'] - + [tools.indent('-' + tools.indent(repr(bracket))[1:]) for bracket in self.brackets] - ) + ["brackets:"] + + [ + tools.indent("-" + tools.indent(repr(bracket))[1:]) + for bracket in self.brackets + ], + ) def get_descendants(self): return iter(()) @@ -76,45 +83,39 @@ def clone(self): def _get_at_instant(self, instant): brackets = [bracket.get_at_instant(instant) for bracket in self.brackets] - if self.metadata.get('type') == 'single_amount': + if self.metadata.get("type") == "single_amount": scale = SingleAmountTaxScale() for bracket in brackets: - if 'amount' in bracket._children and 'threshold' in bracket._children: + if "amount" in bracket._children and "threshold" in bracket._children: amount = bracket.amount threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any('amount' in bracket._children for bracket in brackets): + if any("amount" in bracket._children for bracket in brackets): scale = MarginalAmountTaxScale() for bracket in brackets: - if 'amount' in bracket._children and 'threshold' in bracket._children: + if "amount" in bracket._children and "threshold" in bracket._children: amount = bracket.amount threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any('average_rate' in bracket._children for bracket in brackets): + if any("average_rate" in bracket._children for bracket in brackets): scale = LinearAverageRateTaxScale() for bracket in brackets: - if 'base' in bracket._children: - base = bracket.base - else: - base = 1. - if 'average_rate' in bracket._children and 'threshold' in bracket._children: + if ( + "average_rate" in bracket._children + and "threshold" in bracket._children + ): average_rate = bracket.average_rate threshold = bracket.threshold - scale.add_bracket(threshold, average_rate * base) - return scale - else: - scale = MarginalRateTaxScale() - - for bracket in brackets: - if 'base' in bracket._children: - base = bracket.base - else: - base = 1. - if 'rate' in bracket._children and 'threshold' in bracket._children: - rate = bracket.rate - threshold = bracket.threshold - scale.add_bracket(threshold, rate * base) + scale.add_bracket(threshold, average_rate) return scale + scale = MarginalRateTaxScale() + + for bracket in brackets: + if "rate" in bracket._children and "threshold" in bracket._children: + rate = bracket.rate + threshold = bracket.threshold + scale.add_bracket(threshold, rate) + return scale diff --git a/openfisca_core/parameters/parameter_scale_bracket.py b/openfisca_core/parameters/parameter_scale_bracket.py index 6d361d09fa..b9691ea3ca 100644 --- a/openfisca_core/parameters/parameter_scale_bracket.py +++ b/openfisca_core/parameters/parameter_scale_bracket.py @@ -2,8 +2,6 @@ class ParameterScaleBracket(ParameterNode): - """ - A parameter scale bracket. - """ + """A parameter scale bracket.""" - _allowed_keys = set(['amount', 'threshold', 'rate', 'average_rate', 'base']) + _allowed_keys = {"amount", "threshold", "rate", "average_rate"} diff --git a/openfisca_core/parameters/values_history.py b/openfisca_core/parameters/values_history.py index fc55400c89..4c56c72398 100644 --- a/openfisca_core/parameters/values_history.py +++ b/openfisca_core/parameters/values_history.py @@ -1,9 +1,5 @@ -from openfisca_core.parameters import Parameter +from .parameter import Parameter class ValuesHistory(Parameter): - """ - Only for backward compatibility. - """ - - pass + """Only for backward compatibility.""" diff --git a/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py new file mode 100644 index 0000000000..27be1f6946 --- /dev/null +++ b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py @@ -0,0 +1,81 @@ +import numpy + +from openfisca_core.parameters.parameter_node_at_instant import ParameterNodeAtInstant +from openfisca_core.parameters.vectorial_parameter_node_at_instant import ( + VectorialParameterNodeAtInstant, +) + + +class VectorialAsofDateParameterNodeAtInstant(VectorialParameterNodeAtInstant): + """Parameter node of the legislation at a given instant which has been vectorized along some date. + Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a numpy.datetime64 type vector. + """ + + @staticmethod + def build_from_node(node): + VectorialParameterNodeAtInstant.check_node_vectorisable(node) + subnodes_name = node._children.keys() + # Recursively vectorize the children of the node + vectorial_subnodes = tuple( + [ + ( + VectorialAsofDateParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], ParameterNodeAtInstant) + else node[subnode_name] + ) + for subnode_name in subnodes_name + ], + ) + # A vectorial node is a wrapper around a numpy recarray + # We first build the recarray + recarray = numpy.array( + [vectorial_subnodes], + dtype=[ + ( + subnode_name, + subnode.dtype if isinstance(subnode, numpy.recarray) else "float", + ) + for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes) + ], + ) + return VectorialAsofDateParameterNodeAtInstant( + node._name, + recarray.view(numpy.recarray), + node._instant_str, + ) + + def __getitem__(self, key): + # If the key is a string, just get the subnode + if isinstance(key, str): + key = numpy.array([key], dtype="datetime64[D]") + return self.__getattr__(key) + # If the key is a vector, e.g. ['1990-11-25', '1983-04-17', '1969-09-09'] + if isinstance(key, numpy.ndarray): + assert numpy.issubdtype(key.dtype, numpy.datetime64) + names = list( + self.dtype.names, + ) # Get all the names of the subnodes, e.g. ['before_X', 'after_X', 'after_Y'] + values = numpy.asarray(list(self.vector[0])) + names = [name for name in names if not name.startswith("before")] + names = [ + numpy.datetime64("-".join(name[len("after_") :].split("_"))) + for name in names + ] + conditions = sum([name <= key for name in names]) + result = values[conditions] + + # If the result is not a leaf, wrap the result in a vectorial node. + if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( + result.dtype, + numpy.void, + ): + return VectorialAsofDateParameterNodeAtInstant( + self._name, + result.view(numpy.recarray), + self._instant_str, + ) + + return result + return None diff --git a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py index 845b2f9664..eaa679c869 100644 --- a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py +++ b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import numpy from openfisca_core import parameters @@ -7,91 +9,90 @@ class VectorialParameterNodeAtInstant: - """ - Parameter node of the legislation at a given instant which has been vectorized. - Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector + """Parameter node of the legislation at a given instant which has been vectorized. + Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector. """ @staticmethod def build_from_node(node): VectorialParameterNodeAtInstant.check_node_vectorisable(node) - subnodes_name = node._children.keys() + subnodes_name = sorted(node._children.keys()) # Recursively vectorize the children of the node - vectorial_subnodes = tuple([ - VectorialParameterNodeAtInstant.build_from_node(node[subnode_name]).vector if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) else node[subnode_name] - for subnode_name in subnodes_name - ]) + vectorial_subnodes = tuple( + [ + ( + VectorialParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) + else node[subnode_name] + ) + for subnode_name in subnodes_name + ], + ) # A vectorial node is a wrapper around a numpy recarray # We first build the recarray recarray = numpy.array( [vectorial_subnodes], - dtype = [ - (subnode_name, subnode.dtype if isinstance(subnode, numpy.recarray) else 'float') + dtype=[ + ( + subnode_name, + subnode.dtype if isinstance(subnode, numpy.recarray) else "float", + ) for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes) - ] - ) + ], + ) - return VectorialParameterNodeAtInstant(node._name, recarray.view(numpy.recarray), node._instant_str) + return VectorialParameterNodeAtInstant( + node._name, + recarray.view(numpy.recarray), + node._instant_str, + ) @staticmethod - def check_node_vectorisable(node): - """ - Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing. - """ + def check_node_vectorisable(node) -> None: + """Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing.""" MESSAGE_PART_1 = "Cannot use fancy indexing on parameter node '{}', as" - MESSAGE_PART_3 = "To use fancy indexing on parameter node, its children must be homogenous." + MESSAGE_PART_3 = ( + "To use fancy indexing on parameter node, its children must be homogeneous." + ) MESSAGE_PART_4 = "See more at ." - def raise_key_inhomogeneity_error(node_with_key, node_without_key, missing_key): - message = " ".join([ - MESSAGE_PART_1, - "'{}' exists, but '{}' doesn't.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ]).format( + def raise_key_inhomogeneity_error( + node_with_key, node_without_key, missing_key + ) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' exists, but '{{}}' doesn't. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, - '.'.join([node_with_key, missing_key]), - '.'.join([node_without_key, missing_key]), - ) + f"{node_with_key}.{missing_key}", + f"{node_without_key}.{missing_key}", + ) raise ValueError(message) - def raise_type_inhomogeneity_error(node_name, non_node_name): - message = " ".join([ - MESSAGE_PART_1, - "'{}' is a node, but '{}' is not.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ]).format( + def raise_type_inhomogeneity_error(node_name, non_node_name) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a node, but '{{}}' is not. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, node_name, non_node_name, - ) + ) raise ValueError(message) - def raise_not_implemented(node_name, node_type): - message = " ".join([ - MESSAGE_PART_1, - "'{}' is a '{}', and fancy indexing has not been implemented yet on this kind of parameters.", - MESSAGE_PART_4, - ]).format( + def raise_not_implemented(node_name, node_type) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a '{{}}', and fancy indexing has not been implemented yet on this kind of parameters. {MESSAGE_PART_4}".format( node._name, node_name, node_type, - ) + ) raise NotImplementedError(message) def extract_named_children(node): return { - '.'.join([node._name, key]): value - for key, value in node._children.items() - } - - def check_nodes_homogeneous(named_nodes): - """ - Check than several nodes (or parameters, or baremes) have the same structure. - """ + f"{node._name}.{key}": value for key, value in node._children.items() + } + + def check_nodes_homogeneous(named_nodes) -> None: + """Check than several nodes (or parameters, or baremes) have the same structure.""" names = list(named_nodes.keys()) nodes = list(named_nodes.values()) first_node = nodes[0] @@ -103,18 +104,24 @@ def check_nodes_homogeneous(named_nodes): raise_type_inhomogeneity_error(first_name, name) first_node_keys = first_node._children.keys() node_keys = node._children.keys() - if not first_node_keys == node_keys: + if first_node_keys != node_keys: missing_keys = set(first_node_keys).difference(node_keys) if missing_keys: # If the first_node has a key that node hasn't - raise_key_inhomogeneity_error(first_name, name, missing_keys.pop()) + raise_key_inhomogeneity_error( + first_name, + name, + missing_keys.pop(), + ) else: # If If the node has a key that first_node doesn't have - missing_key = set(node_keys).difference(first_node_keys).pop() + missing_key = ( + set(node_keys).difference(first_node_keys).pop() + ) raise_key_inhomogeneity_error(name, first_name, missing_key) children.update(extract_named_children(node)) check_nodes_homogeneous(children) - elif isinstance(first_node, float) or isinstance(first_node, int): + elif isinstance(first_node, (float, int)): for node, name in list(zip(nodes, names))[1:]: - if isinstance(node, int) or isinstance(node, float): + if isinstance(node, (int, float)): pass elif isinstance(node, parameters.ParameterNodeAtInstant): raise_type_inhomogeneity_error(name, first_name) @@ -126,8 +133,7 @@ def check_nodes_homogeneous(named_nodes): check_nodes_homogeneous(extract_named_children(node)) - def __init__(self, name, vector, instant_str): - + def __init__(self, name, vector, instant_str) -> None: self.vector = vector self._name = name self._instant_str = instant_str @@ -143,28 +149,51 @@ def __getitem__(self, key): if isinstance(key, str): return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] - elif isinstance(key, numpy.ndarray): + if isinstance(key, numpy.ndarray): if not numpy.issubdtype(key.dtype, numpy.str_): # In case the key is not a string vector, stringify it if key.dtype == object and issubclass(type(key[0]), Enum): enum = type(key[0]) - key = numpy.select([key == item for item in enum], [item.name for item in enum]) + key = numpy.select( + [key == item for item in enum], + [item.name for item in enum], + ) elif isinstance(key, EnumArray): enum = key.possible_values - key = numpy.select([key == item.index for item in enum], [item.name for item in enum]) + key = numpy.select( + [key == item.index for item in enum], + [item.name for item in enum], + ) else: - key = key.astype('str') - names = list(self.dtype.names) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] - default = numpy.full_like(self.vector[key[0]], numpy.nan) # In case of unexpected key, we will set the corresponding value to NaN. + key = key.astype("str") + names = list( + self.dtype.names, + ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] + default = numpy.full_like( + self.vector[key[0]], + numpy.nan, + ) # In case of unexpected key, we will set the corresponding value to NaN. conditions = [key == name for name in names] values = [self.vector[name] for name in names] result = numpy.select(conditions, values, default) if helpers.contains_nan(result): unexpected_key = set(key).difference(self.vector.dtype.names).pop() - raise ParameterNotFoundError('.'.join([self._name, unexpected_key]), self._instant_str) + msg = f"{self._name}.{unexpected_key}" + raise ParameterNotFoundError( + msg, + self._instant_str, + ) # If the result is not a leaf, wrap the result in a vectorial node. - if numpy.issubdtype(result.dtype, numpy.record): - return VectorialParameterNodeAtInstant(self._name, result.view(numpy.recarray), self._instant_str) + if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( + result.dtype, + numpy.void, + ): + return VectorialParameterNodeAtInstant( + self._name, + result.view(numpy.recarray), + self._instant_str, + ) return result + return None diff --git a/openfisca_core/periods/__init__.py b/openfisca_core/periods/__init__.py index 4cd9db648c..2335f1792a 100644 --- a/openfisca_core/periods/__init__.py +++ b/openfisca_core/periods/__init__.py @@ -21,26 +21,59 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .config import ( # noqa: F401 - DAY, - MONTH, - YEAR, - ETERNITY, +from . import types +from ._errors import InstantError, ParserError, PeriodError +from .config import ( INSTANT_PATTERN, date_by_instant_cache, str_by_instant_cache, year_or_month_or_day_re, - ) - -from .helpers import ( # noqa: F401 - N_, +) +from .date_unit import DateUnit +from .helpers import ( instant, instant_date, - period, key_period_size, - unit_weights, + period, unit_weight, - ) + unit_weights, +) +from .instant_ import Instant +from .period_ import Period + +WEEKDAY = DateUnit.WEEKDAY +WEEK = DateUnit.WEEK +DAY = DateUnit.DAY +MONTH = DateUnit.MONTH +YEAR = DateUnit.YEAR +ETERNITY = DateUnit.ETERNITY +ISOFORMAT = DateUnit.isoformat +ISOCALENDAR = DateUnit.isocalendar -from .instant_ import Instant # noqa: F401 -from .period_ import Period # noqa: F401 +__all__ = [ + "DAY", + "DateUnit", + "ETERNITY", + "INSTANT_PATTERN", + "ISOCALENDAR", + "ISOFORMAT", + "Instant", + "InstantError", + "MONTH", + "ParserError", + "Period", + "PeriodError", + "WEEK", + "WEEKDAY", + "YEAR", + "date_by_instant_cache", + "instant", + "instant_date", + "key_period_size", + "period", + "str_by_instant_cache", + "types", + "unit_weight", + "unit_weights", + "year_or_month_or_day_re", +] diff --git a/openfisca_core/periods/_errors.py b/openfisca_core/periods/_errors.py new file mode 100644 index 0000000000..733d03ce2a --- /dev/null +++ b/openfisca_core/periods/_errors.py @@ -0,0 +1,28 @@ +from pendulum.parsing.exceptions import ParserError + + +class InstantError(ValueError): + """Raised when an invalid instant-like is provided.""" + + def __init__(self, value: str) -> None: + msg = ( + f"'{value}' is not a valid instant string. Instants are described " + "using either the 'YYYY-MM-DD' format, for instance '2015-06-15', " + "or the 'YYYY-Www-D' format, for instance '2015-W24-1'." + ) + super().__init__(msg) + + +class PeriodError(ValueError): + """Raised when an invalid period-like is provided.""" + + def __init__(self, value: str) -> None: + msg = ( + "Expected a period (eg. '2017', 'month:2017-01', 'week:2017-W01-1:3', " + f"...); got: '{value}'. Learn more about legal period formats in " + "OpenFisca: ." + ) + super().__init__(msg) + + +__all__ = ["InstantError", "ParserError", "PeriodError"] diff --git a/openfisca_core/periods/_parsers.py b/openfisca_core/periods/_parsers.py new file mode 100644 index 0000000000..9973b890a0 --- /dev/null +++ b/openfisca_core/periods/_parsers.py @@ -0,0 +1,121 @@ +"""To parse periods and instants from strings.""" + +from __future__ import annotations + +import datetime + +import pendulum + +from . import types as t +from ._errors import InstantError, ParserError, PeriodError +from .date_unit import DateUnit +from .instant_ import Instant +from .period_ import Period + + +def parse_instant(value: str) -> t.Instant: + """Parse a string into an instant. + + Args: + value (str): The string to parse. + + Returns: + An InstantStr. + + Raises: + InstantError: When the string is not a valid ISO Calendar/Format. + ParserError: When the string couldn't be parsed. + + Examples: + >>> parse_instant("2022") + Instant((2022, 1, 1)) + + >>> parse_instant("2022-02") + Instant((2022, 2, 1)) + + >>> parse_instant("2022-W02-7") + Instant((2022, 1, 16)) + + >>> parse_instant("2022-W013") + Traceback (most recent call last): + openfisca_core.periods._errors.InstantError: '2022-W013' is not a va... + + >>> parse_instant("2022-02-29") + Traceback (most recent call last): + pendulum.parsing.exceptions.ParserError: Unable to parse string [202... + + """ + + if not isinstance(value, t.InstantStr): + raise InstantError(str(value)) + + date = pendulum.parse(value, exact=True) + + if not isinstance(date, datetime.date): + msg = f"Unable to parse string [{value}]" + raise ParserError(msg) + + return Instant((date.year, date.month, date.day)) + + +def parse_period(value: str) -> t.Period: + """Parses ISO format/calendar periods. + + Such as "2012" or "2015-03". + + Examples: + >>> parse_period("2022") + Period((, Instant((2022, 1, 1)), 1)) + + >>> parse_period("2022-02") + Period((, Instant((2022, 2, 1)), 1)) + + >>> parse_period("2022-W02-7") + Period((, Instant((2022, 1, 16)), 1)) + + """ + + try: + instant = parse_instant(value) + + except InstantError as error: + raise PeriodError(value) from error + + unit = parse_unit(value) + + return Period((unit, instant, 1)) + + +def parse_unit(value: str) -> t.DateUnit: + """Determine the date unit of a date string. + + Args: + value (str): The date string to parse. + + Returns: + A DateUnit. + + Raises: + InstantError: when no DateUnit can be determined. + + Examples: + >>> parse_unit("2022") + + + >>> parse_unit("2022-W03-1") + + + """ + + if not isinstance(value, t.InstantStr): + raise InstantError(str(value)) + + length = len(value.split("-")) + + if isinstance(value, t.ISOCalendarStr): + return DateUnit.isocalendar[-length] + + return DateUnit.isoformat[-length] + + +__all__ = ["parse_instant", "parse_period", "parse_unit"] diff --git a/openfisca_core/periods/config.py b/openfisca_core/periods/config.py index 6e0c698098..4486a5caf0 100644 --- a/openfisca_core/periods/config.py +++ b/openfisca_core/periods/config.py @@ -1,15 +1,20 @@ import re -import typing -DAY = 'day' -MONTH = 'month' -YEAR = 'year' -ETERNITY = 'eternity' +import pendulum + +from . import types as t # Matches "2015", "2015-01", "2015-01-01" # Does not match "2015-13", "2015-12-32" -INSTANT_PATTERN = re.compile(r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$") +INSTANT_PATTERN = re.compile( + r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$", +) + +date_by_instant_cache: dict[t.Instant, pendulum.Date] = {} +str_by_instant_cache: dict[t.Instant, t.InstantStr] = {} +year_or_month_or_day_re = re.compile( + r"(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$", +) + -date_by_instant_cache: typing.Dict = {} -str_by_instant_cache: typing.Dict = {} -year_or_month_or_day_re = re.compile(r'(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$') +__all__ = ["INSTANT_PATTERN", "date_by_instant_cache", "str_by_instant_cache"] diff --git a/openfisca_core/periods/date_unit.py b/openfisca_core/periods/date_unit.py new file mode 100644 index 0000000000..c66346c3c2 --- /dev/null +++ b/openfisca_core/periods/date_unit.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from enum import EnumMeta + +from strenum import StrEnum + +from . import types as t + + +class DateUnitMeta(EnumMeta): + @property + def isoformat(self) -> tuple[t.DateUnit, ...]: + """Creates a :obj:`tuple` of ``key`` with isoformat items. + + Returns: + tuple(str): A :obj:`tuple` containing the ``keys``. + + Examples: + >>> DateUnit.isoformat + (, , >> DateUnit.DAY in DateUnit.isoformat + True + + >>> DateUnit.WEEK in DateUnit.isoformat + False + + """ + return DateUnit.DAY, DateUnit.MONTH, DateUnit.YEAR + + @property + def isocalendar(self) -> tuple[t.DateUnit, ...]: + """Creates a :obj:`tuple` of ``key`` with isocalendar items. + + Returns: + tuple(str): A :obj:`tuple` containing the ``keys``. + + Examples: + >>> DateUnit.isocalendar + (, , >> DateUnit.WEEK in DateUnit.isocalendar + True + + >>> "day" in DateUnit.isocalendar + False + + """ + return DateUnit.WEEKDAY, DateUnit.WEEK, DateUnit.YEAR + + +class DateUnit(StrEnum, metaclass=DateUnitMeta): + """The date units of a rule system. + + Examples: + >>> repr(DateUnit) + "" + + >>> repr(DateUnit.DAY) + "" + + >>> str(DateUnit.DAY) + 'day' + + >>> dict([(DateUnit.DAY, DateUnit.DAY.value)]) + {: 'day'} + + >>> list(DateUnit) + [, , >> len(DateUnit) + 6 + + >>> DateUnit["DAY"] + + + >>> DateUnit(DateUnit.DAY) + + + >>> DateUnit.DAY in DateUnit + True + + >>> "day" in list(DateUnit) + True + + >>> DateUnit.DAY == "day" + True + + >>> DateUnit.DAY.name + 'DAY' + + >>> DateUnit.DAY.value + 'day' + + """ + + def __contains__(self, other: object) -> bool: + if isinstance(other, str): + return super().__contains__(other) + return NotImplemented + + WEEKDAY = "weekday" + WEEK = "week" + DAY = "day" + MONTH = "month" + YEAR = "year" + ETERNITY = "eternity" + + +__all__ = ["DateUnit"] diff --git a/openfisca_core/periods/helpers.py b/openfisca_core/periods/helpers.py index 9ddf794d06..fab26c48ab 100644 --- a/openfisca_core/periods/helpers.py +++ b/openfisca_core/periods/helpers.py @@ -1,203 +1,313 @@ +from __future__ import annotations + +from typing import NoReturn + import datetime -import os +import functools + +import pendulum + +from . import config, types as t +from ._errors import InstantError, PeriodError +from ._parsers import parse_instant, parse_period +from .date_unit import DateUnit +from .instant_ import Instant +from .period_ import Period + + +@functools.singledispatch +def instant(value: object) -> t.Instant: + """Build a new instant, aka a triple of integers (year, month, day). + + Args: + value(object): An ``instant-like`` object. + + Returns: + :obj:`.Instant`: A new instant. + + Raises: + :exc:`ValueError`: When the arguments were invalid, like "2021-32-13". + + Examples: + >>> instant((2021,)) + Instant((2021, 1, 1)) + + >>> instant((2021, 9)) + Instant((2021, 9, 1)) + + >>> instant(datetime.date(2021, 9, 16)) + Instant((2021, 9, 16)) -from openfisca_core import periods -from openfisca_core.periods import config + >>> instant(Instant((2021, 9, 16))) + Instant((2021, 9, 16)) + >>> instant(Period((DateUnit.YEAR, Instant((2021, 9, 16)), 1))) + Instant((2021, 9, 16)) -def N_(message): - return message + >>> instant(2021) + Instant((2021, 1, 1)) + >>> instant("2021") + Instant((2021, 1, 1)) -def instant(instant): - """Return a new instant, aka a triple of integers (year, month, day). + >>> instant([2021]) + Instant((2021, 1, 1)) - >>> instant(2014) - Instant((2014, 1, 1)) - >>> instant('2014') - Instant((2014, 1, 1)) - >>> instant('2014-02') - Instant((2014, 2, 1)) - >>> instant('2014-3-2') - Instant((2014, 3, 2)) - >>> instant(instant('2014-3-2')) - Instant((2014, 3, 2)) - >>> instant(period('month', '2014-3-2')) - Instant((2014, 3, 2)) + >>> instant([2021, 9]) + Instant((2021, 9, 1)) + + >>> instant(None) + Traceback (most recent call last): + openfisca_core.periods._errors.InstantError: 'None' is not a valid i... + + """ + + if isinstance(value, t.SeqInt): + return Instant((list(value) + [1] * 3)[:3]) + + raise InstantError(str(value)) + + +@instant.register +def _(value: None) -> NoReturn: + raise InstantError(str(value)) + + +@instant.register +def _(value: int) -> t.Instant: + return Instant((value, 1, 1)) + + +@instant.register +def _(value: Period) -> t.Instant: + return value.start + + +@instant.register +def _(value: t.Instant) -> t.Instant: + return value + + +@instant.register +def _(value: datetime.date) -> t.Instant: + return Instant((value.year, value.month, value.day)) + + +@instant.register +def _(value: str) -> t.Instant: + return parse_instant(value) + + +def instant_date(instant: None | t.Instant) -> None | datetime.date: + """Returns the date representation of an ``Instant``. + + Args: + instant: An ``Instant``. + + Returns: + None: When ``instant`` is None. + datetime.date: Otherwise. + + Examples: + >>> instant_date(Instant((2021, 1, 1))) + Date(2021, 1, 1) - >>> instant(None) """ if instant is None: return None - if isinstance(instant, periods.Instant): - return instant - if isinstance(instant, str): - if not config.INSTANT_PATTERN.match(instant): - raise ValueError("'{}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'.".format(instant)) - instant = periods.Instant( - int(fragment) - for fragment in instant.split('-', 2)[:3] - ) - elif isinstance(instant, datetime.date): - instant = periods.Instant((instant.year, instant.month, instant.day)) - elif isinstance(instant, int): - instant = (instant,) - elif isinstance(instant, list): - assert 1 <= len(instant) <= 3 - instant = tuple(instant) - elif isinstance(instant, periods.Period): - instant = instant.start - else: - assert isinstance(instant, tuple), instant - assert 1 <= len(instant) <= 3 - if len(instant) == 1: - return periods.Instant((instant[0], 1, 1)) - if len(instant) == 2: - return periods.Instant((instant[0], instant[1], 1)) - return periods.Instant(instant) - - -def instant_date(instant): - if instant is None: - return None + instant_date = config.date_by_instant_cache.get(instant) + if instant_date is None: - config.date_by_instant_cache[instant] = instant_date = datetime.date(*instant) + config.date_by_instant_cache[instant] = instant_date = pendulum.date(*instant) + return instant_date -def period(value): - """Return a new period, aka a triple (unit, start_instant, size). +@functools.singledispatch +def period(value: object) -> t.Period: + """Build a new period, aka a triple (unit, start_instant, size). + + Args: + value: A ``period-like`` object. + + Returns: + :obj:`.Period`: A period. + + Raises: + :exc:`ValueError`: When the arguments were invalid, like "2021-32-13". + + Examples: + >>> period(Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1))) + Period((, Instant((2021, 1, 1)), 1)) + + >>> period(Instant((2021, 1, 1))) + Period((, Instant((2021, 1, 1)), 1)) + + >>> period(DateUnit.ETERNITY) + Period((, Instant((-1, -1, -1)), -1)) - >>> period('2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) - >>> period('year:2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) + >>> period(2021) + Period((, Instant((2021, 1, 1)), 1)) - >>> period('2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('2014-02') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('month:2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) + >>> period("2014") + Period((, Instant((2014, 1, 1)), 1)) + + >>> period("year:2014") + Period((, Instant((2014, 1, 1)), 1)) + + >>> period("month:2014-02") + Period((, Instant((2014, 2, 1)), 1)) + + >>> period("year:2014-02") + Period((, Instant((2014, 2, 1)), 1)) + + >>> period("day:2014-02-02") + Period((, Instant((2014, 2, 2)), 1)) + + >>> period("day:2014-02-02:3") + Period((, Instant((2014, 2, 2)), 3)) - >>> period('year:2014-2') - Period((YEAR, Instant((2014, 2, 1)), 1)) """ - if isinstance(value, periods.Period): - return value - - if isinstance(value, periods.Instant): - return periods.Period((config.DAY, value, 1)) - - def parse_simple_period(value): - """ - Parses simple periods respecting the ISO format, such as 2012 or 2015-03 - """ - try: - date = datetime.datetime.strptime(value, '%Y') - except ValueError: + + one, two, three = 1, 2, 3 + + # We return an "eternity-period", for example + # ``, -1))>``. + if str(value).lower() == DateUnit.ETERNITY: + return Period.eternity() + + # We try to parse from an ISO format/calendar period. + if isinstance(value, t.InstantStr): + return parse_period(value) + + # A complex period has a ':' in its string. + if isinstance(value, t.PeriodStr): + components = value.split(":") + + # The left-most component must be a valid unit + unit = components[0] + + if unit not in list(DateUnit) or unit == DateUnit.ETERNITY: + raise PeriodError(str(value)) + + # Cast ``unit`` to DateUnit. + unit = DateUnit(unit) + + # The middle component must be a valid iso period + period = parse_period(components[1]) + + # Periods like year:2015-03 have a size of 1 + if len(components) == two: + size = one + + # if provided, make sure the size is an integer + elif len(components) == three: try: - date = datetime.datetime.strptime(value, '%Y-%m') - except ValueError: - try: - date = datetime.datetime.strptime(value, '%Y-%m-%d') - except ValueError: - return None - else: - return periods.Period((config.DAY, periods.Instant((date.year, date.month, date.day)), 1)) - else: - return periods.Period((config.MONTH, periods.Instant((date.year, date.month, 1)), 1)) + size = int(components[2]) + + except ValueError as error: + raise PeriodError(str(value)) from error + + # If there are more than 2 ":" in the string, the period is invalid else: - return periods.Period((config.YEAR, periods.Instant((date.year, date.month, 1)), 1)) - - def raise_error(value): - message = os.linesep.join([ - "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format(value), - "Learn more about legal period formats in OpenFisca:", - "." - ]) - raise ValueError(message) - - if value == 'ETERNITY' or value == config.ETERNITY: - return periods.Period(('eternity', instant(datetime.date.min), float("inf"))) - - # check the type - if isinstance(value, int): - return periods.Period((config.YEAR, periods.Instant((value, 1, 1)), 1)) - if not isinstance(value, str): - raise_error(value) - - # try to parse as a simple period - period = parse_simple_period(value) - if period is not None: - return period - - # complex period must have a ':' in their strings - if ":" not in value: - raise_error(value) - - components = value.split(':') - - # left-most component must be a valid unit - unit = components[0] - if unit not in (config.DAY, config.MONTH, config.YEAR): - raise_error(value) - - # middle component must be a valid iso period - base_period = parse_simple_period(components[1]) - if not base_period: - raise_error(value) - - # period like year:2015-03 have a size of 1 - if len(components) == 2: - size = 1 - # if provided, make sure the size is an integer - elif len(components) == 3: - try: - size = int(components[2]) - except ValueError: - raise_error(value) - # if there is more than 2 ":" in the string, the period is invalid - else: - raise_error(value) - - # reject ambiguous period such as month:2014 - if unit_weight(base_period.unit) > unit_weight(unit): - raise_error(value) - - return periods.Period((unit, base_period.start, size)) - - -def key_period_size(period): - """ - Defines a key in order to sort periods by length. It uses two aspects : first unit then size + raise PeriodError(str(value)) + + # Reject ambiguous periods such as month:2014 + if unit_weight(period.unit) > unit_weight(unit): + raise PeriodError(str(value)) + + return Period((unit, period.start, size)) + + raise PeriodError(str(value)) + - :param period: an OpenFisca period - :return: a string +@period.register +def _(value: None) -> NoReturn: + raise PeriodError(str(value)) - >>> key_period_size(period('2014')) - '2_1' - >>> key_period_size(period('2013')) - '2_1' - >>> key_period_size(period('2014-01')) - '1_1' + +@period.register +def _(value: int) -> t.Period: + return Period((DateUnit.YEAR, instant(value), 1)) + + +@period.register +def _(value: t.Period) -> t.Period: + return value + + +@period.register +def _(value: t.Instant) -> t.Period: + return Period((DateUnit.DAY, value, 1)) + + +@period.register +def _(value: datetime.date) -> t.Period: + return Period((DateUnit.DAY, instant(value), 1)) + + +def key_period_size(period: t.Period) -> str: + """Define a key in order to sort periods by length. + + It uses two aspects: first, ``unit``, then, ``size``. + + Args: + period: An :mod:`.openfisca_core` :obj:`.Period`. + + Returns: + :obj:`str`: A string. + + Examples: + >>> instant = Instant((2021, 9, 14)) + + >>> period = Period((DateUnit.DAY, instant, 1)) + >>> key_period_size(period) + '100_1' + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> key_period_size(period) + '300_3' """ - unit, start, size = period + return f"{unit_weight(period.unit)}_{period.size}" + - return '{}_{}'.format(unit_weight(unit), size) +def unit_weights() -> dict[t.DateUnit, int]: + """Assign weights to date units. + Examples: + >>> unit_weights() + {: 100, ...ETERNITY: 'eternity'>: 400} -def unit_weights(): + """ return { - config.DAY: 100, - config.MONTH: 200, - config.YEAR: 300, - config.ETERNITY: 400, - } + DateUnit.WEEKDAY: 100, + DateUnit.WEEK: 200, + DateUnit.DAY: 100, + DateUnit.MONTH: 200, + DateUnit.YEAR: 300, + DateUnit.ETERNITY: 400, + } -def unit_weight(unit): +def unit_weight(unit: t.DateUnit) -> int: + """Retrieves a specific date unit weight. + + Examples: + >>> unit_weight(DateUnit.DAY) + 100 + + """ return unit_weights()[unit] + + +__all__ = [ + "instant", + "instant_date", + "key_period_size", + "period", + "unit_weight", + "unit_weights", +] diff --git a/openfisca_core/periods/instant_.py b/openfisca_core/periods/instant_.py index c3da65f894..f71dbb3222 100644 --- a/openfisca_core/periods/instant_.py +++ b/openfisca_core/periods/instant_.py @@ -1,249 +1,224 @@ -import calendar -import datetime +from __future__ import annotations -from openfisca_core import periods -from openfisca_core.periods import config +import pendulum +from . import config, types as t +from .date_unit import DateUnit -class Instant(tuple): - def __repr__(self): - """ - Transform instant to to its Python representation as a string. - - >>> repr(instant(2014)) - 'Instant((2014, 1, 1))' - >>> repr(instant('2014-2')) - 'Instant((2014, 2, 1))' - >>> repr(instant('2014-2-3')) - 'Instant((2014, 2, 3))' - """ - return '{}({})'.format(self.__class__.__name__, super(Instant, self).__repr__()) +class Instant(tuple[int, int, int]): + """An instant in time (year, month, day). - def __str__(self): - """ - Transform instant to a string. + An :class:`.Instant` represents the most atomic and indivisible + legislation's date unit. - >>> str(instant(2014)) - '2014-01-01' - >>> str(instant('2014-2')) - '2014-02-01' - >>> str(instant('2014-2-3')) - '2014-02-03' + Current implementation considers this unit to be a day, so + :obj:`instants <.Instant>` can be thought of as "day dates". - """ + Examples: + >>> instant = Instant((2021, 9, 13)) + + >>> repr(Instant) + "" + + >>> repr(instant) + 'Instant((2021, 9, 13))' + + >>> str(instant) + '2021-09-13' + + >>> dict([(instant, (2021, 9, 13))]) + {Instant((2021, 9, 13)): (2021, 9, 13)} + + >>> list(instant) + [2021, 9, 13] + + >>> instant[0] + 2021 + + >>> instant[0] in instant + True + + >>> len(instant) + 3 + + >>> instant == (2021, 9, 13) + True + + >>> instant != (2021, 9, 13) + False + + >>> instant > (2020, 9, 13) + True + + >>> instant < (2020, 9, 13) + False + + >>> instant >= (2020, 9, 13) + True + + >>> instant <= (2020, 9, 13) + False + + >>> instant.year + 2021 + + >>> instant.month + 9 + + >>> instant.day + 13 + + >>> instant.date + Date(2021, 9, 13) + + >>> year, month, day = instant + + """ + + __slots__ = () + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + def __str__(self) -> t.InstantStr: instant_str = config.str_by_instant_cache.get(self) + if instant_str is None: - config.str_by_instant_cache[self] = instant_str = self.date.isoformat() + instant_str = t.InstantStr(self.date.isoformat()) + config.str_by_instant_cache[self] = instant_str + return instant_str + def __lt__(self, other: object) -> bool: + if isinstance(other, Instant): + return super().__lt__(other) + return NotImplemented + + def __le__(self, other: object) -> bool: + if isinstance(other, Instant): + return super().__le__(other) + return NotImplemented + @property - def date(self): - """ - Convert instant to a date. - - >>> instant(2014).date - datetime.date(2014, 1, 1) - >>> instant('2014-2').date - datetime.date(2014, 2, 1) - >>> instant('2014-2-3').date - datetime.date(2014, 2, 3) - """ + def date(self) -> pendulum.Date: instant_date = config.date_by_instant_cache.get(self) + if instant_date is None: - config.date_by_instant_cache[self] = instant_date = datetime.date(*self) + instant_date = pendulum.date(*self) + config.date_by_instant_cache[self] = instant_date + return instant_date @property - def day(self): - """ - Extract day from instant. - - >>> instant(2014).day - 1 - >>> instant('2014-2').day - 1 - >>> instant('2014-2-3').day - 3 - """ + def day(self) -> int: return self[2] @property - def month(self): - """ - Extract month from instant. - - >>> instant(2014).month - 1 - >>> instant('2014-2').month - 2 - >>> instant('2014-2-3').month - 2 - """ + def month(self) -> int: return self[1] - def period(self, unit, size = 1): - """ - Create a new period starting at instant. - - >>> instant(2014).period('month') - Period(('month', Instant((2014, 1, 1)), 1)) - >>> instant('2014-2').period('year', 2) - Period(('year', Instant((2014, 2, 1)), 2)) - >>> instant('2014-2-3').period('day', size = 2) - Period(('day', Instant((2014, 2, 3)), 2)) - """ - assert unit in (config.DAY, config.MONTH, config.YEAR), 'Invalid unit: {} of type {}'.format(unit, type(unit)) - assert isinstance(size, int) and size >= 1, 'Invalid size: {} of type {}'.format(size, type(size)) - return periods.Period((unit, self, size)) - - def offset(self, offset, unit): - """ - Increment (or decrement) the given instant with offset units. - - >>> instant(2014).offset(1, 'day') - Instant((2014, 1, 2)) - >>> instant(2014).offset(1, 'month') - Instant((2014, 2, 1)) - >>> instant(2014).offset(1, 'year') - Instant((2015, 1, 1)) - - >>> instant('2014-1-31').offset(1, 'day') - Instant((2014, 2, 1)) - >>> instant('2014-1-31').offset(1, 'month') - Instant((2014, 2, 28)) - >>> instant('2014-1-31').offset(1, 'year') - Instant((2015, 1, 31)) - - >>> instant('2011-2-28').offset(1, 'day') - Instant((2011, 3, 1)) - >>> instant('2011-2-28').offset(1, 'month') - Instant((2011, 3, 28)) - >>> instant('2012-2-29').offset(1, 'year') - Instant((2013, 2, 28)) - - >>> instant(2014).offset(-1, 'day') - Instant((2013, 12, 31)) - >>> instant(2014).offset(-1, 'month') - Instant((2013, 12, 1)) - >>> instant(2014).offset(-1, 'year') - Instant((2013, 1, 1)) - - >>> instant('2011-3-1').offset(-1, 'day') - Instant((2011, 2, 28)) - >>> instant('2011-3-31').offset(-1, 'month') - Instant((2011, 2, 28)) - >>> instant('2012-2-29').offset(-1, 'year') - Instant((2011, 2, 28)) - - >>> instant('2014-1-30').offset(3, 'day') - Instant((2014, 2, 2)) - >>> instant('2014-10-2').offset(3, 'month') - Instant((2015, 1, 2)) - >>> instant('2014-1-1').offset(3, 'year') - Instant((2017, 1, 1)) - - >>> instant(2014).offset(-3, 'day') - Instant((2013, 12, 29)) - >>> instant(2014).offset(-3, 'month') - Instant((2013, 10, 1)) - >>> instant(2014).offset(-3, 'year') - Instant((2011, 1, 1)) - - >>> instant(2014).offset('first-of', 'month') - Instant((2014, 1, 1)) - >>> instant('2014-2').offset('first-of', 'month') - Instant((2014, 2, 1)) - >>> instant('2014-2-3').offset('first-of', 'month') - Instant((2014, 2, 1)) - - >>> instant(2014).offset('first-of', 'year') - Instant((2014, 1, 1)) - >>> instant('2014-2').offset('first-of', 'year') - Instant((2014, 1, 1)) - >>> instant('2014-2-3').offset('first-of', 'year') - Instant((2014, 1, 1)) - - >>> instant(2014).offset('last-of', 'month') - Instant((2014, 1, 31)) - >>> instant('2014-2').offset('last-of', 'month') - Instant((2014, 2, 28)) - >>> instant('2012-2-3').offset('last-of', 'month') - Instant((2012, 2, 29)) - - >>> instant(2014).offset('last-of', 'year') - Instant((2014, 12, 31)) - >>> instant('2014-2').offset('last-of', 'year') - Instant((2014, 12, 31)) - >>> instant('2014-2-3').offset('last-of', 'year') - Instant((2014, 12, 31)) - """ - year, month, day = self - assert unit in (config.DAY, config.MONTH, config.YEAR), 'Invalid unit: {} of type {}'.format(unit, type(unit)) - if offset == 'first-of': - if unit == config.MONTH: - day = 1 - elif unit == config.YEAR: - month = 1 - day = 1 - elif offset == 'last-of': - if unit == config.MONTH: - day = calendar.monthrange(year, month)[1] - elif unit == config.YEAR: - month = 12 - day = 31 - else: - assert isinstance(offset, int), 'Invalid offset: {} of type {}'.format(offset, type(offset)) - if unit == config.DAY: - day += offset - if offset < 0: - while day < 1: - month -= 1 - if month == 0: - year -= 1 - month = 12 - day += calendar.monthrange(year, month)[1] - elif offset > 0: - month_last_day = calendar.monthrange(year, month)[1] - while day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - month_last_day = calendar.monthrange(year, month)[1] - elif unit == config.MONTH: - month += offset - if offset < 0: - while month < 1: - year -= 1 - month += 12 - elif offset > 0: - while month > 12: - year += 1 - month -= 12 - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - day = month_last_day - elif unit == config.YEAR: - year += offset - # Handle february month of leap year. - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - day = month_last_day - - return self.__class__((year, month, day)) + @property + def year(self) -> int: + return self[0] @property - def year(self): - """ - Extract year from instant. - - >>> instant(2014).year - 2014 - >>> instant('2014-2').year - 2014 - >>> instant('2014-2-3').year - 2014 + def is_eternal(self) -> bool: + return self == self.eternity() + + def offset(self, offset: str | int, unit: t.DateUnit) -> t.Instant | None: + """Increments/decrements the given instant with offset units. + + Args: + offset: How much of ``unit`` to offset. + unit: What to offset + + Returns: + :obj:`.Instant`: A new :obj:`.Instant` in time. + + Raises: + :exc:`AssertionError`: When ``unit`` is not a date unit. + :exc:`AssertionError`: When ``offset`` is not either ``first-of``, + ``last-of``, or any :obj:`int`. + + Examples: + >>> Instant((2020, 12, 31)).offset("first-of", DateUnit.MONTH) + Instant((2020, 12, 1)) + + >>> Instant((2020, 1, 1)).offset("last-of", DateUnit.YEAR) + Instant((2020, 12, 31)) + + >>> Instant((2020, 1, 1)).offset(1, DateUnit.YEAR) + Instant((2021, 1, 1)) + + >>> Instant((2020, 1, 1)).offset(-3, DateUnit.DAY) + Instant((2019, 12, 29)) + """ - return self[0] + year, month, _ = self + + assert unit in ( + DateUnit.isoformat + DateUnit.isocalendar + ), f"Invalid unit: {unit} of type {type(unit)}" + + if offset == "first-of": + if unit == DateUnit.YEAR: + return self.__class__((year, 1, 1)) + + if unit == DateUnit.MONTH: + return self.__class__((year, month, 1)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.start_of("week") + return self.__class__((date.year, date.month, date.day)) + return None + + if offset == "last-of": + if unit == DateUnit.YEAR: + return self.__class__((year, 12, 31)) + + if unit == DateUnit.MONTH: + date = self.date + date = date.end_of("month") + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.end_of("week") + return self.__class__((date.year, date.month, date.day)) + return None + + assert isinstance( + offset, + int, + ), f"Invalid offset: {offset} of type {type(offset)}" + + if unit == DateUnit.YEAR: + date = self.date + date = date.add(years=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.MONTH: + date = self.date + date = date.add(months=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.add(weeks=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): + date = self.date + date = date.add(days=offset) + return self.__class__((date.year, date.month, date.day)) + return None + + @classmethod + def eternity(cls) -> t.Instant: + """Return an eternity instant.""" + return cls((-1, -1, -1)) + + +__all__ = ["Instant"] diff --git a/openfisca_core/periods/period_.py b/openfisca_core/periods/period_.py index 808540f28a..00e833d861 100644 --- a/openfisca_core/periods/period_.py +++ b/openfisca_core/periods/period_.py @@ -1,124 +1,415 @@ from __future__ import annotations +from collections.abc import Sequence + import calendar +import datetime -from openfisca_core import periods -from openfisca_core.periods import config, helpers +import pendulum +from . import helpers, types as t +from .date_unit import DateUnit +from .instant_ import Instant -class Period(tuple): - """ - Toolbox to handle date intervals. - A period is a triple (unit, start, size), where unit is either "month" or "year", where start format is a - (year, month, day) triple, and where size is an integer > 1. +class Period(tuple[t.DateUnit, t.Instant, int]): + """Toolbox to handle date intervals. + + A :class:`.Period` is a triple (``unit``, ``start``, ``size``). + + Attributes: + unit (:obj:`str`): + Either ``year``, ``month``, ``day`` or ``eternity``. + start (:obj:`.Instant`): + The "instant" the :obj:`.Period` starts at. + size (:obj:`int`): + The amount of ``unit``, starting at ``start``, at least ``1``. + + Args: + (tuple(tuple(str, .Instant, int))): + The ``unit``, ``start``, and ``size``, accordingly. + + Examples: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + + >>> repr(Period) + "" + + >>> repr(period) + "Period((, Instant((2021, 10, 1)), 3))" + + >>> str(period) + 'year:2021-10:3' + + >>> dict([period, instant]) + Traceback (most recent call last): + ValueError: dictionary update sequence element #0 has length 3... + + >>> list(period) + [, Instant((2021, 10, 1)), 3] + + >>> period[0] + + + >>> period[0] in period + True + + >>> len(period) + 3 + + >>> period == Period((DateUnit.YEAR, instant, 3)) + True + + >>> period != Period((DateUnit.YEAR, instant, 3)) + False + + >>> period > Period((DateUnit.YEAR, instant, 3)) + False + + >>> period < Period((DateUnit.YEAR, instant, 3)) + False + + >>> period >= Period((DateUnit.YEAR, instant, 3)) + True + + >>> period <= Period((DateUnit.YEAR, instant, 3)) + True + + >>> period.days + 1096 + + >>> period.size_in_months + 36 + + >>> period.size_in_days + 1096 + + >>> period.stop + Instant((2024, 9, 30)) + + >>> period.unit + + + >>> period.last_3_months + Period((, Instant((2021, 7, 1)), 3)) + + >>> period.last_month + Period((, Instant((2021, 9, 1)), 1)) + + >>> period.last_year + Period((, Instant((2020, 1, 1)), 1)) + + >>> period.n_2 + Period((, Instant((2019, 1, 1)), 1)) + + >>> period.this_year + Period((, Instant((2021, 1, 1)), 1)) + + >>> period.first_month + Period((, Instant((2021, 10, 1)), 1)) + + >>> period.first_day + Period((, Instant((2021, 10, 1)), 1)) + Since a period is a triple it can be used as a dictionary key. + """ - def __repr__(self): - """ - Transform period to to its Python representation as a string. - - >>> repr(period('year', 2014)) - "Period(('year', Instant((2014, 1, 1)), 1))" - >>> repr(period('month', '2014-2')) - "Period(('month', Instant((2014, 2, 1)), 1))" - >>> repr(period('day', '2014-2-3')) - "Period(('day', Instant((2014, 2, 3)), 1))" - """ - return '{}({})'.format(self.__class__.__name__, super(Period, self).__repr__()) + __slots__ = () - def __str__(self): - """ - Transform period to a string. - - >>> str(period(YEAR, 2014)) - '2014' - - >>> str(period(YEAR, '2014-2')) - 'year:2014-02' - >>> str(period(MONTH, '2014-2')) - '2014-02' - - >>> str(period(YEAR, 2012, size = 2)) - 'year:2012:2' - >>> str(period(MONTH, 2012, size = 2)) - 'month:2012-01:2' - >>> str(period(MONTH, 2012, size = 12)) - '2012' - - >>> str(period(YEAR, '2012-3', size = 2)) - 'year:2012-03:2' - >>> str(period(MONTH, '2012-3', size = 2)) - 'month:2012-03:2' - >>> str(period(MONTH, '2012-3', size = 12)) - 'year:2012-03' - """ + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + def __str__(self) -> t.PeriodStr: unit, start_instant, size = self - if unit == config.ETERNITY: - return 'ETERNITY' - year, month, day = start_instant + + if unit == DateUnit.ETERNITY: + return t.PeriodStr(unit.upper()) + + # ISO format date units. + f_year, month, day = start_instant + + # ISO calendar date units. + c_year, week, weekday = datetime.date(f_year, month, day).isocalendar() # 1 year long period - if (unit == config.MONTH and size == 12 or unit == config.YEAR and size == 1): + if unit == DateUnit.MONTH and size == 12 or unit == DateUnit.YEAR and size == 1: if month == 1: # civil year starting from january - return str(year) - else: - # rolling year - return '{}:{}-{:02d}'.format(config.YEAR, year, month) + return t.PeriodStr(str(f_year)) + # rolling year + return t.PeriodStr(f"{DateUnit.YEAR}:{f_year}-{month:02d}") + # simple month - if unit == config.MONTH and size == 1: - return '{}-{:02d}'.format(year, month) + if unit == DateUnit.MONTH and size == 1: + return t.PeriodStr(f"{f_year}-{month:02d}") + # several civil years - if unit == config.YEAR and month == 1: - return '{}:{}:{}'.format(unit, year, size) + if unit == DateUnit.YEAR and month == 1: + return t.PeriodStr(f"{unit}:{f_year}:{size}") - if unit == config.DAY: + if unit == DateUnit.DAY: if size == 1: - return '{}-{:02d}-{:02d}'.format(year, month, day) - else: - return '{}:{}-{:02d}-{:02d}:{}'.format(unit, year, month, day, size) + return t.PeriodStr(f"{f_year}-{month:02d}-{day:02d}") + return t.PeriodStr(f"{unit}:{f_year}-{month:02d}-{day:02d}:{size}") + + # 1 week + if unit == DateUnit.WEEK and size == 1: + if week < 10: + return t.PeriodStr(f"{c_year}-W0{week}") + + return t.PeriodStr(f"{c_year}-W{week}") + + # several weeks + if unit == DateUnit.WEEK and size > 1: + if week < 10: + return t.PeriodStr(f"{unit}:{c_year}-W0{week}:{size}") + + return t.PeriodStr(f"{unit}:{c_year}-W{week}:{size}") + + # 1 weekday + if unit == DateUnit.WEEKDAY and size == 1: + if week < 10: + return t.PeriodStr(f"{c_year}-W0{week}-{weekday}") + + return t.PeriodStr(f"{c_year}-W{week}-{weekday}") + + # several weekdays + if unit == DateUnit.WEEKDAY and size > 1: + if week < 10: + return t.PeriodStr(f"{unit}:{c_year}-W0{week}-{weekday}:{size}") + + return t.PeriodStr(f"{unit}:{c_year}-W{week}-{weekday}:{size}") # complex period - return '{}:{}-{:02d}:{}'.format(unit, year, month, size) + return t.PeriodStr(f"{unit}:{f_year}-{month:02d}:{size}") @property - def date(self): - assert self.size == 1, '"date" is undefined for a period of size > 1: {}'.format(self) + def unit(self) -> t.DateUnit: + """The ``unit`` of the ``Period``. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.unit + + + """ + return self[0] + + @property + def start(self) -> t.Instant: + """The ``Instant`` at which the ``Period`` starts. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.start + Instant((2021, 10, 1)) + + """ + return self[1] + + @property + def size(self) -> int: + """The ``size`` of the ``Period``. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size + 3 + + """ + return self[2] + + @property + def date(self) -> pendulum.Date: + """The date representation of the ``Period`` start date. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 1)) + >>> period.date + Date(2021, 10, 1) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.date + Traceback (most recent call last): + ValueError: "date" is undefined for a period of size > 1: year:2021-10:3. + + """ + if self.size != 1: + msg = f'"date" is undefined for a period of size > 1: {self}.' + raise ValueError(msg) + return self.start.date @property - def days(self): + def size_in_years(self) -> int: + """The ``size`` of the ``Period`` in years. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_years + 3 + + >>> period = Period((DateUnit.MONTH, instant, 3)) + >>> period.size_in_years + Traceback (most recent call last): + ValueError: Can't calculate number of years in a month. + """ - Count the number of days in period. - - >>> period('day', 2014).days - 365 - >>> period('month', 2014).days - 365 - >>> period('year', 2014).days - 365 - - >>> period('day', '2014-2').days - 28 - >>> period('month', '2014-2').days - 28 - >>> period('year', '2014-2').days - 365 - - >>> period('day', '2014-2-3').days - 1 - >>> period('month', '2014-2-3').days - 28 - >>> period('year', '2014-2-3').days - 365 + if self.unit == DateUnit.YEAR: + return self.size + + msg = f"Can't calculate number of years in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_months(self) -> int: + """The ``size`` of the ``Period`` in months. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_months + 36 + + >>> period = Period((DateUnit.DAY, instant, 3)) + >>> period.size_in_months + Traceback (most recent call last): + ValueError: Can't calculate number of months in a day. + + """ + if self.unit == DateUnit.YEAR: + return self.size * 12 + + if self.unit == DateUnit.MONTH: + return self.size + + msg = f"Can't calculate number of months in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_days(self) -> int: + """The ``size`` of the ``Period`` in days. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_days + 1096 + + >>> period = Period((DateUnit.MONTH, instant, 3)) + >>> period.size_in_days + 92 + """ + if self.unit in (DateUnit.YEAR, DateUnit.MONTH): + last = self.start.offset(self.size, self.unit) + if last is None: + raise NotImplementedError + last_day = last.offset(-1, DateUnit.DAY) + if last_day is None: + raise NotImplementedError + return (last_day.date - self.start.date).days + 1 + + if self.unit == DateUnit.WEEK: + return self.size * 7 + + if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): + return self.size + + msg = f"Can't calculate number of days in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_weeks(self) -> int: + """The ``size`` of the ``Period`` in weeks. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_weeks + 156 + + >>> period = Period((DateUnit.YEAR, instant, 5)) + >>> period.size_in_weeks + 261 + + """ + if self.unit == DateUnit.YEAR: + start = self.start.date + cease = start.add(years=self.size) + delta = start.diff(cease) + return delta.in_weeks() + + if self.unit == DateUnit.MONTH: + start = self.start.date + cease = start.add(months=self.size) + delta = start.diff(cease) + return delta.in_weeks() + + if self.unit == DateUnit.WEEK: + return self.size + + msg = f"Can't calculate number of weeks in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_weekdays(self) -> int: + """The ``size`` of the ``Period`` in weekdays. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_weekdays + 1092 + + >>> period = Period((DateUnit.WEEK, instant, 3)) + >>> period.size_in_weekdays + 21 + + """ + if self.unit == DateUnit.YEAR: + return self.size_in_weeks * 7 + + if DateUnit.MONTH in self.unit: + last = self.start.offset(self.size, self.unit) + if last is None: + raise NotImplementedError + last_day = last.offset(-1, DateUnit.DAY) + if last_day is None: + raise NotImplementedError + return (last_day.date - self.start.date).days + 1 + + if self.unit == DateUnit.WEEK: + return self.size * 7 + + if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): + return self.size + + msg = f"Can't calculate number of weekdays in a {self.unit}." + raise ValueError(msg) + + @property + def days(self) -> int: + """Same as ``size_in_days``.""" return (self.stop.date - self.start.date).days + 1 - def intersection(self, start, stop): + def intersection( + self, start: t.Instant | None, stop: t.Instant | None + ) -> t.Period | None: if start is None and stop is None: return self period_start = self[1] @@ -133,351 +424,495 @@ def intersection(self, start, stop): intersection_stop = min(period_stop, stop) if intersection_start == period_start and intersection_stop == period_stop: return self - if intersection_start.day == 1 and intersection_start.month == 1 \ - and intersection_stop.day == 31 and intersection_stop.month == 12: - return self.__class__(( - 'year', - intersection_start, - intersection_stop.year - intersection_start.year + 1, - )) - if intersection_start.day == 1 and intersection_stop.day == calendar.monthrange(intersection_stop.year, - intersection_stop.month)[1]: - return self.__class__(( - 'month', - intersection_start, + if ( + intersection_start.day == 1 + and intersection_start.month == 1 + and intersection_stop.day == 31 + and intersection_stop.month == 12 + ): + return self.__class__( + ( + DateUnit.YEAR, + intersection_start, + intersection_stop.year - intersection_start.year + 1, + ), + ) + if ( + intersection_start.day == 1 + and intersection_stop.day + == calendar.monthrange(intersection_stop.year, intersection_stop.month)[1] + ): + return self.__class__( ( - (intersection_stop.year - intersection_start.year) * 12 - + intersection_stop.month - - intersection_start.month - + 1 + DateUnit.MONTH, + intersection_start, + ( + (intersection_stop.year - intersection_start.year) * 12 + + intersection_stop.month + - intersection_start.month + + 1 ), - )) - return self.__class__(( - 'day', - intersection_start, - (intersection_stop.date - intersection_start.date).days + 1, - )) - - def get_subperiods(self, unit): - """ - Return the list of all the periods of unit ``unit`` contained in self. + ), + ) + return self.__class__( + ( + DateUnit.DAY, + intersection_start, + (intersection_stop.date - intersection_start.date).days + 1, + ), + ) + + def get_subperiods(self, unit: t.DateUnit) -> Sequence[t.Period]: + """Return the list of periods of unit ``unit`` contained in self. Examples: + >>> period = Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)) + >>> period.get_subperiods(DateUnit.MONTH) + [Period((, Instant((2021, 1, 1)), 1)),...] - >>> period('2017').get_subperiods(MONTH) - >>> [period('2017-01'), period('2017-02'), ... period('2017-12')] + >>> period = Period((DateUnit.YEAR, Instant((2021, 1, 1)), 2)) + >>> period.get_subperiods(DateUnit.YEAR) + [Period((, Instant((2021, 1, 1)), 1)), P...] - >>> period('year:2014:2').get_subperiods(YEAR) - >>> [period('2014'), period('2015')] """ if helpers.unit_weight(self.unit) < helpers.unit_weight(unit): - raise ValueError('Cannot subdivide {0} into {1}'.format(self.unit, unit)) + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) - if unit == config.YEAR: - return [self.this_year.offset(i, config.YEAR) for i in range(self.size)] + if unit == DateUnit.YEAR: + return [self.this_year.offset(i, DateUnit.YEAR) for i in range(self.size)] - if unit == config.MONTH: - return [self.first_month.offset(i, config.MONTH) for i in range(self.size_in_months)] + if unit == DateUnit.MONTH: + return [ + self.first_month.offset(i, DateUnit.MONTH) + for i in range(self.size_in_months) + ] - if unit == config.DAY: - return [self.first_day.offset(i, config.DAY) for i in range(self.size_in_days)] + if unit == DateUnit.DAY: + return [ + self.first_day.offset(i, DateUnit.DAY) for i in range(self.size_in_days) + ] - def offset(self, offset, unit = None): - """ - Increment (or decrement) the given period with offset units. - - >>> period('day', 2014).offset(1) - Period(('day', Instant((2014, 1, 2)), 365)) - >>> period('day', 2014).offset(1, 'day') - Period(('day', Instant((2014, 1, 2)), 365)) - >>> period('day', 2014).offset(1, 'month') - Period(('day', Instant((2014, 2, 1)), 365)) - >>> period('day', 2014).offset(1, 'year') - Period(('day', Instant((2015, 1, 1)), 365)) - - >>> period('month', 2014).offset(1) - Period(('month', Instant((2014, 2, 1)), 12)) - >>> period('month', 2014).offset(1, 'day') - Period(('month', Instant((2014, 1, 2)), 12)) - >>> period('month', 2014).offset(1, 'month') - Period(('month', Instant((2014, 2, 1)), 12)) - >>> period('month', 2014).offset(1, 'year') - Period(('month', Instant((2015, 1, 1)), 12)) - - >>> period('year', 2014).offset(1) - Period(('year', Instant((2015, 1, 1)), 1)) - >>> period('year', 2014).offset(1, 'day') - Period(('year', Instant((2014, 1, 2)), 1)) - >>> period('year', 2014).offset(1, 'month') - Period(('year', Instant((2014, 2, 1)), 1)) - >>> period('year', 2014).offset(1, 'year') - Period(('year', Instant((2015, 1, 1)), 1)) - - >>> period('day', '2011-2-28').offset(1) - Period(('day', Instant((2011, 3, 1)), 1)) - >>> period('month', '2011-2-28').offset(1) - Period(('month', Instant((2011, 3, 28)), 1)) - >>> period('year', '2011-2-28').offset(1) - Period(('year', Instant((2012, 2, 28)), 1)) - - >>> period('day', '2011-3-1').offset(-1) - Period(('day', Instant((2011, 2, 28)), 1)) - >>> period('month', '2011-3-1').offset(-1) - Period(('month', Instant((2011, 2, 1)), 1)) - >>> period('year', '2011-3-1').offset(-1) - Period(('year', Instant((2010, 3, 1)), 1)) - - >>> period('day', '2014-1-30').offset(3) - Period(('day', Instant((2014, 2, 2)), 1)) - >>> period('month', '2014-1-30').offset(3) - Period(('month', Instant((2014, 4, 30)), 1)) - >>> period('year', '2014-1-30').offset(3) - Period(('year', Instant((2017, 1, 30)), 1)) - - >>> period('day', 2014).offset(-3) - Period(('day', Instant((2013, 12, 29)), 365)) - >>> period('month', 2014).offset(-3) - Period(('month', Instant((2013, 10, 1)), 12)) - >>> period('year', 2014).offset(-3) - Period(('year', Instant((2011, 1, 1)), 1)) - - >>> period('day', '2014-2-3').offset('first-of', 'month') - Period(('day', Instant((2014, 2, 1)), 1)) - >>> period('day', '2014-2-3').offset('first-of', 'year') - Period(('day', Instant((2014, 1, 1)), 1)) - - >>> period('day', '2014-2-3', 4).offset('first-of', 'month') - Period(('day', Instant((2014, 2, 1)), 4)) - >>> period('day', '2014-2-3', 4).offset('first-of', 'year') - Period(('day', Instant((2014, 1, 1)), 4)) - - >>> period('month', '2014-2-3').offset('first-of') - Period(('month', Instant((2014, 2, 1)), 1)) - >>> period('month', '2014-2-3').offset('first-of', 'month') - Period(('month', Instant((2014, 2, 1)), 1)) - >>> period('month', '2014-2-3').offset('first-of', 'year') - Period(('month', Instant((2014, 1, 1)), 1)) - - >>> period('month', '2014-2-3', 4).offset('first-of') - Period(('month', Instant((2014, 2, 1)), 4)) - >>> period('month', '2014-2-3', 4).offset('first-of', 'month') - Period(('month', Instant((2014, 2, 1)), 4)) - >>> period('month', '2014-2-3', 4).offset('first-of', 'year') - Period(('month', Instant((2014, 1, 1)), 4)) - - >>> period('year', 2014).offset('first-of') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', 2014).offset('first-of', 'month') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', 2014).offset('first-of', 'year') - Period(('year', Instant((2014, 1, 1)), 1)) - - >>> period('year', '2014-2-3').offset('first-of') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', '2014-2-3').offset('first-of', 'month') - Period(('year', Instant((2014, 2, 1)), 1)) - >>> period('year', '2014-2-3').offset('first-of', 'year') - Period(('year', Instant((2014, 1, 1)), 1)) - - >>> period('day', '2014-2-3').offset('last-of', 'month') - Period(('day', Instant((2014, 2, 28)), 1)) - >>> period('day', '2014-2-3').offset('last-of', 'year') - Period(('day', Instant((2014, 12, 31)), 1)) - - >>> period('day', '2014-2-3', 4).offset('last-of', 'month') - Period(('day', Instant((2014, 2, 28)), 4)) - >>> period('day', '2014-2-3', 4).offset('last-of', 'year') - Period(('day', Instant((2014, 12, 31)), 4)) - - >>> period('month', '2014-2-3').offset('last-of') - Period(('month', Instant((2014, 2, 28)), 1)) - >>> period('month', '2014-2-3').offset('last-of', 'month') - Period(('month', Instant((2014, 2, 28)), 1)) - >>> period('month', '2014-2-3').offset('last-of', 'year') - Period(('month', Instant((2014, 12, 31)), 1)) - - >>> period('month', '2014-2-3', 4).offset('last-of') - Period(('month', Instant((2014, 2, 28)), 4)) - >>> period('month', '2014-2-3', 4).offset('last-of', 'month') - Period(('month', Instant((2014, 2, 28)), 4)) - >>> period('month', '2014-2-3', 4).offset('last-of', 'year') - Period(('month', Instant((2014, 12, 31)), 4)) - - >>> period('year', 2014).offset('last-of') - Period(('year', Instant((2014, 12, 31)), 1)) - >>> period('year', 2014).offset('last-of', 'month') - Period(('year', Instant((2014, 1, 31)), 1)) - >>> period('year', 2014).offset('last-of', 'year') - Period(('year', Instant((2014, 12, 31)), 1)) - - >>> period('year', '2014-2-3').offset('last-of') - Period(('year', Instant((2014, 12, 31)), 1)) - >>> period('year', '2014-2-3').offset('last-of', 'month') - Period(('year', Instant((2014, 2, 28)), 1)) - >>> period('year', '2014-2-3').offset('last-of', 'year') - Period(('year', Instant((2014, 12, 31)), 1)) - """ - return self.__class__((self[0], self[1].offset(offset, self[0] if unit is None else unit), self[2])) + if unit == DateUnit.WEEK: + return [ + self.first_week.offset(i, DateUnit.WEEK) + for i in range(self.size_in_weeks) + ] + + if unit == DateUnit.WEEKDAY: + return [ + self.first_weekday.offset(i, DateUnit.WEEKDAY) + for i in range(self.size_in_weekdays) + ] + + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) + + def offset(self, offset: str | int, unit: t.DateUnit | None = None) -> t.Period: + """Increment (or decrement) the given period with offset units. + + Examples: + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset(1) + Period((, Instant((2021, 1, 2)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 365)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(1) + Period((, Instant((2021, 2, 1)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 12)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset(1) + Period((, Instant((2022, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2011, 3, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2011, 3, 28)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2012, 2, 28)), 1)) + + >>> Period((DateUnit.DAY, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2011, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2011, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2010, 3, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2014, 2, 2)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2014, 4, 30)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2017, 1, 30)), 1)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset(-3) + Period((, Instant((2020, 12, 29)), 365)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(-3) + Period((, Instant((2020, 10, 1)), 12)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset(-3) + Period((, Instant((2011, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("first-of") + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("first-of") + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 4)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset("first-of") + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("first-of") + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("last-of") + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 4)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 1, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) - def contains(self, other: Period) -> bool: """ - Returns ``True`` if the period contains ``other``. For instance, ``period(2015)`` contains ``period(2015-01)`` + + start: None | t.Instant = self[1].offset( + offset, self[0] if unit is None else unit + ) + + if start is None: + raise NotImplementedError + + return self.__class__( + ( + self[0], + start, + self[2], + ), + ) + + def contains(self, other: t.Period) -> bool: + """Returns ``True`` if the period contains ``other``. + + For instance, ``period(2015)`` contains ``period(2015-01)``. + """ return self.start <= other.start and self.stop >= other.stop @property - def size(self): - """ - Return the size of the period. + def stop(self) -> t.Instant: + """Return the last day of the period as an Instant instance. - >>> period('month', '2012-2-29', 4).size - 4 - """ - return self[2] + Examples: + >>> Period((DateUnit.YEAR, Instant((2022, 1, 1)), 1)).stop + Instant((2022, 12, 31)) - @property - def size_in_months(self): - """ - Return the size of the period in months. + >>> Period((DateUnit.MONTH, Instant((2022, 1, 1)), 12)).stop + Instant((2022, 12, 31)) - >>> period('month', '2012-2-29', 4).size_in_months - 4 - >>> period('year', '2012', 1).size_in_months - 12 - """ - if (self[0] == config.MONTH): - return self[2] - if(self[0] == config.YEAR): - return self[2] * 12 - raise ValueError("Cannot calculate number of months in {0}".format(self[0])) + >>> Period((DateUnit.DAY, Instant((2022, 1, 1)), 365)).stop + Instant((2022, 12, 31)) - @property - def size_in_days(self): - """ - Return the size of the period in days. + >>> Period((DateUnit.YEAR, Instant((2012, 2, 29)), 1)).stop + Instant((2013, 2, 27)) - >>> period('month', '2012-2-29', 4).size_in_days - 28 - >>> period('year', '2012', 1).size_in_days - 366 - """ - unit, instant, length = self + >>> Period((DateUnit.MONTH, Instant((2012, 2, 29)), 1)).stop + Instant((2012, 3, 28)) - if unit == config.DAY: - return length - if unit in [config.MONTH, config.YEAR]: - last_day = self.start.offset(length, unit).offset(-1, config.DAY) - return (last_day.date - self.start.date).days + 1 + >>> Period((DateUnit.DAY, Instant((2012, 2, 29)), 1)).stop + Instant((2012, 2, 29)) - raise ValueError("Cannot calculate number of days in {0}".format(unit)) + >>> Period((DateUnit.YEAR, Instant((2012, 2, 29)), 2)).stop + Instant((2014, 2, 27)) - @property - def start(self) -> periods.Instant: - """ - Return the first day of the period as an Instant instance. + >>> Period((DateUnit.MONTH, Instant((2012, 2, 29)), 2)).stop + Instant((2012, 4, 28)) - >>> period('month', '2012-2-29', 4).start - Instant((2012, 2, 29)) - """ - return self[1] + >>> Period((DateUnit.DAY, Instant((2012, 2, 29)), 2)).stop + Instant((2012, 3, 1)) - @property - def stop(self) -> periods.Instant: - """ - Return the last day of the period as an Instant instance. - - >>> period('year', 2014).stop - Instant((2014, 12, 31)) - >>> period('month', 2014).stop - Instant((2014, 12, 31)) - >>> period('day', 2014).stop - Instant((2014, 12, 31)) - - >>> period('year', '2012-2-29').stop - Instant((2013, 2, 28)) - >>> period('month', '2012-2-29').stop - Instant((2012, 3, 28)) - >>> period('day', '2012-2-29').stop - Instant((2012, 2, 29)) - - >>> period('year', '2012-2-29', 2).stop - Instant((2014, 2, 28)) - >>> period('month', '2012-2-29', 2).stop - Instant((2012, 4, 28)) - >>> period('day', '2012-2-29', 2).stop - Instant((2012, 3, 1)) """ unit, start_instant, size = self - year, month, day = start_instant - if unit == config.ETERNITY: - return periods.Instant((float("inf"), float("inf"), float("inf"))) - if unit == 'day': - if size > 1: - day += size - 1 - month_last_day = calendar.monthrange(year, month)[1] - while day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - month_last_day = calendar.monthrange(year, month)[1] - else: - if unit == 'month': - month += size - while month > 12: - year += 1 - month -= 12 - else: - assert unit == 'year', 'Invalid unit: {} of type {}'.format(unit, type(unit)) - year += size - day -= 1 - if day < 1: - month -= 1 - if month == 0: - year -= 1 - month = 12 - day += calendar.monthrange(year, month)[1] - else: - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - return periods.Instant((year, month, day)) + + if unit == DateUnit.ETERNITY: + return Instant.eternity() + + if unit == DateUnit.YEAR: + date = start_instant.date.add(years=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit == DateUnit.MONTH: + date = start_instant.date.add(months=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = start_instant.date.add(weeks=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): + date = start_instant.date.add(days=size - 1) + return Instant((date.year, date.month, date.day)) + + raise ValueError @property - def unit(self): - return self[0] + def is_eternal(self) -> bool: + return self == self.eternity() # Reference periods @property - def last_3_months(self): - return self.first_month.start.period('month', 3).offset(-3) + def last_week(self) -> t.Period: + return self.first_week.offset(-1) @property - def last_month(self): + def last_fortnight(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 1)).offset(-2) + + @property + def last_2_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 2)).offset(-2) + + @property + def last_26_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 26)).offset(-26) + + @property + def last_52_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 52)).offset(-52) + + @property + def last_month(self) -> t.Period: return self.first_month.offset(-1) @property - def last_year(self): - return self.start.offset('first-of', 'year').period('year').offset(-1) + def last_3_months(self) -> t.Period: + start: t.Instant = self.first_month.start + return self.__class__((DateUnit.MONTH, start, 3)).offset(-3) + + @property + def last_year(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)).offset(-1) @property - def n_2(self): - return self.start.offset('first-of', 'year').period('year').offset(-2) + def n_2(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)).offset(-2) @property - def this_year(self): - return self.start.offset('first-of', 'year').period('year') + def this_year(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)) + + @property + def first_month(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.MONTH) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.MONTH, start, 1)) @property - def first_month(self): - return self.start.offset('first-of', 'month').period('month') + def first_day(self) -> t.Period: + return self.__class__((DateUnit.DAY, self.start, 1)) @property - def first_day(self): - return self.start.period('day') + def first_week(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.WEEK) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.WEEK, start, 1)) + + @property + def first_weekday(self) -> t.Period: + return self.__class__((DateUnit.WEEKDAY, self.start, 1)) + + @classmethod + def eternity(cls) -> t.Period: + """Return an eternity period.""" + return cls((DateUnit.ETERNITY, Instant.eternity(), -1)) + + +__all__ = ["Period"] diff --git a/openfisca_core/periods/py.typed b/openfisca_core/periods/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/__init__.py b/openfisca_core/periods/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/helpers/__init__.py b/openfisca_core/periods/tests/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/helpers/test_helpers.py b/openfisca_core/periods/tests/helpers/test_helpers.py new file mode 100644 index 0000000000..175ea8c873 --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_helpers.py @@ -0,0 +1,65 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, Period + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (None, None), + (Instant((1, 1, 1)), datetime.date(1, 1, 1)), + (Instant((4, 2, 29)), datetime.date(4, 2, 29)), + ((1, 1, 1), datetime.date(1, 1, 1)), + ], +) +def test_instant_date(arg, expected) -> None: + assert periods.instant_date(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (Instant((-1, 1, 1)), ValueError), + (Instant((1, -1, 1)), ValueError), + (Instant((1, 13, -1)), ValueError), + (Instant((1, 1, -1)), ValueError), + (Instant((1, 1, 32)), ValueError), + (Instant((1, 2, 29)), ValueError), + (Instant(("1", 1, 1)), TypeError), + ((1,), TypeError), + ((1, 1), TypeError), + ], +) +def test_instant_date_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.instant_date(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (Period((DateUnit.WEEKDAY, Instant((1, 1, 1)), 5)), "100_5"), + (Period((DateUnit.WEEK, Instant((1, 1, 1)), 26)), "200_26"), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), "100_365"), + (Period((DateUnit.MONTH, Instant((1, 1, 1)), 12)), "200_12"), + (Period((DateUnit.YEAR, Instant((1, 1, 1)), 2)), "300_2"), + (Period((DateUnit.ETERNITY, Instant((1, 1, 1)), 1)), "400_1"), + ], +) +def test_key_period_size(arg, expected) -> None: + assert periods.key_period_size(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + ((DateUnit.DAY, None, 1), AttributeError), + ((DateUnit.MONTH, None, -1000), AttributeError), + ], +) +def test_key_period_size_when_an_invalid_argument(arg, error): + with pytest.raises(error): + periods.key_period_size(arg) diff --git a/openfisca_core/periods/tests/helpers/test_instant.py b/openfisca_core/periods/tests/helpers/test_instant.py new file mode 100644 index 0000000000..fb4472814b --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_instant.py @@ -0,0 +1,73 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, InstantError, Period + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (datetime.date(1, 1, 1), Instant((1, 1, 1))), + (Instant((1, 1, 1)), Instant((1, 1, 1))), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), Instant((1, 1, 1))), + (-1, Instant((-1, 1, 1))), + (0, Instant((0, 1, 1))), + (1, Instant((1, 1, 1))), + (999, Instant((999, 1, 1))), + (1000, Instant((1000, 1, 1))), + ("1000", Instant((1000, 1, 1))), + ("1000-01", Instant((1000, 1, 1))), + ("1000-01-01", Instant((1000, 1, 1))), + ((-1,), Instant((-1, 1, 1))), + ((-1, -1), Instant((-1, -1, 1))), + ((-1, -1, -1), Instant((-1, -1, -1))), + ], +) +def test_instant(arg, expected) -> None: + assert periods.instant(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, InstantError), + (DateUnit.YEAR, ValueError), + (DateUnit.ETERNITY, ValueError), + ("1000-0", ValueError), + ("1000-0-0", ValueError), + ("1000-1", ValueError), + ("1000-1-1", ValueError), + ("1", ValueError), + ("a", ValueError), + ("year", ValueError), + ("eternity", ValueError), + ("999", ValueError), + ("1:1000-01-01", ValueError), + ("a:1000-01-01", ValueError), + ("year:1000-01-01", ValueError), + ("year:1000-01-01:1", ValueError), + ("year:1000-01-01:3", ValueError), + ("1000-01-01:a", ValueError), + ("1000-01-01:1", ValueError), + ((), InstantError), + ({}, InstantError), + ("", InstantError), + ((None,), InstantError), + ((None, None), InstantError), + ((None, None, None), InstantError), + ((None, None, None, None), InstantError), + (("-1",), InstantError), + (("-1", "-1"), InstantError), + (("-1", "-1", "-1"), InstantError), + (("1-1",), InstantError), + (("1-1-1",), InstantError), + ((datetime.date(1, 1, 1),), InstantError), + ((Instant((1, 1, 1)),), InstantError), + ((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), InstantError), + ], +) +def test_instant_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.instant(arg) diff --git a/openfisca_core/periods/tests/helpers/test_period.py b/openfisca_core/periods/tests/helpers/test_period.py new file mode 100644 index 0000000000..d2d5c6679a --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_period.py @@ -0,0 +1,134 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, Period, PeriodError + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("eternity", Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1))), + ("ETERNITY", Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1))), + ( + DateUnit.ETERNITY, + Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1)), + ), + (datetime.date(1, 1, 1), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + (Instant((1, 1, 1)), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + ( + Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), + Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), + ), + (-1, Period((DateUnit.YEAR, Instant((-1, 1, 1)), 1))), + (0, Period((DateUnit.YEAR, Instant((0, 1, 1)), 1))), + (1, Period((DateUnit.YEAR, Instant((1, 1, 1)), 1))), + (999, Period((DateUnit.YEAR, Instant((999, 1, 1)), 1))), + (1000, Period((DateUnit.YEAR, Instant((1000, 1, 1)), 1))), + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1004-02-29", Period((DateUnit.DAY, Instant((1004, 2, 29)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ("year:1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-W01:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("year:1001-W01-1:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("month:1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("week:1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01-1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("month:1001-01:1", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("month:1001-01-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("week:1001-W01:1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("week:1001-W01-1:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("day:1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("day:1001-01-01:3", Period((DateUnit.DAY, Instant((1001, 1, 1)), 3))), + ("weekday:1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ( + "weekday:1001-W01-1:3", + Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 3)), + ), + ], +) +def test_period(arg, expected) -> None: + assert periods.period(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, PeriodError), + (DateUnit.YEAR, PeriodError), + ("1", PeriodError), + ("999", PeriodError), + ("1000-0", PeriodError), + ("1000-13", PeriodError), + ("1000-W0", PeriodError), + ("1000-W54", PeriodError), + ("1000-0-0", PeriodError), + ("1000-1-0", PeriodError), + ("1000-2-31", PeriodError), + ("1000-W0-0", PeriodError), + ("1000-W1-0", PeriodError), + ("1000-W1-8", PeriodError), + ("a", PeriodError), + ("year", PeriodError), + ("1:1000", PeriodError), + ("a:1000", PeriodError), + ("month:1000", PeriodError), + ("week:1000", PeriodError), + ("day:1000-01", PeriodError), + ("weekday:1000-W1", PeriodError), + ("1000:a", PeriodError), + ("1000:1", PeriodError), + ("1000-01:1", PeriodError), + ("1000-01-01:1", PeriodError), + ("1000-W1:1", PeriodError), + ("1000-W1-1:1", PeriodError), + ("month:1000:1", PeriodError), + ("week:1000:1", PeriodError), + ("day:1000:1", PeriodError), + ("day:1000-01:1", PeriodError), + ("weekday:1000:1", PeriodError), + ("weekday:1000-W1:1", PeriodError), + ((), PeriodError), + ({}, PeriodError), + ("", PeriodError), + ((None,), PeriodError), + ((None, None), PeriodError), + ((None, None, None), PeriodError), + ((None, None, None, None), PeriodError), + ((Instant((1, 1, 1)),), PeriodError), + ((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), PeriodError), + ((1,), PeriodError), + ((1, 1), PeriodError), + ((1, 1, 1), PeriodError), + ((-1,), PeriodError), + ((-1, -1), PeriodError), + ((-1, -1, -1), PeriodError), + (("-1",), PeriodError), + (("-1", "-1"), PeriodError), + (("-1", "-1", "-1"), PeriodError), + (("1-1",), PeriodError), + (("1-1-1",), PeriodError), + ], +) +def test_period_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.period(arg) diff --git a/openfisca_core/periods/tests/test_instant.py b/openfisca_core/periods/tests/test_instant.py new file mode 100644 index 0000000000..e9c73ef6aa --- /dev/null +++ b/openfisca_core/periods/tests/test_instant.py @@ -0,0 +1,32 @@ +import pytest + +from openfisca_core.periods import DateUnit, Instant + + +@pytest.mark.parametrize( + ("instant", "offset", "unit", "expected"), + [ + (Instant((2020, 2, 29)), "first-of", DateUnit.YEAR, Instant((2020, 1, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.MONTH, Instant((2020, 2, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEK, Instant((2020, 2, 24))), + (Instant((2020, 2, 29)), "first-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.YEAR, Instant((2020, 12, 31))), + (Instant((2020, 2, 29)), "last-of", DateUnit.MONTH, Instant((2020, 2, 29))), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEK, Instant((2020, 3, 1))), + (Instant((2020, 2, 29)), "last-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), -3, DateUnit.YEAR, Instant((2017, 2, 28))), + (Instant((2020, 2, 29)), -3, DateUnit.MONTH, Instant((2019, 11, 29))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEK, Instant((2020, 2, 8))), + (Instant((2020, 2, 29)), -3, DateUnit.DAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEKDAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), 3, DateUnit.YEAR, Instant((2023, 2, 28))), + (Instant((2020, 2, 29)), 3, DateUnit.MONTH, Instant((2020, 5, 29))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEK, Instant((2020, 3, 21))), + (Instant((2020, 2, 29)), 3, DateUnit.DAY, Instant((2020, 3, 3))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEKDAY, Instant((2020, 3, 3))), + ], +) +def test_offset(instant, offset, unit, expected) -> None: + assert instant.offset(offset, unit) == expected diff --git a/openfisca_core/periods/tests/test_parsers.py b/openfisca_core/periods/tests/test_parsers.py new file mode 100644 index 0000000000..c9131414b2 --- /dev/null +++ b/openfisca_core/periods/tests/test_parsers.py @@ -0,0 +1,129 @@ +import pytest + +from openfisca_core.periods import ( + DateUnit, + Instant, + InstantError, + ParserError, + Period, + PeriodError, + _parsers, +) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("1001", Instant((1001, 1, 1))), + ("1001-01", Instant((1001, 1, 1))), + ("1001-12", Instant((1001, 12, 1))), + ("1001-01-01", Instant((1001, 1, 1))), + ("2028-02-29", Instant((2028, 2, 29))), + ("1001-W01", Instant((1000, 12, 29))), + ("1001-W52", Instant((1001, 12, 21))), + ("1001-W01-1", Instant((1000, 12, 29))), + ], +) +def test_parse_instant(arg, expected) -> None: + assert _parsers.parse_instant(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, InstantError), + ({}, InstantError), + ((), InstantError), + ([], InstantError), + (1, InstantError), + ("", InstantError), + ("à", InstantError), + ("1", InstantError), + ("-1", InstantError), + ("999", InstantError), + ("1000-0", InstantError), + ("1000-1", ParserError), + ("1000-1-1", InstantError), + ("1000-00", InstantError), + ("1000-13", InstantError), + ("1000-01-00", InstantError), + ("1000-01-99", InstantError), + ("2029-02-29", ParserError), + ("1000-W0", InstantError), + ("1000-W1", InstantError), + ("1000-W99", InstantError), + ("1000-W1-0", InstantError), + ("1000-W1-1", InstantError), + ("1000-W1-99", InstantError), + ("1000-W01-0", InstantError), + ("1000-W01-00", InstantError), + ], +) +def test_parse_instant_with_invalid_argument(arg, error) -> None: + with pytest.raises(error): + _parsers.parse_instant(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-12", Period((DateUnit.MONTH, Instant((1001, 12, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W52", Period((DateUnit.WEEK, Instant((1001, 12, 21)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ], +) +def test_parse_period(arg, expected) -> None: + assert _parsers.parse_period(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, PeriodError), + ({}, PeriodError), + ((), PeriodError), + ([], PeriodError), + (1, PeriodError), + ("", PeriodError), + ("à", PeriodError), + ("1", PeriodError), + ("-1", PeriodError), + ("999", PeriodError), + ("1000-0", PeriodError), + ("1000-1", ParserError), + ("1000-1-1", PeriodError), + ("1000-00", PeriodError), + ("1000-13", PeriodError), + ("1000-01-00", PeriodError), + ("1000-01-99", PeriodError), + ("1000-W0", PeriodError), + ("1000-W1", PeriodError), + ("1000-W99", PeriodError), + ("1000-W1-0", PeriodError), + ("1000-W1-1", PeriodError), + ("1000-W1-99", PeriodError), + ("1000-W01-0", PeriodError), + ("1000-W01-00", PeriodError), + ], +) +def test_parse_period_with_invalid_argument(arg, error) -> None: + with pytest.raises(error): + _parsers.parse_period(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("2022", DateUnit.YEAR), + ("2022-01", DateUnit.MONTH), + ("2022-01-01", DateUnit.DAY), + ("2022-W01", DateUnit.WEEK), + ("2022-W01-1", DateUnit.WEEKDAY), + ], +) +def test_parse_unit(arg, expected) -> None: + assert _parsers.parse_unit(arg) == expected diff --git a/openfisca_core/periods/tests/test_period.py b/openfisca_core/periods/tests/test_period.py new file mode 100644 index 0000000000..9e53bf7d12 --- /dev/null +++ b/openfisca_core/periods/tests/test_period.py @@ -0,0 +1,283 @@ +import pytest + +from openfisca_core.periods import DateUnit, Instant, Period + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 1, 1)), 1, "2022"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 12, "2022"), + (DateUnit.YEAR, Instant((2022, 3, 1)), 1, "year:2022-03"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 12, "year:2022-03"), + (DateUnit.YEAR, Instant((2022, 1, 1)), 3, "year:2022:3"), + (DateUnit.YEAR, Instant((2022, 1, 3)), 3, "year:2022:3"), + ], +) +def test_str_with_years(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.MONTH, Instant((2022, 1, 1)), 1, "2022-01"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 3, "month:2022-01:3"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 3, "month:2022-03:3"), + ], +) +def test_str_with_months(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.DAY, Instant((2022, 1, 1)), 1, "2022-01-01"), + (DateUnit.DAY, Instant((2022, 1, 1)), 3, "day:2022-01-01:3"), + (DateUnit.DAY, Instant((2022, 3, 1)), 3, "day:2022-03-01:3"), + ], +) +def test_str_with_days(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.WEEK, Instant((2022, 1, 1)), 1, "2021-W52"), + (DateUnit.WEEK, Instant((2022, 1, 1)), 3, "week:2021-W52:3"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 1, "2022-W09"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 3, "week:2022-W09:3"), + ], +) +def test_str_with_weeks(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 1, "2021-W52-6"), + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 3, "weekday:2021-W52-6:3"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 1, "2022-W09-2"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 3, "weekday:2022-W09-2:3"), + ], +) +def test_str_with_weekdays(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 1), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 2), + ], +) +def test_size_in_years(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_years == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 12), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 24), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 3), + ], +) +def test_size_in_months(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_months == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 365), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 366), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 730), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_days(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_days == expected + assert period.size_in_days == period.days + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 52), + (DateUnit.YEAR, Instant((2020, 1, 1)), 5, 261), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 4), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 4), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 12), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 13), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_weeks(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_weeks == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 364), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 364), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 728), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_weekdays(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_weekdays == expected + + +@pytest.mark.parametrize( + ("period_unit", "sub_unit", "instant", "start", "cease", "count"), + [ + ( + DateUnit.YEAR, + DateUnit.YEAR, + Instant((2022, 12, 31)), + Instant((2022, 1, 1)), + Instant((2024, 1, 1)), + 3, + ), + ( + DateUnit.YEAR, + DateUnit.MONTH, + Instant((2022, 12, 31)), + Instant((2022, 12, 1)), + Instant((2025, 11, 1)), + 36, + ), + ( + DateUnit.YEAR, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2025, 12, 30)), + 1096, + ), + ( + DateUnit.YEAR, + DateUnit.WEEK, + Instant((2022, 12, 31)), + Instant((2022, 12, 26)), + Instant((2025, 12, 15)), + 156, + ), + ( + DateUnit.YEAR, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2025, 12, 26)), + 1092, + ), + ( + DateUnit.MONTH, + DateUnit.MONTH, + Instant((2022, 12, 31)), + Instant((2022, 12, 1)), + Instant((2023, 2, 1)), + 3, + ), + ( + DateUnit.MONTH, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 3, 30)), + 90, + ), + ( + DateUnit.DAY, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.DAY, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.WEEK, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 20)), + 21, + ), + ( + DateUnit.WEEK, + DateUnit.WEEK, + Instant((2022, 12, 31)), + Instant((2022, 12, 26)), + Instant((2023, 1, 9)), + 3, + ), + ( + DateUnit.WEEK, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 20)), + 21, + ), + ( + DateUnit.WEEKDAY, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.WEEKDAY, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ], +) +def test_subperiods(period_unit, sub_unit, instant, start, cease, count) -> None: + period = Period((period_unit, instant, 3)) + subperiods = period.get_subperiods(sub_unit) + assert len(subperiods) == count + assert subperiods[0] == Period((sub_unit, start, 1)) + assert subperiods[-1] == Period((sub_unit, cease, 1)) diff --git a/openfisca_core/periods/types.py b/openfisca_core/periods/types.py new file mode 100644 index 0000000000..092509c621 --- /dev/null +++ b/openfisca_core/periods/types.py @@ -0,0 +1,183 @@ +# TODO(): Properly resolve metaclass types. +# https://github.com/python/mypy/issues/14033 + +from collections.abc import Sequence + +from openfisca_core.types import DateUnit, Instant, Period + +import re + +#: Matches "2015", "2015-01", "2015-01-01" but not "2015-13", "2015-12-32". +iso_format = re.compile(r"^\d{4}(-(?:0[1-9]|1[0-2])(-(?:0[1-9]|[12]\d|3[01]))?)?$") + +#: Matches "2015", "2015-W01", "2015-W53-1" but not "2015-W54", "2015-W10-8". +iso_calendar = re.compile(r"^\d{4}(-W(0[1-9]|[1-4][0-9]|5[0-3]))?(-[1-7])?$") + + +class _SeqIntMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return ( + bool(arg) + and isinstance(arg, Sequence) + and all(isinstance(item, int) for item in arg) + ) + + +class SeqInt(list[int], metaclass=_SeqIntMeta): # type: ignore[misc] + """A sequence of integers. + + Examples: + >>> isinstance([1, 2, 3], SeqInt) + True + + >>> isinstance((1, 2, 3), SeqInt) + True + + >>> isinstance({1, 2, 3}, SeqInt) + False + + >>> isinstance([1, 2, "3"], SeqInt) + False + + >>> isinstance(1, SeqInt) + False + + >>> isinstance([], SeqInt) + False + + """ + + +class _InstantStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, (ISOFormatStr, ISOCalendarStr)) + + +class InstantStr(str, metaclass=_InstantStrMeta): # type: ignore[misc] + """A string representing an instant in string format. + + Examples: + >>> isinstance("2015", InstantStr) + True + + >>> isinstance("2015-01", InstantStr) + True + + >>> isinstance("2015-W01", InstantStr) + True + + >>> isinstance("2015-W01-12", InstantStr) + False + + >>> isinstance("week:2015-W01:3", InstantStr) + False + + """ + + __slots__ = () + + +class _ISOFormatStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, str) and bool(iso_format.match(arg)) + + +class ISOFormatStr(str, metaclass=_ISOFormatStrMeta): # type: ignore[misc] + """A string representing an instant in ISO format. + + Examples: + >>> isinstance("2015", ISOFormatStr) + True + + >>> isinstance("2015-01", ISOFormatStr) + True + + >>> isinstance("2015-01-01", ISOFormatStr) + True + + >>> isinstance("2015-13", ISOFormatStr) + False + + >>> isinstance("2015-W01", ISOFormatStr) + False + + """ + + __slots__ = () + + +class _ISOCalendarStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, str) and bool(iso_calendar.match(arg)) + + +class ISOCalendarStr(str, metaclass=_ISOCalendarStrMeta): # type: ignore[misc] + """A string representing an instant in ISO calendar. + + Examples: + >>> isinstance("2015", ISOCalendarStr) + True + + >>> isinstance("2015-W01", ISOCalendarStr) + True + + >>> isinstance("2015-W11-7", ISOCalendarStr) + True + + >>> isinstance("2015-W010", ISOCalendarStr) + False + + >>> isinstance("2015-01", ISOCalendarStr) + False + + """ + + __slots__ = () + + +class _PeriodStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return ( + isinstance(arg, str) + and ":" in arg + and isinstance(arg.split(":")[1], InstantStr) + ) + + +class PeriodStr(str, metaclass=_PeriodStrMeta): # type: ignore[misc] + """A string representing a period. + + Examples: + >>> isinstance("year", PeriodStr) + False + + >>> isinstance("2015", PeriodStr) + False + + >>> isinstance("year:2015", PeriodStr) + True + + >>> isinstance("month:2015-01", PeriodStr) + True + + >>> isinstance("weekday:2015-W01-1:365", PeriodStr) + True + + >>> isinstance("2015-W01:1", PeriodStr) + False + + """ + + __slots__ = () + + +__all__ = [ + "DateUnit", + "ISOCalendarStr", + "ISOFormatStr", + "Instant", + "InstantStr", + "Period", + "PeriodStr", + "SeqInt", +] diff --git a/openfisca_core/populations/__init__.py b/openfisca_core/populations/__init__.py index 7dedd71dc6..36f000e38d 100644 --- a/openfisca_core/populations/__init__.py +++ b/openfisca_core/populations/__init__.py @@ -21,18 +21,44 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.projectors import ( # noqa: F401 - Projector, +from openfisca_core.projectors import ( EntityToPersonProjector, FirstPersonToEntityProjector, + Projector, UniqueRoleToEntityProjector, - ) +) +from openfisca_core.projectors.helpers import get_projector_from_shortcut, projectable + +from . import types +from ._core_population import CorePopulation +from ._errors import ( + IncompatibleOptionsError, + InvalidArraySizeError, + InvalidOptionError, + PeriodValidityError, +) +from .group_population import GroupPopulation +from .population import Population -from openfisca_core.projectors.helpers import ( # noqa: F401 - projectable, - get_projector_from_shortcut, - ) +ADD, DIVIDE = types.Option +SinglePopulation = Population -from .config import ADD, DIVIDE # noqa: F401 -from .population import Population # noqa: F401 -from .group_population import GroupPopulation # noqa: F401 +__all__ = [ + "ADD", + "DIVIDE", + "CorePopulation", + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "GroupPopulation", + "IncompatibleOptionsError", + "InvalidArraySizeError", + "InvalidOptionError", + "PeriodValidityError", + "Population", + "Projector", + "SinglePopulation", + "UniqueRoleToEntityProjector", + "get_projector_from_shortcut", + "projectable", + "types", +] diff --git a/openfisca_core/populations/_core_population.py b/openfisca_core/populations/_core_population.py new file mode 100644 index 0000000000..0041a6927a --- /dev/null +++ b/openfisca_core/populations/_core_population.py @@ -0,0 +1,455 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TypeVar + +import traceback + +import numpy + +from openfisca_core import holders, periods + +from . import types as t +from ._errors import ( + IncompatibleOptionsError, + InvalidArraySizeError, + InvalidOptionError, + PeriodValidityError, +) + +#: Type variable for a covariant data type. +_DT_co = TypeVar("_DT_co", covariant=True, bound=t.VarDType) + + +class CorePopulation: + """Base class to build populations from. + + Args: + entity: The :class:`.CoreEntity` of the population. + *__args: Variable length argument list. + **__kwds: Arbitrary keyword arguments. + + """ + + #: The number :class:`.CoreEntity` members in the population. + count: int = 0 + + #: The :class:`.CoreEntity` of the population. + entity: t.CoreEntity + + #: A pseudo index for the members of the population. + ids: Sequence[str] = [] + + #: The :class:`.Simulation` for which the population is calculated. + simulation: None | t.Simulation = None + + def __init__(self, entity: t.CoreEntity, *__args: object, **__kwds: object) -> None: + self.entity = entity + self._holders: t.HolderByVariable = {} + + def __call__( + self, + variable_name: t.VariableName, + period: t.PeriodLike, + options: None | Sequence[t.Option] = None, + ) -> None | t.VarArray: + """Calculate ``variable_name`` for ``period``, using the formula if it exists. + + Args: + variable_name: The name of the variable to calculate. + period: The period to calculate the variable for. + options: The options to use for the calculation. + + Returns: + None: If there is no :class:`.Simulation`. + ndarray[generic]: The result of the calculation. + + Raises: + IncompatibleOptionsError: If the options are incompatible. + InvalidOptionError: If the option is invalid. + + Examples: + >>> from openfisca_core import ( + ... entities, + ... periods, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... variables, + ... ) + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + >>> period = periods.Period.eternity() + >>> population = populations.CorePopulation(person) + >>> population.count = 3 + >>> population("salary", period) + + >>> tbs = taxbenefitsystems.TaxBenefitSystem([person]) + >>> person.set_tax_benefit_system(tbs) + >>> simulation = simulations.Simulation(tbs, {person.key: population}) + >>> population("salary", period) + Traceback (most recent call last): + VariableNotFoundError: You tried to calculate or to set a value ... + + >>> class Salary(variables.Variable): + ... definition_period = periods.ETERNITY + ... entity = person + ... value_type = int + + >>> tbs.add_variable(Salary) + >> population(Salary().name, period) + array([0, 0, 0], dtype=int32) + + >>> class Tax(Salary): + ... default_value = 100.0 + ... definition_period = periods.ETERNITY + ... entity = person + ... value_type = float + + >>> tbs.add_variable(Tax) + >> population(Tax().name, period) + array([100., 100., 100.], dtype=float32) + + >>> population(Tax().name, period, [populations.ADD]) + Traceback (most recent call last): + ValueError: Unable to ADD constant variable 'Tax' over the perio... + + >>> population(Tax().name, period, [populations.DIVIDE]) + Traceback (most recent call last): + ValueError: Unable to DIVIDE constant variable 'Tax' over the pe... + + >>> population(Tax().name, period, [populations.ADD, populations.DIVIDE]) + Traceback (most recent call last): + IncompatibleOptionsError: Options ADD and DIVIDE are incompatibl... + + >>> population(Tax().name, period, ["LAGRANGIAN"]) + Traceback (most recent call last): + InvalidOptionError: Option LAGRANGIAN is not a valid option (try... + + """ + if self.simulation is None: + return None + + calculate = t.Calculate( + variable=variable_name, + period=periods.period(period), + option=options, + ) + + self.entity.check_variable_defined_for_entity(calculate.variable) + self.check_period_validity(calculate.variable, calculate.period) + + if not isinstance(calculate.option, Sequence): + return self.simulation.calculate( + calculate.variable, + calculate.period, + ) + + if t.Option.ADD in calculate.option and t.Option.DIVIDE in calculate.option: + raise IncompatibleOptionsError(variable_name) + + if t.Option.ADD in calculate.option: + return self.simulation.calculate_add( + calculate.variable, + calculate.period, + ) + + if t.Option.DIVIDE in calculate.option: + return self.simulation.calculate_divide( + calculate.variable, + calculate.period, + ) + + raise InvalidOptionError(calculate.option[0], variable_name) + + def empty_array(self) -> t.FloatArray: + """Return an empty array. + + Returns: + ndarray[float32]: An empty array. + + Examples: + >>> import numpy + + >>> from openfisca_core import populations as p + + >>> class Population(p.CorePopulation): ... + + >>> population = Population(None) + >>> population.empty_array() + array([], dtype=float32) + + >>> population.count = 3 + >>> population.empty_array() + array([0., 0., 0.], dtype=float32) + + """ + return numpy.zeros(self.count, dtype=t.FloatDType) + + def filled_array( + self, value: _DT_co, dtype: None | t.DTypeLike = None + ) -> t.Array[_DT_co]: + """Return an array filled with a value. + + Args: + value: The value to fill the array with. + dtype: The data type of the array. + + Returns: + ndarray[generic]: An array filled with the value. + + Examples: + >>> import numpy + + >>> from openfisca_core import populations + + >>> class Population(populations.CorePopulation): ... + + >>> population = Population(None) + >>> population.count = 3 + >>> population.filled_array(1) + array([1, 1, 1]) + + >>> population.filled_array(numpy.float32(1)) + array([1., 1., 1.], dtype=float32) + + >>> population.filled_array(1, dtype=str) + array(['1', '1', '1'], dtype='>> population.filled_array("hola", dtype=numpy.uint8) + Traceback (most recent call last): + ValueError: could not convert string to float: 'hola' + + """ + return numpy.full(self.count, value, dtype) + + def get_index(self, id: str) -> int: + """Return the index of an `id``. + + Args: + id: The id to get the index for. + + Returns: + int: The index of the id. + + Examples: + >>> from openfisca_core import entities, populations + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + >>> population = populations.CorePopulation(person) + >>> population.ids = ["Juan", "Megan", "Brahim"] + + >>> population.get_index("Megan") + 1 + + >>> population.get_index("Ibrahim") + Traceback (most recent call last): + ValueError: 'Ibrahim' is not in list + + """ + return self.ids.index(id) + + # Calculations + + def check_array_compatible_with_entity(self, array: t.VarArray) -> None: + """Check if an array is compatible with the population. + + Args: + array: The array to check. + + Raises: + InvalidArraySizeError: If the array is not compatible. + + Examples: + >>> import numpy + + >>> from openfisca_core import entities, populations + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + >>> population = populations.CorePopulation(person) + >>> population.count = 3 + + >>> array = numpy.array([1, 2, 3]) + >>> population.check_array_compatible_with_entity(array) + + >>> array = numpy.array([1, 2, 3, 4]) + >>> population.check_array_compatible_with_entity(array) + Traceback (most recent call last): + InvalidArraySizeError: Input [1 2 3 4] is not a valid value for t... + + """ + if self.count == array.size: + return + raise InvalidArraySizeError(array, self.entity.key, self.count) + + @staticmethod + def check_period_validity( + variable_name: t.VariableName, + period: None | t.PeriodLike = None, + ) -> None: + """Check if a period is valid. + + Args: + variable_name: The name of the variable. + period: The period to check. + + Raises: + PeriodValidityError: If the period is not valid. + + Examples: + >>> from openfisca_core import entities, periods, populations + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + >>> period = periods.Period("2017-04") + >>> population = populations.CorePopulation(person) + + >>> population.check_period_validity("salary") + Traceback (most recent call last): + PeriodValidityError: You requested computation of variable "sala... + + >>> population.check_period_validity("salary", 2017) + + >>> population.check_period_validity("salary", "2017-04") + + >>> population.check_period_validity("salary", period) + + """ + if isinstance(period, (int, str, periods.Period)): + return + stack = traceback.extract_stack() + filename, line_number, _, line_of_code = stack[-3] + raise PeriodValidityError(variable_name, filename, line_number, line_of_code) + + # Helpers + + def get_holder(self, variable_name: t.VariableName) -> t.Holder: + """Return the holder of a variable. + + Args: + variable_name: The name of the variable. + + Returns: + Holder: The holder of the variable. + + Examples: + >>> from openfisca_core import ( + ... entities, + ... holders, + ... periods, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... simulations, + ... variables, + ... ) + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + + >>> class Salary(variables.Variable): + ... definition_period = periods.WEEK + ... entity = person + ... value_type = int + + >>> tbs = taxbenefitsystems.TaxBenefitSystem([person]) + >>> person.set_tax_benefit_system(tbs) + >>> population = populations.SinglePopulation(person) + >>> simulation = simulations.Simulation(tbs, {person.key: population}) + >>> population.get_holder("income_tax") + Traceback (most recent call last): + VariableNotFoundError: You tried to calculate or to set a value ... + + >>> tbs.add_variable(Salary) + >> salary = Salary() + >>> population.get_holder(salary.name) + t.MemoryUsageByVariable: + """Return the memory usage of the population per variable. + + Args: + variables: The variables to get the memory usage for. + + Returns: + MemoryUsageByVariable: The memory usage of the population per variable. + + Examples: + >>> from openfisca_core import ( + ... entities, + ... holders, + ... periods, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... simulations, + ... variables, + ... ) + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + + >>> class Salary(variables.Variable): + ... definition_period = periods.WEEK + ... entity = person + ... value_type = int + + >>> tbs = taxbenefitsystems.TaxBenefitSystem([person]) + >>> population = populations.SinglePopulation(person) + >>> simulation = simulations.Simulation(tbs, {person.key: population}) + >>> salary = Salary() + >>> holder = holders.Holder(salary, population) + >>> population._holders[salary.name] = holder + + >>> population.get_memory_usage() + {'total_nb_bytes': 0, 'by_variable': {'Salary': {'nb_cells_by...}}} + + >>> population.get_memory_usage([salary.name]) + {'total_nb_bytes': 0, 'by_variable': {'Salary': {'nb_cells_by...}}} + + """ + holders_memory_usage = { + variable_name: holder.get_memory_usage() + for variable_name, holder in self._holders.items() + if variables is None or variable_name in variables + } + + total_memory_usage = sum( + holder_memory_usage["total_nb_bytes"] + for holder_memory_usage in holders_memory_usage.values() + ) + + return t.MemoryUsageByVariable( + total_nb_bytes=total_memory_usage, + by_variable=holders_memory_usage, + ) + + +__all__ = ["CorePopulation"] diff --git a/openfisca_core/populations/_errors.py b/openfisca_core/populations/_errors.py new file mode 100644 index 0000000000..0aad0d11dc --- /dev/null +++ b/openfisca_core/populations/_errors.py @@ -0,0 +1,65 @@ +from . import types as t + + +class IncompatibleOptionsError(ValueError): + """Raised when two options are incompatible.""" + + def __init__(self, variable_name: t.VariableName) -> None: + add, divide = t.Option + msg = ( + f"Options {add} and {divide} are incompatible (trying to compute " + f"variable {variable_name})." + ) + super().__init__(msg) + + +class InvalidOptionError(ValueError): + """Raised when an option is invalid.""" + + def __init__(self, option: str, variable_name: t.VariableName) -> None: + msg = ( + f"Option {option} is not a valid option (trying to compute " + f"variable {variable_name})." + ) + super().__init__(msg) + + +class InvalidArraySizeError(ValueError): + """Raised when an array has an invalid size.""" + + def __init__(self, array: t.VarArray, entity: t.EntityKey, count: int) -> None: + msg = ( + f"Input {array} is not a valid value for the entity {entity} " + f"(size = {array.size} != {count} = count)." + ) + super().__init__(msg) + + +class PeriodValidityError(ValueError): + """Raised when a period is not valid.""" + + def __init__( + self, + variable_name: t.VariableName, + filename: str, + line_number: int, + line_of_code: int, + ) -> None: + msg = ( + f'You requested computation of variable "{variable_name}", but ' + f'you did not specify on which period in "{filename}:{line_number}": ' + f"{line_of_code}. When you request the computation of a variable " + "within a formula, you must always specify the period as the second " + 'parameter. The convention is to call this parameter "period". For ' + 'example: computed_salary = person("salary", period). More information at ' + "." + ) + super().__init__(msg) + + +__all__ = [ + "IncompatibleOptionsError", + "InvalidArraySizeError", + "InvalidOptionError", + "PeriodValidityError", +] diff --git a/openfisca_core/populations/config.py b/openfisca_core/populations/config.py deleted file mode 100644 index 92a0b28865..0000000000 --- a/openfisca_core/populations/config.py +++ /dev/null @@ -1,2 +0,0 @@ -ADD = 'add' -DIVIDE = 'divide' diff --git a/openfisca_core/populations/group_population.py b/openfisca_core/populations/group_population.py index bcd345df10..120dc9c656 100644 --- a/openfisca_core/populations/group_population.py +++ b/openfisca_core/populations/group_population.py @@ -1,15 +1,17 @@ +from __future__ import annotations + import typing import numpy -from openfisca_core import projectors -from openfisca_core.entities import Role -from openfisca_core.indexed_enums import EnumArray -from openfisca_core.populations import Population +from openfisca_core import entities, indexed_enums, projectors + +from . import types as t +from .population import Population class GroupPopulation(Population): - def __init__(self, entity, members): + def __init__(self, entity: t.GroupEntity, members: t.Members) -> None: super().__init__(entity) self.members = members self._members_entity_id = None @@ -20,7 +22,9 @@ def __init__(self, entity, members): def clone(self, simulation): result = GroupPopulation(self.entity, self.members) result.simulation = simulation - result._holders = {variable: holder.clone(self) for (variable, holder) in self._holders.items()} + result._holders = { + variable: holder.clone(self) for (variable, holder) in self._holders.items() + } result.count = self.count result.ids = self.ids result._members_entity_id = self._members_entity_id @@ -32,7 +36,7 @@ def clone(self, simulation): @property def members_position(self): if self._members_position is None and self.members_entity_id is not None: - # We could use self.count and self.members.count , but with the current initilization, we are not sure count will be set before members_position is called + # We could use self.count and self.members.count , but with the current initialization, we are not sure count will be set before members_position is called nb_entities = numpy.max(self.members_entity_id) + 1 nb_persons = len(self.members_entity_id) self._members_position = numpy.empty_like(self.members_entity_id) @@ -45,7 +49,7 @@ def members_position(self): return self._members_position @members_position.setter - def members_position(self, members_position): + def members_position(self, members_position) -> None: self._members_position = members_position @property @@ -53,7 +57,7 @@ def members_entity_id(self): return self._members_entity_id @members_entity_id.setter - def members_entity_id(self, members_entity_id): + def members_entity_id(self, members_entity_id) -> None: self._members_entity_id = members_entity_id @property @@ -64,39 +68,44 @@ def members_role(self): return self._members_role @members_role.setter - def members_role(self, members_role: typing.Iterable[Role]): + def members_role(self, members_role: typing.Iterable[entities.Role]) -> None: if members_role is not None: self._members_role = numpy.array(list(members_role)) @property def ordered_members_map(self): - """ - Mask to group the persons by entity + """Mask to group the persons by entity This function only caches the map value, to see what the map is used for, see value_nth_person method. """ if self._ordered_members_map is None: self._ordered_members_map = numpy.argsort(self.members_entity_id) return self._ordered_members_map + # Helpers + def get_role(self, role_name): - return next((role for role in self.entity.flattened_roles if role.key == role_name), None) + return next( + (role for role in self.entity.flattened_roles if role.key == role_name), + None, + ) # Aggregation persons -> entity @projectors.projectable - def sum(self, array, role = None): - """ - Return the sum of ``array`` for the members of the entity. + def sum(self, array, role=None): + """Return the sum of ``array`` for the members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.sum(salaries) + >>> array([3500]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.sum(salaries) - >>> array([3500]) """ self.entity.check_role_validity(role) self.members.check_array_compatible_with_entity(array) @@ -104,142 +113,164 @@ def sum(self, array, role = None): role_filter = self.members.has_role(role) return numpy.bincount( self.members_entity_id[role_filter], - weights = array[role_filter], - minlength = self.count) - else: - return numpy.bincount(self.members_entity_id, weights = array) + weights=array[role_filter], + minlength=self.count, + ) + return numpy.bincount(self.members_entity_id, weights=array) @projectors.projectable - def any(self, array, role = None): - """ - Return ``True`` if ``array`` is ``True`` for any members of the entity. + def any(self, array, role=None): + """Return ``True`` if ``array`` is ``True`` for any members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.any(salaries >= 1800) + >>> array([True]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.any(salaries >= 1800) - >>> array([True]) """ - sum_in_entity = self.sum(array, role = role) - return (sum_in_entity > 0) + sum_in_entity = self.sum(array, role=role) + return sum_in_entity > 0 @projectors.projectable - def reduce(self, array, reducer, neutral_element, role = None): + def reduce(self, array, reducer, neutral_element, role=None): self.members.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) position_in_entity = self.members_position role_filter = self.members.has_role(role) if role is not None else True filtered_array = numpy.where(role_filter, array, neutral_element) - result = self.filled_array(neutral_element) # Neutral value that will be returned if no one with the given role exists. + result = self.filled_array( + neutral_element, + ) # Neutral value that will be returned if no one with the given role exists. # We loop over the positions in the entity - # Looping over the entities is tempting, but potentielly slow if there are a lot of entities + # Looping over the entities is tempting, but potentially slow if there are a lot of entities biggest_entity_size = numpy.max(position_in_entity) + 1 for p in range(biggest_entity_size): - values = self.value_nth_person(p, filtered_array, default = neutral_element) + values = self.value_nth_person(p, filtered_array, default=neutral_element) result = reducer(result, values) return result @projectors.projectable - def all(self, array, role = None): - """ - Return ``True`` if ``array`` is ``True`` for all members of the entity. + def all(self, array, role=None): + """Return ``True`` if ``array`` is ``True`` for all members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.all(salaries >= 1800) + >>> array([False]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.all(salaries >= 1800) - >>> array([False]) """ - return self.reduce(array, reducer = numpy.logical_and, neutral_element = True, role = role) + return self.reduce( + array, + reducer=numpy.logical_and, + neutral_element=True, + role=role, + ) @projectors.projectable - def max(self, array, role = None): - """ - Return the maximum value of ``array`` for the entity members. + def max(self, array, role=None): + """Return the maximum value of ``array`` for the entity members. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.max(salaries) + >>> array([2000]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.max(salaries) - >>> array([2000]) """ - return self.reduce(array, reducer = numpy.maximum, neutral_element = - numpy.infty, role = role) + return self.reduce( + array, + reducer=numpy.maximum, + neutral_element=-numpy.inf, + role=role, + ) @projectors.projectable - def min(self, array, role = None): - """ - Return the minimum value of ``array`` for the entity members. + def min(self, array, role=None): + """Return the minimum value of ``array`` for the entity members. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.min(salaries) + >>> array([0]) + >>> household.min( + ... salaries, role=Household.PARENT + ... ) # Assuming the 1st two persons are parents + >>> array([1500]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.min(salaries) - >>> array([0]) - >>> household.min(salaries, role = Household.PARENT) # Assuming the 1st two persons are parents - >>> array([1500]) """ - return self.reduce(array, reducer = numpy.minimum, neutral_element = numpy.infty, role = role) + return self.reduce( + array, + reducer=numpy.minimum, + neutral_element=numpy.inf, + role=role, + ) @projectors.projectable - def nb_persons(self, role = None): - """ - Returns the number of persons contained in the entity. + def nb_persons(self, role=None): + """Returns the number of persons contained in the entity. - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. """ if role: if role.subroles: - role_condition = numpy.logical_or.reduce([self.members_role == subrole for subrole in role.subroles]) + role_condition = numpy.logical_or.reduce( + [self.members_role == subrole for subrole in role.subroles], + ) else: role_condition = self.members_role == role return self.sum(role_condition) - else: - return numpy.bincount(self.members_entity_id) + return numpy.bincount(self.members_entity_id) # Projection person -> entity @projectors.projectable - def value_from_person(self, array, role, default = 0): - """ - Get the value of ``array`` for the person with the unique role ``role``. + def value_from_person(self, array, role, default=0): + """Get the value of ``array`` for the person with the unique role ``role``. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If such a person does not exist, return ``default`` instead + If such a person does not exist, return ``default`` instead - The result is a vector which dimension is the number of entities + The result is a vector which dimension is the number of entities """ self.entity.check_role_validity(role) if role.max != 1: + msg = f"You can only use value_from_person with a role that is unique in {self.key}. Role {role.key} is not unique." raise Exception( - 'You can only use value_from_person with a role that is unique in {}. Role {} is not unique.' - .format(self.key, role.key) - ) + msg, + ) self.members.check_array_compatible_with_entity(array) members_map = self.ordered_members_map - result = self.filled_array(default, dtype = array.dtype) - if isinstance(array, EnumArray): - result = EnumArray(result, array.possible_values) + result = self.filled_array(default, dtype=array.dtype) + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) role_filter = self.members.has_role(role) entity_filter = self.any(role_filter) @@ -248,24 +279,28 @@ def value_from_person(self, array, role, default = 0): return result @projectors.projectable - def value_nth_person(self, n, array, default = 0): - """ - Get the value of array for the person whose position in the entity is n. + def value_nth_person(self, n, array, default=0): + """Get the value of array for the person whose position in the entity is n. - Note that this position is arbitrary, and that members are not sorted. + Note that this position is arbitrary, and that members are not sorted. - If the nth person does not exist, return ``default`` instead. + If the nth person does not exist, return ``default`` instead. - The result is a vector which dimension is the number of entities. + The result is a vector which dimension is the number of entities. """ self.members.check_array_compatible_with_entity(array) positions = self.members_position nb_persons_per_entity = self.nb_persons() members_map = self.ordered_members_map - result = self.filled_array(default, dtype = array.dtype) + result = self.filled_array(default, dtype=array.dtype) # For households that have at least n persons, set the result as the value of criteria for the person for which the position is n. # The map is needed b/c the order of the nth persons of each household in the persons vector is not necessarily the same than the household order. - result[nb_persons_per_entity > n] = array[members_map][positions[members_map] == n] + result[nb_persons_per_entity > n] = array[members_map][ + positions[members_map] == n + ] + + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) return result @@ -275,11 +310,10 @@ def value_from_first_person(self, array): # Projection entity -> person(s) - def project(self, array, role = None): + def project(self, array, role=None): self.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) if role is None: return array[self.members_entity_id] - else: - role_condition = self.members.has_role(role) - return numpy.where(role_condition, array[self.members_entity_id], 0) + role_condition = self.members.has_role(role) + return numpy.where(role_condition, array[self.members_entity_id], 0) diff --git a/openfisca_core/populations/population.py b/openfisca_core/populations/population.py index 41cdbcd8c4..24742ab0a0 100644 --- a/openfisca_core/populations/population.py +++ b/openfisca_core/populations/population.py @@ -1,140 +1,80 @@ -import traceback +from __future__ import annotations import numpy from openfisca_core import projectors -from openfisca_core.holders import Holder -from openfisca_core.populations import config -from openfisca_core.projectors import Projector +from . import types as t +from ._core_population import CorePopulation -class Population: - def __init__(self, entity): - self.simulation = None - self.entity = entity - self._holders = {} - self.count = 0 - self.ids = [] - def clone(self, simulation): +class Population(CorePopulation): + def __init__(self, entity: t.SingleEntity) -> None: + super().__init__(entity) + + def clone(self, simulation: t.Simulation) -> t.CorePopulation: result = Population(self.entity) result.simulation = simulation - result._holders = {variable: holder.clone(result) for (variable, holder) in self._holders.items()} + result._holders = { + variable: holder.clone(result) + for (variable, holder) in self._holders.items() + } result.count = self.count result.ids = self.ids return result - def empty_array(self): - return numpy.zeros(self.count) - - def filled_array(self, value, dtype = None): - return numpy.full(self.count, value, dtype) - - def __getattr__(self, attribute): + def __getattr__(self, attribute: str) -> projectors.Projector: + projector: projectors.Projector | None projector = projectors.get_projector_from_shortcut(self, attribute) - if not projector: - raise AttributeError("You tried to use the '{}' of '{}' but that is not a known attribute.".format(attribute, self.entity.key)) - return projector - - def get_index(self, id): - return self.ids.index(id) - - # Calculations - - def check_array_compatible_with_entity(self, array): - if not self.count == array.size: - raise ValueError("Input {} is not a valid value for the entity {} (size = {} != {} = count)".format( - array, self.key, array.size, self.count)) - - def check_period_validity(self, variable_name, period): - if period is None: - stack = traceback.extract_stack() - filename, line_number, function_name, line_of_code = stack[-3] - raise ValueError(''' -You requested computation of variable "{}", but you did not specify on which period in "{}:{}": - {} -When you request the computation of a variable within a formula, you must always specify the period as the second parameter. The convention is to call this parameter "period". For example: - computed_salary = person('salary', period). -See more information at . -'''.format(variable_name, filename, line_number, line_of_code)) - - def __call__(self, variable_name, period = None, options = None): - """ - Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. - Example: - - >>> person('salary', '2017-04') - >>> array([300.]) - - :returns: A numpy array containing the result of the calculation - """ - self.entity.check_variable_defined_for_entity(variable_name) - self.check_period_validity(variable_name, period) + if isinstance(projector, projectors.Projector): + return projector - if options is None: - options = [] - - if config.ADD in options and config.DIVIDE in options: - raise ValueError('Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {})'.format(variable_name).encode('utf-8')) - elif config.ADD in options: - return self.simulation.calculate_add(variable_name, period) - elif config.DIVIDE in options: - return self.simulation.calculate_divide(variable_name, period) - else: - return self.simulation.calculate(variable_name, period) + msg = f"You tried to use the '{attribute}' of '{self.entity.key}' but that is not a known attribute." + raise AttributeError( + msg, + ) # Helpers - def get_holder(self, variable_name): - self.entity.check_variable_defined_for_entity(variable_name) - holder = self._holders.get(variable_name) - if holder: - return holder - variable = self.entity.get_variable(variable_name) - self._holders[variable_name] = holder = Holder(variable, self) - return holder - - def get_memory_usage(self, variables = None): - holders_memory_usage = { - variable_name: holder.get_memory_usage() - for variable_name, holder in self._holders.items() - if variables is None or variable_name in variables - } - - total_memory_usage = sum( - holder_memory_usage['total_nb_bytes'] for holder_memory_usage in holders_memory_usage.values() - ) - - return dict( - total_nb_bytes = total_memory_usage, - by_variable = holders_memory_usage - ) - @projectors.projectable - def has_role(self, role): - """ - Check if a person has a given role within its :any:`GroupEntity` + def has_role(self, role: t.Role) -> None | t.BoolArray: + """Check if a person has a given role within its `GroupEntity`. - Example: + Example: + >>> person.has_role(Household.CHILD) + >>> array([False]) - >>> person.has_role(Household.CHILD) - >>> array([False]) """ + if self.simulation is None: + return None + self.entity.check_role_validity(role) + group_population = self.simulation.get_population(role.entity.plural) + if role.subroles: - return numpy.logical_or.reduce([group_population.members_role == subrole for subrole in role.subroles]) - else: - return group_population.members_role == role + return numpy.logical_or.reduce( + [group_population.members_role == subrole for subrole in role.subroles], + ) + + return group_population.members_role == role @projectors.projectable - def value_from_partner(self, array, entity, role): + def value_from_partner( + self, + array: t.FloatArray, + entity: projectors.Projector, + role: t.Role, + ) -> None | t.FloatArray: self.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) - if not role.subroles or not len(role.subroles) == 2: - raise Exception('Projection to partner is only implemented for roles having exactly two subroles.') + if not role.subroles or len(role.subroles) != 2: + msg = "Projection to partner is only implemented for roles having exactly two subroles." + raise Exception( + msg, + ) [subrole_1, subrole_2] = role.subroles value_subrole_1 = entity.value_from_person(array, subrole_1) @@ -143,28 +83,39 @@ def value_from_partner(self, array, entity, role): return numpy.select( [self.has_role(subrole_1), self.has_role(subrole_2)], [value_subrole_2, value_subrole_1], - ) + ) @projectors.projectable - def get_rank(self, entity, criteria, condition = True): - """ - Get the rank of a person within an entity according to a criteria. + def get_rank( + self, + entity: Population, + criteria: t.FloatArray, + condition: bool = True, + ) -> t.IntArray: + """Get the rank of a person within an entity according to a criteria. The person with rank 0 has the minimum value of criteria. If condition is specified, then the persons who don't respect it are not taken into account and their rank is -1. Example: - - >>> age = person('age', period) # e.g [32, 34, 2, 8, 1] + >>> age = person("age", period) # e.g [32, 34, 2, 8, 1] >>> person.get_rank(household, age) >>> [3, 4, 0, 2, 1] - >>> is_child = person.has_role(Household.CHILD) # [False, False, True, True, True] - >>> person.get_rank(household, - age, condition = is_child) # Sort in reverse order so that the eldest child gets the rank 0. + >>> is_child = person.has_role( + ... Household.CHILD + ... ) # [False, False, True, True, True] + >>> person.get_rank( + ... household, -age, condition=is_child + ... ) # Sort in reverse order so that the eldest child gets the rank 0. >>> [-1, -1, 1, 0, 2] - """ + """ # If entity is for instance 'person.household', we get the reference entity 'household' behind the projector - entity = entity if not isinstance(entity, Projector) else entity.reference_entity + entity = ( + entity + if not isinstance(entity, projectors.Projector) + else entity.reference_entity + ) positions = entity.members_position biggest_entity_size = numpy.max(positions) + 1 @@ -172,10 +123,12 @@ def get_rank(self, entity, criteria, condition = True): ids = entity.members_entity_id # Matrix: the value in line i and column j is the value of criteria for the jth person of the ith entity - matrix = numpy.asarray([ - entity.value_nth_person(k, filtered_criteria, default = numpy.inf) - for k in range(biggest_entity_size) - ]).transpose() + matrix = numpy.asarray( + [ + entity.value_nth_person(k, filtered_criteria, default=numpy.inf) + for k in range(biggest_entity_size) + ], + ).transpose() # We double-argsort all lines of the matrix. # Double-argsorting gets the rank of each value once sorted diff --git a/openfisca_core/populations/types.py b/openfisca_core/populations/types.py new file mode 100644 index 0000000000..07f34d2f5f --- /dev/null +++ b/openfisca_core/populations/types.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Iterable, MutableMapping, Sequence +from typing import NamedTuple, Union +from typing_extensions import TypeAlias, TypedDict + +from openfisca_core.types import ( + Array, + CoreEntity, + CorePopulation, + DTypeLike, + EntityKey, + GroupEntity, + Holder, + MemoryUsage, + Period, + PeriodInt, + PeriodStr, + Role, + Simulation, + SingleEntity, + SinglePopulation, + VariableName, +) + +import enum + +import strenum +from numpy import ( + bool_ as BoolDType, + float32 as FloatDType, + generic as VarDType, + int32 as IntDType, + str_ as StrDType, +) + +# Commons + +#: Type alias for an array of strings. +IntArray: TypeAlias = Array[IntDType] + +#: Type alias for an array of strings. +StrArray: TypeAlias = Array[StrDType] + +#: Type alias for an array of booleans. +BoolArray: TypeAlias = Array[BoolDType] + +#: Type alias for an array of floats. +FloatArray: TypeAlias = Array[FloatDType] + +#: Type alias for an array of generic objects. +VarArray: TypeAlias = Array[VarDType] + +# Periods + +#: Type alias for a period-like object. +PeriodLike: TypeAlias = Union[Period, PeriodStr, PeriodInt] + +# Populations + +#: Type alias for a population's holders. +HolderByVariable: TypeAlias = MutableMapping[VariableName, Holder] + +# TODO(Mauko Quiroga-Alvarado): I'm not sure if this type alias is correct. +# https://openfisca.org/doc/coding-the-legislation/50_entities.html +Members: TypeAlias = Iterable[SinglePopulation] + + +class Option(strenum.StrEnum): + ADD = enum.auto() + DIVIDE = enum.auto() + + def __contains__(self, option: str) -> bool: + return option.upper() is self + + +class Calculate(NamedTuple): + variable: VariableName + period: Period + option: None | Sequence[Option] + + +class MemoryUsageByVariable(TypedDict, total=False): + by_variable: dict[VariableName, MemoryUsage] + total_nb_bytes: int + + +__all__ = [ + "CoreEntity", + "CorePopulation", + "DTypeLike", + "EntityKey", + "GroupEntity", + "Holder", + "MemoryUsage", + "Period", + "Role", + "Simulation", + "SingleEntity", + "SinglePopulation", + "VarDType", + "VariableName", +] diff --git a/openfisca_core/projectors/__init__.py b/openfisca_core/projectors/__init__.py index 02982bf982..28776e3cf9 100644 --- a/openfisca_core/projectors/__init__.py +++ b/openfisca_core/projectors/__init__.py @@ -21,8 +21,19 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import projectable, get_projector_from_shortcut # noqa: F401 -from .projector import Projector # noqa: F401 -from .entity_to_person_projector import EntityToPersonProjector # noqa: F401 -from .first_person_to_entity_projector import FirstPersonToEntityProjector # noqa: F401 -from .unique_role_to_entity_projector import UniqueRoleToEntityProjector # noqa: F401 +from . import typing +from .entity_to_person_projector import EntityToPersonProjector +from .first_person_to_entity_projector import FirstPersonToEntityProjector +from .helpers import get_projector_from_shortcut, projectable +from .projector import Projector +from .unique_role_to_entity_projector import UniqueRoleToEntityProjector + +__all__ = [ + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "get_projector_from_shortcut", + "projectable", + "Projector", + "UniqueRoleToEntityProjector", + "typing", +] diff --git a/openfisca_core/projectors/entity_to_person_projector.py b/openfisca_core/projectors/entity_to_person_projector.py index 3990233c70..392fda08a1 100644 --- a/openfisca_core/projectors/entity_to_person_projector.py +++ b/openfisca_core/projectors/entity_to_person_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class EntityToPersonProjector(Projector): """For instance person.family.""" - def __init__(self, entity, parent = None): + def __init__(self, entity, parent=None) -> None: self.reference_entity = entity self.parent = parent diff --git a/openfisca_core/projectors/first_person_to_entity_projector.py b/openfisca_core/projectors/first_person_to_entity_projector.py index 3912ccef1e..d986460cdc 100644 --- a/openfisca_core/projectors/first_person_to_entity_projector.py +++ b/openfisca_core/projectors/first_person_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class FirstPersonToEntityProjector(Projector): """For instance famille.first_person.""" - def __init__(self, entity, parent = None): + def __init__(self, entity, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/projectors/helpers.py b/openfisca_core/projectors/helpers.py index 7bc55e0fd9..8071eecf94 100644 --- a/openfisca_core/projectors/helpers.py +++ b/openfisca_core/projectors/helpers.py @@ -1,25 +1,140 @@ -from openfisca_core import projectors +from __future__ import annotations + +from collections.abc import Mapping + +from openfisca_core.types import GroupEntity, Role, SingleEntity + +from openfisca_core import entities, projectors + +from .typing import GroupPopulation, Population def projectable(function): - """ - Decorator to indicate that when called on a projector, the outcome of the function must be projected. + """Decorator to indicate that when called on a projector, the outcome of the function must be projected. For instance person.household.sum(...) must be projected on person, while it would not make sense for person.household.get_holder. """ function.projectable = True return function -def get_projector_from_shortcut(population, shortcut, parent = None): - if population.entity.is_person: - if shortcut in population.simulation.populations: - entity_2 = population.simulation.populations[shortcut] - return projectors.EntityToPersonProjector(entity_2, parent) - else: - if shortcut == 'first_person': - return projectors.FirstPersonToEntityProjector(population, parent) - role = next((role for role in population.entity.flattened_roles if (role.max == 1) and (role.key == shortcut)), None) - if role: +def get_projector_from_shortcut( + population: Population | GroupPopulation, + shortcut: str, + parent: projectors.Projector | None = None, +) -> projectors.Projector | None: + """Get a projector from a shortcut. + + Projectors are used to project an invidividual Population's or a + collective GroupPopulation's on to other populations. + + The currently available cases are projecting: + - from an individual to a group + - from a group to an individual + - from a group to an individual with a unique role + + For example, if there are two entities, person (Entity) and household + (GroupEntity), on which calculations can be run (Population and + GroupPopulation respectively), and there is a Variable "rent" defined for + the household entity, then `person.household("rent")` will assign a rent to + every person within that household. + + Behind the scenes, this is done thanks to a Projector, and this function is + used to find the appropriate one for each case. In the above example, the + `shortcut` argument would be "household", and the `population` argument + would be the Population linked to the "person" Entity in the context + of a specific Simulation and TaxBenefitSystem. + + Args: + population (Population | GroupPopulation): Where to project from. + shortcut (str): Where to project to. + parent: ??? + + Examples: + >>> from openfisca_core import ( + ... entities, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... ) + + >>> entity = entities.Entity("person", "", "", "") + + >>> group_entity_1 = entities.GroupEntity("family", "", "", "", []) + + >>> roles = [ + ... {"key": "person", "max": 1}, + ... {"key": "animal", "subroles": ["cat", "dog"]}, + ... ] + + >>> group_entity_2 = entities.GroupEntity("household", "", "", "", roles) + + >>> population = populations.Population(entity) + + >>> group_population_1 = populations.GroupPopulation(group_entity_1, []) + + >>> group_population_2 = populations.GroupPopulation(group_entity_2, []) + + >>> populations = { + ... entity.key: population, + ... group_entity_1.key: group_population_1, + ... group_entity_2.key: group_population_2, + ... } + + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem( + ... [entity, group_entity_1, group_entity_2] + ... ) + + >>> simulation = simulations.Simulation(tax_benefit_system, populations) + + >>> get_projector_from_shortcut(population, "person") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "family") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "household") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "first_person") + <...FirstPersonToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "person") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "cat") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "dog") + <...UniqueRoleToEntityProjector object at ...> + + """ + entity: SingleEntity | GroupEntity = population.entity + + if isinstance(entity, entities.Entity): + populations: Mapping[ + str, + Population | GroupPopulation, + ] = population.simulation.populations + + if shortcut not in populations: + return None + + return projectors.EntityToPersonProjector(populations[shortcut], parent) + + if shortcut == "first_person": + return projectors.FirstPersonToEntityProjector(population, parent) + + if isinstance(entity, entities.GroupEntity): + role: Role | None = entities.find_role(entity.roles, shortcut, total=1) + + if role is not None: return projectors.UniqueRoleToEntityProjector(population, role, parent) - if shortcut in population.entity.containing_entities: - return getattr(projectors.FirstPersonToEntityProjector(population, parent), shortcut) + + if shortcut in entity.containing_entities: + projector: projectors.Projector = getattr( + projectors.FirstPersonToEntityProjector(population, parent), + shortcut, + ) + return projector + + return None diff --git a/openfisca_core/projectors/projector.py b/openfisca_core/projectors/projector.py index 41138813b5..37881201dc 100644 --- a/openfisca_core/projectors/projector.py +++ b/openfisca_core/projectors/projector.py @@ -6,12 +6,16 @@ class Projector: parent = None def __getattr__(self, attribute): - projector = helpers.get_projector_from_shortcut(self.reference_entity, attribute, parent = self) + projector = helpers.get_projector_from_shortcut( + self.reference_entity, + attribute, + parent=self, + ) if projector: return projector reference_attr = getattr(self.reference_entity, attribute) - if not hasattr(reference_attr, 'projectable'): + if not hasattr(reference_attr, "projectable"): return reference_attr def projector_function(*args, **kwargs): @@ -28,8 +32,7 @@ def transform_and_bubble_up(self, result): transformed_result = self.transform(result) if self.parent is None: return transformed_result - else: - return self.parent.transform_and_bubble_up(transformed_result) + return self.parent.transform_and_bubble_up(transformed_result) def transform(self, result): return NotImplementedError() diff --git a/openfisca_core/projectors/typing.py b/openfisca_core/projectors/typing.py new file mode 100644 index 0000000000..a49bc96621 --- /dev/null +++ b/openfisca_core/projectors/typing.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Protocol + +from openfisca_core.types import GroupEntity, SingleEntity + + +class Population(Protocol): + @property + def entity(self) -> SingleEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class GroupPopulation(Protocol): + @property + def entity(self) -> GroupEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class Simulation(Protocol): + @property + def populations(self) -> Mapping[str, Population | GroupPopulation]: ... diff --git a/openfisca_core/projectors/unique_role_to_entity_projector.py b/openfisca_core/projectors/unique_role_to_entity_projector.py index 25b3258dc3..c565484339 100644 --- a/openfisca_core/projectors/unique_role_to_entity_projector.py +++ b/openfisca_core/projectors/unique_role_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class UniqueRoleToEntityProjector(Projector): - """ For instance famille.declarant_principal.""" + """For instance famille.declarant_principal.""" - def __init__(self, entity, role, parent = None): + def __init__(self, entity, role, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/reforms/reform.py b/openfisca_core/reforms/reform.py index 1ba0be30a8..76e7152334 100644 --- a/openfisca_core/reforms/reform.py +++ b/openfisca_core/reforms/reform.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from openfisca_core.parameters import ParameterNode @@ -5,25 +7,22 @@ class Reform(TaxBenefitSystem): - """ - A modified TaxBenefitSystem - + """A modified TaxBenefitSystem. - All reforms must subclass `Reform` and implement a method `apply()`. + All reforms must subclass `Reform` and implement a method `apply()`. - In this method, the reform can add or replace variables and call :any:`modify_parameters` to modify the parameters of the legislation. - - Example: + In this method, the reform can add or replace variables and call `modify_parameters` to modify the parameters of the legislation. + Example: >>> from openfisca_core import reforms >>> from openfisca_core.parameters import load_parameter_file >>> >>> def modify_my_parameters(parameters): - >>> # Add new parameters + >>> # Add new parameters >>> new_parameters = load_parameter_file(name='reform_name', file_path='path_to_yaml_file.yaml') >>> parameters.add_child('reform_name', new_parameters) >>> - >>> # Update a value + >>> # Update a value >>> parameters.taxes.some_tax.some_param.update(period=some_period, value=1000.0) >>> >>> return parameters @@ -33,13 +32,13 @@ class Reform(TaxBenefitSystem): >>> self.add_variable(some_variable) >>> self.update_variable(some_other_variable) >>> self.modify_parameters(modifier_function = modify_my_parameters) + """ + name = None - def __init__(self, baseline): - """ - :param baseline: Baseline TaxBenefitSystem. - """ + def __init__(self, baseline) -> None: + """:param baseline: Baseline TaxBenefitSystem.""" super().__init__(baseline.entities) self.baseline = baseline self.parameters = baseline.parameters @@ -47,8 +46,9 @@ def __init__(self, baseline): self.variables = baseline.variables.copy() self.decomposition_file_path = baseline.decomposition_file_path self.key = self.__class__.__name__ - if not hasattr(self, 'apply'): - raise Exception("Reform {} must define an `apply` function".format(self.key)) + if not hasattr(self, "apply"): + msg = f"Reform {self.key} must define an `apply` function" + raise Exception(msg) self.apply() def __getattr__(self, attribute): @@ -57,27 +57,30 @@ def __getattr__(self, attribute): @property def full_key(self): key = self.key - assert key is not None, 'key was not set for reform {} (name: {!r})'.format(self, self.name) - if self.baseline is not None and hasattr(self.baseline, 'key'): + assert ( + key is not None + ), f"key was not set for reform {self} (name: {self.name!r})" + if self.baseline is not None and hasattr(self.baseline, "key"): baseline_full_key = self.baseline.full_key - key = '.'.join([baseline_full_key, key]) + key = f"{baseline_full_key}.{key}" return key def modify_parameters(self, modifier_function): - """ - Make modifications on the parameters of the legislation + """Make modifications on the parameters of the legislation. Call this function in `apply()` if the reform asks for legislation parameter modifications. - :param modifier_function: A function that takes an object of type :any:`ParameterNode` and should return an object of the same type. + Args: + modifier_function: A function that takes a :obj:`.ParameterNode` and should return an object of the same type. + """ baseline_parameters = self.baseline.parameters baseline_parameters_copy = copy.deepcopy(baseline_parameters) reform_parameters = modifier_function(baseline_parameters_copy) if not isinstance(reform_parameters, ParameterNode): return ValueError( - 'modifier_function {} in module {} must return a ParameterNode' - .format(modifier_function.__name__, modifier_function.__module__,) - ) + f"modifier_function {modifier_function.__name__} in module {modifier_function.__module__} must return a ParameterNode", + ) self.parameters = reform_parameters self._parameters_at_instant_cache = {} + return None diff --git a/openfisca_core/scripts/__init__.py b/openfisca_core/scripts/__init__.py index 9e0a3b67bc..e9080f2381 100644 --- a/openfisca_core/scripts/__init__.py +++ b/openfisca_core/scripts/__init__.py @@ -1,18 +1,33 @@ -# -*- coding: utf-8 -*- - -import traceback import importlib import logging import pkgutil +import traceback from os import linesep log = logging.getLogger(__name__) def add_tax_benefit_system_arguments(parser): - parser.add_argument('-c', '--country-package', action = 'store', help = 'country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".') - parser.add_argument('-e', '--extensions', action = 'store', help = 'extensions to load', nargs = '*') - parser.add_argument('-r', '--reforms', action = 'store', help = 'reforms to apply to the country package', nargs = '*') + parser.add_argument( + "-c", + "--country-package", + action="store", + help='country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".', + ) + parser.add_argument( + "-e", + "--extensions", + action="store", + help="extensions to load", + nargs="*", + ) + parser.add_argument( + "-r", + "--reforms", + action="store", + help="reforms to apply to the country package", + nargs="*", + ) return parser @@ -23,14 +38,21 @@ def build_tax_benefit_system(country_package_name, extensions, reforms): try: country_package = importlib.import_module(country_package_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(country_package_name), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .']) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{country_package_name}`.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + ], + ) raise ImportError(message) - if not hasattr(country_package, 'CountryTaxBenefitSystem'): - raise ImportError('`{}` does not seem to be a valid Openfisca country package.'.format(country_package_name)) + if not hasattr(country_package, "CountryTaxBenefitSystem"): + msg = f"`{country_package_name}` does not seem to be a valid Openfisca country package." + raise ImportError( + msg, + ) country_package = importlib.import_module(country_package_name) tax_benefit_system = country_package.CountryTaxBenefitSystem() @@ -54,19 +76,31 @@ def detect_country_package(): installed_country_packages = [] for module_description in pkgutil.iter_modules(): module_name = module_description[1] - if 'openfisca' in module_name.lower(): + if "openfisca" in module_name.lower(): try: module = importlib.import_module(module_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(module_name), - 'Look at the stack trace above to determine the error that stopped installed modules detection.']) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{module_name}`.", + "Look at the stack trace above to determine the error that stopped installed modules detection.", + ], + ) raise ImportError(message) - if hasattr(module, 'CountryTaxBenefitSystem'): + if hasattr(module, "CountryTaxBenefitSystem"): installed_country_packages.append(module_name) if len(installed_country_packages) == 0: - raise ImportError('No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option.') + msg = "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." + raise ImportError( + msg, + ) if len(installed_country_packages) > 1: - log.warning('Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.'.format(', '.join(installed_country_packages), installed_country_packages[0])) + log.warning( + "Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.".format( + ", ".join(installed_country_packages), + installed_country_packages[0], + ), + ) return installed_country_packages[0] diff --git a/openfisca_core/scripts/find_placeholders.py b/openfisca_core/scripts/find_placeholders.py index b14fd5fea9..b7b5a81969 100644 --- a/openfisca_core/scripts/find_placeholders.py +++ b/openfisca_core/scripts/find_placeholders.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 -import os import fnmatch +import os import sys from bs4 import BeautifulSoup @@ -10,42 +9,37 @@ def find_param_files(input_dir): param_files = [] - for root, dirnames, filenames in os.walk(input_dir): - for filename in fnmatch.filter(filenames, '*.xml'): + for root, _dirnames, filenames in os.walk(input_dir): + for filename in fnmatch.filter(filenames, "*.xml"): param_files.append(os.path.join(root, filename)) return param_files def find_placeholders(filename_input): - with open(filename_input, 'r') as f: + with open(filename_input) as f: xml_content = f.read() xml_parsed = BeautifulSoup(xml_content, "lxml-xml") - placeholders = xml_parsed.find_all('PLACEHOLDER') + placeholders = xml_parsed.find_all("PLACEHOLDER") output_list = [] for placeholder in placeholders: parent_list = list(placeholder.parents)[:-1] - path = '.'.join([p.attrs['code'] for p in parent_list if 'code' in p.attrs][::-1]) + path = ".".join( + [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1], + ) - deb = placeholder.attrs['deb'] + deb = placeholder.attrs["deb"] output_list.append((deb, path)) - output_list = sorted(output_list, key = lambda x: x[0]) - - return output_list + return sorted(output_list, key=lambda x: x[0]) if __name__ == "__main__": - print('''find_placeholders.py : Find nodes PLACEHOLDER in xml parameter files -Usage : - python find_placeholders /dir/to/search -''') - - assert(len(sys.argv) == 2) + assert len(sys.argv) == 2 input_dir = sys.argv[1] param_files = find_param_files(input_dir) @@ -53,9 +47,5 @@ def find_placeholders(filename_input): for filename_input in param_files: output_list = find_placeholders(filename_input) - print('File {}'.format(filename_input)) - - for deb, path in output_list: - print('{} {}'.format(deb, path)) - - print('\n') + for _deb, _path in output_list: + pass diff --git a/openfisca_core/scripts/measure_numpy_condition_notations.py b/openfisca_core/scripts/measure_numpy_condition_notations.py index 3132205a1c..65e48f6e2c 100755 --- a/openfisca_core/scripts/measure_numpy_condition_notations.py +++ b/openfisca_core/scripts/measure_numpy_condition_notations.py @@ -1,33 +1,30 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 -""" -Measure and compare different vectorial condition notations: +"""Measure and compare different vectorial condition notations: - using multiplication notation: (choice == 1) * choice_1_value + (choice == 2) * choice_2_value -- using np.select: the same than multiplication but more idiomatic like a "switch" control-flow statement -- using np.fromiter: iterates in Python over the array and calculates lazily only the required values +- using numpy.select: the same than multiplication but more idiomatic like a "switch" control-flow statement +- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values. The aim of this script is to compare the time taken by the calculation of the values """ -from contextlib import contextmanager + import argparse import sys import time +from contextlib import contextmanager -import numpy as np - +import numpy args = None @contextmanager def measure_time(title): - t1 = time.time() + time.time() yield - t2 = time.time() - print('{}\t: {:.8f} seconds elapsed'.format(title, t2 - t1)) + time.time() def switch_fromiter(conditions, function_by_condition, dtype): @@ -39,34 +36,28 @@ def get_or_store_value(condition): value_by_condition[condition] = value return value_by_condition[condition] - return np.fromiter( - ( - get_or_store_value(condition) - for condition in conditions - ), + return numpy.fromiter( + (get_or_store_value(condition) for condition in conditions), dtype, - ) + ) def switch_select(conditions, value_by_condition): - condlist = [ - conditions == condition - for condition in value_by_condition.keys() - ] - return np.select(condlist, value_by_condition.values()) + condlist = [conditions == condition for condition in value_by_condition] + return numpy.select(condlist, value_by_condition.values()) -def calculate_choice_1_value(): +def calculate_choice_1_value() -> int: time.sleep(args.calculate_time) return 80 -def calculate_choice_2_value(): +def calculate_choice_2_value() -> int: time.sleep(args.calculate_time) return 90 -def calculate_choice_3_value(): +def calculate_choice_3_value() -> int: time.sleep(args.calculate_time) return 95 @@ -75,61 +66,70 @@ def test_multiplication(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_3_value() - result = (choice == 1) * choice_1_value + (choice == 2) * choice_2_value + (choice == 3) * choice_3_value - return result + return ( + (choice == 1) * choice_1_value + + (choice == 2) * choice_2_value + + (choice == 3) * choice_3_value + ) def test_switch_fromiter(choice): - result = switch_fromiter( + return switch_fromiter( choice, { 1: calculate_choice_1_value, 2: calculate_choice_2_value, 3: calculate_choice_3_value, - }, - dtype = np.int, - ) - return result + }, + dtype=int, + ) def test_switch_select(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_2_value() - result = switch_select( + return switch_select( choice, { 1: choice_1_value, 2: choice_2_value, 3: choice_3_value, - }, - ) - return result + }, + ) -def test_all_notations(): +def test_all_notations() -> None: # choice is an array with 1 and 2 items like [2, 1, ..., 1, 2] - choice = np.random.randint(2, size = args.array_length) + 1 + choice = numpy.random.randint(2, size=args.array_length) + 1 - with measure_time('multiplication'): + with measure_time("multiplication"): test_multiplication(choice) - with measure_time('switch_select'): + with measure_time("switch_select"): test_switch_select(choice) - with measure_time('switch_fromiter'): + with measure_time("switch_fromiter"): test_switch_fromiter(choice) -def main(): - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument('--array-length', default = 1000, type = int, help = "length of the array") - parser.add_argument('--calculate-time', default = 0.1, type = float, - help = "time taken by the calculation in seconds") +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--array-length", + default=1000, + type=int, + help="length of the array", + ) + parser.add_argument( + "--calculate-time", + default=0.1, + type=float, + help="time taken by the calculation in seconds", + ) global args args = parser.parse_args() - print(args) test_all_notations() diff --git a/openfisca_core/scripts/measure_performances.py b/openfisca_core/scripts/measure_performances.py index 1d84ddd585..48b99c93f8 100644 --- a/openfisca_core/scripts/measure_performances.py +++ b/openfisca_core/scripts/measure_performances.py @@ -1,35 +1,32 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 """Measure performances of a basic tax-benefit system to compare to other OpenFisca implementations.""" + import argparse import logging import sys import time -import numpy as np +import numpy from numpy.core.defchararray import startswith from openfisca_core import periods, simulations -from openfisca_core.periods import ETERNITY from openfisca_core.entities import build_entity -from openfisca_core.variables import Variable +from openfisca_core.periods import DateUnit from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_core.tools import assert_near - +from openfisca_core.variables import Variable args = None def timeit(method): def timed(*args, **kwargs): - start_time = time.time() - result = method(*args, **kwargs) + time.time() + return method(*args, **kwargs) # print '%r (%r, %r) %2.9f s' % (method.__name__, args, kw, time.time() - start_time) - print('{:2.6f} s'.format(time.time() - start_time)) - return result return timed @@ -37,31 +34,31 @@ def timed(*args, **kwargs): # Entities Famille = build_entity( - key = "famille", - plural = "familles", - label = 'Famille', - roles = [ + key="famille", + plural="familles", + label="Famille", + roles=[ { - 'key': 'parent', - 'plural': 'parents', - 'label': 'Parents', - 'subroles': ['demandeur', 'conjoint'] - }, + "key": "parent", + "plural": "parents", + "label": "Parents", + "subroles": ["demandeur", "conjoint"], + }, { - 'key': 'enfant', - 'plural': 'enfants', - 'label': 'Enfants', - } - ] - ) + "key": "enfant", + "plural": "enfants", + "label": "Enfants", + }, + ], +) Individu = build_entity( - key = "individu", - plural = "individus", - label = 'Individu', - is_person = True, - ) + key="individu", + plural="individus", + label="Individu", + is_person=True, +) # Input variables @@ -73,16 +70,16 @@ class age_en_mois(Variable): class birth(Variable): - value_type = 'Date' + value_type = "Date" entity = Individu label = "Date de naissance" class city_code(Variable): - value_type = 'FixedStr' + value_type = "FixedStr" max_length = 5 entity = Famille - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY label = """Code INSEE "city_code" de la commune de résidence de la famille""" @@ -94,30 +91,33 @@ class salaire_brut(Variable): # Calculated variables + class age(Variable): value_type = int entity = Individu label = "Âge (en nombre d'années)" def formula(self, simulation, period): - birth = simulation.get_array('birth', period) + birth = simulation.get_array("birth", period) if birth is None: - age_en_mois = simulation.get_array('age_en_mois', period) + age_en_mois = simulation.get_array("age_en_mois", period) if age_en_mois is not None: return age_en_mois // 12 - birth = simulation.calculate('birth', period) - return (np.datetime64(period.date) - birth).astype('timedelta64[Y]') + birth = simulation.calculate("birth", period) + return (numpy.datetime64(period.date) - birth).astype("timedelta64[Y]") class dom_tom(Variable): - value_type = 'Bool' + value_type = "Bool" entity = Famille label = "La famille habite-t-elle les DOM-TOM ?" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - city_code = simulation.calculate('city_code', period) - return np.logical_or(startswith(city_code, '97'), startswith(city_code, '98')) + period = period.start.period(DateUnit.YEAR).offset("first-of") + city_code = simulation.calculate("city_code", period) + return numpy.logical_or( + startswith(city_code, "97"), startswith(city_code, "98") + ) class revenu_disponible(Variable): @@ -126,9 +126,9 @@ class revenu_disponible(Variable): label = "Revenu disponible de l'individu" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - rsa = simulation.calculate('rsa', period) - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.YEAR).offset("first-of") + rsa = simulation.calculate("rsa", period) + salaire_imposable = simulation.calculate("salaire_imposable", period) return rsa + salaire_imposable * 0.7 @@ -138,18 +138,18 @@ class rsa(Variable): label = "RSA" def formula_2010_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 100.0 def formula_2011_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 200.0 def formula_2013_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 300 @@ -158,10 +158,10 @@ class salaire_imposable(Variable): entity = Individu label = "Salaire imposable" - def formula(individu, period): - period = period.start.period('year').offset('first-of') - dom_tom = individu.famille('dom_tom', period) - salaire_net = individu('salaire_net', period) + def formula(self, period): + period = period.start.period(DateUnit.YEAR).offset("first-of") + dom_tom = self.famille("dom_tom", period) + salaire_net = self("salaire_net", period) return salaire_net * 0.9 - 100 * dom_tom @@ -171,8 +171,8 @@ class salaire_net(Variable): label = "Salaire net" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - salaire_brut = simulation.calculate('salaire_brut', period) + period = period.start.period(DateUnit.YEAR).offset("first-of") + salaire_brut = simulation.calculate("salaire_brut", period) return salaire_brut * 0.8 @@ -180,13 +180,26 @@ def formula(self, simulation, period): tax_benefit_system = TaxBenefitSystem([Famille, Individu]) -tax_benefit_system.add_variables(age_en_mois, birth, city_code, salaire_brut, age, - dom_tom, revenu_disponible, rsa, salaire_imposable, salaire_net) +tax_benefit_system.add_variables( + age_en_mois, + birth, + city_code, + salaire_brut, + age, + dom_tom, + revenu_disponible, + rsa, + salaire_imposable, + salaire_net, +) @timeit -def check_revenu_disponible(year, city_code, expected_revenu_disponible): - simulation = simulations.Simulation(period = periods.period(year), tax_benefit_system = tax_benefit_system) +def check_revenu_disponible(year, city_code, expected_revenu_disponible) -> None: + simulation = simulations.Simulation( + period=periods.period(year), + tax_benefit_system=tax_benefit_system, + ) famille = simulation.populations["famille"] famille.count = 3 famille.roles_count = 2 @@ -194,31 +207,84 @@ def check_revenu_disponible(year, city_code, expected_revenu_disponible): individu = simulation.populations["individu"] individu.count = 6 individu.step_size = 2 - simulation.get_or_new_holder("city_code").array = np.array([city_code, city_code, city_code]) - famille.members_entity_id = np.array([0, 0, 1, 1, 2, 2]) - simulation.get_or_new_holder("salaire_brut").array = np.array([0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0]) - revenu_disponible = simulation.calculate('revenu_disponible') - assert_near(revenu_disponible, expected_revenu_disponible, absolute_error_margin = 0.005) + simulation.get_or_new_holder("city_code").array = numpy.array( + [city_code, city_code, city_code], + ) + famille.members_entity_id = numpy.array([0, 0, 1, 1, 2, 2]) + simulation.get_or_new_holder("salaire_brut").array = numpy.array( + [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0], + ) + revenu_disponible = simulation.calculate("revenu_disponible") + assert_near( + revenu_disponible, + expected_revenu_disponible, + absolute_error_margin=0.005, + ) -def main(): - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument('-v', '--verbose', action = 'store_true', default = False, help = "increase output verbosity") +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="increase output verbosity", + ) global args args = parser.parse_args() - logging.basicConfig(level = logging.DEBUG if args.verbose else logging.WARNING, stream = sys.stdout) - - check_revenu_disponible(2009, '75101', np.array([0, 0, 25200, 0, 50400, 0])) - check_revenu_disponible(2010, '75101', np.array([1200, 1200, 25200, 1200, 50400, 1200])) - check_revenu_disponible(2011, '75101', np.array([2400, 2400, 25200, 2400, 50400, 2400])) - check_revenu_disponible(2012, '75101', np.array([2400, 2400, 25200, 2400, 50400, 2400])) - check_revenu_disponible(2013, '75101', np.array([3600, 3600, 25200, 3600, 50400, 3600])) - - check_revenu_disponible(2009, '97123', np.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0])) - check_revenu_disponible(2010, '97123', np.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0])) - check_revenu_disponible(2011, '98456', np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0])) - check_revenu_disponible(2012, '98456', np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0])) - check_revenu_disponible(2013, '98456', np.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0])) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, + ) + + check_revenu_disponible(2009, "75101", numpy.array([0, 0, 25200, 0, 50400, 0])) + check_revenu_disponible( + 2010, + "75101", + numpy.array([1200, 1200, 25200, 1200, 50400, 1200]), + ) + check_revenu_disponible( + 2011, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), + ) + check_revenu_disponible( + 2012, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), + ) + check_revenu_disponible( + 2013, + "75101", + numpy.array([3600, 3600, 25200, 3600, 50400, 3600]), + ) + + check_revenu_disponible( + 2009, + "97123", + numpy.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]), + ) + check_revenu_disponible( + 2010, + "97123", + numpy.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]), + ) + check_revenu_disponible( + 2011, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), + ) + check_revenu_disponible( + 2012, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), + ) + check_revenu_disponible( + 2013, + "98456", + numpy.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]), + ) if __name__ == "__main__": diff --git a/openfisca_core/scripts/measure_performances_fancy_indexing.py b/openfisca_core/scripts/measure_performances_fancy_indexing.py index 894250ef54..7c261e2fe3 100644 --- a/openfisca_core/scripts/measure_performances_fancy_indexing.py +++ b/openfisca_core/scripts/measure_performances_fancy_indexing.py @@ -1,24 +1,26 @@ # flake8: noqa T001 -import numpy as np import timeit -from openfisca_france import CountryTaxBenefitSystem -from openfisca_core.model_api import * # noqa analysis:ignore +import numpy +from openfisca_france import CountryTaxBenefitSystem tbs = CountryTaxBenefitSystem() N = 200000 -al_plaf_acc = tbs.get_parameters_at_instant('2015-01-01').prestations.al_plaf_acc -zone_apl = np.random.choice([1, 2, 3], N) -al_nb_pac = np.random.choice(6, N) -couple = np.random.choice([True, False], N) -formatted_zone = concat('plafond_pour_accession_a_la_propriete_zone_', zone_apl) # zone_apl returns 1, 2 or 3 but the parameters have a long name +al_plaf_acc = tbs.get_parameters_at_instant("2015-01-01").prestations.al_plaf_acc +zone_apl = numpy.random.choice([1, 2, 3], N) +al_nb_pac = numpy.random.choice(6, N) +couple = numpy.random.choice([True, False], N) +formatted_zone = concat( + "plafond_pour_accession_a_la_propriete_zone_", + zone_apl, +) # zone_apl returns 1, 2 or 3 but the parameters have a long name def formula_with(): plafonds = al_plaf_acc[formatted_zone] - result = ( + return ( plafonds.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + plafonds.menage_seul * couple * (al_nb_pac == 0) + plafonds.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) @@ -26,10 +28,10 @@ def formula_with(): + plafonds.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + plafonds.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + plafonds.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + plafonds.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) - - return result + + plafonds.menage_ou_isole_par_enfant_en_plus + * (al_nb_pac > 5) + * (al_nb_pac - 5) + ) def formula_without(): @@ -37,41 +39,51 @@ def formula_without(): z2 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_2 z3 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_3 - return (zone_apl == 1) * ( - z1.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z1.menage_seul * couple * (al_nb_pac == 0) - + z1.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z1.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z1.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z1.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z1.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z1.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) + (zone_apl == 2) * ( - z2.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z2.menage_seul * couple * (al_nb_pac == 0) - + z2.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z2.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z2.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z2.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z2.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z2.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) + (zone_apl == 3) * ( - z3.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z3.menage_seul * couple * (al_nb_pac == 0) - + z3.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z3.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z3.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z3.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z3.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z3.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + return ( + (zone_apl == 1) + * ( + z1.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z1.menage_seul * couple * (al_nb_pac == 0) + + z1.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z1.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z1.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z1.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z1.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z1.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) ) + + (zone_apl == 2) + * ( + z2.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z2.menage_seul * couple * (al_nb_pac == 0) + + z2.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z2.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z2.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z2.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z2.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z2.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + ) + + (zone_apl == 3) + * ( + z3.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z3.menage_seul * couple * (al_nb_pac == 0) + + z3.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z3.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z3.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z3.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z3.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z3.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + ) + ) -if __name__ == '__main__': - - time_with = timeit.timeit('formula_with()', setup = "from __main__ import formula_with", number = 50) - time_without = timeit.timeit('formula_without()', setup = "from __main__ import formula_without", number = 50) - - print("Computing with dynamic legislation computing took {}".format(time_with)) - print("Computing without dynamic legislation computing took {}".format(time_without)) - print("Ratio: {}".format(time_with / time_without)) +if __name__ == "__main__": + time_with = timeit.timeit( + "formula_with()", + setup="from __main__ import formula_with", + number=50, + ) + time_without = timeit.timeit( + "formula_without()", + setup="from __main__ import formula_without", + number=50, + ) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py index 6e8f672988..38538d644a 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py @@ -1,34 +1,30 @@ -# -*- coding: utf-8 -*- - -''' xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_country_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_country_template.py` -''' -import sys +""" + import os +import sys + +from openfisca_country_template import COUNTRY_DIR, CountryTaxBenefitSystem -from openfisca_country_template import CountryTaxBenefitSystem, COUNTRY_DIR from . import xml_to_yaml tax_benefit_system = CountryTaxBenefitSystem() -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = 'yaml_parameters' +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" -param_dir = os.path.join(COUNTRY_DIR, 'parameters') +param_dir = os.path.join(COUNTRY_DIR, "parameters") param_files = [ - 'benefits.xml', - 'general.xml', - 'taxes.xml', - ] + "benefits.xml", + "general.xml", + "taxes.xml", +] legislation_xml_info_list = [ - (os.path.join(param_dir, param_file), []) - for param_file in param_files - ] + (os.path.join(param_dir, param_file), []) for param_file in param_files +] xml_to_yaml.write_parameters(legislation_xml_info_list, target_path) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py index 91144ed6a0..0b57c19016 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py @@ -1,31 +1,26 @@ -# -*- coding: utf-8 -*- - -''' xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_extension_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_extension_template.py` -''' +""" -import sys import os +import sys -from . import xml_to_yaml import openfisca_extension_template -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = 'yaml_parameters' +from . import xml_to_yaml + +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.dirname(openfisca_extension_template.__file__) param_files = [ - 'parameters.xml', - ] + "parameters.xml", +] legislation_xml_info_list = [ - (os.path.join(param_dir, param_file), []) - for param_file in param_files - ] + (os.path.join(param_dir, param_file), []) for param_file in param_files +] xml_to_yaml.write_parameters(legislation_xml_info_list, target_path) diff --git a/openfisca_core/scripts/migrations/v24_to_25.py b/openfisca_core/scripts/migrations/v24_to_25.py index 853c4e9a94..08bbeddc3b 100644 --- a/openfisca_core/scripts/migrations/v24_to_25.py +++ b/openfisca_core/scripts/migrations/v24_to_25.py @@ -1,37 +1,52 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 import argparse -import os import glob +import os +from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedSeq -from openfisca_core.scripts import add_tax_benefit_system_arguments, build_tax_benefit_system +from openfisca_core.scripts import ( + add_tax_benefit_system_arguments, + build_tax_benefit_system, +) -from ruamel.yaml import YAML yaml = YAML() yaml.default_flow_style = False yaml.width = 4096 -TEST_METADATA = {'period', 'name', 'reforms', 'only_variables', 'ignore_variables', 'absolute_error_margin', 'relative_error_margin', 'description', 'keywords'} +TEST_METADATA = { + "period", + "name", + "reforms", + "only_variables", + "ignore_variables", + "absolute_error_margin", + "relative_error_margin", + "description", + "keywords", +} def build_parser(): parser = argparse.ArgumentParser() - parser.add_argument('path', help = "paths (files or directories) of tests to execute", nargs = '+') - parser = add_tax_benefit_system_arguments(parser) + parser.add_argument( + "path", + help="paths (files or directories) of tests to execute", + nargs="+", + ) + return add_tax_benefit_system_arguments(parser) - return parser - -class Migrator(object): - - def __init__(self, tax_benefit_system): +class Migrator: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system - self.entities_by_plural = {entity.plural: entity for entity in self.tax_benefit_system.entities} + self.entities_by_plural = { + entity.plural: entity for entity in self.tax_benefit_system.entities + } - def migrate(self, path): + def migrate(self, path) -> None: if isinstance(path, list): for item in path: self.migrate(item) @@ -49,8 +64,6 @@ def migrate(self, path): return - print('Migrating {}.'.format(path)) - with open(path) as yaml_file: tests = yaml.safe_load(yaml_file) if isinstance(tests, CommentedSeq): @@ -58,23 +71,23 @@ def migrate(self, path): else: migrated_tests = self.convert_test(tests) - with open(path, 'w') as yaml_file: + with open(path, "w") as yaml_file: yaml.dump(migrated_tests, yaml_file) def convert_test(self, test): - if test.get('output'): + if test.get("output"): # This test is already converted, ignoring it return test result = {} - outputs = test.pop('output_variables') - inputs = test.pop('input_variables', {}) + outputs = test.pop("output_variables") + inputs = test.pop("input_variables", {}) for key, value in test.items(): if key in TEST_METADATA: result[key] = value else: inputs[key] = value - result['input'] = self.convert_inputs(inputs) - result['output'] = outputs + result["input"] = self.convert_inputs(inputs) + result["output"] = outputs return result def convert_inputs(self, inputs): @@ -91,15 +104,15 @@ def convert_inputs(self, inputs): continue results[entity_plural] = self.convert_entities(entity, entities_description) - results = self.generate_missing_entities(results) - - return results + return self.generate_missing_entities(results) def convert_entities(self, entity, entities_description): return { - entity_description.get('id', "{}_{}".format(entity.key, index)): remove_id(entity_description) + entity_description.get("id", f"{entity.key}_{index}"): remove_id( + entity_description, + ) for index, entity_description in enumerate(entities_description) - } + } def generate_missing_entities(self, inputs): for entity in self.tax_benefit_system.entities: @@ -108,29 +121,33 @@ def generate_missing_entities(self, inputs): persons = inputs[self.tax_benefit_system.person_entity.plural] if len(persons) == 1: person_id = next(iter(persons)) - inputs[entity.key] = {entity.roles[0].plural or entity.roles[0].key: [person_id]} + inputs[entity.key] = { + entity.roles[0].plural or entity.roles[0].key: [person_id], + } else: inputs[entity.plural] = { - '{}_{}'.format(entity.key, index): {entity.roles[0].plural or entity.roles[0].key: [person_id]} - for index, person_id in enumerate(persons.keys()) + f"{entity.key}_{index}": { + entity.roles[0].plural or entity.roles[0].key: [person_id], } + for index, person_id in enumerate(persons.keys()) + } return inputs def remove_id(input_dict): - return { - key: value - for (key, value) in input_dict.items() - if key != "id" - } + return {key: value for (key, value) in input_dict.items() if key != "id"} -def main(): +def main() -> None: parser = build_parser() args = parser.parse_args() paths = [os.path.abspath(path) for path in args.path] - tax_benefit_system = build_tax_benefit_system(args.country_package, args.extensions, args.reforms) + tax_benefit_system = build_tax_benefit_system( + args.country_package, + args.extensions, + args.reforms, + ) Migrator(tax_benefit_system).migrate(paths) diff --git a/openfisca_core/scripts/openfisca_command.py b/openfisca_core/scripts/openfisca_command.py index 786b73b35f..d82e0aef61 100644 --- a/openfisca_core/scripts/openfisca_command.py +++ b/openfisca_core/scripts/openfisca_command.py @@ -1,8 +1,9 @@ import argparse -import warnings import sys +import warnings from openfisca_core.scripts import add_tax_benefit_system_arguments + """ Define the `openfisca` command line interface. """ @@ -11,62 +12,157 @@ def get_parser(): parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(help = 'Available commands', dest = 'command') - subparsers.required = True # Can be added as an argument of add_subparsers in Python 3 + subparsers = parser.add_subparsers(help="Available commands", dest="command") + subparsers.required = ( + True # Can be added as an argument of add_subparsers in Python 3 + ) def build_serve_parser(parser): # Define OpenFisca modules configuration parser = add_tax_benefit_system_arguments(parser) # Define server configuration - parser.add_argument('-p', '--port', action = 'store', help = "port to serve on (use --bind to specify host and port)", type = int) - parser.add_argument('--tracker-url', action = 'store', help = "tracking service url", type = str) - parser.add_argument('--tracker-idsite', action = 'store', help = "tracking service id site", type = int) - parser.add_argument('--tracker-token', action = 'store', help = "tracking service authentication token", type = str) - parser.add_argument('--welcome-message', action = 'store', help = "welcome message users will get when visiting the API root", type = str) - parser.add_argument('-f', '--configuration-file', action = 'store', help = "configuration file", type = str) + parser.add_argument( + "-p", + "--port", + action="store", + help="port to serve on (use --bind to specify host and port)", + type=int, + ) + parser.add_argument( + "--tracker-url", + action="store", + help="tracking service url", + type=str, + ) + parser.add_argument( + "--tracker-idsite", + action="store", + help="tracking service id site", + type=int, + ) + parser.add_argument( + "--tracker-token", + action="store", + help="tracking service authentication token", + type=str, + ) + parser.add_argument( + "--welcome-message", + action="store", + help="welcome message users will get when visiting the API root", + type=str, + ) + parser.add_argument( + "-f", + "--configuration-file", + action="store", + help="configuration file", + type=str, + ) return parser - parser_serve = subparsers.add_parser('serve', help = 'Run the OpenFisca Web API') + parser_serve = subparsers.add_parser("serve", help="Run the OpenFisca Web API") parser_serve = build_serve_parser(parser_serve) def build_test_parser(parser): - parser.add_argument('path', help = "paths (files or directories) of tests to execute", nargs = '+') + parser.add_argument( + "path", + help="paths (files or directories) of tests to execute", + nargs="+", + ) parser = add_tax_benefit_system_arguments(parser) - parser.add_argument('-n', '--name_filter', default = None, help = "partial name of tests to execute. Only tests with the given name_filter in their name, file name, or keywords will be run.") - parser.add_argument('-p', '--pdb', action = 'store_true', default = False, help = "drop into debugger on failures or errors") - parser.add_argument('--performance-graph', '--performance', action = 'store_true', default = False, help = "output a performance graph in a 'performance_graph.html' file") - parser.add_argument('--performance-tables', action = 'store_true', default = False, help = "output performance CSV tables") - parser.add_argument('-v', '--verbose', action = 'store_true', default = False, help = "increase output verbosity") - parser.add_argument('-o', '--only-variables', nargs = '*', default = None, help = "variables to test. If specified, only test the given variables.") - parser.add_argument('-i', '--ignore-variables', nargs = '*', default = None, help = "variables to ignore. If specified, do not test the given variables.") + parser.add_argument( + "-n", + "--name_filter", + default=None, + help="partial name of tests to execute. Only tests with the given name_filter in their name, file name, or keywords will be run.", + ) + parser.add_argument( + "-p", + "--pdb", + action="store_true", + default=False, + help="drop into debugger on failures or errors", + ) + parser.add_argument( + "--performance-graph", + "--performance", + action="store_true", + default=False, + help="output a performance graph in a 'performance_graph.html' file", + ) + parser.add_argument( + "--performance-tables", + action="store_true", + default=False, + help="output performance CSV tables", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="increase output verbosity. If specified, output the entire calculation trace.", + ) + parser.add_argument( + "-a", + "--aggregate", + action="store_true", + default=False, + help="increase output verbosity to aggregate. If specified, output the avg, max, and min values of the calculation trace. This flag has no effect without --verbose.", + ) + parser.add_argument( + "-d", + "--max-depth", + type=int, + default=None, + help="set maximal verbosity depth. If specified, output the calculation trace up to the provided depth. This flag has no effect without --verbose.", + ) + parser.add_argument( + "-o", + "--only-variables", + nargs="*", + default=None, + help="variables to test. If specified, only test the given variables.", + ) + parser.add_argument( + "-i", + "--ignore-variables", + nargs="*", + default=None, + help="variables to ignore. If specified, do not test the given variables.", + ) return parser - parser_test = subparsers.add_parser('test', help = 'Run OpenFisca YAML tests') + parser_test = subparsers.add_parser("test", help="Run OpenFisca YAML tests") parser_test = build_test_parser(parser_test) return parser def main(): - if sys.argv[0].endswith('openfisca-run-test'): - sys.argv[0:1] = ['openfisca', 'test'] + if sys.argv[0].endswith("openfisca-run-test"): + sys.argv[0:1] = ["openfisca", "test"] message = "The 'openfisca-run-test' command has been deprecated in favor of 'openfisca test' since version 25.0, and will be removed in the future." - warnings.warn(message, Warning) + warnings.warn(message, Warning, stacklevel=2) parser = get_parser() args, _ = parser.parse_known_args() - if args.command == 'serve': + if args.command == "serve": from openfisca_web_api.scripts.serve import main + return sys.exit(main(parser)) - if args.command == 'test': + if args.command == "test": from openfisca_core.scripts.run_test import main + return sys.exit(main(parser)) + return None -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/openfisca_core/scripts/remove_fuzzy.py b/openfisca_core/scripts/remove_fuzzy.py index 669af2d01b..a4827aef39 100755 --- a/openfisca_core/scripts/remove_fuzzy.py +++ b/openfisca_core/scripts/remove_fuzzy.py @@ -1,28 +1,26 @@ # remove_fuzzy.py : Remove the fuzzy attribute in xml files and add END tags. # See https://github.com/openfisca/openfisca-core/issues/437 -import re import datetime +import re import sys -import numpy as np -assert(len(sys.argv) == 2) +import numpy + +assert len(sys.argv) == 2 filename = sys.argv[1] -with open(filename, 'r') as f: +with open(filename) as f: lines = f.readlines() # Remove fuzzy -lines_2 = [ - line.replace(' fuzzy="true"', '') - for line in lines - ] +lines_2 = [line.replace(' fuzzy="true"', "") for line in lines] -regex_indent = r'^(\s*)\n$' -bool_code = [ - bool(re.search(regex_code, line)) - for line in lines_5 - ] +bool_code = [bool(re.search(regex_code, line)) for line in lines_5] -bool_code_end = [ - bool(re.search(regex_code_end, line)) - for line in lines_5 - ] +bool_code_end = [bool(re.search(regex_code_end, line)) for line in lines_5] list_value = [] for line in lines_5: @@ -227,19 +194,19 @@ to_remove = [] for i in range(len(lines_5) - 1): - if (list_value[i] is not None) and (list_value[i + 1] is not None) and (list_value[i] == list_value[i + 1]): + if ( + (list_value[i] is not None) + and (list_value[i + 1] is not None) + and (list_value[i] == list_value[i + 1]) + ): to_remove.append(i) to_remove_set = set(to_remove) -lines_6 = [ - line - for j, line in enumerate(lines_5) - if j not in to_remove_set - ] +lines_6 = [line for j, line in enumerate(lines_5) if j not in to_remove_set] # Write -with open(filename, 'w') as f: +with open(filename, "w") as f: for line in lines_6: f.write(line) diff --git a/openfisca_core/scripts/run_test.py b/openfisca_core/scripts/run_test.py index 77c4140899..458dc7e50e 100644 --- a/openfisca_core/scripts/run_test.py +++ b/openfisca_core/scripts/run_test.py @@ -1,28 +1,35 @@ -# -*- coding: utf-8 -*- - import logging -import sys import os +import sys -from openfisca_core.tools.test_runner import run_tests from openfisca_core.scripts import build_tax_benefit_system +from openfisca_core.tools.test_runner import run_tests -def main(parser): +def main(parser) -> None: args = parser.parse_args() - logging.basicConfig(level = logging.DEBUG if args.verbose else logging.WARNING, stream = sys.stdout) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, + ) - tax_benefit_system = build_tax_benefit_system(args.country_package, args.extensions, args.reforms) + tax_benefit_system = build_tax_benefit_system( + args.country_package, + args.extensions, + args.reforms, + ) options = { - 'pdb': args.pdb, - 'performance_graph': args.performance_graph, - 'performance_tables': args.performance_tables, - 'verbose': args.verbose, - 'name_filter': args.name_filter, - 'only_variables': args.only_variables, - 'ignore_variables': args.ignore_variables, - } + "pdb": args.pdb, + "performance_graph": args.performance_graph, + "performance_tables": args.performance_tables, + "verbose": args.verbose, + "aggregate": args.aggregate, + "max_depth": args.max_depth, + "name_filter": args.name_filter, + "only_variables": args.only_variables, + "ignore_variables": args.ignore_variables, + } paths = [os.path.abspath(path) for path in args.path] sys.exit(run_tests(tax_benefit_system, paths, options)) diff --git a/openfisca_core/scripts/simulation_generator.py b/openfisca_core/scripts/simulation_generator.py index 489f42356f..eca2fa30d1 100644 --- a/openfisca_core/scripts/simulation_generator.py +++ b/openfisca_core/scripts/simulation_generator.py @@ -1,30 +1,32 @@ -import numpy as np - import random + +import numpy + from openfisca_core.simulations import Simulation def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): - """ - Generate a simulation containing nb_persons persons spread in nb_groups groups. + """Generate a simulation containing nb_persons persons spread in nb_groups groups. - Example: + Example: + >>> from openfisca_core.scripts.simulation_generator import make_simulation + >>> from openfisca_france import CountryTaxBenefitSystem + >>> tbs = CountryTaxBenefitSystem() + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> simulation.calculate("revenu_disponible", 2017) - >>> from openfisca_core.scripts.simulation_generator import make_simulation - >>> from openfisca_france import CountryTaxBenefitSystem - >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> simulation.calculate('revenu_disponible', 2017) """ - simulation = Simulation(tax_benefit_system = tax_benefit_system, **kwargs) - simulation.persons.ids = np.arange(nb_persons) + simulation = Simulation(tax_benefit_system=tax_benefit_system, **kwargs) + simulation.persons.ids = numpy.arange(nb_persons) simulation.persons.count = nb_persons - adults = [0] + sorted(random.sample(range(1, nb_persons), nb_groups - 1)) + adults = [0, *sorted(random.sample(range(1, nb_persons), nb_groups - 1))] - members_entity_id = np.empty(nb_persons, dtype = int) + members_entity_id = numpy.empty(nb_persons, dtype=int) # A legacy role is an index that every person within an entity has. For instance, the 'demandeur' has legacy role 0, the 'conjoint' 1, the first 'child' 2, the second 3, etc. - members_legacy_role = np.empty(nb_persons, dtype = int) + members_legacy_role = numpy.empty(nb_persons, dtype=int) id_group = -1 for id_person in range(nb_persons): @@ -40,27 +42,49 @@ def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): if not entity.is_person: entity.members_entity_id = members_entity_id entity.count = nb_groups - entity.members_role = np.where(members_legacy_role == 0, entity.flattened_roles[0], entity.flattened_roles[-1]) + entity.members_role = numpy.where( + members_legacy_role == 0, + entity.flattened_roles[0], + entity.flattened_roles[-1], + ) return simulation -def randomly_init_variable(simulation, variable_name, period, max_value, condition = None): - """ - Initialise a variable with random values (from 0 to max_value) for the given period. - If a condition vector is provided, only set the value of persons or groups for which condition is True. +def randomly_init_variable( + simulation, + variable_name: str, + period, + max_value, + condition=None, +) -> None: + """Initialise a variable with random values (from 0 to max_value) for the given period. + If a condition vector is provided, only set the value of persons or groups for which condition is True. - Example: + Example: + >>> from openfisca_core.scripts.simulation_generator import ( + ... make_simulation, + ... randomly_init_variable, + ... ) + >>> from openfisca_france import CountryTaxBenefitSystem + >>> tbs = CountryTaxBenefitSystem() + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> randomly_init_variable( + ... simulation, + ... "salaire_net", + ... 2017, + ... max_value=50000, + ... condition=simulation.persons.has_role(simulation.famille.DEMANDEUR), + ... ) # Randomly set a salaire_net for all persons between 0 and 50000? + >>> simulation.calculate("revenu_disponible", 2017) - >>> from openfisca_core.scripts.simulation_generator import make_simulation, randomly_init_variable - >>> from openfisca_france import CountryTaxBenefitSystem - >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> randomly_init_variable(simulation, 'salaire_net', 2017, max_value = 50000, condition = simulation.persons.has_role(simulation.famille.DEMANDEUR)) # Randomly set a salaire_net for all persons between 0 and 50000? - >>> simulation.calculate('revenu_disponible', 2017) - """ + """ if condition is None: condition = True variable = simulation.tax_benefit_system.get_variable(variable_name) population = simulation.get_variable_population(variable_name) - value = (np.random.rand(population.count) * max_value * condition).astype(variable.dtype) + value = (numpy.random.rand(population.count) * max_value * condition).astype( + variable.dtype, + ) simulation.set_input(variable_name, period, value) diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 2f7a9c6d51..9ab10f81a7 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,6 +21,25 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from openfisca_core.errors import CycleError, NaNCreationError, SpiralError + +from .helpers import ( + calculate_output_add, + calculate_output_divide, + check_type, + transform_to_strict_syntax, +) +from .simulation import Simulation +from .simulation_builder import SimulationBuilder + +__all__ = [ + "CycleError", + "NaNCreationError", + "Simulation", + "SimulationBuilder", + "SpiralError", + "calculate_output_add", + "calculate_output_divide", + "check_type", + "transform_to_strict_syntax", +] diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py new file mode 100644 index 0000000000..adc7cf4783 --- /dev/null +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -0,0 +1,159 @@ +"""This module contains the _BuildDefaultSimulation class.""" + +from typing import Union +from typing_extensions import Self + +import numpy + +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem + + +class _BuildDefaultSimulation: + """Build a default simulation. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + count(int): The number of periods. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 1 + >>> builder = ( + ... _BuildDefaultSimulation(tax_benefit_system, count) + ... .add_count() + ... .add_ids() + ... .add_members_entity_id() + ... ) + + >>> builder.count + 1 + + >>> sorted(builder.populations.keys()) + ['dog', 'pack'] + + >>> sorted(builder.simulation.populations.keys()) + ['dog', 'pack'] + + """ + + #: The number of Population. + count: int + + #: The built populations. + populations: dict[str, Union[Population[Entity]]] + + #: The built simulation. + simulation: Simulation + + def __init__(self, tax_benefit_system: TaxBenefitSystem, count: int) -> None: + self.count = count + self.populations = tax_benefit_system.instantiate_entities() + self.simulation = Simulation(tax_benefit_system, self.populations) + + def add_count(self) -> Self: + """Add the number of Population to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_count() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].count + 2 + + >>> builder.populations["pack"].count + 2 + + """ + for population in self.populations.values(): + population.count = self.count + + return self + + def add_ids(self) -> Self: + """Add the populations ids to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_ids() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].ids + array([0, 1]) + + >>> builder.populations["pack"].ids + array([0, 1]) + + """ + for population in self.populations.values(): + population.ids = numpy.array(range(self.count)) + + return self + + def add_members_entity_id(self) -> Self: + """Add ??? + + Each SingleEntity has its own GroupEntity. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_members_entity_id() + <..._BuildDefaultSimulation object at ...> + + >>> population = builder.populations["pack"] + + >>> hasattr(population, "members_entity_id") + True + + >>> population.members_entity_id + array([0, 1]) + + """ + for population in self.populations.values(): + if hasattr(population, "members_entity_id"): + population.members_entity_id = numpy.array(range(self.count)) + + return self diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py new file mode 100644 index 0000000000..20f49ce113 --- /dev/null +++ b/openfisca_core/simulations/_build_from_variables.py @@ -0,0 +1,230 @@ +"""This module contains the _BuildFromVariables class.""" + +from __future__ import annotations + +from typing_extensions import Self + +from openfisca_core import errors + +from ._build_default_simulation import _BuildDefaultSimulation +from ._type_guards import is_variable_dated +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem, Variables + + +class _BuildFromVariables: + """Build a simulation from variables. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + params(Variables): The simulation parameters. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = ( + ... _BuildFromVariables(tax_benefit_system, variables, period) + ... .add_dated_values() + ... .add_undated_values() + ... ) + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + + #: The number of Population. + count: int + + #: The Simulation's default period. + default_period: str | None + + #: The built populations. + populations: dict[str, Population[Entity]] + + #: The built simulation. + simulation: Simulation + + #: The simulation parameters. + variables: Variables + + def __init__( + self, + tax_benefit_system: TaxBenefitSystem, + params: Variables, + default_period: str | None = None, + ) -> None: + self.count = _person_count(params) + + default_builder = ( + _BuildDefaultSimulation(tax_benefit_system, self.count) + .add_count() + .add_ids() + .add_members_entity_id() + ) + + self.variables = params + self.simulation = default_builder.simulation + self.populations = default_builder.populations + self.default_period = default_period + + def add_dated_values(self) -> Self: + """Add the dated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_dated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + + """ + for variable, value in self.variables.items(): + if is_variable_dated(dated_variable := value): + for period, dated_value in dated_variable.items(): + self.simulation.set_input(variable, period, dated_value) + + return self + + def add_undated_values(self) -> Self: + """Add the undated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Raises: + SituationParsingError: If there is not a default period set. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_undated_values() + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + >>> builder.default_period = period + >>> builder.add_undated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + for variable, value in self.variables.items(): + if not is_variable_dated(undated_value := value): + if (period := self.default_period) is None: + message = ( + "Can't deal with type: expected object. Input " + "variables should be set for specific periods. For " + "instance: " + " {'salary': {'2017-01': 2000, '2017-02': 2500}}" + " {'birth_date': {'ETERNITY': '1980-01-01'}}" + ) + + raise errors.SituationParsingError([variable], message) + + self.simulation.set_input(variable, period, undated_value) + + return self + + +def _person_count(params: Variables) -> int: + try: + first_value = next(iter(params.values())) + + if isinstance(first_value, dict): + first_value = next(iter(first_value.values())) + + if isinstance(first_value, str): + return 1 + + return len(first_value) + + except Exception: + return 1 diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py new file mode 100644 index 0000000000..990248213d --- /dev/null +++ b/openfisca_core/simulations/_type_guards.py @@ -0,0 +1,298 @@ +"""Type guards to help type narrowing simulation parameters.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing_extensions import TypeGuard + +from .typing import ( + Axes, + DatedVariable, + FullySpecifiedEntities, + ImplicitGroupEntities, + Params, + UndatedVariable, + Variables, +) + + +def are_entities_fully_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[FullySpecifiedEntities]: + """Check if the params contain fully specified entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the params contain fully specified entities. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {} + + >>> are_entities_fully_specified(params, entities) + False + + """ + if not params: + return False + + return all(key in items for key in params if key != "axes") + + +def are_entities_short_form( + params: Params, + items: Iterable[str], +) -> TypeGuard[ImplicitGroupEntities]: + """Check if the params contain short form entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in singular form. + + Returns: + bool: True if the params contain short form entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = { + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": "Javier"}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {} + + >>> are_entities_short_form(params, entities) + False + + """ + return bool(set(params).intersection(items)) + + +def are_entities_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[Variables]: + """Check if the params contains entities at all. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of variables. + + Returns: + bool: True if the params does not contain variables at the root level. + + Examples: + >>> variables = {"salary"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"salary": {"2016-10": [12000, 13000]}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": [12000, 13000]} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, variables) + False + + >>> params = {} + + >>> are_entities_specified(params, variables) + False + + """ + if not params: + return False + + return not any(key in items for key in params) + + +def has_axes(params: Params) -> TypeGuard[Axes]: + """Check if the params contains axes. + + Args: + params(Params): Simulation parameters. + + Returns: + bool: True if the params contain axes. + + Examples: + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> has_axes(params) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> has_axes(params) + False + + """ + return params.get("axes", None) is not None + + +def is_variable_dated( + variable: DatedVariable | UndatedVariable, +) -> TypeGuard[DatedVariable]: + """Check if the variable is dated. + + Args: + variable(DatedVariable | UndatedVariable): A variable. + + Returns: + bool: True if the variable is dated. + + Examples: + >>> variable = {"2018-11": [2000, 3000]} + + >>> is_variable_dated(variable) + True + + >>> variable = {"2018-11": 2000} + + >>> is_variable_dated(variable) + True + + >>> variable = 2000 + + >>> is_variable_dated(variable) + False + + """ + return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index 683a4106b9..7929c5beda 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -1,27 +1,106 @@ -from openfisca_core.errors import SituationParsingError +from collections.abc import Iterable +from openfisca_core import errors -def calculate_output_add(simulation, variable_name, period): +from .typing import ParamsWithoutAxes + + +def calculate_output_add(simulation, variable_name: str, period): return simulation.calculate_add(variable_name, period) -def calculate_output_divide(simulation, variable_name, period): +def calculate_output_divide(simulation, variable_name: str, period): return simulation.calculate_divide(variable_name, period) -def check_type(input, input_type, path = None): +def check_type(input, input_type, path=None) -> None: json_type_map = { dict: "Object", list: "Array", str: "String", - } + } if path is None: path = [] if not isinstance(input, input_type): - raise SituationParsingError(path, - "Invalid type: must be of type '{}'.".format(json_type_map[input_type])) + raise errors.SituationParsingError( + path, + f"Invalid type: must be of type '{json_type_map[input_type]}'.", + ) + + +def check_unexpected_entities( + params: ParamsWithoutAxes, + entities: Iterable[str], +) -> None: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Raises: + SituationParsingError: If there are entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> check_unexpected_entities(params, entities) + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> check_unexpected_entities(params, entities) + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + + """ + if has_unexpected_entities(params, entities): + unexpected_entities = [entity for entity in params if entity not in entities] + + message = ( + "Some entities in the situation are not defined in the loaded tax " + "and benefit system. " + f"These entities are not found: {', '.join(unexpected_entities)}. " + f"The defined entities are: {', '.join(entities)}." + ) + + raise errors.SituationParsingError([unexpected_entities[0]], message) + + +def has_unexpected_entities(params: ParamsWithoutAxes, entities: Iterable[str]) -> bool: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the input contains entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> has_unexpected_entities(params, entities) + False + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> has_unexpected_entities(params, entities) + True + + """ + return any(entity for entity in params if entity not in entities) def transform_to_strict_syntax(data): @@ -30,16 +109,3 @@ def transform_to_strict_syntax(data): if isinstance(data, list): return [str(item) if isinstance(item, int) else item for item in data] return data - - -def _get_person_count(input_dict): - try: - first_value = next(iter(input_dict.values())) - if isinstance(first_value, dict): - first_value = next(iter(first_value.values())) - if isinstance(first_value, str): - return 1 - - return len(first_value) - except Exception: - return 1 diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 5dd2694292..df7716c7be 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,28 +1,42 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import NamedTuple + +from openfisca_core.types import ( + CorePopulation as Population, + TaxBenefitSystem, + Variable, +) + import tempfile import warnings import numpy -from openfisca_core import commons, periods -from openfisca_core.errors import CycleError, SpiralError -from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import Period -from openfisca_core.tracers import FullTracer, SimpleTracer, TracingParameterNodeAtInstant -from openfisca_core.warnings import TempfileWarning +from openfisca_core import ( + commons, + errors, + indexed_enums, + periods, + tracers, + warnings as core_warnings, +) class Simulation: - """ - Represents a simulation, and handles the calculation logic - """ + """Represents a simulation, and handles the calculation logic.""" + + tax_benefit_system: TaxBenefitSystem + populations: dict[str, Population] + invalidated_caches: set[Cache] def __init__( - self, - tax_benefit_system, - populations - ): - """ - This constructor is reserved for internal use; see :any:`SimulationBuilder`, + self, + tax_benefit_system: TaxBenefitSystem, + populations: Mapping[str, Population], + ) -> None: + """This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ @@ -38,11 +52,11 @@ def __init__( self.debug = False self.trace = False - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 - self.max_spiral_loops = 1 + self.max_spiral_loops: int = 1 self.memory_config = None self._data_storage_dir = None @@ -51,42 +65,45 @@ def trace(self): return self._trace @trace.setter - def trace(self, trace): + def trace(self, trace) -> None: self._trace = trace if trace: - self.tracer = FullTracer() + self.tracer = tracers.FullTracer() else: - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() - def link_to_entities_instances(self): - for _key, entity_instance in self.populations.items(): + def link_to_entities_instances(self) -> None: + for entity_instance in self.populations.values(): entity_instance.simulation = self - def create_shortcuts(self): - for _key, population in self.populations.items(): + def create_shortcuts(self) -> None: + for population in self.populations.values(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): - """ - Temporary folder used to store intermediate calculation data in case the memory is saturated - """ + """Temporary folder used to store intermediate calculation data in case the memory is saturated.""" if self._data_storage_dir is None: - self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_") + self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_") message = [ - ("Intermediate results will be stored on disk in {} in case of memory overflow.").format(self._data_storage_dir), - "You should remove this directory once you're done with your simulation." - ] - warnings.warn(" ".join(message), TempfileWarning) + ( + f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow." + ), + "You should remove this directory once you're done with your simulation.", + ] + warnings.warn( + " ".join(message), + core_warnings.TempfileWarning, + stacklevel=2, + ) return self._data_storage_dir # ----- Calculation methods ----- # - def calculate(self, variable_name, period): + def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" - - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) self.tracer.record_calculation_start(variable_name, period) @@ -100,15 +117,22 @@ def calculate(self, variable_name, period): self.tracer.record_calculation_end() self.purge_cache_of_invalid_values() - def _calculate(self, variable_name, period: Period): - """ - Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. + def _calculate(self, variable_name: str, period: periods.Period): + """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ + variable: Variable | None + population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) self._check_period_consistency(period, variable) @@ -131,75 +155,163 @@ def _calculate(self, variable_name, period: Period): array = self._cast_formula_result(array, variable) holder.put_in_cache(array, period) - except SpiralError: + except errors.SpiralError: array = holder.default_array() return array - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return - for (_name, _period) in self.invalidated_caches: + for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) self.invalidated_caches = set() - def calculate_add(self, variable_name, period): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + def calculate_add(self, variable_name: str, period): + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) # Check that the requested period matches definition_period - if periods.unit_weight(variable.definition_period) > periods.unit_weight(period.unit): - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, - period, - variable.definition_period - )) - - if variable.definition_period not in [periods.DAY, periods.MONTH, periods.YEAR]: - raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time.".format( - variable.name, - period)) + if periods.unit_weight(variable.definition_period) > periods.unit_weight( + period.unit, + ): + msg = ( + f"Unable to compute variable '{variable.name}' for period " + f"{period}: '{variable.name}' can only be computed for " + f"{variable.definition_period}-long periods. You can use the " + f"DIVIDE option to get an estimate of {variable.name}." + ) + raise ValueError( + msg, + ) + + if variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = ( + f"Unable to ADD constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be summed " + "over time." + ) + raise ValueError( + msg, + ) return sum( self.calculate(variable_name, sub_period) for sub_period in period.get_subperiods(variable.definition_period) - ) + ) + + def calculate_divide(self, variable_name: str, period): + variable: Variable | None - def calculate_divide(self, variable_name, period): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) - # Check that the requested period matches definition_period - if variable.definition_period != periods.YEAR: - raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format( - variable_name, - period)) + if ( + periods.unit_weight(variable.definition_period) + < periods.unit_weight(period.unit) + or period.size > 1 + ): + msg = ( + f"Can't calculate variable '{variable.name}' for period " + f"{period}: '{variable.name}' can only be computed for " + f"{variable.definition_period}-long periods. You can use the " + f"ADD option to get an estimate of {variable.name}." + ) + raise ValueError( + msg, + ) - if period.size != 1: - raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period") + if variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = ( + f"Unable to DIVIDE constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be divided " + "over time." + ) + raise ValueError( + msg, + ) - if period.unit == periods.MONTH: - computation_period = period.this_year - return self.calculate(variable_name, period = computation_period) / 12. - elif period.unit == periods.YEAR: - return self.calculate(variable_name, period) + if ( + period.unit + not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar) + or period.size != 1 + ): + msg = ( + f"Unable to DIVIDE constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be used " + "as a denominator to divide a variable over time." + ) + raise ValueError( + msg, + ) - raise ValueError("Unable to divide the value of '{}' to match period {}.".format( - variable_name, - period)) + if variable.definition_period == periods.DateUnit.YEAR: + calculation_period = period.this_year - def calculate_output(self, variable_name, period): - """ - Calculate the value of a variable using the ``calculate_output`` attribute of the variable. - """ + elif variable.definition_period == periods.DateUnit.MONTH: + calculation_period = period.first_month + + elif variable.definition_period == periods.DateUnit.DAY: + calculation_period = period.first_day + + elif variable.definition_period == periods.DateUnit.WEEK: + calculation_period = period.first_week + + else: + calculation_period = period.first_weekday - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + if period.unit == periods.DateUnit.YEAR: + denominator = calculation_period.size_in_years + + elif period.unit == periods.DateUnit.MONTH: + denominator = calculation_period.size_in_months + + elif period.unit == periods.DateUnit.DAY: + denominator = calculation_period.size_in_days + + elif period.unit == periods.DateUnit.WEEK: + denominator = calculation_period.size_in_weeks + + else: + denominator = calculation_period.size_in_weekdays + + return self.calculate(variable_name, calculation_period) / denominator + + def calculate_output(self, variable_name: str, period): + """Calculate the value of a variable using the ``calculate_output`` attribute of the variable.""" + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if variable.calculate_output is None: return self.calculate(variable_name, period) @@ -207,16 +319,13 @@ def calculate_output(self, variable_name, period): return variable.calculate_output(self, variable_name, period) def trace_parameters_at_instant(self, formula_period): - return TracingParameterNodeAtInstant( + return tracers.TracingParameterNodeAtInstant( self.tax_benefit_system.get_parameters_at_instant(formula_period), - self.tracer - ) + self.tracer, + ) def _run_formula(self, variable, population, period): - """ - Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. - """ - + """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.""" formula = variable.get_formula(period) if formula is None: return None @@ -233,34 +342,49 @@ def _run_formula(self, variable, population, period): return array - def _check_period_consistency(self, period, variable): - """ - Check that a period matches the variable definition_period - """ - if variable.definition_period == periods.ETERNITY: + def _check_period_consistency(self, period, variable) -> None: + """Check that a period matches the variable definition_period.""" + if variable.definition_period == periods.DateUnit.ETERNITY: return # For variables which values are constant in time, all periods are accepted - if variable.definition_period == periods.MONTH and period.unit != periods.MONTH: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format( - variable.name, - period - )) + if ( + variable.definition_period == periods.DateUnit.YEAR + and period.unit != periods.DateUnit.YEAR + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'." + raise ValueError( + msg, + ) - if variable.definition_period == periods.YEAR and period.unit != periods.YEAR: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, - period - )) + if ( + variable.definition_period == periods.DateUnit.MONTH + and period.unit != periods.DateUnit.MONTH + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'." + raise ValueError( + msg, + ) + + if ( + variable.definition_period == periods.DateUnit.WEEK + and period.unit != periods.DateUnit.WEEK + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'." + raise ValueError( + msg, + ) if period.size != 1: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( - variable.name, - period, - 'month' if variable.definition_period == periods.MONTH else 'year' - )) + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period." + raise ValueError( + msg, + ) def _cast_formula_result(self, value, variable): - if variable.value_type == Enum and not isinstance(value, EnumArray): + if variable.value_type == indexed_enums.Enum and not isinstance( + value, + indexed_enums.EnumArray, + ): return variable.possible_values.encode(value) if not isinstance(value, numpy.ndarray): @@ -274,165 +398,190 @@ def _cast_formula_result(self, value, variable): # ----- Handle circular dependencies in a calculation ----- # - def _check_for_cycle(self, variable: str, period): - """ - Raise an exception in the case of a circular definition, where evaluating a variable for + def _check_for_cycle(self, variable: str, period) -> None: + """Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. """ # The last frame is the current calculation, so it should be ignored from cycle detection - previous_periods = [frame['period'] for frame in self.tracer.stack[:-1] if frame['name'] == variable] + previous_periods = [ + frame["period"] + for frame in self.tracer.stack[:-1] + if frame["name"] == variable + ] if period in previous_periods: - raise CycleError("Circular definition detected on formula {}@{}".format(variable, period)) + msg = f"Circular definition detected on formula {variable}@{period}" + raise errors.CycleError( + msg, + ) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) - message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack) - raise SpiralError(message, variable) + message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}" + raise errors.SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): - self.invalidated_caches.add((variable, period)) + def invalidate_cache_entry(self, variable: str, period) -> None: + self.invalidated_caches.add(Cache(variable, period)) - def invalidate_spiral_variables(self, variable: str): + def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them # for deletion from the cache once the calculation ends. count = 0 for frame in reversed(self.tracer.stack): - self.invalidate_cache_entry(frame['name'], frame['period']) - if frame['name'] == variable: + self.invalidate_cache_entry(str(frame["name"]), frame["period"]) + if frame["name"] == variable: count += 1 if count > self.max_spiral_loops: break # ----- Methods to access stored values ----- # - def get_array(self, variable_name, period): - """ - Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). + def get_array(self, variable_name: str, period): + """Return the value of ``variable_name`` for ``period``, if this value is already in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) return self.get_holder(variable_name).get_array(period) - def get_holder(self, variable_name): - """ - Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation - """ + def get_holder(self, variable_name: str): + """Get the holder associated with the variable.""" return self.get_variable_population(variable_name).get_holder(variable_name) - def get_memory_usage(self, variables = None): - """ - Get data about the virtual memory usage of the simulation - """ - result = dict( - total_nb_bytes = 0, - by_variable = {} - ) + def get_memory_usage(self, variables=None): + """Get data about the virtual memory usage of the simulation.""" + result = {"total_nb_bytes": 0, "by_variable": {}} for entity in self.populations.values(): - entity_memory_usage = entity.get_memory_usage(variables = variables) - result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes'] - result['by_variable'].update(entity_memory_usage['by_variable']) + entity_memory_usage = entity.get_memory_usage(variables=variables) + result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"] + result["by_variable"].update(entity_memory_usage["by_variable"]) return result # ----- Misc ----- # - def delete_arrays(self, variable, period = None): - """ - Delete a variable's value for a given period + def delete_arrays(self, variable, period=None) -> None: + """Delete a variable's value for a given period. :param variable: the variable to be set :param period: the period for which the value should be deleted Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_array('age', '2018-05') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_array("age", "2018-05") array([13, 14], dtype=int32) - >>> simulation.delete_arrays('age', '2018-05') - >>> simulation.get_array('age', '2018-04') + >>> simulation.delete_arrays("age", "2018-05") + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.delete_arrays('age') - >>> simulation.get_array('age', '2018-04') is None + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.delete_arrays("age") + >>> simulation.get_array("age", "2018-04") is None True - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True + """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): - """ - Get a list variable's known period, i.e. the periods where a value has been initialized and + """Get a list variable's known period, i.e. the periods where a value has been initialized and. :param variable: the variable to be set Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_known_periods('age') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_known_periods("age") [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] + """ return self.get_holder(variable).get_known_periods() - def set_input(self, variable_name, period, value): - """ - Set a variable's value for a given period + def set_input(self, variable_name: str, period, value) -> None: + """Set a variable's value for a given period. :param variable: the variable to be set :param value: the input value for the variable - :param period: the period for which the value is setted + :param period: the period for which the value is set Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.get_array('age', '2018-04') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. + """ - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) + period = periods.period(period) - if ((variable.end is not None) and (period.start.date > variable.end)): + if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value) - def get_variable_population(self, variable_name): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) - return self.populations[variable.entity.key] + def get_variable_population(self, variable_name: str) -> Population: + variable: Variable | None - def get_population(self, plural = None): - return next((population for population in self.populations.values() if population.entity.plural == plural), None) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - def get_entity(self, plural = None): + return self.populations[variable.entity.key] + + def get_population(self, plural: str | None = None) -> Population | None: + return next( + ( + population + for population in self.populations.values() + if population.entity.plural == plural + ), + None, + ) + + def get_entity( + self, + plural: str | None = None, + ) -> Population | None: population = self.get_population(plural) return population and population.entity def describe_entities(self): - return {population.entity.plural: population.ids for population in self.populations.values()} + return { + population.entity.plural: population.ids + for population in self.populations.values() + } - def clone(self, debug = False, trace = False): - """ - Copy the simulation just enough to be able to run the copy without modifying the original simulation - """ + def clone(self, debug=False, trace=False): + """Copy the simulation just enough to be able to run the copy without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('debug', 'trace', 'tracer'): + if key not in ("debug", "trace", "tracer"): new_dict[key] = value new.persons = self.persons.clone(new) @@ -442,9 +591,18 @@ def clone(self, debug = False, trace = False): for entity in self.tax_benefit_system.group_entities: population = self.populations[entity.key].clone(new) new.populations[entity.key] = population - setattr(new, entity.key, population) # create shortcut simulation.household (for instance) + setattr( + new, + entity.key, + population, + ) # create shortcut simulation.household (for instance) new.debug = debug new.trace = trace return new + + +class Cache(NamedTuple): + variable: str + period: periods.Period diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 88553488db..064b5b4cb6 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,227 +1,427 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import NoReturn + import copy -import dpath -import typing +import dpath.util import numpy -from openfisca_core import periods -from openfisca_core.entities import Entity -from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError -from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation -from openfisca_core.variables import Variable +from openfisca_core import entities, errors, periods, populations, variables + +from . import helpers +from ._build_default_simulation import _BuildDefaultSimulation +from ._build_from_variables import _BuildFromVariables +from ._type_guards import ( + are_entities_fully_specified, + are_entities_short_form, + are_entities_specified, + has_axes, +) +from .simulation import Simulation +from .typing import ( + Axis, + Entity, + FullySpecifiedEntities, + GroupEntities, + GroupEntity, + ImplicitGroupEntities, + Params, + ParamsWithoutAxes, + Population, + Role, + SingleEntity, + TaxBenefitSystem, + Variables, +) class SimulationBuilder: - - def __init__(self): - self.default_period = None # Simulation period used for variables when no period is defined - self.persons_plural = None # Plural name for person entity in current tax and benefits system + def __init__(self) -> None: + self.default_period = ( + None # Simulation period used for variables when no period is defined + ) + self.persons_plural = ( + None # Plural name for person entity in current tax and benefits system + ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} - self.populations: typing.Dict[Entity.key, Population] = {} + self.input_buffer: dict[ + variables.Variable.name, + dict[str(periods.period), numpy.array], + ] = {} + self.populations: dict[entities.Entity.key, populations.Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: typing.Dict[Entity.plural, int] = {} - # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} + self.entity_counts: dict[entities.Entity.plural, int] = {} + # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. + self.entity_ids: dict[entities.Entity.plural, list[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.memberships: dict[entities.Entity.plural, list[int]] = {} + self.roles: dict[entities.Entity.plural, list[int]] = {} - self.variable_entities: typing.Dict[Variable.name, Entity] = {} + self.variable_entities: dict[variables.Variable.name, entities.Entity] = {} self.axes = [[]] - self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} - self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.axes_entity_counts: dict[entities.Entity.plural, int] = {} + self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {} + self.axes_memberships: dict[entities.Entity.plural, list[int]] = {} + self.axes_roles: dict[entities.Entity.plural, list[int]] = {} + + def build_from_dict( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Params, + ) -> Simulation: + """Build a simulation from an input dictionary. + + This method uses :meth:`.SimulationBuilder.build_from_entities` if + entities are fully specified, or + :meth:`.SimulationBuilder.build_from_variables` if they are not. + + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. + + Returns: + Simulation: The built simulation. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": { + ... "Alicia": {"salary": {"2018-11": 0}}, + ... "Javier": {}, + ... "Tom": {}, + ... }, + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": [12000, 13000]} + + >>> not are_entities_specified(params, {"salary"}) + True - def build_from_dict(self, tax_benefit_system, input_dict): """ - Build a simulation from ``input_dict`` + #: The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() - This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. + #: The singular names of the entities in the tax and benefits system. + singular: Iterable[str] = tax_benefit_system.entities_by_singular() - :param dict input_dict: A dict represeting the input of the simulation - :return: A :any:`Simulation` - """ + #: The names of the variables in the tax and benefits system. + variables: Iterable[str] = tax_benefit_system.variables.keys() - input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict) - if any(key in tax_benefit_system.entities_plural() for key in input_dict.keys()): - return self.build_from_entities(tax_benefit_system, input_dict) - else: - return self.build_from_variables(tax_benefit_system, input_dict) + if are_entities_short_form(input_dict, singular): + params = self.explicit_singular_entities(tax_benefit_system, input_dict) + return self.build_from_entities(tax_benefit_system, params) - def build_from_entities(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` fully specifying entities. + if are_entities_fully_specified(params := input_dict, plural): + return self.build_from_entities(tax_benefit_system, params) + + if not are_entities_specified(params := input_dict, variables): + return self.build_from_variables(tax_benefit_system, params) + return None + + def build_from_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: FullySpecifiedEntities, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` fully specifying + entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } - Examples: + >>> are_entities_short_form(params, entities) + True - >>> simulation_builder.build_from_entities({ - 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, - 'households': {'household': {'parents': ['Javier']}} - }) """ + # Create the populations + populations = tax_benefit_system.instantiate_entities() + + # Create the simulation + simulation = Simulation(tax_benefit_system, populations) + + # Why? input_dict = copy.deepcopy(input_dict) - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) + # The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() # Register variables so get_variable_entity can find them - for (variable_name, _variable) in tax_benefit_system.variables.items(): - self.register_variable(variable_name, simulation.get_variable_population(variable_name).entity) - - helpers.check_type(input_dict, dict, ['error']) - axes = input_dict.pop('axes', None) - - unexpected_entities = [entity for entity in input_dict if entity not in tax_benefit_system.entities_plural()] - if unexpected_entities: - unexpected_entity = unexpected_entities[0] - raise SituationParsingError([unexpected_entity], - ''.join([ - "Some entities in the situation are not defined in the loaded tax and benefit system.", - "These entities are not found: {0}.", - "The defined entities are: {1}."] - ) - .format( - ', '.join(unexpected_entities), - ', '.join(tax_benefit_system.entities_plural()) - ) - ) - persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None) + self.register_variables(simulation) + + # Declare axes + axes: list[list[Axis]] | None = None + + # ? + helpers.check_type(input_dict, dict, ["error"]) + + # Remove axes from input_dict + params: ParamsWithoutAxes = { + key: value for key, value in input_dict.items() if key != "axes" + } + + # Save axes for later + if has_axes(axes_params := input_dict): + axes = copy.deepcopy(axes_params.get("axes", None)) + + # Check for unexpected entities + helpers.check_unexpected_entities(params, plural) + + person_entity: SingleEntity = tax_benefit_system.person_entity + + persons_json = params.get(person_entity.plural, None) if not persons_json: - raise SituationParsingError([tax_benefit_system.person_entity.plural], - 'No {0} found. At least one {0} must be defined to run a simulation.'.format(tax_benefit_system.person_entity.key)) + raise errors.SituationParsingError( + [person_entity.plural], + f"No {person_entity.key} found. At least one {person_entity.key} must be defined to run a simulation.", + ) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) for entity_class in tax_benefit_system.group_entities: - instances_json = input_dict.get(entity_class.plural) + instances_json = params.get(entity_class.plural) + if instances_json is not None: - self.add_group_entity(self.persons_plural, persons_ids, entity_class, instances_json) + self.add_group_entity( + self.persons_plural, + persons_ids, + entity_class, + instances_json, + ) + + elif axes is not None: + message = ( + f"We could not find any specified {entity_class.plural}. " + "In order to expand over axes, all group entities and roles " + "must be fully specified. For further support, please do " + "not hesitate to take a look at the official documentation: " + "https://openfisca.org/doc/simulate/replicate-simulation-inputs.html." + ) + + raise errors.SituationParsingError([entity_class.plural], message) + else: self.add_default_group_entity(persons_ids, entity_class) - if axes: - self.axes = axes + if axes is not None: + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: self.finalize_variables_init(simulation.persons) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(simulation.persons.entity, persons_json, e) for entity_class in tax_benefit_system.group_entities: try: population = simulation.populations[entity_class.key] self.finalize_variables_init(population) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(population.entity, instances_json, e) return simulation - def build_from_variables(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. + def build_from_variables( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Variables, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` describing + variables values without expliciting entities. - This method uses :any:`build_default_simulation` to infer an entity structure + This method uses :meth:`.SimulationBuilder.build_default_simulation` to + infer an entity structure. - Example: + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. - >>> simulation_builder.build_from_variables( - {'salary': {'2016-10': 12000}} - ) - """ - count = helpers._get_person_count(input_dict) - simulation = self.build_default_simulation(tax_benefit_system, count) - for variable, value in input_dict.items(): - if not isinstance(value, dict): - if self.default_period is None: - raise SituationParsingError([variable], - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") - simulation.set_input(variable, self.default_period, value) - else: - for period_str, dated_value in value.items(): - simulation.set_input(variable, period_str, dated_value) - return simulation + Returns: + Simulation: The built simulation. - def build_default_simulation(self, tax_benefit_system, count = 1): - """ - Build a simulation where: - - There are ``count`` persons - - There are ``count`` instances of each group entity, containing one person - - Every person has, in each entity, the first role - """ + Raises: + SituationParsingError: If the input is not valid. - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) - for population in simulation.populations.values(): - population.count = count - population.ids = numpy.array(range(count)) - if not population.entity.is_person: - population.members_entity_id = population.ids # Each person is its own group entity - return simulation + Examples: + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, {"salary"}) + False - def create_entities(self, tax_benefit_system): + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, {"salary"}) + False + + """ + return ( + _BuildFromVariables(tax_benefit_system, input_dict, self.default_period) + .add_dated_values() + .add_undated_values() + .simulation + ) + + @staticmethod + def build_default_simulation( + tax_benefit_system: TaxBenefitSystem, + count: int = 1, + ) -> Simulation: + """Build a default simulation. + + Where: + - There are ``count`` persons + - There are ``count`` of each group entity, containing one person + - Every person has, in each entity, the first role + + """ + return ( + _BuildDefaultSimulation(tax_benefit_system, count) + .add_count() + .add_ids() + .add_members_entity_id() + .simulation + ) + + def create_entities(self, tax_benefit_system) -> None: self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: typing.Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable) -> None: person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural - def declare_entity(self, entity_singular, entity_ids: typing.Iterable): + def declare_entity(self, entity_singular, entity_ids: Iterable): entity_instance = self.populations[entity_singular] entity_instance.ids = numpy.array(list(entity_ids)) entity_instance.count = len(entity_instance.ids) return entity_instance - def nb_persons(self, entity_singular, role = None): - return self.populations[entity_singular].nb_persons(role = role) + def nb_persons(self, entity_singular, role=None): + return self.populations[entity_singular].nb_persons(role=role) - def join_with_persons(self, group_population, persons_group_assignment, roles: typing.Iterable[str]): + def join_with_persons( + self, + group_population, + persons_group_assignment, + roles: Iterable[str], + ) -> None: # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) - group_sorted_indices = numpy.unique(persons_group_assignment, return_inverse = True)[1] - group_population.members_entity_id = numpy.argsort(group_population.ids)[group_sorted_indices] + group_sorted_indices = numpy.unique( + persons_group_assignment, + return_inverse=True, + )[1] + group_population.members_entity_id = numpy.argsort(group_population.ids)[ + group_sorted_indices + ] flattened_roles = group_population.entity.flattened_roles roles_array = numpy.array(roles) if numpy.issubdtype(roles_array.dtype, numpy.integer): group_population.members_role = numpy.array(flattened_roles)[roles_array] + elif len(flattened_roles) == 0: + group_population.members_role = numpy.int16(0) else: - if len(flattened_roles) == 0: - group_population.members_role = numpy.int64(0) - else: - group_population.members_role = numpy.select([roles_array == role.key for role in flattened_roles], flattened_roles) + group_population.members_role = numpy.select( + [roles_array == role.key for role in flattened_roles], + flattened_roles, + ) def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations) - def explicit_singular_entities(self, tax_benefit_system, input_dict): - """ - Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut + def explicit_singular_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: ImplicitGroupEntities, + ) -> GroupEntities: + """Preprocess ``input_dict`` to explicit entities defined using the + single-entity shortcut. - Example: + Examples: + >>> params = { + ... "persons": { + ... "Javier": {}, + ... }, + ... "household": {"parents": ["Javier"]}, + ... } - >>> simulation_builder.explicit_singular_entities( - {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} - ) - >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} - """ + >>> are_entities_fully_specified(params, {"persons", "households"}) + False + + >>> are_entities_short_form(params, {"person", "household"}) + True - singular_keys = set(input_dict).intersection(tax_benefit_system.entities_by_singular()) - if not singular_keys: - return input_dict + >>> params = { + ... "persons": {"Javier": {}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> are_entities_fully_specified(params, {"persons", "households"}) + True + + >>> are_entities_short_form(params, {"person", "household"}) + False + + """ + singular_keys = set(input_dict).intersection( + tax_benefit_system.entities_by_singular(), + ) result = { entity_id: entity_description for (entity_id, entity_description) in input_dict.items() if entity_id in tax_benefit_system.entities_plural() - } # filter out the singular entities + } # filter out the singular entities for singular in singular_keys: plural = tax_benefit_system.entities_by_singular()[singular].plural @@ -230,9 +430,7 @@ def explicit_singular_entities(self, tax_benefit_system, input_dict): return result def add_person_entity(self, entity, instances_json): - """ - Add the simulation's instances of the persons entity as described in ``instances_json``. - """ + """Add the simulation's instances of the persons entity as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural @@ -245,17 +443,28 @@ def add_person_entity(self, entity, instances_json): return self.get_ids(entity.plural) - def add_default_group_entity(self, persons_ids, entity): + def add_default_group_entity( + self, + persons_ids: list[str], + entity: GroupEntity, + ) -> None: persons_count = len(persons_ids) + roles = list(entity.flattened_roles) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count - self.memberships[entity.plural] = numpy.arange(0, persons_count, dtype = numpy.int32) - self.roles[entity.plural] = numpy.repeat(entity.flattened_roles[0], persons_count) - - def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): - """ - Add all instances of one of the model's entities as described in ``instances_json``. - """ + self.memberships[entity.plural] = list( + numpy.arange(0, persons_count, dtype=numpy.int32), + ) + self.roles[entity.plural] = [roles[0]] * persons_count + + def add_group_entity( + self, + persons_plural: str, + persons_ids: list[str], + entity: GroupEntity, + instances_json, + ) -> None: + """Add all instances of one of the model's entities as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) @@ -264,8 +473,8 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): persons_count = len(persons_ids) persons_to_allocate = set(persons_ids) - self.memberships[entity.plural] = numpy.empty(persons_count, dtype = numpy.int32) - self.roles[entity.plural] = numpy.empty(persons_count, dtype = object) + self.memberships[entity.plural] = numpy.empty(persons_count, dtype=numpy.int32) + self.roles[entity.plural] = numpy.empty(persons_count, dtype=object) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) @@ -276,18 +485,31 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): variables_json = instance_object.copy() # Don't mutate function input roles_json = { - role.plural or role.key: helpers.transform_to_strict_syntax(variables_json.pop(role.plural or role.key, [])) + role.plural + or role.key: helpers.transform_to_strict_syntax( + variables_json.pop(role.plural or role.key, []), + ) for role in entity.roles - } + } for role_id, role_definition in roles_json.items(): - helpers.check_type(role_definition, list, [entity.plural, instance_id, role_id]) + helpers.check_type( + role_definition, + list, + [entity.plural, instance_id, role_id], + ) for index, person_id in enumerate(role_definition): entity_plural = entity.plural - self.check_persons_to_allocate(persons_plural, entity_plural, - persons_ids, - person_id, instance_id, role_id, - persons_to_allocate, index) + self.check_persons_to_allocate( + persons_plural, + entity_plural, + persons_ids, + person_id, + instance_id, + role_id, + persons_to_allocate, + index, + ) persons_to_allocate.discard(person_id) @@ -298,12 +520,17 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): role = role_by_plural[role_plural] if role.max is not None and len(persons_with_role) > role.max: - raise SituationParsingError([entity.plural, instance_id, role_plural], f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.") + raise errors.SituationParsingError( + [entity.plural, instance_id, role_plural], + f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.", + ) for index_within_role, person_id in enumerate(persons_with_role): person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_index - person_role = role.subroles[index_within_role] if role.subroles else role + person_role = ( + role.subroles[index_within_role] if role.subroles else role + ) self.roles[entity.plural][person_index] = person_role self.init_variable_values(entity, variables_json, instance_id) @@ -312,7 +539,9 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): entity_ids = entity_ids + list(persons_to_allocate) for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) - self.memberships[entity.plural][person_index] = entity_ids.index(person_id) + self.memberships[entity.plural][person_index] = entity_ids.index( + person_id, + ) self.roles[entity.plural][person_index] = entity.flattened_roles[0] # Adjust previously computed ids and counts self.entity_ids[entity.plural] = entity_ids @@ -322,58 +551,87 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() - def set_default_period(self, period_str): + def set_default_period(self, period_str) -> None: if period_str: self.default_period = str(periods.period(period_str)) - def get_input(self, variable, period_str): + def get_input(self, variable: str, period_str: str) -> Array | None: if variable not in self.input_buffer: self.input_buffer[variable] = {} + return self.input_buffer[variable].get(period_str) - def check_persons_to_allocate(self, persons_plural, entity_plural, - persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index): - helpers.check_type(person_id, str, [entity_plural, entity_id, role_id, str(index)]) + def check_persons_to_allocate( + self, + persons_plural, + entity_plural, + persons_ids, + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) -> None: + helpers.check_type( + person_id, + str, + [entity_plural, entity_id, role_id, str(index)], + ) if person_id not in persons_ids: - raise SituationParsingError([entity_plural, entity_id, role_id], - "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( - person_id, entity_id, role_id, persons_plural) - ) + raise errors.SituationParsingError( + [entity_plural, entity_id, role_id], + f"Unexpected value: {person_id}. {person_id} has been declared in {entity_id} {role_id}, but has not been declared in {persons_plural}.", + ) if person_id not in persons_to_allocate: - raise SituationParsingError([entity_plural, entity_id, role_id], - "{} has been declared more than once in {}".format( - person_id, entity_plural) - ) + raise errors.SituationParsingError( + [entity_plural, entity_id, role_id], + f"{person_id} has been declared more than once in {entity_plural}", + ) - def init_variable_values(self, entity, instance_object, instance_id): + def init_variable_values(self, entity, instance_object, instance_id) -> None: for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] try: entity.check_variable_defined_for_entity(variable_name) except ValueError as e: # The variable is defined for another entity - raise SituationParsingError(path_in_json, e.args[0]) - except VariableNotFoundError as e: # The variable doesn't exist - raise SituationParsingError(path_in_json, str(e), code = 404) + raise errors.SituationParsingError(path_in_json, e.args[0]) + except errors.VariableNotFoundError as e: # The variable doesn't exist + raise errors.SituationParsingError(path_in_json, str(e), code=404) instance_index = self.get_ids(entity.plural).index(instance_id) if not isinstance(variable_values, dict): if self.default_period is None: - raise SituationParsingError(path_in_json, - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") + raise errors.SituationParsingError( + path_in_json, + "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", + ) variable_values = {self.default_period: variable_values} for period_str, value in variable_values.items(): try: periods.period(period_str) except ValueError as e: - raise SituationParsingError(path_in_json, e.args[0]) + raise errors.SituationParsingError(path_in_json, e.args[0]) variable = entity.get_variable(variable_name) - self.add_variable_value(entity, variable, instance_index, instance_id, period_str, value) + self.add_variable_value( + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) - def add_variable_value(self, entity, variable, instance_index, instance_id, period_str, value): + def add_variable_value( + self, + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) -> None: path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: @@ -388,13 +646,13 @@ def add_variable_value(self, entity, variable, instance_index, instance_id, peri try: value = variable.check_set_value(value) except ValueError as error: - raise SituationParsingError(path_in_json, *error.args) + raise errors.SituationParsingError(path_in_json, *error.args) array[instance_index] = value self.input_buffer[variable.name][str(periods.period(period_str))] = array - def finalize_variables_init(self, population): + def finalize_variables_init(self, population) -> None: # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural @@ -404,15 +662,18 @@ def finalize_variables_init(self, population): if plural_key in self.memberships: population.members_entity_id = numpy.array(self.get_memberships(plural_key)) population.members_role = numpy.array(self.get_roles(plural_key)) - for variable_name in self.input_buffer.keys(): + for variable_name in self.input_buffer: try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that continue buffer = self.input_buffer[variable_name] - unsorted_periods = [periods.period(period_str) for period_str in self.input_buffer[variable_name].keys()] + unsorted_periods = [ + periods.period(period_str) + for period_str in self.input_buffer[variable_name] + ] # We need to handle small periods first for set_input to work - sorted_periods = sorted(unsorted_periods, key = periods.key_period_size) + sorted_periods = sorted(unsorted_periods, key=periods.key_period_size) for period_value in sorted_periods: values = buffer[str(period_value)] # Hack to replicate the values in the persons entity @@ -424,66 +685,85 @@ def finalize_variables_init(self, population): if (variable.end is None) or (period_value.start.date <= variable.end): holder.set_input(period_value, array) - def raise_period_mismatch(self, entity, json, e): + def raise_period_mismatch(self, entity, json, e) -> NoReturn: # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( - dpath.search(json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded = True), - None) + dpath.util.search( + json, + f"*/{e.variable_name}/{e.period!s}", + yielded=True, + ), + None, + ) if culprit_path: - path = [entity.plural] + culprit_path[0].split('/') + path = [entity.plural, *culprit_path[0].split("/")] else: - path = [entity.plural] # Fallback: if we can't find the culprit, just set the error at the entities level + path = [ + entity.plural, + ] # Fallback: if we can't find the culprit, just set the error at the entities level - raise SituationParsingError(path, e.message) + raise errors.SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes - def get_count(self, entity_name): + def get_count(self, entity_name: str) -> int: return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes - def get_ids(self, entity_name): + def get_ids(self, entity_name: str) -> list[str]: return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes def get_memberships(self, entity_name): # Return empty array for the "persons" entity - return self.axes_memberships.get(entity_name, self.memberships.get(entity_name, [])) + return self.axes_memberships.get( + entity_name, + self.memberships.get(entity_name, []), + ) # Returns the roles of individuals in this entity, including when there is replication along axes - def get_roles(self, entity_name): + def get_roles(self, entity_name: str) -> Sequence[Role]: # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): + def add_parallel_axis(self, axis: Axis) -> None: # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis) - def add_perpendicular_axis(self, axis): + def add_perpendicular_axis(self, axis: Axis) -> None: # This adds an axis perpendicular to all previous dimensions self.axes.append([axis]) - def expand_axes(self): + def expand_axes(self) -> None: # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + perpendicular_dimensions: list[list[Axis]] = self.axes + cell_count: int = 1 - cell_count = 1 for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis['count'] + first_axis: Axis = parallel_axes[0] + axis_count: int = first_axis["count"] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times - for entity_name in self.entity_counts.keys(): + for entity_name in self.entity_counts: # Adjust counts - self.axes_entity_counts[entity_name] = self.get_count(entity_name) * cell_count + self.axes_entity_counts[entity_name] = ( + self.get_count(entity_name) * cell_count + ) # Adjust ids - original_ids = self.get_ids(entity_name) * cell_count - indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) - adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] + original_ids: list[str] = self.get_ids(entity_name) * cell_count + indices: Array[numpy.int16] = numpy.arange( + 0, + cell_count * self.entity_counts[entity_name], + ) + adjusted_ids: list[str] = [ + original_id + str(index) + for original_id, index in zip(original_ids, indices) + ] self.axes_entity_ids[entity_name] = adjusted_ids + # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count @@ -492,8 +772,13 @@ def expand_axes(self): if entity_name != self.persons_plural: original_memberships = self.get_memberships(entity_name) repeated_memberships = original_memberships * cell_count - indices = numpy.repeat(numpy.arange(0, cell_count), len(original_memberships)) * self.entity_counts[entity_name] - adjusted_memberships = (numpy.array(repeated_memberships) + indices).tolist() + indices = ( + numpy.repeat(numpy.arange(0, cell_count), len(original_memberships)) + * self.entity_counts[entity_name] + ) + adjusted_memberships = ( + numpy.array(repeated_memberships) + indices + ).tolist() self.axes_memberships[entity_name] = adjusted_memberships # Now generate input values along the specified axes @@ -501,61 +786,72 @@ def expand_axes(self): if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] - axis_count: int = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count: int = first_axis["count"] + axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis.get('period', self.default_period) - axis_name = axis['name'] + axis_index = axis.get("index", 0) + axis_period = axis.get("period", self.default_period) + axis_name = axis["name"] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(axis_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, axis_count) - array[axis_index:: axis_entity_step_size] = numpy.linspace( - axis['min'], - axis['max'], - num = axis_count, - ) + array[axis_index::axis_entity_step_size] = numpy.linspace( + axis["min"], + axis["max"], + num=axis_count, + ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: typing.List[int] = ( - parallel_axes[0]["count"] - for parallel_axes - in self.axes - ) + first_axes_count: list[int] = ( + parallel_axes[0]["count"] for parallel_axes in self.axes + ) axes_linspaces = [ - numpy.linspace(0, axis_count - 1, num = axis_count) - for axis_count - in first_axes_count - ] + numpy.linspace(0, axis_count - 1, num=axis_count) + for axis_count in first_axes_count + ] axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): first_axis = parallel_axes[0] - axis_count = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count = first_axis["count"] + axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis['period'] or self.default_period - axis_name = axis['name'] - variable = axis_entity.get_variable(axis_name) + axis_index = axis.get("index", 0) + axis_period = axis.get("period", self.default_period) + axis_name = axis["name"] + variable = axis_entity.get_variable(axis_name, check_existence=True) array = self.get_input(axis_name, str(axis_period)) if array is None: - array = variable.default_array(cell_count * axis_entity_step_size) + array = variable.default_array( + cell_count * axis_entity_step_size, + ) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) - array[axis_index:: axis_entity_step_size] = axis['min'] \ - + mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1) + array[axis_index::axis_entity_step_size] = axis[ + "min" + ] + mesh.reshape(cell_count) * (axis["max"] - axis["min"]) / ( + axis_count - 1 + ) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name): + def get_variable_entity(self, variable_name: str) -> Entity: return self.variable_entities[variable_name] - def register_variable(self, variable_name, entity): + def register_variable(self, variable_name: str, entity: Entity) -> None: self.variable_entities[variable_name] = entity + + def register_variables(self, simulation: Simulation) -> None: + tax_benefit_system: TaxBenefitSystem = simulation.tax_benefit_system + variables: Iterable[str] = tax_benefit_system.variables.keys() + + for name in variables: + population: Population = simulation.get_variable_population(name) + entity: Entity = population.entity + self.register_variable(name, entity) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py new file mode 100644 index 0000000000..8091994e53 --- /dev/null +++ b/openfisca_core/simulations/typing.py @@ -0,0 +1,203 @@ +"""Type aliases of OpenFisca models to use in the context of simulations.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import Protocol, TypeVar, TypedDict, Union +from typing_extensions import NotRequired, Required, TypeAlias + +import datetime +from abc import abstractmethod + +from numpy import ( + bool_ as Bool, + datetime64 as Date, + float32 as Float, + int16 as Enum, + int32 as Int, + str_ as String, +) + +#: Generic type variables. +E = TypeVar("E") +G = TypeVar("G", covariant=True) +T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) +U = TypeVar("U", bool, datetime.date, float, str) +V = TypeVar("V", covariant=True) + + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] + +#: Type alias for a simulation dictionary with undated variables. +UndatedVariable: TypeAlias = dict[str, object] + +#: Type alias for a simulation dictionary with dated variables. +DatedVariable: TypeAlias = dict[str, UndatedVariable] + +#: Type alias for a simulation dictionary with abbreviated entities. +Variables: TypeAlias = dict[str, Union[UndatedVariable, DatedVariable]] + +#: Type alias for a simulation with fully specified single entities. +SingleEntities: TypeAlias = dict[str, dict[str, Variables]] + +#: Type alias for a simulation dictionary with implicit group entities. +ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] + +#: Type alias for a simulation dictionary with explicit group entities. +GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] + +#: Type alias for a simulation dictionary with fully specified entities. +FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] + +#: Type alias for a simulation dictionary with axes parameters. +Axes: TypeAlias = dict[str, Iterable[Iterable["Axis"]]] + +#: Type alias for a simulation dictionary without axes parameters. +ParamsWithoutAxes: TypeAlias = Union[ + Variables, + ImplicitGroupEntities, + FullySpecifiedEntities, +] + +#: Type alias for a simulation dictionary with axes parameters. +ParamsWithAxes: TypeAlias = Union[Axes, ParamsWithoutAxes] + +#: Type alias for a simulation dictionary with all the possible scenarios. +Params: TypeAlias = ParamsWithAxes + + +class Axis(TypedDict, total=False): + """Interface representing an axis of a simulation.""" + + count: Required[int] + index: NotRequired[int] + max: Required[float] + min: Required[float] + name: Required[str] + period: NotRequired[str | int] + + +class Entity(Protocol): + """Interface representing an entity of a simulation.""" + + key: str + plural: str | None + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> Variable[T] | None: + """Get a variable.""" + + +class SingleEntity(Entity, Protocol): + """Interface representing a single entity of a simulation.""" + + +class GroupEntity(Entity, Protocol): + """Interface representing a group entity of a simulation.""" + + @property + @abstractmethod + def flattened_roles(self) -> Iterable[Role[G]]: + """Get the flattened roles of the GroupEntity.""" + + +class Holder(Protocol[V]): + """Interface representing a holder of a simulation's computed values.""" + + @property + @abstractmethod + def variable(self) -> Variable[T]: + """Get the Variable of the Holder.""" + + def get_array(self, __period: str) -> Array[T] | None: + """Get the values of the Variable for a given Period.""" + + def set_input( + self, + __period: Period, + __array: Array[T] | Sequence[U], + ) -> Array[T] | None: + """Set values for a Variable for a given Period.""" + + +class Period(Protocol): + """Interface representing a period of a simulation.""" + + +class Population(Protocol[E]): + """Interface representing a data vector of an Entity.""" + + count: int + entity: E + ids: Array[String] + + def get_holder(self, __variable_name: str) -> Holder[V]: + """Get the holder of a Variable.""" + + +class SinglePopulation(Population[E], Protocol): + """Interface representing a data vector of a SingleEntity.""" + + +class GroupPopulation(Population[E], Protocol): + """Interface representing a data vector of a GroupEntity.""" + + members_entity_id: Array[String] + + def nb_persons(self, __role: Role[G] | None = ...) -> int: + """Get the number of persons for a given Role.""" + + +class Role(Protocol[G]): + """Interface representing a role of the group entities of a simulation.""" + + +class TaxBenefitSystem(Protocol): + """Interface representing a tax-benefit system.""" + + @property + @abstractmethod + def person_entity(self) -> SingleEntity: + """Get the person entity of the tax-benefit system.""" + + @person_entity.setter + @abstractmethod + def person_entity(self, person_entity: SingleEntity) -> None: + """Set the person entity of the tax-benefit system.""" + + @property + @abstractmethod + def variables(self) -> dict[str, V]: + """Get the variables of the tax-benefit system.""" + + def entities_by_singular(self) -> dict[str, E]: + """Get the singular form of the entities' keys.""" + + def entities_plural(self) -> Iterable[str]: + """Get the plural form of the entities' keys.""" + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> V | None: + """Get a variable.""" + + def instantiate_entities( + self, + ) -> dict[str, Population[E]]: + """Instantiate the populations of each Entity.""" + + +class Variable(Protocol[T]): + """Interface representing a variable of a tax-benefit system.""" + + end: str + + def default_array(self, __array_size: int) -> Array[T]: + """Fill an array with the default value of the Variable.""" diff --git a/openfisca_core/taxbenefitsystems/__init__.py b/openfisca_core/taxbenefitsystems/__init__.py index 05a2deb36b..bf5f224c2c 100644 --- a/openfisca_core/taxbenefitsystems/__init__.py +++ b/openfisca_core/taxbenefitsystems/__init__.py @@ -21,4 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports +from openfisca_core.errors import VariableNameConflict, VariableNotFound # noqa: F401 + from .tax_benefit_system import TaxBenefitSystem # noqa: F401 diff --git a/openfisca_core/taxbenefitsystems/tax_benefit_system.py b/openfisca_core/taxbenefitsystems/tax_benefit_system.py index 26e37a7b81..1c582e2407 100644 --- a/openfisca_core/taxbenefitsystems/tax_benefit_system.py +++ b/openfisca_core/taxbenefitsystems/tax_benefit_system.py @@ -1,20 +1,30 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from openfisca_core.types import ParameterNodeAtInstant + +import ast import copy +import functools import glob import importlib +import importlib.metadata +import importlib.util import inspect +import linecache import logging import os -import pkg_resources +import sys import traceback -import typing -from imp import find_module, load_module from openfisca_core import commons, periods, variables from openfisca_core.entities import Entity from openfisca_core.errors import VariableNameConflictError, VariableNotFoundError from openfisca_core.parameters import ParameterNode from openfisca_core.periods import Instant, Period -from openfisca_core.populations import Population, GroupPopulation +from openfisca_core.populations import GroupPopulation, Population from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable @@ -22,39 +32,45 @@ class TaxBenefitSystem: - """ - Represents the legislation. + """Represents the legislation. - It stores parameters (values defined for everyone) and variables (values defined for some given entity e.g. a person). + It stores parameters (values defined for everyone) and variables (values + defined for some given entity e.g. a person). - :param entities: Entities used by the tax benefit system. - :param string parameters: Directory containing the YAML parameter files. + Attributes: + parameters: Directory containing the YAML parameter files. + Args: + entities: Entities used by the tax benefit system. - .. attribute:: parameters - - :obj:`.ParameterNode` containing the legislation parameters """ + + person_entity: Entity + _base_tax_benefit_system = None - _parameters_at_instant_cache = None + _parameters_at_instant_cache: dict[Instant, ParameterNodeAtInstant] = {} person_key_plural = None preprocess_parameters = None baseline = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. cache_blacklist = None decomposition_file_path = None - def __init__(self, entities): + def __init__(self, entities: Sequence[Entity]) -> None: # TODO: Currently: Don't use a weakref, because they are cleared by Paste (at least) at each call. - self.parameters = None - self._parameters_at_instant_cache = {} # weakref.WeakValueDictionary() - self.variables = {} - self.open_api_config = {} + self.parameters: ParameterNode | None = None + self.variables: dict[Any, Any] = {} + self.open_api_config: dict[Any, Any] = {} # Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them if entities is None or len(entities) == 0: - raise Exception("A tax and benefit sytem must have at least an entity.") + msg = "A tax and benefit system must have at least an entity." + raise Exception(msg) self.entities = [copy.copy(entity) for entity in entities] - self.person_entity = [entity for entity in self.entities if entity.is_person][0] - self.group_entities = [entity for entity in self.entities if not entity.is_person] + self.person_entity = next( + entity for entity in self.entities if entity.is_person + ) + self.group_entities = [ + entity for entity in self.entities if not entity.is_person + ] for entity in self.entities: entity.set_tax_benefit_system(self) @@ -65,13 +81,15 @@ def base_tax_benefit_system(self): baseline = self.baseline if baseline is None: return self - self._base_tax_benefit_system = base_tax_benefit_system = baseline.base_tax_benefit_system + self._base_tax_benefit_system = base_tax_benefit_system = ( + baseline.base_tax_benefit_system + ) return base_tax_benefit_system def instantiate_entities(self): person = self.person_entity members = Population(person) - entities: typing.Dict[Entity.key, Entity] = {person.key: members} + entities: dict[Entity.key, Entity] = {person.key: members} for entity in self.group_entities: entities[entity.key] = GroupPopulation(entity, members) @@ -80,8 +98,8 @@ def instantiate_entities(self): # Deprecated method of constructing simulations, to be phased out in favor of SimulationBuilder def new_scenario(self): - class ScenarioAdapter(object): - def __init__(self, tax_benefit_system): + class ScenarioAdapter: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system def init_from_attributes(self, **attributes): @@ -91,10 +109,16 @@ def init_from_attributes(self, **attributes): def init_from_dict(self, dict): self.attributes = None self.dict = dict - self.period = dict.pop('period') + self.period = dict.pop("period") return self - def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = False, trace = False): + def new_simulation( + self, + debug=False, + opt_out_cache=False, + use_baseline=False, + trace=False, + ): # Legacy from scenarios, used in reforms tax_benefit_system = self.tax_benefit_system if use_baseline: @@ -106,13 +130,19 @@ def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = Fa builder = SimulationBuilder() if self.attributes: - variables = self.attributes.get('input_variables') or {} - period = self.attributes.get('period') + variables = self.attributes.get("input_variables") or {} + period = self.attributes.get("period") builder.set_default_period(period) - simulation = builder.build_from_variables(tax_benefit_system, variables) + simulation = builder.build_from_variables( + tax_benefit_system, + variables, + ) else: builder.set_default_period(self.period) - simulation = builder.build_from_entities(tax_benefit_system, self.dict) + simulation = builder.build_from_entities( + tax_benefit_system, + self.dict, + ) simulation.trace = trace simulation.debug = debug @@ -122,93 +152,134 @@ def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = Fa return ScenarioAdapter(self) - def prefill_cache(self): + def prefill_cache(self) -> None: pass - def load_variable(self, variable_class, update = False): + def load_variable(self, variable_class, update=False): name = variable_class.__name__ # Check if a Variable with the same name is already registered. baseline_variable = self.get_variable(name) if baseline_variable and not update: + msg = f'Variable "{name}" is already defined. Use `update_variable` to replace it.' raise VariableNameConflictError( - 'Variable "{}" is already defined. Use `update_variable` to replace it.'.format(name)) + msg, + ) - variable = variable_class(baseline_variable = baseline_variable) + variable = variable_class(baseline_variable=baseline_variable) self.variables[variable.name] = variable return variable - def add_variable(self, variable): - """ - Adds an OpenFisca variable to the tax and benefit system. + def add_variable(self, variable: Variable) -> Variable: + """Adds an OpenFisca variable to the tax and benefit system. - :param .Variable variable: The variable to add. Must be a subclass of Variable. + Args: + variable: The variable to add. Must be a subclass of Variable. - :raises: :exc:`.VariableNameConflictError` if a variable with the same name have previously been added to the tax and benefit system. - """ - return self.load_variable(variable, update = False) + Raises: + openfisca_core.errors.VariableNameConflictError: if a variable with the same name have previously been added to the tax and benefit system. - def replace_variable(self, variable): """ - Replaces an existing OpenFisca variable in the tax and benefit system by a new one. + return self.load_variable(variable, update=False) + + def replace_variable(self, variable: Variable) -> None: + """Replaces an existing variable by a new one. The new variable must have the same name than the replaced one. - If no variable with the given name exists in the tax and benefit system, no error will be raised and the new variable will be simply added. + If no variable with the given name exists in the Tax-Benefit system, no + error will be raised and the new variable will be simply added. + + Args: + variable: The variable to replace. - :param Variable variable: New variable to add. Must be a subclass of Variable. """ name = variable.__name__ + if self.variables.get(name) is not None: del self.variables[name] - self.load_variable(variable, update = False) - def update_variable(self, variable): - """ - Updates an existing OpenFisca variable in the tax and benefit system. + self.load_variable(variable, update=False) + + def update_variable(self, variable: Variable) -> Variable: + """Update an existing variable in the Tax-Benefit system. - All attributes of the updated variable that are not explicitely overridden by the new ``variable`` will stay unchanged. + All attributes of the updated variable that are not explicitly + overridden by the new ``variable`` will stay unchanged. The new variable must have the same name than the updated one. - If no variable with the given name exists in the tax and benefit system, no error will be raised and the new variable will be simply added. + If no variable with the given name exists in the tax and benefit + system, no error will be raised and the new variable will be simply + added. - :param Variable variable: Variable to add. Must be a subclass of Variable. - """ - return self.load_variable(variable, update = True) + Args: + variable: Variable to add. Must be a subclass of Variable. + + Returns: + The added variable. - def add_variables_from_file(self, file_path): - """ - Adds all OpenFisca variables contained in a given file to the tax and benefit system. """ + return self.load_variable(variable, update=True) + + def add_variables_from_file(self, file_path) -> None: + """Adds all OpenFisca variables contained in a given file to the tax and benefit system.""" try: + source_file_path = file_path.replace( + self.get_package_metadata()["location"], + "", + ) + file_name = os.path.splitext(os.path.basename(file_path))[0] # As Python remembers loaded modules by name, in order to prevent collisions, we need to make sure that: # - Files with the same name, but located in different directories, have a different module names. Hence the file path hash in the module name. # - The same file, loaded by different tax and benefit systems, has distinct module names. Hence the `id(self)` in the module name. - module_name = '{}_{}_{}'.format(id(self), hash(os.path.abspath(file_path)), file_name) + module_name = f"{id(self)}_{hash(os.path.abspath(file_path))}_{file_name}" - module_directory = os.path.dirname(file_path) try: - module = load_module(module_name, *find_module(file_name, [module_directory])) + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + lines = linecache.getlines(file_path, module.__dict__) + source = "".join(lines) + tree = ast.parse(source) + defs = {i.name: i for i in tree.body if isinstance(i, ast.ClassDef)} + spec.loader.exec_module(module) + except NameError as e: - logging.error(str(e) + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ") + logging.exception( + str(e) + + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ", + ) raise - potential_variables = [getattr(module, item) for item in dir(module) if not item.startswith('__')] + potential_variables = [ + getattr(module, item) + for item in dir(module) + if not item.startswith("__") + ] for pot_variable in potential_variables: # We only want to get the module classes defined in this module (not imported) - if inspect.isclass(pot_variable) and issubclass(pot_variable, Variable) and pot_variable.__module__ == module_name: + if ( + inspect.isclass(pot_variable) + and issubclass(pot_variable, Variable) + and pot_variable.__module__ == module_name + ): + class_def = defs[pot_variable.__name__] + pot_variable.introspection_data = ( + source_file_path, + "".join(lines[class_def.lineno - 1 : class_def.end_lineno]), + class_def.lineno - 1, + ) self.add_variable(pot_variable) except Exception: - log.error('Unable to load OpenFisca variables from file "{}"'.format(file_path)) + log.exception(f'Unable to load OpenFisca variables from file "{file_path}"') raise - def add_variables_from_directory(self, directory): - """ - Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system. - """ + def add_variables_from_directory(self, directory) -> None: + """Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system.""" py_files = glob.glob(os.path.join(directory, "*.py")) for py_file in py_files: self.add_variables_from_file(py_file) @@ -216,20 +287,18 @@ def add_variables_from_directory(self, directory): for subdirectory in subdirectories: self.add_variables_from_directory(subdirectory) - def add_variables(self, *variables): - """ - Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. + def add_variables(self, *variables) -> None: + """Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. See also :any:`add_variable` """ for variable in variables: self.add_variable(variable) - def load_extension(self, extension): - """ - Loads an extension to the tax and benefit system. + def load_extension(self, extension) -> None: + """Loads an extension to the tax and benefit system. - :param string extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. + :param str extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. """ # Load extension from installed pip package @@ -237,91 +306,130 @@ def load_extension(self, extension): package = importlib.import_module(extension) extension_directory = package.__path__[0] except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'Error loading extension: `{}` is neither a directory, nor a package.'.format(extension), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .']) + message = os.linesep.join( + [ + traceback.format_exc(), + f"Error loading extension: `{extension}` is neither a directory, nor a package.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + ], + ) raise ValueError(message) self.add_variables_from_directory(extension_directory) - param_dir = os.path.join(extension_directory, 'parameters') + param_dir = os.path.join(extension_directory, "parameters") if os.path.isdir(param_dir): - extension_parameters = ParameterNode(directory_path = param_dir) + extension_parameters = ParameterNode(directory_path=param_dir) self.parameters.merge(extension_parameters) - def apply_reform(self, reform_path): - """ - Generates a new tax and benefit system applying a reform to the tax and benefit system. + def apply_reform(self, reform_path: str) -> TaxBenefitSystem: + """Generates a new tax and benefit system applying a reform to the tax and benefit system. The current tax and benefit system is **not** mutated. - :param string reform_path: The reform to apply. Must respect the format *installed_package.sub_module.reform* + Args: + reform_path: The reform to apply. Must respect the format *installed_package.sub_module.reform* - :returns: A reformed tax and benefit system. + Returns: + TaxBenefitSystem: A reformed tax and benefit system. Example: - - >>> self.apply_reform('openfisca_france.reforms.inversion_revenus') + >>> self.apply_reform("openfisca_france.reforms.inversion_revenus") """ from openfisca_core.reforms import Reform + try: - reform_package, reform_name = reform_path.rsplit('.', 1) + reform_package, reform_name = reform_path.rsplit(".", 1) except ValueError: - raise ValueError('`{}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`'.format(reform_path)) + msg = f"`{reform_path}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`" + raise ValueError( + msg, + ) try: reform_module = importlib.import_module(reform_package) except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'Could not import `{}`.'.format(reform_package), - 'Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.']) + message = os.linesep.join( + [ + traceback.format_exc(), + f"Could not import `{reform_package}`.", + "Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.", + ], + ) raise ValueError(message) reform = getattr(reform_module, reform_name, None) if reform is None: - raise ValueError('{} has no attribute {}'.format(reform_package, reform_name)) + msg = f"{reform_package} has no attribute {reform_name}" + raise ValueError(msg) if not issubclass(reform, Reform): - raise ValueError('`{}` does not seem to be a valid Openfisca reform.'.format(reform_path)) + msg = f"`{reform_path}` does not seem to be a valid Openfisca reform." + raise ValueError( + msg, + ) return reform(self) - def get_variable(self, variable_name, check_existence = False): - """ - Get a variable from the tax and benefit system. + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> Variable | None: + """Get a variable from the tax and benefit system. :param variable_name: Name of the requested variable. :param check_existence: If True, raise an error if the requested variable does not exist. """ - variables = self.variables - found = variables.get(variable_name) - if not found and check_existence: - raise VariableNotFoundError(variable_name, self) - return found + variables: dict[str, Variable | None] = self.variables + variable: Variable | None = variables.get(variable_name) - def neutralize_variable(self, variable_name): - """ - Neutralizes an OpenFisca variable existing in the tax and benefit system. + if isinstance(variable, Variable): + return variable + + if not isinstance(variable, Variable) and not check_existence: + return variable + + raise VariableNotFoundError(variable_name, self) + + def neutralize_variable(self, variable_name: str) -> None: + """Neutralizes an OpenFisca variable existing in the tax and benefit system. A neutralized variable always returns its default value when computed. Trying to set inputs for a neutralized variable has no effect except raising a warning. """ - self.variables[variable_name] = variables.get_neutralized_variable(self.get_variable(variable_name)) + self.variables[variable_name] = variables.get_neutralized_variable( + self.get_variable(variable_name), + ) + + def annualize_variable( + self, + variable_name: str, + period: Period | None = None, + ) -> None: + check: bool + variable: Variable | None + annualised_variable: Variable + + check = bool(period) + variable = self.get_variable(variable_name, check) + + if variable is None: + raise VariableNotFoundError(variable_name, self) - def annualize_variable(self, variable_name: str, period: typing.Optional[Period] = None): - self.variables[variable_name] = variables.get_annualized_variable(self.get_variable(variable_name, period)) + annualised_variable = variables.get_annualized_variable(variable) - def load_parameters(self, path_to_yaml_dir): - """ - Loads the legislation parameter for a directory containing YAML parameters files. + self.variables[variable_name] = annualised_variable + + def load_parameters(self, path_to_yaml_dir) -> None: + """Loads the legislation parameter for a directory containing YAML parameters files. :param path_to_yaml_dir: Absolute path towards the YAML parameter directory. Example: + >>> self.load_parameters("/path/to/yaml/parameters/dir") - >>> self.load_parameters('/path/to/yaml/parameters/dir') """ - - parameters = ParameterNode('', directory_path = path_to_yaml_dir) + parameters = ParameterNode("", directory_path=path_to_yaml_dir) if self.preprocess_parameters is not None: parameters = self.preprocess_parameters(parameters) @@ -334,36 +442,48 @@ def _get_baseline_parameters_at_instant(self, instant): return self.get_parameters_at_instant(instant) return baseline._get_baseline_parameters_at_instant(instant) - def get_parameters_at_instant(self, instant): - """ - Get the parameters of the legislation at a given instant + @functools.lru_cache + def get_parameters_at_instant( + self, + instant: str | int | Period | Instant, + ) -> ParameterNodeAtInstant | None: + """Get the parameters of the legislation at a given instant. + + Args: + instant: :obj:`str` formatted "YYYY-MM-DD" or :class:`~openfisca_core.periods.Instant`. + + Returns: + The parameters of the legislation at a given instant. - :param instant: :obj:`str` of the format 'YYYY-MM-DD' or :class:`.Instant` instance. - :returns: The parameters of the legislation at a given instant. - :rtype: :class:`.ParameterNodeAtInstant` """ - if isinstance(instant, Period): - instant = instant.start + key: Instant | None + msg: str + + if isinstance(instant, Instant): + key = instant + + elif isinstance(instant, Period): + key = instant.start + elif isinstance(instant, (str, int)): - instant = periods.instant(instant) + key = periods.instant(instant) + else: - assert isinstance(instant, Instant), "Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {}.".format(instant) + msg = f"Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {key}." + raise AssertionError(msg) - parameters_at_instant = self._parameters_at_instant_cache.get(instant) - if parameters_at_instant is None and self.parameters is not None: - parameters_at_instant = self.parameters.get_at_instant(str(instant)) - self._parameters_at_instant_cache[instant] = parameters_at_instant - return parameters_at_instant + if self.parameters is None: + return None - def get_package_metadata(self): - """ - Gets metatada relative to the country package the tax and benefit system is built from. + return self.parameters.get_at_instant(key) - :returns: Country package metadata - :rtype: dict + def get_package_metadata(self) -> dict[str, str]: + """Gets metadata relative to the country package. - Example: + Returns: + A dictionary with the country package metadata + Example: >>> tax_benefit_system.get_package_metadata() >>> { >>> 'location': '/path/to/dir/containing/package', @@ -371,75 +491,93 @@ def get_package_metadata(self): >>> 'repository_url': 'https://github.com/openfisca/openfisca-france', >>> 'version': '17.2.0' >>> } + """ # Handle reforms if self.baseline: return self.baseline.get_package_metadata() - fallback_metadata = { - 'name': self.__class__.__name__, - 'version': '', - 'repository_url': '', - 'location': '', - } - module = inspect.getmodule(self) - if not module.__package__: - return fallback_metadata - package_name = module.__package__.split('.')[0] + try: - distribution = pkg_resources.get_distribution(package_name) - except pkg_resources.DistributionNotFound: - return fallback_metadata + source_file = inspect.getsourcefile(module) + package_name = module.__package__.split(".")[0] + distribution = importlib.metadata.distribution(package_name) + source_metadata = distribution.metadata + except Exception as e: + log.warning("Unable to load package metadata, exposing default metadata", e) + source_metadata = { + "Name": self.__class__.__name__, + "Version": "0.0.0", + "Home-page": "https://openfisca.org", + } - location = inspect.getsourcefile(module).split(package_name)[0].rstrip('/') + try: + source_file = inspect.getsourcefile(module) + location = source_file.split(package_name)[0].rstrip("/") + except Exception as e: + log.warning("Unable to load package source folder", e) + location = "_unknown_" + + repository_url = "" + if source_metadata.get("Project-URL"): # pyproject.toml metadata format + repository_url = next( + filter( + lambda url: url.startswith("Repository"), + source_metadata.get_all("Project-URL"), + ), + ).split("Repository, ")[-1] + else: # setup.py format + repository_url = source_metadata.get("Home-page") - home_page_metadatas = [ - metadata.split(':', 1)[1].strip(' ') - for metadata in distribution._get_metadata(distribution.PKG_INFO) if 'Home-page' in metadata - ] - repository_url = home_page_metadatas[0] if home_page_metadatas else '' return { - 'name': distribution.key, - 'version': distribution.version, - 'repository_url': repository_url, - 'location': location, - } + "name": source_metadata.get("Name").lower(), + "version": source_metadata.get("Version"), + "repository_url": repository_url, + "location": location, + } - def get_variables(self, entity = None): - """ - Gets all variables contained in a tax and benefit system. + def get_variables( + self, + entity: Entity | None = None, + ) -> dict[str, Variable]: + """Gets all variables contained in a tax and benefit system. - :param .Entity entity: If set, returns only the variable defined for the given entity. + Args: + entity: If set, returns the variable defined for the given entity. - :returns: A dictionnary, indexed by variable names. - :rtype: dict + Returns: + A dictionary, indexed by variable names. """ if not entity: return self.variables - else: - return { - variable_name: variable - for variable_name, variable in self.variables.items() - # TODO - because entities are copied (see constructor) they can't be compared - if variable.entity.key == entity.key - } + return { + variable_name: variable + for variable_name, variable in self.variables.items() + # TODO - because entities are copied (see constructor) they can't be compared + if variable.entity.key == entity.key + } def clone(self): new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('parameters', '_parameters_at_instant_cache', 'variables', 'open_api_config'): + if key not in ( + "parameters", + "_parameters_at_instant_cache", + "variables", + "open_api_config", + ): new_dict[key] = value - for entity in new_dict['entities']: + for entity in new_dict["entities"]: entity.set_tax_benefit_system(new) - new_dict['parameters'] = self.parameters.clone() - new_dict['_parameters_at_instant_cache'] = {} - new_dict['variables'] = self.variables.copy() - new_dict['open_api_config'] = self.open_api_config.copy() + new_dict["parameters"] = self.parameters.clone() + new_dict["_parameters_at_instant_cache"] = {} + new_dict["variables"] = self.variables.copy() + new_dict["open_api_config"] = self.open_api_config.copy() return new def entities_plural(self): diff --git a/openfisca_core/taxscales/__init__.py b/openfisca_core/taxscales/__init__.py index 0e074b2e6e..1911d20c56 100644 --- a/openfisca_core/taxscales/__init__.py +++ b/openfisca_core/taxscales/__init__.py @@ -21,13 +21,15 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import combine_tax_scales # noqa: F401 -from .tax_scale_like import TaxScaleLike # noqa: F401 -from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 -from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 -from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 +from openfisca_core.errors import EmptyArgumentError # noqa: F401 + +from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 from .abstract_tax_scale import AbstractTaxScale # noqa: F401 from .amount_tax_scale_like import AmountTaxScaleLike # noqa: F401 -from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 +from .helpers import combine_tax_scales # noqa: F401 +from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 from .marginal_amount_tax_scale import MarginalAmountTaxScale # noqa: F401 +from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 +from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 from .single_amount_tax_scale import SingleAmountTaxScale # noqa: F401 +from .tax_scale_like import TaxScaleLike # noqa: F401 diff --git a/openfisca_core/taxscales/abstract_rate_tax_scale.py b/openfisca_core/taxscales/abstract_rate_tax_scale.py index b9316273d1..9d828ed673 100644 --- a/openfisca_core/taxscales/abstract_rate_tax_scale.py +++ b/openfisca_core/taxscales/abstract_rate_tax_scale.py @@ -1,41 +1,42 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import RateTaxScaleLike +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractRateTaxScale(RateTaxScaleLike): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ def __init__( - self, name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: message = [ "The 'AbstractRateTaxScale' class has been deprecated since", "version 34.7.0, and will be removed in the future.", - ] + ] - warnings.warn(" ".join(message), DeprecationWarning) + warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) super().__init__(name, option, unit) def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> typing.NoReturn: + self, + tax_base: NumericalArray, + right: bool, + ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) diff --git a/openfisca_core/taxscales/abstract_tax_scale.py b/openfisca_core/taxscales/abstract_tax_scale.py index 9cbeeb7565..de9a6348c5 100644 --- a/openfisca_core/taxscales/abstract_tax_scale.py +++ b/openfisca_core/taxscales/abstract_tax_scale.py @@ -1,55 +1,54 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import TaxScaleLike +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractTaxScale(TaxScaleLike): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: numpy.int_ = None, - ) -> None: - + self, + name: str | None = None, + option: typing.Any = None, + unit: numpy.int16 = None, + ) -> None: message = [ "The 'AbstractTaxScale' class has been deprecated since", "version 34.7.0, and will be removed in the future.", - ] + ] - warnings.warn(" ".join(message), DeprecationWarning) + warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) super().__init__(name, option, unit) def __repr__(self) -> typing.NoReturn: + msg = "Method '__repr__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__repr__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> typing.NoReturn: + self, + tax_base: NumericalArray, + right: bool, + ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def to_dict(self) -> typing.NoReturn: + msg = f"Method 'to_dict' is not implemented for {self.__class__.__name__}" raise NotImplementedError( - f"Method 'to_dict' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) diff --git a/openfisca_core/taxscales/amount_tax_scale_like.py b/openfisca_core/taxscales/amount_tax_scale_like.py index cfc0a6973f..1dc9acf4b3 100644 --- a/openfisca_core/taxscales/amount_tax_scale_like.py +++ b/openfisca_core/taxscales/amount_tax_scale_like.py @@ -1,26 +1,27 @@ +import typing + import abc import bisect import os -import typing from openfisca_core import tools -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike class AmountTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of amount-based tax scales: single amount, + """Base class for various types of amount-based tax scales: single amount, marginal amount... """ - amounts: typing.List + amounts: list def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: typing.Optional[str] = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: super().__init__(name, option, unit) self.amounts = [] @@ -29,17 +30,16 @@ def __repr__(self) -> str: os.linesep.join( [ f"- threshold: {threshold}{os.linesep} amount: {amount}" - for (threshold, amount) - in zip(self.thresholds, self.amounts) - ] - ) - ) + for (threshold, amount) in zip(self.thresholds, self.amounts) + ], + ), + ) def add_bracket( - self, - threshold: int, - amount: typing.Union[int, float], - ) -> None: + self, + threshold: int, + amount: typing.Union[int, float], + ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) self.amounts[i] += amount @@ -52,6 +52,5 @@ def add_bracket( def to_dict(self) -> dict: return { str(threshold): self.amounts[index] - for index, threshold - in enumerate(self.thresholds) - } + for index, threshold in enumerate(self.thresholds) + } diff --git a/openfisca_core/taxscales/helpers.py b/openfisca_core/taxscales/helpers.py index 181fbfed36..687db41a3b 100644 --- a/openfisca_core/taxscales/helpers.py +++ b/openfisca_core/taxscales/helpers.py @@ -1,8 +1,9 @@ from __future__ import annotations -import logging import typing +import logging + from openfisca_core import taxscales log = logging.getLogger(__name__) @@ -14,21 +15,19 @@ def combine_tax_scales( - node: ParameterNodeAtInstant, - combined_tax_scales: TaxScales = None, - ) -> TaxScales: - """ - Combine all the MarginalRateTaxScales in the node into a single + node: ParameterNodeAtInstant, + combined_tax_scales: TaxScales = None, +) -> TaxScales: + """Combine all the MarginalRateTaxScales in the node into a single MarginalRateTaxScale. """ - name = next(iter(node or []), None) if name is None: return combined_tax_scales if combined_tax_scales is None: - combined_tax_scales = taxscales.MarginalRateTaxScale(name = name) + combined_tax_scales = taxscales.MarginalRateTaxScale(name=name) combined_tax_scales.add_bracket(0, 0) for child_name in node: @@ -41,6 +40,6 @@ def combine_tax_scales( log.info( f"Skipping {child_name} with value {child} " "because it is not a marginal rate tax scale", - ) + ) return combined_tax_scales diff --git a/openfisca_core/taxscales/linear_average_rate_tax_scale.py b/openfisca_core/taxscales/linear_average_rate_tax_scale.py index d1fe9c8094..ffccfc2205 100644 --- a/openfisca_core/taxscales/linear_average_rate_tax_scale.py +++ b/openfisca_core/taxscales/linear_average_rate_tax_scale.py @@ -1,26 +1,27 @@ from __future__ import annotations -import logging import typing +import logging + import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike log = logging.getLogger(__name__) if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class LinearAverageRateTaxScale(RateTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: if len(self.rates) == 1: return tax_base * self.rates[0] @@ -28,17 +29,15 @@ def calc( tiled_thresholds = numpy.tile(self.thresholds, (len(tax_base), 1)) bracket_dummy = (tiled_base >= tiled_thresholds[:, :-1]) * ( - + tiled_base - < tiled_thresholds[:, 1:] - ) + +tiled_base < tiled_thresholds[:, 1:] + ) rates_array = numpy.array(self.rates) thresholds_array = numpy.array(self.thresholds) rate_slope = (rates_array[1:] - rates_array[:-1]) / ( - + thresholds_array[1:] - - thresholds_array[:-1] - ) + +thresholds_array[1:] - thresholds_array[:-1] + ) average_rate_slope = numpy.dot(bracket_dummy, rate_slope.T) @@ -49,17 +48,16 @@ def calc( log.info(f"average_rate_slope: {average_rate_slope}") return tax_base * ( - + bracket_average_start_rate - + (tax_base - bracket_threshold) - * average_rate_slope - ) + +bracket_average_start_rate + + (tax_base - bracket_threshold) * average_rate_slope + ) def to_marginal(self) -> taxscales.MarginalRateTaxScale: marginal_tax_scale = taxscales.MarginalRateTaxScale( - name = self.name, - option = self.option, - unit = self.unit, - ) + name=self.name, + option=self.option, + unit=self.unit, + ) previous_i = 0 previous_threshold = 0 @@ -70,7 +68,7 @@ def to_marginal(self) -> taxscales.MarginalRateTaxScale: marginal_tax_scale.add_bracket( previous_threshold, (i - previous_i) / (threshold - previous_threshold), - ) + ) previous_i = i previous_threshold = threshold diff --git a/openfisca_core/taxscales/marginal_amount_tax_scale.py b/openfisca_core/taxscales/marginal_amount_tax_scale.py index 348f2445c0..aa96bff57b 100644 --- a/openfisca_core/taxscales/marginal_amount_tax_scale.py +++ b/openfisca_core/taxscales/marginal_amount_tax_scale.py @@ -4,31 +4,31 @@ import numpy -from openfisca_core.taxscales import AmountTaxScaleLike +from .amount_tax_scale_like import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalAmountTaxScale(AmountTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the sum of + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the sum of cell values from the lowest bracket to the one containing the input. """ base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T thresholds1 = numpy.tile( - numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1) - ) + numpy.hstack((self.thresholds, numpy.inf)), + (len(tax_base), 1), + ) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 - ) + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, + ) return numpy.dot(self.amounts, a.T > 0) diff --git a/openfisca_core/taxscales/marginal_rate_tax_scale.py b/openfisca_core/taxscales/marginal_rate_tax_scale.py index 38331e0bb8..803a5f8547 100644 --- a/openfisca_core/taxscales/marginal_rate_tax_scale.py +++ b/openfisca_core/taxscales/marginal_rate_tax_scale.py @@ -1,44 +1,44 @@ from __future__ import annotations +import typing + import bisect import itertools -import typing import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalRateTaxScale(RateTaxScaleLike): - def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None: # So as not to have problems with empty scales - if (len(tax_scale.thresholds) > 0): + if len(tax_scale.thresholds) > 0: for threshold_low, threshold_high, rate in zip( - tax_scale.thresholds[:-1], - tax_scale.thresholds[1:], - tax_scale.rates, - ): + tax_scale.thresholds[:-1], + tax_scale.thresholds[1:], + tax_scale.rates, + ): self.combine_bracket(rate, threshold_low, threshold_high) # To process the last threshold self.combine_bracket( tax_scale.rates[-1], tax_scale.thresholds[-1], - ) + ) def calc( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the tax amount for the given tax bases by applying a taxscale. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the tax amount for the given tax bases by applying a taxscale. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of the taxscale. @@ -67,31 +67,31 @@ def calc( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - factor + numpy.finfo(numpy.float_).eps, - numpy.array(self.thresholds + [numpy.inf]), - ) + factor + numpy.finfo(numpy.float64).eps, + numpy.array([*self.thresholds, numpy.inf]), + ) if round_base_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_base_decimals) + thresholds1 = numpy.round(thresholds1, round_base_decimals) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 - ) + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, + ) if round_base_decimals is None: return numpy.dot(self.rates, a.T) - else: - r = numpy.tile(self.rates, (len(tax_base), 1)) - b = numpy.round_(a, round_base_decimals) - return numpy.round_(r * b, round_base_decimals).sum(axis = 1) + r = numpy.tile(self.rates, (len(tax_base), 1)) + b = numpy.round(a, round_base_decimals) + return numpy.round(r * b, round_base_decimals).sum(axis=1) def combine_bracket( - self, - rate: typing.Union[int, float], - threshold_low: int = 0, - threshold_high: typing.Union[int, bool] = False, - ) -> None: + self, + rate: int | float, + threshold_low: int = 0, + threshold_high: int | bool = False, + ) -> None: # Insert threshold_low and threshold_high without modifying rates if threshold_low not in self.thresholds: index = bisect.bisect_right(self.thresholds, threshold_low) - 1 @@ -115,13 +115,12 @@ def combine_bracket( i += 1 def marginal_rates( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the marginal tax rates relevant for the given tax bases. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the marginal tax rates relevant for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of a tax scale. @@ -144,13 +143,72 @@ def marginal_rates( tax_base, factor, round_base_decimals, - ) + ) return numpy.array(self.rates)[bracket_indices] - def inverse(self) -> MarginalRateTaxScale: + def rate_from_bracket_indice( + self, + bracket_indice: numpy.int16, + ) -> numpy.float32: + """Compute the relevant tax rates for the given bracket indices. + + :param: ndarray bracket_indice: Array of the bracket indices. + + :returns: Floating array with relevant tax rates + for the given bracket indices. + + For instance: + + >>> import numpy + >>> tax_scale = MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([50, 1_000, 250]) + >>> bracket_indice = tax_scale.bracket_indices(tax_base) + >>> tax_scale.rate_from_bracket_indice(bracket_indice) + array([0. , 0.25, 0.1 ]) + """ + if bracket_indice.max() > len(self.rates) - 1: + msg = ( + f"bracket_indice parameter ({bracket_indice}) " + f"contains one or more bracket indice which is unavailable " + f"inside current {self.__class__.__name__} :\n" + f"{self}" + ) + raise IndexError( + msg, + ) + + return numpy.array(self.rates)[bracket_indice] + + def rate_from_tax_base( + self, + tax_base: NumericalArray, + ) -> numpy.float32: + """Compute the relevant tax rates for the given tax bases. + + :param: ndarray tax_base: Array of the tax bases. + + :returns: Floating array with relevant tax rates + for the given tax bases. + + For instance: + + >>> import numpy + >>> tax_scale = MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([1_000, 50, 450]) + >>> tax_scale.rate_from_tax_base(tax_base) + array([0.25, 0. , 0.1 ]) """ - Returns a new instance of MarginalRateTaxScale. + return self.rate_from_bracket_indice(self.bracket_indices(tax_base)) + + def inverse(self) -> MarginalRateTaxScale: + """Returns a new instance of MarginalRateTaxScale. Invert a taxscale: @@ -176,10 +234,10 @@ def inverse(self) -> MarginalRateTaxScale: # Actually 1 / (1 - global_rate) inverse = self.__class__( - name = str(self.name) + "'", - option = self.option, - unit = self.unit, - ) + name=str(self.name) + "'", + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): if threshold == 0: @@ -202,10 +260,10 @@ def scale_tax_scales(self, factor: float) -> MarginalRateTaxScale: def to_average(self) -> taxscales.LinearAverageRateTaxScale: average_tax_scale = taxscales.LinearAverageRateTaxScale( - name = self.name, - option = self.option, - unit = self.unit, - ) + name=self.name, + option=self.option, + unit=self.unit, + ) average_tax_scale.add_bracket(0, 0) @@ -215,10 +273,10 @@ def to_average(self) -> taxscales.LinearAverageRateTaxScale: previous_rate = self.rates[0] for threshold, rate in itertools.islice( - zip(self.thresholds, self.rates), - 1, - None, - ): + zip(self.thresholds, self.rates), + 1, + None, + ): i += previous_rate * (threshold - previous_threshold) average_tax_scale.add_bracket(threshold, i / threshold) previous_threshold = threshold diff --git a/openfisca_core/taxscales/rate_tax_scale_like.py b/openfisca_core/taxscales/rate_tax_scale_like.py index 824a94debe..288226f11e 100644 --- a/openfisca_core/taxscales/rate_tax_scale_like.py +++ b/openfisca_core/taxscales/rate_tax_scale_like.py @@ -1,34 +1,35 @@ from __future__ import annotations +import typing + import abc import bisect import os -import typing import numpy from openfisca_core import tools from openfisca_core.errors import EmptyArgumentError -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class RateTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ - rates: typing.List + rates: list def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: super().__init__(name, option, unit) self.rates = [] @@ -37,17 +38,16 @@ def __repr__(self) -> str: os.linesep.join( [ f"- threshold: {threshold}{os.linesep} rate: {rate}" - for (threshold, rate) - in zip(self.thresholds, self.rates) - ] - ) - ) + for (threshold, rate) in zip(self.thresholds, self.rates) + ], + ), + ) def add_bracket( - self, - threshold: typing.Union[int, float], - rate: typing.Union[int, float], - ) -> None: + self, + threshold: int | float, + rate: int | float, + ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) self.rates[i] += rate @@ -58,11 +58,11 @@ def add_bracket( self.rates.insert(i, rate) def multiply_rates( - self, - factor: float, - inplace: bool = True, - new_name: typing.Optional[str] = None, - ) -> RateTaxScaleLike: + self, + factor: float, + inplace: bool = True, + new_name: str | None = None, + ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -73,9 +73,9 @@ def multiply_rates( new_tax_scale = self.__class__( new_name or self.name, - option = self.option, - unit = self.unit, - ) + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): new_tax_scale.thresholds.append(threshold) @@ -84,12 +84,12 @@ def multiply_rates( return new_tax_scale def multiply_thresholds( - self, - factor: float, - decimals: typing.Optional[int] = None, - inplace: bool = True, - new_name: typing.Optional[str] = None, - ) -> RateTaxScaleLike: + self, + factor: float, + decimals: int | None = None, + inplace: bool = True, + new_name: str | None = None, + ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -97,8 +97,8 @@ def multiply_thresholds( if decimals is not None: self.thresholds[i] = numpy.around( threshold * factor, - decimals = decimals, - ) + decimals=decimals, + ) else: self.thresholds[i] = threshold * factor @@ -107,15 +107,15 @@ def multiply_thresholds( new_tax_scale = self.__class__( new_name or self.name, - option = self.option, - unit = self.unit, - ) + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): if decimals is not None: new_tax_scale.thresholds.append( - numpy.around(threshold * factor, decimals = decimals), - ) + numpy.around(threshold * factor, decimals=decimals), + ) else: new_tax_scale.thresholds.append(threshold * factor) @@ -124,13 +124,12 @@ def multiply_thresholds( return new_tax_scale def bracket_indices( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_decimals: typing.Optional[int] = None, - ) -> numpy.int_: - """ - Compute the relevant bracket indices for the given tax bases. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_decimals: int | None = None, + ) -> numpy.int32: + """Compute the relevant bracket indices for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds. @@ -148,14 +147,13 @@ def bracket_indices( >>> tax_scale.bracket_indices(tax_base) [0, 1] """ - if not numpy.size(numpy.array(self.thresholds)): raise EmptyArgumentError( self.__class__.__name__, "bracket_indices", "self.thresholds", self.thresholds, - ) + ) if not numpy.size(numpy.asarray(tax_base)): raise EmptyArgumentError( @@ -163,7 +161,7 @@ def bracket_indices( "bracket_indices", "tax_base", tax_base, - ) + ) base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T factor = numpy.ones(len(tax_base)) * factor @@ -176,18 +174,42 @@ def bracket_indices( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - + factor - + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds) - ) + +factor + numpy.finfo(numpy.float64).eps, + numpy.array(self.thresholds), + ) if round_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_decimals) + thresholds1 = numpy.round(thresholds1, round_decimals) + + return (base1 - thresholds1 >= 0).sum(axis=1) - 1 + + def threshold_from_tax_base( + self, + tax_base: NumericalArray, + ) -> NumericalArray: + """Compute the relevant thresholds for the given tax bases. - return (base1 - thresholds1 >= 0).sum(axis = 1) - 1 + :param: ndarray tax_base: Array of the tax bases. + + :returns: Floating array with relevant thresholds + for the given tax bases. + + For instance: + + >>> import numpy + >>> from openfisca_core import taxscales + >>> tax_scale = taxscales.MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([450, 1_150, 10]) + >>> tax_scale.threshold_from_tax_base(tax_base) + array([200, 500, 0]) + """ + return numpy.array(self.thresholds)[self.bracket_indices(tax_base)] def to_dict(self) -> dict: return { str(threshold): self.rates[index] - for index, threshold - in enumerate(self.thresholds) - } + for index, threshold in enumerate(self.thresholds) + } diff --git a/openfisca_core/taxscales/single_amount_tax_scale.py b/openfisca_core/taxscales/single_amount_tax_scale.py index bdfee48010..1c8cf69a32 100644 --- a/openfisca_core/taxscales/single_amount_tax_scale.py +++ b/openfisca_core/taxscales/single_amount_tax_scale.py @@ -7,36 +7,26 @@ from openfisca_core.taxscales import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class SingleAmountTaxScale(AmountTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the single + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the single cell value that fits within that bracket. """ - guarded_thresholds = numpy.array( - [-numpy.inf] - + self.thresholds - + [numpy.inf] - ) + guarded_thresholds = numpy.array([-numpy.inf, *self.thresholds, numpy.inf]) bracket_indices = numpy.digitize( tax_base, guarded_thresholds, - right = right, - ) - - guarded_amounts = numpy.array( - [0] - + self.amounts - + [0] - ) + right=right, + ) + + guarded_amounts = numpy.array([0, *self.amounts, 0]) return guarded_amounts[bracket_indices - 1] diff --git a/openfisca_core/taxscales/tax_scale_like.py b/openfisca_core/taxscales/tax_scale_like.py index 8177ee0505..e8680b9f8f 100644 --- a/openfisca_core/taxscales/tax_scale_like.py +++ b/openfisca_core/taxscales/tax_scale_like.py @@ -1,67 +1,64 @@ from __future__ import annotations -import abc -import copy import typing -import numpy +import abc +import copy from openfisca_core import commons if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + import numpy + + NumericalArray = typing.Union[numpy.int32, numpy.float32] class TaxScaleLike(abc.ABC): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ - name: typing.Optional[str] + name: str | None option: typing.Any unit: typing.Any - thresholds: typing.List + thresholds: list @abc.abstractmethod def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: self.name = name or "Untitled TaxScale" self.option = option self.unit = unit self.thresholds = [] def __eq__(self, _other: object) -> typing.NoReturn: + msg = "Method '__eq__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__eq__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def __ne__(self, _other: object) -> typing.NoReturn: + msg = "Method '__ne__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__ne__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) @abc.abstractmethod - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... @abc.abstractmethod def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> numpy.float_: - ... + self, + tax_base: NumericalArray, + right: bool, + ) -> numpy.float32: ... @abc.abstractmethod - def to_dict(self) -> dict: - ... + def to_dict(self) -> dict: ... def copy(self) -> typing.Any: new = commons.empty_clone(self) diff --git a/openfisca_core/tools/__init__.py b/openfisca_core/tools/__init__.py index 9b1dd2cc5d..952dca6ebd 100644 --- a/openfisca_core/tools/__init__.py +++ b/openfisca_core/tools/__init__.py @@ -1,65 +1,70 @@ -# -*- coding: utf-8 -*- - - import os -import numexpr - +from openfisca_core import commons from openfisca_core.indexed_enums import EnumArray -def assert_near(value, target_value, absolute_error_margin = None, message = '', relative_error_margin = None): - ''' - - :param value: Value returned by the test - :param target_value: Value that the test should return to pass - :param absolute_error_margin: Absolute error margin authorized - :param message: Error message to be displayed if the test fails - :param relative_error_margin: Relative error margin authorized - - Limit : This function cannot be used to assert near periods. +def assert_near( + value, + target_value, + absolute_error_margin=None, + message="", + relative_error_margin=None, +): + """:param value: Value returned by the test + :param target_value: Value that the test should return to pass + :param absolute_error_margin: Absolute error margin authorized + :param message: Error message to be displayed if the test fails + :param relative_error_margin: Relative error margin authorized - ''' + Limit : This function cannot be used to assert near periods. - import numpy as np + """ + import numpy if absolute_error_margin is None and relative_error_margin is None: absolute_error_margin = 0 - if not isinstance(value, np.ndarray): - value = np.array(value) + if not isinstance(value, numpy.ndarray): + value = numpy.array(value) if isinstance(value, EnumArray): return assert_enum_equals(value, target_value, message) - if np.issubdtype(value.dtype, np.datetime64): - target_value = np.array(target_value, dtype = value.dtype) + if numpy.issubdtype(value.dtype, numpy.datetime64): + target_value = numpy.array(target_value, dtype=value.dtype) assert_datetime_equals(value, target_value, message) if isinstance(target_value, str): - target_value = eval_expression(target_value) + target_value = commons.eval_expression(target_value) - target_value = np.array(target_value).astype(np.float32) + target_value = numpy.array(target_value).astype(numpy.float32) - value = np.array(value).astype(np.float32) + value = numpy.array(value).astype(numpy.float32) diff = abs(target_value - value) if absolute_error_margin is not None: - assert (diff <= absolute_error_margin).all(), \ - '{}{} differs from {} with an absolute margin {} > {}'.format(message, value, target_value, - diff, absolute_error_margin) + assert ( + diff <= absolute_error_margin + ).all(), f"{message}{value} differs from {target_value} with an absolute margin {diff} > {absolute_error_margin}" if relative_error_margin is not None: - assert (diff <= abs(relative_error_margin * target_value)).all(), \ - '{}{} differs from {} with a relative margin {} > {}'.format(message, value, target_value, - diff, abs(relative_error_margin * target_value)) + assert ( + diff <= abs(relative_error_margin * target_value) + ).all(), f"{message}{value} differs from {target_value} with a relative margin {diff} > {abs(relative_error_margin * target_value)}" + return None + return None -def assert_datetime_equals(value, target_value, message = ''): - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) +def assert_datetime_equals(value, target_value, message="") -> None: + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." -def assert_enum_equals(value, target_value, message = ''): +def assert_enum_equals(value, target_value, message="") -> None: value = value.decode_to_str() - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." def indent(text): - return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) + return " {}".format(text.replace(os.linesep, f"{os.linesep} ")) def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): @@ -68,18 +73,16 @@ def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): scenario_json = scenario.to_json() simulation_json = { - 'scenarios': [scenario_json], - 'variables': variables, - } - url = trace_tool_url + '?' + urllib.urlencode({ - 'simulation': json.dumps(simulation_json), - 'api_url': api_url, - }) - return url - - -def eval_expression(expression): - try: - return numexpr.evaluate(expression) - except (KeyError, TypeError): - return expression + "scenarios": [scenario_json], + "variables": variables, + } + return ( + trace_tool_url + + "?" + + urllib.urlencode( + { + "simulation": json.dumps(simulation_json), + "api_url": api_url, + }, + ) + ) diff --git a/openfisca_core/tools/simulation_dumper.py b/openfisca_core/tools/simulation_dumper.py index 4b5907c0ff..84898165fd 100644 --- a/openfisca_core/tools/simulation_dumper.py +++ b/openfisca_core/tools/simulation_dumper.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- - - import os -import numpy as np +import numpy -from openfisca_core.simulations import Simulation from openfisca_core.data_storage import OnDiskStorage -from openfisca_core.periods import ETERNITY +from openfisca_core.periods import DateUnit +from openfisca_core.simulations import Simulation -def dump_simulation(simulation, directory): - """ - Write simulation data to directory, so that it can be restored later. - """ +def dump_simulation(simulation, directory) -> None: + """Write simulation data to directory, so that it can be restored later.""" parent_directory = os.path.abspath(os.path.join(directory, os.pardir)) if not os.path.isdir(parent_directory): # To deal with reforms os.mkdir(parent_directory) @@ -21,7 +16,8 @@ def dump_simulation(simulation, directory): os.mkdir(directory) if os.listdir(directory): - raise ValueError("Directory '{}' is not empty".format(directory)) + msg = f"Directory '{directory}' is not empty" + raise ValueError(msg) entities_dump_dir = os.path.join(directory, "__entities__") os.mkdir(entities_dump_dir) @@ -36,10 +32,11 @@ def dump_simulation(simulation, directory): def restore_simulation(directory, tax_benefit_system, **kwargs): - """ - Restore simulation from directory - """ - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) + """Restore simulation from directory.""" + simulation = Simulation( + tax_benefit_system, + tax_benefit_system.instantiate_entities(), + ) entities_dump_dir = os.path.join(directory, "__entities__") for population in simulation.populations.values(): @@ -53,75 +50,84 @@ def restore_simulation(directory, tax_benefit_system, **kwargs): _restore_entity(population, entities_dump_dir) population.count = person_count - variables_to_restore = (variable for variable in os.listdir(directory) if variable != "__entities__") + variables_to_restore = ( + variable for variable in os.listdir(directory) if variable != "__entities__" + ) for variable in variables_to_restore: _restore_holder(simulation, variable, directory) return simulation -def _dump_holder(holder, directory): - disk_storage = holder.create_disk_storage(directory, preserve = True) +def _dump_holder(holder, directory) -> None: + disk_storage = holder.create_disk_storage(directory, preserve=True) for period in holder.get_known_periods(): value = holder.get_array(period) disk_storage.put(value, period) -def _dump_entity(population, directory): +def _dump_entity(population, directory) -> None: path = os.path.join(directory, population.entity.key) os.mkdir(path) - np.save(os.path.join(path, "id.npy"), population.ids) + numpy.save(os.path.join(path, "id.npy"), population.ids) if population.entity.is_person: return - np.save(os.path.join(path, "members_position.npy"), population.members_position) - np.save(os.path.join(path, "members_entity_id.npy"), population.members_entity_id) + numpy.save(os.path.join(path, "members_position.npy"), population.members_position) + numpy.save( + os.path.join(path, "members_entity_id.npy"), population.members_entity_id + ) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - encoded_roles = np.int64(0) + encoded_roles = numpy.int16(0) else: - encoded_roles = np.select( + encoded_roles = numpy.select( [population.members_role == role for role in flattened_roles], [role.key for role in flattened_roles], - ) - np.save(os.path.join(path, "members_role.npy"), encoded_roles) + ) + numpy.save(os.path.join(path, "members_role.npy"), encoded_roles) def _restore_entity(population, directory): path = os.path.join(directory, population.entity.key) - population.ids = np.load(os.path.join(path, "id.npy")) + population.ids = numpy.load(os.path.join(path, "id.npy")) if population.entity.is_person: - return + return None - population.members_position = np.load(os.path.join(path, "members_position.npy")) - population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy")) - encoded_roles = np.load(os.path.join(path, "members_role.npy")) + population.members_position = numpy.load(os.path.join(path, "members_position.npy")) + population.members_entity_id = numpy.load( + os.path.join(path, "members_entity_id.npy") + ) + encoded_roles = numpy.load(os.path.join(path, "members_role.npy")) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - population.members_role = np.int64(0) + population.members_role = numpy.int16(0) else: - population.members_role = np.select( + population.members_role = numpy.select( [encoded_roles == role.key for role in flattened_roles], - [role for role in flattened_roles], - ) + list(flattened_roles), + ) person_count = len(population.members_entity_id) population.count = max(population.members_entity_id) + 1 return person_count -def _restore_holder(simulation, variable, directory): +def _restore_holder(simulation, variable, directory) -> None: storage_dir = os.path.join(directory, variable) - is_variable_eternal = simulation.tax_benefit_system.get_variable(variable).definition_period == ETERNITY + is_variable_eternal = ( + simulation.tax_benefit_system.get_variable(variable).definition_period + == DateUnit.ETERNITY + ) disk_storage = OnDiskStorage( storage_dir, - is_eternal = is_variable_eternal, - preserve_storage_dir = True - ) + is_eternal=is_variable_eternal, + preserve_storage_dir=True, + ) disk_storage.restore() holder = simulation.get_holder(variable) diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index 286ff06991..1f7b603b6a 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -1,22 +1,83 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -import warnings -import sys +from collections.abc import Sequence +from typing import Any +from typing_extensions import Literal, TypedDict + +from openfisca_core.types import TaxBenefitSystem + +import dataclasses import os -import traceback +import pathlib +import sys import textwrap -from typing import Dict, List +import traceback +import warnings import pytest -from openfisca_core.tools import assert_near +from openfisca_core.errors import SituationParsingError, VariableNotFound from openfisca_core.simulations import SimulationBuilder -from openfisca_core.errors import SituationParsingError, VariableNotFoundError +from openfisca_core.tools import assert_near from openfisca_core.warnings import LibYAMLWarning +class Options(TypedDict, total=False): + aggregate: bool + ignore_variables: Sequence[str] | None + max_depth: int | None + name_filter: str | None + only_variables: Sequence[str] | None + pdb: bool + performance_graph: bool + performance_tables: bool + verbose: bool + + +@dataclasses.dataclass(frozen=True) +class ErrorMargin: + __root__: dict[str | Literal["default"], float | None] + + def __getitem__(self, key: str) -> float | None: + if key in self.__root__: + return self.__root__[key] + + return self.__root__["default"] + + +@dataclasses.dataclass +class Test: + absolute_error_margin: ErrorMargin + relative_error_margin: ErrorMargin + name: str = "" + input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) + output: dict[str, float | dict[str, float]] | None = None + period: str | None = None + reforms: Sequence[str] = dataclasses.field(default_factory=list) + keywords: Sequence[str] | None = None + extensions: Sequence[str] = dataclasses.field(default_factory=list) + description: str | None = None + max_spiral_loops: int | None = None + + +def build_test(params: dict[str, Any]) -> Test: + for key in ["absolute_error_margin", "relative_error_margin"]: + value = params.get(key) + + if value is None: + value = {"default": None} + + elif isinstance(value, (float, int, str)): + value = {"default": float(value)} + + params[key] = ErrorMargin(value) + + return Test(**params) + + def import_yaml(): import yaml + try: from yaml import CLoader as Loader except ImportError: @@ -24,33 +85,55 @@ def import_yaml(): "libyaml is not installed in your environment.", "This can make your test suite slower to run. Once you have installed libyaml, ", "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", - "so that it is used in your Python environment." - ] - warnings.warn(" ".join(message), LibYAMLWarning) + "so that it is used in your Python environment.", + ] + warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) from yaml import SafeLoader as Loader return yaml, Loader -TEST_KEYWORDS = {'absolute_error_margin', 'description', 'extensions', 'ignore_variables', 'input', 'keywords', 'max_spiral_loops', 'name', 'only_variables', 'output', 'period', 'reforms', 'relative_error_margin'} +TEST_KEYWORDS = { + "absolute_error_margin", + "description", + "extensions", + "ignore_variables", + "input", + "keywords", + "max_spiral_loops", + "name", + "only_variables", + "output", + "period", + "reforms", + "relative_error_margin", +} yaml, Loader = import_yaml() -_tax_benefit_system_cache: Dict = {} +_tax_benefit_system_cache: dict = {} +options: Options = Options() -def run_tests(tax_benefit_system, paths, options = None): - """ - Runs all the YAML tests contained in a file or a directory. - If `path` is a directory, subdirectories will be recursively explored. +def run_tests( + tax_benefit_system: TaxBenefitSystem, + paths: str | Sequence[str], + options: Options = options, +) -> int: + """Runs all the YAML tests contained in a file or a directory. - :param .TaxBenefitSystem tax_benefit_system: the tax-benefit system to use to run the tests - :param str or list paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. - :param dict options: See more details below. + If ``path`` is a directory, subdirectories will be recursively explored. - :raises :exc:`AssertionError`: if a test does not pass + Args: + tax_benefit_system: the tax-benefit system to use to run the tests. + paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. + options: See more details below. - :return: the number of sucessful tests excecuted + Returns: + The number of successful tests executed. + + Raises: + :exc:`AssertionError`: if a test does not pass. **Testing options**: @@ -63,98 +146,108 @@ def run_tests(tax_benefit_system, paths, options = None): +-------------------------------+-----------+-------------------------------------------+ """ - argv = [] + plugins = [OpenFiscaPlugin(tax_benefit_system, options)] - if options.get('pdb'): - argv.append('--pdb') + if options.get("pdb"): + argv.append("--pdb") - if options.get('verbose'): - argv.append('--verbose') + if options.get("verbose"): + argv.append("--verbose") if isinstance(paths, str): paths = [paths] - return pytest.main([*argv, *paths] if True else paths, plugins = [OpenFiscaPlugin(tax_benefit_system, options)]) + return pytest.main([*argv, *paths], plugins=plugins) class YamlFile(pytest.File): - - def __init__(self, path, fspath, parent, tax_benefit_system, options): - super(YamlFile, self).__init__(path, parent) + def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: + super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options def collect(self): try: - tests = yaml.load(self.fspath.open(), Loader = Loader) + tests = yaml.load(open(self.path), Loader=Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): - message = os.linesep.join([ - traceback.format_exc(), - f"'{self.fspath}' is not a valid YAML file. Check the stack trace above for more details.", - ]) + message = os.linesep.join( + [ + traceback.format_exc(), + f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", + ], + ) raise ValueError(message) if not isinstance(tests, list): - tests: List[Dict] = [tests] + tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): - yield YamlItem.from_parent(self, - name = '', - baseline_tax_benefit_system = self.tax_benefit_system, - test = test, options = self.options) + yield YamlItem.from_parent( + self, + name="", + baseline_tax_benefit_system=self.tax_benefit_system, + test=test, + options=self.options, + ) def should_ignore(self, test): - name_filter = self.options.get('name_filter') + name_filter = self.options.get("name_filter") return ( name_filter is not None - and name_filter not in os.path.splitext(self.fspath.basename)[0] - and name_filter not in test.get('name', '') - and name_filter not in test.get('keywords', []) - ) + and name_filter not in os.path.splitext(os.path.basename(self.path))[0] + and name_filter not in test.get("name", "") + and name_filter not in test.get("keywords", []) + ) class YamlItem(pytest.Item): - """ - Terminal nodes of the test collection tree. - """ + """Terminal nodes of the test collection tree.""" - def __init__(self, name, parent, baseline_tax_benefit_system, test, options): - super(YamlItem, self).__init__(name, parent) + def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: + super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options - self.test = test + self.test = build_test(test) self.simulation = None self.tax_benefit_system = None - def runtest(self): - self.name = self.test.get('name', '') - if not self.test.get('output'): - raise ValueError("Missing key 'output' in test '{}' in file '{}'".format(self.name, self.fspath)) + def runtest(self) -> None: + self.name = self.test.name - if not TEST_KEYWORDS.issuperset(self.test.keys()): - unexpected_keys = set(self.test.keys()).difference(TEST_KEYWORDS) - raise ValueError("Unexpected keys {} in test '{}' in file '{}'".format(unexpected_keys, self.name, self.fspath)) + if self.test.output is None: + msg = f"Missing key 'output' in test '{self.name}' in file '{self.path}'" + raise ValueError(msg) - self.tax_benefit_system = _get_tax_benefit_system(self.baseline_tax_benefit_system, self.test.get('reforms', []), self.test.get('extensions', [])) + self.tax_benefit_system = _get_tax_benefit_system( + self.baseline_tax_benefit_system, + self.test.reforms, + self.test.extensions, + ) builder = SimulationBuilder() - input = self.test.get('input', {}) - period = self.test.get('period') - max_spiral_loops = self.test.get('max_spiral_loops') - verbose = self.options.get('verbose') - performance_graph = self.options.get('performance_graph') - performance_tables = self.options.get('performance_tables') + input = self.test.input + period = self.test.period + max_spiral_loops = self.test.max_spiral_loops + verbose = self.options.get("verbose") + aggregate = self.options.get("aggregate") + max_depth = self.options.get("max_depth") + performance_graph = self.options.get("performance_graph") + performance_tables = self.options.get("performance_tables") try: builder.set_default_period(period) self.simulation = builder.build_from_dict(self.tax_benefit_system, input) - except (VariableNotFoundError, SituationParsingError): + except (VariableNotFound, SituationParsingError): raise except Exception as e: - error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) - raise ValueError(error_message).with_traceback(sys.exc_info()[2]) from e # Keep the stack trace from the root error + error_message = os.linesep.join( + [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], + ) + raise ValueError(error_message).with_traceback( + sys.exc_info()[2], + ) from e # Keep the stack trace from the root error if max_spiral_loops: self.simulation.max_spiral_loops = max_spiral_loops @@ -165,101 +258,130 @@ def runtest(self): finally: tracer = self.simulation.tracer if verbose: - self.print_computation_log(tracer) + self.print_computation_log(tracer, aggregate, max_depth) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) - def print_computation_log(self, tracer): - print("Computation log:") # noqa T001 - tracer.print_computation_log() + def print_computation_log(self, tracer, aggregate, max_depth) -> None: + tracer.print_computation_log(aggregate, max_depth) - def generate_performance_graph(self, tracer): - tracer.generate_performance_graph('.') + def generate_performance_graph(self, tracer) -> None: + tracer.generate_performance_graph(".") - def generate_performance_tables(self, tracer): - tracer.generate_performance_tables('.') + def generate_performance_tables(self, tracer) -> None: + tracer.generate_performance_tables(".") - def check_output(self): - output = self.test.get('output') + def check_output(self) -> None: + output = self.test.output if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable(key): # If key is a variable - self.check_variable(key, expected_value, self.test.get('period')) + self.check_variable(key, expected_value, self.test.period) elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): - self.check_variable(variable_name, value, self.test.get('period')) + self.check_variable(variable_name, value, self.test.period) else: - population = self.simulation.get_population(plural = key) + population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) - self.check_variable(variable_name, value, self.test.get('period'), entity_index) + self.check_variable( + variable_name, + value, + self.test.period, + entity_index, + ) else: - raise VariableNotFoundError(key, self.tax_benefit_system) - - def check_variable(self, variable_name, expected_value, period, entity_index = None): + raise VariableNotFound(key, self.tax_benefit_system) + + def check_variable( + self, + variable_name: str, + expected_value, + period, + entity_index=None, + ): if self.should_ignore_variable(variable_name): - return + return None + if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): - self.check_variable(variable_name, expected_value_at_period, requested_period, entity_index) - return + self.check_variable( + variable_name, + expected_value_at_period, + requested_period, + entity_index, + ) + + return None actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] + return assert_near( actual_value, expected_value, - absolute_error_margin = self.test.get('absolute_error_margin'), - message = f"{variable_name}@{period}: ", - relative_error_margin = self.test.get('relative_error_margin'), - ) - - def should_ignore_variable(self, variable_name): - only_variables = self.options.get('only_variables') - ignore_variables = self.options.get('ignore_variables') - variable_ignored = ignore_variables is not None and variable_name in ignore_variables - variable_not_tested = only_variables is not None and variable_name not in only_variables + self.test.absolute_error_margin[variable_name], + f"{variable_name}@{period}: ", + self.test.relative_error_margin[variable_name], + ) + + def should_ignore_variable(self, variable_name: str): + only_variables = self.options.get("only_variables") + ignore_variables = self.options.get("ignore_variables") + variable_ignored = ( + ignore_variables is not None and variable_name in ignore_variables + ) + variable_not_tested = ( + only_variables is not None and variable_name not in only_variables + ) return variable_ignored or variable_not_tested def repr_failure(self, excinfo): - if not isinstance(excinfo.value, (AssertionError, VariableNotFoundError, SituationParsingError)): - return super(YamlItem, self).repr_failure(excinfo) + if not isinstance( + excinfo.value, + (AssertionError, VariableNotFound, SituationParsingError), + ): + return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" - return os.linesep.join([ - f"{str(self.fspath)}:", - f" Test '{str(self.name)}':", - textwrap.indent(message, ' ') - ]) - + return os.linesep.join( + [ + f"{self.path!s}:", + f" Test '{self.name!s}':", + textwrap.indent(message, " "), + ], + ) -class OpenFiscaPlugin(object): - def __init__(self, tax_benefit_system, options): +class OpenFiscaPlugin: + def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): - """ - Called by pytest for all plugins. + """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: - return YamlFile.from_parent(parent, path = path, fspath = path, - tax_benefit_system = self.tax_benefit_system, - options = self.options) + return YamlFile.from_parent( + parent, + path=pathlib.Path(path), + tax_benefit_system=self.tax_benefit_system, + options=self.options, + ) + return None def _get_tax_benefit_system(baseline, reforms, extensions): @@ -269,17 +391,18 @@ def _get_tax_benefit_system(baseline, reforms, extensions): extensions = [extensions] # keep reforms order in cache, ignore extensions order - key = hash((id(baseline), ':'.join(reforms), frozenset(extensions))) + key = hash((id(baseline), ":".join(reforms), frozenset(extensions))) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) - current_tax_benefit_system = baseline + current_tax_benefit_system = baseline.clone() for reform_path in reforms: - current_tax_benefit_system = current_tax_benefit_system.apply_reform(reform_path) + current_tax_benefit_system = current_tax_benefit_system.apply_reform( + reform_path, + ) for extension in extensions: - current_tax_benefit_system = current_tax_benefit_system.clone() current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system diff --git a/openfisca_core/tracers/__init__.py b/openfisca_core/tracers/__init__.py index de489ad6d9..76e36b55cd 100644 --- a/openfisca_core/tracers/__init__.py +++ b/openfisca_core/tracers/__init__.py @@ -21,10 +21,22 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .computation_log import ComputationLog # noqa: F401 -from .flat_trace import FlatTrace # noqa: F401 -from .full_tracer import FullTracer # noqa: F401 -from .performance_log import PerformanceLog # noqa: F401 -from .simple_tracer import SimpleTracer # noqa: F401 -from .trace_node import TraceNode # noqa: F401 -from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant # noqa: F401 +from . import types +from .computation_log import ComputationLog +from .flat_trace import FlatTrace +from .full_tracer import FullTracer +from .performance_log import PerformanceLog +from .simple_tracer import SimpleTracer +from .trace_node import TraceNode +from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant + +__all__ = [ + "ComputationLog", + "FlatTrace", + "FullTracer", + "PerformanceLog", + "SimpleTracer", + "TraceNode", + "TracingParameterNodeAtInstant", + "types", +] diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index c785fd9395..9fcc09e258 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -1,103 +1,108 @@ from __future__ import annotations -import typing -from typing import List, Optional, Union +import sys import numpy -from .. import tracers from openfisca_core.indexed_enums import EnumArray -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - Array = Union[EnumArray, ArrayLike] +from . import types as t class ComputationLog: + _full_tracer: t.FullTracer - _full_tracer: tracers.FullTracer - - def __init__(self, full_tracer: tracers.FullTracer) -> None: + def __init__(self, full_tracer: t.FullTracer) -> None: self._full_tracer = full_tracer - def display( - self, - value: Optional[Array], - ) -> str: - if isinstance(value, EnumArray): - value = value.decode_to_str() + def lines( + self, + aggregate: bool = False, + max_depth: int = sys.maxsize, + ) -> list[str]: + depth = 1 - return numpy.array2string(value, max_line_width = float("inf")) + lines_by_tree = [ + self._get_node_log(node, depth, aggregate, max_depth) + for node in self._full_tracer.trees + ] - def _get_node_log( - self, - node: tracers.TraceNode, - depth: int, - aggregate: bool, - ) -> List[str]: + return self._flatten(lines_by_tree) - def print_line(depth: int, node: tracers.TraceNode) -> str: - indent = ' ' * depth - value = node.value + def print_log(self, aggregate: bool = False, max_depth: int = sys.maxsize) -> None: + """Print the computation log of a simulation. - if value is None: - formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + If ``aggregate`` is ``False`` (default), print the value of each + computed vector. - elif aggregate: - try: - formatted_value = str({ - 'avg': numpy.mean(value), - 'max': numpy.max(value), - 'min': numpy.min(value), - }) + If ``aggregate`` is ``True``, only print the minimum, maximum, and + average value of each computed vector. - except TypeError: - formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + This mode is more suited for simulations on a large population. - else: - formatted_value = self.display(value) + If ``max_depth`` is ``None`` (default), print the entire computation. - return f"{indent}{node.name}<{node.period}> >> {formatted_value}" + If ``max_depth`` is set, for example to ``3``, only print computed + vectors up to a depth of ``max_depth``. + """ + for _ in self.lines(aggregate, max_depth): + pass - node_log = [print_line(depth, node)] + def _get_node_log( + self, + node: t.TraceNode, + depth: int, + aggregate: bool, + max_depth: int = sys.maxsize, + ) -> list[str]: + if depth > max_depth: + return [] + + node_log = [self._print_line(depth, node, aggregate)] children_logs = [ - self._get_node_log(child, depth + 1, aggregate) - for child - in node.children - ] + self._get_node_log(child, depth + 1, aggregate, max_depth) + for child in node.children + ] return node_log + self._flatten(children_logs) - def _flatten( - self, - list_of_lists: List[List[str]], - ) -> List[str]: - return [item for _list in list_of_lists for item in _list] - - def lines(self, aggregate: bool = False) -> List[str]: - depth = 1 + def _print_line(self, depth: int, node: t.TraceNode, aggregate: bool) -> str: + indent = " " * depth + value = node.value + + if value is None: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + elif aggregate: + try: + formatted_value = str( # pyright: ignore[reportCallIssue] + { + "avg": numpy.mean( + value + ), # pyright: ignore[reportArgumentType,reportCallIssue] + "max": numpy.max(value), + "min": numpy.min(value), + }, + ) + + except TypeError: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" - lines_by_tree = [ - self._get_node_log(node, depth, aggregate) - for node - in self._full_tracer.trees - ] + else: + formatted_value = self.display(value) - return self._flatten(lines_by_tree) + return f"{indent}{node.name}<{node.period}> >> {formatted_value}" - def print_log(self, aggregate = False) -> None: - """ - Print the computation log of a simulation. + @staticmethod + def display(value: t.VarArray, max_depth: int = sys.maxsize) -> str: + if isinstance(value, EnumArray): + value = value.decode_to_str() + return numpy.array2string(value, max_line_width=max_depth) - If ``aggregate`` is ``False`` (default), print the value of each - computed vector. + @staticmethod + def _flatten(lists: list[list[str]]) -> list[str]: + return [item for list_ in lists for item in list_] - If ``aggregate`` is ``True``, only print the minimum, maximum, and - average value of each computed vector. - This mode is more suited for simulations on a large population. - """ - for line in self.lines(aggregate): - print(line) # noqa T001 +__all__ = ["ComputationLog"] diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index d51dd2576b..412ac8b027 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -1,94 +1,88 @@ from __future__ import annotations -import typing -from typing import Dict, Optional, Union - import numpy -from openfisca_core import tracers from openfisca_core.indexed_enums import EnumArray -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - Array = Union[EnumArray, ArrayLike] - Trace = Dict[str, dict] +from . import types as t class FlatTrace: + _full_tracer: t.FullTracer - _full_tracer: tracers.FullTracer - - def __init__(self, full_tracer: tracers.FullTracer) -> None: + def __init__(self, full_tracer: t.FullTracer) -> None: self._full_tracer = full_tracer - def key(self, node: tracers.TraceNode) -> str: - name = node.name - period = node.period - return f"{name}<{period}>" - - def get_trace(self) -> dict: - trace = {} + def get_trace(self) -> t.FlatNodeMap: + trace: t.FlatNodeMap = {} for node in self._full_tracer.browse_trace(): # We don't want cache read to overwrite data about the initial # calculation. # # We therefore use a non-overwriting update. - trace.update({ - key: node_trace - for key, node_trace in self._get_flat_trace(node).items() - if key not in trace - }) + trace.update( + { + key: node_trace + for key, node_trace in self._get_flat_trace(node).items() + if key not in trace + }, + ) return trace - def get_serialized_trace(self) -> dict: + def get_serialized_trace(self) -> t.SerializedNodeMap: return { - key: { - **flat_trace, - 'value': self.serialize(flat_trace['value']) - } + key: {**flat_trace, "value": self.serialize(flat_trace["value"])} for key, flat_trace in self.get_trace().items() - } + } + + def _get_flat_trace( + self, + node: t.TraceNode, + ) -> t.FlatNodeMap: + key = self.key(node) + return { + key: { + "dependencies": [self.key(child) for child in node.children], + "parameters": { + self.key(parameter): self.serialize(parameter.value) + for parameter in node.parameters + }, + "value": node.value, + "calculation_time": node.calculation_time(), + "formula_time": node.formula_time(), + }, + } + + @staticmethod + def key(node: t.TraceNode) -> t.NodeKey: + """Return the key of a node.""" + name = node.name + period = node.period + return t.NodeKey(f"{name}<{period}>") + + @staticmethod def serialize( - self, - value: Optional[Array], - ) -> Union[Optional[Array], list]: + value: None | t.VarArray | t.ArrayLike[object], + ) -> None | t.ArrayLike[object]: + if value is None: + return None + if isinstance(value, EnumArray): - value = value.decode_to_str() + return value.decode_to_str().tolist() - if isinstance(value, numpy.ndarray) and \ - numpy.issubdtype(value.dtype, numpy.dtype(bytes)): - value = value.astype(numpy.dtype(str)) + if isinstance(value, numpy.ndarray) and numpy.issubdtype( + value.dtype, + numpy.dtype(bytes), + ): + return value.astype(numpy.dtype(str)).tolist() if isinstance(value, numpy.ndarray): - value = value.tolist() + return value.tolist() return value - def _get_flat_trace( - self, - node: tracers.TraceNode, - ) -> Trace: - key = self.key(node) - - node_trace = { - key: { - 'dependencies': [ - self.key(child) for child in node.children - ], - 'parameters': { - self.key(parameter): - self.serialize(parameter.value) - for parameter - in node.parameters - }, - 'value': node.value, - 'calculation_time': node.calculation_time(), - 'formula_time': node.formula_time(), - }, - } - return node_trace +__all__ = ["FlatTrace"] diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 3fa46de5ab..f6f793e190 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,50 +1,123 @@ from __future__ import annotations -import time -import typing -from typing import Dict, Iterator, List, Optional, Union - -from .. import tracers +from collections.abc import Iterator -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period +import sys +import time - Stack = List[Dict[str, Union[str, Period]]] +from . import types as t +from .computation_log import ComputationLog +from .flat_trace import FlatTrace +from .performance_log import PerformanceLog +from .simple_tracer import SimpleTracer +from .trace_node import TraceNode class FullTracer: - - _simple_tracer: tracers.SimpleTracer - _trees: list - _current_node: Optional[tracers.TraceNode] + _simple_tracer: t.SimpleTracer + _trees: list[t.TraceNode] + _current_node: None | t.TraceNode def __init__(self) -> None: - self._simple_tracer = tracers.SimpleTracer() + self._simple_tracer = SimpleTracer() self._trees = [] self._current_node = None + @property + def stack(self) -> t.SimpleStack: + """Return the stack of traces.""" + return self._simple_tracer.stack + + @property + def trees(self) -> list[t.TraceNode]: + """Return the tree of traces.""" + return self._trees + + @property + def computation_log(self) -> t.ComputationLog: + """Return the computation log.""" + return ComputationLog(self) + + @property + def performance_log(self) -> t.PerformanceLog: + """Return the performance log.""" + return PerformanceLog(self) + + @property + def flat_trace(self) -> t.FlatTrace: + """Return the flat trace.""" + return FlatTrace(self) + def record_calculation_start( - self, - variable: str, - period: Period, - ) -> None: + self, + variable: t.VariableName, + period: t.PeriodInt | t.Period, + ) -> None: self._simple_tracer.record_calculation_start(variable, period) self._enter_calculation(variable, period) self._record_start_time() - def _enter_calculation( - self, - variable: str, - period: Period, - ) -> None: - new_node = tracers.TraceNode( - name = variable, - period = period, - parent = self._current_node, + def record_parameter_access( + self, + parameter: str, + period: t.Period, + value: t.VarArray, + ) -> None: + if self._current_node is not None: + self._current_node.parameters.append( + TraceNode(name=parameter, period=period, value=value), ) + def record_calculation_result(self, value: t.VarArray) -> None: + if self._current_node is not None: + self._current_node.value = value + + def record_calculation_end(self) -> None: + self._simple_tracer.record_calculation_end() + self._record_end_time() + self._exit_calculation() + + def print_computation_log( + self, aggregate: bool = False, max_depth: int = sys.maxsize + ) -> None: + self.computation_log.print_log(aggregate, max_depth) + + def generate_performance_graph(self, dir_path: str) -> None: + self.performance_log.generate_graph(dir_path) + + def generate_performance_tables(self, dir_path: str) -> None: + self.performance_log.generate_performance_tables(dir_path) + + def get_nb_requests(self, variable: str) -> int: + return sum(self._get_nb_requests(tree, variable) for tree in self.trees) + + def get_flat_trace(self) -> t.FlatNodeMap: + return self.flat_trace.get_trace() + + def get_serialized_flat_trace(self) -> t.SerializedNodeMap: + return self.flat_trace.get_serialized_trace() + + def browse_trace(self) -> Iterator[t.TraceNode]: + def _browse_node(node: t.TraceNode) -> Iterator[t.TraceNode]: + yield node + + for child in node.children: + yield from _browse_node(child) + + for node in self._trees: + yield from _browse_node(node) + + def _enter_calculation( + self, + variable: t.VariableName, + period: t.PeriodInt | t.Period, + ) -> None: + new_node = TraceNode( + name=variable, + period=period, + parent=self._current_node, + ) + if self._current_node is None: self._trees.append(new_node) @@ -53,41 +126,20 @@ def _enter_calculation( self._current_node = new_node - def record_parameter_access( - self, - parameter: str, - period: Period, - value: ArrayLike, - ) -> None: - - if self._current_node is not None: - self._current_node.parameters.append( - tracers.TraceNode(name = parameter, period = period, value = value), - ) - def _record_start_time( - self, - time_in_s: Optional[float] = None, - ) -> None: + self, + time_in_s: float | None = None, + ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() if self._current_node is not None: self._current_node.start = time_in_s - def record_calculation_result(self, value: ArrayLike) -> None: - if self._current_node is not None: - self._current_node.value = value - - def record_calculation_end(self) -> None: - self._simple_tracer.record_calculation_end() - self._record_end_time() - self._exit_calculation() - def _record_end_time( - self, - time_in_s: Optional[float] = None, - ) -> None: + self, + time_in_s: None | t.Time = None, + ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -98,68 +150,17 @@ def _exit_calculation(self) -> None: if self._current_node is not None: self._current_node = self._current_node.parent - @property - def stack(self) -> Stack: - return self._simple_tracer.stack - - @property - def trees(self) -> List[tracers.TraceNode]: - return self._trees - - @property - def computation_log(self) -> tracers.ComputationLog: - return tracers.ComputationLog(self) - - @property - def performance_log(self) -> tracers.PerformanceLog: - return tracers.PerformanceLog(self) - - @property - def flat_trace(self) -> tracers.FlatTrace: - return tracers.FlatTrace(self) - - def _get_time_in_sec(self) -> float: - return time.time_ns() / (10**9) - - def print_computation_log(self, aggregate = False): - self.computation_log.print_log(aggregate) - - def generate_performance_graph(self, dir_path: str) -> None: - self.performance_log.generate_graph(dir_path) - - def generate_performance_tables(self, dir_path: str) -> None: - self.performance_log.generate_performance_tables(dir_path) - - def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: + def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int: tree_call = tree.name == variable children_calls = sum( - self._get_nb_requests(child, variable) - for child - in tree.children - ) + self._get_nb_requests(child, variable) for child in tree.children + ) return tree_call + children_calls - def get_nb_requests(self, variable: str) -> int: - return sum( - self._get_nb_requests(tree, variable) - for tree - in self.trees - ) - - def get_flat_trace(self) -> dict: - return self.flat_trace.get_trace() - - def get_serialized_flat_trace(self) -> dict: - return self.flat_trace.get_serialized_trace() - - def browse_trace(self) -> Iterator[tracers.TraceNode]: - - def _browse_node(node): - yield node + @staticmethod + def _get_time_in_sec() -> t.Time: + return time.time_ns() / (10**9) - for child in node.children: - yield from _browse_node(child) - for node in self._trees: - yield from _browse_node(node) +__all__ = ["FullTracer"] diff --git a/openfisca_core/tracers/performance_log.py b/openfisca_core/tracers/performance_log.py index 754d7f8056..f69a3dd3a2 100644 --- a/openfisca_core/tracers/performance_log.py +++ b/openfisca_core/tracers/performance_log.py @@ -1,36 +1,36 @@ from __future__ import annotations +import typing + import csv import importlib.resources import itertools import json import os -import typing -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: - Trace = typing.Dict[str, dict] - Calculation = typing.Tuple[str, dict] - SortedTrace = typing.List[Calculation] + Trace = dict[str, dict] + Calculation = tuple[str, dict] + SortedTrace = list[Calculation] class PerformanceLog: - def __init__(self, full_tracer: tracers.FullTracer) -> None: self._full_tracer = full_tracer def generate_graph(self, dir_path: str) -> None: - with open(os.path.join(dir_path, 'performance_graph.html'), 'w') as f: + with open(os.path.join(dir_path, "performance_graph.html"), "w") as f: template = importlib.resources.read_text( - 'openfisca_core.scripts.assets', - 'index.html', - ) + "openfisca_core.scripts.assets", + "index.html", + ) perf_graph_html = template.replace( - '{{data}}', + "{{data}}", json.dumps(self._json()), - ) + ) f.write(perf_graph_html) @@ -39,94 +39,95 @@ def generate_performance_tables(self, dir_path: str) -> None: csv_rows = [ { - 'name': key, - 'calculation_time': trace['calculation_time'], - 'formula_time': trace['formula_time'], - } - for key, trace - in flat_trace.items() - ] + "name": key, + "calculation_time": trace["calculation_time"], + "formula_time": trace["formula_time"], + } + for key, trace in flat_trace.items() + ] self._write_csv( - os.path.join(dir_path, 'performance_table.csv'), + os.path.join(dir_path, "performance_table.csv"), csv_rows, - ) + ) aggregated_csv_rows = [ - {'name': key, **aggregated_time} - for key, aggregated_time - in self.aggregate_calculation_times(flat_trace).items() - ] + {"name": key, **aggregated_time} + for key, aggregated_time in self.aggregate_calculation_times( + flat_trace, + ).items() + ] self._write_csv( - os.path.join(dir_path, 'aggregated_performance_table.csv'), + os.path.join(dir_path, "aggregated_performance_table.csv"), aggregated_csv_rows, - ) + ) def aggregate_calculation_times( - self, - flat_trace: Trace, - ) -> typing.Dict[str, dict]: - + self, + flat_trace: Trace, + ) -> dict[str, dict]: def _aggregate_calculations(calculations: list) -> dict: calculation_count = len(calculations) calculation_time = sum( - calculation[1]['calculation_time'] - for calculation - in calculations - ) + calculation[1]["calculation_time"] for calculation in calculations + ) formula_time = sum( - calculation[1]['formula_time'] - for calculation - in calculations - ) + calculation[1]["formula_time"] for calculation in calculations + ) return { - 'calculation_count': calculation_count, - 'calculation_time': tracers.TraceNode.round(calculation_time), - 'formula_time': tracers.TraceNode.round(formula_time), - 'avg_calculation_time': tracers.TraceNode.round(calculation_time / calculation_count), - 'avg_formula_time': tracers.TraceNode.round(formula_time / calculation_count), - } + "calculation_count": calculation_count, + "calculation_time": tracers.TraceNode.round(calculation_time), + "formula_time": tracers.TraceNode.round(formula_time), + "avg_calculation_time": tracers.TraceNode.round( + calculation_time / calculation_count, + ), + "avg_formula_time": tracers.TraceNode.round( + formula_time / calculation_count, + ), + } def _groupby(calculation: Calculation) -> str: - return calculation[0].split('<')[0] + return calculation[0].split("<")[0] all_calculations: SortedTrace = sorted(flat_trace.items()) return { variable_name: _aggregate_calculations(list(calculations)) - for variable_name, calculations - in itertools.groupby(all_calculations, _groupby) - } + for variable_name, calculations in itertools.groupby( + all_calculations, + _groupby, + ) + } def _json(self) -> dict: children = [self._json_tree(tree) for tree in self._full_tracer.trees] - calculations_total_time = sum(child['value'] for child in children) + calculations_total_time = sum(child["value"] for child in children) return { - 'name': 'All calculations', - 'value': calculations_total_time, - 'children': children, - } + "name": "All calculations", + "value": calculations_total_time, + "children": children, + } def _json_tree(self, tree: tracers.TraceNode) -> dict: calculation_total_time = tree.calculation_time() children = [self._json_tree(child) for child in tree.children] return { - 'name': f"{tree.name}<{tree.period}>", - 'value': calculation_total_time, - 'children': children, - } + "name": f"{tree.name}<{tree.period}>", + "value": calculation_total_time, + "children": children, + } - def _write_csv(self, path: str, rows: typing.List[dict]) -> None: + def _write_csv(self, path: str, rows: list[dict]) -> None: fieldnames = list(rows[0].keys()) - with open(path, 'w') as csv_file: - writer = csv.DictWriter(csv_file, fieldnames = fieldnames) + with open(path, "w") as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() for row in rows: diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 2fa98c6582..174dd31196 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,35 +1,64 @@ from __future__ import annotations -import typing -from typing import Dict, List, Union - -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period - - Stack = List[Dict[str, Union[str, Period]]] +from . import types as t class SimpleTracer: + """A simple tracer that records a stack of traces.""" - _stack: Stack + #: The stack of traces. + _stack: t.SimpleStack def __init__(self) -> None: self._stack = [] - def record_calculation_start(self, variable: str, period: Period) -> None: - self.stack.append({'name': variable, 'period': period}) + @property + def stack(self) -> t.SimpleStack: + """Return the stack of traces.""" + return self._stack + + def record_calculation_start( + self, variable: t.VariableName, period: t.PeriodInt | t.Period + ) -> None: + """Record the start of a calculation. + + Args: + variable: The variable being calculated. + period: The period for which the variable is being calculated. + + Examples: + >>> from openfisca_core import tracers - def record_calculation_result(self, value: ArrayLike) -> None: - pass # ignore calculation result + >>> tracer = tracers.SimpleTracer() + >>> tracer.record_calculation_start("variable", 2020) + >>> tracer.stack + [{'name': 'variable', 'period': 2020}] - def record_parameter_access(self, parameter: str, period, value): - pass + """ + self.stack.append({"name": variable, "period": period}) + + def record_calculation_result(self, value: t.ArrayLike[object]) -> None: + """Ignore calculation result.""" + + def record_parameter_access( + self, parameter: str, period: t.Period, value: t.ArrayLike[object] + ) -> None: + """Ignore parameter access.""" def record_calculation_end(self) -> None: + """Record the end of a calculation. + + Examples: + >>> from openfisca_core import tracers + + >>> tracer = tracers.SimpleTracer() + >>> tracer.record_calculation_start("variable", 2020) + >>> tracer.record_calculation_end() + >>> tracer.stack + [] + + """ self.stack.pop() - @property - def stack(self) -> Stack: - return self._stack + +__all__ = ["SimpleTracer"] diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index 93b630886c..de81825e81 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -1,30 +1,61 @@ from __future__ import annotations import dataclasses -import typing -if typing.TYPE_CHECKING: - import numpy - - from openfisca_core.indexed_enums import EnumArray - from openfisca_core.periods import Period - - Array = typing.Union[EnumArray, numpy.typing.ArrayLike] - Time = typing.Union[float, int] +from . import types as t @dataclasses.dataclass class TraceNode: + """A node in the tracing tree.""" + + #: The name of the node. name: str - period: Period - parent: typing.Optional[TraceNode] = None - children: typing.List[TraceNode] = dataclasses.field(default_factory = list) - parameters: typing.List[TraceNode] = dataclasses.field(default_factory = list) - value: typing.Optional[Array] = None - start: float = 0 - end: float = 0 - - def calculation_time(self, round_: bool = True) -> Time: + + #: The period of the node. + period: t.PeriodInt | t.Period + + #: The parent of the node. + parent: None | t.TraceNode = None + + #: The children of the node. + children: list[t.TraceNode] = dataclasses.field(default_factory=list) + + #: The parameters of the node. + parameters: list[t.TraceNode] = dataclasses.field(default_factory=list) + + #: The value of the node. + value: None | t.VarArray = None + + #: The start time of the node. + start: t.Time = 0.0 + + #: The end time of the node. + end: t.Time = 0.0 + + def calculation_time(self, round_: bool = True) -> t.Time: + """Calculate the time spent in the node. + + Args: + round_: Whether to round the result. + + Returns: + float: The time spent in the node. + + Examples: + >>> from openfisca_core import tracers + + >>> node = tracers.TraceNode("variable", 2020) + >>> node.start = 1.123122313 + >>> node.end = 1.12312313123 + + >>> node.calculation_time() + 8.182e-07 + + >>> node.calculation_time(round_=False) + 8.182299999770493e-07 + + """ result = self.end - self.start if round_: @@ -32,23 +63,59 @@ def calculation_time(self, round_: bool = True) -> Time: return result - def formula_time(self) -> float: + def formula_time(self) -> t.Time: + """Calculate the time spent on the formula. + + Returns: + float: The time spent on the formula. + + Examples: + >>> from openfisca_core import tracers + + >>> node = tracers.TraceNode("variable", 2020) + >>> node.start = 1.123122313 * 11 + >>> node.end = 1.12312313123 * 11 + >>> child = tracers.TraceNode("variable", 2020) + >>> child.start = 1.123122313 + >>> child.end = 1.12312313123 + + >>> for i in range(10): + ... node.children = [child, *node.children] + + >>> node.formula_time() + 8.182e-07 + + """ children_calculation_time = sum( - child.calculation_time(round_ = False) - for child - in self.children - ) + child.calculation_time(round_=False) for child in self.children + ) - result = ( - + self.calculation_time(round_ = False) - - children_calculation_time - ) + result = +self.calculation_time(round_=False) - children_calculation_time return self.round(result) - def append_child(self, node: TraceNode) -> None: + def append_child(self, node: t.TraceNode) -> None: + """Append a child to the node.""" self.children.append(node) @staticmethod - def round(time: Time) -> float: - return float(f'{time:.4g}') # Keep only 4 significant figures + def round(time: t.Time) -> t.Time: + """Keep only 4 significant figures. + + Args: + time: The time to round. + + Returns: + float: The rounded time. + + Examples: + >>> from openfisca_core import tracers + + >>> tracers.TraceNode.round(0.000123456789) + 0.0001235 + + """ + return float(f"{time:.4g}") + + +__all__ = ["TraceNode"] diff --git a/openfisca_core/tracers/tracing_parameter_node_at_instant.py b/openfisca_core/tracers/tracing_parameter_node_at_instant.py index 89d9b8fb01..074c24221d 100644 --- a/openfisca_core/tracers/tracing_parameter_node_at_instant.py +++ b/openfisca_core/tracers/tracing_parameter_node_at_instant.py @@ -7,70 +7,77 @@ from openfisca_core import parameters -from .. import tracers - ParameterNode = Union[ parameters.VectorialParameterNodeAtInstant, parameters.ParameterNodeAtInstant, - ] +] if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Child = Union[ParameterNode, ArrayLike] class TracingParameterNodeAtInstant: - def __init__( - self, - parameter_node_at_instant: ParameterNode, - tracer: tracers.FullTracer, - ) -> None: + self, + parameter_node_at_instant: ParameterNode, + tracer: tracers.FullTracer, + ) -> None: self.parameter_node_at_instant = parameter_node_at_instant self.tracer = tracer def __getattr__( - self, - key: str, - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + key: str, + ) -> TracingParameterNodeAtInstant | Child: child = getattr(self.parameter_node_at_instant, key) return self.get_traced_child(child, key) + def __contains__(self, key) -> bool: + return key in self.parameter_node_at_instant + + def __iter__(self): + return iter(self.parameter_node_at_instant) + def __getitem__( - self, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: child = self.parameter_node_at_instant[key] return self.get_traced_child(child, key) def get_traced_child( - self, - child: Child, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + child: Child, + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: period = self.parameter_node_at_instant._instant_str if isinstance( - child, - (parameters.ParameterNodeAtInstant, parameters.VectorialParameterNodeAtInstant), - ): + child, + ( + parameters.ParameterNodeAtInstant, + parameters.VectorialParameterNodeAtInstant, + ), + ): return TracingParameterNodeAtInstant(child, self.tracer) - if not isinstance(key, str) or \ - isinstance( - self.parameter_node_at_instant, - parameters.VectorialParameterNodeAtInstant, - ): + if not isinstance(key, str) or isinstance( + self.parameter_node_at_instant, + parameters.VectorialParameterNodeAtInstant, + ): # In case of vectorization, we keep the parent node name as, for # instance, rate[status].zone1 is best described as the value of # "rate". name = self.parameter_node_at_instant._name else: - name = '.'.join([self.parameter_node_at_instant._name, key]) + name = f"{self.parameter_node_at_instant._name}.{key}" - if isinstance(child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES): + if isinstance(child, (numpy.ndarray, *parameters.ALLOWED_PARAM_TYPES)): self.tracer.record_parameter_access(name, period, child) return child diff --git a/openfisca_core/tracers/types.py b/openfisca_core/tracers/types.py new file mode 100644 index 0000000000..f26c854241 --- /dev/null +++ b/openfisca_core/tracers/types.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import NewType, Protocol +from typing_extensions import TypeAlias, TypedDict + +from openfisca_core.types import ( + Array, + ArrayLike, + ParameterNode, + ParameterNodeChild, + Period, + PeriodInt, + VariableName, +) + +from numpy import generic as VarDType + +#: A type of a generic array. +VarArray: TypeAlias = Array[VarDType] + +#: A type representing a unit time. +Time: TypeAlias = float + +#: A type representing a mapping of flat traces. +FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"] + +#: A type representing a mapping of serialized traces. +SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"] + +#: A stack of simple traces. +SimpleStack: TypeAlias = list["SimpleTraceMap"] + +#: Key of a trace. +NodeKey = NewType("NodeKey", str) + + +class FlatTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | VarArray + calculation_time: Time + formula_time: Time + + +class SerializedTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | ArrayLike[object] + calculation_time: Time + formula_time: Time + + +class SimpleTraceMap(TypedDict, total=True): + name: VariableName + period: int | Period + + +class ComputationLog(Protocol): + def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ... + + +class FlatTrace(Protocol): + def get_trace(self, /) -> FlatNodeMap: ... + def get_serialized_trace(self, /) -> SerializedNodeMap: ... + + +class FullTracer(Protocol): + @property + def trees(self, /) -> list[TraceNode]: ... + def browse_trace(self, /) -> Iterator[TraceNode]: ... + + +class PerformanceLog(Protocol): + def generate_graph(self, dir_path: str, /) -> None: ... + def generate_performance_tables(self, dir_path: str, /) -> None: ... + + +class SimpleTracer(Protocol): + @property + def stack(self, /) -> SimpleStack: ... + def record_calculation_start( + self, variable: VariableName, period: PeriodInt | Period, / + ) -> None: ... + def record_calculation_end(self, /) -> None: ... + + +class TraceNode(Protocol): + children: list[TraceNode] + end: Time + name: str + parameters: list[TraceNode] + parent: None | TraceNode + period: PeriodInt | Period + start: Time + value: None | VarArray + + def calculation_time(self, *, round_: bool = ...) -> Time: ... + def formula_time(self, /) -> Time: ... + def append_child(self, node: TraceNode, /) -> None: ... + + +__all__ = [ + "ArrayLike", + "ParameterNode", + "ParameterNodeChild", + "PeriodInt", +] diff --git a/openfisca_core/types.py b/openfisca_core/types.py new file mode 100644 index 0000000000..9c81057416 --- /dev/null +++ b/openfisca_core/types.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence, Sized +from numpy.typing import DTypeLike, NDArray +from typing import NewType, TypeVar, Union +from typing_extensions import Protocol, Required, Self, TypeAlias, TypedDict + +import abc +import enum + +import numpy +import pendulum + +#: Generic covariant type var. +_T_co = TypeVar("_T_co", covariant=True) + +# Commons + +#: Type var for numpy arrays. +_N_co = TypeVar("_N_co", covariant=True, bound="DTypeGeneric") + +#: Type representing an numpy array. +Array: TypeAlias = NDArray[_N_co] + +#: Type var for array-like objects. +_L = TypeVar("_L") + +#: Type representing an array-like object. +ArrayLike: TypeAlias = Sequence[_L] + +#: Type for bool arrays. +DTypeBool: TypeAlias = numpy.bool_ + +#: Type for int arrays. +DTypeInt: TypeAlias = numpy.int32 + +#: Type for float arrays. +DTypeFloat: TypeAlias = numpy.float32 + +#: Type for string arrays. +DTypeStr: TypeAlias = numpy.str_ + +#: Type for bytes arrays. +DTypeBytes: TypeAlias = numpy.bytes_ + +#: Type for Enum arrays. +DTypeEnum: TypeAlias = numpy.uint8 + +#: Type for date arrays. +DTypeDate: TypeAlias = numpy.datetime64 + +#: Type for "object" arrays. +DTypeObject: TypeAlias = numpy.object_ + +#: Type for "generic" arrays. +DTypeGeneric: TypeAlias = numpy.generic + +# Entities + +#: For example "person". +EntityKey = NewType("EntityKey", str) + +#: For example "persons". +EntityPlural = NewType("EntityPlural", str) + +#: For example "principal". +RoleKey = NewType("RoleKey", str) + +#: For example "parents". +RolePlural = NewType("RolePlural", str) + + +class CoreEntity(Protocol): + key: EntityKey + plural: EntityPlural + + def check_role_validity(self, role: object, /) -> None: ... + def check_variable_defined_for_entity( + self, + variable_name: VariableName, + /, + ) -> None: ... + def get_variable( + self, + variable_name: VariableName, + check_existence: bool = ..., + /, + ) -> None | Variable: ... + + +class SingleEntity(CoreEntity, Protocol): ... + + +class GroupEntity(CoreEntity, Protocol): ... + + +class Role(Protocol): + entity: GroupEntity + max: int | None + subroles: None | Iterable[Role] + + @property + def key(self, /) -> RoleKey: ... + @property + def plural(self, /) -> None | RolePlural: ... + + +# Indexed enums + + +class EnumType(enum.EnumMeta): + indices: Array[DTypeEnum] + names: Array[DTypeStr] + enums: Array[DTypeObject] + + +class Enum(enum.Enum, metaclass=EnumType): + index: int + _member_names_: list[str] + + +class EnumArray(Array[DTypeEnum], metaclass=abc.ABCMeta): + possible_values: None | type[Enum] + + @abc.abstractmethod + def __new__( + cls, input_array: Array[DTypeEnum], possible_values: type[Enum] + ) -> Self: ... + + +# Holders + + +class Holder(Protocol): + def clone(self, population: CorePopulation, /) -> Holder: ... + def get_memory_usage(self, /) -> MemoryUsage: ... + + +class MemoryUsage(TypedDict, total=False): + cell_size: int + dtype: DTypeLike + nb_arrays: int + nb_cells_by_array: int + nb_requests: int + nb_requests_by_array: int + total_nb_bytes: Required[int] + + +# Parameters + +#: A type representing a node of parameters. +ParameterNode: TypeAlias = Union[ + "ParameterNodeAtInstant", "VectorialParameterNodeAtInstant" +] + +#: A type representing a ??? +ParameterNodeChild: TypeAlias = Union[ParameterNode, ArrayLike[object]] + + +class ParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, __item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... + + +class VectorialParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... + + +# Periods + +#: For example "2000-01". +InstantStr = NewType("InstantStr", str) + +#: For example 2020. +PeriodInt = NewType("PeriodInt", int) + +#: For example "1:2000-01-01:day". +PeriodStr = NewType("PeriodStr", str) + + +class Container(Protocol[_T_co]): + def __contains__(self, item: object, /) -> bool: ... + + +class Indexable(Protocol[_T_co]): + def __getitem__(self, index: int, /) -> _T_co: ... + + +class DateUnit(Container[str], Protocol): + def upper(self, /) -> str: ... + + +class Instant(Indexable[int], Iterable[int], Sized, Protocol): + @property + def year(self, /) -> int: ... + @property + def month(self, /) -> int: ... + @property + def day(self, /) -> int: ... + @property + def date(self, /) -> pendulum.Date: ... + def __lt__(self, other: object, /) -> bool: ... + def __le__(self, other: object, /) -> bool: ... + def offset(self, offset: str | int, unit: DateUnit, /) -> None | Instant: ... + + +class Period(Indexable[Union[DateUnit, Instant, int]], Protocol): + @property + def unit(self, /) -> DateUnit: ... + @property + def start(self, /) -> Instant: ... + @property + def size(self, /) -> int: ... + @property + def stop(self, /) -> Instant: ... + def contains(self, other: Period, /) -> bool: ... + def offset(self, offset: str | int, unit: None | DateUnit = None, /) -> Period: ... + + +# Populations + + +class CorePopulation(Protocol): ... + + +class SinglePopulation(CorePopulation, Protocol): + entity: SingleEntity + + def get_holder(self, variable_name: VariableName, /) -> Holder: ... + + +class GroupPopulation(CorePopulation, Protocol): ... + + +# Simulations + + +class Simulation(Protocol): + def calculate( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def calculate_add( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def calculate_divide( + self, variable_name: VariableName, period: Period, / + ) -> Array[DTypeGeneric]: ... + def get_population(self, plural: None | str, /) -> CorePopulation: ... + + +# Tax-Benefit systems + + +class TaxBenefitSystem(Protocol): + person_entity: SingleEntity + + def get_variable( + self, + variable_name: VariableName, + check_existence: bool = ..., + /, + ) -> None | Variable: ... + + +# Variables + +#: For example "salary". +VariableName = NewType("VariableName", str) + + +class Variable(Protocol): + entity: CoreEntity + name: VariableName + + +class Formula(Protocol): + def __call__( + self, + population: CorePopulation, + instant: Instant, + params: Params, + /, + ) -> Array[DTypeGeneric]: ... + + +class Params(Protocol): + def __call__(self, instant: Instant, /) -> ParameterNodeAtInstant: ... + + +__all__ = ["DTypeLike"] diff --git a/openfisca_core/types/__init__.py b/openfisca_core/types/__init__.py deleted file mode 100644 index e14cfea65d..0000000000 --- a/openfisca_core/types/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Data types and protocols used by OpenFisca Core. - -The type definitions included in this sub-package are intented for -contributors, to help them better understand and document contracts -and expected behaviours. - -Official Public API: - * ``ArrayLike`` - * :attr:`.ArrayType` - -Note: - How imports are being used today:: - - from openfisca_core.types import * # Bad - from openfisca_core.types.data_types.arrays import ArrayLike # Bad - - - The previous examples provoke cyclic dependency problems, that prevents us - from modularizing the different components of the library, so as to make - them easier to test and to maintain. - - How could them be used after the next major release:: - - from openfisca_core.types import ArrayLike - - ArrayLike # Good: import types as publicly exposed - - .. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. - - .. _PEP8#Imports: - https://www.python.org/dev/peps/pep-0008/#imports - - .. _OpenFisca's Styleguide: - https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md - -""" - -# Official Public API - -from .data_types import ( # noqa: F401 - ArrayLike, - ArrayType, - ) - -__all__ = ["ArrayLike", "ArrayType"] diff --git a/openfisca_core/types/data_types/__init__.py b/openfisca_core/types/data_types/__init__.py deleted file mode 100644 index 6dd38194e3..0000000000 --- a/openfisca_core/types/data_types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .arrays import ArrayLike, ArrayType # noqa: F401 diff --git a/openfisca_core/types/data_types/arrays.py b/openfisca_core/types/data_types/arrays.py deleted file mode 100644 index 5cfef639c5..0000000000 --- a/openfisca_core/types/data_types/arrays.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Sequence, TypeVar, Union - -from nptyping import types, NDArray as ArrayType - -import numpy - -T = TypeVar("T", bool, bytes, float, int, object, str) - -types._ndarray_meta._Type = Union[type, numpy.dtype, TypeVar] - -ArrayLike = Union[ArrayType[T], Sequence[T]] -""":obj:`typing.Generic`: Type of any castable to :class:`numpy.ndarray`. - -These include any :obj:`numpy.ndarray` and sequences (like -:obj:`list`, :obj:`tuple`, and so on). - -Examples: - >>> ArrayLike[float] - typing.Union[numpy.ndarray, typing.Sequence[float]] - - >>> ArrayLike[str] - typing.Union[numpy.ndarray, typing.Sequence[str]] - -Note: - It is possible since numpy version 1.21 to specify the type of an - array, thanks to `numpy.typing.NDArray`_:: - - from numpy.typing import NDArray - NDArray[numpy.float64] - - `mypy`_ provides `duck type compatibility`_, so an :obj:`int` is - considered to be valid whenever a :obj:`float` is expected. - -Todo: - * Refactor once numpy version >= 1.21 is used. - -.. versionadded:: 35.5.0 - -.. versionchanged:: 35.6.0 - Moved to :mod:`.types` - -.. _mypy: - https://mypy.readthedocs.io/en/stable/ - -.. _duck type compatibility: - https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html - -.. _numpy.typing.NDArray: - https://numpy.org/doc/stable/reference/typing.html#numpy.typing.NDArray - -""" diff --git a/openfisca_core/variables/__init__.py b/openfisca_core/variables/__init__.py index 3decaf8f42..1ab191c5ce 100644 --- a/openfisca_core/variables/__init__.py +++ b/openfisca_core/variables/__init__.py @@ -21,6 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .config import VALUE_TYPES, FORMULA_NAME_PREFIX # noqa: F401 +from .config import FORMULA_NAME_PREFIX, VALUE_TYPES # noqa: F401 from .helpers import get_annualized_variable, get_neutralized_variable # noqa: F401 from .variable import Variable # noqa: F401 diff --git a/openfisca_core/variables/config.py b/openfisca_core/variables/config.py index b260bb3dd9..54270145bf 100644 --- a/openfisca_core/variables/config.py +++ b/openfisca_core/variables/config.py @@ -5,50 +5,49 @@ from openfisca_core import indexed_enums from openfisca_core.indexed_enums import Enum - VALUE_TYPES = { bool: { - 'dtype': numpy.bool_, - 'default': False, - 'json_type': 'boolean', - 'formatted_value_type': 'Boolean', - 'is_period_size_independent': True - }, + "dtype": numpy.bool_, + "default": False, + "json_type": "boolean", + "formatted_value_type": "Boolean", + "is_period_size_independent": True, + }, int: { - 'dtype': numpy.int32, - 'default': 0, - 'json_type': 'integer', - 'formatted_value_type': 'Int', - 'is_period_size_independent': False - }, + "dtype": numpy.int32, + "default": 0, + "json_type": "integer", + "formatted_value_type": "Int", + "is_period_size_independent": False, + }, float: { - 'dtype': numpy.float32, - 'default': 0, - 'json_type': 'number', - 'formatted_value_type': 'Float', - 'is_period_size_independent': False, - }, + "dtype": numpy.float32, + "default": 0, + "json_type": "number", + "formatted_value_type": "Float", + "is_period_size_independent": False, + }, str: { - 'dtype': object, - 'default': '', - 'json_type': 'string', - 'formatted_value_type': 'String', - 'is_period_size_independent': True - }, + "dtype": object, + "default": "", + "json_type": "string", + "formatted_value_type": "String", + "is_period_size_independent": True, + }, Enum: { - 'dtype': indexed_enums.ENUM_ARRAY_DTYPE, - 'json_type': 'string', - 'formatted_value_type': 'String', - 'is_period_size_independent': True, - }, + "dtype": indexed_enums.ENUM_ARRAY_DTYPE, + "json_type": "string", + "formatted_value_type": "String", + "is_period_size_independent": True, + }, datetime.date: { - 'dtype': 'datetime64[D]', - 'default': datetime.date.fromtimestamp(0), # 0 == 1970-01-01 - 'json_type': 'string', - 'formatted_value_type': 'Date', - 'is_period_size_independent': True, - }, - } + "dtype": "datetime64[D]", + "default": datetime.date.fromtimestamp(0), # 0 == 1970-01-01 + "json_type": "string", + "formatted_value_type": "Date", + "is_period_size_independent": True, + }, +} -FORMULA_NAME_PREFIX = 'formula' +FORMULA_NAME_PREFIX = "formula" diff --git a/openfisca_core/variables/helpers.py b/openfisca_core/variables/helpers.py index 335a585498..5038a78240 100644 --- a/openfisca_core/variables/helpers.py +++ b/openfisca_core/variables/helpers.py @@ -1,23 +1,24 @@ from __future__ import annotations import sortedcontainers -from typing import Optional +from openfisca_core import variables from openfisca_core.periods import Period -from .. import variables - -def get_annualized_variable(variable: variables.Variable, annualization_period: Optional[Period] = None) -> variables.Variable: - """ - Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. +def get_annualized_variable( + variable: variables.Variable, + annualization_period: Period | None = None, +) -> variables.Variable: + """Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. When annualized, a variable's formula is only called for a January calculation, and the results for other months are assumed to be identical. """ - def make_annual_formula(original_formula, annualization_period = None): - + def make_annual_formula(original_formula, annualization_period=None): def annual_formula(population, period, parameters): - if period.start.month != 1 and (annualization_period is None or annualization_period.contains(period)): + if period.start.month != 1 and ( + annualization_period is None or annualization_period.contains(period) + ): return population(variable.name, period.this_year.first_month) if original_formula.__code__.co_argcount == 2: return original_formula(population, period) @@ -26,22 +27,29 @@ def annual_formula(population, period, parameters): return annual_formula new_variable = variable.clone() - new_variable.formulas = sortedcontainers.sorteddict.SortedDict({ - key: make_annual_formula(formula, annualization_period) - for key, formula in variable.formulas.items() - }) + new_variable.formulas = sortedcontainers.sorteddict.SortedDict( + { + key: make_annual_formula(formula, annualization_period) + for key, formula in variable.formulas.items() + }, + ) return new_variable def get_neutralized_variable(variable): - """ - Return a new neutralized variable (to be used by reforms). + """Return a new neutralized variable (to be used by reforms). A neutralized variable always returns its default value, and does not cache anything. """ result = variable.clone() result.is_neutralized = True - result.label = '[Neutralized]' if variable.label is None else '[Neutralized] {}'.format(variable.label), + result.label = ( + ( + "[Neutralized]" + if variable.label is None + else f"[Neutralized] {variable.label}" + ), + ) return result diff --git a/openfisca_core/variables/tests/__init__.py b/openfisca_core/variables/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/variables/tests/test_definition_period.py b/openfisca_core/variables/tests/test_definition_period.py new file mode 100644 index 0000000000..8ef9bfaa87 --- /dev/null +++ b/openfisca_core/variables/tests/test_definition_period.py @@ -0,0 +1,43 @@ +import pytest + +from openfisca_core import periods +from openfisca_core.variables import Variable + + +@pytest.fixture +def variable(persons): + class TestVariable(Variable): + value_type = float + entity = persons + + return TestVariable + + +def test_weekday_variable(variable) -> None: + variable.definition_period = periods.WEEKDAY + assert variable() + + +def test_week_variable(variable) -> None: + variable.definition_period = periods.WEEK + assert variable() + + +def test_day_variable(variable) -> None: + variable.definition_period = periods.DAY + assert variable() + + +def test_month_variable(variable) -> None: + variable.definition_period = periods.MONTH + assert variable() + + +def test_year_variable(variable) -> None: + variable.definition_period = periods.YEAR + assert variable() + + +def test_eternity_variable(variable) -> None: + variable.definition_period = periods.ETERNITY + assert variable() diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index acfeb9fe70..926e4c59c1 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -1,22 +1,24 @@ +from __future__ import annotations + +from typing import NoReturn + import datetime -import inspect import re import textwrap -import sortedcontainers import numpy +import sortedcontainers -from openfisca_core import periods, tools -from openfisca_core.entities import Entity +from openfisca_core import commons, periods, types as t +from openfisca_core.entities import Entity, GroupEntity from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import Period +from openfisca_core.periods import DateUnit, Period from . import config, helpers class Variable: - """ - A `variable `_ of the legislation. + """A `variable `_ of the legislation. Main attributes: @@ -34,7 +36,7 @@ class Variable: .. attribute:: definition_period - `Period `_ the variable is defined for. Possible value: ``MONTH``, ``YEAR``, ``ETERNITY``. + `Period `_ the variable is defined for. Possible value: ``DateUnit.DAY``, ``DateUnit.MONTH``, ``DateUnit.YEAR``, ``DateUnit.ETERNITY``. .. attribute:: formulas @@ -95,64 +97,137 @@ class Variable: Free multilines text field describing the variable context and usage. """ - def __init__(self, baseline_variable = None): + __name__: str + + def __init__(self, baseline_variable=None) -> None: self.name = self.__class__.__name__ attr = { - name: value for name, value in self.__class__.__dict__.items() - if not name.startswith('__')} + name: value + for name, value in self.__class__.__dict__.items() + if not name.startswith("__") + } self.baseline_variable = baseline_variable - self.value_type = self.set(attr, 'value_type', required = True, allowed_values = config.VALUE_TYPES.keys()) - self.dtype = config.VALUE_TYPES[self.value_type]['dtype'] - self.json_type = config.VALUE_TYPES[self.value_type]['json_type'] + self.value_type = self.set( + attr, + "value_type", + required=True, + allowed_values=config.VALUE_TYPES.keys(), + ) + self.dtype = config.VALUE_TYPES[self.value_type]["dtype"] + self.json_type = config.VALUE_TYPES[self.value_type]["json_type"] if self.value_type == Enum: - self.possible_values = self.set(attr, 'possible_values', required = True, setter = self.set_possible_values) + self.possible_values = self.set( + attr, + "possible_values", + required=True, + setter=self.set_possible_values, + ) if self.value_type == str: - self.max_length = self.set(attr, 'max_length', allowed_type = int) + self.max_length = self.set(attr, "max_length", allowed_type=int) if self.max_length: - self.dtype = '|S{}'.format(self.max_length) + self.dtype = f"|S{self.max_length}" if self.value_type == Enum: - self.default_value = self.set(attr, 'default_value', allowed_type = self.possible_values, required = True) + self.default_value = self.set( + attr, + "default_value", + allowed_type=self.possible_values, + required=True, + ) else: - self.default_value = self.set(attr, 'default_value', allowed_type = self.value_type, default = config.VALUE_TYPES[self.value_type].get('default')) - self.entity = self.set(attr, 'entity', required = True, setter = self.set_entity) - self.definition_period = self.set(attr, 'definition_period', required = True, allowed_values = (periods.DAY, periods.MONTH, periods.YEAR, periods.ETERNITY)) - self.label = self.set(attr, 'label', allowed_type = str, setter = self.set_label) - self.end = self.set(attr, 'end', allowed_type = str, setter = self.set_end) - self.reference = self.set(attr, 'reference', setter = self.set_reference) - self.cerfa_field = self.set(attr, 'cerfa_field', allowed_type = (str, dict)) - self.unit = self.set(attr, 'unit', allowed_type = str) - self.documentation = self.set(attr, 'documentation', allowed_type = str, setter = self.set_documentation) - self.set_input = self.set_set_input(attr.pop('set_input', None)) - self.calculate_output = self.set_calculate_output(attr.pop('calculate_output', None)) - self.is_period_size_independent = self.set(attr, 'is_period_size_independent', allowed_type = bool, default = config.VALUE_TYPES[self.value_type]['is_period_size_independent']) - - formulas_attr, unexpected_attrs = helpers._partition(attr, lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX)) + self.default_value = self.set( + attr, + "default_value", + allowed_type=self.value_type, + default=config.VALUE_TYPES[self.value_type].get("default"), + ) + self.entity = self.set(attr, "entity", required=True, setter=self.set_entity) + self.definition_period = self.set( + attr, + "definition_period", + required=True, + allowed_values=DateUnit, + ) + self.label = self.set(attr, "label", allowed_type=str, setter=self.set_label) + self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end) + self.reference = self.set(attr, "reference", setter=self.set_reference) + self.cerfa_field = self.set(attr, "cerfa_field", allowed_type=(str, dict)) + self.unit = self.set(attr, "unit", allowed_type=str) + self.documentation = self.set( + attr, + "documentation", + allowed_type=str, + setter=self.set_documentation, + ) + self.set_input = self.set_set_input(attr.pop("set_input", None)) + self.calculate_output = self.set_calculate_output( + attr.pop("calculate_output", None), + ) + self.is_period_size_independent = self.set( + attr, + "is_period_size_independent", + allowed_type=bool, + default=config.VALUE_TYPES[self.value_type]["is_period_size_independent"], + ) + + self.introspection_data = self.set( + attr, + "introspection_data", + ) + + formulas_attr, unexpected_attrs = helpers._partition( + attr, + lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX), + ) self.formulas = self.set_formulas(formulas_attr) if unexpected_attrs: + msg = 'Unexpected attributes in definition of variable "{}": {!r}'.format( + self.name, + ", ".join(sorted(unexpected_attrs.keys())), + ) raise ValueError( - 'Unexpected attributes in definition of variable "{}": {!r}' - .format(self.name, ', '.join(sorted(unexpected_attrs.keys())))) + msg, + ) self.is_neutralized = False # ----- Setters used to build the variable ----- # - def set(self, attributes, attribute_name, required = False, allowed_values = None, allowed_type = None, setter = None, default = None): + def set( + self, + attributes, + attribute_name, + required=False, + allowed_values=None, + allowed_type=None, + setter=None, + default=None, + ): value = attributes.pop(attribute_name, None) if value is None and self.baseline_variable: return getattr(self.baseline_variable, attribute_name) if required and value is None: - raise ValueError("Missing attribute '{}' in definition of variable '{}'.".format(attribute_name, self.name)) + msg = f"Missing attribute '{attribute_name}' in definition of variable '{self.name}'." + raise ValueError( + msg, + ) if allowed_values is not None and value not in allowed_values: - raise ValueError("Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'." - .format(value, attribute_name, self.name, allowed_values)) - if allowed_type is not None and value is not None and not isinstance(value, allowed_type): + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Allowed values are '{allowed_values}'." + raise ValueError( + msg, + ) + if ( + allowed_type is not None + and value is not None + and not isinstance(value, allowed_type) + ): if allowed_type == float and isinstance(value, int): value = float(value) else: - raise ValueError("Invalid value '{}' for attribute '{}' in variable '{}'. Must be of type '{}'." - .format(value, attribute_name, self.name, allowed_type)) + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Must be of type '{allowed_type}'." + raise ValueError( + msg, + ) if setter is not None: value = setter(value) if value is None and default is not None: @@ -160,26 +235,39 @@ def set(self, attributes, attribute_name, required = False, allowed_values = Non return value def set_entity(self, entity): - if not isinstance(entity, Entity): - raise ValueError(f"Invalid value '{entity}' for attribute 'entity' in variable '{self.name}'. Must be an instance of Entity.") + if not isinstance(entity, (Entity, GroupEntity)): + msg = ( + f"Invalid value '{entity}' for attribute 'entity' in variable " + f"'{self.name}'. Must be an instance of Entity or GroupEntity." + ) + raise ValueError( + msg, + ) return entity def set_possible_values(self, possible_values): if not issubclass(possible_values, Enum): - raise ValueError("Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}." - .format(possible_values, self.name, Enum)) + msg = f"Invalid value '{possible_values}' for attribute 'possible_values' in variable '{self.name}'. Must be a subclass of {Enum}." + raise ValueError( + msg, + ) return possible_values def set_label(self, label): if label: return label + return None def set_end(self, end): if end: try: - return datetime.datetime.strptime(end, '%Y-%m-%d').date() + return datetime.datetime.strptime(end, "%Y-%m-%d").date() except ValueError: - raise ValueError("Incorrect 'end' attribute format in '{}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {}".format(self.name, end)) + msg = f"Incorrect 'end' attribute format in '{self.name}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {end}" + raise ValueError( + msg, + ) + return None def set_reference(self, reference): if reference: @@ -190,19 +278,24 @@ def set_reference(self, reference): elif isinstance(reference, tuple): reference = list(reference) else: - raise TypeError('The reference of the variable {} is a {} instead of a String or a List of Strings.'.format(self.name, type(reference))) + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." + raise TypeError( + msg, + ) for element in reference: if not isinstance(element, str): + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." raise TypeError( - 'The reference of the variable {} is a {} instead of a String or a List of Strings.'.format( - self.name, type(reference))) + msg, + ) return reference def set_documentation(self, documentation): if documentation: return textwrap.dedent(documentation) + return None def set_set_input(self, set_input): if not set_input and self.baseline_variable: @@ -220,25 +313,29 @@ def set_formulas(self, formulas_attr): starting_date = self.parse_formula_name(formula_name) if self.end is not None and starting_date > self.end: - raise ValueError('You declared that "{}" ends on "{}", but you wrote a formula to calculate it from "{}" ({}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.' - .format(self.name, self.end, starting_date, formula_name)) + msg = f'You declared that "{self.name}" ends on "{self.end}", but you wrote a formula to calculate it from "{starting_date}" ({formula_name}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.' + raise ValueError( + msg, + ) formulas[str(starting_date)] = formula # If the variable is reforming a baseline variable, keep the formulas from the latter when they are not overridden by new formulas. if self.baseline_variable is not None: first_reform_formula_date = formulas.peekitem(0)[0] if formulas else None - formulas.update({ - baseline_start_date: baseline_formula - for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items() - if first_reform_formula_date is None or baseline_start_date < first_reform_formula_date - }) + formulas.update( + { + baseline_start_date: baseline_formula + for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items() + if first_reform_formula_date is None + or baseline_start_date < first_reform_formula_date + }, + ) return formulas def parse_formula_name(self, attribute_name): - """ - Returns the starting date of a formula based on its name. + """Returns the starting date of a formula based on its name. Valid dated name formats are : 'formula', 'formula_YYYY', 'formula_YYYY_MM' and 'formula_YYYY_MM_DD' where YYYY, MM and DD are a year, month and day. @@ -248,76 +345,69 @@ def parse_formula_name(self, attribute_name): - `formula_YYYY_MM` is `YYYY-MM-01` """ - def raise_error(): + def raise_error() -> NoReturn: + msg = f'Unrecognized formula name in variable "{self.name}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{attribute_name}".' raise ValueError( - 'Unrecognized formula name in variable "{}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{}".' - .format(self.name, attribute_name)) + msg, + ) if attribute_name == config.FORMULA_NAME_PREFIX: return datetime.date.min - FORMULA_REGEX = r'formula_(\d{4})(?:_(\d{2}))?(?:_(\d{2}))?$' # YYYY or YYYY_MM or YYYY_MM_DD + FORMULA_REGEX = r"formula_(\d{4})(?:_(\d{2}))?(?:_(\d{2}))?$" # YYYY or YYYY_MM or YYYY_MM_DD match = re.match(FORMULA_REGEX, attribute_name) if not match: raise_error() - date_str = '-'.join([match.group(1), match.group(2) or '01', match.group(3) or '01']) + date_str = "-".join( + [match.group(1), match.group(2) or "01", match.group(3) or "01"], + ) try: - return datetime.datetime.strptime(date_str, '%Y-%m-%d').date() + return datetime.datetime.strptime(date_str, "%Y-%m-%d").date() except ValueError: # formula_2005_99_99 for instance raise_error() # ----- Methods ----- # def is_input_variable(self): - """ - Returns True if the variable is an input variable. - """ + """Returns True if the variable is an input variable.""" return len(self.formulas) == 0 @classmethod - def get_introspection_data(cls, tax_benefit_system): - """ - Get instrospection data about the code of the variable. - - :returns: (comments, source file path, source code, start line number) - :rtype: tuple - - """ - comments = inspect.getcomments(cls) - - # Handle dynamically generated variable classes or Jupyter Notebooks, which have no source. + def get_introspection_data(cls): try: - absolute_file_path = inspect.getsourcefile(cls) - except TypeError: - source_file_path = None - else: - source_file_path = absolute_file_path.replace(tax_benefit_system.get_package_metadata()['location'], '') - try: - source_lines, start_line_number = inspect.getsourcelines(cls) - source_code = textwrap.dedent(''.join(source_lines)) - except (IOError, TypeError): - source_code, start_line_number = None, None + return cls.introspection_data + except AttributeError: + return "", None, 0 - return comments, source_file_path, source_code, start_line_number + def get_formula( + self, + period: None | t.Instant | t.Period | str | int = None, + ) -> None | t.Formula: + """Returns the formula to compute the variable at the given period. - def get_formula(self, period = None): - """ - Returns the formula used to compute the variable at the given period. + If no period is given and the variable has several formulas, the method + returns the oldest formula. - If no period is given and the variable has several formula, return the oldest formula. + Args: + period: The period to get the formula. - :returns: Formula used to compute the variable - :rtype: callable + Returns: + Formula used to compute the variable. """ + instant: None | t.Instant if not self.formulas: return None if period is None: - return self.formulas.peekitem(index = 0)[1] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula. + return self.formulas.peekitem( + index=0, + )[ + 1 + ] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula. if isinstance(period, Period): instant = period.start @@ -327,19 +417,22 @@ def get_formula(self, period = None): except ValueError: instant = periods.instant(period) + if instant is None: + return None + if self.end and instant.date > self.end: return None - instant = str(instant) + instant_str = str(instant) + for start_date in reversed(self.formulas): - if start_date <= instant: + if start_date <= instant_str: return self.formulas[start_date] return None def clone(self): - clone = self.__class__() - return clone + return self.__class__() def check_set_value(self, value): if self.value_type == Enum and isinstance(value, str): @@ -347,35 +440,39 @@ def check_set_value(self, value): value = self.possible_values[value].index except KeyError: possible_values = [item.name for item in self.possible_values] + msg = "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( + value, + self.name, + "', '".join(possible_values), + ) raise ValueError( - "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( - value, self.name, "', '".join(possible_values)) - ) + msg, + ) if self.value_type in (float, int) and isinstance(value, str): try: - value = tools.eval_expression(value) + value = commons.eval_expression(value) except SyntaxError: + msg = f"I couldn't understand '{value}' as a value for '{self.name}'" raise ValueError( - "I couldn't understand '{}' as a value for '{}'".format( - value, self.name) - ) + msg, + ) try: - value = numpy.array([value], dtype = self.dtype)[0] + value = numpy.array([value], dtype=self.dtype)[0] except (TypeError, ValueError): - if (self.value_type == datetime.date): - error_message = "Can't deal with date: '{}'.".format(value) + if self.value_type == datetime.date: + error_message = f"Can't deal with date: '{value}'." else: - error_message = "Can't deal with value: expected type {}, received '{}'.".format(self.json_type, value) + error_message = f"Can't deal with value: expected type {self.json_type}, received '{value}'." raise ValueError(error_message) - except (OverflowError): - error_message = "Can't deal with value: '{}', it's too large for type '{}'.".format(value, self.json_type) + except OverflowError: + error_message = f"Can't deal with value: '{value}', it's too large for type '{self.json_type}'." raise ValueError(error_message) return value def default_array(self, array_size): - array = numpy.empty(array_size, dtype = self.dtype) + array = numpy.empty(array_size, dtype=self.dtype) if self.value_type == Enum: array.fill(self.default_value.index) return EnumArray(array, self.possible_values) diff --git a/openfisca_core/warnings/__init__.py b/openfisca_core/warnings/__init__.py index 9e450c8702..3397fb52de 100644 --- a/openfisca_core/warnings/__init__.py +++ b/openfisca_core/warnings/__init__.py @@ -22,5 +22,4 @@ # See: https://www.python.org/dev/peps/pep-0008/#imports from .libyaml_warning import LibYAMLWarning # noqa: F401 -from .memory_warning import MemoryConfigWarning # noqa: F401 from .tempfile_warning import TempfileWarning # noqa: F401 diff --git a/openfisca_core/warnings/libyaml_warning.py b/openfisca_core/warnings/libyaml_warning.py index 361a1688ad..7ea797b667 100644 --- a/openfisca_core/warnings/libyaml_warning.py +++ b/openfisca_core/warnings/libyaml_warning.py @@ -1,5 +1,2 @@ class LibYAMLWarning(UserWarning): - """ - Custom warning for LibYAML not installed. - """ - pass + """Custom warning for LibYAML not installed.""" diff --git a/openfisca_core/warnings/memory_warning.py b/openfisca_core/warnings/memory_warning.py deleted file mode 100644 index 8fcb1f46f4..0000000000 --- a/openfisca_core/warnings/memory_warning.py +++ /dev/null @@ -1,5 +0,0 @@ -class MemoryConfigWarning(UserWarning): - """ - Custom warning for MemoryConfig. - """ - pass diff --git a/openfisca_core/warnings/tempfile_warning.py b/openfisca_core/warnings/tempfile_warning.py index cf2b9947ac..9f4aad3820 100644 --- a/openfisca_core/warnings/tempfile_warning.py +++ b/openfisca_core/warnings/tempfile_warning.py @@ -1,5 +1,2 @@ class TempfileWarning(UserWarning): - """ - Custom warning when using a tempfile on disk. - """ - pass + """Custom warning when using a tempfile on disk.""" diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk index ac6bf1b34a..bb844b9d56 100644 --- a/openfisca_tasks/install.mk +++ b/openfisca_tasks/install.mk @@ -1,50 +1,21 @@ -## Install project dependencies. -install: - @${MAKE} install-deps - @${MAKE} install-dev - @${MAKE} install-core - @$(call print_pass,$@:) - -## Install common dependencies. -install-deps: - @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/common pip setuptools - -## Install development dependencies. -install-dev: - @$(call print_help,$@:) - @pip install --quiet --upgrade --requirement requirements/install - @pip install --quiet --upgrade --requirement requirements/dev - -## Install package. -install-core: - @$(call print_help,$@:) - @pip uninstall --quiet --yes openfisca-core - @pip install --quiet --no-dependencies --editable . - -## Install the WebAPI tracker. -install-tracker: - @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/tracker openfisca-tracker - -## Install lower-bound dependencies for compatibility check. -install-compat: +## Uninstall project's dependencies. +uninstall: @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/compatibility numpy + @python -m pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y -## Install coverage dependencies. -install-cov: +## Install project's overall dependencies +install-deps: @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/coverage coveralls + @python -m pip install --upgrade pip -## Uninstall project dependencies. -uninstall: +## Install project's development dependencies. +install-edit: @$(call print_help,$@:) - @pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y + @python -m pip install --upgrade --editable ".[dev]" ## Delete builds and compiled python files. -clean: \ - $(shell ls -d * | grep "build\|dist") \ - $(shell find . -name "*.pyc") +clean: @$(call print_help,$@:) - @rm -rf $? + @ls -d * | grep "build\|dist" | xargs rm -rf + @find . -name "__pycache__" | xargs rm -rf + @find . -name "*.pyc" | xargs rm -rf diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk index 115c6267bb..532518dc7e 100644 --- a/openfisca_tasks/lint.mk +++ b/openfisca_tasks/lint.mk @@ -1,5 +1,5 @@ ## Lint the codebase. -lint: check-syntax-errors check-style lint-doc check-types lint-typing-strict +lint: check-syntax-errors check-style lint-doc @$(call print_pass,$@:) ## Compile python files to check for syntax errors. @@ -9,15 +9,21 @@ check-syntax-errors: . @$(call print_pass,$@:) ## Run linters to check for syntax and style errors. -check-style: $(shell git ls-files "*.py") +check-style: $(shell git ls-files "*.py" "*.pyi") @$(call print_help,$@:) - @flake8 $? + @python -m isort --check $? + @python -m black --check $? + @python -m flake8 $? + @codespell @$(call print_pass,$@:) ## Run linters to check for syntax and style errors in the doc. lint-doc: \ lint-doc-commons \ - lint-doc-types \ + lint-doc-data_storage \ + lint-doc-entities \ + lint-doc-experimental \ + lint-doc-indexed_enums \ ; ## Run linters to check for syntax and style errors in the doc. @@ -26,37 +32,30 @@ lint-doc-%: @## @## They can be integrated into setup.cfg once all checks pass. @## The reason they're here is because otherwise we wouldn't be - @## able to integrate documentation improvements progresively. + @## able to integrate documentation improvements progressively. @## @$(call print_help,$(subst $*,%,$@:)) - @flake8 --select=D101,D102,D103,DAR openfisca_core/$* - @pylint openfisca_core/$* + @python -m flake8 --select=D101,D102,D103,DAR openfisca_core/$* + @python -m pylint openfisca_core/$* @$(call print_pass,$@:) ## Run static type checkers for type errors. check-types: @$(call print_help,$@:) - @mypy --package openfisca_core --package openfisca_web_api - @$(call print_pass,$@:) - -## Run static type checkers for type errors (strict). -lint-typing-strict: \ - lint-typing-strict-commons \ - lint-typing-strict-types \ - ; - -## Run static type checkers for type errors (strict). -lint-typing-strict-%: - @$(call print_help,$(subst $*,%,$@:)) - @mypy \ - --cache-dir .mypy_cache-openfisca_core.$* \ - --implicit-reexport \ - --strict \ - --package openfisca_core.$* + @python -m mypy \ + openfisca_core/commons \ + openfisca_core/data_storage \ + openfisca_core/experimental \ + openfisca_core/entities \ + openfisca_core/indexed_enums \ + openfisca_core/periods \ + openfisca_core/types.py @$(call print_pass,$@:) ## Run code formatters to correct style errors. -format-style: $(shell git ls-files "*.py") +format-style: $(shell git ls-files "*.py" "*.pyi") @$(call print_help,$@:) - @autopep8 $? + @python -m isort $? + @python -m black $? + @codespell --write-changes @$(call print_pass,$@:) diff --git a/openfisca_tasks/publish.mk b/openfisca_tasks/publish.mk index 09686a6274..c511bd9c23 100644 --- a/openfisca_tasks/publish.mk +++ b/openfisca_tasks/publish.mk @@ -1,45 +1,17 @@ .PHONY: build -## Build openfisca-core for deployment and publishing. -build: - @## This allows us to be sure tests are run against the packaged version - @## of openfisca-core, the same we put in the hands of users and reusers. +## Install project's build dependencies. +install-dist: @$(call print_help,$@:) - @${MAKE} install-deps - @${MAKE} build-deps - @${MAKE} build-build - @${MAKE} build-install + @python -m pip install .[ci,dev] @$(call print_pass,$@:) -## Install building dependencies. -build-deps: - @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/publication build - -## Build the package. -build-build: +## Build & install openfisca-core for deployment and publishing. +build: + @## This allows us to be sure tests are run against the packaged version + @## of openfisca-core, the same we put in the hands of users and reusers. @$(call print_help,$@:) @python -m build - -## Install the built package. -build-install: - @$(call print_help,$@:) - @pip uninstall --quiet --yes openfisca-core - @find dist -name "*.whl" -exec pip install --quiet --no-dependencies {} \; - -## Publish package. -publish: - @$(call print_help,$@:) - @${MAKE} publish-deps - @${MAKE} publish-upload + @python -m pip uninstall --yes openfisca-core + @find dist -name "*.whl" -exec python -m pip install --no-deps {} \; @$(call print_pass,$@:) - -## Install required publishing dependencies. -publish-deps: - @$(call print_help,$@:) - @pip install --quiet --upgrade --constraint requirements/publication twine - -## Upload package to PyPi. -publish-upload: - @$(call print_help,$@:) - twine upload dist/* --username $${PYPI_USERNAME} --password $${PYPI_PASSWORD} diff --git a/openfisca_tasks/test_code.mk b/openfisca_tasks/test_code.mk index 4abcce6aed..273dd4106f 100644 --- a/openfisca_tasks/test_code.mk +++ b/openfisca_tasks/test_code.mk @@ -1,8 +1,21 @@ ## The openfisca command module. openfisca = openfisca_core.scripts.openfisca_command -## The path to the installed packages. -python_packages = $(shell python -c "import sysconfig; print(sysconfig.get_paths()[\"purelib\"])") +## The path to the templates' tests. +ifeq ($(OS),Windows_NT) + tests = $(shell python -c "import os, $(1); print(repr(os.path.join($(1).__path__[0], 'tests')))") +else + tests = $(shell python -c "import $(1); print($(1).__path__[0])")/tests +endif + +## Run all tasks required for testing. +install: install-deps install-edit install-test + +## Enable regression testing with template repositories. +install-test: + @$(call print_help,$@:) + @python -m pip install --upgrade --no-deps openfisca-country-template + @python -m pip install --upgrade --no-deps openfisca-extension-template ## Run openfisca-core & country/extension template tests. test-code: test-core test-country test-extension @@ -20,19 +33,27 @@ test-code: test-core test-country test-extension @$(call print_pass,$@:) ## Run openfisca-core tests. -test-core: $(shell pytest --quiet --quiet --collect-only | cut -f 1 -d ":") +test-core: $(shell git ls-files "*test_*.py") @$(call print_help,$@:) + @python -m pytest --capture=no \ + openfisca_core/commons \ + openfisca_core/data_storage \ + openfisca_core/experimental \ + openfisca_core/entities \ + openfisca_core/holders \ + openfisca_core/indexed_enums \ + openfisca_core/periods \ + openfisca_core/projectors @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - coverage run -m \ - ${openfisca} test $? \ - ${openfisca_args} + python -m ${openfisca} test $? ${openfisca_args} @$(call print_pass,$@:) ## Run country-template tests. test-country: @$(call print_help,$@:) @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - openfisca test ${python_packages}/openfisca_country_template/tests \ + python -m ${openfisca} test \ + $(call tests,"openfisca_country_template") \ --country-package openfisca_country_template \ ${openfisca_args} @$(call print_pass,$@:) @@ -41,13 +62,9 @@ test-country: test-extension: @$(call print_help,$@:) @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ - openfisca test ${python_packages}/openfisca_extension_template/tests \ + python -m ${openfisca} test \ + $(call tests,"openfisca_extension_template") \ --country-package openfisca_country_template \ --extensions openfisca_extension_template \ ${openfisca_args} @$(call print_pass,$@:) - -## Print the coverage report. -test-cov: - @$(call print_help,$@:) - @coverage report diff --git a/openfisca_tasks/test_doc.mk b/openfisca_tasks/test_doc.mk deleted file mode 100644 index bce952fe81..0000000000 --- a/openfisca_tasks/test_doc.mk +++ /dev/null @@ -1,78 +0,0 @@ -## The repository of the documentation. -repo = https://github.com/openfisca/openfisca-doc - -## The current working branch. -branch = $(shell git branch --show-current) - -## Check that the current changes do not break the doc. -test-doc: - @## Usage: - @## - @## make test-doc [branch=BRANCH] - @## - @## Examples: - @## - @## # Will check the current branch in openfisca-doc. - @## make test-doc - @## - @## # Will check "test-doc" in openfisca-doc. - @## make test-doc branch=test-doc - @## - @## # Will check "master" if "asdf1234" does not exist. - @## make test-doc branch=asdf1234 - @## - @$(call print_help,$@:) - @${MAKE} test-doc-checkout - @${MAKE} test-doc-install - @${MAKE} test-doc-build - @$(call print_pass,$@:) - -## Update the local copy of the doc. -test-doc-checkout: - @$(call print_help,$@:) - @[ ! -d doc ] && git clone ${repo} doc || : - @cd doc && { \ - git reset --hard ; \ - git fetch --all ; \ - [ "$$(git branch --show-current)" != "master" ] && git checkout master || : ; \ - [ "${branch}" != "master" ] \ - && { \ - { \ - >&2 echo "$(print_info) Trying to checkout the branch 'openfisca-doc/${branch}'..." ; \ - git branch -D ${branch} 2> /dev/null ; \ - git checkout ${branch} 2> /dev/null ; \ - } \ - && git pull --ff-only origin ${branch} \ - || { \ - >&2 echo "$(print_warn) The branch 'openfisca-doc/${branch}' was not found, falling back to 'openfisca-doc/master'..." ; \ - >&2 echo "" ; \ - >&2 echo "$(print_info) This is perfectly normal, one of two things can ensue:" ; \ - >&2 echo "$(print_info)" ; \ - >&2 echo "$(print_info) $$(tput setaf 2)[If tests pass]$$(tput sgr0)" ; \ - >&2 echo "$(print_info) * No further action required on your side..." ; \ - >&2 echo "$(print_info)" ; \ - >&2 echo "$(print_info) $$(tput setaf 1)[If tests fail]$$(tput sgr0)" ; \ - >&2 echo "$(print_info) * Create the branch '${branch}' in 'openfisca-doc'... " ; \ - >&2 echo "$(print_info) * Push your fixes..." ; \ - >&2 echo "$(print_info) * Run 'make test-doc' again..." ; \ - >&2 echo "" ; \ - >&2 echo "$(print_work) Checking out 'openfisca-doc/master'..." ; \ - git pull --ff-only origin master ; \ - } \ - } \ - || git pull --ff-only origin master ; \ - } 1> /dev/null - @$(call print_pass,$@:) - -## Install doc dependencies. -test-doc-install: - @$(call print_help,$@:) - @pip install --requirement doc/requirements.txt 1> /dev/null - @pip install --editable .[dev] --upgrade 1> /dev/null - @$(call print_pass,$@:) - -## Dry-build the doc. -test-doc-build: - @$(call print_help,$@:) - @sphinx-build -M dummy doc/source doc/build -n -q -W - @$(call print_pass,$@:) diff --git a/openfisca_web_api/app.py b/openfisca_web_api/app.py index e2244e9ba2..a76f255a0c 100644 --- a/openfisca_web_api/app.py +++ b/openfisca_web_api/app.py @@ -1,50 +1,57 @@ -# -*- coding: utf-8 -*- - import logging import os import traceback -from openfisca_core.errors import SituationParsingError, PeriodMismatchError -from openfisca_web_api.loader import build_data -from openfisca_web_api.errors import handle_import_error +from openfisca_core.errors import PeriodMismatchError, SituationParsingError from openfisca_web_api import handlers +from openfisca_web_api.errors import handle_import_error +from openfisca_web_api.loader import build_data try: - from flask import Flask, jsonify, abort, request, make_response + import werkzeug.exceptions + from flask import Flask, abort, jsonify, make_response, redirect, request from flask_cors import CORS from werkzeug.middleware.proxy_fix import ProxyFix - import werkzeug.exceptions except ImportError as error: handle_import_error(error) -log = logging.getLogger('gunicorn.error') +log = logging.getLogger("gunicorn.error") def init_tracker(url, idsite, tracker_token): try: from openfisca_tracker.piwik import PiwikTracker + tracker = PiwikTracker(url, idsite, tracker_token) - info = os.linesep.join(['You chose to activate the `tracker` module. ', - 'Tracking data will be sent to: ' + url, - 'For more information, see .']) + info = os.linesep.join( + [ + "You chose to activate the `tracker` module. ", + "Tracking data will be sent to: " + url, + "For more information, see .", + ], + ) log.info(info) return tracker except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'You chose to activate the `tracker` module, but it is not installed.', - 'For more information, see .']) - log.warn(message) - - -def create_app(tax_benefit_system, - tracker_url = None, - tracker_idsite = None, - tracker_token = None, - welcome_message = None, - ): - + message = os.linesep.join( + [ + traceback.format_exc(), + "You chose to activate the `tracker` module, but it is not installed.", + "For more information, see .", + ], + ) + log.warning(message) + + +def create_app( + tax_benefit_system, + tracker_url=None, + tracker_idsite=None, + tracker_token=None, + welcome_message=None, +): if not tracker_url or not tracker_idsite: tracker = None else: @@ -52,88 +59,108 @@ def create_app(tax_benefit_system, app = Flask(__name__) # Fix request.remote_addr to get the real client IP address - app.wsgi_app = ProxyFix(app.wsgi_app, x_for = 1, x_host = 1) - CORS(app, origins = '*') + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_host=1) + CORS(app, origins="*") - app.config['JSON_AS_ASCII'] = False # When False, lets jsonify encode to utf-8 + app.config["JSON_AS_ASCII"] = False # When False, lets jsonify encode to utf-8 app.url_map.strict_slashes = False # Accept url like /parameters/ app.url_map.merge_slashes = False # Do not eliminate // in paths - app.config['JSON_SORT_KEYS'] = False # Don't sort JSON keys in the Web API + app.config["JSON_SORT_KEYS"] = False # Don't sort JSON keys in the Web API data = build_data(tax_benefit_system) DEFAULT_WELCOME_MESSAGE = "This is the root of an OpenFisca Web API. To learn how to use it, check the general documentation (https://openfisca.org/doc/) and the OpenAPI specification of this instance ({}spec)." - @app.route('/') - def get_root(): - return jsonify({ - 'welcome': welcome_message or DEFAULT_WELCOME_MESSAGE.format(request.host_url) - }), 300 + @app.before_request + def before_request(): + if request.path != "/" and request.path.endswith("/"): + return redirect(request.path[:-1]) + return None - @app.route('/parameters') + @app.route("/") + def get_root(): + return ( + jsonify( + { + "welcome": welcome_message + or DEFAULT_WELCOME_MESSAGE.format(request.host_url), + }, + ), + 300, + ) + + @app.route("/parameters") def get_parameters(): parameters = { - parameter['id']: { - 'description': parameter['description'], - 'href': '{}parameter/{}'.format(request.host_url, name) - } - for name, parameter in data['parameters'].items() - if parameter.get('subparams') is None # For now and for backward compat, don't show nodes in overview + parameter["id"]: { + "description": parameter["description"], + "href": f"{request.host_url}parameter/{name}", } + for name, parameter in data["parameters"].items() + if parameter.get("subparams") + is None # For now and for backward compat, don't show nodes in overview + } return jsonify(parameters) - @app.route('/parameter/') + @app.route("/parameter/") def get_parameter(parameter_id): - parameter = data['parameters'].get(parameter_id) + parameter = data["parameters"].get(parameter_id) if parameter is None: # Try legacy route - parameter_new_id = parameter_id.replace('.', '/') - parameter = data['parameters'].get(parameter_new_id) + parameter_new_id = parameter_id.replace(".", "/") + parameter = data["parameters"].get(parameter_new_id) if parameter is None: raise abort(404) return jsonify(parameter) - @app.route('/variables') + @app.route("/variables") def get_variables(): variables = { name: { - 'description': variable['description'], - 'href': '{}variable/{}'.format(request.host_url, name) - } - for name, variable in data['variables'].items() + "description": variable["description"], + "href": f"{request.host_url}variable/{name}", } + for name, variable in data["variables"].items() + } return jsonify(variables) - @app.route('/variable/') + @app.route("/variable/") def get_variable(id): - variable = data['variables'].get(id) + variable = data["variables"].get(id) if variable is None: raise abort(404) return jsonify(variable) - @app.route('/entities') + @app.route("/entities") def get_entities(): - return jsonify(data['entities']) + return jsonify(data["entities"]) - @app.route('/spec') + @app.route("/spec") def get_spec(): - return jsonify({ - **data['openAPI_spec'], - **{'host': request.host}, - **{'schemes': [request.environ['wsgi.url_scheme']]} - }) - - def handle_invalid_json(error): - json_response = jsonify({ - 'error': 'Invalid JSON: {}'.format(error.args[0]), - }) + scheme = request.environ["wsgi.url_scheme"] + host = request.host + url = f"{scheme}://{host}" + + return jsonify( + { + **data["openAPI_spec"], + "servers": [{"url": url}], + }, + ) + + def handle_invalid_json(error) -> None: + json_response = jsonify( + { + "error": f"Invalid JSON: {error.args[0]}", + }, + ) abort(make_response(json_response, 400)) - @app.route('/calculate', methods=['POST']) + @app.route("/calculate", methods=["POST"]) def calculate(): - tax_benefit_system = data['tax_benefit_system'] + tax_benefit_system = data["tax_benefit_system"] request.on_json_loading_failed = handle_invalid_json input_data = request.get_json() try: @@ -141,12 +168,17 @@ def calculate(): except (SituationParsingError, PeriodMismatchError) as e: abort(make_response(jsonify(e.error), e.code or 400)) except (UnicodeEncodeError, UnicodeDecodeError) as e: - abort(make_response(jsonify({"error": "'" + e[1] + "' is not a valid ASCII value."}), 400)) + abort( + make_response( + jsonify({"error": "'" + e[1] + "' is not a valid ASCII value."}), + 400, + ), + ) return jsonify(result) - @app.route('/trace', methods=['POST']) + @app.route("/trace", methods=["POST"]) def trace(): - tax_benefit_system = data['tax_benefit_system'] + tax_benefit_system = data["tax_benefit_system"] request.on_json_loading_failed = handle_invalid_json input_data = request.get_json() try: @@ -157,25 +189,28 @@ def trace(): @app.after_request def apply_headers(response): - response.headers.extend({ - 'Country-Package': data['country_package_metadata']['name'], - 'Country-Package-Version': data['country_package_metadata']['version'] - }) + response.headers.extend( + { + "Country-Package": data["country_package_metadata"]["name"], + "Country-Package-Version": data["country_package_metadata"]["version"], + }, + ) return response @app.after_request def track_requests(response): - if tracker: - if request.headers.get('dnt'): + if request.headers.get("dnt"): source_ip = "" - elif request.headers.get('X-Forwarded-For'): - source_ip = request.headers['X-Forwarded-For'].split(', ')[0] + elif request.headers.get("X-Forwarded-For"): + source_ip = request.headers["X-Forwarded-For"].split(", ")[0] else: source_ip = request.remote_addr - api_version = "{}-{}".format(data['country_package_metadata']['name'], - data['country_package_metadata']['version']) + api_version = "{}-{}".format( + data["country_package_metadata"]["name"], + data["country_package_metadata"]["version"], + ) tracker.track(request.url, source_ip, api_version, request.path) return response diff --git a/openfisca_web_api/errors.py b/openfisca_web_api/errors.py index 96c95a6874..ac93ebd833 100644 --- a/openfisca_web_api/errors.py +++ b/openfisca_web_api/errors.py @@ -1,9 +1,12 @@ -# -*- coding: utf-8 -*- +from typing import NoReturn import logging -log = logging.getLogger('gunicorn.error') +log = logging.getLogger("gunicorn.error") -def handle_import_error(error): - raise ImportError("OpenFisca is missing some dependencies to run the Web API: '{}'. To install them, run `pip install openfisca_core[web-api]`.".format(error)) +def handle_import_error(error) -> NoReturn: + msg = f"OpenFisca is missing some dependencies to run the Web API: '{error}'. To install them, run `pip install openfisca_core[web-api]`." + raise ImportError( + msg, + ) diff --git a/openfisca_web_api/handlers.py b/openfisca_web_api/handlers.py index 1a6ace07db..59d035eb57 100644 --- a/openfisca_web_api/handlers.py +++ b/openfisca_web_api/handlers.py @@ -1,20 +1,24 @@ -# -*- coding: utf-8 -*- +import dpath.util -import dpath - -from openfisca_core.simulations import SimulationBuilder from openfisca_core.indexed_enums import Enum +from openfisca_core.simulations import SimulationBuilder -def calculate(tax_benefit_system, input_data): +def calculate(tax_benefit_system, input_data: dict) -> dict: + """Returns the input_data where the None values are replaced by the calculated values.""" simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data) - - requested_computations = dpath.util.search(input_data, '*/*/*/*', afilter = lambda t: t is None, yielded = True) - computation_results = {} - + requested_computations = dpath.util.search( + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, + ) + computation_results: dict = {} for computation in requested_computations: - path = computation[0] - entity_plural, entity_id, variable_name, period = path.split('/') + path = computation[ + 0 + ] # format: entity_plural/entity_instance_id/openfisca_variable_name/period + entity_plural, entity_id, variable_name, period = path.split("/") variable = tax_benefit_system.get_variable(variable_name) result = simulation.calculate(variable_name, period) population = simulation.get_population(entity_plural) @@ -23,15 +27,39 @@ def calculate(tax_benefit_system, input_data): if variable.value_type == Enum: entity_result = result.decode()[entity_index].name elif variable.value_type == float: - entity_result = float(str(result[entity_index])) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way. + entity_result = float( + str(result[entity_index]), + ) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way. elif variable.value_type == str: entity_result = str(result[entity_index]) else: entity_result = result.tolist()[entity_index] - - dpath.util.new(computation_results, path, entity_result) - - dpath.merge(input_data, computation_results) + # Don't use dpath.util.new, because there is a problem with dpath>=2.0 + # when we have a key that is numeric, like the year. + # See https://github.com/dpath-maintainers/dpath-python/issues/160 + if computation_results == {}: + computation_results = { + entity_plural: {entity_id: {variable_name: {period: entity_result}}}, + } + elif entity_plural in computation_results: + if entity_id in computation_results[entity_plural]: + if variable_name in computation_results[entity_plural][entity_id]: + computation_results[entity_plural][entity_id][variable_name][ + period + ] = entity_result + else: + computation_results[entity_plural][entity_id][variable_name] = { + period: entity_result, + } + else: + computation_results[entity_plural][entity_id] = { + variable_name: {period: entity_result}, + } + else: + computation_results[entity_plural] = { + entity_id: {variable_name: {period: entity_result}}, + } + dpath.util.merge(input_data, computation_results) return input_data @@ -41,11 +69,16 @@ def trace(tax_benefit_system, input_data): simulation.trace = True requested_calculations = [] - requested_computations = dpath.util.search(input_data, '*/*/*/*', afilter = lambda t: t is None, yielded = True) + requested_computations = dpath.util.search( + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, + ) for computation in requested_computations: path = computation[0] - entity_plural, entity_id, variable_name, period = path.split('/') - requested_calculations.append(f"{variable_name}<{str(period)}>") + entity_plural, entity_id, variable_name, period = path.split("/") + requested_calculations.append(f"{variable_name}<{period!s}>") simulation.calculate(variable_name, period) trace = simulation.tracer.get_serialized_flat_trace() @@ -53,5 +86,5 @@ def trace(tax_benefit_system, input_data): return { "trace": trace, "entitiesDescription": simulation.describe_entities(), - "requestedCalculations": requested_calculations - } + "requestedCalculations": requested_calculations, + } diff --git a/openfisca_web_api/loader/__init__.py b/openfisca_web_api/loader/__init__.py index b86aefad57..8d9318d9ae 100644 --- a/openfisca_web_api/loader/__init__.py +++ b/openfisca_web_api/loader/__init__.py @@ -1,24 +1,22 @@ -# -*- coding: utf-8 -*- - - -from openfisca_web_api.loader.parameters import build_parameters -from openfisca_web_api.loader.variables import build_variables from openfisca_web_api.loader.entities import build_entities +from openfisca_web_api.loader.parameters import build_parameters from openfisca_web_api.loader.spec import build_openAPI_specification +from openfisca_web_api.loader.variables import build_variables def build_data(tax_benefit_system): country_package_metadata = tax_benefit_system.get_package_metadata() parameters = build_parameters(tax_benefit_system, country_package_metadata) variables = build_variables(tax_benefit_system, country_package_metadata) + entities = build_entities(tax_benefit_system) data = { - 'tax_benefit_system': tax_benefit_system, - 'country_package_metadata': tax_benefit_system.get_package_metadata(), - 'openAPI_spec': None, - 'parameters': parameters, - 'variables': variables, - 'entities': build_entities(tax_benefit_system), - } - data['openAPI_spec'] = build_openAPI_specification(data) + "tax_benefit_system": tax_benefit_system, + "country_package_metadata": country_package_metadata, + "openAPI_spec": None, + "parameters": parameters, + "variables": variables, + "entities": entities, + } + data["openAPI_spec"] = build_openAPI_specification(data) return data diff --git a/openfisca_web_api/loader/entities.py b/openfisca_web_api/loader/entities.py index 2f3194882a..98ce4e6fb9 100644 --- a/openfisca_web_api/loader/entities.py +++ b/openfisca_web_api/loader/entities.py @@ -1,39 +1,28 @@ -# -*- coding: utf-8 -*- - - def build_entities(tax_benefit_system): - entities = { - entity.key: build_entity(entity) - for entity in tax_benefit_system.entities - } - return entities + return {entity.key: build_entity(entity) for entity in tax_benefit_system.entities} def build_entity(entity): formatted_doc = entity.doc.strip() formatted_entity = { - 'plural': entity.plural, - 'description': entity.label, - 'documentation': formatted_doc - } + "plural": entity.plural, + "description": entity.label, + "documentation": formatted_doc, + } if not entity.is_person: - formatted_entity['roles'] = { - role.key: build_role(role) - for role in entity.roles - } + formatted_entity["roles"] = { + role.key: build_role(role) for role in entity.roles + } return formatted_entity def build_role(role): - formatted_role = { - 'plural': role.plural, - 'description': role.doc - } + formatted_role = {"plural": role.plural, "description": role.doc} if role.max: - formatted_role['max'] = role.max + formatted_role["max"] = role.max if role.subroles: - formatted_role['max'] = len(role.subroles) + formatted_role["max"] = len(role.subroles) return formatted_role diff --git a/openfisca_web_api/loader/parameters.py b/openfisca_web_api/loader/parameters.py index 39534f972e..193f12915f 100644 --- a/openfisca_web_api/loader/parameters.py +++ b/openfisca_web_api/loader/parameters.py @@ -1,6 +1,7 @@ -# -*- coding: utf-8 -*- +import functools +import operator -from openfisca_core.parameters import Parameter, ParameterNode, ParameterScale +from openfisca_core.parameters import Parameter, ParameterNode, Scale def build_api_values_history(values_history): @@ -12,43 +13,57 @@ def build_api_values_history(values_history): def get_value(date, values): - candidates = sorted([ - (start_date, value) - for start_date, value in values.items() - if start_date <= date # dates are lexicographically ordered and can be sorted - ], reverse = True) + candidates = sorted( + [ + (start_date, value) + for start_date, value in values.items() + if start_date + <= date # dates are lexicographically ordered and can be sorted + ], + reverse=True, + ) if candidates: return candidates[0][1] - else: - return None + return None def build_api_scale(scale, value_key_name): # preprocess brackets for a scale with 'rates' or 'amounts' - brackets = [{ - 'thresholds': build_api_values_history(bracket.threshold), - 'values': build_api_values_history(getattr(bracket, value_key_name)) - } for bracket in scale.brackets] - - dates = set(sum( - [list(bracket['thresholds'].keys()) - + list(bracket['values'].keys()) for bracket in brackets], - [])) # flatten the dates and remove duplicates + brackets = [ + { + "thresholds": build_api_values_history(bracket.threshold), + "values": build_api_values_history(getattr(bracket, value_key_name)), + } + for bracket in scale.brackets + ] + + dates = set( + functools.reduce( + operator.iadd, + [ + list(bracket["thresholds"].keys()) + list(bracket["values"].keys()) + for bracket in brackets + ], + [], + ), + ) # flatten the dates and remove duplicates # We iterate on all dates as we need to build the whole scale for each of them api_scale = {} for date in dates: for bracket in brackets: - threshold_value = get_value(date, bracket['thresholds']) + threshold_value = get_value(date, bracket["thresholds"]) if threshold_value is not None: - rate_or_amount_value = get_value(date, bracket['values']) + rate_or_amount_value = get_value(date, bracket["values"]) api_scale[date] = api_scale.get(date) or {} api_scale[date][threshold_value] = rate_or_amount_value # Handle stopped parameters: a parameter is stopped if its first bracket is stopped - latest_date_first_threshold = max(brackets[0]['thresholds'].keys()) - latest_value_first_threshold = brackets[0]['thresholds'][latest_date_first_threshold] + latest_date_first_threshold = max(brackets[0]["thresholds"].keys()) + latest_value_first_threshold = brackets[0]["thresholds"][ + latest_date_first_threshold + ] if latest_value_first_threshold is None: api_scale[latest_date_first_threshold] = None @@ -57,45 +72,51 @@ def build_api_scale(scale, value_key_name): def build_source_url(absolute_file_path, country_package_metadata): - relative_path = absolute_file_path.replace(country_package_metadata['location'], '') - return '{}/blob/{}{}'.format( - country_package_metadata['repository_url'], - country_package_metadata['version'], - relative_path - ) + relative_path = absolute_file_path.replace(country_package_metadata["location"], "") + return "{}/blob/{}{}".format( + country_package_metadata["repository_url"], + country_package_metadata["version"], + relative_path, + ) def build_api_parameter(parameter, country_package_metadata): api_parameter = { - 'description': getattr(parameter, "description", None), - 'id': parameter.name, - 'metadata': parameter.metadata - } + "description": getattr(parameter, "description", None), + "id": parameter.name, + "metadata": parameter.metadata, + } if parameter.file_path: - api_parameter['source'] = build_source_url(parameter.file_path, country_package_metadata) + api_parameter["source"] = build_source_url( + parameter.file_path, + country_package_metadata, + ) if isinstance(parameter, Parameter): if parameter.documentation: - api_parameter['documentation'] = parameter.documentation.strip() - api_parameter['values'] = build_api_values_history(parameter) - elif isinstance(parameter, ParameterScale): - if 'rate' in parameter.brackets[0].children: - api_parameter['brackets'] = build_api_scale(parameter, 'rate') - elif 'amount' in parameter.brackets[0].children: - api_parameter['brackets'] = build_api_scale(parameter, 'amount') + api_parameter["documentation"] = parameter.documentation.strip() + api_parameter["values"] = build_api_values_history(parameter) + elif isinstance(parameter, Scale): + if "rate" in parameter.brackets[0].children: + api_parameter["brackets"] = build_api_scale(parameter, "rate") + elif "amount" in parameter.brackets[0].children: + api_parameter["brackets"] = build_api_scale(parameter, "amount") elif isinstance(parameter, ParameterNode): if parameter.documentation: - api_parameter['documentation'] = parameter.documentation.strip() - api_parameter['subparams'] = { + api_parameter["documentation"] = parameter.documentation.strip() + api_parameter["subparams"] = { child_name: { - 'description': child.description, - } - for child_name, child in parameter.children.items() + "description": child.description, } + for child_name, child in parameter.children.items() + } return api_parameter def build_parameters(tax_benefit_system, country_package_metadata): return { - parameter.name.replace('.', '/'): build_api_parameter(parameter, country_package_metadata) + parameter.name.replace(".", "/"): build_api_parameter( + parameter, + country_package_metadata, + ) for parameter in tax_benefit_system.parameters.get_descendants() - } + } diff --git a/openfisca_web_api/loader/spec.py b/openfisca_web_api/loader/spec.py index fde2818c33..4a163bd91f 100644 --- a/openfisca_web_api/loader/spec.py +++ b/openfisca_web_api/loader/spec.py @@ -1,76 +1,119 @@ -# -*- coding: utf-8 -*- - import os -import yaml from copy import deepcopy -import dpath +import dpath.util +import yaml from openfisca_core.indexed_enums import Enum from openfisca_web_api import handlers - -OPEN_API_CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir, 'openAPI.yml') +OPEN_API_CONFIG_FILE = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + os.path.pardir, + "openAPI.yml", +) def build_openAPI_specification(api_data): - tax_benefit_system = api_data['tax_benefit_system'] - file = open(OPEN_API_CONFIG_FILE, 'r') + tax_benefit_system = api_data["tax_benefit_system"] + file = open(OPEN_API_CONFIG_FILE) spec = yaml.safe_load(file) - country_package_name = api_data['country_package_metadata']['name'].title() - dpath.new(spec, 'info/title', spec['info']['title'].replace("{COUNTRY_PACKAGE_NAME}", country_package_name)) - dpath.new(spec, 'info/description', spec['info']['description'].replace("{COUNTRY_PACKAGE_NAME}", country_package_name)) - dpath.new(spec, 'info/version', api_data['country_package_metadata']['version']) + country_package_name = api_data["country_package_metadata"]["name"].title() + country_package_version = api_data["country_package_metadata"]["version"] + dpath.util.new( + spec, + "info/title", + spec["info"]["title"].replace("{COUNTRY_PACKAGE_NAME}", country_package_name), + ) + dpath.util.new( + spec, + "info/description", + spec["info"]["description"].replace( + "{COUNTRY_PACKAGE_NAME}", + country_package_name, + ), + ) + dpath.util.new( + spec, + "info/version", + spec["info"]["version"].replace( + "{COUNTRY_PACKAGE_VERSION}", + country_package_version, + ), + ) for entity in tax_benefit_system.entities: name = entity.key.title() - spec['definitions'][name] = get_entity_json_schema(entity, tax_benefit_system) + spec["components"]["schemas"][name] = get_entity_json_schema( + entity, + tax_benefit_system, + ) situation_schema = get_situation_json_schema(tax_benefit_system) - dpath.new(spec, 'definitions/SituationInput', situation_schema) - dpath.new(spec, 'definitions/SituationOutput', situation_schema.copy()) - dpath.new(spec, 'definitions/Trace/properties/entitiesDescription/properties', { - entity.plural: {'type': 'array', 'items': {"type": "string"}} - for entity in tax_benefit_system.entities - }) + dpath.util.new(spec, "components/schemas/SituationInput", situation_schema) + dpath.util.new(spec, "components/schemas/SituationOutput", situation_schema.copy()) + dpath.util.new( + spec, + "components/schemas/Trace/properties/entitiesDescription/properties", + { + entity.plural: {"type": "array", "items": {"type": "string"}} + for entity in tax_benefit_system.entities + }, + ) # Get example from the served tax benefist system - if tax_benefit_system.open_api_config.get('parameter_example'): - parameter_id = tax_benefit_system.open_api_config['parameter_example'] - parameter_path = parameter_id.replace('.', '/') - parameter_example = api_data['parameters'][parameter_path] + if tax_benefit_system.open_api_config.get("parameter_example"): + parameter_id = tax_benefit_system.open_api_config["parameter_example"] + parameter_path = parameter_id.replace(".", "/") + parameter_example = api_data["parameters"][parameter_path] else: - parameter_example = next(iter(api_data['parameters'].values())) - dpath.new(spec, 'definitions/Parameter/example', parameter_example) + parameter_example = next(iter(api_data["parameters"].values())) + dpath.util.new(spec, "components/schemas/Parameter/example", parameter_example) - if tax_benefit_system.open_api_config.get('variable_example'): - variable_example = api_data['variables'][tax_benefit_system.open_api_config['variable_example']] + if tax_benefit_system.open_api_config.get("variable_example"): + variable_example = api_data["variables"][ + tax_benefit_system.open_api_config["variable_example"] + ] else: - variable_example = next(iter(api_data['variables'].values())) - dpath.new(spec, 'definitions/Variable/example', variable_example) - - if tax_benefit_system.open_api_config.get('simulation_example'): - simulation_example = tax_benefit_system.open_api_config['simulation_example'] - dpath.new(spec, 'definitions/SituationInput/example', simulation_example) - dpath.new(spec, 'definitions/SituationOutput/example', handlers.calculate(tax_benefit_system, deepcopy(simulation_example))) # calculate has side-effects - dpath.new(spec, 'definitions/Trace/example', handlers.trace(tax_benefit_system, simulation_example)) + variable_example = next(iter(api_data["variables"].values())) + dpath.util.new(spec, "components/schemas/Variable/example", variable_example) + + if tax_benefit_system.open_api_config.get("simulation_example"): + simulation_example = tax_benefit_system.open_api_config["simulation_example"] + dpath.util.new( + spec, + "components/schemas/SituationInput/example", + simulation_example, + ) + dpath.util.new( + spec, + "components/schemas/SituationOutput/example", + handlers.calculate(tax_benefit_system, deepcopy(simulation_example)), + ) # calculate has side-effects + dpath.util.new( + spec, + "components/schemas/Trace/example", + handlers.trace(tax_benefit_system, simulation_example), + ) else: - message = "No simulation example has been defined for this tax and benefit system. If you are the maintainer of {}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html".format(country_package_name) - dpath.new(spec, 'definitions/SituationInput/example', message) - dpath.new(spec, 'definitions/SituationOutput/example', message) - dpath.new(spec, 'definitions/Trace/example', message) + message = f"No simulation example has been defined for this tax and benefit system. If you are the maintainer of {country_package_name}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html" + dpath.util.new(spec, "components/schemas/SituationInput/example", message) + dpath.util.new(spec, "components/schemas/SituationOutput/example", message) + dpath.util.new(spec, "components/schemas/Trace/example", message) return spec def get_variable_json_schema(variable): result = { - 'type': 'object', - 'additionalProperties': {'type': variable.json_type}, - } + "type": "object", + "additionalProperties": {"type": variable.json_type}, + } if variable.value_type == Enum: - result['additionalProperties']['enum'] = [item.name for item in list(variable.possible_values)] + result["additionalProperties"]["enum"] = [ + item.name for item in list(variable.possible_values) + ] return result @@ -78,46 +121,48 @@ def get_variable_json_schema(variable): def get_entity_json_schema(entity, tax_benefit_system): if entity.is_person: return { - 'type': 'object', - 'properties': { + "type": "object", + "properties": { variable_name: get_variable_json_schema(variable) - for variable_name, variable in tax_benefit_system.get_variables(entity).items() - }, - 'additionalProperties': False, - } - else: - properties = {} - properties.update({ - role.plural or role.key: { - 'type': 'array', - "items": { - "type": "string" - } - } + for variable_name, variable in tax_benefit_system.get_variables( + entity, + ).items() + }, + "additionalProperties": False, + } + properties = {} + properties.update( + { + role.plural or role.key: {"type": "array", "items": {"type": "string"}} for role in entity.roles - }) - properties.update({ + }, + ) + properties.update( + { variable_name: get_variable_json_schema(variable) - for variable_name, variable in tax_benefit_system.get_variables(entity).items() - }) - return { - 'type': 'object', - 'properties': properties, - 'additionalProperties': False, - } + for variable_name, variable in tax_benefit_system.get_variables( + entity, + ).items() + }, + ) + return { + "type": "object", + "properties": properties, + "additionalProperties": False, + } def get_situation_json_schema(tax_benefit_system): return { - 'type': 'object', - 'additionalProperties': False, - 'properties': { + "type": "object", + "additionalProperties": False, + "properties": { entity.plural: { - 'type': 'object', - 'additionalProperties': { - "$ref": "#/definitions/{}".format(entity.key.title()) - } - } - for entity in tax_benefit_system.entities + "type": "object", + "additionalProperties": { + "$ref": f"#/components/schemas/{entity.key.title()}", + }, } - } + for entity in tax_benefit_system.entities + }, + } diff --git a/openfisca_web_api/loader/tax_benefit_system.py b/openfisca_web_api/loader/tax_benefit_system.py index 856f760008..358f960501 100644 --- a/openfisca_web_api/loader/tax_benefit_system.py +++ b/openfisca_web_api/loader/tax_benefit_system.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- - import importlib -import traceback import logging +import traceback from os import linesep log = logging.getLogger(__name__) @@ -12,14 +10,18 @@ def build_tax_benefit_system(country_package_name): try: country_package = importlib.import_module(country_package_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(country_package_name), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .', - linesep]) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{country_package_name}`.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + linesep, + ], + ) raise ValueError(message) try: return country_package.CountryTaxBenefitSystem() except NameError: # Gunicorn swallows NameErrors. Force printing the stack trace. - log.error(traceback.format_exc()) + log.exception(traceback.format_exc()) raise diff --git a/openfisca_web_api/loader/variables.py b/openfisca_web_api/loader/variables.py index d9390fb3a2..6730dc0811 100644 --- a/openfisca_web_api/loader/variables.py +++ b/openfisca_web_api/loader/variables.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import datetime import inspect import textwrap @@ -10,8 +8,8 @@ def get_next_day(date): parsed_date = date - next_day = parsed_date + datetime.timedelta(days = 1) - return next_day.isoformat().split('T')[0] + next_day = parsed_date + datetime.timedelta(days=1) + return next_day.isoformat().split("T")[0] def get_default_value(variable): @@ -25,84 +23,99 @@ def get_default_value(variable): return default_value -def build_source_url(country_package_metadata, source_file_path, start_line_number, source_code): - nb_lines = source_code.count('\n') - return '{}/blob/{}{}#L{}-L{}'.format( - country_package_metadata['repository_url'], - country_package_metadata['version'], +def build_source_url( + country_package_metadata, + source_file_path, + start_line_number, + source_code, +): + nb_lines = source_code.count("\n") + return "{}/blob/{}{}#L{}-L{}".format( + country_package_metadata["repository_url"], + country_package_metadata["version"], source_file_path, start_line_number, start_line_number + nb_lines - 1, - ) + ) -def build_formula(formula, country_package_metadata, source_file_path, tax_benefit_system): +def build_formula(formula, country_package_metadata, source_file_path): source_code, start_line_number = inspect.getsourcelines(formula) - source_code = textwrap.dedent(''.join(source_code)) + source_code = textwrap.dedent("".join(source_code)) api_formula = { - 'source': build_source_url( + "source": build_source_url( country_package_metadata, source_file_path, start_line_number, - source_code - ), - 'content': source_code, - } + source_code, + ), + "content": source_code, + } if formula.__doc__: - api_formula['documentation'] = textwrap.dedent(formula.__doc__) + api_formula["documentation"] = textwrap.dedent(formula.__doc__) return api_formula -def build_formulas(formulas, country_package_metadata, source_file_path, tax_benefit_system): +def build_formulas(formulas, country_package_metadata, source_file_path): return { - start_date: build_formula(formula, country_package_metadata, source_file_path, tax_benefit_system) + start_date: build_formula(formula, country_package_metadata, source_file_path) for start_date, formula in formulas.items() - } + } -def build_variable(variable, country_package_metadata, tax_benefit_system): - comments, source_file_path, source_code, start_line_number = variable.get_introspection_data(tax_benefit_system) +def build_variable(variable, country_package_metadata): + ( + source_file_path, + source_code, + start_line_number, + ) = variable.get_introspection_data() result = { - 'id': variable.name, - 'description': variable.label, - 'valueType': VALUE_TYPES[variable.value_type]['formatted_value_type'], - 'defaultValue': get_default_value(variable), - 'definitionPeriod': variable.definition_period.upper(), - 'entity': variable.entity.key, - } + "id": variable.name, + "description": variable.label, + "valueType": VALUE_TYPES[variable.value_type]["formatted_value_type"], + "defaultValue": get_default_value(variable), + "definitionPeriod": variable.definition_period.upper(), + "entity": variable.entity.key, + } if source_code: - result['source'] = build_source_url( + result["source"] = build_source_url( country_package_metadata, source_file_path, start_line_number, - source_code - ) + source_code, + ) if variable.documentation: - result['documentation'] = variable.documentation.strip() + result["documentation"] = variable.documentation.strip() if variable.reference: - result['references'] = variable.reference + result["references"] = variable.reference if len(variable.formulas) > 0: - result['formulas'] = build_formulas(variable.formulas, country_package_metadata, source_file_path, tax_benefit_system) + result["formulas"] = build_formulas( + variable.formulas, + country_package_metadata, + source_file_path, + ) if variable.end: - result['formulas'][get_next_day(variable.end)] = None + result["formulas"][get_next_day(variable.end)] = None if variable.value_type == Enum: - result['possibleValues'] = {item.name: item.value for item in list(variable.possible_values)} + result["possibleValues"] = { + item.name: item.value for item in list(variable.possible_values) + } return result def build_variables(tax_benefit_system, country_package_metadata): return { - name: build_variable(variable, country_package_metadata, tax_benefit_system) + name: build_variable(variable, country_package_metadata) for name, variable in tax_benefit_system.variables.items() - } + } diff --git a/openfisca_web_api/openAPI.yml b/openfisca_web_api/openAPI.yml index d0c52f9a14..ce935e5596 100644 --- a/openfisca_web_api/openAPI.yml +++ b/openfisca_web_api/openAPI.yml @@ -1,374 +1,434 @@ -swagger: "2.0" +openapi: "3.0.0" + info: title: "{COUNTRY_PACKAGE_NAME} Web API" description: "The OpenFisca Web API lets you get up-to-date information and formulas included in the {COUNTRY_PACKAGE_NAME} legislation." - version: null + version: "{COUNTRY_PACKAGE_VERSION}" termsOfService: "https://openfisca.org/doc/licence.html" contact: email: "contact@openfisca.org" license: name: "AGPL" url: "https://www.gnu.org/licenses/agpl-3.0" -host: null -schemes: null + tags: - name: "Parameters" description: "A parameter is a numeric property of the legislation that can evolve over time." externalDocs: description: "Parameters documentation" url: "https://openfisca.org/doc/key-concepts/parameters.html" + - name: "Variables" description: "A variable depends on a person, or an entity (e.g. zip code, salary, income tax)." externalDocs: description: "Variables documentation" url: "https://openfisca.org/doc/key-concepts/variables.html" + - name: "Entities" description: "An entity is a person of a group of individuals (such as a household)." externalDocs: description: "Entities documentation" url: "https://openfisca.org/doc/key-concepts/person,_entities,_role.html" + - name: "Calculations" + - name: "Documentation" + +components: + schemas: + Parameter: + type: "object" + properties: + values: + $ref: "#/components/schemas/Values" + brackets: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Brackets" + subparams: + type: "object" + additionalProperties: + type: "object" + properties: + definition: + type: "string" + metadata: + type: "object" + description: + type: "string" + id: + type: "integer" + format: "string" + source: + type: "string" + + Parameters: + type: "object" + additionalProperties: + type: "object" + properties: + description: + type: "string" + href: + type: "string" + + Variable: + type: "object" + properties: + defaultValue: + type: "string" + definitionPeriod: + type: "string" + enum: + - "MONTH" + - "YEAR" + - "ETERNITY" + description: + type: "string" + entity: + type: "string" + formulas: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Formula" + id: + type: "string" + reference: + type: "array" + items: + type: "string" + source: + type: "string" + valueType: + type: "string" + enum: + - "Int" + - "Float" + - "Boolean" + - "Date" + - "String" + + Variables: + type: "object" + additionalProperties: + type: "object" + properties: + description: + type: "string" + href: + type: "string" + + Formula: + type: "object" + properties: + content: + type: "string" + source: + type: "string" + + Brackets: + type: "object" + additionalProperties: + type: "number" + format: "float" + + Values: + description: "All keys are ISO dates. Values can be numbers, booleans, or arrays of a single type (number, boolean or string)." + type: "object" + additionalProperties: + $ref: "#/components/schemas/Value" + # propertyNames: # this keyword is part of JSON Schema but is not supported in OpenAPI v3.0.0 + # pattern: "^[12][0-9]{3}-[01][0-9]-[0-3][0-9]$" # all keys are ISO dates + + Value: + oneOf: + - type: "boolean" + - type: "number" + format: "float" + - type: "array" + items: + oneOf: + - type: "string" + - type: "number" + + Entities: + type: "object" + properties: + description: + type: "string" + documentation: + type: "string" + plural: + type: "string" + roles: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Roles" + + Roles: + type: "object" + properties: + description: + type: "string" + max: + type: "integer" + plural: + type: "string" + + Trace: + type: "object" + properties: + requestedCalculations: + type: "array" + items: + type: "string" + entitiesDescription: + type: "object" + additionalProperties: false # Will be dynamically added by the Web API + trace: + type: "object" + additionalProperties: + type: "object" + properties: + value: + type: "array" + items: {} + dependencies: + type: "array" + items: + type: "string" + parameters: + type: "object" + additionalProperties: + type: "object" + + headers: + Country-Package: + description: "The name of the country package currently loaded in this API server" + schema: + type: "string" + + Country-Package-Version: + description: "The version of the country package currently loaded in this API server" + schema: + type: "string" + pattern: "^(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)(?:-((?:0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$" # adapted from https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string + paths: /calculate: post: summary: "Run a simulation" tags: - - Calculations + - "Calculations" operationId: "calculate" - consumes: - - "application/json" - produces: - - "application/json" - parameters: - - in: "body" - name: "Situation" + requestBody: description: "Describe the situation (persons and entities). Add the variable you wish to calculate in the proper entity, with null as the value. Learn more in our official documentation: https://openfisca.org/doc/openfisca-web-api/input-output-data.html" required: true - schema: - $ref: "#/definitions/SituationInput" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationInput" responses: 200: description: "The calculation result is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationOutput" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/SituationOutput" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "A variable mentioned in the input situation does not exist in the loaded tax and benefit system. Details are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 400: description: "The request is invalid. Details about the error are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /parameters: get: tags: - "Parameters" summary: "List all available parameters" operationId: "getParameters" - produces: - - "application/json" responses: 200: description: "The list of parameters is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Parameters" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Parameters" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /parameter/{parameterID}: get: tags: - "Parameters" summary: "Get information about a specific parameter" operationId: "getParameter" - produces: - - "application/json" parameters: - name: "parameterID" in: "path" description: "ID of parameter. IDs can be obtained by enumerating the /parameters endpoint" required: true - type: "string" + schema: + type: "string" responses: 200: description: "The requested parameter's information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Parameter" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Parameter" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "The requested parameter does not exist" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /variables: get: tags: - "Variables" summary: "List all available variables" operationId: "getVariables" - produces: - - "application/json" responses: 200: description: "The list of variables is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Variables" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Variables" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /variable/{variableID}: get: tags: - "Variables" summary: "Get information about a specific variable" operationId: "getVariable" - produces: - - "application/json" parameters: - name: "variableID" in: "path" description: "ID of a variable. IDs can be obtained by enumerating the /variables endpoint." required: true - type: "string" + schema: + type: "string" responses: 200: description: "The requested variable's information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Variable" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Variable" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "The requested variable does not exist" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /entities: get: tags: - "Entities" summary: "List all available Entities" - operationId: "getVariables" - produces: - - "application/json" + operationId: "getEntities" responses: 200: description: "The list of the entities as well as their information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Entities" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Entities" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /trace: post: summary: "Explore a simulation's steps in details." tags: - - Calculations + - "Calculations" operationId: "trace" - consumes: - - "application/json" - produces: - - "application/json" - parameters: - - in: "body" - name: "Situation" + requestBody: description: "Describe the situation (persons and entities). Add the variable you wish to calculate in the proper entity, with null as the value." required: true - schema: - $ref: "#/definitions/SituationInput" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationInput" responses: 200: description: "The calculation details are sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Trace" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Trace" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "A variable mentioned in the input situation does not exist in the loaded tax and benefit system. Details are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 400: description: "The request is invalid. Details about the error are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /spec: get: - summary: Provide the API documentation in an OpenAPI format + summary: "Provide the API documentation in an OpenAPI format" tags: - - Documentation - operationId: spec - produces: - - application/json + - "Documentation" + operationId: "spec" responses: 200: - description: The API documentation is sent back in the response body + description: "The API documentation is sent back in the response body" headers: - $ref: "#/commons/Headers" - -definitions: - Parameter: - type: "object" - properties: - values: - $ref: "#/definitions/Values" - brackets: - type: "object" - additionalProperties: - $ref: "#/definitions/Brackets" - subparams: - type: "object" - additionalProperties: - type: "object" - properties: - definition: - type: "string" - metadata: - type: "object" - description: - type: "string" - id: - type: "integer" - format: "string" - source: - type: "string" - example: null - - Parameters: - type: "object" - additionalProperties: - type: "object" - properties: - description: - type: "string" - href: - type: "string" - - Variable: - type: "object" - properties: - defaultValue: - type: "string" - definitionPeriod: - type: string - enum: - - MONTH - - YEAR - - ETERNITY - description: - type: "string" - entity: - type: "string" - formulas: - type: "object" - additionalProperties: - $ref: "#/definitions/Formula" - id: - type: "string" - reference: - type: "array" - items: - type: "string" - source: - type: "string" - valueType: - type: "string" - enum: - - Int - - Float - - Boolean - - Date - - String - example: null - - Variables: - type: "object" - additionalProperties: - type: "object" - properties: - description: - type: "string" - href: - type: "string" - - Formula: - type: "object" - properties: - content: - type: "string" - source: - type: "string" - - Brackets: - type: "object" - additionalProperties: - type: "number" - format: "float" - - Values: - description: All keys are ISO dates. Values can be numbers, booleans, or arrays of a single type (number, boolean or string). - type: "object" - additionalProperties: true -# propertyNames: # this keyword is part of JSON Schema but is not supported in OpenAPI Specification at the time of writing, see https://swagger.io/docs/specification/data-models/keywords/#unsupported -# pattern: "^[12][0-9]{3}-[01][0-9]-[0-3][0-9]$" # all keys are ISO dates - - Entities: - type: "object" - properties: - description: - type: "string" - documentation: - type: "string" - plural: - type: "string" - roles: - type: "object" - additionalProperties: - $ref: "#/definitions/Roles" - Roles: - type: "object" - properties: - description: - type: "string" - max: - type: "integer" - plural: - type: "string" - SituationInput: null - SituationOutput: null - - Trace: - type: object - properties: - requestedCalculations: - type: array - items: - type: string - entitiesDescription: - type: object - properties: null # Will be dynamically added by the Web API - trace: - type: object - additionalProperties: - type: object - properties: - value: - type: array - items: - type: any - dependencies: - type: array - items: - type: string - parameters: - type: object - additionalProperties: - type: object - - example: null - -commons: - Headers: - Country-Package: - description: "The name of the country package currently loaded in this API server" - type: "string" - Country-Package-Version: - description: "The version of the country package currently loaded in this API server" - type: "string" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" diff --git a/openfisca_web_api/scripts/serve.py b/openfisca_web_api/scripts/serve.py index 428cf2b965..6ba89f440a 100644 --- a/openfisca_web_api/scripts/serve.py +++ b/openfisca_web_api/scripts/serve.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- - -import sys import logging +import sys from openfisca_core.scripts import build_tax_benefit_system from openfisca_web_api.app import create_app from openfisca_web_api.errors import handle_import_error try: - from gunicorn.app.base import BaseApplication from gunicorn import config + from gunicorn.app.base import BaseApplication except ImportError as error: handle_import_error(error) @@ -18,10 +16,10 @@ Define the `openfisca serve` command line interface. """ -DEFAULT_PORT = '5000' -HOST = '127.0.0.1' -DEFAULT_WORKERS_NUMBER = '3' -DEFAULT_TIMEOUT = 120 +DEFAULT_PORT = "5000" +HOST = "127.0.0.1" +DEFAULT_WORKERS_NUMBER = "3" +DEFAULT_TIMEOUT = 1200 log = logging.getLogger(__name__) @@ -33,7 +31,7 @@ def read_user_configuration(default_configuration, command_line_parser): if args.configuration_file: file_configuration = {} - with open(args.configuration_file, "r") as file: + with open(args.configuration_file) as file: exec(file.read(), {}, file_configuration) # Configuration file overloads default configuration @@ -42,10 +40,13 @@ def read_user_configuration(default_configuration, command_line_parser): # Command line configuration overloads all configuration gunicorn_parser = config.Config().parser() configuration = update(configuration, vars(args)) - configuration = update(configuration, vars(gunicorn_parser.parse_args(unknown_args))) - if configuration['args']: + configuration = update( + configuration, + vars(gunicorn_parser.parse_args(unknown_args)), + ) + if configuration["args"]: command_line_parser.print_help() - log.error('Unexpected positional argument {}'.format(configuration['args'])) + log.error("Unexpected positional argument {}".format(configuration["args"])) sys.exit(1) return configuration @@ -56,42 +57,43 @@ def update(configuration, new_options): if value is not None: configuration[key] = value if key == "port": - configuration['bind'] = configuration['bind'][:-4] + str(configuration['port']) + configuration["bind"] = configuration["bind"][:-4] + str( + configuration["port"], + ) return configuration class OpenFiscaWebAPIApplication(BaseApplication): - - def __init__(self, options): + def __init__(self, options) -> None: self.options = options - super(OpenFiscaWebAPIApplication, self).__init__() + super().__init__() - def load_config(self): + def load_config(self) -> None: for key, value in self.options.items(): if key in self.cfg.settings: self.cfg.set(key.lower(), value) def load(self): tax_benefit_system = build_tax_benefit_system( - self.options.get('country_package'), - self.options.get('extensions'), - self.options.get('reforms') - ) + self.options.get("country_package"), + self.options.get("extensions"), + self.options.get("reforms"), + ) return create_app( tax_benefit_system, - self.options.get('tracker_url'), - self.options.get('tracker_idsite'), - self.options.get('tracker_token'), - self.options.get('welcome_message') - ) + self.options.get("tracker_url"), + self.options.get("tracker_idsite"), + self.options.get("tracker_token"), + self.options.get("welcome_message"), + ) -def main(parser): +def main(parser) -> None: configuration = { - 'port': DEFAULT_PORT, - 'bind': '{}:{}'.format(HOST, DEFAULT_PORT), - 'workers': DEFAULT_WORKERS_NUMBER, - 'timeout': DEFAULT_TIMEOUT, - } + "port": DEFAULT_PORT, + "bind": f"{HOST}:{DEFAULT_PORT}", + "workers": DEFAULT_WORKERS_NUMBER, + "timeout": DEFAULT_TIMEOUT, + } configuration = read_user_configuration(configuration, parser) OpenFiscaWebAPIApplication(configuration).run() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..1f99960e90 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[tool.black] +target-version = [ "py39", "py310", "py311", "py312" ] + +[tool.codespell] +ignore-words-list = [ + "THIRDPARTY", + "ans", + "constitue", + "exemple", + "fonction", + "impot", + "treshold", +] +skip = "./venv" diff --git a/setup.cfg b/setup.cfg index bb3ff50fc5..e6b37ba7eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,55 +1,100 @@ -; C011X: We (progressively) document the code base. -; D10X: We (progressively) check docstrings (see https://www.pydocstyle.org/en/2.1.1/error_codes.html#grouping). -; DARXXX: We (progressively) check docstrings (see https://github.com/terrencepreilly/darglint#error-codes). -; E128/133: We prefer hang-closing visual indents. -; E251: We prefer `function(x = 1)` over `function(x=1)`. -; E501: We do not enforce a maximum line length. -; F403/405: We ignore * imports. -; R0401: We avoid cyclic imports —required for unit/doc tests. -; RST301: We use Google Python Style (see https://pypi.org/project/flake8-rst-docstrings/) -; W503/504: We break lines before binary operators (Knuth's style). +# C011X: We (progressively) document the code base. +# D10X: We (progressively) check docstrings (see https://www.pydocstyle.org/en/2.1.1/error_codes.html#grouping). +# DARXXX: We (progressively) check docstrings (see https://github.com/terrencepreilly/darglint#error-codes). +# E203: We ignore a false positive in whitespace before ":" (see https://github.com/PyCQA/pycodestyle/issues/373). +# F403/405: We ignore * imports. +# R0401: We avoid cyclic imports —required for unit/doc tests. +# RST301: We use Google Python Style (see https://pypi.org/project/flake8-rst-docstrings/). +# W503/504: We break lines before binary operators (Knuth's style). [flake8] -extend-ignore = D -hang-closing = true -ignore = E128,E251,F403,F405,E501,RST301,W503,W504 -in-place = true -include-in-doctest = openfisca_core/commons openfisca_core/types -rst-directives = attribute, deprecated, seealso, versionadded, versionchanged -rst-roles = any, attr, class, exc, func, meth, obj -strictness = short +convention = google +docstring_style = google +extend-ignore = D +ignore = + B019 + E203 + E501 + F405 + E701 + E704 + RST210 + RST212 + RST213 + RST301 + RST306 + W503 +in-place = true +include-in-doctest = + openfisca_core/commons + openfisca_core/data_storage + openfisca_core/entities + openfisca_core/experimental + openfisca_core/holders + openfisca_core/indexed_enums + openfisca_core/periods + openfisca_core/projectors +max-line-length = 88 +per-file-ignores = + */types.py:D101,D102,E301,E704,W504 + */test_*.py:D101,D102,D103 + */__init__.py:F401 + */__init__.pyi:E302,E704 +rst-directives = attribute, deprecated, seealso, versionadded, versionchanged +rst-roles = any, attr, class, exc, func, meth, mod, obj +strictness = short -[pylint.message_control] -disable = all -enable = C0115,C0116,R0401 -score = no - -[coverage:paths] -source = . */site-packages +[pylint.MASTER] +load-plugins = pylint_per_file_ignores -[coverage:run] -branch = true -source = openfisca_core, openfisca_web_api +[pylint.message_control] +disable = all +enable = C0115, C0116, R0401 +per-file-ignores = + types.py:C0115,C0116 + /tests/:C0115,C0116 +score = no -[coverage:report] -fail_under = 75 -show_missing = true -skip_covered = true -skip_empty = true +[isort] +case_sensitive = true +combine_as_imports = true +force_alphabetical_sort_within_sections = false +group_by_package = true +honor_noqa = true +include_trailing_comma = true +known_first_party = openfisca_core +known_openfisca = openfisca_country_template, openfisca_extension_template +known_typing = *collections.abc*, *typing*, *typing_extensions* +known_types = *types* +multi_line_output = 3 +profile = black +py_version = 39 +sections = FUTURE, TYPING, TYPES, STDLIB, THIRDPARTY, OPENFISCA, FIRSTPARTY, LOCALFOLDER [tool:pytest] -addopts = --doctest-modules --disable-pytest-warnings --showlocals -doctest_optionflags = ELLIPSIS IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE -python_files = **/*.py -testpaths = openfisca_core/commons openfisca_core/types tests +addopts = --disable-pytest-warnings --doctest-modules --showlocals +doctest_optionflags = ELLIPSIS IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE +python_files = **/*.py +testpaths = tests [mypy] -ignore_missing_imports = True -install_types = True -non_interactive = True - -[mypy-openfisca_core.commons.tests.*] -ignore_errors = True +check_untyped_defs = false +disallow_any_decorated = false +disallow_any_explicit = false +disallow_any_expr = false +disallow_any_unimported = false +follow_imports = skip +ignore_missing_imports = true +implicit_reexport = false +install_types = true +mypy_path = stubs +non_interactive = true +plugins = numpy.typing.mypy_plugin +pretty = true +python_version = 3.9 +strict = false +warn_no_return = true +warn_unreachable = true -[mypy-openfisca_core.scripts.*] -ignore_errors = True +[mypy-openfisca_core.*.tests.*] +ignore_errors = True diff --git a/setup.py b/setup.py index 5975877b72..40c16dbfff 100644 --- a/setup.py +++ b/setup.py @@ -1,70 +1,118 @@ -#! /usr/bin/env python +"""Package config file. -from __future__ import annotations +This file contains all package's metadata, including the current version and +its third-party dependencies. -from typing import List +Note: + For integration testing, OpenFisca-Core relies on two other packages, + listed below. Because these packages rely at the same time on + OpenFisca-Core, adding them as official dependencies creates a resolution + loop that makes it hard to contribute. We've therefore decided to install + them via the task manager (`make install-test`):: -import re -from pathlib import Path -from setuptools import setup, find_packages + openfisca-country-template = "*" + openfisca-extension-template = "*" +""" -def load_requirements_from_file(filename: str) -> List[str]: - """Allows for composable requirement files with the `-r filename` flag.""" +from pathlib import Path - file = Path(f"./requirements/{filename}").resolve() - reqs = open(file).readlines() - pattern = re.compile(r"^\s*-r\s*(?P.*)$") +from setuptools import find_packages, setup - for req in reqs: - match = pattern.match(req) +# Read the contents of our README file for PyPi +this_directory = Path(__file__).parent +long_description = (this_directory / "README.md").read_text() - if match: - reqs.remove(req) - reqs.extend(load_requirements_from_file(match.group("filename"))) +# Please make sure to cap all dependency versions, in order to avoid unwanted +# functional and integration breaks caused by external code updates. +# DO NOT add space between '>=' and version number as it break conda build. +general_requirements = [ + "PyYAML >=6.0, <7.0", + "StrEnum >=0.4.8, <0.5.0", # 3.11.x backport + "dpath >=2.1.4, <3.0", + "numexpr >=2.10.1, <3.0", + "numpy >=1.24.2, <2.0", + "pendulum >=3.0.0, <4.0.0", + "psutil >=5.9.4, <6.0", + "pytest >=8.3.3, <9.0", + "sortedcontainers >=2.4.0, <3.0", + "typing_extensions >=4.5.0, <5.0", +] - return reqs +api_requirements = [ + "Flask >=2.2.3, <3.0", + "Flask-Cors >=3.0.10, <4.0", + "gunicorn >=21.0, <22.0", + "Werkzeug >=2.2.3, <3.0", +] +dev_requirements = [ + "black >=24.8.0, <25.0", + "codespell >=2.3.0, <3.0", + "colorama >=0.4.4, <0.5", + "darglint >=1.8.1, <2.0", + "flake8 >=7.1.1, <8.0.0", + "flake8-bugbear >=24.8.19, <25.0", + "flake8-docstrings >=1.7.0, <2.0", + "flake8-print >=5.0.0, <6.0", + "flake8-rst-docstrings >=0.3.0, <0.4.0", + "idna >=3.10, <4.0", + "isort >=5.13.2, <6.0", + "mypy >=1.11.2, <2.0", + "openapi-spec-validator >=0.7.1, <0.8.0", + "pylint >=3.3.1, <4.0", + "pylint-per-file-ignores >=1.3.2, <2.0", + "pyright >=1.1.382, <2.0", + "ruff >=0.6.9, <1.0", + "ruff-lsp >=0.0.57, <1.0", + *api_requirements, +] setup( - name = 'OpenFisca-Core', - version = '36.0.0', - author = 'OpenFisca Team', - author_email = 'contact@openfisca.org', - classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'Operating System :: POSIX', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3.7', - 'Topic :: Scientific/Engineering :: Information Analysis', - ], - description = 'A versatile microsimulation free software', - keywords = 'benefit microsimulation social tax', - license = 'https://www.fsf.org/licensing/licenses/agpl-3.0.html', - url = 'https://github.com/openfisca/openfisca-core', - data_files = [ + name="OpenFisca-Core", + version="43.2.2", + author="OpenFisca Team", + author_email="contact@openfisca.org", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Operating System :: POSIX", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Information Analysis", + ], + description="A versatile microsimulation free software", + keywords="benefit microsimulation social tax", + license="https://www.fsf.org/licensing/licenses/agpl-3.0.html", + license_files=("LICENSE",), + url="https://github.com/openfisca/openfisca-core", + long_description=long_description, + long_description_content_type="text/markdown", + data_files=[ ( - 'share/openfisca/openfisca-core', - ['CHANGELOG.md', 'LICENSE', 'README.md'], - ), + "share/openfisca/openfisca-core", + ["CHANGELOG.md", "README.md"], + ), + ], + entry_points={ + "console_scripts": [ + "openfisca=openfisca_core.scripts.openfisca_command:main", + "openfisca-run-test=openfisca_core.scripts.openfisca_command:main", + ], + }, + extras_require={ + "web-api": api_requirements, + "dev": dev_requirements, + "ci": [ + "build >=0.10.0, <0.11.0", + "twine >=5.1.1, <6.0", + "wheel >=0.40.0, <0.41.0", ], - entry_points = { - 'console_scripts': [ - 'openfisca=openfisca_core.scripts.openfisca_command:main', - 'openfisca-run-test=openfisca_core.scripts.openfisca_command:main', - ], - }, - python_requires = ">= 3.7", - install_requires = load_requirements_from_file("install"), - extras_require = { - "common": load_requirements_from_file("common"), - "coverage": load_requirements_from_file("coverage"), - "dev": load_requirements_from_file("dev"), - "publication": load_requirements_from_file("publication"), - "tracker": load_requirements_from_file("tracker"), - "web-api": load_requirements_from_file("web-api"), - }, - include_package_data = True, # Will read MANIFEST.in - packages = find_packages(exclude=['tests*']), - ) + "tracker": ["OpenFisca-Tracker >=0.4.0, <0.5.0"], + }, + include_package_data=True, # Will read MANIFEST.in + install_requires=general_requirements, + packages=find_packages(exclude=["tests*"]), +) diff --git a/stubs/numexpr/__init__.pyi b/stubs/numexpr/__init__.pyi new file mode 100644 index 0000000000..931d47ddb1 --- /dev/null +++ b/stubs/numexpr/__init__.pyi @@ -0,0 +1,10 @@ +from numpy.typing import NDArray + +import numpy + +def evaluate( + __ex: str, + /, + *__args: object, + **__kwargs: object, +) -> NDArray[numpy.bool_] | NDArray[numpy.int32] | NDArray[numpy.float32]: ... diff --git a/tests/core/parameter_validation/test_parameter_clone.py b/tests/core/parameter_validation/test_parameter_clone.py index a14630e9a0..6c77b4bb0b 100644 --- a/tests/core/parameter_validation/test_parameter_clone.py +++ b/tests/core/parameter_validation/test_parameter_clone.py @@ -6,21 +6,20 @@ year = 2016 -def test_clone(): - path = os.path.join(BASE_DIR, 'filesystem_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_clone() -> None: + path = os.path.join(BASE_DIR, "filesystem_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 clone = parameters.clone() - clone_at_instant = clone('2016-01-01') + clone_at_instant = clone("2016-01-01") assert clone_at_instant.node1.param == 1.0 assert id(clone) != id(parameters) assert id(clone.node1) != id(parameters.node1) assert id(clone.node1.param) != id(parameters.node1.param) -def test_clone_parameter(tax_benefit_system): - +def test_clone_parameter(tax_benefit_system) -> None: param = tax_benefit_system.parameters.taxes.income_tax_rate clone = param.clone() @@ -31,16 +30,16 @@ def test_clone_parameter(tax_benefit_system): assert clone.values_list == param.values_list -def test_clone_parameter_node(tax_benefit_system): +def test_clone_parameter_node(tax_benefit_system) -> None: node = tax_benefit_system.parameters.taxes clone = node.clone() assert clone is not node assert clone.income_tax_rate is not node.income_tax_rate - assert clone.children['income_tax_rate'] is not node.children['income_tax_rate'] + assert clone.children["income_tax_rate"] is not node.children["income_tax_rate"] -def test_clone_scale(tax_benefit_system): +def test_clone_scale(tax_benefit_system) -> None: scale = tax_benefit_system.parameters.taxes.social_security_contribution clone = scale.clone() @@ -48,7 +47,7 @@ def test_clone_scale(tax_benefit_system): assert clone.brackets[0].rate is not scale.brackets[0].rate -def test_deep_edit(tax_benefit_system): +def test_deep_edit(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters clone = parameters.clone() diff --git a/tests/core/parameter_validation/test_parameter_validation.py b/tests/core/parameter_validation/test_parameter_validation.py index 561fb28cb1..d3419312d2 100644 --- a/tests/core/parameter_validation/test_parameter_validation.py +++ b/tests/core/parameter_validation/test_parameter_validation.py @@ -1,18 +1,19 @@ -# -*- coding: utf-8 -*- - import os import pytest -from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import load_parameter_file, ParameterNode +from openfisca_core.parameters import ( + ParameterNode, + ParameterParsingError, + load_parameter_file, +) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) year = 2016 -def check_fails_with_message(file_name, keywords): - path = os.path.join(BASE_DIR, file_name) + '.yaml' +def check_fails_with_message(file_name, keywords) -> None: + path = os.path.join(BASE_DIR, file_name) + ".yaml" try: load_parameter_file(path, file_name) except ParameterParsingError as e: @@ -22,42 +23,65 @@ def check_fails_with_message(file_name, keywords): raise -@pytest.mark.parametrize("test", [ - ('indentation', {'Invalid YAML', 'indentation.yaml', 'line 2', 'mapping values are not allowed'}), - ("wrong_date", {"Error parsing parameter file", "Properties must be valid YYYY-MM-DD instants"}), - ('wrong_scale', {'Unexpected property', 'scale[1]', 'treshold'}), - ('wrong_value', {'not one of the allowed types', 'wrong_value[2015-12-01]', '1A'}), - ('unexpected_key_in_parameter', {'Unexpected property', 'unexpected_key'}), - ('wrong_type_in_parameter', {'must be of type object'}), - ('wrong_type_in_value_history', {'must be of type object'}), - ('unexpected_key_in_value_history', {'must be valid YYYY-MM-DD instants'}), - ('unexpected_key_in_value_at_instant', {'Unexpected property', 'unexpected_key'}), - ('unexpected_key_in_scale', {'Unexpected property', 'unexpected_key'}), - ('wrong_type_in_scale', {'must be of type object'}), - ('wrong_type_in_brackets', {'must be of type array'}), - ('wrong_type_in_bracket', {'must be of type object'}), - ('missing_value', {'missing', 'value'}), - ('duplicate_key', {'duplicate'}), - ]) -def test_parsing_errors(test): +@pytest.mark.parametrize( + "test", + [ + ( + "indentation", + { + "Invalid YAML", + "indentation.yaml", + "line 2", + "mapping values are not allowed", + }, + ), + ( + "wrong_date", + { + "Error parsing parameter file", + "Properties must be valid YYYY-MM-DD instants", + }, + ), + ("wrong_scale", {"Unexpected property", "scale[1]", "treshold"}), + ( + "wrong_value", + {"not one of the allowed types", "wrong_value[2015-12-01]", "1A"}, + ), + ("unexpected_key_in_parameter", {"Unexpected property", "unexpected_key"}), + ("wrong_type_in_parameter", {"must be of type object"}), + ("wrong_type_in_value_history", {"must be of type object"}), + ("unexpected_key_in_value_history", {"must be valid YYYY-MM-DD instants"}), + ( + "unexpected_key_in_value_at_instant", + {"Unexpected property", "unexpected_key"}, + ), + ("unexpected_key_in_scale", {"Unexpected property", "unexpected_key"}), + ("wrong_type_in_scale", {"must be of type object"}), + ("wrong_type_in_brackets", {"must be of type array"}), + ("wrong_type_in_bracket", {"must be of type object"}), + ("missing_value", {"missing", "value"}), + ("duplicate_key", {"duplicate"}), + ], +) +def test_parsing_errors(test) -> None: with pytest.raises(ParameterParsingError): check_fails_with_message(*test) -def test_array_type(): - path = os.path.join(BASE_DIR, 'array_type.yaml') - load_parameter_file(path, 'array_type') +def test_array_type() -> None: + path = os.path.join(BASE_DIR, "array_type.yaml") + load_parameter_file(path, "array_type") -def test_filesystem_hierarchy(): - path = os.path.join(BASE_DIR, 'filesystem_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_filesystem_hierarchy() -> None: + path = os.path.join(BASE_DIR, "filesystem_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 -def test_yaml_hierarchy(): - path = os.path.join(BASE_DIR, 'yaml_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_yaml_hierarchy() -> None: + path = os.path.join(BASE_DIR, "yaml_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 diff --git a/tests/core/parameters_date_indexing/__init__.py b/tests/core/parameters_date_indexing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/parameters_date_indexing/full_rate_age.yaml b/tests/core/parameters_date_indexing/full_rate_age.yaml new file mode 100644 index 0000000000..fa9377fec5 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_age.yaml @@ -0,0 +1,121 @@ +description: Full rate age +full_rate_age_by_birthdate: + description: Full rate age by birthdate + before_1951_07_01: + description: Born before 01/07/1951 + year: + description: Year + values: + 1983-04-01: + value: 65.0 + month: + description: Month + values: + 1983-04-01: + value: 0.0 + after_1951_07_01: + description: Born after 01/07/1951 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1952_01_01: + description: Born after 01/01/1952 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 9.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1953_01_01: + description: Born after 01/01/1953 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 2.0 + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null + after_1954_01_01: + description: Born after 01/01/1954 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 7.0 + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1955_01_01: + description: Born after 01/01/1955 + year: + description: Year + values: + 2012-01-01: + value: 67.0 + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 0.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1956_01_01: + description: Born after 01/01/1956 + year: + description: Year + values: + 2011-07-01: + value: 67.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null diff --git a/tests/core/parameters_date_indexing/full_rate_required_duration.yml b/tests/core/parameters_date_indexing/full_rate_required_duration.yml new file mode 100644 index 0000000000..af394ec568 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_required_duration.yml @@ -0,0 +1,162 @@ +description: Required contribution duration for full rate +contribution_quarters_required_by_birthdate: + description: Contribution quarters required by birthdate + before_1934_01_01: + description: before 1934 + values: + 1983-01-01: + value: 150.0 + after_1934_01_01: + description: '1934-01-01' + values: + 1994-01-01: + value: 151.0 + 1983-01-01: + value: null + after_1935_01_01: + description: '1935-01-01' + values: + 1994-01-01: + value: 152.0 + 1983-01-01: + value: null + after_1936_01_01: + description: '1936-01-01' + values: + 1994-01-01: + value: 153.0 + 1983-01-01: + value: null + after_1937_01_01: + description: '1937-01-01' + values: + 1994-01-01: + value: 154.0 + 1983-01-01: + value: null + after_1938_01_01: + description: '1938-01-01' + values: + 1994-01-01: + value: 155.0 + 1983-01-01: + value: null + after_1939_01_01: + description: '1939-01-01' + values: + 1994-01-01: + value: 156.0 + 1983-01-01: + value: null + after_1940_01_01: + description: '1940-01-01' + values: + 1994-01-01: + value: 157.0 + 1983-01-01: + value: null + after_1941_01_01: + description: '1941-01-01' + values: + 1994-01-01: + value: 158.0 + 1983-01-01: + value: null + after_1942_01_01: + description: '1942-01-01' + values: + 1994-01-01: + value: 159.0 + 1983-01-01: + value: null + after_1943_01_01: + description: '1943-01-01' + values: + 1994-01-01: + value: 160.0 + 1983-01-01: + value: null + after_1949_01_01: + description: '1949-01-01' + values: + 2009-01-01: + value: 161.0 + 1983-01-01: + value: null + after_1950_01_01: + description: '1950-01-01' + values: + 2009-01-01: + value: 162.0 + 1983-01-01: + value: null + after_1951_01_01: + description: '1951-01-01' + values: + 2009-01-01: + value: 163.0 + 1983-01-01: + value: null + after_1952_01_01: + description: '1952-01-01' + values: + 2009-01-01: + value: 164.0 + 1983-01-01: + value: null + after_1953_01_01: + description: '1953-01-01' + values: + 2012-01-01: + value: 165.0 + 1983-01-01: + value: null + after_1955_01_01: + description: '1955-01-01' + values: + 2013-01-01: + value: 166.0 + 1983-01-01: + value: null + after_1958_01_01: + description: '1958-01-01' + values: + 2015-01-01: + value: 167.0 + 1983-01-01: + value: null + after_1961_01_01: + description: '1961-01-01' + values: + 2015-01-01: + value: 168.0 + 1983-01-01: + value: null + after_1964_01_01: + description: '1964-01-01' + values: + 2015-01-01: + value: 169.0 + 1983-01-01: + value: null + after_1967_01_01: + description: '1967-01-01' + values: + 2015-01-01: + value: 170.0 + 1983-01-01: + value: null + after_1970_01_01: + description: '1970-01-01' + values: + 2015-01-01: + value: 171.0 + 1983-01-01: + value: null + after_1973_01_01: + description: '1973-01-01' + values: + 2015-01-01: + value: 172.0 + 1983-01-01: + value: null diff --git a/tests/core/parameters_date_indexing/test_date_indexing.py b/tests/core/parameters_date_indexing/test_date_indexing.py new file mode 100644 index 0000000000..cefec26648 --- /dev/null +++ b/tests/core/parameters_date_indexing/test_date_indexing.py @@ -0,0 +1,48 @@ +import os + +import numpy + +from openfisca_core.parameters import ParameterNode +from openfisca_core.tools import assert_near + +from openfisca_core.model_api import * # noqa + +LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) + +parameters = ParameterNode(directory_path=LOCAL_DIR) + + +def get_message(error): + return error.args[0] + + +def test_on_leaf() -> None: + parameter_at_instant = parameters.full_rate_required_duration("1995-01-01") + birthdate = numpy.array( + ["1930-01-01", "1935-01-01", "1940-01-01", "1945-01-01"], + dtype="datetime64[D]", + ) + assert_near( + parameter_at_instant.contribution_quarters_required_by_birthdate[birthdate], + [150, 152, 157, 160], + ) + + +def test_on_node() -> None: + birthdate = numpy.array( + ["1950-01-01", "1953-01-01", "1956-01-01", "1959-01-01"], + dtype="datetime64[D]", + ) + parameter_at_instant = parameters.full_rate_age("2012-03-01") + node = parameter_at_instant.full_rate_age_by_birthdate[birthdate] + assert_near(node.year, [65, 66, 67, 67]) + assert_near(node.month, [0, 2, 0, 0]) + + +# def test_inhomogenous(): +# birthdate = numpy.array(['1930-01-01', '1935-01-01', '1940-01-01', '1945-01-01'], dtype = 'datetime64[D]') +# parameter_at_instant = parameters..full_rate_age('2011-01-01') +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# with pytest.raises(ValueError) as error: +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# assert "Cannot use fancy indexing on parameter node '.full_rate_age.full_rate_age_by_birthdate'" in get_message(error.value) diff --git a/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml new file mode 100644 index 0000000000..9894ae64aa --- /dev/null +++ b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml @@ -0,0 +1,135 @@ +description: Coefficient de minoration ARRCO +coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee: + description: Coefficient de minoration à l'Arrco en fonction de la distance à l'âge d'annulation de la décote (en année) + '-10': + description: '-10' + values: + 1965-01-01: + value: 0.43 + 1957-05-15: + value: null + '-9': + description: '-9' + values: + 1965-01-01: + value: 0.5 + 1957-05-15: + value: null + '-8': + description: '-8' + values: + 1965-01-01: + value: 0.57 + 1957-05-15: + value: null + '-7': + description: '-7' + values: + 1965-01-01: + value: 0.64 + 1957-05-15: + value: null + '-6': + description: '-6' + values: + 1965-01-01: + value: 0.71 + 1957-05-15: + value: null + '-5': + description: '-5' + values: + 1965-01-01: + value: 0.78 + 1957-05-15: + value: 0.75 + '-4': + description: '-4' + values: + 1965-01-01: + value: 0.83 + 1957-05-15: + value: 0.8 + '-3': + description: '-3' + values: + 1965-01-01: + value: 0.88 + 1957-05-15: + value: 0.85 + '-2': + description: '-2' + values: + 1965-01-01: + value: 0.92 + 1957-05-15: + value: 0.9 + '-1': + description: '-1' + values: + 1965-01-01: + value: 0.96 + 1957-05-15: + value: 0.95 + '0': + description: '0' + values: + 1965-01-01: + value: 1.0 + 1957-05-15: + value: 1.05 + '1': + description: '1' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.1 + '2': + description: '2' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.15 + '3': + description: '3' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.2 + '4': + description: '4' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.25 + metadata: + order: + - '-10' + - '-9' + - '-8' + - '-7' + - '-6' + - '-5' + - '-4' + - '-3' + - '-2' + - '-1' + - '0' + - '1' + - '2' + - '3' + - '4' +metadata: + order: + - coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee + reference: + 1965-01-01: Article 18 de l'annexe A de l'Accord national interprofessionnel de retraite complémentaire du 8 décembre 1961 + 1957-05-15: Accord du 15/05/1957 pour la création de l'UNIRS + description_en: Penalty for early retirement ARRCO +documentation: | + Note: Le coefficient d'abattement (ou de majoration avant 1965) constitue une multiplication des droits de pension à l'arrco par le coefficient en question. Par exemple, un individu partant en retraite à 60 ans en 1960 touchait 75% de sa pension. A partir de 1983, une double condition d'âge et de durée d'assurance est instaurée: un individu ayant validé une durée égale à la durée d'assurance cible(voir onglet Trim_tx_plein_RG) partira sans abbattement, même s'il n'a pas atteint l'âge d'annulation de la décôte dans le régime général (voir onglet Age_ann_dec_RG). + Note : le coefficient de minoration est linéaire en nombre de trimestres, e.g. il est de 0,43 à AAD - 10 ans, de 0,4475 à AAD - 9 ans et 3 trimestres, de 0,465 à AAD - 9 ans et 2 trimestres, etc. diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index 41f42fad88..b7e7cf4e45 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -1,149 +1,177 @@ -# -*- coding: utf-8 -*- - import os import re -import numpy as np +import numpy import pytest - -from openfisca_core.errors import ParameterNotFoundError -from openfisca_core.model_api import * # noqa -from openfisca_core.parameters import ParameterNode, Parameter +from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import Parameter, ParameterNode, ParameterNotFound from openfisca_core.tools import assert_near LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) -parameters = ParameterNode(directory_path = LOCAL_DIR) +parameters = ParameterNode(directory_path=LOCAL_DIR) -P = parameters.rate('2015-01-01') +P = parameters.rate("2015-01-01") def get_message(error): return error.args[0] -def test_on_leaf(): - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) +def test_on_leaf() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) -def test_on_node(): - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_on_node() -> None: + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P.single[housing_occupancy_status] assert_near(node.z1, [100, 100, 300, 300]) - assert_near(node['z1'], [100, 100, 300, 300]) + assert_near(node["z1"], [100, 100, 300, 300]) -def test_double_fancy_indexing(): - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_double_fancy_indexing() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) assert_near(P.single[housing_occupancy_status][zone], [100, 200, 400, 300]) -def test_double_fancy_indexing_on_node(): - family_status = np.asarray(['single', 'couple', 'single', 'couple']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_double_fancy_indexing_on_node() -> None: + family_status = numpy.asarray(["single", "couple", "single", "couple"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P[family_status][housing_occupancy_status] assert_near(node.z1, [100, 500, 300, 700]) - assert_near(node['z1'], [100, 500, 300, 700]) + assert_near(node["z1"], [100, 500, 300, 700]) assert_near(node.z2, [200, 600, 400, 800]) - assert_near(node['z2'], [200, 600, 400, 800]) - - -def test_triple_fancy_indexing(): - family_status = np.asarray(['single', 'single', 'single', 'single', 'couple', 'couple', 'couple', 'couple']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant', 'owner', 'owner', 'tenant', 'tenant']) - zone = np.asarray(['z1', 'z2', 'z1', 'z2', 'z1', 'z2', 'z1', 'z2']) - assert_near(P[family_status][housing_occupancy_status][zone], [100, 200, 300, 400, 500, 600, 700, 800]) - - -def test_wrong_key(): - zone = np.asarray(['z1', 'z2', 'z2', 'toto']) - with pytest.raises(ParameterNotFoundError) as e: + assert_near(node["z2"], [200, 600, 400, 800]) + + +def test_triple_fancy_indexing() -> None: + family_status = numpy.asarray( + [ + "single", + "single", + "single", + "single", + "couple", + "couple", + "couple", + "couple", + ], + ) + housing_occupancy_status = numpy.asarray( + ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"], + ) + zone = numpy.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) + assert_near( + P[family_status][housing_occupancy_status][zone], + [100, 200, 300, 400, 500, 600, 700, 800], + ) + + +def test_wrong_key() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "toto"]) + with pytest.raises(ParameterNotFound) as e: P.single.owner[zone] assert "'rate.single.owner.toto' was not found" in get_message(e.value) -def test_inhomogenous(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.owner.add_child('toto', Parameter('toto', { - "values": { - "2015-01-01": { - "value": 1000 +def test_inhomogenous() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.owner.add_child( + "toto", + Parameter( + "toto", + { + "values": { + "2015-01-01": {"value": 1000}, }, - } - })) + }, + ), + ) - P = parameters.rate('2015-01-01') - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) + P = parameters.rate("2015-01-01") + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as error: P.couple[housing_occupancy_status] assert "'rate.couple.owner.toto' exists" in get_message(error.value) assert "'rate.couple.tenant.toto' doesn't" in get_message(error.value) -def test_inhomogenous_2(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.tenant.add_child('toto', Parameter('toto', { - "values": { - "2015-01-01": { - "value": 1000 +def test_inhomogenous_2() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.tenant.add_child( + "toto", + Parameter( + "toto", + { + "values": { + "2015-01-01": {"value": 1000}, }, - } - })) + }, + ), + ) - P = parameters.rate('2015-01-01') - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) + P = parameters.rate("2015-01-01") + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as e: P.couple[housing_occupancy_status] assert "'rate.couple.tenant.toto' exists" in get_message(e.value) assert "'rate.couple.owner.toto' doesn't" in get_message(e.value) -def test_inhomogenous_3(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.tenant.add_child('z4', ParameterNode('toto', data = { - 'amount': { - 'values': { - "2015-01-01": {'value': 550}, - "2016-01-01": {'value': 600} - } - } - })) +def test_inhomogenous_3() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.tenant.add_child( + "z4", + ParameterNode( + "toto", + data={ + "amount": { + "values": { + "2015-01-01": {"value": 550}, + "2016-01-01": {"value": 600}, + }, + }, + }, + ), + ) - P = parameters.rate('2015-01-01') - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) + P = parameters.rate("2015-01-01") + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) with pytest.raises(ValueError) as e: P.couple.tenant[zone] assert "'rate.couple.tenant.z4' is a node" in get_message(e.value) assert re.findall(r"'rate.couple.tenant.z(1|2|3)' is not", get_message(e.value)) -P_2 = parameters.local_tax('2015-01-01') +P_2 = parameters.local_tax("2015-01-01") -def test_with_properties_starting_by_number(): - city_code = np.asarray(['75012', '75007', '75015']) +def test_with_properties_starting_by_number() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) assert_near(P_2[city_code], [100, 300, 200]) -P_3 = parameters.bareme('2015-01-01') +P_3 = parameters.bareme("2015-01-01") -def test_with_bareme(): - city_code = np.asarray(['75012', '75007', '75015']) +def test_with_bareme() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) with pytest.raises(NotImplementedError) as e: P_3[city_code] - assert re.findall(r"'bareme.7501\d' is a 'MarginalRateTaxScale'", get_message(e.value)) + assert re.findall( + r"'bareme.7501\d' is a 'MarginalRateTaxScale'", + get_message(e.value), + ) assert "has not been implemented" in get_message(e.value) -def test_with_enum(): - +def test_with_enum() -> None: class TypesZone(Enum): z1 = "Zone 1" z2 = "Zone 2" - zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) + zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) diff --git a/tests/core/tax_scales/test_abstract_rate_tax_scale.py b/tests/core/tax_scales/test_abstract_rate_tax_scale.py index 3d284a49e9..c966aa30f3 100644 --- a/tests/core/tax_scales/test_abstract_rate_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_rate_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractRateTaxScale() - assert type(result) == taxscales.AbstractRateTaxScale + assert isinstance(result, taxscales.AbstractRateTaxScale) diff --git a/tests/core/tax_scales/test_abstract_tax_scale.py b/tests/core/tax_scales/test_abstract_tax_scale.py index f6834e7dc7..aad04d58ed 100644 --- a/tests/core/tax_scales/test_abstract_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractTaxScale() - assert type(result) == taxscales.AbstractTaxScale + assert isinstance(result, taxscales.AbstractTaxScale) diff --git a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py index 74b2762963..6205d6de9b 100644 --- a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py +++ b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py @@ -1,13 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools -from openfisca_core.errors import EmptyArgumentError - import pytest +from openfisca_core import taxscales, tools + -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -19,50 +16,50 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, factor = 2.0) + result = tax_scale.bracket_indices(tax_base, factor=2.0) tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, round_decimals = 0) + result = tax_scale.bracket_indices(tax_base, round_decimals=0) tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - with pytest.raises(EmptyArgumentError): + with pytest.raises(taxscales.EmptyArgumentError): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() - with pytest.raises(EmptyArgumentError): + with pytest.raises(taxscales.EmptyArgumentError): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -72,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_to_marginal(): +def test_to_marginal() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -82,9 +79,9 @@ def test_to_marginal(): result = tax_scale.to_marginal() assert result.thresholds == [0, 1, 2] - tools.assert_near(result.rates, [0.1, 0.3, 0.2], absolute_error_margin = 0) + tools.assert_near(result.rates, [0.1, 0.3, 0.2], absolute_error_margin=0) tools.assert_near( result.calc(tax_base), [0.1, 0.25, 0.4, 0.5], - absolute_error_margin = 0, - ) + absolute_error_margin=0, + ) diff --git a/tests/core/tax_scales/test_marginal_amount_tax_scale.py b/tests/core/tax_scales/test_marginal_amount_tax_scale.py index cdd7cc4f27..0a3275c901 100644 --- a/tests/core/tax_scales/test_marginal_amount_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_amount_tax_scale.py @@ -1,12 +1,8 @@ from numpy import array - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -16,13 +12,15 @@ def data(): "brackets": [ { "threshold": {"2017-10-01": {"value": 0.23}}, - "amount": {"2017-10-01": {"value": 6}, }, - } - ], - } + "amount": { + "2017-10-01": {"value": 6}, + }, + }, + ], + } -def test_calc(): +def test_calc() -> None: tax_base = array([1, 8, 10]) tax_scale = taxscales.MarginalAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -34,8 +32,8 @@ def test_calc(): # TODO: move, as we're testing Scale, not MarginalAmountTaxScale -def test_dispatch_scale_type_on_creation(data): - scale = parameters.ParameterScale("amount_scale", data, "") +def test_dispatch_scale_type_on_creation(data) -> None: + scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) result = scale.get_at_instant(first_jan) diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py index 505d103348..7696e95fc4 100644 --- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py @@ -1,13 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools -from openfisca_core.errors import EmptyArgumentError - import pytest +from openfisca_core import taxscales, tools -def test_bracket_indices(): + +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -19,50 +16,50 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, factor = 2.0) + result = tax_scale.bracket_indices(tax_base, factor=2.0) tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, round_decimals = 0) + result = tax_scale.bracket_indices(tax_base, round_decimals=0) tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - with pytest.raises(EmptyArgumentError): + with pytest.raises(taxscales.EmptyArgumentError): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() - with pytest.raises(EmptyArgumentError): + with pytest.raises(taxscales.EmptyArgumentError): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -72,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5, 3.0, 4.0]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -85,11 +82,11 @@ def test_calc(): tools.assert_near( result, [0, 0.05, 0.1, 0.2, 0.3, 0.3], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_without_round(): +def test_calc_without_round() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -100,56 +97,56 @@ def test_calc_without_round(): tools.assert_near( result, [10, 10.02, 10.0002, 10.06, 10.0006, 10.05, 10.0005], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_1(): +def test_calc_when_round_is_1() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 1) + result = tax_scale.calc(tax_base, round_base_decimals=1) tools.assert_near( result, [10, 10.0, 10.0, 10.1, 10.0, 10, 10.0], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_2(): +def test_calc_when_round_is_2() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 2) + result = tax_scale.calc(tax_base, round_base_decimals=2) tools.assert_near( result, [10, 10.02, 10.0, 10.06, 10.00, 10.05, 10], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_3(): +def test_calc_when_round_is_3() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 3) + result = tax_scale.calc(tax_base, round_base_decimals=3) tools.assert_near( result, [10, 10.02, 10.0, 10.06, 10.001, 10.05, 10], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_marginal_rates(): +def test_marginal_rates() -> None: tax_base = numpy.array([0, 10, 50, 125, 250]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -161,7 +158,7 @@ def test_marginal_rates(): tools.assert_near(result, [0, 0, 0, 0.1, 0.2]) -def test_inverse(): +def test_inverse() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -174,7 +171,7 @@ def test_inverse(): tools.assert_near(result.calc(net_tax_base), gross_tax_base, 1e-15) -def test_scale_tax_scales(): +def test_scale_tax_scales() -> None: tax_base = numpy.array([1, 2, 3]) tax_base_scale = 12.345 scaled_tax_base = tax_base * tax_base_scale @@ -188,7 +185,7 @@ def test_scale_tax_scales(): tools.assert_near(result.thresholds, scaled_tax_base) -def test_inverse_scaled_marginal_tax_scales(): +def test_inverse_scaled_marginal_tax_scales() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) gross_tax_base_scale = 12.345 scaled_gross_tax_base = gross_tax_base * gross_tax_base_scale @@ -197,17 +194,16 @@ def test_inverse_scaled_marginal_tax_scales(): tax_scale.add_bracket(1, 0.1) tax_scale.add_bracket(3, 0.05) scaled_tax_scale = tax_scale.scale_tax_scales(gross_tax_base_scale) - scaled_net_tax_base = ( - + scaled_gross_tax_base - - scaled_tax_scale.calc(scaled_gross_tax_base) - ) + scaled_net_tax_base = +scaled_gross_tax_base - scaled_tax_scale.calc( + scaled_gross_tax_base, + ) result = scaled_tax_scale.inverse() tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13) -def test_to_average(): +def test_to_average() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -222,5 +218,33 @@ def test_to_average(): tools.assert_near( result.calc(tax_base), [0, 0.0375, 0.1, 0.125], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) + + +def test_rate_from_bracket_indice() -> None: + tax_base = numpy.array([0, 1_000, 1_500, 50_000]) + tax_scale = taxscales.MarginalRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + + bracket_indice = tax_scale.bracket_indices(tax_base) + result = tax_scale.rate_from_bracket_indice(bracket_indice) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0.0, 0.1, 0.1, 0.4])).all() + + +def test_rate_from_tax_base() -> None: + tax_base = numpy.array([0, 3_000, 15_500, 500_000]) + tax_scale = taxscales.MarginalRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + tax_scale.add_bracket(200_000, 0.6) + + result = tax_scale.rate_from_tax_base(tax_base) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0.0, 0.1, 0.4, 0.6])).all() diff --git a/tests/core/tax_scales/test_rate_tax_scale_like.py b/tests/core/tax_scales/test_rate_tax_scale_like.py new file mode 100644 index 0000000000..9f5bc61286 --- /dev/null +++ b/tests/core/tax_scales/test_rate_tax_scale_like.py @@ -0,0 +1,17 @@ +import numpy + +from openfisca_core import taxscales + + +def test_threshold_from_tax_base() -> None: + tax_base = numpy.array([0, 33_000, 500, 400_000]) + tax_scale = taxscales.LinearAverageRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + tax_scale.add_bracket(200_000, 0.6) + + result = tax_scale.threshold_from_tax_base(tax_base) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0, 15_000, 400, 200_000])).all() diff --git a/tests/core/tax_scales/test_single_amount_tax_scale.py b/tests/core/tax_scales/test_single_amount_tax_scale.py index 0eb63c1f26..2b384f6374 100644 --- a/tests/core/tax_scales/test_single_amount_tax_scale.py +++ b/tests/core/tax_scales/test_single_amount_tax_scale.py @@ -1,12 +1,8 @@ import numpy - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -16,17 +12,19 @@ def data(): "type": "single_amount", "threshold_unit": "currency-EUR", "rate_unit": "/1", - }, + }, "brackets": [ { "threshold": {"2017-10-01": {"value": 0.23}}, - "amount": {"2017-10-01": {"value": 6}, }, - } - ], - } + "amount": { + "2017-10-01": {"value": 6}, + }, + }, + ], + } -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 8, 10]) tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -37,7 +35,7 @@ def test_calc(): tools.assert_near(result, [0, 0.23, 0.29]) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) tax_scale.add_bracket(9, 0.29) @@ -48,8 +46,8 @@ def test_to_dict(): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_thresholds_on_creation(data): - scale = parameters.ParameterScale("amount_scale", data, "") +def test_assign_thresholds_on_creation(data) -> None: + scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -59,8 +57,8 @@ def test_assign_thresholds_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_amounts_on_creation(data): - scale = parameters.ParameterScale("amount_scale", data, "") +def test_assign_amounts_on_creation(data) -> None: + scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -70,8 +68,8 @@ def test_assign_amounts_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_dispatch_scale_type_on_creation(data): - scale = parameters.ParameterScale("amount_scale", data, "") +def test_dispatch_scale_type_on_creation(data) -> None: + scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) result = scale.get_at_instant(first_jan) diff --git a/tests/core/tax_scales/test_tax_scales_commons.py b/tests/core/tax_scales/test_tax_scales_commons.py index d45bdd894a..544e5a07fe 100644 --- a/tests/core/tax_scales/test_tax_scales_commons.py +++ b/tests/core/tax_scales/test_tax_scales_commons.py @@ -1,32 +1,30 @@ -from openfisca_core import parameters -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import parameters, taxscales, tools + @pytest.fixture def node(): return parameters.ParameterNode( "baremes", - data = { + data={ "health": { "brackets": [ {"rate": {"2015-01-01": 0.05}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.10}, "threshold": {"2015-01-01": 2000}}, - ] - }, + ], + }, "retirement": { "brackets": [ {"rate": {"2015-01-01": 0.02}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.04}, "threshold": {"2015-01-01": 3000}}, - ] - }, + ], }, - )(2015) + }, + )(2015) -def test_combine_tax_scales(node): +def test_combine_tax_scales(node) -> None: result = taxscales.combine_tax_scales(node) tools.assert_near(result.thresholds, [0, 2000, 3000]) diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index f106a82a5b..11590daf51 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,192 +1,338 @@ import pytest +from openfisca_core import errors from openfisca_core.simulations import SimulationBuilder from openfisca_core.tools import test_runner - # With periods -def test_add_axis_without_period(persons): +def test_add_axis_without_period(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('2018-11') - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000}) + simulation_builder.set_default_period("2018-11") + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) # With variables -def test_add_axis_on_a_non_existing_variable(persons): +def test_add_axis_on_a_non_existing_variable(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'ubi', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"}, + ) with pytest.raises(KeyError): simulation_builder.expand_axes() -def test_add_axis_on_an_existing_variable_with_input(persons): +def test_add_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {'salary': {'2018-11': 1000}}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {"salary": {"2018-11": 1000}}}, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_count('persons') == 3 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_count("persons") == 3 + assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] # With entities -def test_add_axis_on_persons(persons): +def test_add_axis_on_persons(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_count('persons') == 3 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_count("persons") == 3 + assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] -def test_add_two_axes(persons): +def test_add_two_axes(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 1000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 1000, 2000], + ) -def test_add_axis_with_group(persons): +def test_add_axis_with_group(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11', 'index': 1}) + simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_parallel_axis( + { + "count": 2, + "name": "salary", + "min": 0, + "max": 3000, + "period": "2018-11", + "index": 1, + }, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('persons') == 4 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier1', 'Alicia2', 'Javier3'] - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 0, 3000, 3000]) - - -def test_add_axis_with_group_int_period(persons): + assert simulation_builder.get_count("persons") == 4 + assert simulation_builder.get_ids("persons") == [ + "Alicia0", + "Javier1", + "Alicia2", + "Javier3", + ] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 0, 3000, 3000], + ) + + +def test_add_axis_with_group_int_period(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018}) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1}) + simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018}, + ) + simulation_builder.add_parallel_axis( + { + "count": 2, + "name": "salary", + "min": 0, + "max": 3000, + "period": 2018, + "index": 1, + }, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018') == pytest.approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input("salary", "2018") == pytest.approx( + [0, 0, 3000, 3000], + ) -def test_add_axis_on_households(persons, households): +def test_add_axis_on_households(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia', 'Javier']}, - 'houseb': {'parents': ['Tom']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia", "Javier"]}, + "houseb": {"parents": ["Tom"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('households') == 4 - assert simulation_builder.get_ids('households') == ['housea0', 'houseb1', 'housea2', 'houseb3'] - assert simulation_builder.get_input('rent', '2018-11') == pytest.approx([0, 0, 3000, 0]) - - -def test_axis_on_group_expands_persons(persons, households): + assert simulation_builder.get_count("households") == 4 + assert simulation_builder.get_ids("households") == [ + "housea0", + "houseb1", + "housea2", + "houseb3", + ] + assert simulation_builder.get_input("rent", "2018-11") == pytest.approx( + [0, 0, 3000, 0], + ) + + +def test_axis_on_group_expands_persons(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia', 'Javier']}, - 'houseb': {'parents': ['Tom']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia", "Javier"]}, + "houseb": {"parents": ["Tom"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('persons') == 6 + assert simulation_builder.get_count("persons") == 6 -def test_add_axis_distributes_roles(persons, households): +def test_add_axis_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'child', 'parent', 'parent', 'child', 'parent'] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "child", + "parent", + "parent", + "child", + "parent", + ] -def test_add_axis_on_persons_distributes_roles(persons, households): +def test_add_axis_on_persons_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'child', 'parent', 'parent', 'child', 'parent'] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "child", + "parent", + "parent", + "child", + "parent", + ] -def test_add_axis_distributes_memberships(persons, households): +def test_add_axis_distributes_memberships(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_memberships('households') == [0, 1, 1, 2, 3, 3] + assert simulation_builder.get_memberships("households") == [0, 1, 1, 2, 3, 3] -def test_add_perpendicular_axes(persons): +def test_add_perpendicular_axes(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.register_variable('pension', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.register_variable("pension", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_perpendicular_axis( + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000, 0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 0, 0, 2000, 2000, 2000], + ) -def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): +def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, { - 'Alicia': { - 'salary': {'2018-11': 1000}, - 'pension': {'2018-11': 1000}, + simulation_builder.add_person_entity( + persons, + { + "Alicia": { + "salary": {"2018-11": 1000}, + "pension": {"2018-11": 1000}, }, - },) - simulation_builder.register_variable('salary', persons) - simulation_builder.register_variable('pension', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + }, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.register_variable("pension", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_perpendicular_axis( + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000, 0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 0, 0, 2000, 2000, 2000], + ) -# Integration test +# Integration tests -def test_simulation_with_axes(tax_benefit_system): +def test_simulation_with_axes(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {salary: {2018-11: 0}} @@ -207,5 +353,31 @@ def test_simulation_with_axes(tax_benefit_system): """ data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) - assert simulation.get_array('salary', '2018-11') == pytest.approx([0, 0, 0, 0, 0, 0]) - assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 0]) + assert simulation.get_array("salary", "2018-11") == pytest.approx( + [0, 0, 0, 0, 0, 0], + ) + assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0]) + + +# Test for missing group entities with build_from_entities() + + +def test_simulation_with_axes_missing_entities(tax_benefit_system) -> None: + input_yaml = """ + persons: + Alicia: {salary: {2018-11: 0}} + Javier: {} + Tom: {} + axes: + - + - count: 2 + name: rent + min: 0 + max: 3000 + period: 2018-11 + """ + data = test_runner.yaml.safe_load(input_yaml) + with pytest.raises(errors.SituationParsingError) as error: + SimulationBuilder().build_from_dict(tax_benefit_system, data) + assert "In order to expand over axes" in error.value() + assert "all group entities and roles must be fully specified" in error.value() diff --git a/tests/core/test_calculate_output.py b/tests/core/test_calculate_output.py index 6a11a27d84..54d868ba92 100644 --- a/tests/core/test_calculate_output.py +++ b/tests/core/test_calculate_output.py @@ -2,57 +2,67 @@ from openfisca_country_template import entities, situation_examples -from openfisca_core import periods, simulations, tools +from openfisca_core import simulations, tools +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable class simple_variable(Variable): entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH value_type = int class variable_with_calculate_output_add(Variable): entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH value_type = int calculate_output = simulations.calculate_output_add class variable_with_calculate_output_divide(Variable): entity = entities.Person - definition_period = periods.YEAR + definition_period = DateUnit.YEAR value_type = int calculate_output = simulations.calculate_output_divide -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( simple_variable, variable_with_calculate_output_add, - variable_with_calculate_output_divide - ) + variable_with_calculate_output_divide, + ) @pytest.fixture def simulation(tax_benefit_system): - return SimulationBuilder().build_from_entities(tax_benefit_system, situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.single, + ) -def test_calculate_output_default(simulation): +def test_calculate_output_default(simulation) -> None: with pytest.raises(ValueError): - simulation.calculate_output('simple_variable', 2017) + simulation.calculate_output("simple_variable", 2017) -def test_calculate_output_add(simulation): - simulation.set_input('variable_with_calculate_output_add', '2017-01', [10]) - simulation.set_input('variable_with_calculate_output_add', '2017-05', [20]) - simulation.set_input('variable_with_calculate_output_add', '2017-12', [70]) - tools.assert_near(simulation.calculate_output('variable_with_calculate_output_add', 2017), 100) +def test_calculate_output_add(simulation) -> None: + simulation.set_input("variable_with_calculate_output_add", "2017-01", [10]) + simulation.set_input("variable_with_calculate_output_add", "2017-05", [20]) + simulation.set_input("variable_with_calculate_output_add", "2017-12", [70]) + tools.assert_near( + simulation.calculate_output("variable_with_calculate_output_add", 2017), + 100, + ) -def test_calculate_output_divide(simulation): - simulation.set_input('variable_with_calculate_output_divide', 2017, [12000]) - tools.assert_near(simulation.calculate_output('variable_with_calculate_output_divide', '2017-06'), 1000) +def test_calculate_output_divide(simulation) -> None: + simulation.set_input("variable_with_calculate_output_divide", 2017, [12000]) + tools.assert_near( + simulation.calculate_output("variable_with_calculate_output_divide", "2017-06"), + 1000, + ) diff --git a/tests/core/test_countries.py b/tests/core/test_countries.py index aeb4d762c7..d206a8cb35 100644 --- a/tests/core/test_countries.py +++ b/tests/core/test_countries.py @@ -2,55 +2,56 @@ from openfisca_core import periods, populations, tools from openfisca_core.errors import VariableNameConflictError, VariableNotFoundError +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable PERIOD = periods.period("2016-01") -@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect = True) -def test_input_variable(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) +def test_input_variable(simulation) -> None: result = simulation.calculate("salary", PERIOD) - tools.assert_near(result, [2000], absolute_error_margin = 0.01) + tools.assert_near(result, [2000], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect = True) -def test_basic_calculation(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) +def test_basic_calculation(simulation) -> None: result = simulation.calculate("income_tax", PERIOD) - tools.assert_near(result, [300], absolute_error_margin = 0.01) + tools.assert_near(result, [300], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect = True) -def test_calculate_add(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect=True) +def test_calculate_add(simulation) -> None: result = simulation.calculate_add("income_tax", PERIOD) - tools.assert_near(result, [3600], absolute_error_margin = 0.01) + tools.assert_near(result, [3600], absolute_error_margin=0.01) @pytest.mark.parametrize( "simulation", [({"accommodation_size": 100, "housing_occupancy_status": "tenant"}, PERIOD)], - indirect = True, - ) -def test_calculate_divide(simulation): + indirect=True, +) +def test_calculate_divide(simulation) -> None: result = simulation.calculate_divide("housing_tax", PERIOD) - tools.assert_near(result, [1000 / 12.], absolute_error_margin = 0.01) + tools.assert_near(result, [1000 / 12.0], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect = True) -def test_bareme(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect=True) +def test_bareme(simulation) -> None: result = simulation.calculate("social_security_contribution", PERIOD) expected = [0.02 * 6000 + 0.06 * 6400 + 0.12 * 7600] - tools.assert_near(result, expected, absolute_error_margin = 0.01) + tools.assert_near(result, expected, absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_non_existing_variable(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_non_existing_variable(simulation) -> None: with pytest.raises(VariableNotFoundError): simulation.calculate("non_existent_variable", PERIOD) -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_calculate_variable_with_wrong_definition_period(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_calculate_variable_with_wrong_definition_period(simulation) -> None: year = str(PERIOD.this_year) with pytest.raises(ValueError) as error: @@ -60,30 +61,28 @@ def test_calculate_variable_with_wrong_definition_period(simulation): expected_words = ["period", year, "month", "basic_income", "ADD"] for word in expected_words: - assert word in error_message, f"Expected '{word}' in error message '{error_message}'" + assert ( + word in error_message + ), f"Expected '{word}' in error message '{error_message}'" -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_divide_option_on_month_defined_variable(simulation): - with pytest.raises(ValueError): - simulation.person("disposable_income", PERIOD, options = [populations.DIVIDE]) - - -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_divide_option_with_complex_period(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_divide_option_with_complex_period(simulation) -> None: quarter = PERIOD.last_3_months with pytest.raises(ValueError) as error: - simulation.household("housing_tax", quarter, options = [populations.DIVIDE]) + simulation.household("housing_tax", quarter, options=[populations.DIVIDE]) error_message = str(error.value) - expected_words = ["DIVIDE", "one-year", "one-month", "period"] + expected_words = ["Can't", "calculate", "month", "year"] for word in expected_words: - assert word in error_message, f"Expected '{word}' in error message '{error_message}'" + assert ( + word in error_message + ), f"Expected '{word}' in error message '{error_message}'" -def test_input_with_wrong_period(tax_benefit_system): +def test_input_with_wrong_period(tax_benefit_system) -> None: year = str(PERIOD.this_year) variables = {"basic_income": {year: 12000}} simulation_builder = SimulationBuilder() @@ -93,7 +92,7 @@ def test_input_with_wrong_period(tax_benefit_system): simulation_builder.build_from_variables(tax_benefit_system, variables) -def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): +def test_variable_with_reference(make_simulation, isolated_tax_benefit_system) -> None: variables = {"salary": 4000} simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -102,10 +101,10 @@ def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): assert result > 0 class disposable_income(Variable): - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() isolated_tax_benefit_system.update_variable(disposable_income) simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -115,14 +114,13 @@ def formula(household, period): assert result == 0 -def test_variable_name_conflict(tax_benefit_system): - +def test_variable_name_conflict(tax_benefit_system) -> None: class disposable_income(Variable): reference = "disposable_income" - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() with pytest.raises(VariableNameConflictError): tax_benefit_system.add_variable(disposable_income) diff --git a/tests/core/test_cycles.py b/tests/core/test_cycles.py index 1c4361ded2..acb08c6424 100644 --- a/tests/core/test_cycles.py +++ b/tests/core/test_cycles.py @@ -4,13 +4,14 @@ from openfisca_core import periods, tools from openfisca_core.errors import CycleError +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable @pytest.fixture def reference_period(): - return periods.period('2013-01') + return periods.period("2013-01") @pytest.fixture @@ -22,38 +23,38 @@ def simulation(tax_benefit_system): class variable1(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable2', period) + def formula(self, period): + return self("variable2", period) class variable2(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable1', period) + def formula(self, period): + return self("variable1", period) # 3 <--> 4 with a period offset class variable3(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable4', period.last_month) + def formula(self, period): + return self("variable4", period.last_month) class variable4(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable3', period) + def formula(self, period): + return self("variable3", period) # 5 -f-> 6 with a period offset @@ -61,30 +62,30 @@ def formula(person, period): class variable5(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable6 = person('variable6', period.last_month) + def formula(self, period): + variable6 = self("variable6", period.last_month) return 5 + variable6 class variable6(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person('variable5', period) + def formula(self, period): + variable5 = self("variable5", period) return 6 + variable5 class variable7(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person('variable5', period) + def formula(self, period): + variable5 = self("variable5", period) return 7 + variable5 @@ -92,17 +93,16 @@ def formula(person, period): class cotisation(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period): if period.start.month == 12: - return 2 * person('cotisation', period.last_month) - else: - return person.empty_array() + 1 + return 2 * self("cotisation", period.last_month) + return self.empty_array() + 1 -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( variable1, variable2, @@ -112,35 +112,38 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): variable6, variable7, cotisation, - ) + ) -def test_pure_cycle(simulation, reference_period): +def test_pure_cycle(simulation, reference_period) -> None: with pytest.raises(CycleError): - simulation.calculate('variable1', period = reference_period) + simulation.calculate("variable1", period=reference_period) -def test_spirals_result_in_default_value(simulation, reference_period): - variable3 = simulation.calculate('variable3', period = reference_period) +def test_spirals_result_in_default_value(simulation, reference_period) -> None: + variable3 = simulation.calculate("variable3", period=reference_period) tools.assert_near(variable3, [0]) -def test_spiral_heuristic(simulation, reference_period): - variable5 = simulation.calculate('variable5', period = reference_period) - variable6 = simulation.calculate('variable6', period = reference_period) - variable6_last_month = simulation.calculate('variable6', reference_period.last_month) +def test_spiral_heuristic(simulation, reference_period) -> None: + variable5 = simulation.calculate("variable5", period=reference_period) + variable6 = simulation.calculate("variable6", period=reference_period) + variable6_last_month = simulation.calculate( + "variable6", + reference_period.last_month, + ) tools.assert_near(variable5, [11]) tools.assert_near(variable6, [11]) tools.assert_near(variable6_last_month, [11]) -def test_spiral_cache(simulation, reference_period): - simulation.calculate('variable7', period = reference_period) - cached_variable7 = simulation.get_holder('variable7').get_array(reference_period) +def test_spiral_cache(simulation, reference_period) -> None: + simulation.calculate("variable7", period=reference_period) + cached_variable7 = simulation.get_holder("variable7").get_array(reference_period) assert cached_variable7 is not None -def test_cotisation_1_level(simulation, reference_period): +def test_cotisation_1_level(simulation, reference_period) -> None: month = reference_period.last_month - cotisation = simulation.calculate('cotisation', period = month) + cotisation = simulation.calculate("cotisation", period=month) tools.assert_near(cotisation, [0]) diff --git a/tests/core/test_dump_restore.py b/tests/core/test_dump_restore.py index 5d377913c9..c84044165c 100644 --- a/tests/core/test_dump_restore.py +++ b/tests/core/test_dump_restore.py @@ -9,10 +9,13 @@ from openfisca_core.tools import simulation_dumper -def test_dump(tax_benefit_system): - directory = tempfile.mkdtemp(prefix = "openfisca_") - simulation = SimulationBuilder().build_from_entities(tax_benefit_system, situation_examples.couple) - calculated_value = simulation.calculate('disposable_income', '2018-01') +def test_dump(tax_benefit_system) -> None: + directory = tempfile.mkdtemp(prefix="openfisca_") + simulation = SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.couple, + ) + calculated_value = simulation.calculate("disposable_income", "2018-01") simulation_dumper.dump_simulation(simulation, directory) simulation_2 = simulation_dumper.restore_simulation(directory, tax_benefit_system) @@ -23,14 +26,23 @@ def test_dump(tax_benefit_system): testing.assert_array_equal(simulation.person.count, simulation_2.person.count) testing.assert_array_equal(simulation.household.ids, simulation_2.household.ids) testing.assert_array_equal(simulation.household.count, simulation_2.household.count) - testing.assert_array_equal(simulation.household.members_position, simulation_2.household.members_position) - testing.assert_array_equal(simulation.household.members_entity_id, simulation_2.household.members_entity_id) - testing.assert_array_equal(simulation.household.members_role, simulation_2.household.members_role) + testing.assert_array_equal( + simulation.household.members_position, + simulation_2.household.members_position, + ) + testing.assert_array_equal( + simulation.household.members_entity_id, + simulation_2.household.members_entity_id, + ) + testing.assert_array_equal( + simulation.household.members_role, + simulation_2.household.members_role, + ) # Check calculated values are in cache - disposable_income_holder = simulation_2.person.get_holder('disposable_income') - cached_value = disposable_income_holder.get_array('2018-01') + disposable_income_holder = simulation_2.person.get_holder("disposable_income") + cached_value = disposable_income_holder.get_array("2018-01") assert cached_value is not None testing.assert_array_equal(cached_value, calculated_value) diff --git a/tests/core/test_entities.py b/tests/core/test_entities.py index b15653b055..aba17dc4dc 100644 --- a/tests/core/test_entities.py +++ b/tests/core/test_entities.py @@ -7,17 +7,17 @@ from openfisca_core.tools import test_runner TEST_CASE = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}, 'ind4': {}, 'ind5': {}}, - 'households': { - 'h1': {'children': ['ind2', 'ind3'], 'parents': ['ind0', 'ind1']}, - 'h2': {'children': ['ind5'], 'parents': ['ind4']} - }, - } + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}, "ind4": {}, "ind5": {}}, + "households": { + "h1": {"children": ["ind2", "ind3"], "parents": ["ind0", "ind1"]}, + "h2": {"children": ["ind5"], "parents": ["ind4"]}, + }, +} TEST_CASE_AGES = deepcopy(TEST_CASE) AGES = [40, 37, 7, 9, 54, 20] -for (individu, age) in zip(TEST_CASE_AGES['persons'].values(), AGES): - individu['age'] = age +for individu, age in zip(TEST_CASE_AGES["persons"].values(), AGES): + individu["age"] = age FIRST_PARENT = entities.Household.FIRST_PARENT SECOND_PARENT = entities.Household.SECOND_PARENT @@ -28,22 +28,25 @@ MONTH = "2016-01" -def new_simulation(tax_benefit_system, test_case, period = MONTH): +def new_simulation(tax_benefit_system, test_case, period=MONTH): simulation_builder = SimulationBuilder() simulation_builder.set_default_period(period) return simulation_builder.build_from_entities(tax_benefit_system, test_case) -def test_role_index_and_positions(tax_benefit_system): +def test_role_index_and_positions(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) tools.assert_near(simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1]) - assert((simulation.household.members_role == [FIRST_PARENT, SECOND_PARENT, CHILD, CHILD, FIRST_PARENT, CHILD]).all()) + assert ( + simulation.household.members_role + == [FIRST_PARENT, SECOND_PARENT, CHILD, CHILD, FIRST_PARENT, CHILD] + ).all() tools.assert_near(simulation.household.members_position, [0, 1, 2, 3, 0, 1]) - assert(simulation.person.ids == ["ind0", "ind1", "ind2", "ind3", "ind4", "ind5"]) - assert(simulation.household.ids == ['h1', 'h2']) + assert simulation.person.ids == ["ind0", "ind1", "ind2", "ind3", "ind4", "ind5"] + assert simulation.household.ids == ["h1", "h2"] -def test_entity_structure_with_constructor(tax_benefit_system): +def test_entity_structure_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -64,16 +67,22 @@ def test_entity_structure_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) household = simulation.household tools.assert_near(household.members_entity_id, [0, 0, 1, 0, 0]) - assert((household.members_role == [FIRST_PARENT, SECOND_PARENT, FIRST_PARENT, CHILD, CHILD]).all()) + assert ( + household.members_role + == [FIRST_PARENT, SECOND_PARENT, FIRST_PARENT, CHILD, CHILD] + ).all() tools.assert_near(household.members_position, [0, 1, 0, 2, 3]) -def test_entity_variables_with_constructor(tax_benefit_system): +def test_entity_variables_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -98,12 +107,15 @@ def test_entity_variables_with_constructor(tax_benefit_system): 2017-06: 600 """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) household = simulation.household - tools.assert_near(household('rent', "2017-06"), [800, 600]) + tools.assert_near(household("rent", "2017-06"), [800, 600]) -def test_person_variable_with_constructor(tax_benefit_system): +def test_person_variable_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -131,13 +143,16 @@ def test_person_variable_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) person = simulation.person - tools.assert_near(person('salary', "2017-11"), [1500, 0, 3000, 0, 0]) - tools.assert_near(person('salary', "2017-12"), [2000, 0, 4000, 0, 0]) + tools.assert_near(person("salary", "2017-11"), [1500, 0, 3000, 0, 0]) + tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) -def test_set_input_with_constructor(tax_benefit_system): +def test_set_input_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -170,136 +185,148 @@ def test_set_input_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) person = simulation.person - tools.assert_near(person('salary', "2017-12"), [2000, 0, 4000, 0, 0]) - tools.assert_near(person('salary', "2017-10"), [2000, 3000, 1600, 0, 0]) + tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) + tools.assert_near(person("salary", "2017-10"), [2000, 3000, 1600, 0, 0]) -def test_has_role(tax_benefit_system): +def test_has_role(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near(individu.has_role(CHILD), [False, False, True, True, False, True]) -def test_has_role_with_subrole(tax_benefit_system): +def test_has_role_with_subrole(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons - tools.assert_near(individu.has_role(PARENT), [True, True, False, False, True, False]) - tools.assert_near(individu.has_role(FIRST_PARENT), [True, False, False, False, True, False]) - tools.assert_near(individu.has_role(SECOND_PARENT), [False, True, False, False, False, False]) - - -def test_project(tax_benefit_system): + tools.assert_near( + individu.has_role(PARENT), + [True, True, False, False, True, False], + ) + tools.assert_near( + individu.has_role(FIRST_PARENT), + [True, False, False, False, True, False], + ) + tools.assert_near( + individu.has_role(SECOND_PARENT), + [False, True, False, False, False, False], + ) + + +def test_project(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['households']['h1']['housing_tax'] = 20000 + test_case["households"]["h1"]["housing_tax"] = 20000 simulation = new_simulation(tax_benefit_system, test_case, YEAR) household = simulation.household - housing_tax = household('housing_tax', YEAR) + housing_tax = household("housing_tax", YEAR) projected_housing_tax = household.project(housing_tax) tools.assert_near(projected_housing_tax, [20000, 20000, 20000, 20000, 0, 0]) - housing_tax_projected_on_parents = household.project(housing_tax, role = PARENT) + housing_tax_projected_on_parents = household.project(housing_tax, role=PARENT) tools.assert_near(housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0]) -def test_implicit_projection(tax_benefit_system): +def test_implicit_projection(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['households']['h1']['housing_tax'] = 20000 + test_case["households"]["h1"]["housing_tax"] = 20000 simulation = new_simulation(tax_benefit_system, test_case, YEAR) individu = simulation.person - housing_tax = individu.household('housing_tax', YEAR) + housing_tax = individu.household("housing_tax", YEAR) tools.assert_near(housing_tax, [20000, 20000, 20000, 20000, 0, 0]) -def test_sum(tax_benefit_system): +def test_sum(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case, MONTH) household = simulation.household - salary = household.members('salary', "2016-01") + salary = household.members("salary", "2016-01") total_salary_by_household = household.sum(salary) tools.assert_near(total_salary_by_household, [2500, 3500]) - total_salary_parents_by_household = household.sum(salary, role = PARENT) + total_salary_parents_by_household = household.sum(salary, role=PARENT) tools.assert_near(total_salary_parents_by_household, [2500, 3000]) -def test_any(tax_benefit_system): +def test_any(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) - condition_age = (age <= 18) + age = household.members("age", period=MONTH) + condition_age = age <= 18 has_household_member_with_age_inf_18 = household.any(condition_age) tools.assert_near(has_household_member_with_age_inf_18, [True, False]) - condition_age_2 = (age > 18) - has_household_CHILD_with_age_sup_18 = household.any(condition_age_2, role = CHILD) + condition_age_2 = age > 18 + has_household_CHILD_with_age_sup_18 = household.any(condition_age_2, role=CHILD) tools.assert_near(has_household_CHILD_with_age_sup_18, [False, True]) -def test_all(tax_benefit_system): +def test_all(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) - condition_age = (age >= 18) + condition_age = age >= 18 all_persons_age_sup_18 = household.all(condition_age) tools.assert_near(all_persons_age_sup_18, [False, True]) - all_parents_age_sup_18 = household.all(condition_age, role = PARENT) + all_parents_age_sup_18 = household.all(condition_age, role=PARENT) tools.assert_near(all_parents_age_sup_18, [True, True]) -def test_max(tax_benefit_system): +def test_max(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) age_max = household.max(age) tools.assert_near(age_max, [40, 54]) - age_max_child = household.max(age, role = CHILD) + age_max_child = household.max(age, role=CHILD) tools.assert_near(age_max_child, [9, 20]) -def test_min(tax_benefit_system): +def test_min(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) age_min = household.min(age) tools.assert_near(age_min, [7, 20]) - age_min_parents = household.min(age, role = PARENT) + age_min_parents = household.min(age, role=PARENT) tools.assert_near(age_min_parents, [37, 54]) -def test_value_nth_person(tax_benefit_system): +def test_value_nth_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - array = household.members('age', MONTH) + array = household.members("age", MONTH) result0 = household.value_nth_person(0, array, default=-1) tools.assert_near(result0, [40, 54]) @@ -314,141 +341,157 @@ def test_value_nth_person(tax_benefit_system): tools.assert_near(result3, [9, -1]) -def test_rank(tax_benefit_system): +def test_rank(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) person = simulation.person - age = person('age', MONTH) # [40, 37, 7, 9, 54, 20] + age = person("age", MONTH) # [40, 37, 7, 9, 54, 20] rank = person.get_rank(person.household, age) tools.assert_near(rank, [3, 2, 0, 1, 1, 0]) - rank_in_siblings = person.get_rank(person.household, - age, condition = person.has_role(entities.Household.CHILD)) + rank_in_siblings = person.get_rank( + person.household, + -age, + condition=person.has_role(entities.Household.CHILD), + ) tools.assert_near(rank_in_siblings, [-1, -1, 1, 0, -1, 0]) -def test_partner(tax_benefit_system): +def test_partner(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case) persons = simulation.persons - salary = persons('salary', period = MONTH) + salary = persons("salary", period=MONTH) salary_second_parent = persons.value_from_partner(salary, persons.household, PARENT) tools.assert_near(salary_second_parent, [1500, 1000, 0, 0, 0, 0]) -def test_value_from_first_person(tax_benefit_system): +def test_value_from_first_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - salaries = household.members('salary', period = MONTH) + salaries = household.members("salary", period=MONTH) salary_first_person = household.value_from_first_person(salaries) tools.assert_near(salary_first_person, [1000, 3000]) -def test_projectors_methods(tax_benefit_system): - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, situation_examples.couple) +def test_projectors_methods(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + situation_examples.couple, + ) household = simulation.household person = simulation.person projected_vector = household.first_parent.has_role(entities.Household.FIRST_PARENT) - assert(len(projected_vector) == 1) # Must be of a household dimension + assert len(projected_vector) == 1 # Must be of a household dimension - salary_i = person.household.members('salary', '2017-01') - assert(len(person.household.sum(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.max(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.min(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.all(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.any(salary_i)) == 2) # Must be of a person dimension - assert(len(household.first_parent.get_rank(household, salary_i)) == 1) # Must be of a person dimension + salary_i = person.household.members("salary", "2017-01") + assert len(person.household.sum(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.max(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.min(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.all(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.any(salary_i)) == 2 # Must be of a person dimension + assert ( + len(household.first_parent.get_rank(household, salary_i)) == 1 + ) # Must be of a person dimension -def test_sum_following_bug_ipp_1(tax_benefit_system): +def test_sum_following_bug_ipp_1(tax_benefit_system) -> None: test_case = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}}, - 'households': { - 'h1': {'parents': ['ind0']}, - 'h2': {'parents': ['ind1'], 'children': ['ind2', 'ind3']} - }, - } - test_case['persons']['ind0']['salary'] = 2000 - test_case['persons']['ind1']['salary'] = 2000 - test_case['persons']['ind2']['salary'] = 1000 - test_case['persons']['ind3']['salary'] = 1000 + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, + "households": { + "h1": {"parents": ["ind0"]}, + "h2": {"parents": ["ind1"], "children": ["ind2", "ind3"]}, + }, + } + test_case["persons"]["ind0"]["salary"] = 2000 + test_case["persons"]["ind1"]["salary"] = 2000 + test_case["persons"]["ind2"]["salary"] = 1000 + test_case["persons"]["ind3"]["salary"] = 1000 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - eligible_i = household.members('salary', period = MONTH) < 1500 - nb_eligibles_by_household = household.sum(eligible_i, role = CHILD) + eligible_i = household.members("salary", period=MONTH) < 1500 + nb_eligibles_by_household = household.sum(eligible_i, role=CHILD) tools.assert_near(nb_eligibles_by_household, [0, 2]) -def test_sum_following_bug_ipp_2(tax_benefit_system): +def test_sum_following_bug_ipp_2(tax_benefit_system) -> None: test_case = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}}, - 'households': { - 'h1': {'parents': ['ind1'], 'children': ['ind2', 'ind3']}, - 'h2': {'parents': ['ind0']}, - }, - } - test_case['persons']['ind0']['salary'] = 2000 - test_case['persons']['ind1']['salary'] = 2000 - test_case['persons']['ind2']['salary'] = 1000 - test_case['persons']['ind3']['salary'] = 1000 + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, + "households": { + "h1": {"parents": ["ind1"], "children": ["ind2", "ind3"]}, + "h2": {"parents": ["ind0"]}, + }, + } + test_case["persons"]["ind0"]["salary"] = 2000 + test_case["persons"]["ind1"]["salary"] = 2000 + test_case["persons"]["ind2"]["salary"] = 1000 + test_case["persons"]["ind3"]["salary"] = 1000 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - eligible_i = household.members('salary', period = MONTH) < 1500 - nb_eligibles_by_household = household.sum(eligible_i, role = CHILD) + eligible_i = household.members("salary", period=MONTH) < 1500 + nb_eligibles_by_household = household.sum(eligible_i, role=CHILD) tools.assert_near(nb_eligibles_by_household, [2, 0]) -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: test_case = deepcopy(situation_examples.single) test_case["persons"]["Alicia"]["salary"] = {"2017-01": 0} simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_case) - simulation.calculate('disposable_income', '2017-01') - memory_usage = simulation.person.get_memory_usage(variables = ['salary']) - assert(memory_usage['total_nb_bytes'] > 0) - assert(len(memory_usage['by_variable']) == 1) + simulation.calculate("disposable_income", "2017-01") + memory_usage = simulation.person.get_memory_usage(variables=["salary"]) + assert memory_usage["total_nb_bytes"] > 0 + assert len(memory_usage["by_variable"]) == 1 -def test_unordered_persons(tax_benefit_system): +def test_unordered_persons(tax_benefit_system) -> None: test_case = { - 'persons': {'ind4': {}, 'ind3': {}, 'ind1': {}, 'ind2': {}, 'ind5': {}, 'ind0': {}}, - 'households': { - 'h1': {'children': ['ind2', 'ind3'], 'parents': ['ind0', 'ind1']}, - 'h2': {'children': ['ind5'], 'parents': ['ind4']} - }, - } + "persons": { + "ind4": {}, + "ind3": {}, + "ind1": {}, + "ind2": {}, + "ind5": {}, + "ind0": {}, + }, + "households": { + "h1": {"children": ["ind2", "ind3"], "parents": ["ind0", "ind1"]}, + "h2": {"children": ["ind5"], "parents": ["ind4"]}, + }, + } # 1st family - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind2']['salary'] = 20 - test_case['households']['h1']['accommodation_size'] = 160 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind2"]["salary"] = 20 + test_case["households"]["h1"]["accommodation_size"] = 160 # 2nd family - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 - test_case['households']['h2']['accommodation_size'] = 60 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 + test_case["households"]["h2"]["accommodation_size"] = 60 # household.members_entity_id == [1, 0, 0, 0, 1, 0] @@ -456,8 +499,8 @@ def test_unordered_persons(tax_benefit_system): household = simulation.household person = simulation.person - salary = household.members('salary', "2016-01") # [ 3000, 0, 1500, 20, 500, 1000 ] - accommodation_size = household('accommodation_size', "2016-01") # [ 160, 60 ] + salary = household.members("salary", "2016-01") # [ 3000, 0, 1500, 20, 500, 1000 ] + accommodation_size = household("accommodation_size", "2016-01") # [ 160, 60 ] # Aggregation/Projection persons -> entity @@ -466,30 +509,42 @@ def test_unordered_persons(tax_benefit_system): tools.assert_near(household.min(salary), [0, 500]) tools.assert_near(household.all(salary > 0), [False, True]) tools.assert_near(household.any(salary > 2000), [False, True]) - tools.assert_near(household.first_person('salary', "2016-01"), [0, 3000]) - tools.assert_near(household.first_parent('salary', "2016-01"), [1000, 3000]) - tools.assert_near(household.second_parent('salary', "2016-01"), [1500, 0]) - tools.assert_near(person.value_from_partner(salary, person.household, PARENT), [0, 0, 1000, 0, 0, 1500]) - - tools.assert_near(household.sum(salary, role = PARENT), [2500, 3000]) - tools.assert_near(household.sum(salary, role = CHILD), [20, 500]) - tools.assert_near(household.max(salary, role = PARENT), [1500, 3000]) - tools.assert_near(household.max(salary, role = CHILD), [20, 500]) - tools.assert_near(household.min(salary, role = PARENT), [1000, 3000]) - tools.assert_near(household.min(salary, role = CHILD), [0, 500]) - tools.assert_near(household.all(salary > 0, role = PARENT), [True, True]) - tools.assert_near(household.all(salary > 0, role = CHILD), [False, True]) - tools.assert_near(household.any(salary < 1500, role = PARENT), [True, False]) - tools.assert_near(household.any(salary > 200, role = CHILD), [False, True]) + tools.assert_near(household.first_person("salary", "2016-01"), [0, 3000]) + tools.assert_near(household.first_parent("salary", "2016-01"), [1000, 3000]) + tools.assert_near(household.second_parent("salary", "2016-01"), [1500, 0]) + tools.assert_near( + person.value_from_partner(salary, person.household, PARENT), + [0, 0, 1000, 0, 0, 1500], + ) + + tools.assert_near(household.sum(salary, role=PARENT), [2500, 3000]) + tools.assert_near(household.sum(salary, role=CHILD), [20, 500]) + tools.assert_near(household.max(salary, role=PARENT), [1500, 3000]) + tools.assert_near(household.max(salary, role=CHILD), [20, 500]) + tools.assert_near(household.min(salary, role=PARENT), [1000, 3000]) + tools.assert_near(household.min(salary, role=CHILD), [0, 500]) + tools.assert_near(household.all(salary > 0, role=PARENT), [True, True]) + tools.assert_near(household.all(salary > 0, role=CHILD), [False, True]) + tools.assert_near(household.any(salary < 1500, role=PARENT), [True, False]) + tools.assert_near(household.any(salary > 200, role=CHILD), [False, True]) # nb_persons tools.assert_near(household.nb_persons(), [4, 2]) - tools.assert_near(household.nb_persons(role = PARENT), [2, 1]) - tools.assert_near(household.nb_persons(role = CHILD), [2, 1]) + tools.assert_near(household.nb_persons(role=PARENT), [2, 1]) + tools.assert_near(household.nb_persons(role=CHILD), [2, 1]) # Projection entity -> persons - tools.assert_near(household.project(accommodation_size), [60, 160, 160, 160, 60, 160]) - tools.assert_near(household.project(accommodation_size, role = PARENT), [60, 0, 160, 0, 0, 160]) - tools.assert_near(household.project(accommodation_size, role = CHILD), [0, 160, 0, 160, 60, 0]) + tools.assert_near( + household.project(accommodation_size), + [60, 160, 160, 160, 60, 160], + ) + tools.assert_near( + household.project(accommodation_size, role=PARENT), + [60, 0, 160, 0, 0, 160], + ) + tools.assert_near( + household.project(accommodation_size, role=CHILD), + [0, 160, 0, 160, 60, 0], + ) diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 5c3da81d66..4854815ac3 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -1,24 +1,26 @@ import pytest -def test_load_extension(tax_benefit_system): +def test_load_extension(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() - assert tbs.get_variable('local_town_child_allowance') is None + assert tbs.get_variable("local_town_child_allowance") is None - tbs.load_extension('openfisca_extension_template') + tbs.load_extension("openfisca_extension_template") - assert tbs.get_variable('local_town_child_allowance') is not None - assert tax_benefit_system.get_variable('local_town_child_allowance') is None + assert tbs.get_variable("local_town_child_allowance") is not None + assert tax_benefit_system.get_variable("local_town_child_allowance") is None -def test_access_to_parameters(tax_benefit_system): +def test_access_to_parameters(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() - tbs.load_extension('openfisca_extension_template') + tbs.load_extension("openfisca_extension_template") - assert tbs.parameters('2016-01').local_town.child_allowance.amount == 100.0 - assert tbs.parameters.local_town.child_allowance.amount('2016-01') == 100.0 + assert tbs.parameters("2016-01").local_town.child_allowance.amount == 100.0 + assert tbs.parameters.local_town.child_allowance.amount("2016-01") == 100.0 -def test_failure_to_load_extension_when_directory_doesnt_exist(tax_benefit_system): +def test_failure_to_load_extension_when_directory_doesnt_exist( + tax_benefit_system, +) -> None: with pytest.raises(ValueError): - tax_benefit_system.load_extension('/this/is/not/a/real/path') + tax_benefit_system.load_extension("/this/is/not/a/real/path") diff --git a/tests/core/test_formulas.py b/tests/core/test_formulas.py index 8851671755..32e6fd35e7 100644 --- a/tests/core/test_formulas.py +++ b/tests/core/test_formulas.py @@ -1,102 +1,108 @@ import numpy +from pytest import approx, fixture from openfisca_country_template import entities -from openfisca_core import commons, periods +from openfisca_core import commons +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable -from pytest import fixture, approx - class choice(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH class uses_multiplication(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that uses multiplication' - definition_period = periods.MONTH + label = "Variable with formula that uses multiplication" + definition_period = DateUnit.MONTH - def formula(person, period): - choice = person('choice', period) - result = (choice == 1) * 80 + (choice == 2) * 90 - return result + def formula(self, period): + choice = self("choice", period) + return (choice == 1) * 80 + (choice == 2) * 90 class returns_scalar(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that returns a scalar value' - definition_period = periods.MONTH + label = "Variable with formula that returns a scalar value" + definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period) -> int: return 666 class uses_switch(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that uses switch' - definition_period = periods.MONTH + label = "Variable with formula that uses switch" + definition_period = DateUnit.MONTH - def formula(person, period): - choice = person('choice', period) - result = commons.switch( + def formula(self, period): + choice = self("choice", period) + return commons.switch( choice, { 1: 80, 2: 90, - }, - ) - return result + }, + ) -@fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): - tax_benefit_system.add_variables(choice, uses_multiplication, uses_switch, returns_scalar) +@fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: + tax_benefit_system.add_variables( + choice, + uses_multiplication, + uses_switch, + returns_scalar, + ) @fixture -def month(): - return '2013-01' +def month() -> str: + return "2013-01" @fixture def simulation(tax_benefit_system, month): simulation_builder = SimulationBuilder() simulation_builder.default_period = month - simulation = simulation_builder.build_from_variables(tax_benefit_system, {'choice': numpy.random.randint(2, size = 1000) + 1}) + simulation = simulation_builder.build_from_variables( + tax_benefit_system, + {"choice": numpy.random.randint(2, size=1000) + 1}, + ) simulation.debug = True return simulation -def test_switch(simulation, month): - uses_switch = simulation.calculate('uses_switch', period = month) +def test_switch(simulation, month) -> None: + uses_switch = simulation.calculate("uses_switch", period=month) assert isinstance(uses_switch, numpy.ndarray) -def test_multiplication(simulation, month): - uses_multiplication = simulation.calculate('uses_multiplication', period = month) +def test_multiplication(simulation, month) -> None: + uses_multiplication = simulation.calculate("uses_multiplication", period=month) assert isinstance(uses_multiplication, numpy.ndarray) -def test_broadcast_scalar(simulation, month): - array_value = simulation.calculate('returns_scalar', period = month) +def test_broadcast_scalar(simulation, month) -> None: + array_value = simulation.calculate("returns_scalar", period=month) assert isinstance(array_value, numpy.ndarray) assert array_value == approx(numpy.repeat(666, 1000)) -def test_compare_multiplication_and_switch(simulation, month): - uses_multiplication = simulation.calculate('uses_multiplication', period = month) - uses_switch = simulation.calculate('uses_switch', period = month) +def test_compare_multiplication_and_switch(simulation, month) -> None: + uses_multiplication = simulation.calculate("uses_multiplication", period=month) + uses_switch = simulation.calculate("uses_switch", period=month) assert numpy.all(uses_switch == uses_multiplication) -def test_group_encapsulation(): +def test_group_encapsulation() -> None: """Projects a calculation to all members of an entity. When a household contains more than one family @@ -104,37 +110,41 @@ def test_group_encapsulation(): And calculations are projected to all the member families. """ - from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_core.entities import build_entity - from openfisca_core.periods import ETERNITY + from openfisca_core.periods import DateUnit + from openfisca_core.taxbenefitsystems import TaxBenefitSystem person_entity = build_entity( key="person", plural="people", label="A person", is_person=True, - ) + ) family_entity = build_entity( key="family", plural="families", label="A family (all members in the same household)", containing_entities=["household"], - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) household_entity = build_entity( key="household", plural="households", label="A household, containing one or more families", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person_entity, family_entity, household_entity] @@ -143,40 +153,35 @@ def test_group_encapsulation(): class household_level_variable(Variable): value_type = int entity = household_entity - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY class projected_family_level_variable(Variable): value_type = int entity = family_entity - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) system.add_variables(household_level_variable, projected_family_level_variable) - simulation = SimulationBuilder().build_from_dict(system, { - "people": { - "person1": {}, - "person2": {}, - "person3": {} + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "families": { + "family1": {"members": ["person1", "person2"]}, + "family2": {"members": ["person3"]}, }, - "families": { - "family1": { - "members": ["person1", "person2"] - }, - "family2": { - "members": ["person3"] + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_level_variable": {"eternity": 5}, }, }, - "households": { - "household1": { - "members": ["person1", "person2", "person3"], - "household_level_variable": { - "eternity": 5 - } - } - } - }) - - assert (simulation.calculate("projected_family_level_variable", "2021-01-01") == 5).all() + }, + ) + + assert ( + simulation.calculate("projected_family_level_variable", "2021-01-01") == 5 + ).all() diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index d06ce34f04..b784aea41b 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -1,6 +1,5 @@ -import pytest - import numpy +import pytest from openfisca_country_template import situation_examples from openfisca_country_template.variables import housing @@ -8,178 +7,212 @@ from openfisca_core import holders, periods, tools from openfisca_core.errors import PeriodMismatchError from openfisca_core.experimental import MemoryConfig -from openfisca_core.simulations import SimulationBuilder from openfisca_core.holders import Holder +from openfisca_core.periods import DateUnit +from openfisca_core.simulations import SimulationBuilder @pytest.fixture def single(tax_benefit_system): - return \ - SimulationBuilder() \ - .build_from_entities(tax_benefit_system, situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.single, + ) @pytest.fixture def couple(tax_benefit_system): - return \ - SimulationBuilder(). \ - build_from_entities(tax_benefit_system, situation_examples.couple) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.couple, + ) -period = periods.period('2017-12') +period = periods.period("2017-12") -def test_set_input_enum_string(couple): +def test_set_input_enum_string(couple) -> None: simulation = couple - status_occupancy = numpy.asarray(['free_lodger']) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray(["free_lodger"]) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_int(couple): +def test_set_input_enum_int(couple) -> None: simulation = couple - status_occupancy = numpy.asarray([2], dtype = numpy.int16) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray([2], dtype=numpy.int16) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_item(couple): +def test_set_input_enum_item(couple) -> None: simulation = couple status_occupancy = numpy.asarray([housing.HousingOccupancyStatus.free_lodger]) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_yearly_input_month_variable(couple): +def test_yearly_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('rent', 2019, 3000) - assert 'Unable to set a value for variable "rent" for year-long period' in error.value.message + couple.set_input("rent", 2019, 3000) + assert ( + 'Unable to set a value for variable "rent" for year-long period' + in error.value.message + ) -def test_3_months_input_month_variable(couple): +def test_3_months_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('rent', 'month:2019-01:3', 3000) - assert 'Unable to set a value for variable "rent" for 3-months-long period' in error.value.message + couple.set_input("rent", "month:2019-01:3", 3000) + assert ( + 'Unable to set a value for variable "rent" for 3-months-long period' + in error.value.message + ) -def test_month_input_year_variable(couple): +def test_month_input_year_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('housing_tax', '2019-01', 3000) - assert 'Unable to set a value for variable "housing_tax" for month-long period' in error.value.message + couple.set_input("housing_tax", "2019-01", 3000) + assert ( + 'Unable to set a value for variable "housing_tax" for month-long period' + in error.value.message + ) -def test_enum_dtype(couple): +def test_enum_dtype(couple) -> None: simulation = couple - status_occupancy = numpy.asarray([2], dtype = numpy.int16) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray([2], dtype=numpy.int16) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result.dtype.kind is not None -def test_permanent_variable_empty(single): +def test_permanent_variable_empty(single) -> None: simulation = single - holder = simulation.person.get_holder('birth') + holder = simulation.person.get_holder("birth") assert holder.get_array(None) is None -def test_permanent_variable_filled(single): +def test_permanent_variable_filled(single) -> None: simulation = single - holder = simulation.person.get_holder('birth') - value = numpy.asarray(['1980-01-01'], dtype = holder.variable.dtype) - holder.set_input(periods.period(periods.ETERNITY), value) + holder = simulation.person.get_holder("birth") + value = numpy.asarray(["1980-01-01"], dtype=holder.variable.dtype) + holder.set_input(periods.period(DateUnit.ETERNITY), value) assert holder.get_array(None) == value - assert holder.get_array(periods.ETERNITY) == value - assert holder.get_array('2016-01') == value + assert holder.get_array(DateUnit.ETERNITY) == value + assert holder.get_array("2016-01") == value -def test_delete_arrays(single): +def test_delete_arrays(single) -> None: simulation = single - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) salary_holder.set_input(periods.period(2018), numpy.asarray([60000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 5000 - salary_holder.delete_arrays(period = 2018) + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 5000 + salary_holder.delete_arrays(period=2018) + + salary_array = simulation.get_array("salary", "2017-01") + assert salary_array is not None + salary_array = simulation.get_array("salary", "2018-01") + assert salary_array is None + salary_holder.set_input(periods.period(2018), numpy.asarray([15000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 1250 + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 1250 -def test_get_memory_usage(single): +def test_get_memory_usage(single) -> None: simulation = single - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") memory_usage = salary_holder.get_memory_usage() - assert memory_usage['total_nb_bytes'] == 0 + assert memory_usage["total_nb_bytes"] == 0 salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) memory_usage = salary_holder.get_memory_usage() - assert memory_usage['nb_cells_by_array'] == 1 - assert memory_usage['cell_size'] == 4 # float 32 - assert memory_usage['nb_cells_by_array'] == 1 # one person - assert memory_usage['nb_arrays'] == 12 # 12 months - assert memory_usage['total_nb_bytes'] == 4 * 12 * 1 + assert memory_usage["nb_cells_by_array"] == 1 + assert memory_usage["cell_size"] == 4 # float 32 + assert memory_usage["nb_cells_by_array"] == 1 # one person + assert memory_usage["nb_arrays"] == 12 # 12 months + assert memory_usage["total_nb_bytes"] == 4 * 12 * 1 -def test_get_memory_usage_with_trace(single): +def test_get_memory_usage_with_trace(single) -> None: simulation = single simulation.trace = True - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) - simulation.calculate('salary', '2017-01') - simulation.calculate('salary', '2017-01') - simulation.calculate('salary', '2017-02') - simulation.calculate_add('salary', '2017') # 12 calculations + simulation.calculate("salary", "2017-01") + simulation.calculate("salary", "2017-01") + simulation.calculate("salary", "2017-02") + simulation.calculate_add("salary", "2017") # 12 calculations memory_usage = salary_holder.get_memory_usage() - assert memory_usage['nb_requests'] == 15 - assert memory_usage['nb_requests_by_array'] == 1.25 # 15 calculations / 12 arrays + assert memory_usage["nb_requests"] == 15 + assert memory_usage["nb_requests_by_array"] == 1.25 # 15 calculations / 12 arrays -def test_set_input_dispatch_by_period(single): +def test_set_input_dispatch_by_period(single) -> None: simulation = single - variable = simulation.tax_benefit_system.get_variable('housing_occupancy_status') + variable = simulation.tax_benefit_system.get_variable("housing_occupancy_status") entity = simulation.household holder = Holder(variable, entity) - holders.set_input_dispatch_by_period(holder, periods.period(2019), 'owner') - assert holder.get_array('2019-01') == holder.get_array('2019-12') # Check the feature - assert holder.get_array('2019-01') is holder.get_array('2019-12') # Check that the vectors are the same in memory, to avoid duplication + holders.set_input_dispatch_by_period(holder, periods.period(2019), "owner") + assert holder.get_array("2019-01") == holder.get_array( + "2019-12", + ) # Check the feature + assert holder.get_array("2019-01") is holder.get_array( + "2019-12", + ) # Check that the vectors are the same in memory, to avoid duplication -force_storage_on_disk = MemoryConfig(max_memory_occupation = 0) +force_storage_on_disk = MemoryConfig(max_memory_occupation=0) -def test_delete_arrays_on_disk(single): +def test_delete_arrays_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) salary_holder.set_input(periods.period(2018), numpy.asarray([60000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 5000 - salary_holder.delete_arrays(period = 2018) + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 5000 + salary_holder.delete_arrays(period=2018) salary_holder.set_input(periods.period(2018), numpy.asarray([15000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 1250 + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 1250 -def test_cache_disk(couple): +def test_cache_disk(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - holder = simulation.person.get_holder('disposable_income') + month = periods.period("2017-01") + holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) holder.put_in_cache(data, month) stored_data = holder.get_array(month) tools.assert_near(data, stored_data) -def test_known_periods(couple): +def test_known_periods(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - month_2 = periods.period('2017-02') - holder = simulation.person.get_holder('disposable_income') + month = periods.period("2017-01") + month_2 = periods.period("2017-02") + holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) holder.put_in_cache(data, month) holder._memory_storage.put(data, month_2) @@ -187,28 +220,34 @@ def test_known_periods(couple): assert sorted(holder.get_known_periods()), [month == month_2] -def test_cache_enum_on_disk(single): +def test_cache_enum_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - simulation.calculate('housing_occupancy_status', month) # First calculation - housing_occupancy_status = simulation.calculate('housing_occupancy_status', month) # Read from cache + month = periods.period("2017-01") + simulation.calculate("housing_occupancy_status", month) # First calculation + housing_occupancy_status = simulation.calculate( + "housing_occupancy_status", + month, + ) # Read from cache assert housing_occupancy_status == housing.HousingOccupancyStatus.tenant -def test_set_not_cached_variable(single): - dont_cache_variable = MemoryConfig(max_memory_occupation = 1, variables_to_drop = ['salary']) +def test_set_not_cached_variable(single) -> None: + dont_cache_variable = MemoryConfig( + max_memory_occupation=1, + variables_to_drop=["salary"], + ) simulation = single simulation.memory_config = dont_cache_variable - holder = simulation.person.get_holder('salary') + holder = simulation.person.get_holder("salary") array = numpy.asarray([2000]) - holder.set_input('2015-01', array) - assert simulation.calculate('salary', '2015-01') == array + holder.set_input("2015-01", array) + assert simulation.calculate("salary", "2015-01") == array -def test_set_input_float_to_int(single): +def test_set_input_float_to_int(single) -> None: simulation = single age = numpy.asarray([50.6]) - simulation.person.get_holder('age').set_input(period, age) - result = simulation.calculate('age', period) + simulation.person.get_holder("age").set_input(period, age) + result = simulation.calculate("age", period) assert result == numpy.asarray([50]) diff --git a/tests/core/test_opt_out_cache.py b/tests/core/test_opt_out_cache.py index b4eab3e5a5..2f61da2898 100644 --- a/tests/core/test_opt_out_cache.py +++ b/tests/core/test_opt_out_cache.py @@ -3,10 +3,9 @@ from openfisca_country_template.entities import Person from openfisca_core import periods -from openfisca_core.periods import MONTH +from openfisca_core.periods import DateUnit from openfisca_core.variables import Variable - PERIOD = periods.period("2016-01") @@ -14,57 +13,57 @@ class input(Variable): value_type = int entity = Person label = "Input variable" - definition_period = MONTH + definition_period = DateUnit.MONTH class intermediate(Variable): value_type = int entity = Person label = "Intermediate result that don't need to be cached" - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('input', period) + def formula(self, period): + return self("input", period) class output(Variable): value_type = int entity = Person - label = 'Output variable' - definition_period = MONTH + label = "Output variable" + definition_period = DateUnit.MONTH - def formula(person, period): - return person('intermediate', period) + def formula(self, period): + return self("intermediate", period) -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(input, intermediate, output) -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_cache_blakclist(tax_benefit_system): - tax_benefit_system.cache_blacklist = set(['intermediate']) +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_cache_blakclist(tax_benefit_system) -> None: + tax_benefit_system.cache_blacklist = {"intermediate"} -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_without_cache_opt_out(simulation): - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is not None) +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_without_cache_opt_out(simulation) -> None: + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is not None -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_with_cache_opt_out(simulation): +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_with_cache_opt_out(simulation) -> None: simulation.debug = True simulation.opt_out_cache = True - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is None) + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is None -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_with_no_blacklist(simulation): - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is not None) +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_with_no_blacklist(simulation) -> None: + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is not None diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 4f74f9d907..7fe63a8180 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -2,107 +2,135 @@ import pytest -from openfisca_core.errors import ParameterNotFoundError -from openfisca_core.parameters import ParameterNode, ParameterNodeAtInstant, load_parameter_file +from openfisca_core.parameters import ( + ParameterNode, + ParameterNodeAtInstant, + ParameterNotFound, + load_parameter_file, +) -def test_get_at_instant(tax_benefit_system): +def test_get_at_instant(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters assert isinstance(parameters, ParameterNode), parameters - parameters_at_instant = parameters('2016-01-01') - assert isinstance(parameters_at_instant, ParameterNodeAtInstant), parameters_at_instant + parameters_at_instant = parameters("2016-01-01") + assert isinstance( + parameters_at_instant, + ParameterNodeAtInstant, + ), parameters_at_instant assert parameters_at_instant.taxes.income_tax_rate == 0.15 assert parameters_at_instant.benefits.basic_income == 600 -def test_param_values(tax_benefit_system): +def test_param_values(tax_benefit_system) -> None: dated_values = { - '2015-01-01': 0.15, - '2014-01-01': 0.14, - '2013-01-01': 0.13, - '2012-01-01': 0.16, - } + "2015-01-01": 0.15, + "2014-01-01": 0.14, + "2013-01-01": 0.13, + "2012-01-01": 0.16, + } for date, value in dated_values.items(): - assert tax_benefit_system.get_parameters_at_instant(date).taxes.income_tax_rate == value + assert ( + tax_benefit_system.get_parameters_at_instant(date).taxes.income_tax_rate + == value + ) -def test_param_before_it_is_defined(tax_benefit_system): - with pytest.raises(ParameterNotFoundError): - tax_benefit_system.get_parameters_at_instant('1997-12-31').taxes.income_tax_rate +def test_param_before_it_is_defined(tax_benefit_system) -> None: + with pytest.raises(ParameterNotFound): + tax_benefit_system.get_parameters_at_instant("1997-12-31").taxes.income_tax_rate # The placeholder should have no effect on the parameter computation -def test_param_with_placeholder(tax_benefit_system): - assert tax_benefit_system.get_parameters_at_instant('2018-01-01').taxes.income_tax_rate == 0.15 +def test_param_with_placeholder(tax_benefit_system) -> None: + assert ( + tax_benefit_system.get_parameters_at_instant("2018-01-01").taxes.income_tax_rate + == 0.15 + ) -def test_stopped_parameter_before_end_value(tax_benefit_system): - assert tax_benefit_system.get_parameters_at_instant('2011-12-31').benefits.housing_allowance == 0.25 +def test_stopped_parameter_before_end_value(tax_benefit_system) -> None: + assert ( + tax_benefit_system.get_parameters_at_instant( + "2011-12-31", + ).benefits.housing_allowance + == 0.25 + ) -def test_stopped_parameter_after_end_value(tax_benefit_system): - with pytest.raises(ParameterNotFoundError): - tax_benefit_system.get_parameters_at_instant('2016-12-01').benefits.housing_allowance +def test_stopped_parameter_after_end_value(tax_benefit_system) -> None: + with pytest.raises(ParameterNotFound): + tax_benefit_system.get_parameters_at_instant( + "2016-12-01", + ).benefits.housing_allowance -def test_parameter_for_period(tax_benefit_system): +def test_parameter_for_period(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate assert income_tax_rate("2015") == income_tax_rate("2015-01-01") -def test_wrong_value(tax_benefit_system): +def test_wrong_value(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate with pytest.raises(ValueError): income_tax_rate("test") -def test_parameter_repr(tax_benefit_system): +def test_parameter_repr(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters - tf = tempfile.NamedTemporaryFile(delete = False) - tf.write(repr(parameters).encode('utf-8')) + tf = tempfile.NamedTemporaryFile(delete=False) + tf.write(repr(parameters).encode("utf-8")) tf.close() - tf_parameters = load_parameter_file(file_path = tf.name) + tf_parameters = load_parameter_file(file_path=tf.name) assert repr(parameters) == repr(tf_parameters) -def test_parameters_metadata(tax_benefit_system): +def test_parameters_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.basic_income - assert parameter.metadata['reference'] == 'https://law.gov.example/basic-income/amount' - assert parameter.metadata['unit'] == 'currency-EUR' - assert parameter.values_list[0].metadata['reference'] == 'https://law.gov.example/basic-income/amount/2015-12' - assert parameter.values_list[0].metadata['unit'] == 'currency-EUR' + assert ( + parameter.metadata["reference"] == "https://law.gov.example/basic-income/amount" + ) + assert parameter.metadata["unit"] == "currency-EUR" + assert ( + parameter.values_list[0].metadata["reference"] + == "https://law.gov.example/basic-income/amount/2015-12" + ) + assert parameter.values_list[0].metadata["unit"] == "currency-EUR" scale = tax_benefit_system.parameters.taxes.social_security_contribution - assert scale.metadata['threshold_unit'] == 'currency-EUR' - assert scale.metadata['rate_unit'] == '/1' + assert scale.metadata["threshold_unit"] == "currency-EUR" + assert scale.metadata["rate_unit"] == "/1" -def test_parameter_node_metadata(tax_benefit_system): +def test_parameter_node_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits - assert parameter.description == 'Social benefits' + assert parameter.description == "Social benefits" parameter_2 = tax_benefit_system.parameters.taxes.housing_tax - assert parameter_2.description == 'Housing tax' + assert parameter_2.description == "Housing tax" -def test_parameter_documentation(tax_benefit_system): +def test_parameter_documentation(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.housing_allowance - assert parameter.documentation == 'A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.\n' + assert ( + parameter.documentation + == "A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.\n" + ) -def test_get_descendants(tax_benefit_system): - all_parameters = {parameter.name for parameter in tax_benefit_system.parameters.get_descendants()} - assert all_parameters.issuperset({'taxes', 'taxes.housing_tax', 'taxes.housing_tax.minimal_amount'}) +def test_get_descendants(tax_benefit_system) -> None: + all_parameters = { + parameter.name for parameter in tax_benefit_system.parameters.get_descendants() + } + assert all_parameters.issuperset( + {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"}, + ) -def test_name(): +def test_name() -> None: parameter_data = { "description": "Parameter indexed by a numeric key", - "2010": { - "values": { - '2006-01-01': 0.0075 - } - } - } - parameter = ParameterNode('root', data = parameter_data) + "2010": {"values": {"2006-01-01": 0.0075}}, + } + parameter = ParameterNode("root", data=parameter_data) assert parameter.children["2010"].name == "root.2010" diff --git a/tests/core/test_periods.py b/tests/core/test_periods.py deleted file mode 100644 index 2c125d527c..0000000000 --- a/tests/core/test_periods.py +++ /dev/null @@ -1,203 +0,0 @@ -# -*- coding: utf-8 -*- - - -import pytest - -from openfisca_core.periods import Period, Instant, YEAR, MONTH, DAY, period - -first_jan = Instant((2014, 1, 1)) -first_march = Instant((2014, 3, 1)) - - -''' -Test Period -> String -''' - - -# Years - -def test_year(): - assert str(Period((YEAR, first_jan, 1))) == '2014' - - -def test_12_months_is_a_year(): - assert str(Period((MONTH, first_jan, 12))) == '2014' - - -def test_rolling_year(): - assert str(Period((MONTH, first_march, 12))) == 'year:2014-03' - assert str(Period((YEAR, first_march, 1))) == 'year:2014-03' - - -def test_several_years(): - assert str(Period((YEAR, first_jan, 3))) == 'year:2014:3' - assert str(Period((YEAR, first_march, 3))) == 'year:2014-03:3' - - -# Months - -def test_month(): - assert str(Period((MONTH, first_jan, 1))) == '2014-01' - - -def test_several_months(): - assert str(Period((MONTH, first_jan, 3))) == 'month:2014-01:3' - assert str(Period((MONTH, first_march, 3))) == 'month:2014-03:3' - - -# Days - -def test_day(): - assert str(Period((DAY, first_jan, 1))) == '2014-01-01' - - -def test_several_days(): - assert str(Period((DAY, first_jan, 3))) == 'day:2014-01-01:3' - assert str(Period((DAY, first_march, 3))) == 'day:2014-03-01:3' - - -''' -Test String -> Period -''' - - -# Years - -def test_parsing_year(): - assert period('2014') == Period((YEAR, first_jan, 1)) - - -def test_parsing_rolling_year(): - assert period('year:2014-03') == Period((YEAR, first_march, 1)) - - -def test_parsing_several_years(): - assert period('year:2014:2') == Period((YEAR, first_jan, 2)) - - -def test_wrong_syntax_several_years(): - with pytest.raises(ValueError): - period('2014:2') - - -# Months - -def test_parsing_month(): - assert period('2014-01') == Period((MONTH, first_jan, 1)) - - -def test_parsing_several_months(): - assert period('month:2014-03:3') == Period((MONTH, first_march, 3)) - - -def test_wrong_syntax_several_months(): - with pytest.raises(ValueError): - period('2014-3:3') - - -# Days - -def test_parsing_day(): - assert period('2014-01-01') == Period((DAY, first_jan, 1)) - - -def test_parsing_several_days(): - assert period('day:2014-03-01:3') == Period((DAY, first_march, 3)) - - -def test_wrong_syntax_several_days(): - with pytest.raises(ValueError): - period('2014-2-3:2') - - -def test_day_size_in_days(): - assert Period(('day', Instant((2014, 12, 31)), 1)).size_in_days == 1 - - -def test_3_day_size_in_days(): - assert Period(('day', Instant((2014, 12, 31)), 3)).size_in_days == 3 - - -def test_month_size_in_days(): - assert Period(('month', Instant((2014, 12, 1)), 1)).size_in_days == 31 - - -def test_leap_month_size_in_days(): - assert Period(('month', Instant((2012, 2, 3)), 1)).size_in_days == 29 - - -def test_3_month_size_in_days(): - assert Period(('month', Instant((2013, 1, 3)), 3)).size_in_days == 31 + 28 + 31 - - -def test_leap_3_month_size_in_days(): - assert Period(('month', Instant((2012, 1, 3)), 3)).size_in_days == 31 + 29 + 31 - - -def test_year_size_in_days(): - assert Period(('year', Instant((2014, 12, 1)), 1)).size_in_days == 365 - - -def test_leap_year_size_in_days(): - assert Period(('year', Instant((2012, 1, 1)), 1)).size_in_days == 366 - - -def test_2_years_size_in_days(): - assert Period(('year', Instant((2014, 1, 1)), 2)).size_in_days == 730 - -# Misc - - -def test_wrong_date(): - with pytest.raises(ValueError): - period("2006-31-03") - - -def test_ambiguous_period(): - with pytest.raises(ValueError): - period('month:2014') - - -def test_deprecated_signature(): - with pytest.raises(TypeError): - period(MONTH, 2014) - - -def test_wrong_argument(): - with pytest.raises(ValueError): - period({}) - - -def test_wrong_argument_1(): - with pytest.raises(ValueError): - period([]) - - -def test_none(): - with pytest.raises(ValueError): - period(None) - - -def test_empty_string(): - with pytest.raises(ValueError): - period('') - - -@pytest.mark.parametrize("test", [ - (period('year:2014:2'), YEAR, 2, period('2014'), period('2015')), - (period(2017), MONTH, 12, period('2017-01'), period('2017-12')), - (period('year:2014:2'), MONTH, 24, period('2014-01'), period('2015-12')), - (period('month:2014-03:3'), MONTH, 3, period('2014-03'), period('2014-05')), - (period(2017), DAY, 365, period('2017-01-01'), period('2017-12-31')), - (period('year:2014:2'), DAY, 730, period('2014-01-01'), period('2015-12-31')), - (period('month:2014-03:3'), DAY, 92, period('2014-03-01'), period('2014-05-31')), - ]) -def test_subperiods(test): - - def check_subperiods(period, unit, length, first, last): - subperiods = period.get_subperiods(unit) - assert len(subperiods) == length - assert subperiods[0] == first - assert subperiods[-1] == last - - check_subperiods(*test) diff --git a/tests/core/test_projectors.py b/tests/core/test_projectors.py index be401bbec8..c62e49d3a7 100644 --- a/tests/core/test_projectors.py +++ b/tests/core/test_projectors.py @@ -1,13 +1,15 @@ +import numpy + +from openfisca_core.entities import build_entity +from openfisca_core.indexed_enums import Enum +from openfisca_core.periods import DateUnit from openfisca_core.simulations.simulation_builder import SimulationBuilder from openfisca_core.taxbenefitsystems import TaxBenefitSystem -from openfisca_core.entities import build_entity -from openfisca_core.model_api import Enum, Variable, ETERNITY -import numpy as np +from openfisca_core.variables import Variable -def test_shortcut_to_containing_entity_provided(): - """ - Tests that, when an entity provides a containing entity, +def test_shortcut_to_containing_entity_provided() -> None: + """Tests that, when an entity provides a containing entity, the shortcut to that containing entity is provided. """ person_entity = build_entity( @@ -15,28 +17,32 @@ def test_shortcut_to_containing_entity_provided(): plural="people", label="A person", is_person=True, - ) + ) family_entity = build_entity( key="family", plural="families", label="A family (all members in the same household)", containing_entities=["household"], - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) household_entity = build_entity( key="household", plural="households", label="A household, containing one or more families", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person_entity, family_entity, household_entity] @@ -45,9 +51,8 @@ def test_shortcut_to_containing_entity_provided(): assert simulation.populations["family"].household.entity.key == "household" -def test_shortcut_to_containing_entity_not_provided(): - """ - Tests that, when an entity doesn't provide a containing +def test_shortcut_to_containing_entity_not_provided() -> None: + """Tests that, when an entity doesn't provide a containing entity, the shortcut to that containing entity is not provided. """ person_entity = build_entity( @@ -55,28 +60,32 @@ def test_shortcut_to_containing_entity_not_provided(): plural="people", label="A person", is_person=True, - ) + ) family_entity = build_entity( key="family", plural="families", label="A family (all members in the same household)", containing_entities=[], - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) household_entity = build_entity( key="household", plural="households", label="A household, containing one or more families", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person_entity, family_entity, household_entity] @@ -84,33 +93,33 @@ def test_shortcut_to_containing_entity_not_provided(): simulation = SimulationBuilder().build_from_dict(system, {}) try: simulation.populations["family"].household - raise AssertionError() + raise AssertionError except AttributeError: pass -def test_enum_projects_downwards(): - """ - Test that an Enum-type household-level variable projects +def test_enum_projects_downwards() -> None: + """Test that an Enum-type household-level variable projects values onto its members correctly. """ - person = build_entity( key="person", plural="people", label="A person", is_person=True, - ) + ) household = build_entity( key="household", plural="households", label="A household", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person, household] @@ -125,61 +134,61 @@ class household_enum_variable(Variable): possible_values = enum default_value = enum.FIRST_OPTION entity = household - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY class projected_enum_variable(Variable): value_type = Enum possible_values = enum default_value = enum.FIRST_OPTION entity = person - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY - def formula(person, period): - return person.household("household_enum_variable", period) + def formula(self, period): + return self.household("household_enum_variable", period) system.add_variables(household_enum_variable, projected_enum_variable) - simulation = SimulationBuilder().build_from_dict(system, { - "people": { - "person1": {}, - "person2": {}, - "person3": {} + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_enum_variable": {"eternity": "SECOND_OPTION"}, + }, }, - "households": { - "household1": { - "members": ["person1", "person2", "person3"], - "household_enum_variable": { - "eternity": "SECOND_OPTION" - } - } - } - }) + }, + ) - assert (simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str() == np.array(["SECOND_OPTION"] * 3)).all() + assert ( + simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str() + == numpy.array(["SECOND_OPTION"] * 3) + ).all() -def test_enum_projects_upwards(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_upwards() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person = build_entity( key="person", plural="people", label="A person", is_person=True, - ) + ) household = build_entity( key="household", plural="households", label="A household", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person, household] @@ -194,73 +203,82 @@ class household_projected_variable(Variable): possible_values = enum default_value = enum.FIRST_OPTION entity = household - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY - def formula(household, period): - return household.value_from_first_person(household.members("person_enum_variable", period)) + def formula(self, period): + return self.value_from_first_person( + self.members("person_enum_variable", period), + ) class person_enum_variable(Variable): value_type = Enum possible_values = enum default_value = enum.FIRST_OPTION entity = person - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY system.add_variables(household_projected_variable, person_enum_variable) - simulation = SimulationBuilder().build_from_dict(system, { - "people": { - "person1": { - "person_enum_variable": { - "ETERNITY": "SECOND_OPTION" - } + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": { + "person1": {"person_enum_variable": {"ETERNITY": "SECOND_OPTION"}}, + "person2": {}, + "person3": {}, + }, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], }, - "person2": {}, - "person3": {} }, - "households": { - "household1": { - "members": ["person1", "person2", "person3"], - } - } - }) + }, + ) - assert (simulation.calculate("household_projected_variable", "2021-01-01").decode_to_str() == np.array(["SECOND_OPTION"])).all() + assert ( + simulation.calculate( + "household_projected_variable", + "2021-01-01", + ).decode_to_str() + == numpy.array(["SECOND_OPTION"]) + ).all() -def test_enum_projects_between_containing_groups(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_between_containing_groups() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person_entity = build_entity( key="person", plural="people", label="A person", is_person=True, - ) + ) family_entity = build_entity( key="family", plural="families", label="A family (all members in the same household)", containing_entities=["household"], - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) household_entity = build_entity( key="household", plural="households", label="A household, containing one or more families", - roles=[{ - "key": "member", - "plural": "members", - "label": "Member", - }] - ) + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) entities = [person_entity, family_entity, household_entity] @@ -275,42 +293,57 @@ class household_level_variable(Variable): possible_values = enum default_value = enum.FIRST_OPTION entity = household_entity - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY class projected_family_level_variable(Variable): value_type = Enum possible_values = enum default_value = enum.FIRST_OPTION entity = family_entity - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) - system.add_variables(household_level_variable, projected_family_level_variable) - - simulation = SimulationBuilder().build_from_dict(system, { - "people": { - "person1": {}, - "person2": {}, - "person3": {} + class decoded_projected_family_level_variable(Variable): + value_type = str + entity = family_entity + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.household("household_level_variable", period).decode_to_str() + + system.add_variables( + household_level_variable, + projected_family_level_variable, + decoded_projected_family_level_variable, + ) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "families": { + "family1": {"members": ["person1", "person2"]}, + "family2": {"members": ["person3"]}, }, - "families": { - "family1": { - "members": ["person1", "person2"] - }, - "family2": { - "members": ["person3"] + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_level_variable": {"eternity": "SECOND_OPTION"}, }, }, - "households": { - "household1": { - "members": ["person1", "person2", "person3"], - "household_level_variable": { - "eternity": "SECOND_OPTION" - } - } - } - }) - - assert (simulation.calculate("projected_family_level_variable", "2021-01-01").decode_to_str() == np.array(["SECOND_OPTION"])).all() + }, + ) + + assert ( + simulation.calculate( + "projected_family_level_variable", + "2021-01-01", + ).decode_to_str() + == numpy.array(["SECOND_OPTION"]) + ).all() + assert ( + simulation.calculate("decoded_projected_family_level_variable", "2021-01-01") + == numpy.array(["SECOND_OPTION"]) + ).all() diff --git a/tests/core/test_reforms.py b/tests/core/test_reforms.py index 8735cee18f..1f31bcde2a 100644 --- a/tests/core/test_reforms.py +++ b/tests/core/test_reforms.py @@ -2,12 +2,14 @@ import pytest -from openfisca_core import periods -from openfisca_core.periods import Instant -from openfisca_core.tools import assert_near -from openfisca_core.parameters import ValuesHistory, ParameterNode from openfisca_country_template.entities import Household, Person -from openfisca_core.model_api import * # noqa analysis:ignore + +from openfisca_core import holders, periods, simulations +from openfisca_core.parameters import ParameterNode, ValuesHistory +from openfisca_core.periods import DateUnit, Instant +from openfisca_core.reforms import Reform +from openfisca_core.tools import assert_near +from openfisca_core.variables import Variable class goes_to_school(Variable): @@ -15,333 +17,463 @@ class goes_to_school(Variable): default_value = True entity = Person label = "The person goes to school (only relevant for children)" - definition_period = MONTH + definition_period = DateUnit.MONTH class WithBasicIncomeNeutralized(Reform): - def apply(self): - self.neutralize_variable('basic_income') + def apply(self) -> None: + self.neutralize_variable("basic_income") -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(goes_to_school) -def test_formula_neutralization(make_simulation, tax_benefit_system): +def test_formula_neutralization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) - period = '2017-01' + period = "2017-01" simulation = make_simulation(reform.base_tax_benefit_system, {}, period) simulation.debug = True - basic_income = simulation.calculate('basic_income', period = period) + basic_income = simulation.calculate("basic_income", period=period) assert_near(basic_income, 600) - disposable_income = simulation.calculate('disposable_income', period = period) + disposable_income = simulation.calculate("disposable_income", period=period) assert disposable_income > 0 reform_simulation = make_simulation(reform, {}, period) reform_simulation.debug = True - basic_income_reform = reform_simulation.calculate('basic_income', period = '2013-01') - assert_near(basic_income_reform, 0, absolute_error_margin = 0) - disposable_income_reform = reform_simulation.calculate('disposable_income', period = period) + basic_income_reform = reform_simulation.calculate("basic_income", period="2013-01") + assert_near(basic_income_reform, 0, absolute_error_margin=0) + disposable_income_reform = reform_simulation.calculate( + "disposable_income", + period=period, + ) assert_near(disposable_income_reform, 0) -def test_neutralization_variable_with_default_value(make_simulation, tax_benefit_system): +def test_neutralization_variable_with_default_value( + make_simulation, + tax_benefit_system, +) -> None: class test_goes_to_school_neutralization(Reform): - def apply(self): - self.neutralize_variable('goes_to_school') + def apply(self) -> None: + self.neutralize_variable("goes_to_school") reform = test_goes_to_school_neutralization(tax_benefit_system) period = "2017-01" simulation = make_simulation(reform.base_tax_benefit_system, {}, period) - goes_to_school = simulation.calculate('goes_to_school', period) - assert_near(goes_to_school, [True], absolute_error_margin = 0) + goes_to_school = simulation.calculate("goes_to_school", period) + assert_near(goes_to_school, [True], absolute_error_margin=0) -def test_neutralization_optimization(make_simulation, tax_benefit_system): +def test_neutralization_optimization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) - period = '2017-01' + period = "2017-01" simulation = make_simulation(reform, {}, period) simulation.debug = True - simulation.calculate('basic_income', period = '2013-01') - simulation.calculate_add('basic_income', period = '2013') + simulation.calculate("basic_income", period="2013-01") + simulation.calculate_add("basic_income", period="2013") # As basic_income is neutralized, it should not be cached - basic_income_holder = simulation.persons.get_holder('basic_income') + basic_income_holder = simulation.persons.get_holder("basic_income") assert basic_income_holder.get_known_periods() == [] -def test_input_variable_neutralization(make_simulation, tax_benefit_system): - +def test_input_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_salary_neutralization(Reform): - def apply(self): - self.neutralize_variable('salary') + def apply(self) -> None: + self.neutralize_variable("salary") reform = test_salary_neutralization(tax_benefit_system) - period = '2017-01' + period = "2017-01" reform = test_salary_neutralization(tax_benefit_system) with warnings.catch_warnings(record=True) as raised_warnings: - reform_simulation = make_simulation(reform, {'salary': [1200, 1000]}, period) - assert 'You cannot set a value for the variable' in raised_warnings[0].message.args[0] - salary = reform_simulation.calculate('salary', period) - assert_near(salary, [0, 0],) - disposable_income_reform = reform_simulation.calculate('disposable_income', period = period) + reform_simulation = make_simulation(reform, {"salary": [1200, 1000]}, period) + assert ( + "You cannot set a value for the variable" + in raised_warnings[0].message.args[0] + ) + salary = reform_simulation.calculate("salary", period) + assert_near( + salary, + [0, 0], + ) + disposable_income_reform = reform_simulation.calculate( + "disposable_income", + period=period, + ) assert_near(disposable_income_reform, [600, 600]) -def test_permanent_variable_neutralization(make_simulation, tax_benefit_system): - +def test_permanent_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_date_naissance_neutralization(Reform): - def apply(self): - self.neutralize_variable('birth') + def apply(self) -> None: + self.neutralize_variable("birth") reform = test_date_naissance_neutralization(tax_benefit_system) - period = '2017-01' - simulation = make_simulation(reform.base_tax_benefit_system, {'birth': '1980-01-01'}, period) + period = "2017-01" + simulation = make_simulation( + reform.base_tax_benefit_system, + {"birth": "1980-01-01"}, + period, + ) with warnings.catch_warnings(record=True) as raised_warnings: - reform_simulation = make_simulation(reform, {'birth': '1980-01-01'}, period) - assert 'You cannot set a value for the variable' in raised_warnings[0].message.args[0] - assert str(simulation.calculate('birth', None)[0]) == '1980-01-01' - assert str(reform_simulation.calculate('birth', None)[0]) == '1970-01-01' - - -def test_update_items(): - def check_update_items(description, value_history, start_instant, stop_instant, value, expected_items): - value_history.update(period=None, start=start_instant, stop=stop_instant, value=value) + reform_simulation = make_simulation(reform, {"birth": "1980-01-01"}, period) + assert ( + "You cannot set a value for the variable" + in raised_warnings[0].message.args[0] + ) + assert str(simulation.calculate("birth", None)[0]) == "1980-01-01" + assert str(reform_simulation.calculate("birth", None)[0]) == "1970-01-01" + + +def test_update_items() -> None: + def check_update_items( + description, + value_history, + start_instant, + stop_instant, + value, + expected_items, + ) -> None: + value_history.update( + period=None, + start=start_instant, + stop=stop_instant, + value=value, + ) assert value_history == expected_items check_update_items( - 'Replace an item by a new item', - ValuesHistory('dummy_name', {"2013-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Replace an item by a new item", + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2013).start, periods.period(2013).stop, 1.0, - ValuesHistory('dummy_name', {"2013-01-01": {'value': 1.0}, "2014-01-01": {'value': None}}), - ) + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}}, + ), + ) check_update_items( - 'Replace an item by a new item in a list of items, the last being open', - ValuesHistory('dummy_name', {"2014-01-01": {'value': 9.53}, "2015-01-01": {'value': 9.61}, "2016-01-01": {'value': 9.67}}), + "Replace an item by a new item in a list of items, the last being open", + ValuesHistory( + "dummy_name", + { + "2014-01-01": {"value": 9.53}, + "2015-01-01": {"value": 9.61}, + "2016-01-01": {"value": 9.67}, + }, + ), periods.period(2015).start, periods.period(2015).stop, 1.0, - ValuesHistory('dummy_name', {"2014-01-01": {'value': 9.53}, "2015-01-01": {'value': 1.0}, "2016-01-01": {'value': 9.67}}), - ) + ValuesHistory( + "dummy_name", + { + "2014-01-01": {"value": 9.53}, + "2015-01-01": {"value": 1.0}, + "2016-01-01": {"value": 9.67}, + }, + ), + ) check_update_items( - 'Open the stop instant to the future', - ValuesHistory('dummy_name', {"2013-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Open the stop instant to the future", + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2013).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2013-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2013-01-01": {"value": 1.0}}), + ) check_update_items( - 'Insert a new item in the middle of an existing item', - ValuesHistory('dummy_name', {"2010-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Insert a new item in the middle of an existing item", + ValuesHistory( + "dummy_name", + {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2011).start, periods.period(2011).stop, 1.0, - ValuesHistory('dummy_name', {"2010-01-01": {'value': 0.0}, "2011-01-01": {'value': 1.0}, "2012-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), - ) + ValuesHistory( + "dummy_name", + { + "2010-01-01": {"value": 0.0}, + "2011-01-01": {"value": 1.0}, + "2012-01-01": {"value": 0.0}, + "2014-01-01": {"value": None}, + }, + ), + ) check_update_items( - 'Insert a new open item coming after the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item coming after the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2015).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}, "2015-01-01": {'value': 1.0}}), - ) + ValuesHistory( + "dummy_name", + { + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + "2015-01-01": {"value": 1.0}, + }, + ), + ) check_update_items( - 'Insert a new item starting at the same date than the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item starting at the same date than the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2014).start, periods.period(2014).stop, 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 1.0}, "2015-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 1.0}, + "2015-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new open item starting at the same date than the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting at the same date than the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2014).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 1.0}}), - ) + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}}, + ), + ) check_update_items( - 'Insert a new item coming before the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item coming before the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2005).start, periods.period(2005).stop, 1.0, - ValuesHistory('dummy_name', {"2005-01-01": {'value': 1.0}, "2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2005-01-01": {"value": 1.0}, + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new item coming before the first item with a hole', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item coming before the first item with a hole", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2003).start, periods.period(2003).stop, 1.0, - ValuesHistory('dummy_name', {"2003-01-01": {'value': 1.0}, "2004-01-01": {'value': None}, "2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2003-01-01": {"value": 1.0}, + "2004-01-01": {"value": None}, + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new open item starting before the start date of the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting before the start date of the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2005).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2005-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2005-01-01": {"value": 1.0}}), + ) check_update_items( - 'Insert a new open item starting at the same date than the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting at the same date than the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2006).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2006-01-01": {"value": 1.0}}), + ) -def test_add_variable(make_simulation, tax_benefit_system): +def test_add_variable(make_simulation, tax_benefit_system) -> None: class new_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + 10 + def formula(self, period): + return self.empty_array() + 10 class test_add_variable(Reform): - - def apply(self): + def apply(self) -> None: self.add_variable(new_variable) reform = test_add_variable(tax_benefit_system) - assert tax_benefit_system.get_variable('new_variable') is None + assert tax_benefit_system.get_variable("new_variable") is None reform_simulation = make_simulation(reform, {}, 2013) reform_simulation.debug = True - new_variable1 = reform_simulation.calculate('new_variable', period = '2013-01') - assert_near(new_variable1, 10, absolute_error_margin = 0) + new_variable1 = reform_simulation.calculate("new_variable", period="2013-01") + assert_near(new_variable1, 10, absolute_error_margin=0) -def test_add_dated_variable(make_simulation, tax_benefit_system): +def test_add_dated_variable(make_simulation, tax_benefit_system) -> None: class new_dated_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula_2010_01_01(household, period): - return household.empty_array() + 10 + def formula_2010_01_01(self, period): + return self.empty_array() + 10 - def formula_2011_01_01(household, period): - return household.empty_array() + 15 + def formula_2011_01_01(self, period): + return self.empty_array() + 15 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_dated_variable) reform = test_add_variable(tax_benefit_system) - reform_simulation = make_simulation(reform, {}, '2013-01') + reform_simulation = make_simulation(reform, {}, "2013-01") reform_simulation.debug = True - new_dated_variable1 = reform_simulation.calculate('new_dated_variable', period = '2013-01') - assert_near(new_dated_variable1, 15, absolute_error_margin = 0) - + new_dated_variable1 = reform_simulation.calculate( + "new_dated_variable", + period="2013-01", + ) + assert_near(new_dated_variable1, 15, absolute_error_margin=0) -def test_update_variable(make_simulation, tax_benefit_system): +def test_update_variable(make_simulation, tax_benefit_system) -> None: class disposable_income(Variable): - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.update_variable(disposable_income) reform = test_update_variable(tax_benefit_system) - disposable_income_reform = reform.get_variable('disposable_income') - disposable_income_baseline = tax_benefit_system.get_variable('disposable_income') + disposable_income_reform = reform.get_variable("disposable_income") + disposable_income_baseline = tax_benefit_system.get_variable("disposable_income") assert disposable_income_reform is not None - assert disposable_income_reform.entity.plural == disposable_income_baseline.entity.plural + assert ( + disposable_income_reform.entity.plural + == disposable_income_baseline.entity.plural + ) assert disposable_income_reform.name == disposable_income_baseline.name assert disposable_income_reform.label == disposable_income_baseline.label reform_simulation = make_simulation(reform, {}, 2018) - disposable_income1 = reform_simulation.calculate('disposable_income', period = '2018-01') - assert_near(disposable_income1, 10, absolute_error_margin = 0) - - disposable_income2 = reform_simulation.calculate('disposable_income', period = '2017-01') + disposable_income1 = reform_simulation.calculate( + "disposable_income", + period="2018-01", + ) + assert_near(disposable_income1, 10, absolute_error_margin=0) + + disposable_income2 = reform_simulation.calculate( + "disposable_income", + period="2017-01", + ) # Before 2018, the former formula is used - assert(disposable_income2 > 100) - + assert disposable_income2 > 100 -def test_replace_variable(tax_benefit_system): +def test_replace_variable(tax_benefit_system) -> None: class disposable_income(Variable): - definition_period = MONTH + definition_period = DateUnit.MONTH entity = Person label = "Disposable income" value_type = float - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.replace_variable(disposable_income) reform = test_update_variable(tax_benefit_system) - disposable_income_reform = reform.get_variable('disposable_income') - assert disposable_income_reform.get_formula('2017') is None + disposable_income_reform = reform.get_variable("disposable_income") + assert disposable_income_reform.get_formula("2017") is None -def test_wrong_reform(tax_benefit_system): +def test_wrong_reform(tax_benefit_system) -> None: class wrong_reform(Reform): # A Reform must implement an `apply` method pass - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 wrong_reform(tax_benefit_system) -def test_modify_parameters(tax_benefit_system): - +def test_modify_parameters(tax_benefit_system) -> None: def modify_parameters(reference_parameters): reform_parameters_subtree = ParameterNode( - 'new_node', - data = { - 'new_param': { - 'values': {"2000-01-01": {'value': True}, "2015-01-01": {'value': None}} + "new_node", + data={ + "new_param": { + "values": { + "2000-01-01": {"value": True}, + "2015-01-01": {"value": None}, }, }, - ) - reference_parameters.children['new_node'] = reform_parameters_subtree + }, + ) + reference_parameters.children["new_node"] = reform_parameters_subtree return reference_parameters class test_modify_parameters(Reform): - def apply(self): - self.modify_parameters(modifier_function = modify_parameters) + def apply(self) -> None: + self.modify_parameters(modifier_function=modify_parameters) reform = test_modify_parameters(tax_benefit_system) - parameters_new_node = reform.parameters.children['new_node'] + parameters_new_node = reform.parameters.children["new_node"] assert parameters_new_node is not None instant = Instant((2013, 1, 1)) @@ -349,15 +481,14 @@ def apply(self): assert parameters_at_instant.new_node.new_param is True -def test_attributes_conservation(tax_benefit_system): - +def test_attributes_conservation(tax_benefit_system) -> None: class some_variable(Variable): value_type = int entity = Person label = "Variable with many attributes" - definition_period = MONTH - set_input = set_input_divide_by_period - calculate_output = calculate_output_add + definition_period = DateUnit.MONTH + set_input = holders.set_input_divide_by_period + calculate_output = simulations.calculate_output_add tax_benefit_system.add_variable(some_variable) @@ -365,12 +496,12 @@ class reform(Reform): class some_variable(Variable): default_value = 10 - def apply(self): + def apply(self) -> None: self.update_variable(some_variable) reformed_tbs = reform(tax_benefit_system) - reform_variable = reformed_tbs.get_variable('some_variable') - baseline_variable = tax_benefit_system.get_variable('some_variable') + reform_variable = reformed_tbs.get_variable("some_variable") + baseline_variable = tax_benefit_system.get_variable("some_variable") assert reform_variable.value_type == baseline_variable.value_type assert reform_variable.entity == baseline_variable.entity assert reform_variable.label == baseline_variable.label @@ -379,18 +510,17 @@ def apply(self): assert reform_variable.calculate_output == baseline_variable.calculate_output -def test_formulas_removal(tax_benefit_system): +def test_formulas_removal(tax_benefit_system) -> None: class reform(Reform): - def apply(self): - + def apply(self) -> None: class basic_income(Variable): pass self.update_variable(basic_income) - self.variables['basic_income'].formulas.clear() + self.variables["basic_income"].formulas.clear() reformed_tbs = reform(tax_benefit_system) - reform_variable = reformed_tbs.get_variable('basic_income') - baseline_variable = tax_benefit_system.get_variable('basic_income') + reform_variable = reformed_tbs.get_variable("basic_income") + baseline_variable = tax_benefit_system.get_variable("basic_income") assert len(reform_variable.formulas) == 0 assert len(baseline_variable.formulas) > 0 diff --git a/tests/core/test_simulation_builder.py b/tests/core/test_simulation_builder.py index b6a558751d..507d10e707 100644 --- a/tests/core/test_simulation_builder.py +++ b/tests/core/test_simulation_builder.py @@ -1,13 +1,15 @@ +from collections.abc import Iterable + import datetime -from typing import Iterable import pytest from openfisca_country_template import entities, situation_examples -from openfisca_core import periods, tools +from openfisca_core import tools from openfisca_core.errors import SituationParsingError from openfisca_core.indexed_enums import Enum +from openfisca_core.periods import DateUnit from openfisca_core.populations import Population from openfisca_core.simulations import Simulation, SimulationBuilder from openfisca_core.tools import test_runner @@ -16,13 +18,12 @@ @pytest.fixture def int_variable(persons): - class intvar(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = int entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return intvar() @@ -30,13 +31,12 @@ def __init__(self): @pytest.fixture def date_variable(persons): - class datevar(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = datetime.date entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return datevar() @@ -44,213 +44,339 @@ def __init__(self): @pytest.fixture def enum_variable(): + class _TestEnum(Enum): + foo = "bar" class TestEnum(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = Enum - dtype = 'O' - default_value = '0' + dtype = "O" + default_value = _TestEnum.foo is_neutralized = False set_input = None - possible_values = Enum('foo', 'bar') + possible_values = _TestEnum name = "enum" - def __init__(self): + def __init__(self) -> None: pass return TestEnum() -def test_build_default_simulation(tax_benefit_system): - one_person_simulation = SimulationBuilder().build_default_simulation(tax_benefit_system, 1) +def test_build_default_simulation(tax_benefit_system) -> None: + one_person_simulation = SimulationBuilder().build_default_simulation( + tax_benefit_system, + 1, + ) assert one_person_simulation.persons.count == 1 assert one_person_simulation.household.count == 1 assert one_person_simulation.household.members_entity_id == [0] - assert one_person_simulation.household.members_role == entities.Household.FIRST_PARENT + assert ( + one_person_simulation.household.members_role == entities.Household.FIRST_PARENT + ) - several_persons_simulation = SimulationBuilder().build_default_simulation(tax_benefit_system, 4) + several_persons_simulation = SimulationBuilder().build_default_simulation( + tax_benefit_system, + 4, + ) assert several_persons_simulation.persons.count == 4 assert several_persons_simulation.household.count == 4 - assert (several_persons_simulation.household.members_entity_id == [0, 1, 2, 3]).all() - assert (several_persons_simulation.household.members_role == entities.Household.FIRST_PARENT).all() + assert ( + several_persons_simulation.household.members_entity_id == [0, 1, 2, 3] + ).all() + assert ( + several_persons_simulation.household.members_role + == entities.Household.FIRST_PARENT + ).all() -def test_explicit_singular_entities(tax_benefit_system): +def test_explicit_singular_entities(tax_benefit_system) -> None: assert SimulationBuilder().explicit_singular_entities( tax_benefit_system, - {'persons': {'Javier': {}}, 'household': {'parents': ['Javier']}} - ) == {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}}} + {"persons": {"Javier": {}}, "household": {"parents": ["Javier"]}}, + ) == { + "persons": {"Javier": {}}, + "households": {"household": {"parents": ["Javier"]}}, + } -def test_add_person_entity(persons): - persons_json = {'Alicia': {'salary': {}}, 'Javier': {}} +def test_add_person_entity(persons) -> None: + persons_json = {"Alicia": {"salary": {}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - assert simulation_builder.get_count('persons') == 2 - assert simulation_builder.get_ids('persons') == ['Alicia', 'Javier'] + assert simulation_builder.get_count("persons") == 2 + assert simulation_builder.get_ids("persons") == ["Alicia", "Javier"] -def test_numeric_ids(persons): - persons_json = {1: {'salary': {}}, 2: {}} +def test_numeric_ids(persons) -> None: + persons_json = {1: {"salary": {}}, 2: {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - assert simulation_builder.get_count('persons') == 2 - assert simulation_builder.get_ids('persons') == ['1', '2'] + assert simulation_builder.get_count("persons") == 2 + assert simulation_builder.get_ids("persons") == ["1", "2"] -def test_add_person_entity_with_values(persons): - persons_json = {'Alicia': {'salary': {'2018-11': 3000}}, 'Javier': {}} +def test_add_person_entity_with_values(persons) -> None: + persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period(persons): - persons_json = {'Alicia': {'salary': 3000}, 'Javier': {}} +def test_add_person_values_with_default_period(persons) -> None: + persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('2018-11') + simulation_builder.set_default_period("2018-11") simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period_old_syntax(persons): - persons_json = {'Alicia': {'salary': 3000}, 'Javier': {}} +def test_add_person_values_with_default_period_old_syntax(persons) -> None: + persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('month:2018-11') + simulation_builder.set_default_period("month:2018-11") simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_group_entity(households): +def test_add_group_entity(households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', 'Tom'], households, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': ['Tom'], 'children': ['Sarah']}, - }) - assert simulation_builder.get_count('households') == 2 - assert simulation_builder.get_ids('households') == ['Household_1', 'Household_2'] - assert simulation_builder.get_memberships('households') == [0, 0, 1, 1] - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'parent', 'child', 'parent'] - - -def test_add_group_entity_loose_syntax(households): + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "Tom"], + households, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": ["Tom"], "children": ["Sarah"]}, + }, + ) + assert simulation_builder.get_count("households") == 2 + assert simulation_builder.get_ids("households") == ["Household_1", "Household_2"] + assert simulation_builder.get_memberships("households") == [0, 0, 1, 1] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "parent", + "child", + "parent", + ] + + +def test_add_group_entity_loose_syntax(households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', '1'], households, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': 1, 'children': 'Sarah'}, - }) - assert simulation_builder.get_count('households') == 2 - assert simulation_builder.get_ids('households') == ['Household_1', 'Household_2'] - assert simulation_builder.get_memberships('households') == [0, 0, 1, 1] - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'parent', 'child', 'parent'] - - -def test_add_variable_value(persons): - salary = persons.get_variable('salary') + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "1"], + households, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": 1, "children": "Sarah"}, + }, + ) + assert simulation_builder.get_count("households") == 2 + assert simulation_builder.get_ids("households") == ["Household_1", "Household_2"] + assert simulation_builder.get_memberships("households") == [0, 0, 1, 1] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "parent", + "child", + "parent", + ] + + +def test_add_variable_value(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', 3000) - input_array = simulation_builder.get_input('salary', '2018-11') + simulation_builder.entity_counts["persons"] = 1 + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + 3000, + ) + input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_add_variable_value_as_expression(persons): - salary = persons.get_variable('salary') +def test_add_variable_value_as_expression(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', '3 * 1000') - input_array = simulation_builder.get_input('salary', '2018-11') + simulation_builder.entity_counts["persons"] = 1 + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "3 * 1000", + ) + input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_fail_on_wrong_data(persons): - salary = persons.get_variable('salary') +def test_fail_on_wrong_data(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', 'alicia') - assert excinfo.value.error == {'persons': {'Alicia': {'salary': {'2018-11': "Can't deal with value: expected type number, received 'alicia'."}}}} - - -def test_fail_on_ill_formed_expression(persons): - salary = persons.get_variable('salary') + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "alicia", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "salary": { + "2018-11": "Can't deal with value: expected type number, received 'alicia'.", + }, + }, + }, + } + + +def test_fail_on_ill_formed_expression(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', '2 * / 1000') - assert excinfo.value.error == {'persons': {'Alicia': {'salary': {'2018-11': "I couldn't understand '2 * / 1000' as a value for 'salary'"}}}} - - -def test_fail_on_integer_overflow(persons, int_variable): + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "2 * / 1000", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "salary": { + "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'", + }, + }, + }, + } + + +def test_fail_on_integer_overflow(persons, int_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, int_variable, instance_index, 'Alicia', '2018-11', 9223372036854775808) - assert excinfo.value.error == {'persons': {'Alicia': {'intvar': {'2018-11': "Can't deal with value: '9223372036854775808', it's too large for type 'integer'."}}}} - - -def test_fail_on_date_parsing(persons, date_variable): + simulation_builder.add_variable_value( + persons, + int_variable, + instance_index, + "Alicia", + "2018-11", + 9223372036854775808, + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "intvar": { + "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'.", + }, + }, + }, + } + + +def test_fail_on_date_parsing(persons, date_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, date_variable, instance_index, 'Alicia', '2018-11', '2019-02-30') - assert excinfo.value.error == {'persons': {'Alicia': {'datevar': {'2018-11': "Can't deal with date: '2019-02-30'."}}}} + simulation_builder.add_variable_value( + persons, + date_variable, + instance_index, + "Alicia", + "2018-11", + "2019-02-30", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}}, + }, + } -def test_add_unknown_enum_variable_value(persons, enum_variable): +def test_add_unknown_enum_variable_value(persons, enum_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError): - simulation_builder.add_variable_value(persons, enum_variable, instance_index, 'Alicia', '2018-11', 'baz') + simulation_builder.add_variable_value( + persons, + enum_variable, + instance_index, + "Alicia", + "2018-11", + "baz", + ) -def test_finalize_person_entity(persons): - persons_json = {'Alicia': {'salary': {'2018-11': 3000}}, 'Javier': {}} +def test_finalize_person_entity(persons) -> None: + persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) population = Population(persons) simulation_builder.finalize_variables_init(population) - tools.assert_near(population.get_holder('salary').get_array('2018-11'), [3000, 0]) + tools.assert_near(population.get_holder("salary").get_array("2018-11"), [3000, 0]) assert population.count == 2 - assert population.ids == ['Alicia', 'Javier'] + assert population.ids == ["Alicia", "Javier"] -def test_canonicalize_period_keys(persons): - persons_json = {'Alicia': {'salary': {'year:2018-01': 100}}} +def test_canonicalize_period_keys(persons) -> None: + persons_json = {"Alicia": {"salary": {"year:2018-01": 100}}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) population = Population(persons) simulation_builder.finalize_variables_init(population) - tools.assert_near(population.get_holder('salary').get_array('2018-12'), [100]) + tools.assert_near(population.get_holder("salary").get_array("2018-12"), [100]) -def test_finalize_households(tax_benefit_system): - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) +def test_finalize_households(tax_benefit_system) -> None: + simulation = Simulation( + tax_benefit_system, + tax_benefit_system.instantiate_entities(), + ) simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', 'Tom'], simulation.household.entity, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': ['Tom'], 'children': ['Sarah']}, - }) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "Tom"], + simulation.household.entity, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": ["Tom"], "children": ["Sarah"]}, + }, + ) simulation_builder.finalize_variables_init(simulation.household) tools.assert_near(simulation.household.members_entity_id, [0, 0, 1, 1]) - tools.assert_near(simulation.persons.has_role(entities.Household.PARENT), [True, True, False, True]) - - -def test_check_persons_to_allocate(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' - persons_to_allocate = ['Alicia'] - persons_ids = ['Alicia'] + tools.assert_near( + simulation.persons.has_role(entities.Household.PARENT), + [True, True, False, True], + ) + + +def test_check_persons_to_allocate() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" + persons_to_allocate = ["Alicia"] + persons_ids = ["Alicia"] index = 0 SimulationBuilder().check_persons_to_allocate( persons_plural, @@ -261,135 +387,196 @@ def test_check_persons_to_allocate(): role_id, persons_to_allocate, index, - ) + ) -def test_allocate_undeclared_person(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' - persons_to_allocate = ['Alicia'] +def test_allocate_undeclared_person() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" + persons_to_allocate = ["Alicia"] persons_ids = [] index = 0 with pytest.raises(SituationParsingError) as exception: SimulationBuilder().check_persons_to_allocate( - persons_plural, entity_plural, + persons_plural, + entity_plural, persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index) - assert exception.value.error == {'familles': {'famille1': {'parents': 'Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.'}}} - - -def test_allocate_person_twice(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) + assert exception.value.error == { + "familles": { + "famille1": { + "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.", + }, + }, + } + + +def test_allocate_person_twice() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" persons_to_allocate = [] - persons_ids = ['Alicia'] + persons_ids = ["Alicia"] index = 0 with pytest.raises(SituationParsingError) as exception: SimulationBuilder().check_persons_to_allocate( - persons_plural, entity_plural, + persons_plural, + entity_plural, persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index) - assert exception.value.error == {'familles': {'famille1': {'parents': 'Alicia has been declared more than once in familles'}}} - - -def test_one_person_without_household(tax_benefit_system): - simulation_dict = {'persons': {'Alicia': {}}} - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, simulation_dict) + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) + assert exception.value.error == { + "familles": { + "famille1": { + "parents": "Alicia has been declared more than once in familles", + }, + }, + } + + +def test_one_person_without_household(tax_benefit_system) -> None: + simulation_dict = {"persons": {"Alicia": {}}} + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + simulation_dict, + ) assert simulation.household.count == 1 - parents_in_households = simulation.household.nb_persons(role = entities.Household.PARENT) - assert parents_in_households.tolist() == [1] # household member default role is first_parent + parents_in_households = simulation.household.nb_persons( + role=entities.Household.PARENT, + ) + assert parents_in_households.tolist() == [ + 1, + ] # household member default role is first_parent -def test_some_person_without_household(tax_benefit_system): +def test_some_person_without_household(tax_benefit_system) -> None: input_yaml = """ persons: {'Alicia': {}, 'Bob': {}} household: {'parents': ['Alicia']} """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.household.count == 2 - parents_in_households = simulation.household.nb_persons(role = entities.Household.PARENT) - assert parents_in_households.tolist() == [1, 1] # household member default role is first_parent + parents_in_households = simulation.household.nb_persons( + role=entities.Household.PARENT, + ) + assert parents_in_households.tolist() == [ + 1, + 1, + ] # household member default role is first_parent -def test_nb_persons_in_households(tax_benefit_system): +def test_nb_persons_in_households(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) - persons_in_households = simulation_builder.nb_persons('household') + persons_in_households = simulation_builder.nb_persons("household") assert persons_in_households.tolist() == [1, 3, 1] -def test_nb_persons_no_role(tax_benefit_system): +def test_nb_persons_no_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) - parents_in_households = household_instance.nb_persons(role = entities.Household.PARENT) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.PARENT, + ) - assert parents_in_households.tolist() == [1, 3, 1] # household member default role is first_parent + assert parents_in_households.tolist() == [ + 1, + 3, + 1, + ] # household member default role is first_parent -def test_nb_persons_by_role(tax_benefit_system): +def test_nb_persons_by_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] - persons_households_roles: Iterable = ['child', 'first_parent', 'second_parent', 'first_parent', 'child'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] + persons_households_roles: Iterable = [ + "child", + "first_parent", + "second_parent", + "first_parent", + "child", + ] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( household_instance, persons_households, - persons_households_roles - ) - parents_in_households = household_instance.nb_persons(role = entities.Household.FIRST_PARENT) + persons_households_roles, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.FIRST_PARENT, + ) assert parents_in_households.tolist() == [0, 1, 1] -def test_integral_roles(tax_benefit_system): +def test_integral_roles(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] # Same roles as test_nb_persons_by_role persons_households_roles: Iterable = [2, 0, 1, 0, 2] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( household_instance, persons_households, - persons_households_roles - ) - parents_in_households = household_instance.nb_persons(role = entities.Household.FIRST_PARENT) + persons_households_roles, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.FIRST_PARENT, + ) assert parents_in_households.tolist() == [0, 1, 1] @@ -397,66 +584,79 @@ def test_integral_roles(tax_benefit_system): # Test Intégration -def test_from_person_variable_to_group(tax_benefit_system): +def test_from_person_variable_to_group(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] + households_ids: Iterable = ["c", "a", "b"] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + persons_households: Iterable = ["c", "a", "a", "b", "a"] persons_salaries: Iterable = [6000, 2000, 1000, 1500, 1500] households_rents = [1036.6667, 781.6667, 271.6667] - period = '2018-12' + period = "2018-12" simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) + simulation_builder.declare_person_entity("person", persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) + household_instance = simulation_builder.declare_entity("household", households_ids) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) simulation = simulation_builder.build(tax_benefit_system) - simulation.set_input('salary', period, persons_salaries) - simulation.set_input('rent', period, households_rents) + simulation.set_input("salary", period, persons_salaries) + simulation.set_input("rent", period, households_rents) - total_taxes = simulation.calculate('total_taxes', period) + total_taxes = simulation.calculate("total_taxes", period) assert total_taxes == pytest.approx(households_rents) - assert total_taxes / simulation.calculate('rent', period) == pytest.approx(1) + assert total_taxes / simulation.calculate("rent", period) == pytest.approx(1) -def test_simulation(tax_benefit_system): +def test_simulation(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: 12000 """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.get_array("salary", "2016-10") == 12000 simulation.calculate("income_tax", "2016-10") simulation.calculate("total_taxes", "2016-10") -def test_vectorial_input(tax_benefit_system): +def test_vectorial_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) tools.assert_near(simulation.get_array("salary", "2016-10"), [12000, 20000]) simulation.calculate("income_tax", "2016-10") simulation.calculate("total_taxes", "2016-10") -def test_fully_specified_entities(tax_benefit_system): - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, situation_examples.couple) +def test_fully_specified_entities(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + situation_examples.couple, + ) assert simulation.household.count == 1 assert simulation.persons.count == 2 -def test_single_entity_shortcut(tax_benefit_system): +def test_single_entity_shortcut(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {} @@ -465,11 +665,14 @@ def test_single_entity_shortcut(tax_benefit_system): parents: [Alicia, Javier] """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.household.count == 1 -def test_order_preserved(tax_benefit_system): +def test_order_preserved(tax_benefit_system) -> None: input_yaml = """ persons: Javier: {} @@ -484,10 +687,10 @@ def test_order_preserved(tax_benefit_system): data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) - assert simulation.persons.ids == ['Javier', 'Alicia', 'Sarah', 'Tom'] + assert simulation.persons.ids == ["Javier", "Alicia", "Sarah", "Tom"] -def test_inconsistent_input(tax_benefit_system): +def test_inconsistent_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] @@ -495,5 +698,8 @@ def test_inconsistent_input(tax_benefit_system): 2016-10: [100, 200, 300] """ with pytest.raises(ValueError) as error: - SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert "its length is 3 while there are 2" in error.value.args[0] diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index d5f3ac8008..7f4897e776 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -1,46 +1,47 @@ +import pytest + from openfisca_country_template.situation_examples import single +from openfisca_core import errors, periods from openfisca_core.simulations import SimulationBuilder -def test_calculate_full_tracer(tax_benefit_system): +def test_calculate_full_tracer(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) simulation.trace = True - simulation.calculate('income_tax', '2017-01') + simulation.calculate("income_tax", "2017-01") income_tax_node = simulation.tracer.trees[0] - assert income_tax_node.name == 'income_tax' - assert str(income_tax_node.period) == '2017-01' + assert income_tax_node.name == "income_tax" + assert str(income_tax_node.period) == "2017-01" assert income_tax_node.value == 0 salary_node = income_tax_node.children[0] - assert salary_node.name == 'salary' - assert str(salary_node.period) == '2017-01' + assert salary_node.name == "salary" + assert str(salary_node.period) == "2017-01" assert salary_node.parameters == [] assert len(income_tax_node.parameters) == 1 - assert income_tax_node.parameters[0].name == 'taxes.income_tax_rate' - assert income_tax_node.parameters[0].period == '2017-01-01' + assert income_tax_node.parameters[0].name == "taxes.income_tax_rate" + assert income_tax_node.parameters[0].period == "2017-01-01" assert income_tax_node.parameters[0].value == 0.15 -def test_get_entity_not_found(tax_benefit_system): +def test_get_entity_not_found(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) - assert simulation.get_entity(plural = "no_such_entities") is None - - -def test_clone(tax_benefit_system): - simulation = SimulationBuilder().build_from_entities(tax_benefit_system, - { - "persons": { - "bill": {"salary": {"2017-01": 3000}}, - }, - "households": { - "household": { - "parents": ["bill"] - } - } - }) + assert simulation.get_entity(plural="no_such_entities") is None + + +def test_clone(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_entities( + tax_benefit_system, + { + "persons": { + "bill": {"salary": {"2017-01": 3000}}, + }, + "households": {"household": {"parents": ["bill"]}}, + }, + ) simulation_clone = simulation.clone() assert simulation != simulation_clone @@ -50,17 +51,31 @@ def test_clone(tax_benefit_system): assert simulation.persons != simulation_clone.persons - salary_holder = simulation.person.get_holder('salary') - salary_holder_clone = simulation_clone.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") + salary_holder_clone = simulation_clone.person.get_holder("salary") assert salary_holder != salary_holder_clone assert salary_holder_clone.simulation == simulation_clone assert salary_holder_clone.population == simulation_clone.persons -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single) - simulation.calculate('disposable_income', '2017-01') - memory_usage = simulation.get_memory_usage(variables = ['salary']) - assert(memory_usage['total_nb_bytes'] > 0) - assert(len(memory_usage['by_variable']) == 1) + simulation.calculate("disposable_income", "2017-01") + memory_usage = simulation.get_memory_usage(variables=["salary"]) + assert memory_usage["total_nb_bytes"] > 0 + assert len(memory_usage["by_variable"]) == 1 + + +def test_invalidate_cache_when_spiral_error_detected(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) + tracer = simulation.tracer + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) + + with pytest.raises(errors.SpiralError): + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(simulation.invalidated_caches) == 3 diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index 383723d20b..c9af9ecee0 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -1,44 +1,51 @@ -# -*- coding: utf-8 -*- - +import csv import json import os -import csv -import numpy as np -from pytest import fixture, mark, raises, approx -from openfisca_core.errors import CycleError, SpiralError -from openfisca_core.simulations import Simulation -from openfisca_core.tracers import SimpleTracer, FullTracer, TracingParameterNodeAtInstant, TraceNode +import numpy +from pytest import approx, fixture, mark, raises + from openfisca_country_template.variables.housing import HousingOccupancyStatus + +from openfisca_core import periods +from openfisca_core.simulations import CycleError, Simulation, SpiralError +from openfisca_core.tracers import ( + FullTracer, + SimpleTracer, + TraceNode, + TracingParameterNodeAtInstant, +) + from .parameters_fancy_indexing.test_fancy_indexing import parameters -class StubSimulation(Simulation): +class TestException(Exception): ... + - def __init__(self): +class StubSimulation(Simulation): + def __init__(self) -> None: self.exception = None self.max_spiral_loops = 1 - def _calculate(self, variable, period): + def _calculate(self, variable, period) -> None: if self.exception: raise self.exception - def invalidate_cache_entry(self, variable, period): + def invalidate_cache_entry(self, variable, period) -> None: pass - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: pass class MockTracer: - - def record_calculation_start(self, variable, period): + def record_calculation_start(self, variable, period) -> None: self.calculation_start_recorded = True - def record_calculation_result(self, value): + def record_calculation_result(self, value) -> None: self.recorded_result = True - def record_calculation_end(self): + def record_calculation_end(self) -> None: self.calculation_end_recorded = True @@ -48,113 +55,134 @@ def tracer(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_one_level(tracer): - tracer.record_calculation_start('a', 2017) +def test_stack_one_level(tracer) -> None: + tracer.record_calculation_start("a", 2017) + assert len(tracer.stack) == 1 - assert tracer.stack == [{'name': 'a', 'period': 2017}] + assert tracer.stack == [{"name": "a", "period": 2017}] tracer.record_calculation_end() + assert tracer.stack == [] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_two_levels(tracer): - tracer.record_calculation_start('a', 2017) - tracer.record_calculation_start('b', 2017) +def test_stack_two_levels(tracer) -> None: + tracer.record_calculation_start("a", 2017) + tracer.record_calculation_start("b", 2017) + assert len(tracer.stack) == 2 - assert tracer.stack == [{'name': 'a', 'period': 2017}, {'name': 'b', 'period': 2017}] + assert tracer.stack == [ + {"name": "a", "period": 2017}, + {"name": "b", "period": 2017}, + ] tracer.record_calculation_end() + assert len(tracer.stack) == 1 - assert tracer.stack == [{'name': 'a', 'period': 2017}] + assert tracer.stack == [{"name": "a", "period": 2017}] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_tracer_contract(tracer): +def test_tracer_contract(tracer) -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.calculate('a', 2017) + simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded assert simulation.tracer.calculation_end_recorded -def test_exception_robustness(): +def test_exception_robustness() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.exception = Exception(":-o") + simulation.exception = TestException(":-o") - with raises(Exception): - simulation.calculate('a', 2017) + with raises(TestException): + simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded assert simulation.tracer.calculation_end_recorded @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_cycle_error(tracer): +def test_cycle_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer - tracer.record_calculation_start('a', 2017) - simulation._check_for_cycle('a', 2017) - tracer.record_calculation_start('a', 2017) + tracer.record_calculation_start("a", 2017) + + assert not simulation._check_for_cycle("a", 2017) + + tracer.record_calculation_start("a", 2017) + with raises(CycleError): - simulation._check_for_cycle('a', 2017) + simulation._check_for_cycle("a", 2017) + + assert len(tracer.stack) == 2 + assert tracer.stack == [ + {"name": "a", "period": 2017}, + {"name": "a", "period": 2017}, + ] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_spiral_error(tracer): +def test_spiral_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer - tracer.record_calculation_start('a', 2017) - tracer.record_calculation_start('a', 2016) - tracer.record_calculation_start('a', 2015) + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) with raises(SpiralError): - simulation._check_for_cycle('a', 2015) + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(tracer.stack) == 3 + assert tracer.stack == [ + {"name": "a", "period": periods.period(2017)}, + {"name": "b", "period": periods.period(2016)}, + {"name": "a", "period": periods.period(2016)}, + ] -def test_full_tracer_one_calculation(tracer): - tracer._enter_calculation('a', 2017) +def test_full_tracer_one_calculation(tracer) -> None: + tracer._enter_calculation("a", 2017) tracer._exit_calculation() + assert tracer.stack == [] assert len(tracer.trees) == 1 - assert tracer.trees[0].name == 'a' + assert tracer.trees[0].name == "a" assert tracer.trees[0].period == 2017 assert tracer.trees[0].children == [] -def test_full_tracer_2_branches(tracer): - tracer._enter_calculation('a', 2017) - - tracer._enter_calculation('b', 2017) +def test_full_tracer_2_branches(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer._enter_calculation("b", 2017) tracer._exit_calculation() - - tracer._enter_calculation('c', 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() - tracer._exit_calculation() assert len(tracer.trees) == 1 assert len(tracer.trees[0].children) == 2 -def test_full_tracer_2_trees(tracer): - tracer._enter_calculation('b', 2017) +def test_full_tracer_2_trees(tracer) -> None: + tracer._enter_calculation("b", 2017) tracer._exit_calculation() - - tracer._enter_calculation('c', 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() assert len(tracer.trees) == 2 -def test_full_tracer_3_generations(tracer): - tracer._enter_calculation('a', 2017) - tracer._enter_calculation('b', 2017) - tracer._enter_calculation('c', 2017) +def test_full_tracer_3_generations(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer._enter_calculation("b", 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() tracer._exit_calculation() tracer._exit_calculation() @@ -164,117 +192,118 @@ def test_full_tracer_3_generations(tracer): assert len(tracer.trees[0].children[0].children) == 1 -def test_full_tracer_variable_nb_requests(tracer): - tracer._enter_calculation('a', '2017-01') - tracer._enter_calculation('a', '2017-02') +def test_full_tracer_variable_nb_requests(tracer) -> None: + tracer._enter_calculation("a", "2017-01") + tracer._enter_calculation("a", "2017-02") - assert tracer.get_nb_requests('a') == 2 + assert tracer.get_nb_requests("a") == 2 -def test_simulation_calls_record_calculation_result(): +def test_simulation_calls_record_calculation_result() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.calculate('a', 2017) + simulation.calculate("a", 2017) assert simulation.tracer.recorded_result -def test_record_calculation_result(tracer): - tracer._enter_calculation('a', 2017) - tracer.record_calculation_result(np.asarray(100)) +def test_record_calculation_result(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer.record_calculation_result(numpy.asarray(100)) tracer._exit_calculation() assert tracer.trees[0].value == 100 -def test_flat_trace(tracer): - tracer._enter_calculation('a', 2019) - tracer._enter_calculation('b', 2019) +def test_flat_trace(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer._enter_calculation("b", 2019) tracer._exit_calculation() tracer._exit_calculation() trace = tracer.get_flat_trace() assert len(trace) == 2 - assert trace['a<2019>']['dependencies'] == ['b<2019>'] - assert trace['b<2019>']['dependencies'] == [] + assert trace["a<2019>"]["dependencies"] == ["b<2019>"] + assert trace["b<2019>"]["dependencies"] == [] -def test_flat_trace_serialize_vectorial_values(tracer): - tracer._enter_calculation('a', 2019) - tracer.record_parameter_access('x.y.z', 2019, np.asarray([100, 200, 300])) - tracer.record_calculation_result(np.asarray([10, 20, 30])) +def test_flat_trace_serialize_vectorial_values(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer.record_parameter_access("x.y.z", 2019, numpy.asarray([100, 200, 300])) + tracer.record_calculation_result(numpy.asarray([10, 20, 30])) tracer._exit_calculation() trace = tracer.get_serialized_flat_trace() - assert json.dumps(trace['a<2019>']['value']) - assert json.dumps(trace['a<2019>']['parameters']['x.y.z<2019>']) + assert json.dumps(trace["a<2019>"]["value"]) + assert json.dumps(trace["a<2019>"]["parameters"]["x.y.z<2019>"]) -def test_flat_trace_with_parameter(tracer): - tracer._enter_calculation('a', 2019) - tracer.record_parameter_access('p', '2019-01-01', 100) +def test_flat_trace_with_parameter(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer.record_parameter_access("p", "2019-01-01", 100) tracer._exit_calculation() trace = tracer.get_flat_trace() assert len(trace) == 1 - assert trace['a<2019>']['parameters'] == {'p<2019-01-01>': 100} + assert trace["a<2019>"]["parameters"] == {"p<2019-01-01>": 100} -def test_flat_trace_with_cache(tracer): - tracer._enter_calculation('a', 2019) - tracer._enter_calculation('b', 2019) - tracer._enter_calculation('c', 2019) +def test_flat_trace_with_cache(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer._enter_calculation("b", 2019) + tracer._enter_calculation("c", 2019) tracer._exit_calculation() tracer._exit_calculation() tracer._exit_calculation() - tracer._enter_calculation('b', 2019) + tracer._enter_calculation("b", 2019) tracer._exit_calculation() trace = tracer.get_flat_trace() - assert trace['b<2019>']['dependencies'] == ['c<2019>'] + assert trace["b<2019>"]["dependencies"] == ["c<2019>"] -def test_calculation_time(): +def test_calculation_time() -> None: tracer = FullTracer() - tracer._enter_calculation('a', 2019) + tracer._enter_calculation("a", 2019) tracer._record_start_time(1500) tracer._record_end_time(2500) tracer._exit_calculation() - performance_json = tracer.performance_log._json() - assert performance_json['name'] == 'All calculations' - assert performance_json['value'] == 1000 - simulation_children = performance_json['children'] - assert simulation_children[0]['name'] == 'a<2019>' - assert simulation_children[0]['value'] == 1000 + assert performance_json["name"] == "All calculations" + assert performance_json["value"] == 1000 + + simulation_children = performance_json["children"] + + assert simulation_children[0]["name"] == "a<2019>" + assert simulation_children[0]["value"] == 1000 @fixture def tracer_calc_time(): tracer = FullTracer() - tracer._enter_calculation('a', 2019) + tracer._enter_calculation("a", 2019) tracer._record_start_time(1500) - tracer._enter_calculation('b', 2019) + tracer._enter_calculation("b", 2019) tracer._record_start_time(1600) tracer._record_end_time(2300) tracer._exit_calculation() - tracer._enter_calculation('c', 2019) + tracer._enter_calculation("c", 2019) tracer._record_start_time(2300) tracer._record_end_time(2400) tracer._exit_calculation() # Cache call - tracer._enter_calculation('c', 2019) + tracer._enter_calculation("c", 2019) tracer._record_start_time(2400) tracer._record_end_time(2410) tracer._exit_calculation() @@ -282,7 +311,7 @@ def tracer_calc_time(): tracer._record_end_time(2500) tracer._exit_calculation() - tracer._enter_calculation('a', 2018) + tracer._enter_calculation("a", 2018) tracer._record_start_time(1800) tracer._record_end_time(1800 + 200) tracer._exit_calculation() @@ -290,185 +319,243 @@ def tracer_calc_time(): return tracer -def test_calculation_time_with_depth(tracer_calc_time): +def test_calculation_time_with_depth(tracer_calc_time) -> None: tracer = tracer_calc_time performance_json = tracer.performance_log._json() - simulation_grand_children = performance_json['children'][0]['children'] + simulation_grand_children = performance_json["children"][0]["children"] - assert simulation_grand_children[0]['name'] == 'b<2019>' - assert simulation_grand_children[0]['value'] == 700 + assert simulation_grand_children[0]["name"] == "b<2019>" + assert simulation_grand_children[0]["value"] == 700 -def test_flat_trace_calc_time(tracer_calc_time): +def test_flat_trace_calc_time(tracer_calc_time) -> None: tracer = tracer_calc_time flat_trace = tracer.get_flat_trace() - assert flat_trace['a<2019>']['calculation_time'] == 1000 - assert flat_trace['b<2019>']['calculation_time'] == 700 - assert flat_trace['c<2019>']['calculation_time'] == 100 - assert flat_trace['a<2019>']['formula_time'] == 190 # 1000 - 700 - 100 - 10 - assert flat_trace['b<2019>']['formula_time'] == 700 - assert flat_trace['c<2019>']['formula_time'] == 100 + assert flat_trace["a<2019>"]["calculation_time"] == 1000 + assert flat_trace["b<2019>"]["calculation_time"] == 700 + assert flat_trace["c<2019>"]["calculation_time"] == 100 + assert flat_trace["a<2019>"]["formula_time"] == 190 # 1000 - 700 - 100 - 10 + assert flat_trace["b<2019>"]["formula_time"] == 700 + assert flat_trace["c<2019>"]["formula_time"] == 100 -def test_generate_performance_table(tracer_calc_time, tmpdir): +def test_generate_performance_table(tracer_calc_time, tmpdir) -> None: tracer = tracer_calc_time tracer.generate_performance_tables(tmpdir) - with open(os.path.join(tmpdir, 'performance_table.csv'), 'r') as csv_file: + + with open(os.path.join(tmpdir, "performance_table.csv")) as csv_file: csv_reader = csv.DictReader(csv_file) csv_rows = list(csv_reader) + assert len(csv_rows) == 4 - a_row = next(row for row in csv_rows if row['name'] == 'a<2019>') - assert float(a_row['calculation_time']) == 1000 - assert float(a_row['formula_time']) == 190 - with open(os.path.join(tmpdir, 'aggregated_performance_table.csv'), 'r') as csv_file: + a_row = next(row for row in csv_rows if row["name"] == "a<2019>") + + assert float(a_row["calculation_time"]) == 1000 + assert float(a_row["formula_time"]) == 190 + + with open(os.path.join(tmpdir, "aggregated_performance_table.csv")) as csv_file: aggregated_csv_reader = csv.DictReader(csv_file) aggregated_csv_rows = list(aggregated_csv_reader) + assert len(aggregated_csv_rows) == 3 - a_row = next(row for row in aggregated_csv_rows if row['name'] == 'a') - assert float(a_row['calculation_time']) == 1000 + 200 - assert float(a_row['formula_time']) == 190 + 200 + a_row = next(row for row in aggregated_csv_rows if row["name"] == "a") -def test_get_aggregated_calculation_times(tracer_calc_time): - perf_log = tracer_calc_time.performance_log - aggregated_calculation_times = perf_log.aggregate_calculation_times(tracer_calc_time.get_flat_trace()) + assert float(a_row["calculation_time"]) == 1000 + 200 + assert float(a_row["formula_time"]) == 190 + 200 - assert aggregated_calculation_times['a']['calculation_time'] == 1000 + 200 - assert aggregated_calculation_times['a']['formula_time'] == 190 + 200 - assert aggregated_calculation_times['a']['avg_calculation_time'] == (1000 + 200) / 2 - assert aggregated_calculation_times['a']['avg_formula_time'] == (190 + 200) / 2 +def test_get_aggregated_calculation_times(tracer_calc_time) -> None: + perf_log = tracer_calc_time.performance_log + aggregated_calculation_times = perf_log.aggregate_calculation_times( + tracer_calc_time.get_flat_trace(), + ) + + assert aggregated_calculation_times["a"]["calculation_time"] == 1000 + 200 + assert aggregated_calculation_times["a"]["formula_time"] == 190 + 200 + assert aggregated_calculation_times["a"]["avg_calculation_time"] == (1000 + 200) / 2 + assert aggregated_calculation_times["a"]["avg_formula_time"] == (190 + 200) / 2 -def test_rounding(): - node_a = TraceNode('a', 2017) +def test_rounding() -> None: + node_a = TraceNode("a", 2017) node_a.start = 1.23456789 node_a.end = node_a.start + 1.23456789e-03 assert node_a.calculation_time() == 1.235e-03 # Keep only 3 significant figures - node_b = TraceNode('b', 2017) + node_b = TraceNode("b", 2017) node_b.start = node_a.start node_b.end = node_a.end - 1.23456789e-08 node_a.children = [node_b] - assert node_a.formula_time() == 1.235e-08 # The rounding should not prevent from calculating a precise formula_time + assert ( + node_a.formula_time() == 1.235e-08 + ) # The rounding should not prevent from calculating a precise formula_time -def test_variable_stats(tracer): +def test_variable_stats(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2016) - assert tracer.get_nb_requests('B') == 3 - assert tracer.get_nb_requests('A') == 1 - assert tracer.get_nb_requests('C') == 0 + assert tracer.get_nb_requests("B") == 3 + assert tracer.get_nb_requests("A") == 1 + assert tracer.get_nb_requests("C") == 0 -def test_log_format(tracer): +def test_log_format(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - tracer.record_calculation_result(np.asarray([2])) + tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() - assert lines[0] == ' A<2017> >> [2]' - assert lines[1] == ' B<2017> >> [1]' + + assert lines[0] == " A<2017> >> [2]" + assert lines[1] == " B<2017> >> [1]" -def test_log_format_forest(tracer): +def test_log_format_forest(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - tracer._enter_calculation("B", 2017) - tracer.record_calculation_result(np.asarray([2])) + tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() - assert lines[0] == ' A<2017> >> [1]' - assert lines[1] == ' B<2017> >> [2]' + assert lines[0] == " A<2017> >> [1]" + assert lines[1] == " B<2017> >> [2]" -def test_log_aggregate(tracer): + +def test_log_aggregate(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) assert lines[0] == " A<2017> >> {'avg': 1.0, 'max': 1, 'min': 1}" -def test_log_aggregate_with_enum(tracer): +def test_log_aggregate_with_enum(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.repeat('tenant', 100))) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), + ) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) - assert lines[0] == " A<2017> >> {'avg': EnumArray(HousingOccupancyStatus.tenant), 'max': EnumArray(HousingOccupancyStatus.tenant), 'min': EnumArray(HousingOccupancyStatus.tenant)}" + assert ( + lines[0] + == " A<2017> >> {'avg': EnumArray([HousingOccupancyStatus.tenant]), 'max': EnumArray([HousingOccupancyStatus.tenant]), 'min': EnumArray([HousingOccupancyStatus.tenant])}" + ) -def test_log_aggregate_with_strings(tracer): +def test_log_aggregate_with_strings(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.repeat('foo', 100)) + tracer.record_calculation_result(numpy.repeat("foo", 100)) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) assert lines[0] == " A<2017> >> {'avg': '?', 'max': '?', 'min': '?'}" -def test_no_wrapping(tracer): +def test_log_max_depth(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.repeat('tenant', 100))) + tracer._enter_calculation("B", 2017) + tracer._enter_calculation("C", 2017) + tracer.record_calculation_result(numpy.asarray([3])) tracer._exit_calculation() + tracer.record_calculation_result(numpy.asarray([2])) + tracer._exit_calculation() + tracer.record_calculation_result(numpy.asarray([1])) + tracer._exit_calculation() + + assert len(tracer.computation_log.lines()) == 3 + assert len(tracer.computation_log.lines(max_depth=4)) == 3 + assert len(tracer.computation_log.lines(max_depth=3)) == 3 + assert len(tracer.computation_log.lines(max_depth=2)) == 2 + assert len(tracer.computation_log.lines(max_depth=1)) == 1 + assert len(tracer.computation_log.lines(max_depth=0)) == 0 + +def test_no_wrapping(tracer) -> None: + tracer._enter_calculation("A", 2017) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), + ) + tracer._exit_calculation() lines = tracer.computation_log.lines() + assert "'tenant'" in lines[0] assert "\n" not in lines[0] -def test_trace_enums(tracer): +def test_trace_enums(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.array(['tenant']))) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.array(["tenant"])), + ) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert lines[0] == " A<2017> >> ['tenant']" # Tests on tracing with fancy indexing -zone = np.asarray(['z1', 'z2', 'z2', 'z1']) -housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) -family_status = np.asarray(['single', 'couple', 'single', 'couple']) +zone = numpy.asarray(["z1", "z2", "z2", "z1"]) +housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) +family_status = numpy.asarray(["single", "couple", "single", "couple"]) -def check_tracing_params(accessor, param_key): +def check_tracing_params(accessor, param_key) -> None: tracer = FullTracer() - tracer._enter_calculation('A', '2015-01') - tracingParams = TracingParameterNodeAtInstant(parameters('2015-01-01'), tracer) + + tracer._enter_calculation("A", "2015-01") + + tracingParams = TracingParameterNodeAtInstant(parameters("2015-01-01"), tracer) param = accessor(tracingParams) + assert tracer.trees[0].parameters[0].name == param_key assert tracer.trees[0].parameters[0].value == approx(param) -@mark.parametrize("test", [ - (lambda P: P.rate.single.owner.z1, 'rate.single.owner.z1'), # basic case - (lambda P: P.rate.single.owner[zone], 'rate.single.owner'), # fancy indexing on leaf - (lambda P: P.rate.single[housing_occupancy_status].z1, 'rate.single'), # on a node - (lambda P: P.rate.single[housing_occupancy_status][zone], 'rate.single'), # double fancy indexing - (lambda P: P.rate[family_status][housing_occupancy_status].z2, 'rate'), # double + node - (lambda P: P.rate[family_status][housing_occupancy_status][zone], 'rate'), # triple - ]) -def test_parameters(test): +@mark.parametrize( + "test", + [ + (lambda P: P.rate.single.owner.z1, "rate.single.owner.z1"), # basic case + ( + lambda P: P.rate.single.owner[zone], + "rate.single.owner", + ), # fancy indexing on leaf + ( + lambda P: P.rate.single[housing_occupancy_status].z1, + "rate.single", + ), # on a node + ( + lambda P: P.rate.single[housing_occupancy_status][zone], + "rate.single", + ), # double fancy indexing + ( + lambda P: P.rate[family_status][housing_occupancy_status].z2, + "rate", + ), # double + node + ( + lambda P: P.rate[family_status][housing_occupancy_status][zone], + "rate", + ), # triple + ], +) +def test_parameters(test) -> None: check_tracing_params(*test) -def test_browse_trace(): +def test_browse_trace() -> None: tracer = FullTracer() tracer._enter_calculation("B", 2017) @@ -481,6 +568,6 @@ def test_browse_trace(): tracer._enter_calculation("F", 2017) tracer._exit_calculation() tracer._exit_calculation() - browsed_nodes = [node.name for node in tracer.browse_trace()] - assert browsed_nodes == ['B', 'C', 'D', 'E', 'F'] + + assert browsed_nodes == ["B", "C", "D", "E", "F"] diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py index f63e37ff39..0560941b4c 100644 --- a/tests/core/test_yaml.py +++ b/tests/core/test_yaml.py @@ -1,11 +1,12 @@ import os import subprocess +import sys import pytest + import openfisca_extension_template from openfisca_core.tools.test_runner import run_tests - from tests.fixtures import yaml_tests yaml_tests_dir = os.path.dirname(yaml_tests.__file__) @@ -13,93 +14,124 @@ EXIT_TESTSFAILED = 1 -def run_yaml_test(tax_benefit_system, path, options = None): +def run_yaml_test(tax_benefit_system, path, options=None): yaml_path = os.path.join(yaml_tests_dir, path) if options is None: options = {} - result = run_tests(tax_benefit_system, yaml_path, options) - return result + return run_tests(tax_benefit_system, yaml_path, options) -def test_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_success.yml') == EXIT_OK +def test_success(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_success.yml") == EXIT_OK -def test_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_failure.yaml') == EXIT_TESTSFAILED +def test_fail(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_failure.yaml") == EXIT_TESTSFAILED -def test_relative_error_margin_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_relative_error_margin.yaml') == EXIT_OK +def test_relative_error_margin_success(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml") == EXIT_OK + ) -def test_relative_error_margin_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'failing_test_relative_error_margin.yaml') == EXIT_TESTSFAILED +def test_relative_error_margin_fail(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "failing_test_relative_error_margin.yaml") + == EXIT_TESTSFAILED + ) -def test_absolute_error_margin_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_absolute_error_margin.yaml') == EXIT_OK +def test_absolute_error_margin_success(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml") == EXIT_OK + ) -def test_absolute_error_margin_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'failing_test_absolute_error_margin.yaml') == EXIT_TESTSFAILED +def test_absolute_error_margin_fail(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "failing_test_absolute_error_margin.yaml") + == EXIT_TESTSFAILED + ) -def test_run_tests_from_directory(tax_benefit_system): - dir_path = os.path.join(yaml_tests_dir, 'directory') +def test_run_tests_from_directory(tax_benefit_system) -> None: + dir_path = os.path.join(yaml_tests_dir, "directory") assert run_yaml_test(tax_benefit_system, dir_path) == EXIT_OK -def test_with_reform(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_reform.yaml') == EXIT_OK +def test_with_reform(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK -def test_with_extension(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_extension.yaml') == EXIT_OK +def test_with_extension(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_extension.yaml") == EXIT_OK -def test_with_anchors(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_anchors.yaml') == EXIT_OK +def test_with_anchors(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK -def test_run_tests_from_directory_fail(tax_benefit_system): +def test_run_tests_from_directory_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED -def test_name_filter(tax_benefit_system): - assert run_yaml_test( - tax_benefit_system, - yaml_tests_dir, - options = {'name_filter': 'success'} - ) == EXIT_OK +def test_name_filter(tax_benefit_system) -> None: + assert ( + run_yaml_test( + tax_benefit_system, + yaml_tests_dir, + options={"name_filter": "success"}, + ) + == EXIT_OK + ) -def test_shell_script(): - yaml_path = os.path.join(yaml_tests_dir, 'test_success.yml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_country_template'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) +def test_shell_script() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_success.yml") + command = ["openfisca", "test", yaml_path, "-c", "openfisca_country_template"] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_failing_shell_script(): - yaml_path = os.path.join(yaml_tests_dir, 'test_failure.yaml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_dummy_country'] - with open(os.devnull, 'wb') as devnull: +def test_failing_shell_script() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_failure.yaml") + command = ["openfisca", "test", yaml_path, "-c", "openfisca_dummy_country"] + with open(os.devnull, "wb") as devnull: with pytest.raises(subprocess.CalledProcessError): - subprocess.check_call(command, stdout = devnull, stderr = devnull) - - -def test_shell_script_with_reform(): - yaml_path = os.path.join(yaml_tests_dir, 'test_with_reform_2.yaml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_country_template', '-r', 'openfisca_country_template.reforms.removal_basic_income.removal_basic_income'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) - - -def test_shell_script_with_extension(): - tests_dir = os.path.join(openfisca_extension_template.__path__[0], 'tests') - command = ['openfisca', 'test', tests_dir, '-c', 'openfisca_country_template', '-e', 'openfisca_extension_template'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) + subprocess.check_call(command, stdout=devnull, stderr=devnull) + + +def test_shell_script_with_reform() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_with_reform_2.yaml") + command = [ + "openfisca", + "test", + yaml_path, + "-c", + "openfisca_country_template", + "-r", + "openfisca_country_template.reforms.removal_basic_income.removal_basic_income", + ] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) + + +# TODO(Mauko Quiroga-Alvarado): Fix this test +# https://github.com/openfisca/openfisca-core/issues/962 +@pytest.mark.skipif(sys.platform == "win32", reason="Does not work on Windows.") +def test_shell_script_with_extension() -> None: + tests_dir = os.path.join(openfisca_extension_template.__path__[0], "tests") + command = [ + "openfisca", + "test", + tests_dir, + "-c", + "openfisca_country_template", + "-e", + "openfisca_extension_template", + ] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) diff --git a/tests/core/tools/test_assert_near.py b/tests/core/tools/test_assert_near.py index eecf9d1d1f..c351be0f9c 100644 --- a/tests/core/tools/test_assert_near.py +++ b/tests/core/tools/test_assert_near.py @@ -1,21 +1,25 @@ -import numpy as np +import numpy from openfisca_core.tools import assert_near -def test_date(): - assert_near(np.array("2012-03-24", dtype = 'datetime64[D]'), "2012-03-24") +def test_date() -> None: + assert_near(numpy.array("2012-03-24", dtype="datetime64[D]"), "2012-03-24") -def test_enum(tax_benefit_system): - possible_values = tax_benefit_system.variables['housing_occupancy_status'].possible_values - value = possible_values.encode(np.array(['tenant'])) - expected_value = 'tenant' +def test_enum(tax_benefit_system) -> None: + possible_values = tax_benefit_system.variables[ + "housing_occupancy_status" + ].possible_values + value = possible_values.encode(numpy.array(["tenant"])) + expected_value = "tenant" assert_near(value, expected_value) -def test_enum_2(tax_benefit_system): - possible_values = tax_benefit_system.variables['housing_occupancy_status'].possible_values - value = possible_values.encode(np.array(['tenant', 'owner'])) - expected_value = ['tenant', 'owner'] +def test_enum_2(tax_benefit_system) -> None: + possible_values = tax_benefit_system.variables[ + "housing_occupancy_status" + ].possible_values + value = possible_values.encode(numpy.array(["tenant", "owner"])) + expected_value = ["tenant", "owner"] assert_near(value, expected_value) diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index a8cd55c154..6a02d14cef 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -1,21 +1,20 @@ import os -from typing import List +import numpy import pytest -import numpy as np -from openfisca_core.tools.test_runner import _get_tax_benefit_system, YamlItem, YamlFile -from openfisca_core.errors import VariableNotFoundError -from openfisca_core.variables import Variable -from openfisca_core.populations import Population +from openfisca_core import errors from openfisca_core.entities import Entity -from openfisca_core.periods import ETERNITY +from openfisca_core.periods import DateUnit +from openfisca_core.populations import Population +from openfisca_core.tools.test_runner import YamlFile, YamlItem, _get_tax_benefit_system +from openfisca_core.variables import Variable class TaxBenefitSystem: - def __init__(self): - self.variables = {'salary': TestVariable()} - self.person_entity = Entity('person', 'persons', None, "") + def __init__(self) -> None: + self.variables = {"salary": TestVariable()} + self.person_entity = Entity("person", "persons", None, "") self.person_entity.set_tax_benefit_system(self) def get_package_metadata(self): @@ -24,7 +23,7 @@ def get_package_metadata(self): def apply_reform(self, path): return Reform(self) - def load_extension(self, extension): + def load_extension(self, extension) -> None: pass def entities_by_singular(self): @@ -34,9 +33,9 @@ def entities_plural(self): return {} def instantiate_entities(self): - return {'person': Population(self.person_entity)} + return {"person": Population(self.person_entity)} - def get_variable(self, variable_name, check_existence = True): + def get_variable(self, variable_name: str, check_existence=True): return self.variables.get(variable_name) def clone(self): @@ -44,106 +43,118 @@ def clone(self): class Reform(TaxBenefitSystem): - def __init__(self, baseline): + def __init__(self, baseline) -> None: self.baseline = baseline class Simulation: - def __init__(self): + def __init__(self) -> None: self.populations = {"person": None} - def get_population(self, plural = None): + def get_population(self, plural=None) -> None: return None class TestFile(YamlFile): - - def __init__(self): + def __init__(self) -> None: self.config = None self.session = None - self._nodeid = 'testname' + self._nodeid = "testname" class TestItem(YamlItem): - def __init__(self, test): - super().__init__('', TestFile(), TaxBenefitSystem(), test, {}) + def __init__(self, test) -> None: + super().__init__("", TestFile(), TaxBenefitSystem(), test, {}) self.tax_benefit_system = self.baseline_tax_benefit_system self.simulation = Simulation() class TestVariable(Variable): - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY value_type = float - def __init__(self): + def __init__(self) -> None: self.end = None - self.entity = Entity('person', 'persons', None, "") + self.entity = Entity("person", "persons", None, "") self.is_neutralized = False self.set_input = None - self.dtype = np.float32 + self.dtype = numpy.float32 -def test_variable_not_found(): +@pytest.mark.skip(reason="Deprecated node constructor") +def test_variable_not_found() -> None: test = {"output": {"unknown_variable": 0}} - with pytest.raises(VariableNotFoundError) as excinfo: + with pytest.raises(errors.VariableNotFoundError) as excinfo: test_item = TestItem(test) test_item.check_output() assert excinfo.value.variable_name == "unknown_variable" -def test_tax_benefit_systems_with_reform_cache(): +def test_tax_benefit_systems_with_reform_cache() -> None: baseline = TaxBenefitSystem() - ab_tax_benefit_system = _get_tax_benefit_system(baseline, 'ab', []) - ba_tax_benefit_system = _get_tax_benefit_system(baseline, 'ba', []) + ab_tax_benefit_system = _get_tax_benefit_system(baseline, "ab", []) + ba_tax_benefit_system = _get_tax_benefit_system(baseline, "ba", []) assert ab_tax_benefit_system != ba_tax_benefit_system -def test_reforms_formats(): +def test_reforms_formats() -> None: baseline = TaxBenefitSystem() - lonely_reform_tbs = _get_tax_benefit_system(baseline, 'lonely_reform', []) - list_lonely_reform_tbs = _get_tax_benefit_system(baseline, ['lonely_reform'], []) + lonely_reform_tbs = _get_tax_benefit_system(baseline, "lonely_reform", []) + list_lonely_reform_tbs = _get_tax_benefit_system(baseline, ["lonely_reform"], []) assert lonely_reform_tbs == list_lonely_reform_tbs -def test_reforms_order(): +def test_reforms_order() -> None: baseline = TaxBenefitSystem() - abba_tax_benefit_system = _get_tax_benefit_system(baseline, ['ab', 'ba'], []) - baab_tax_benefit_system = _get_tax_benefit_system(baseline, ['ba', 'ab'], []) - assert abba_tax_benefit_system != baab_tax_benefit_system # keep reforms order in cache + abba_tax_benefit_system = _get_tax_benefit_system(baseline, ["ab", "ba"], []) + baab_tax_benefit_system = _get_tax_benefit_system(baseline, ["ba", "ab"], []) + assert ( + abba_tax_benefit_system != baab_tax_benefit_system + ) # keep reforms order in cache -def test_tax_benefit_systems_with_extensions_cache(): +def test_tax_benefit_systems_with_extensions_cache() -> None: baseline = TaxBenefitSystem() - xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], 'xy') - yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], 'yx') + xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], "xy") + yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], "yx") assert xy_tax_benefit_system != yx_tax_benefit_system -def test_extensions_formats(): +def test_extensions_formats() -> None: baseline = TaxBenefitSystem() - lonely_extension_tbs = _get_tax_benefit_system(baseline, [], 'lonely_extension') - list_lonely_extension_tbs = _get_tax_benefit_system(baseline, [], ['lonely_extension']) + lonely_extension_tbs = _get_tax_benefit_system(baseline, [], "lonely_extension") + list_lonely_extension_tbs = _get_tax_benefit_system( + baseline, + [], + ["lonely_extension"], + ) assert lonely_extension_tbs == list_lonely_extension_tbs -def test_extensions_order(): +def test_extensions_order() -> None: baseline = TaxBenefitSystem() - xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ['x', 'y']) - yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], ['y', 'x']) - assert xy_tax_benefit_system == yx_tax_benefit_system # extensions order is ignored in cache + xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["x", "y"]) + yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["y", "x"]) + assert ( + xy_tax_benefit_system == yx_tax_benefit_system + ) # extensions order is ignored in cache -def test_performance_graph_option_output(): - test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} +@pytest.mark.skip(reason="Deprecated node constructor") +def test_performance_graph_option_output() -> None: + test = { + "input": {"salary": {"2017-01": 2000}}, + "output": {"salary": {"2017-01": 2000}}, + } test_item = TestItem(test) - test_item.options = {'performance_graph': True} + test_item.options = {"performance_graph": True} paths = ["./performance_graph.html"] @@ -158,10 +169,14 @@ def test_performance_graph_option_output(): clean_performance_files(paths) -def test_performance_tables_option_output(): - test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} +@pytest.mark.skip(reason="Deprecated node constructor") +def test_performance_tables_option_output() -> None: + test = { + "input": {"salary": {"2017-01": 2000}}, + "output": {"salary": {"2017-01": 2000}}, + } test_item = TestItem(test) - test_item.options = {'performance_tables': True} + test_item.options = {"performance_tables": True} paths = ["performance_table.csv", "aggregated_performance_table.csv"] @@ -176,7 +191,7 @@ def test_performance_tables_option_output(): clean_performance_files(paths) -def clean_performance_files(paths: List[str]): +def clean_performance_files(paths: list[str]) -> None: for path in paths: if os.path.isfile(path): os.remove(path) diff --git a/tests/core/variables/test_annualize.py b/tests/core/variables/test_annualize.py index 62b0a79b14..58ea1372dd 100644 --- a/tests/core/variables/test_annualize.py +++ b/tests/core/variables/test_annualize.py @@ -1,25 +1,25 @@ -import numpy as np +import numpy from pytest import fixture -from openfisca_core import periods -from openfisca_core.model_api import * # noqa analysis:ignore from openfisca_country_template.entities import Person -from openfisca_core.variables import get_annualized_variable + +from openfisca_core import periods +from openfisca_core.periods import DateUnit +from openfisca_core.variables import Variable, get_annualized_variable @fixture def monthly_variable(): - calculation_count = 0 class monthly_variable(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(person, period, parameters): + def formula(self, period, parameters): variable.calculation_count += 1 - return np.asarray([100]) + return numpy.asarray([100]) variable = monthly_variable() variable.calculation_count = calculation_count @@ -30,55 +30,57 @@ def formula(person, period, parameters): class PopulationMock: # Simulate a population for whom a variable has already been put in cache for January. - def __init__(self, variable): + def __init__(self, variable) -> None: self.variable = variable - def __call__(self, variable_name, period): + def __call__(self, variable_name: str, period): if period.start.month == 1: - return np.asarray([100]) - else: - return self.variable.get_formula(period)(self, period, None) + return numpy.asarray([100]) + return self.variable.get_formula(period)(self, period, None) -def test_without_annualize(monthly_variable): +def test_without_annualize(monthly_variable) -> None: period = periods.period(2019) person = PopulationMock(monthly_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 11 assert yearly_sum == 1200 -def test_with_annualize(monthly_variable): +def test_with_annualize(monthly_variable) -> None: period = periods.period(2019) annualized_variable = get_annualized_variable(monthly_variable) person = PopulationMock(annualized_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 0 assert yearly_sum == 100 * 12 -def test_with_partial_annualize(monthly_variable): - period = periods.period('year:2018:2') - annualized_variable = get_annualized_variable(monthly_variable, periods.period(2018)) +def test_with_partial_annualize(monthly_variable) -> None: + period = periods.period("year:2018:2") + annualized_variable = get_annualized_variable( + monthly_variable, + periods.period(2018), + ) person = PopulationMock(annualized_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 11 assert yearly_sum == 100 * 12 * 2 diff --git a/tests/core/variables/test_definition_period.py b/tests/core/variables/test_definition_period.py new file mode 100644 index 0000000000..8ef9bfaa87 --- /dev/null +++ b/tests/core/variables/test_definition_period.py @@ -0,0 +1,43 @@ +import pytest + +from openfisca_core import periods +from openfisca_core.variables import Variable + + +@pytest.fixture +def variable(persons): + class TestVariable(Variable): + value_type = float + entity = persons + + return TestVariable + + +def test_weekday_variable(variable) -> None: + variable.definition_period = periods.WEEKDAY + assert variable() + + +def test_week_variable(variable) -> None: + variable.definition_period = periods.WEEK + assert variable() + + +def test_day_variable(variable) -> None: + variable.definition_period = periods.DAY + assert variable() + + +def test_month_variable(variable) -> None: + variable.definition_period = periods.MONTH + assert variable() + + +def test_year_variable(variable) -> None: + variable.definition_period = periods.YEAR + assert variable() + + +def test_eternity_variable(variable) -> None: + variable.definition_period = periods.ETERNITY + assert variable() diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index f01ce7c480..475071218b 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- - import datetime -from openfisca_core.model_api import Variable -from openfisca_core.periods import MONTH, ETERNITY -from openfisca_core.simulations import SimulationBuilder -from openfisca_core.tools import assert_near +from pytest import fixture, mark, raises import openfisca_country_template as country_template import openfisca_country_template.situation_examples from openfisca_country_template.entities import Person -from pytest import fixture, raises, mark +from openfisca_core.periods import DateUnit +from openfisca_core.simulations import SimulationBuilder +from openfisca_core.tools import assert_near +from openfisca_core.variables import Variable # Check which date is applied whether it comes from Variable attribute (end) # or formula(s) dates. @@ -22,27 +20,39 @@ # HELPERS + @fixture def couple(): - return SimulationBuilder().build_from_entities(tax_benefit_system, openfisca_country_template.situation_examples.couple) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + openfisca_country_template.situation_examples.couple, + ) @fixture def simulation(): - return SimulationBuilder().build_from_entities(tax_benefit_system, openfisca_country_template.situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + openfisca_country_template.situation_examples.single, + ) def vectorize(individu, number): return individu.filled_array(number) -def check_error_at_add_variable(tax_benefit_system, variable, error_message_prefix): +def check_error_at_add_variable( + tax_benefit_system, variable, error_message_prefix +) -> None: try: tax_benefit_system.add_variable(variable) except ValueError as e: message = get_message(e) if not message or not message.startswith(error_message_prefix): - raise AssertionError('Incorrect error message. Was expecting something starting by "{}". Got: "{}"'.format(error_message_prefix, message)) + msg = f'Incorrect error message. Was expecting something starting by "{error_message_prefix}". Got: "{message}"' + raise AssertionError( + msg, + ) def get_message(error): @@ -58,104 +68,111 @@ def get_message(error): class variable__no_date(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without date." -def test_before_add__variable__no_date(): - assert tax_benefit_system.variables.get('variable__no_date') is None +def test_before_add__variable__no_date() -> None: + assert tax_benefit_system.variables.get("variable__no_date") is None -def test_variable__no_date(): +def test_variable__no_date() -> None: tax_benefit_system.add_variable(variable__no_date) - variable = tax_benefit_system.variables['variable__no_date'] + variable = tax_benefit_system.variables["variable__no_date"] assert variable.end is None assert len(variable.formulas) == 0 + # end, no formula class variable__strange_end_attribute(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with dubious end attribute, no formula." - end = '1989-00-00' + end = "1989-00-00" -def test_variable__strange_end_attribute(): +def test_variable__strange_end_attribute() -> None: try: tax_benefit_system.add_variable(variable__strange_end_attribute) except ValueError as e: message = get_message(e) - assert message.startswith("Incorrect 'end' attribute format in 'variable__strange_end_attribute'.") + assert message.startswith( + "Incorrect 'end' attribute format in 'variable__strange_end_attribute'.", + ) # Check that Error at variable adding prevents it from registration in the taxbenefitsystem. - assert not tax_benefit_system.variables.get('variable__strange_end_attribute') + assert not tax_benefit_system.variables.get("variable__strange_end_attribute") # end, no formula + class variable__end_attribute(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, no formula." - end = '1989-12-31' + end = "1989-12-31" tax_benefit_system.add_variable(variable__end_attribute) -def test_variable__end_attribute(): - variable = tax_benefit_system.variables['variable__end_attribute'] +def test_variable__end_attribute() -> None: + variable = tax_benefit_system.variables["variable__end_attribute"] assert variable.end == datetime.date(1989, 12, 31) -def test_variable__end_attribute_set_input(simulation): - month_before_end = '1989-01' - month_after_end = '1990-01' - simulation.set_input('variable__end_attribute', month_before_end, 10) - simulation.set_input('variable__end_attribute', month_after_end, 10) - assert simulation.calculate('variable__end_attribute', month_before_end) == 10 - assert simulation.calculate('variable__end_attribute', month_after_end) == 0 +def test_variable__end_attribute_set_input(simulation) -> None: + month_before_end = "1989-01" + month_after_end = "1990-01" + simulation.set_input("variable__end_attribute", month_before_end, 10) + simulation.set_input("variable__end_attribute", month_after_end, 10) + assert simulation.calculate("variable__end_attribute", month_before_end) == 10 + assert simulation.calculate("variable__end_attribute", month_after_end) == 0 # end, one formula without date + class end_attribute__one_simple_formula(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, one formula without date." - end = '1989-12-31' + end = "1989-12-31" - def formula(individu, period): - return vectorize(individu, 100) + def formula(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_simple_formula) -def test_formulas_attributes_single_formula(): - formulas = tax_benefit_system.variables['end_attribute__one_simple_formula'].formulas - assert formulas['0001-01-01'] is not None +def test_formulas_attributes_single_formula() -> None: + formulas = tax_benefit_system.variables[ + "end_attribute__one_simple_formula" + ].formulas + assert formulas["0001-01-01"] is not None -def test_call__end_attribute__one_simple_formula(simulation): - month = '1979-12' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 100 +def test_call__end_attribute__one_simple_formula(simulation) -> None: + month = "1979-12" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 - month = '1989-12' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 100 + month = "1989-12" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 - month = '1990-01' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 0 + month = "1990-01" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 0 -def test_dates__end_attribute__one_simple_formula(): - variable = tax_benefit_system.variables['end_attribute__one_simple_formula'] +def test_dates__end_attribute__one_simple_formula() -> None: + variable = tax_benefit_system.variables["end_attribute__one_simple_formula"] assert variable.end == datetime.date(1989, 12, 31) assert len(variable.formulas) == 1 @@ -167,86 +184,93 @@ def test_dates__end_attribute__one_simple_formula(): # formula, strange name + class no_end_attribute__one_formula__strange_name(Variable): value_type = int entity = Person - definition_period = MONTH - label = "Variable without end attribute, one stangely named formula." + definition_period = DateUnit.MONTH + label = "Variable without end attribute, one strangely named formula." - def formula_2015_toto(individu, period): - return vectorize(individu, 100) + def formula_2015_toto(self, period): + return vectorize(self, 100) -def test_add__no_end_attribute__one_formula__strange_name(): - check_error_at_add_variable(tax_benefit_system, no_end_attribute__one_formula__strange_name, - 'Unrecognized formula name in variable "no_end_attribute__one_formula__strange_name". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: ') +def test_add__no_end_attribute__one_formula__strange_name() -> None: + check_error_at_add_variable( + tax_benefit_system, + no_end_attribute__one_formula__strange_name, + 'Unrecognized formula name in variable "no_end_attribute__one_formula__strange_name". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: ', + ) # formula, start + class no_end_attribute__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__start) -def test_call__no_end_attribute__one_formula__start(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 0 +def test_call__no_end_attribute__one_formula__start(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 0 - month = '2000-05' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 100 + month = "2000-05" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 - month = '2020-01' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 100 + month = "2020-01" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 -def test_dates__no_end_attribute__one_formula__start(): - variable = tax_benefit_system.variables['no_end_attribute__one_formula__start'] +def test_dates__no_end_attribute__one_formula__start() -> None: + variable = tax_benefit_system.variables["no_end_attribute__one_formula__start"] assert variable.end is None assert len(variable.formulas) == 1 - assert variable.formulas.keys()[0] == '2000-01-01' + assert variable.formulas.keys()[0] == "2000-01-01" class no_end_attribute__one_formula__eternity(Variable): value_type = int entity = Person - definition_period = ETERNITY # For this entity, this variable shouldn't evolve through time + definition_period = ( + DateUnit.ETERNITY + ) # For this entity, this variable shouldn't evolve through time label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__eternity) @mark.xfail() -def test_call__no_end_attribute__one_formula__eternity(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 0 +def test_call__no_end_attribute__one_formula__eternity(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 # This fails because a definition period of "ETERNITY" caches for all periods - month = '2000-01' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 100 + month = "2000-01" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 -def test_call__no_end_attribute__one_formula__eternity_before(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 0 +def test_call__no_end_attribute__one_formula__eternity_before(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 -def test_call__no_end_attribute__one_formula__eternity_after(simulation): - month = '2000-01' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 100 +def test_call__no_end_attribute__one_formula__eternity_after(simulation) -> None: + month = "2000-01" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 # formula, different start formats @@ -255,97 +279,123 @@ def test_call__no_end_attribute__one_formula__eternity_after(simulation): class no_end_attribute__formulas__start_formats(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without end attribute, multiple dated formulas." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2010_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_end_attribute__formulas__start_formats) -def test_formulas_attributes_dated_formulas(): - formulas = tax_benefit_system.variables['no_end_attribute__formulas__start_formats'].formulas - assert(len(formulas) == 2) - assert formulas['2000-01-01'] is not None - assert formulas['2010-01-01'] is not None +def test_formulas_attributes_dated_formulas() -> None: + formulas = tax_benefit_system.variables[ + "no_end_attribute__formulas__start_formats" + ].formulas + assert len(formulas) == 2 + assert formulas["2000-01-01"] is not None + assert formulas["2010-01-01"] is not None -def test_get_formulas(): - variable = tax_benefit_system.variables['no_end_attribute__formulas__start_formats'] - formula_2000 = variable.formulas['2000-01-01'] - formula_2010 = variable.formulas['2010-01-01'] +def test_get_formulas() -> None: + variable = tax_benefit_system.variables["no_end_attribute__formulas__start_formats"] + formula_2000 = variable.formulas["2000-01-01"] + formula_2010 = variable.formulas["2010-01-01"] - assert variable.get_formula('1999-01') is None - assert variable.get_formula('2000-01') == formula_2000 - assert variable.get_formula('2009-12') == formula_2000 - assert variable.get_formula('2009-12-31') == formula_2000 - assert variable.get_formula('2010-01') == formula_2010 - assert variable.get_formula('2010-01-01') == formula_2010 + assert variable.get_formula("1999-01") is None + assert variable.get_formula("2000-01") == formula_2000 + assert variable.get_formula("2009-12") == formula_2000 + assert variable.get_formula("2009-12-31") == formula_2000 + assert variable.get_formula("2010-01") == formula_2010 + assert variable.get_formula("2010-01-01") == formula_2010 -def test_call__no_end_attribute__formulas__start_formats(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 0 +def test_call__no_end_attribute__formulas__start_formats(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__formulas__start_formats", month) == 0 - month = '2000-01' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 100 + month = "2000-01" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100 + ) - month = '2009-12' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 100 + month = "2009-12" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100 + ) - month = '2010-01' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 200 + month = "2010-01" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 200 + ) # Multiple formulas, different names with date overlap + class no_attribute__formulas__different_names__dates_overlap(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names but same dates." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2000_01_01(individu, period): - return vectorize(individu, 200) + def formula_2000_01_01(self, period): + return vectorize(self, 200) -def test_add__no_attribute__formulas__different_names__dates_overlap(): +def test_add__no_attribute__formulas__different_names__dates_overlap() -> None: # Variable isn't registered in the taxbenefitsystem - check_error_at_add_variable(tax_benefit_system, no_attribute__formulas__different_names__dates_overlap, "Dated formulas overlap") + check_error_at_add_variable( + tax_benefit_system, + no_attribute__formulas__different_names__dates_overlap, + "Dated formulas overlap", + ) # formula(start), different names, no date overlap + class no_attribute__formulas__different_names__no_overlap(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names and no date overlap." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2010_01_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_attribute__formulas__different_names__no_overlap) -def test_call__no_attribute__formulas__different_names__no_overlap(simulation): - month = '2009-12' - assert simulation.calculate('no_attribute__formulas__different_names__no_overlap', month) == 100 +def test_call__no_attribute__formulas__different_names__no_overlap(simulation) -> None: + month = "2009-12" + assert ( + simulation.calculate( + "no_attribute__formulas__different_names__no_overlap", + month, + ) + == 100 + ) - month = '2015-05' - assert simulation.calculate('no_attribute__formulas__different_names__no_overlap', month) == 200 + month = "2015-05" + assert ( + simulation.calculate( + "no_attribute__formulas__different_names__no_overlap", + month, + ) + == 200 + ) # END ATTRIBUTE - DATED FORMULA(S) @@ -353,123 +403,145 @@ def test_call__no_attribute__formulas__different_names__no_overlap(simulation): # formula, start. + class end_attribute__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, one dated formula." - end = '2001-12-31' + end = "2001-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_formula__start) -def test_call__end_attribute__one_formula__start(simulation): - month = '1980-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 0 +def test_call__end_attribute__one_formula__start(simulation) -> None: + month = "1980-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 0 - month = '2000-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 100 + month = "2000-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 100 - month = '2002-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 0 + month = "2002-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 0 # end < formula, start. + class stop_attribute_before__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with stop attribute only coming before formula start." - end = '1990-01-01' + end = "1990-01-01" - def formula_2000_01_01(individu, period): - return vectorize(individu, 0) + def formula_2000_01_01(self, period): + return vectorize(self, 0) -def test_add__stop_attribute_before__one_formula__start(): - check_error_at_add_variable(tax_benefit_system, stop_attribute_before__one_formula__start, 'You declared that "stop_attribute_before__one_formula__start" ends on "1990-01-01", but you wrote a formula to calculate it from "2000-01-01"') +def test_add__stop_attribute_before__one_formula__start() -> None: + check_error_at_add_variable( + tax_benefit_system, + stop_attribute_before__one_formula__start, + 'You declared that "stop_attribute_before__one_formula__start" ends on "1990-01-01", but you wrote a formula to calculate it from "2000-01-01"', + ) # end, formula with dates intervals overlap. + class end_attribute_restrictive__one_formula(Variable): value_type = int entity = Person - definition_period = MONTH - label = "Variable with end attribute, one dated formula and dates intervals overlap." - end = '2001-01-01' + definition_period = DateUnit.MONTH + label = ( + "Variable with end attribute, one dated formula and dates intervals overlap." + ) + end = "2001-01-01" - def formula_2001_01_01(individu, period): - return vectorize(individu, 100) + def formula_2001_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute_restrictive__one_formula) -def test_call__end_attribute_restrictive__one_formula(simulation): - month = '2000-12' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 0 +def test_call__end_attribute_restrictive__one_formula(simulation) -> None: + month = "2000-12" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 - month = '2001-01' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 100 + month = "2001-01" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 100 - month = '2000-05' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 0 + month = "2000-05" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 # formulas of different names (without dates overlap on formulas) + class end_attribute__formulas__different_names(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, multiple dated formulas with different names." - end = '2010-12-31' + end = "2010-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2005_01_01(individu, period): - return vectorize(individu, 200) + def formula_2005_01_01(self, period): + return vectorize(self, 200) - def formula_2010_01_01(individu, period): - return vectorize(individu, 300) + def formula_2010_01_01(self, period): + return vectorize(self, 300) tax_benefit_system.add_variable(end_attribute__formulas__different_names) -def test_call__end_attribute__formulas__different_names(simulation): - month = '2000-01' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 100 +def test_call__end_attribute__formulas__different_names(simulation) -> None: + month = "2000-01" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 100 + ) - month = '2005-01' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 200 + month = "2005-01" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 200 + ) - month = '2010-12' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 300 + month = "2010-12" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 300 + ) -def test_get_formula(simulation): +def test_get_formula(simulation) -> None: person = simulation.person - disposable_income_formula = tax_benefit_system.get_variable('disposable_income').get_formula() - disposable_income = person('disposable_income', '2017-01') - disposable_income_2 = disposable_income_formula(person, '2017-01', None) # No need for parameters here + disposable_income_formula = tax_benefit_system.get_variable( + "disposable_income", + ).get_formula() + disposable_income = person("disposable_income", "2017-01") + disposable_income_2 = disposable_income_formula( + person, + "2017-01", + None, + ) # No need for parameters here assert_near(disposable_income, disposable_income_2) -def test_unexpected_attr(): +def test_unexpected_attr() -> None: class variable_with_strange_attr(Variable): value_type = int entity = Person - definition_period = MONTH - unexpected = '???' + definition_period = DateUnit.MONTH + unexpected = "???" with raises(ValueError): tax_benefit_system.add_variable(variable_with_strange_attr) diff --git a/tests/fixtures/appclient.py b/tests/fixtures/appclient.py index a140e0f938..692747d393 100644 --- a/tests/fixtures/appclient.py +++ b/tests/fixtures/appclient.py @@ -5,7 +5,7 @@ @pytest.fixture(scope="module") def test_client(tax_benefit_system): - """ This module-scoped fixture creates an API client for the TBS defined in the `tax_benefit_system` + """This module-scoped fixture creates an API client for the TBS defined in the `tax_benefit_system` fixture. This `tax_benefit_system` is mutable, so you can add/update variables. Example: @@ -15,20 +15,22 @@ def test_client(tax_benefit_system): from openfisca_country_template import entities from openfisca_core import periods from openfisca_core.variables import Variable + ... + class new_variable(Variable): value_type = float entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH label = "New variable" reference = "https://law.gov.example/new_variable" # Always use the most official source + tax_benefit_system.add_variable(new_variable) flask_app = app.create_app(tax_benefit_system) """ - # Create the test API client flask_app = app.create_app(tax_benefit_system) return flask_app.test_client() diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 99b6196599..6670a68da1 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -6,22 +6,30 @@ class TestEntity(Entity): - def get_variable(self, variable_name): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True class TestGroupEntity(GroupEntity): - def get_variable(self, variable_name): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True @@ -32,13 +40,9 @@ def persons(): @pytest.fixture def households(): - roles = [{ - 'key': 'parent', - 'plural': 'parents', - 'max': 2 - }, { - 'key': 'child', - 'plural': 'children' - }] + roles = [ + {"key": "parent", "plural": "parents", "max": 2}, + {"key": "child", "plural": "children"}, + ] return TestGroupEntity("household", "households", "", "", roles) diff --git a/tests/fixtures/extensions.py b/tests/fixtures/extensions.py new file mode 100644 index 0000000000..bc4e85fe72 --- /dev/null +++ b/tests/fixtures/extensions.py @@ -0,0 +1,18 @@ +from importlib import metadata + +import pytest + + +@pytest.fixture +def test_country_package_name() -> str: + return "openfisca_country_template" + + +@pytest.fixture +def test_extension_package_name() -> str: + return "openfisca_extension_template" + + +@pytest.fixture +def distribution(test_country_package_name): + return metadata.distribution(test_country_package_name) diff --git a/tests/fixtures/simulations.py b/tests/fixtures/simulations.py index 9d343d5ac0..53120b60d9 100644 --- a/tests/fixtures/simulations.py +++ b/tests/fixtures/simulations.py @@ -14,7 +14,7 @@ def simulation(tax_benefit_system, request): tax_benefit_system, variables, period, - ) + ) @pytest.fixture @@ -24,8 +24,4 @@ def make_simulation(): def _simulation(simulation_builder, tax_benefit_system, variables, period): simulation_builder.set_default_period(period) - simulation = \ - simulation_builder \ - .build_from_variables(tax_benefit_system, variables) - - return simulation + return simulation_builder.build_from_variables(tax_benefit_system, variables) diff --git a/tests/fixtures/taxbenefitsystems.py b/tests/fixtures/taxbenefitsystems.py index c2c47071ca..d29dfd73fd 100644 --- a/tests/fixtures/taxbenefitsystems.py +++ b/tests/fixtures/taxbenefitsystems.py @@ -3,7 +3,7 @@ from openfisca_country_template import CountryTaxBenefitSystem -@pytest.fixture(scope = "module") +@pytest.fixture(scope="module") def tax_benefit_system(): return CountryTaxBenefitSystem() diff --git a/tests/fixtures/variables.py b/tests/fixtures/variables.py index cd0d9b70ce..2deccf5891 100644 --- a/tests/fixtures/variables.py +++ b/tests/fixtures/variables.py @@ -1,11 +1,11 @@ -from openfisca_core import periods +from openfisca_core.periods import DateUnit from openfisca_core.variables import Variable class TestVariable(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = float - def __init__(self, entity): + def __init__(self, entity) -> None: self.__class__.entity = entity super().__init__() diff --git a/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml b/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml index a51ae6894e..4928b06711 100644 --- a/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml +++ b/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 351 # 300 + + +- name: "Failing test: result out of variable specific absolute error margin" + period: 2015-01 + absolute_error_margin: + default: 100 + income_tax: 50 + input: + salary: 2000 + output: + income_tax: 351 # 300 diff --git a/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml b/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml index 9258946c3d..c0788cfa96 100644 --- a/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml +++ b/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 316 # 300 + + +- name: "Failing test: result out of variable specific relative error margin" + period: 2015-01 + relative_error_margin: + default: 1 + income_tax: 0.05 + input: + salary: 2000 + output: + income_tax: 316 # 300 diff --git a/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml b/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml index be7de2d5cb..65dbb308e3 100644 --- a/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml +++ b/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 350 # 300 + + +- name: "Result within absolute error margin" + period: 2015-01 + absolute_error_margin: + default: 100 + income_tax: 50 + input: + salary: 2000 + output: + income_tax: 350 # 300 diff --git a/tests/fixtures/yaml_tests/test_name_filter.yaml b/tests/fixtures/yaml_tests/test_name_filter.yaml index 9bca0e050d..e1aa1894a7 100644 --- a/tests/fixtures/yaml_tests/test_name_filter.yaml +++ b/tests/fixtures/yaml_tests/test_name_filter.yaml @@ -1,11 +1,11 @@ -- name: "Test that sould be run because the magic word success is in its title" +- name: "Test that should be run because the magic word success is in its title" period: 2015-01 input: salary: 2000 output: income_tax: 0.15 * 2000 -- name: "Test that sould be run because the magic word is in its keywords" +- name: "Test that should be run because the magic word is in its keywords" keywords: - some keyword - success diff --git a/tests/fixtures/yaml_tests/test_relative_error_margin.yaml b/tests/fixtures/yaml_tests/test_relative_error_margin.yaml index 7845d6f361..d39a9e4143 100644 --- a/tests/fixtures/yaml_tests/test_relative_error_margin.yaml +++ b/tests/fixtures/yaml_tests/test_relative_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 290 # 300 + + +- name: "Result within variable relative error margin" + period: 2015-01 + relative_error_margin: + default: .001 + income_tax: 0.05 + input: + salary: 2000 + output: + income_tax: 290 # 300 diff --git a/tests/web_api/__init__.py b/tests/web_api/__init__.py index 8098c2a5a2..e69de29bb2 100644 --- a/tests/web_api/__init__.py +++ b/tests/web_api/__init__.py @@ -1,4 +0,0 @@ -import pkg_resources - -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) diff --git a/tests/web_api/basic_case/__init__.py b/tests/web_api/basic_case/__init__.py deleted file mode 100644 index 4114c06467..0000000000 --- a/tests/web_api/basic_case/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -import pkg_resources -from openfisca_web_api.app import create_app -from openfisca_core.scripts import build_tax_benefit_system - -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) -tax_benefit_system = build_tax_benefit_system(TEST_COUNTRY_PACKAGE_NAME, extensions = None, reforms = None) -subject = create_app(tax_benefit_system).test_client() diff --git a/tests/web_api/case_with_extension/test_extensions.py b/tests/web_api/case_with_extension/test_extensions.py index 4da94bf45c..2c688232f8 100644 --- a/tests/web_api/case_with_extension/test_extensions.py +++ b/tests/web_api/case_with_extension/test_extensions.py @@ -1,28 +1,39 @@ -# -*- coding: utf-8 -*- +from http import client + +import pytest -from http.client import OK from openfisca_core.scripts import build_tax_benefit_system from openfisca_web_api.app import create_app -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -TEST_EXTENSION_PACKAGE_NAMES = ['openfisca_extension_template'] +@pytest.fixture +def tax_benefit_system(test_country_package_name, test_extension_package_name): + return build_tax_benefit_system( + test_country_package_name, + extensions=[test_extension_package_name], + reforms=None, + ) -tax_benefit_system = build_tax_benefit_system(TEST_COUNTRY_PACKAGE_NAME, extensions = TEST_EXTENSION_PACKAGE_NAMES, reforms = None) -extended_subject = create_app(tax_benefit_system).test_client() +@pytest.fixture +def extended_subject(tax_benefit_system): + return create_app(tax_benefit_system).test_client() -def test_return_code(): - parameters_response = extended_subject.get('/parameters') - assert parameters_response.status_code == OK +def test_return_code(extended_subject) -> None: + parameters_response = extended_subject.get("/parameters") + assert parameters_response.status_code == client.OK -def test_return_code_existing_parameter(): - extension_parameter_response = extended_subject.get('/parameter/local_town.child_allowance.amount') - assert extension_parameter_response.status_code == OK +def test_return_code_existing_parameter(extended_subject) -> None: + extension_parameter_response = extended_subject.get( + "/parameter/local_town.child_allowance.amount", + ) + assert extension_parameter_response.status_code == client.OK -def test_return_code_existing_variable(): - extension_variable_response = extended_subject.get('/variable/local_town_child_allowance') - assert extension_variable_response.status_code == OK +def test_return_code_existing_variable(extended_subject) -> None: + extension_variable_response = extended_subject.get( + "/variable/local_town_child_allowance", + ) + assert extension_variable_response.status_code == client.OK diff --git a/tests/web_api/case_with_reform/test_reforms.py b/tests/web_api/case_with_reform/test_reforms.py index 5037a4b395..f0895cf189 100644 --- a/tests/web_api/case_with_reform/test_reforms.py +++ b/tests/web_api/case_with_reform/test_reforms.py @@ -1,62 +1,65 @@ import http + import pytest from openfisca_core import scripts from openfisca_web_api import app -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -TEST_REFORMS_PATHS = [ - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_dynamic_variable.add_dynamic_variable", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_new_tax.add_new_tax", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.flat_social_security_contribution.flat_social_security_contribution", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.modify_social_security_taxation.modify_social_security_taxation", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.removal_basic_income.removal_basic_income", + +@pytest.fixture +def test_reforms_path(test_country_package_name): + return [ + f"{test_country_package_name}.reforms.add_dynamic_variable.add_dynamic_variable", + f"{test_country_package_name}.reforms.add_new_tax.add_new_tax", + f"{test_country_package_name}.reforms.flat_social_security_contribution.flat_social_security_contribution", + f"{test_country_package_name}.reforms.modify_social_security_taxation.modify_social_security_taxation", + f"{test_country_package_name}.reforms.removal_basic_income.removal_basic_income", ] # Create app as in 'openfisca serve' script @pytest.fixture -def client(): +def client(test_country_package_name, test_reforms_path): tax_benefit_system = scripts.build_tax_benefit_system( - TEST_COUNTRY_PACKAGE_NAME, - extensions = None, - reforms = TEST_REFORMS_PATHS, - ) + test_country_package_name, + extensions=None, + reforms=test_reforms_path, + ) return app.create_app(tax_benefit_system).test_client() -def test_return_code_of_dynamic_variable(client): +def test_return_code_of_dynamic_variable(client) -> None: result = client.get("/variable/goes_to_school") assert result.status_code == http.client.OK -def test_return_code_of_has_car_variable(client): +def test_return_code_of_has_car_variable(client) -> None: result = client.get("/variable/has_car") assert result.status_code == http.client.OK -def test_return_code_of_new_tax_variable(client): +def test_return_code_of_new_tax_variable(client) -> None: result = client.get("/variable/new_tax") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_variable(client): +def test_return_code_of_social_security_contribution_variable(client) -> None: result = client.get("/variable/social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_parameter(client): +def test_return_code_of_social_security_contribution_parameter(client) -> None: result = client.get("/parameter/taxes.social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_basic_income_variable(client): +def test_return_code_of_basic_income_variable(client) -> None: result = client.get("/variable/basic_income") assert result.status_code == http.client.OK diff --git a/tests/web_api/loader/test_parameters.py b/tests/web_api/loader/test_parameters.py index 232bd24c26..f44632ce49 100644 --- a/tests/web_api/loader/test_parameters.py +++ b/tests/web_api/loader/test_parameters.py @@ -1,35 +1,70 @@ -# -*- coding: utf-8 -*- +from openfisca_core.parameters import Scale +from openfisca_web_api.loader.parameters import build_api_parameter, build_api_scale -from openfisca_core.parameters import ParameterScale -from openfisca_web_api.loader.parameters import build_api_scale, build_api_parameter +def test_build_rate_scale() -> None: + """Extracts a 'rate' children from a bracket collection.""" + data = { + "brackets": [ + { + "rate": {"2014-01-01": {"value": 0.5}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + rate = Scale("this rate", data, None) + assert build_api_scale(rate, "rate") == {"2014-01-01": {1: 0.5}} -def test_build_rate_scale(): - '''Extracts a 'rate' children from a bracket collection''' - data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = ParameterScale('this rate', data, None) - assert build_api_scale(rate, 'rate') == {'2014-01-01': {1: 0.5}} +def test_build_amount_scale() -> None: + """Extracts an 'amount' children from a bracket collection.""" + data = { + "brackets": [ + { + "amount": {"2014-01-01": {"value": 0}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + rate = Scale("that amount", data, None) + assert build_api_scale(rate, "amount") == {"2014-01-01": {1: 0}} -def test_build_amount_scale(): - '''Extracts an 'amount' children from a bracket collection''' - data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = ParameterScale('that amount', data, None) - assert build_api_scale(rate, 'amount') == {'2014-01-01': {1: 0}} - - -def test_full_rate_scale(): - '''Serializes a 'rate' scale parameter''' - data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = ParameterScale('rate', data, None) +def test_full_rate_scale() -> None: + """Serializes a 'rate' scale parameter.""" + data = { + "brackets": [ + { + "rate": {"2014-01-01": {"value": 0.5}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + scale = Scale("rate", data, None) api_scale = build_api_parameter(scale, {}) - assert api_scale == {'description': None, 'id': 'rate', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0.5}}} + assert api_scale == { + "description": None, + "id": "rate", + "metadata": {}, + "brackets": {"2014-01-01": {1: 0.5}}, + } -def test_walk_node_amount_scale(): - '''Serializes an 'amount' scale parameter ''' - data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = ParameterScale('amount', data, None) +def test_walk_node_amount_scale() -> None: + """Serializes an 'amount' scale parameter.""" + data = { + "brackets": [ + { + "amount": {"2014-01-01": {"value": 0}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + scale = Scale("amount", data, None) api_scale = build_api_parameter(scale, {}) - assert api_scale == {'description': None, 'id': 'amount', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0}}} + assert api_scale == { + "description": None, + "id": "amount", + "metadata": {}, + "brackets": {"2014-01-01": {1: 0}}, + } diff --git a/tests/web_api/test_calculate.py b/tests/web_api/test_calculate.py index b3415810a7..4c82de5448 100644 --- a/tests/web_api/test_calculate.py +++ b/tests/web_api/test_calculate.py @@ -1,351 +1,464 @@ import copy -import dpath import json -from http import client import os +from http import client + +import dpath.util import pytest from openfisca_country_template.situation_examples import couple -def post_json(client, data = None, file = None): +def post_json(client, data=None, file=None): if file: - file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', file) - with open(file_path, 'r') as file: + file_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "assets", + file, + ) + with open(file_path) as file: data = file.read() - return client.post('/calculate', data = data, content_type = 'application/json') + return client.post("/calculate", data=data, content_type="application/json") -def check_response(client, data, expected_error_code, path_to_check, content_to_check): +def check_response( + client, data, expected_error_code, path_to_check, content_to_check +) -> None: response = post_json(client, data) assert response.status_code == expected_error_code - json_response = json.loads(response.data.decode('utf-8')) + json_response = json.loads(response.data.decode("utf-8")) if path_to_check: content = dpath.util.get(json_response, path_to_check) assert content_to_check in content -@pytest.mark.parametrize("test", [ - ('{"a" : "x", "b"}', client.BAD_REQUEST, 'error', 'Invalid JSON'), - ('["An", "array"]', client.BAD_REQUEST, 'error', 'Invalid type'), - ('{"persons": {}}', client.BAD_REQUEST, 'persons', 'At least one person'), - ('{"persons": {"bob": {}}, "unknown_entity": {}}', client.BAD_REQUEST, 'unknown_entity', 'entities are not found',), - ('{"persons": {"bob": {}}, "households": {"dupont": {"parents": {}}}}', client.BAD_REQUEST, 'households/dupont/parents', 'type',), - ('{"persons": {"bob": {"unknown_variable": {}}}}', client.NOT_FOUND, 'persons/bob/unknown_variable', 'You tried to calculate or to set',), - ('{"persons": {"bob": {"housing_allowance": {}}}}', client.BAD_REQUEST, 'persons/bob/housing_allowance', "You tried to compute the variable 'housing_allowance' for the entity 'persons'",), - ('{"persons": {"bob": {"salary": 4000 }}}', client.BAD_REQUEST, 'persons/bob/salary', 'period',), - ('{"persons": {"bob": {"salary": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/salary/2017-01', 'expected type number',), - ('{"persons": {"bob": {"salary": {"2017-01": {}} }}}', client.BAD_REQUEST, 'persons/bob/salary/2017-01', 'expected type number',), - ('{"persons": {"bob": {"age": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/age/2017-01', 'expected type integer',), - ('{"persons": {"bob": {"birth": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/birth/2017-01', 'Can\'t deal with date',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["unexpected_person_id"]}}}', client.BAD_REQUEST, 'households/household/parents', 'has not been declared in persons',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", "bob"]}}}', client.BAD_REQUEST, 'households/household/parents', 'has been declared more than once',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", {}]}}}', client.BAD_REQUEST, 'households/household/parents/1', 'Invalid type',), - ('{"persons": {"bob": {"salary": {"invalid period": 2000 }}}}', client.BAD_REQUEST, 'persons/bob/salary', 'Expected a period',), - ('{"persons": {"bob": {"salary": {"invalid period": null }}}}', client.BAD_REQUEST, 'persons/bob/salary', 'Expected a period',), - ('{"persons": {"bob": {"basic_income": {"2017": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', client.BAD_REQUEST, 'persons/bob/basic_income/2017', '"basic_income" can only be set for one month',), - ('{"persons": {"bob": {"salary": {"ETERNITY": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', client.BAD_REQUEST, 'persons/bob/salary/ETERNITY', 'salary is only defined for months',), - ('{"persons": {"alice": {}, "bob": {}, "charlie": {}}, "households": {"_": {"parents": ["alice", "bob", "charlie"]}}}', client.BAD_REQUEST, 'households/_/parents', 'at most 2 parents in a household',), - ]) -def test_responses(test_client, test): +@pytest.mark.parametrize( + "test", + [ + ('{"a" : "x", "b"}', client.BAD_REQUEST, "error", "Invalid JSON"), + ('["An", "array"]', client.BAD_REQUEST, "error", "Invalid type"), + ('{"persons": {}}', client.BAD_REQUEST, "persons", "At least one person"), + ( + '{"persons": {"bob": {}}, "unknown_entity": {}}', + client.BAD_REQUEST, + "unknown_entity", + "entities are not found", + ), + ( + '{"persons": {"bob": {}}, "households": {"dupont": {"parents": {}}}}', + client.BAD_REQUEST, + "households/dupont/parents", + "type", + ), + ( + '{"persons": {"bob": {"unknown_variable": {}}}}', + client.NOT_FOUND, + "persons/bob/unknown_variable", + "You tried to calculate or to set", + ), + ( + '{"persons": {"bob": {"housing_allowance": {}}}}', + client.BAD_REQUEST, + "persons/bob/housing_allowance", + "You tried to compute the variable 'housing_allowance' for the entity 'persons'", + ), + ( + '{"persons": {"bob": {"salary": 4000 }}}', + client.BAD_REQUEST, + "persons/bob/salary", + "period", + ), + ( + '{"persons": {"bob": {"salary": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/salary/2017-01", + "expected type number", + ), + ( + '{"persons": {"bob": {"salary": {"2017-01": {}} }}}', + client.BAD_REQUEST, + "persons/bob/salary/2017-01", + "expected type number", + ), + ( + '{"persons": {"bob": {"age": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/age/2017-01", + "expected type integer", + ), + ( + '{"persons": {"bob": {"birth": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/birth/2017-01", + "Can't deal with date", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["unexpected_person_id"]}}}', + client.BAD_REQUEST, + "households/household/parents", + "has not been declared in persons", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", "bob"]}}}', + client.BAD_REQUEST, + "households/household/parents", + "has been declared more than once", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", {}]}}}', + client.BAD_REQUEST, + "households/household/parents/1", + "Invalid type", + ), + ( + '{"persons": {"bob": {"salary": {"invalid period": 2000 }}}}', + client.BAD_REQUEST, + "persons/bob/salary", + "Expected a period", + ), + ( + '{"persons": {"bob": {"salary": {"invalid period": null }}}}', + client.BAD_REQUEST, + "persons/bob/salary", + "Expected a period", + ), + ( + '{"persons": {"bob": {"basic_income": {"2017": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', + client.BAD_REQUEST, + "persons/bob/basic_income/2017", + '"basic_income" can only be set for one month', + ), + ( + '{"persons": {"bob": {"salary": {"ETERNITY": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', + client.BAD_REQUEST, + "persons/bob/salary/ETERNITY", + "salary is only defined for months", + ), + ( + '{"persons": {"alice": {}, "bob": {}, "charlie": {}}, "households": {"_": {"parents": ["alice", "bob", "charlie"]}}}', + client.BAD_REQUEST, + "households/_/parents", + "at most 2 parents in a household", + ), + ], +) +def test_responses(test_client, test) -> None: check_response(test_client, *test) -def test_basic_calculation(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": { - "birth": { - "2017-12": "1980-01-01" - }, - "age": { - "2017-12": None - }, - "salary": { - "2017-12": 2000 - }, - "basic_income": { - "2017-12": None - }, - "income_tax": { - "2017-12": None - } +def test_basic_calculation(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": { + "birth": {"2017-12": "1980-01-01"}, + "age": {"2017-12": None}, + "salary": {"2017-12": 2000}, + "basic_income": {"2017-12": None}, + "income_tax": {"2017-12": None}, }, - "bob": { - "salary": { - "2017-12": 15000 - }, - "basic_income": { - "2017-12": None - }, - "social_security_contribution": { - "2017-12": None - } + "bob": { + "salary": {"2017-12": 15000}, + "basic_income": {"2017-12": None}, + "social_security_contribution": {"2017-12": None}, }, }, - "households": { - "first_household": { - "parents": ['bill', 'bob'], - "housing_tax": { - "2017": None - }, - "accommodation_size": { - "2017-01": 300 - } + "households": { + "first_household": { + "parents": ["bill", "bob"], + "housing_tax": {"2017": None}, + "accommodation_size": {"2017-01": 300}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, 'persons/bill/basic_income/2017-12') == 600 # Universal basic income - assert dpath.get(response_json, 'persons/bill/income_tax/2017-12') == 300 # 15% of the salary - assert dpath.get(response_json, 'persons/bill/age/2017-12') == 37 # 15% of the salary - assert dpath.get(response_json, 'persons/bob/basic_income/2017-12') == 600 - assert dpath.get(response_json, 'persons/bob/social_security_contribution/2017-12') == 816 # From social_security_contribution.yaml test - assert dpath.get(response_json, 'households/first_household/housing_tax/2017') == 3000 - - -def test_enums_sending_identifier(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {} + response_json = json.loads(response.data.decode("utf-8")) + assert ( + dpath.util.get(response_json, "persons/bill/basic_income/2017-12") == 600 + ) # Universal basic income + assert ( + dpath.util.get(response_json, "persons/bill/income_tax/2017-12") == 300 + ) # 15% of the salary + assert ( + dpath.util.get(response_json, "persons/bill/age/2017-12") == 37 + ) # 15% of the salary + assert dpath.util.get(response_json, "persons/bob/basic_income/2017-12") == 600 + assert ( + dpath.util.get( + response_json, + "persons/bob/social_security_contribution/2017-12", + ) + == 816 + ) # From social_security_contribution.yaml test + assert ( + dpath.util.get(response_json, "households/first_household/housing_tax/2017") + == 3000 + ) + + +def test_enums_sending_identifier(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None}, + "accommodation_size": {"2017-01": 300}, + "housing_occupancy_status": {"2017-01": "free_lodger"}, + }, }, - "households": { - "_": { - "parents": ["bill"], - "housing_tax": { - "2017": None - }, - "accommodation_size": { - "2017-01": 300 - }, - "housing_occupancy_status": { - "2017-01": "free_lodger" - } - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, 'households/_/housing_tax/2017') == 0 + response_json = json.loads(response.data.decode("utf-8")) + assert dpath.util.get(response_json, "households/_/housing_tax/2017") == 0 -def test_enum_output(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {}, +def test_enum_output(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - "housing_occupancy_status": { - "2017-01": None - } + "households": { + "_": { + "parents": ["bill"], + "housing_occupancy_status": {"2017-01": None}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, "households/_/housing_occupancy_status/2017-01") == "tenant" - - -def test_enum_wrong_value(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {}, + response_json = json.loads(response.data.decode("utf-8")) + assert ( + dpath.util.get(response_json, "households/_/housing_occupancy_status/2017-01") + == "tenant" + ) + + +def test_enum_wrong_value(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - "housing_occupancy_status": { - "2017-01": "Unknown value lodger" - } + "households": { + "_": { + "parents": ["bill"], + "housing_occupancy_status": {"2017-01": "Unknown value lodger"}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) message = "Possible values are ['owner', 'tenant', 'free_lodger', 'homeless']" - text = dpath.get(response_json, "households/_/housing_occupancy_status/2017-01") + text = dpath.util.get( + response_json, + "households/_/housing_occupancy_status/2017-01", + ) assert message in text -def test_encoding_variable_value(test_client): - simulation_json = json.dumps({ - "persons": { - "toto": {} - }, - "households": { - "_": { - "housing_occupancy_status": { - "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM" - +def test_encoding_variable_value(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"toto": {}}, + "households": { + "_": { + "housing_occupancy_status": { + "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM", }, - "parent": [ - "toto", - ] - } - } - }) + "parent": [ + "toto", + ], + }, + }, + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) - assert response.status_code == client.BAD_REQUEST, response.data.decode('utf-8') - response_json = json.loads(response.data.decode('utf-8')) + assert response.status_code == client.BAD_REQUEST, response.data.decode("utf-8") + response_json = json.loads(response.data.decode("utf-8")) message = "'Locataire ou sous-locataire d‘un logement loué vide non-HLM' is not a known value for 'housing_occupancy_status'. Possible values are " - text = dpath.get(response_json, 'households/_/housing_occupancy_status/2017-07') + text = dpath.util.get( + response_json, + "households/_/housing_occupancy_status/2017-07", + ) assert message in text -def test_encoding_entity_name(test_client): - simulation_json = json.dumps({ - "persons": { - "O‘Ryan": {}, - "Renée": {} - }, - "households": { - "_": { - "parents": [ - "O‘Ryan", - "Renée" - ] - } - } - }) +def test_encoding_entity_name(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"O‘Ryan": {}, "Renée": {}}, + "households": {"_": {"parents": ["O‘Ryan", "Renée"]}}, + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) # In Python 3, there is no encoding issue. if response.status_code != client.OK: message = "'O‘Ryan' is not a valid ASCII value." - text = response_json['error'] + text = response_json["error"] assert message in text -def test_encoding_period_id(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": { - "salary": { - "2017": 60000 - } +def test_encoding_period_id(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {"salary": {"2017": 60000}}, + "bell": {"salary": {"2017": 60000}}, + }, + "households": { + "_": { + "parents": ["bill", "bell"], + "housing_tax": {"à": 400}, + "accommodation_size": {"2017-01": 300}, + "housing_occupancy_status": {"2017-01": "tenant"}, }, - "bell": { - "salary": { - "2017": 60000 - } - } }, - "households": { - "_": { - "parents": ["bill", "bell"], - "housing_tax": { - "à": 400 - }, - "accommodation_size": { - "2017-01": 300 - }, - "housing_occupancy_status": { - "2017-01": "tenant" - } - } - } - }) + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) # In Python 3, there is no encoding issue. if "Expected a period" not in str(response.data): message = "'à' is not a valid ASCII value." - text = response_json['error'] + text = response_json["error"] assert message in text -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['postal_code'] = {'2017-01': None} + new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/calculate', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/calculate", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK -def test_periods(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {} +def test_periods(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None}, + "housing_occupancy_status": {"2017-01": None}, + }, }, - "households": { - "_": { - "parents": ["bill"], - "housing_tax": { - "2017": None - }, - "housing_occupancy_status": { - "2017-01": None - } - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) - yearly_variable = dpath.get(response_json, 'households/_/housing_tax') # web api year is an int - assert yearly_variable == {'2017': 200.0} + yearly_variable = dpath.util.get( + response_json, + "households/_/housing_tax", + ) # web api year is an int + assert yearly_variable == {"2017": 200.0} - monthly_variable = dpath.get(response_json, 'households/_/housing_occupancy_status') # web api month is a string - assert monthly_variable == {'2017-01': 'tenant'} + monthly_variable = dpath.util.get( + response_json, + "households/_/housing_occupancy_status", + ) # web api month is a string + assert monthly_variable == {"2017-01": "tenant"} -def test_handle_period_mismatch_error(test_client): +def test_two_periods(test_client) -> None: + """Test `calculate` on a request with mixed types periods: yearly periods following + monthly or daily periods to check dpath limitation on numeric keys (yearly periods). + Made to test the case where we have more than one path with a numeric in it. + See https://github.com/dpath-maintainers/dpath-python/issues/160 for more information. + """ + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None, "2018": None}, + "housing_occupancy_status": {"2017-01": None, "2018-01": None}, + }, + }, + }, + ) + response = post_json(test_client, simulation_json) + assert response.status_code == client.OK + + response_json = json.loads(response.data.decode("utf-8")) + + yearly_variable = dpath.util.get( + response_json, + "households/_/housing_tax", + ) # web api year is an int + assert yearly_variable == {"2017": 200.0, "2018": 200.0} + + monthly_variable = dpath.util.get( + response_json, + "households/_/housing_occupancy_status", + ) # web api month is a string + assert monthly_variable == {"2017-01": "tenant", "2018-01": "tenant"} + + +def test_handle_period_mismatch_error(test_client) -> None: variable = "housing_tax" period = "2017-01" - simulation_json = json.dumps({ - "persons": { - "bill": {} + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + variable: {period: 400}, + }, }, - "households": { - "_": { - "parents": ["bill"], - variable: { - period: 400 - }, - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST response_json = json.loads(response.data) - error = dpath.get(response_json, f'households/_/housing_tax/{period}') + error = dpath.util.get(response_json, f"households/_/housing_tax/{period}") message = f'Unable to set a value for variable "{variable}" for month-long period "{period}"' assert message in error -def test_gracefully_handle_unexpected_errors(test_client): - """ - Context +def test_gracefully_handle_unexpected_errors(test_client) -> None: + """Context. ======= Whenever an exception is raised by the calculation engine, the API will try @@ -368,19 +481,21 @@ def test_gracefully_handle_unexpected_errors(test_client): variable = "housing_tax" period = "1234-05-06" - simulation_json = json.dumps({ - "persons": { - "bill": {}, + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - variable: { - period: None, + "households": { + "_": { + "parents": ["bill"], + variable: { + period: None, }, - } - } - }) + }, + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.INTERNAL_SERVER_ERROR diff --git a/tests/web_api/test_entities.py b/tests/web_api/test_entities.py index 6f8153ed37..e7d0ef5b9b 100644 --- a/tests/web_api/test_entities.py +++ b/tests/web_api/test_entities.py @@ -1,37 +1,34 @@ -# -*- coding: utf-8 -*- - -from http import client import json +from http import client from openfisca_country_template import entities - # /entities -def test_return_code(test_client): - entities_response = test_client.get('/entities') +def test_return_code(test_client) -> None: + entities_response = test_client.get("/entities") assert entities_response.status_code == client.OK -def test_response_data(test_client): - entities_response = test_client.get('/entities') - entities_dict = json.loads(entities_response.data.decode('utf-8')) +def test_response_data(test_client) -> None: + entities_response = test_client.get("/entities") + entities_dict = json.loads(entities_response.data.decode("utf-8")) test_documentation = entities.Household.doc.strip() - assert entities_dict['household'] == { - 'description': 'All the people in a family or group who live together in the same place.', - 'documentation': test_documentation, - 'plural': 'households', - 'roles': { - 'child': { - 'description': 'Other individuals living in the household.', - 'plural': 'children', - }, - 'parent': { - 'description': 'The one or two adults in charge of the household.', - 'plural': 'parents', - 'max': 2, - } - } - } + assert entities_dict["household"] == { + "description": "All the people in a family or group who live together in the same place.", + "documentation": test_documentation, + "plural": "households", + "roles": { + "child": { + "description": "Other individuals living in the household.", + "plural": "children", + }, + "parent": { + "description": "The one or two adults in charge of the household.", + "plural": "parents", + "max": 2, + }, + }, + } diff --git a/tests/web_api/test_headers.py b/tests/web_api/test_headers.py index 54bbfd0df8..dc95437a09 100644 --- a/tests/web_api/test_headers.py +++ b/tests/web_api/test_headers.py @@ -1,13 +1,10 @@ -# -*- coding: utf-8 -*- +def test_package_name_header(test_client, distribution) -> None: + name = distribution.metadata.get("Name").lower() + parameters_response = test_client.get("/parameters") + assert parameters_response.headers.get("Country-Package") == name -from . import distribution - -def test_package_name_header(test_client): - parameters_response = test_client.get('/parameters') - assert parameters_response.headers.get('Country-Package') == distribution.key - - -def test_package_version_header(test_client): - parameters_response = test_client.get('/parameters') - assert parameters_response.headers.get('Country-Package-Version') == distribution.version +def test_package_version_header(test_client, distribution) -> None: + version = distribution.metadata.get("Version") + parameters_response = test_client.get("/parameters") + assert parameters_response.headers.get("Country-Package-Version") == version diff --git a/tests/web_api/test_helpers.py b/tests/web_api/test_helpers.py index cb049a0822..a1725cdfbf 100644 --- a/tests/web_api/test_helpers.py +++ b/tests/web_api/test_helpers.py @@ -1,53 +1,51 @@ import os -from openfisca_web_api.loader import parameters - from openfisca_core.parameters import load_parameter_file +from openfisca_web_api.loader import parameters - -dir_path = os.path.join(os.path.dirname(__file__), 'assets') +dir_path = os.path.join(os.path.dirname(__file__), "assets") -def test_build_api_values_history(): - file_path = os.path.join(dir_path, 'test_helpers.yaml') - parameter = load_parameter_file(name='dummy_name', file_path=file_path) +def test_build_api_values_history() -> None: + file_path = os.path.join(dir_path, "test_helpers.yaml") + parameter = load_parameter_file(name="dummy_name", file_path=file_path) values = { - '2017-01-01': 0.02, - '2015-01-01': 0.04, - '2013-01-01': 0.03, - } + "2017-01-01": 0.02, + "2015-01-01": 0.04, + "2013-01-01": 0.03, + } assert parameters.build_api_values_history(parameter) == values -def test_build_api_values_history_with_stop_date(): - file_path = os.path.join(dir_path, 'test_helpers_with_stop_date.yaml') - parameter = load_parameter_file(name='dummy_name', file_path=file_path) +def test_build_api_values_history_with_stop_date() -> None: + file_path = os.path.join(dir_path, "test_helpers_with_stop_date.yaml") + parameter = load_parameter_file(name="dummy_name", file_path=file_path) values = { - '2018-01-01': None, - '2017-01-01': 0.02, - '2015-01-01': 0.04, - '2013-01-01': 0.03, - } + "2018-01-01": None, + "2017-01-01": 0.02, + "2015-01-01": 0.04, + "2013-01-01": 0.03, + } assert parameters.build_api_values_history(parameter) == values -def test_get_value(): - values = {'2013-01-01': 0.03, '2017-01-01': 0.02, '2015-01-01': 0.04} +def test_get_value() -> None: + values = {"2013-01-01": 0.03, "2017-01-01": 0.02, "2015-01-01": 0.04} - assert parameters.get_value('2013-01-01', values) == 0.03 - assert parameters.get_value('2014-01-01', values) == 0.03 - assert parameters.get_value('2015-02-01', values) == 0.04 - assert parameters.get_value('2016-12-31', values) == 0.04 - assert parameters.get_value('2017-01-01', values) == 0.02 - assert parameters.get_value('2018-01-01', values) == 0.02 + assert parameters.get_value("2013-01-01", values) == 0.03 + assert parameters.get_value("2014-01-01", values) == 0.03 + assert parameters.get_value("2015-02-01", values) == 0.04 + assert parameters.get_value("2016-12-31", values) == 0.04 + assert parameters.get_value("2017-01-01", values) == 0.02 + assert parameters.get_value("2018-01-01", values) == 0.02 -def test_get_value_with_none(): - values = {'2015-01-01': 0.04, '2017-01-01': None} +def test_get_value_with_none() -> None: + values = {"2015-01-01": 0.04, "2017-01-01": None} - assert parameters.get_value('2016-12-31', values) == 0.04 - assert parameters.get_value('2017-01-01', values) is None - assert parameters.get_value('2011-01-01', values) is None + assert parameters.get_value("2016-12-31", values) == 0.04 + assert parameters.get_value("2017-01-01", values) is None + assert parameters.get_value("2011-01-01", values) is None diff --git a/tests/web_api/test_parameters.py b/tests/web_api/test_parameters.py index 8f65cca9af..2f9a00a642 100644 --- a/tests/web_api/test_parameters.py +++ b/tests/web_api/test_parameters.py @@ -1,128 +1,168 @@ -from http import client import json -import pytest import re +import sys +from http import client +import pytest # /parameters -GITHUB_URL_REGEX = r'^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$' +GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$" -def test_return_code(test_client): - parameters_response = test_client.get('/parameters') +def test_return_code(test_client) -> None: + parameters_response = test_client.get("/parameters") assert parameters_response.status_code == client.OK -def test_response_data(test_client): - parameters_response = test_client.get('/parameters') - parameters = json.loads(parameters_response.data.decode('utf-8')) +def test_response_data(test_client) -> None: + parameters_response = test_client.get("/parameters") + parameters = json.loads(parameters_response.data.decode("utf-8")) - assert parameters['taxes.income_tax_rate'] == { - 'description': 'Income tax rate', - 'href': 'http://localhost/parameter/taxes/income_tax_rate' - } - assert parameters.get('taxes') is None + assert parameters["taxes.income_tax_rate"] == { + "description": "Income tax rate", + "href": "http://localhost/parameter/taxes/income_tax_rate", + } + assert parameters.get("taxes") is None # /parameter/ -def test_error_code_non_existing_parameter(test_client): - response = test_client.get('/parameter/non/existing.parameter') + +def test_error_code_non_existing_parameter(test_client) -> None: + response = test_client.get("/parameter/non/existing.parameter") assert response.status_code == client.NOT_FOUND -def test_return_code_existing_parameter(test_client): - response = test_client.get('/parameter/taxes/income_tax_rate') +def test_return_code_existing_parameter(test_client) -> None: + response = test_client.get("/parameter/taxes/income_tax_rate") assert response.status_code == client.OK -def test_legacy_parameter_route(test_client): - response = test_client.get('/parameter/taxes.income_tax_rate') +def test_legacy_parameter_route(test_client) -> None: + response = test_client.get("/parameter/taxes.income_tax_rate") assert response.status_code == client.OK -def test_parameter_values(test_client): - response = test_client.get('/parameter/taxes/income_tax_rate') +# TODO(Mauko Quiroga-Alvarado): Fix this test +# https://github.com/openfisca/openfisca-core/issues/962 +@pytest.mark.skipif(sys.platform == "win32", reason="Does not work on Windows.") +def test_parameter_values(test_client) -> None: + response = test_client.get("/parameter/taxes/income_tax_rate") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'id', 'metadata', 'source', 'values'] - assert parameter['id'] == 'taxes.income_tax_rate' - assert parameter['description'] == 'Income tax rate' - assert parameter['values'] == {'2015-01-01': 0.15, '2014-01-01': 0.14, '2013-01-01': 0.13, '2012-01-01': 0.16} - assert parameter['metadata'] == {'unit': '/1'} - assert re.match(GITHUB_URL_REGEX, parameter['source']) - assert 'taxes/income_tax_rate.yaml' in parameter['source'] + assert sorted(parameter.keys()), [ + "description", + "id", + "metadata", + "source", + "values", + ] + assert parameter["id"] == "taxes.income_tax_rate" + assert parameter["description"] == "Income tax rate" + assert parameter["values"] == { + "2015-01-01": 0.15, + "2014-01-01": 0.14, + "2013-01-01": 0.13, + "2012-01-01": 0.16, + } + assert parameter["metadata"] == {"unit": "/1"} + assert re.match(GITHUB_URL_REGEX, parameter["source"]) + assert "taxes/income_tax_rate.yaml" in parameter["source"] # 'documentation' attribute exists only when a value is defined - response = test_client.get('/parameter/benefits/housing_allowance') + response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'documentation', 'id', 'metadata', 'source' == 'values'] + assert sorted(parameter.keys()), [ + "description", + "documentation", + "id", + "metadata", + "source" == "values", + ] assert ( - parameter['documentation'] == - 'A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.' - ) + parameter["documentation"] + == "A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists." + ) -def test_parameter_node(tax_benefit_system, test_client): - response = test_client.get('/parameter/benefits') +def test_parameter_node(tax_benefit_system, test_client) -> None: + response = test_client.get("/parameter/benefits") assert response.status_code == client.OK parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'documentation', 'id', 'metadata', 'source' == 'subparams'] - assert parameter['documentation'] == ( + assert sorted(parameter.keys()), [ + "description", + "documentation", + "id", + "metadata", + "source" == "subparams", + ] + assert parameter["documentation"] == ( "Government support for the citizens and residents of society." "\nThey may be provided to people of any income level, as with social security," "\nbut usually it is intended to ensure that everyone can meet their basic human needs" "\nsuch as food and shelter.\n(See https://en.wikipedia.org/wiki/Welfare)" - ) + ) model_benefits = tax_benefit_system.parameters.benefits - assert parameter['subparams'].keys() == model_benefits.children.keys(), parameter['subparams'].keys() + assert parameter["subparams"].keys() == model_benefits.children.keys(), parameter[ + "subparams" + ].keys() - assert 'description' in parameter['subparams']['basic_income'] - assert parameter['subparams']['basic_income']['description'] == getattr( - model_benefits.basic_income, "description", None - ), parameter['subparams']['basic_income']['description'] + assert "description" in parameter["subparams"]["basic_income"] + assert parameter["subparams"]["basic_income"]["description"] == getattr( + model_benefits.basic_income, + "description", + None, + ), parameter["subparams"]["basic_income"]["description"] -def test_stopped_parameter_values(test_client): - response = test_client.get('/parameter/benefits/housing_allowance') +def test_stopped_parameter_values(test_client) -> None: + response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert parameter['values'] == {'2016-12-01': None, '2010-01-01': 0.25} + assert parameter["values"] == {"2016-12-01": None, "2010-01-01": 0.25} -def test_scale(test_client): - response = test_client.get('/parameter/taxes/social_security_contribution') +def test_scale(test_client) -> None: + response = test_client.get("/parameter/taxes/social_security_contribution") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['brackets', 'description', 'id', 'metadata' == 'source'] - assert parameter['brackets'] == { - '2013-01-01': {"0.0": 0.03, "12000.0": 0.10}, - '2014-01-01': {"0.0": 0.03, "12100.0": 0.10}, - '2015-01-01': {"0.0": 0.04, "12200.0": 0.12}, - '2016-01-01': {"0.0": 0.04, "12300.0": 0.12}, - '2017-01-01': {"0.0": 0.02, "6000.0": 0.06, "12400.0": 0.12}, - } - - -def check_code(client, route, code): + assert sorted(parameter.keys()), [ + "brackets", + "description", + "id", + "metadata" == "source", + ] + assert parameter["brackets"] == { + "2013-01-01": {"0.0": 0.03, "12000.0": 0.10}, + "2014-01-01": {"0.0": 0.03, "12100.0": 0.10}, + "2015-01-01": {"0.0": 0.04, "12200.0": 0.12}, + "2016-01-01": {"0.0": 0.04, "12300.0": 0.12}, + "2017-01-01": {"0.0": 0.02, "6000.0": 0.06, "12400.0": 0.12}, + } + + +def check_code(client, route, code) -> None: response = client.get(route) assert response.status_code == code -@pytest.mark.parametrize("expected_code", [ - ('/parameters/', client.OK), - ('/parameter', client.NOT_FOUND), - ('/parameter/', client.NOT_FOUND), - ('/parameter/with-ÜNı©ød€', client.NOT_FOUND), - ('/parameter/with%20url%20encoding', client.NOT_FOUND), - ('/parameter/taxes/income_tax_rate/', client.OK), - ('/parameter/taxes/income_tax_rate/too-much-nesting', client.NOT_FOUND), - ('/parameter//taxes/income_tax_rate/', client.NOT_FOUND), - ]) -def test_routes_robustness(test_client, expected_code): +@pytest.mark.parametrize( + "expected_code", + [ + ("/parameters/", client.FOUND), + ("/parameter", client.NOT_FOUND), + ("/parameter/", client.FOUND), + ("/parameter/with-ÜNı©ød€", client.NOT_FOUND), + ("/parameter/with%20url%20encoding", client.NOT_FOUND), + ("/parameter/taxes/income_tax_rate/", client.FOUND), + ("/parameter/taxes/income_tax_rate/too-much-nesting", client.NOT_FOUND), + ("/parameter//taxes/income_tax_rate/", client.FOUND), + ], +) +def test_routes_robustness(test_client, expected_code) -> None: check_code(test_client, *expected_code) -def test_parameter_encoding(test_client): - parameter_response = test_client.get('/parameter/general/age_of_retirement') +def test_parameter_encoding(test_client) -> None: + parameter_response = test_client.get("/parameter/general/age_of_retirement") assert parameter_response.status_code == client.OK diff --git a/tests/web_api/test_spec.py b/tests/web_api/test_spec.py index 5e19752119..75a0f00e64 100644 --- a/tests/web_api/test_spec.py +++ b/tests/web_api/test_spec.py @@ -1,55 +1,77 @@ -import dpath import json -import pytest from http import client +import dpath.util +import pytest +from openapi_spec_validator import OpenAPIV30SpecValidator + -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert sorted(x) == sorted(y) -def test_return_code(test_client): - openAPI_response = test_client.get('/spec') +def test_return_code(test_client) -> None: + openAPI_response = test_client.get("/spec") assert openAPI_response.status_code == client.OK @pytest.fixture(scope="module") def body(test_client): - openAPI_response = test_client.get('/spec') - return json.loads(openAPI_response.data.decode('utf-8')) + openAPI_response = test_client.get("/spec") + return json.loads(openAPI_response.data.decode("utf-8")) -def test_paths(body): +def test_paths(body) -> None: assert_items_equal( - body['paths'], - ["/parameter/{parameterID}", - "/parameters", - "/variable/{variableID}", - "/variables", - "/entities", - "/trace", - "/calculate", - "/spec"] - ) + body["paths"], + [ + "/parameter/{parameterID}", + "/parameters", + "/variable/{variableID}", + "/variables", + "/entities", + "/trace", + "/calculate", + "/spec", + ], + ) -def test_entity_definition(body): - assert 'parents' in dpath.get(body, 'definitions/Household/properties') - assert 'children' in dpath.get(body, 'definitions/Household/properties') - assert 'salary' in dpath.get(body, 'definitions/Person/properties') - assert 'rent' in dpath.get(body, 'definitions/Household/properties') - assert 'number' == dpath.get(body, 'definitions/Person/properties/salary/additionalProperties/type') +def test_entity_definition(body) -> None: + assert "parents" in dpath.util.get(body, "components/schemas/Household/properties") + assert "children" in dpath.util.get(body, "components/schemas/Household/properties") + assert "salary" in dpath.util.get(body, "components/schemas/Person/properties") + assert "rent" in dpath.util.get(body, "components/schemas/Household/properties") + assert ( + dpath.util.get( + body, + "components/schemas/Person/properties/salary/additionalProperties/type", + ) + == "number" + ) -def test_situation_definition(body): - situation_input = body['definitions']['SituationInput'] - situation_output = body['definitions']['SituationOutput'] +def test_situation_definition(body) -> None: + situation_input = body["components"]["schemas"]["SituationInput"] + situation_output = body["components"]["schemas"]["SituationOutput"] for situation in situation_input, situation_output: - assert 'households' in dpath.get(situation, '/properties') - assert 'persons' in dpath.get(situation, '/properties') - assert "#/definitions/Household" == dpath.get(situation, '/properties/households/additionalProperties/$ref') - assert "#/definitions/Person" == dpath.get(situation, '/properties/persons/additionalProperties/$ref') + assert "households" in dpath.util.get(situation, "/properties") + assert "persons" in dpath.util.get(situation, "/properties") + assert ( + dpath.util.get( + situation, + "/properties/households/additionalProperties/$ref", + ) + == "#/components/schemas/Household" + ) + assert ( + dpath.util.get( + situation, + "/properties/persons/additionalProperties/$ref", + ) + == "#/components/schemas/Person" + ) -def test_host(body): - assert 'http' not in body['host'] +def test_respects_spec(body) -> None: + assert not list(OpenAPIV30SpecValidator(body).iter_errors()) diff --git a/tests/web_api/test_trace.py b/tests/web_api/test_trace.py index b59fbdb5f0..9463e69dfb 100644 --- a/tests/web_api/test_trace.py +++ b/tests/web_api/test_trace.py @@ -1,80 +1,129 @@ import copy -import dpath -from http import client import json +from http import client + +import dpath.util -from openfisca_country_template.situation_examples import single, couple +from openfisca_country_template.situation_examples import couple, single -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -def test_trace_basic(test_client): +def test_trace_basic(test_client) -> None: simulation_json = json.dumps(single) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - disposable_income_value = dpath.util.get(response_json, 'trace/disposable_income<2017-01>/value') + response_json = json.loads(response.data.decode("utf-8")) + disposable_income_value = dpath.util.get( + response_json, + "trace/disposable_income<2017-01>/value", + ) assert isinstance(disposable_income_value, list) assert isinstance(disposable_income_value[0], float) - disposable_income_dep = dpath.util.get(response_json, 'trace/disposable_income<2017-01>/dependencies') + disposable_income_dep = dpath.util.get( + response_json, + "trace/disposable_income<2017-01>/dependencies", + ) assert_items_equal( disposable_income_dep, - ['salary<2017-01>', 'basic_income<2017-01>', 'income_tax<2017-01>', 'social_security_contribution<2017-01>'] - ) - basic_income_dep = dpath.util.get(response_json, 'trace/basic_income<2017-01>/dependencies') - assert_items_equal(basic_income_dep, ['age<2017-01>']) - - -def test_trace_enums(test_client): + [ + "salary<2017-01>", + "basic_income<2017-01>", + "income_tax<2017-01>", + "social_security_contribution<2017-01>", + ], + ) + basic_income_dep = dpath.util.get( + response_json, + "trace/basic_income<2017-01>/dependencies", + ) + assert_items_equal(basic_income_dep, ["age<2017-01>"]) + + +def test_trace_enums(test_client) -> None: new_single = copy.deepcopy(single) - new_single['households']['_']['housing_occupancy_status'] = {"2017-01": None} + new_single["households"]["_"]["housing_occupancy_status"] = {"2017-01": None} simulation_json = json.dumps(new_single) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) response_json = json.loads(response.data) - housing_status = dpath.util.get(response_json, 'trace/housing_occupancy_status<2017-01>/value') - assert housing_status[0] == 'tenant' # The default value + housing_status = dpath.util.get( + response_json, + "trace/housing_occupancy_status<2017-01>/value", + ) + assert housing_status[0] == "tenant" # The default value -def test_entities_description(test_client): +def test_entities_description(test_client) -> None: simulation_json = json.dumps(couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( - dpath.util.get(response_json, 'entitiesDescription/persons'), - ['Javier', "Alicia"] - ) + dpath.util.get(response_json, "entitiesDescription/persons"), + ["Javier", "Alicia"], + ) -def test_root_nodes(test_client): +def test_root_nodes(test_client) -> None: simulation_json = json.dumps(couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( - dpath.util.get(response_json, 'requestedCalculations'), - ['disposable_income<2017-01>', 'total_benefits<2017-01>', 'total_taxes<2017-01>'] - ) + dpath.util.get(response_json, "requestedCalculations"), + [ + "disposable_income<2017-01>", + "total_benefits<2017-01>", + "total_taxes<2017-01>", + ], + ) -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['postal_code'] = {'2017-01': None} + new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK -def test_trace_parameters(test_client): +def test_trace_parameters(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['housing_tax'] = {'2017': None} + new_couple["households"]["_"]["housing_tax"] = {"2017": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) - - assert len(dpath.util.get(response_json, 'trace/housing_tax<2017>/parameters')) > 0 - taxes__housing_tax__minimal_amount = dpath.util.get(response_json, 'trace/housing_tax<2017>/parameters/taxes.housing_tax.minimal_amount<2017-01-01>') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) + + assert len(dpath.util.get(response_json, "trace/housing_tax<2017>/parameters")) > 0 + taxes__housing_tax__minimal_amount = dpath.util.get( + response_json, + "trace/housing_tax<2017>/parameters/taxes.housing_tax.minimal_amount<2017-01-01>", + ) assert taxes__housing_tax__minimal_amount == 200 diff --git a/tests/web_api/test_variables.py b/tests/web_api/test_variables.py index 4581608aa8..a521e8a1b1 100644 --- a/tests/web_api/test_variables.py +++ b/tests/web_api/test_variables.py @@ -1,14 +1,16 @@ -from http import client import json -import pytest import re +import sys +from http import client + +import pytest -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -GITHUB_URL_REGEX = r'^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/variables/(.)+\.py#L\d+-L\d+$' +GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/variables/(.)+\.py#L\d+-L\d+$" # /variables @@ -16,147 +18,160 @@ def assert_items_equal(x, y): @pytest.fixture(scope="module") def variables_response(test_client): - variables_response = test_client.get("/variables") - return variables_response + return test_client.get("/variables") -def test_return_code(variables_response): +def test_return_code(variables_response) -> None: assert variables_response.status_code == client.OK -def test_response_data(variables_response): - variables = json.loads(variables_response.data.decode('utf-8')) - assert variables['birth'] == { - 'description': 'Birth date', - 'href': 'http://localhost/variable/birth' - } +def test_response_data(variables_response) -> None: + variables = json.loads(variables_response.data.decode("utf-8")) + assert variables["birth"] == { + "description": "Birth date", + "href": "http://localhost/variable/birth", + } # /variable/ -def test_error_code_non_existing_variable(test_client): - response = test_client.get('/variable/non_existing_variable') +def test_error_code_non_existing_variable(test_client) -> None: + response = test_client.get("/variable/non_existing_variable") assert response.status_code == client.NOT_FOUND @pytest.fixture(scope="module") def input_variable_response(test_client): - input_variable_response = test_client.get('/variable/birth') - return input_variable_response + return test_client.get("/variable/birth") -def test_return_code_existing_input_variable(input_variable_response): +def test_return_code_existing_input_variable(input_variable_response) -> None: assert input_variable_response.status_code == client.OK -def check_input_variable_value(key, expected_value, input_variable=None): +def check_input_variable_value(key, expected_value, input_variable=None) -> None: assert input_variable[key] == expected_value -@pytest.mark.parametrize("expected_values", [ - ('description', 'Birth date'), - ('valueType', 'Date'), - ('defaultValue', '1970-01-01'), - ('definitionPeriod', 'ETERNITY'), - ('entity', 'person'), - ('references', ['https://en.wiktionary.org/wiki/birthdate']), - ]) -def test_input_variable_value(expected_values, input_variable_response): - input_variable = json.loads(input_variable_response.data.decode('utf-8')) +@pytest.mark.parametrize( + "expected_values", + [ + ("description", "Birth date"), + ("valueType", "Date"), + ("defaultValue", "1970-01-01"), + ("definitionPeriod", "ETERNITY"), + ("entity", "person"), + ("references", ["https://en.wiktionary.org/wiki/birthdate"]), + ], +) +def test_input_variable_value(expected_values, input_variable_response) -> None: + input_variable = json.loads(input_variable_response.data.decode("utf-8")) check_input_variable_value(*expected_values, input_variable=input_variable) -def test_input_variable_github_url(test_client): - input_variable_response = test_client.get('/variable/income_tax') - input_variable = json.loads(input_variable_response.data.decode('utf-8')) +# TODO(Mauko Quiroga-Alvarado): Fix this test +# https://github.com/openfisca/openfisca-core/issues/962 +@pytest.mark.skipif(sys.platform == "win32", reason="Does not work on Windows.") +def test_input_variable_github_url(test_client) -> None: + input_variable_response = test_client.get("/variable/income_tax") + input_variable = json.loads(input_variable_response.data.decode("utf-8")) - assert re.match(GITHUB_URL_REGEX, input_variable['source']) + assert re.match(GITHUB_URL_REGEX, input_variable["source"]) -def test_return_code_existing_variable(test_client): - variable_response = test_client.get('/variable/income_tax') +def test_return_code_existing_variable(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") assert variable_response.status_code == client.OK -def check_variable_value(key, expected_value, variable=None): +def check_variable_value(key, expected_value, variable=None) -> None: assert variable[key] == expected_value -@pytest.mark.parametrize("expected_values", [ - ('description', 'Income tax'), - ('valueType', 'Float'), - ('defaultValue', 0), - ('definitionPeriod', 'MONTH'), - ('entity', 'person'), - ]) -def test_variable_value(expected_values, test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) +@pytest.mark.parametrize( + "expected_values", + [ + ("description", "Income tax"), + ("valueType", "Float"), + ("defaultValue", 0), + ("definitionPeriod", "MONTH"), + ("entity", "person"), + ], +) +def test_variable_value(expected_values, test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) check_variable_value(*expected_values, variable=variable) -def test_variable_formula_github_link(test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) - assert re.match(GITHUB_URL_REGEX, variable['formulas']['0001-01-01']['source']) +# TODO(Mauko Quiroga-Alvarado): Fix this test +# https://github.com/openfisca/openfisca-core/issues/962 +@pytest.mark.skipif(sys.platform == "win32", reason="Does not work on Windows.") +def test_variable_formula_github_link(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) + assert re.match(GITHUB_URL_REGEX, variable["formulas"]["0001-01-01"]["source"]) -def test_variable_formula_content(test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) - content = variable['formulas']['0001-01-01']['content'] +def test_variable_formula_content(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) + content = variable["formulas"]["0001-01-01"]["content"] assert "def formula(person, period, parameters):" in content - assert "return person(\"salary\", period) * parameters(period).taxes.income_tax_rate" in content + assert ( + 'return person("salary", period) * parameters(period).taxes.income_tax_rate' + in content + ) -def test_null_values_are_dropped(test_client): - variable_response = test_client.get('/variable/age') - variable = json.loads(variable_response.data.decode('utf-8')) - assert 'references' not in variable.keys() +def test_null_values_are_dropped(test_client) -> None: + variable_response = test_client.get("/variable/age") + variable = json.loads(variable_response.data.decode("utf-8")) + assert "references" not in variable -def test_variable_with_start_and_stop_date(test_client): - response = test_client.get('/variable/housing_allowance') - variable = json.loads(response.data.decode('utf-8')) - assert_items_equal(variable['formulas'], ['1980-01-01', '2016-12-01']) - assert variable['formulas']['2016-12-01'] is None - assert 'formula' in variable['formulas']['1980-01-01']['content'] +def test_variable_with_start_and_stop_date(test_client) -> None: + response = test_client.get("/variable/housing_allowance") + variable = json.loads(response.data.decode("utf-8")) + assert_items_equal(variable["formulas"], ["1980-01-01", "2016-12-01"]) + assert variable["formulas"]["2016-12-01"] is None + assert "formula" in variable["formulas"]["1980-01-01"]["content"] -def test_variable_with_enum(test_client): - response = test_client.get('/variable/housing_occupancy_status') - variable = json.loads(response.data.decode('utf-8')) - assert variable['valueType'] == 'String' - assert variable['defaultValue'] == 'tenant' - assert 'possibleValues' in variable.keys() - assert variable['possibleValues'] == { - 'free_lodger': 'Free lodger', - 'homeless': 'Homeless', - 'owner': 'Owner', - 'tenant': 'Tenant'} +def test_variable_with_enum(test_client) -> None: + response = test_client.get("/variable/housing_occupancy_status") + variable = json.loads(response.data.decode("utf-8")) + assert variable["valueType"] == "String" + assert variable["defaultValue"] == "tenant" + assert "possibleValues" in variable + assert variable["possibleValues"] == { + "free_lodger": "Free lodger", + "homeless": "Homeless", + "owner": "Owner", + "tenant": "Tenant", + } @pytest.fixture(scope="module") def dated_variable_response(test_client): - dated_variable_response = test_client.get('/variable/basic_income') - return dated_variable_response + return test_client.get("/variable/basic_income") -def test_return_code_existing_dated_variable(dated_variable_response): +def test_return_code_existing_dated_variable(dated_variable_response) -> None: assert dated_variable_response.status_code == client.OK -def test_dated_variable_formulas_dates(dated_variable_response): - dated_variable = json.loads(dated_variable_response.data.decode('utf-8')) - assert_items_equal(dated_variable['formulas'], ['2016-12-01', '2015-12-01']) +def test_dated_variable_formulas_dates(dated_variable_response) -> None: + dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) + assert_items_equal(dated_variable["formulas"], ["2016-12-01", "2015-12-01"]) -def test_dated_variable_formulas_content(dated_variable_response): - dated_variable = json.loads(dated_variable_response.data.decode('utf-8')) - formula_code_2016 = dated_variable['formulas']['2016-12-01']['content'] - formula_code_2015 = dated_variable['formulas']['2015-12-01']['content'] +def test_dated_variable_formulas_content(dated_variable_response) -> None: + dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) + formula_code_2016 = dated_variable["formulas"]["2016-12-01"]["content"] + formula_code_2015 = dated_variable["formulas"]["2015-12-01"]["content"] assert "def formula_2016_12(person, period, parameters):" in formula_code_2016 assert "return" in formula_code_2016 @@ -164,16 +179,22 @@ def test_dated_variable_formulas_content(dated_variable_response): assert "return" in formula_code_2015 -def test_variable_encoding(test_client): - variable_response = test_client.get('/variable/pension') +def test_variable_encoding(test_client) -> None: + variable_response = test_client.get("/variable/pension") assert variable_response.status_code == client.OK -def test_variable_documentation(test_client): - response = test_client.get('/variable/housing_allowance') - variable = json.loads(response.data.decode('utf-8')) - assert variable['documentation'] == "This allowance was introduced on the 1st of Jan 1980.\nIt disappeared in Dec 2016." +def test_variable_documentation(test_client) -> None: + response = test_client.get("/variable/housing_allowance") + variable = json.loads(response.data.decode("utf-8")) + assert ( + variable["documentation"] + == "This allowance was introduced on the 1st of Jan 1980.\nIt disappeared in Dec 2016." + ) - formula_documentation = variable['formulas']['1980-01-01']['documentation'] + formula_documentation = variable["formulas"]["1980-01-01"]["documentation"] assert "Housing allowance." in formula_documentation - assert "Calculating it before this date will always return the variable default value, 0." in formula_documentation + assert ( + "before this date will always return the variable default value, 0." + in formula_documentation + )