Skip to content

Commit 961a7b2

Browse files
refactor: improve type-strictness
Done by enabling mypy-strict mode
1 parent 61a3e36 commit 961a7b2

File tree

17 files changed

+85
-65
lines changed

17 files changed

+85
-65
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ repos:
102102
- types-requests
103103
- types-ujson
104104
- types-toml
105+
- types-click
106+
- types-python-jose
107+
- pymongo
108+
- pydantic
109+
- fastapi
105110

106111
# The path to the venv python interpreter differ between linux and windows. An if/else is used to find it on either.
107112
- repo: local

api/pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@ requires = ["poetry-core>=1.0.0"]
6565
build-backend = "poetry.core.masonry.api"
6666

6767
[tool.mypy]
68+
69+
plugins = ["pydantic.mypy"]
70+
6871
ignore_missing_imports = true
69-
warn_return_any = true
70-
warn_unused_configs = true
7172
namespace_packages = true
7273
explicit_package_bases = true
74+
allow_subclassing_any = true
75+
76+
strict = true
7377

7478

7579
[tool.ruff]

api/src/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def create_app() -> FastAPI:
6363

6464

6565
@click.group()
66-
def cli():
66+
def cli() -> None:
6767
pass
6868

6969

7070
@cli.command()
71-
def run():
71+
def run() -> None:
7272
import uvicorn
7373

7474
uvicorn.run(

api/src/authentication/authentication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@cached(cache=TTLCache(maxsize=32, ttl=86400))
21-
def fetch_openid_configuration() -> dict:
21+
def fetch_openid_configuration() -> dict[str, str]:
2222
try:
2323
oid_conf_response = httpx.get(config.OAUTH_WELL_KNOWN)
2424
oid_conf_response.raise_for_status()

api/src/authentication/mock_token_generator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"""
5151

5252

53-
def generate_mock_token(user: User = default_user):
53+
def generate_mock_token(user: User = default_user) -> str:
5454
"""
5555
This function is for testing purposes only
5656
Used for behave testing
@@ -64,5 +64,4 @@ def generate_mock_token(user: User = default_user):
6464
"roles": user.roles,
6565
"iss": "mock-auth-server",
6666
}
67-
token = jwt.encode(payload, mock_rsa_private_key, algorithm="RS256")
68-
return token
67+
return jwt.encode(payload, mock_rsa_private_key, algorithm="RS256")

api/src/authentication/models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import IntEnum
2+
from typing import Any
23

34
from pydantic import BaseModel, GetJsonSchemaHandler
45
from pydantic_core import core_schema
@@ -15,11 +16,11 @@ def check_privilege(self, required_level: "AccessLevel") -> bool:
1516
return False
1617

1718
@classmethod
18-
def __get_validators__(cls):
19+
def __get_validators__(cls): # type:ignore
1920
yield cls.validate
2021

2122
@classmethod
22-
def validate(cls, v):
23+
def validate(cls, v: str) -> "AccessLevel":
2324
if isinstance(v, cls):
2425
return v
2526
try:
@@ -28,7 +29,9 @@ def validate(cls, v):
2829
raise ValueError("invalid AccessLevel enum value ")
2930

3031
@classmethod
31-
def __get_pydantic_json_schema__(cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler):
32+
def __get_pydantic_json_schema__(
33+
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
34+
) -> dict[str, Any]:
3235
"""
3336
Add a custom field type to the class representing the Enum's field names
3437
Ref: https://pydantic-docs.helpmanual.io/usage/schema/#modifying-schema-in-custom-fields
@@ -50,7 +53,7 @@ class User(BaseModel):
5053
roles: list[str] = []
5154
scope: AccessLevel = AccessLevel.WRITE
5255

53-
def __hash__(self):
56+
def __hash__(self) -> int:
5457
return hash(type(self.user_id))
5558

5659

@@ -70,7 +73,7 @@ class ACL(BaseModel):
7073
users: dict[str, AccessLevel] = {}
7174
others: AccessLevel = AccessLevel.READ
7275

73-
def dict(self, **kwargs):
76+
def dict(self, **kwargs: Any) -> dict[str, str | dict[str, AccessLevel | str]]:
7477
return {
7578
"owner": self.owner,
7679
"roles": {k: v.name for k, v in self.roles.items()},

api/src/common/exception_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def validation_exception_handler(request: Request, exc: RequestValidationError)
8888
)
8989

9090

91-
def http_exception_handler(request: Request, exc: HTTPStatusError):
91+
def http_exception_handler(request: Request, exc: HTTPStatusError) -> JSONResponse:
9292
logger.error(exc)
9393
return JSONResponse(
9494
ErrorResponse(

api/src/common/exceptions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ErrorResponse(BaseModel):
1717
type: str = "ApplicationException"
1818
message: str = "The requested operation failed"
1919
debug: str = "An unknown and unhandled exception occurred in the API"
20-
extra: dict | None = None
20+
extra: dict[str, str] | None = None
2121

2222

2323
class ApplicationException(Exception):
@@ -26,13 +26,13 @@ class ApplicationException(Exception):
2626
type: str = "ApplicationException"
2727
message: str = "The requested operation failed"
2828
debug: str = "An unknown and unhandled exception occurred in the API"
29-
extra: dict | None = None
29+
extra: dict[str, str] | None = None
3030

3131
def __init__(
3232
self,
3333
message: str = "The requested operation failed",
3434
debug: str = "An unknown and unhandled exception occurred in the API",
35-
extra: dict | None = None,
35+
extra: dict[str, str] | None = None,
3636
status: int = 500,
3737
severity: ExceptionSeverity = ExceptionSeverity.ERROR,
3838
):
@@ -43,7 +43,7 @@ def __init__(
4343
self.extra = extra
4444
self.severity = severity
4545

46-
def dict(self):
46+
def dict(self) -> dict[str, int | str | dict[str, str] | None]:
4747
return {
4848
"status": self.status,
4949
"type": self.type,
@@ -58,7 +58,7 @@ def __init__(
5858
self,
5959
message: str = "You do not have the required permissions",
6060
debug: str = "Action denied because of insufficient permissions",
61-
extra: dict | None = None,
61+
extra: dict[str, str] | None = None,
6262
):
6363
super().__init__(message, debug, extra, request_status.HTTP_403_FORBIDDEN, severity=ExceptionSeverity.WARNING)
6464
self.type = self.__class__.__name__
@@ -69,7 +69,7 @@ def __init__(
6969
self,
7070
message: str = "The requested resource could not be found",
7171
debug: str = "The requested resource could not be found",
72-
extra: dict | None = None,
72+
extra: dict[str, str] | None = None,
7373
):
7474
super().__init__(message, debug, extra, request_status.HTTP_404_NOT_FOUND)
7575
self.type = self.__class__.__name__
@@ -80,7 +80,7 @@ def __init__(
8080
self,
8181
message: str = "Invalid data for the operation",
8282
debug: str = "Unable to complete the requested operation with the given input values.",
83-
extra: dict | None = None,
83+
extra: dict[str, str] | None = None,
8484
):
8585
super().__init__(message, debug, extra, request_status.HTTP_400_BAD_REQUEST)
8686
self.type = self.__class__.__name__
@@ -91,7 +91,7 @@ def __init__(
9191
self,
9292
message: str = "The received data is invalid",
9393
debug: str = "Values are invalid for requested operation.",
94-
extra: dict | None = None,
94+
extra: dict[str, str] | None = None,
9595
):
9696
super().__init__(message, debug, extra, request_status.HTTP_422_UNPROCESSABLE_ENTITY)
9797
self.type = self.__class__.__name__

api/src/common/middleware.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from opencensus.trace.samplers import ProbabilitySampler
77
from opencensus.trace.tracer import Tracer
88
from starlette.datastructures import MutableHeaders
9+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
910

1011
from common.logger import logger
1112
from config import config
@@ -15,20 +16,20 @@
1516
# Middleware inheriting from the "BaseHTTPMiddleware" class does not work with Starlettes BackgroundTasks.
1617
# see: https://github.com/encode/starlette/issues/919
1718
class LocalLoggerMiddleware:
18-
def __init__(self, app):
19+
def __init__(self, app: ASGIApp):
1920
self.app = app
2021

21-
async def __call__(self, scope, receive, send):
22+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2223
if scope["type"] != "http":
2324
return await self.app(scope, receive, send)
2425

2526
start_time = time.time()
2627
process_time = ""
2728
path = scope["path"]
2829
method = scope["method"]
29-
response = {}
30+
response: Message = {}
3031

31-
async def inner_send(message):
32+
async def inner_send(message: Message) -> None:
3233
nonlocal process_time
3334
nonlocal response
3435
if message["type"] == "http.response.start":
@@ -49,19 +50,19 @@ class OpenCensusRequestLoggingMiddleware:
4950
exporter = AzureExporter(connection_string=config.APPINSIGHTS_CONSTRING) if config.APPINSIGHTS_CONSTRING else None
5051
sampler = ProbabilitySampler(1.0)
5152

52-
def __init__(self, app):
53+
def __init__(self, app: ASGIApp):
5354
self.app = app
5455

55-
async def __call__(self, scope, receive, send):
56+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5657
if scope["type"] != "http":
5758
return await self.app(scope, receive, send)
5859

5960
tracer = Tracer(exporter=self.exporter, sampler=self.sampler)
6061

6162
path = scope["path"]
62-
response = {}
63+
response: Message = {}
6364

64-
async def inner_send(message):
65+
async def inner_send(message: Message) -> None:
6566
nonlocal response
6667
if message["type"] == "http.response.start":
6768
response = message

api/src/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Config(BaseSettings):
3737
MICROSOFT_AUTH_PROVIDER: str = "login.microsoftonline.com"
3838

3939

40-
config = Config() # type: ignore[call-arg]
40+
config = Config()
4141

4242
if config.AUTH_ENABLED and not all((config.OAUTH_AUTH_ENDPOINT, config.OAUTH_TOKEN_ENDPOINT, config.OAUTH_WELL_KNOWN)):
4343
raise ValueError("Authentication was enabled, but some auth configuration parameters are missing")

0 commit comments

Comments
 (0)