diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa838a8..832c092 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 3.8, 3.7, 3.6] + python-version: [3.9, 3.8, 3.7] os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v2 diff --git a/fastapi_security/api.py b/fastapi_security/api.py index 0fa9d52..97f250a 100644 --- a/fastapi_security/api.py +++ b/fastapi_security/api.py @@ -1,11 +1,15 @@ import logging -from typing import Callable, Dict, Iterable, List, Optional, Type +from typing import Callable, Dict, Iterable, List, Optional, Type, Union from fastapi import Depends, HTTPException from fastapi.security.http import HTTPAuthorizationCredentials from starlette.datastructures import Headers -from .basic import BasicAuthValidator, IterableOfHTTPBasicCredentials +from .basic import ( + BasicAuthValidator, + BasicAuthWithDigestValidator, + IterableOfHTTPBasicCredentials, +) from .entities import AuthMethod, User, UserAuth, UserInfo from .exceptions import AuthNotConfigured from .oauth2 import Oauth2JwtAccessTokenValidator @@ -25,6 +29,7 @@ class FastAPISecurity: """ def __init__(self, *, user_permission_class: Type[UserPermission] = UserPermission): + self.basic_auth: Union[BasicAuthValidator, BasicAuthWithDigestValidator] self.basic_auth = BasicAuthValidator() self.oauth2_jwt = Oauth2JwtAccessTokenValidator() self.oidc_discovery = OpenIdConnectDiscovery() @@ -35,7 +40,19 @@ def __init__(self, *, user_permission_class: Type[UserPermission] = UserPermissi self._oauth2_audiences: List[str] = [] def init_basic_auth(self, basic_auth_credentials: IterableOfHTTPBasicCredentials): - self.basic_auth.init(basic_auth_credentials) + new_basic_auth = BasicAuthValidator() + new_basic_auth.init(basic_auth_credentials) + self.basic_auth = new_basic_auth + + def init_basic_auth_with_digest( + self, + basic_auth_with_digest_credentials: IterableOfHTTPBasicCredentials, + *, + salt: str, + ): + new_basic_auth = BasicAuthWithDigestValidator() + new_basic_auth.init(basic_auth_with_digest_credentials, salt=salt) + self.basic_auth = new_basic_auth def init_oauth2_through_oidc( self, oidc_discovery_url: str, *, audiences: Iterable[str] = None @@ -185,7 +202,7 @@ async def dependency( return self._maybe_override_permissions( UserAuth.from_jwt_access_token(access_token) ) - elif http_credentials is not None and self.basic_auth.is_configured(): + elif http_credentials is not None: if self.basic_auth.validate(http_credentials): return self._maybe_override_permissions( UserAuth( diff --git a/fastapi_security/basic.py b/fastapi_security/basic.py index b4e8f32..690801b 100644 --- a/fastapi_security/basic.py +++ b/fastapi_security/basic.py @@ -1,9 +1,13 @@ import secrets +from base64 import urlsafe_b64encode from typing import Dict, Iterable, List, Union +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.hashes import SHA512, Hash from fastapi.security.http import HTTPBasicCredentials -__all__ = ("HTTPBasicCredentials",) +__all__ = ("HTTPBasicCredentials", "generate_digest") + IterableOfHTTPBasicCredentials = Iterable[Union[HTTPBasicCredentials, Dict]] @@ -13,7 +17,7 @@ def __init__(self): self._credentials = [] def init(self, credentials: IterableOfHTTPBasicCredentials): - self._credentials = self._make_credentials(credentials) + self._credentials = _make_credentials(credentials) def is_configured(self) -> bool: return len(self._credentials) > 0 @@ -29,10 +33,49 @@ def validate(self, credentials: HTTPBasicCredentials) -> bool: for c in self._credentials ) - def _make_credentials( - self, credentials: IterableOfHTTPBasicCredentials - ) -> List[HTTPBasicCredentials]: - return [ - c if isinstance(c, HTTPBasicCredentials) else HTTPBasicCredentials(**c) - for c in credentials - ] + +class BasicAuthWithDigestValidator: + def __init__(self): + self._credentials = [] + self._salt = None + + def init(self, credentials: IterableOfHTTPBasicCredentials, *, salt: str): + self._credentials = _make_credentials(credentials) + self._salt = salt + + def is_configured(self) -> bool: + return self._salt and len(self._credentials) > 0 + + def validate(self, credentials: HTTPBasicCredentials) -> bool: + if not self.is_configured(): + return False + return any( + ( + secrets.compare_digest(c.username, credentials.username) + and c.password == self.generate_digest(credentials.password) + ) + for c in self._credentials + ) + + def generate_digest(self, secret: str): + if not self._salt: + raise ValueError( + "BasicAuthWithDigestValidator: cannot generate digest, salt is empty" + ) + return generate_digest(secret, salt=self._salt) + + +def _make_credentials( + credentials: IterableOfHTTPBasicCredentials, +) -> List[HTTPBasicCredentials]: + return [ + c if isinstance(c, HTTPBasicCredentials) else HTTPBasicCredentials(**c) + for c in credentials + ] + + +def generate_digest(secret: str, *, salt: str): + hash_obj = Hash(algorithm=SHA512(), backend=default_backend()) + hash_obj.update((salt + secret).encode("latin1")) + result = hash_obj.finalize() + return urlsafe_b64encode(result).decode("latin1") diff --git a/fastapi_security/cli.py b/fastapi_security/cli.py new file mode 100644 index 0000000..8aa4fba --- /dev/null +++ b/fastapi_security/cli.py @@ -0,0 +1,93 @@ +"""fastapi_security command-line interface""" + +import argparse +import sys +import textwrap +from getpass import getpass +from typing import Optional, Sequence, Text + +from fastapi_security.basic import generate_digest + + +def _wrap_paragraphs(s): + paragraphs = s.strip().split("\n\n") + wrapped_paragraphs = [ + "\n".join(textwrap.wrap(paragraph)) for paragraph in paragraphs + ] + return "\n\n".join(wrapped_paragraphs) + + +main_parser = argparse.ArgumentParser( + description=_wrap_paragraphs(__doc__), + formatter_class=argparse.RawDescriptionHelpFormatter, +) +subcommand_parsers = main_parser.add_subparsers( + help="Specify a sub-command", + dest="subcommand", + required=True, +) + +gendigest_description = """ +Generate digest for basic_auth_with_digest credentials. + +Example: + +$ fastapi-security gendigest --salt=very-strong-salt +Password: +Confirm password: + +Here is your digest: +0jFS-cNapwQf_lpyULF7_hEelbl_zreNVHbxqKwKIFmPRQ09bYTEDQLrr_UEWZc9fdYFiU5F3il3rovJQ_UEpg== + +$ cat fastapi_security_conf.py +from fastapi_security import FastAPISecurity + +security = FastAPISecurity() +security.init_basic_auth_with_digest( + [ + {'user': 'me', 'password': '0jFS-cNapwQf_lpyULF7_hEelbl_zreNVHbxqKwKIFmPRQ09bYTEDQLrr_UEWZc9fdYFiU5F3il3rovJQ_UEpg=='} + ], + salt='very-strong-salt', +) +""" + +gendigest_parser = subcommand_parsers.add_parser( + "gendigest", + description=gendigest_description, + formatter_class=argparse.RawDescriptionHelpFormatter, +) +gendigest_parser.add_argument( + "--salt", + help="Salt value used in fastapi_security configuration.", + required=True, +) + + +def gendigest(parsed_args): + # if not parsed_args.salt: + # print("Cannot generate digest: --salt must be non-empty", + # file=sys.stderr) + # sys.exit(1) + + password = getpass(prompt="Password: ") + password_confirmation = getpass(prompt="Confirm password: ") + + if password != password_confirmation: + print("Cannot generate digest: passwords don't match", file=sys.stderr) + sys.exit(1) + + print("\nHere is your digest:", file=sys.stderr) + print(generate_digest(password, salt=parsed_args.salt)) + + +def main(args: Optional[Sequence[Text]] = None): + parsed_args = main_parser.parse_args(args) + if parsed_args.subcommand == "gendigest": + return gendigest(parsed_args) + + main_parser.print_usage(file=sys.stderr) + sys.exit(2) # invalid usage: missing subcommand + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 010fae7..4ec77ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,12 +32,16 @@ classifiers = [ "Typing :: Typed", ] +[tool.poetry.scripts] +fastapi-security = 'fastapi_security.cli:main' + [tool.poetry.dependencies] python = "^3.6" aiohttp = "^3" fastapi = "^0" pydantic = "^1" PyJWT = {version = "^2", extras = ["crypto"]} +cryptography = "^3.4.7" [tool.poetry.dev-dependencies] aioresponses = "^0.7.2" diff --git a/tests/integration/test_basic_auth.py b/tests/integration/test_basic_auth.py index 1b69c7a..6013af3 100644 --- a/tests/integration/test_basic_auth.py +++ b/tests/integration/test_basic_auth.py @@ -1,7 +1,7 @@ from fastapi import Depends from fastapi_security import FastAPISecurity, HTTPBasicCredentials, User -from fastapi_security.basic import BasicAuthValidator +from fastapi_security.basic import BasicAuthValidator, generate_digest from ..helpers.jwks import dummy_audience, dummy_jwks_uri @@ -65,3 +65,61 @@ def get_products(user: User = Depends(security.authenticated_user_or_401)): resp = client.get("/", auth=("user", "pass")) assert resp.status_code == 200 + + +def test_that_basic_auth_with_digest_rejects_credentials_with_wrong_user_or_password( + app, client +): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + pass_digest = generate_digest("pass", salt="salt123") + credentials = [{"username": "user", "password": pass_digest}] + security.init_basic_auth_with_digest(credentials, salt="salt123") + + resp = client.get("/") + assert resp.status_code == 401 + + resp = client.get("/", auth=("user", "")) + assert resp.status_code == 401 + + resp = client.get("/", auth=("", "pass")) + assert resp.status_code == 401 + + resp = client.get("/", auth=("abc", "123")) + assert resp.status_code == 401 + + +def test_that_basic_auth_with_digest_rejects_credentials_when_salt_does_not_match( + app, client +): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + pass_digest = generate_digest("pass", salt="salt123") + credentials = [{"username": "user", "password": pass_digest}] + security.init_basic_auth_with_digest(credentials, salt="salt456") + + resp = client.get("/", auth=("user", "pass")) + assert resp.status_code == 401 + + +def test_that_basic_auth_with_digest_accepts_correct_credentials(app, client): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + pass_digest = generate_digest("pass", salt="salt123") + credentials = [{"username": "user", "password": pass_digest}] + security.init_basic_auth_with_digest(credentials, salt="salt123") + + resp = client.get("/", auth=("user", "pass")) + assert resp.status_code == 200 diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py new file mode 100644 index 0000000..d944122 --- /dev/null +++ b/tests/integration/test_cli.py @@ -0,0 +1,55 @@ +import subprocess +from unittest import mock + +from fastapi_security import cli + + +def test_usage_output_without_params(): + result = subprocess.run(["fastapi-security"], capture_output=True) + assert result.returncode == 2 + assert result.stdout.decode().splitlines() == [] + assert result.stderr.decode().splitlines() == [ + "usage: fastapi-security [-h] {gendigest} ...", + "fastapi-security: error: the following arguments are required: subcommand", + ] + + +def test_usage_with_help_param(): + result = subprocess.run(["fastapi-security", "-h"], capture_output=True) + assert result.returncode == 0 + assert result.stdout.decode().splitlines() == [ + "usage: fastapi-security [-h] {gendigest} ...", + "", + "fastapi_security command-line interface", + "", + "positional arguments:", + " {gendigest} Specify a sub-command", + "", + "optional arguments:", + " -h, --help show this help message and exit", + ] + assert result.stderr.decode().splitlines() == [] + + +def test_gendigest_without_params(): + result = subprocess.run(["fastapi-security", "gendigest"], capture_output=True) + assert result.returncode == 2 + assert result.stdout.decode().splitlines() == [] + assert result.stderr.decode().splitlines() == [ + "usage: fastapi-security gendigest [-h] --salt SALT", + "fastapi-security gendigest: error: the following arguments are required: --salt", + ] + + +def test_gendigest_smoke_test(capsys, monkeypatch): + # gendigest smoke test is performed not in a subprocess, because getpass + # uses /dev/tty instead of stdin/stdout for security reasons, and it is + # much more tricky to intercept it, so instead go for a simple monkeypatch. + monkeypatch.setattr(cli, "getpass", mock.Mock(return_value="hello")) + cli.main(["gendigest", "--salt=very-strong-salt"]) + captured = capsys.readouterr() + assert ( + captured.out + == "xRPfDaQHwpcXlzfWeR_uqOBTytcjEAUMv98SDnbHmpajmT_AxeJTHX6FyeM8H1T4otOe81PMWAOqAD5_tO4gYg==\n" + ) + assert captured.err == "\nHere is your digest:\n"