Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions packages/aws-library/src/aws_library/s3/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,30 @@ async def create(
cls, settings: S3Settings, s3_max_concurrency: int = _S3_MAX_CONCURRENCY_DEFAULT
) -> "SimcoreS3API":
session = aioboto3.Session()
session_client = session.client( # type: ignore[call-overload]
"s3",
endpoint_url=f"{settings.S3_ENDPOINT}",
aws_access_key_id=settings.S3_ACCESS_KEY,
aws_secret_access_key=settings.S3_SECRET_KEY,
region_name=settings.S3_REGION,
config=Config(signature_version="s3v4"),
)
assert isinstance(session_client, ClientCreatorContext) # nosec
session_client = None
exit_stack = contextlib.AsyncExitStack()
s3_client = cast(S3Client, await exit_stack.enter_async_context(session_client))
# NOTE: this triggers a botocore.exception.ClientError in case the connection is not made to the S3 backend
await s3_client.list_buckets()
try:
session_client = session.client( # type: ignore[call-overload]
"s3",
endpoint_url=f"{settings.S3_ENDPOINT}",
aws_access_key_id=settings.S3_ACCESS_KEY,
aws_secret_access_key=settings.S3_SECRET_KEY,
region_name=settings.S3_REGION,
config=Config(signature_version="s3v4"),
)
assert isinstance(session_client, ClientCreatorContext) # nosec

return cls(s3_client, session, exit_stack, s3_max_concurrency)
s3_client = cast(
S3Client, await exit_stack.enter_async_context(session_client)
)
# NOTE: this triggers a botocore.exception.ClientError in case the connection is not made to the S3 backend
await s3_client.list_buckets()

return cls(s3_client, session, exit_stack, s3_max_concurrency)
except Exception:
await exit_stack.aclose()

raise

async def close(self) -> None:
await self._exit_stack.aclose()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import asyncio
import datetime
import logging
import threading
from typing import Final

from asgi_lifespan import LifespanManager
from celery import Celery # type: ignore[import-untyped]
from fastapi import FastAPI
from servicelib.async_utils import cancel_wait_task
from servicelib.logging_utils import log_context

from ...core.application import create_app
from ...core.settings import ApplicationSettings
from ...modules.celery import get_event_loop, set_event_loop
from ...modules.celery import set_event_loop
from ...modules.celery.utils import (
get_fastapi_app,
set_celery_worker,
Expand All @@ -20,52 +21,57 @@

_logger = logging.getLogger(__name__)

_LIFESPAN_TIMEOUT: Final[int] = 10
_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds()
_STARTUP_TIMEOUT: Final[float] = datetime.timedelta(minutes=1).total_seconds()


def on_worker_init(sender, **_kwargs) -> None:
def _init_fastapi() -> None:
startup_complete_event = threading.Event()

def _init_fastapi(startup_complete_event: threading.Event) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
shutdown_event = asyncio.Event()

fastapi_app = create_app(ApplicationSettings.create_from_envs())

async def lifespan():
async def lifespan(
startup_complete_event: threading.Event, shutdown_event: asyncio.Event
) -> None:
async with LifespanManager(
fastapi_app,
startup_timeout=_LIFESPAN_TIMEOUT,
shutdown_timeout=_LIFESPAN_TIMEOUT,
startup_timeout=_STARTUP_TIMEOUT,
shutdown_timeout=_SHUTDOWN_TIMEOUT,
):
try:
_logger.info("fastapi APP started!")
startup_complete_event.set()
await shutdown_event.wait()
except asyncio.CancelledError:
_logger.warning("Lifespan task cancelled")

lifespan_task = loop.create_task(lifespan())
fastapi_app.state.lifespan_task = lifespan_task
fastapi_app.state.shutdown_event = shutdown_event
set_event_loop(fastapi_app, loop)

set_fastapi_app(sender.app, fastapi_app)
set_celery_worker(sender.app, CeleryTaskQueueWorker(sender.app))

loop.run_forever()

thread = threading.Thread(target=_init_fastapi, daemon=True)
loop.run_until_complete(lifespan(startup_complete_event, shutdown_event))

thread = threading.Thread(
group=None,
target=_init_fastapi,
name="fastapi_app",
args=(startup_complete_event,),
daemon=True,
)
thread.start()
# ensure the fastapi app is ready before going on
startup_complete_event.wait(_STARTUP_TIMEOUT * 1.1)


def on_worker_shutdown(sender, **_kwargs):
assert isinstance(sender.app, Celery)

fastapi_app = get_fastapi_app(sender.app)
assert isinstance(fastapi_app, FastAPI)
event_loop = get_event_loop(fastapi_app)

async def shutdown():
def on_worker_shutdown(sender, **_kwargs) -> None:
with log_context(_logger, logging.INFO, "Worker Shuts-down"):
assert isinstance(sender.app, Celery)
fastapi_app = get_fastapi_app(sender.app)
assert isinstance(fastapi_app, FastAPI)
fastapi_app.state.shutdown_event.set()

await cancel_wait_task(fastapi_app.state.lifespan_task, max_delay=5)

asyncio.run_coroutine_threadsafe(shutdown(), event_loop)
Loading