Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,22 @@ def get_async_engine(app: web.Application) -> AsyncEngine:
return engine


async def connect_to_db(app: web.Application, settings: PostgresSettings) -> None:
async def connect_to_db(
app: web.Application, settings: PostgresSettings, application_name: str
) -> None:
"""
- db services up, data migrated and ready to use
- sets an engine in app state (use `get_async_engine(app)` to retrieve)
"""
if settings.POSTGRES_CLIENT_NAME:
settings = settings.model_copy(
update={"POSTGRES_CLIENT_NAME": settings.POSTGRES_CLIENT_NAME + "-asyncpg"}
)

with log_context(
_logger,
logging.INFO,
"Connecting app[APP_DB_ASYNC_ENGINE_KEY] to postgres with %s",
f"{settings=}",
):
engine = await create_async_engine_and_database_ready(settings)
engine = await create_async_engine_and_database_ready(
settings, application_name
)
_set_async_engine_to_app_state(app, engine)

_logger.info(
Expand Down
26 changes: 12 additions & 14 deletions packages/service-library/src/servicelib/db_asyncpg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

@retry(**PostgresRetryPolicyUponInitialization(_logger).kwargs)
async def create_async_engine_and_database_ready(
settings: PostgresSettings,
settings: PostgresSettings, application_name: str
) -> AsyncEngine:
"""
- creates asyncio engine
Expand All @@ -31,15 +31,11 @@ async def create_async_engine_and_database_ready(
)

server_settings = {
"jit": "off"
} # see https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#disabling-the-postgresql-jit-to-improve-enum-datatype-handling
if settings.POSTGRES_CLIENT_NAME:
assert isinstance(settings.POSTGRES_CLIENT_NAME, str) # nosec
server_settings.update(
{
"application_name": settings.POSTGRES_CLIENT_NAME,
}
)
"jit": "off",
"application_name": settings.client_name(
f"{application_name}", suffix="asyncpg"
),
}

engine = create_async_engine(
settings.dsn_with_async_sqlalchemy,
Expand Down Expand Up @@ -75,7 +71,7 @@ async def check_postgres_liveness(engine: AsyncEngine) -> LivenessResult:

@contextlib.asynccontextmanager
async def with_async_pg_engine(
settings: PostgresSettings,
settings: PostgresSettings, *, application_name: str
) -> AsyncIterator[AsyncEngine]:
"""
Creates an asyncpg engine and ensures it is properly closed after use.
Expand All @@ -86,9 +82,11 @@ async def with_async_pg_engine(
logging.DEBUG,
f"connection to db {settings.dsn_with_async_sqlalchemy}",
):
server_settings = None
if settings.POSTGRES_CLIENT_NAME:
assert isinstance(settings.POSTGRES_CLIENT_NAME, str)
server_settings = {
"application_name": settings.client_name(
application_name, suffix="asyncpg"
),
}

engine = create_async_engine(
settings.dsn_with_async_sqlalchemy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
_logger = logging.getLogger(__name__)


async def connect_to_db(app: FastAPI, settings: PostgresSettings) -> None:
async def connect_to_db(
app: FastAPI, settings: PostgresSettings, application_name: str
) -> None:
warnings.warn(
"The 'connect_to_db' function is deprecated and will be removed in a future release. "
"Please use 'postgres_lifespan' instead for managing the database connection lifecycle.",
Expand All @@ -27,7 +29,9 @@ async def connect_to_db(app: FastAPI, settings: PostgresSettings) -> None:
logging.DEBUG,
f"Connecting and migraging {settings.dsn_with_async_sqlalchemy}",
):
engine = await create_async_engine_and_database_ready(settings)
engine = await create_async_engine_and_database_ready(
settings, application_name
)

app.state.engine = engine
_logger.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def create_postgres_database_input_state(settings: PostgresSettings) -> State:
return {PostgresLifespanState.POSTGRES_SETTINGS: settings}


async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
async def postgres_database_lifespan(
app: FastAPI, state: State
) -> AsyncIterator[State]:

_lifespan_name = f"{__name__}.{postgres_database_lifespan.__name__}"

Expand All @@ -43,7 +45,7 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[

# connect to database
async_engine: AsyncEngine = await create_async_engine_and_database_ready(
settings
settings, app.title
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ async def test_lifespan_postgres_database_in_an_app(
mock_create_async_engine_and_database_ready: MockType,
app_lifespan: LifespanManager,
):

app = FastAPI(lifespan=app_lifespan)

async with ASGILifespanManager(
Expand All @@ -93,7 +92,7 @@ async def test_lifespan_postgres_database_in_an_app(
) as asgi_manager:
# Verify that the async engine was created
mock_create_async_engine_and_database_ready.assert_called_once_with(
app.state.settings.CATALOG_POSTGRES
app.state.settings.CATALOG_POSTGRES, app.title
)

# Verify that the async engine is in the lifespan manager state
Expand Down
18 changes: 9 additions & 9 deletions packages/settings-library/src/settings_library/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ def dsn_with_async_sqlalchemy(self) -> str:
)
return f"{url}"

@cached_property
def dsn_with_query(self) -> str:
def dsn_with_query(self, application_name: str, *, suffix: str | None) -> str:
"""Some clients do not support queries in the dsn"""
dsn = self.dsn
return self._update_query(dsn)
return self._update_query(dsn, application_name, suffix=suffix)

def client_name(self, application_name: str, *, suffix: str | None) -> str:
return f"{application_name}{'-' if self.POSTGRES_CLIENT_NAME else ''}{self.POSTGRES_CLIENT_NAME or ''}{'-' + suffix if suffix else ''}"

def _update_query(self, uri: str) -> str:
def _update_query(self, uri: str, application_name: str, suffix: str | None) -> str:
# SEE https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
new_params: dict[str, str] = {}
if self.POSTGRES_CLIENT_NAME:
new_params = {
"application_name": self.POSTGRES_CLIENT_NAME,
}
new_params: dict[str, str] = {
"application_name": self.client_name(application_name, suffix=suffix),
}

if new_params:
parsed_uri = urlparse(uri)
Expand Down
15 changes: 10 additions & 5 deletions packages/settings-library/tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from urllib.parse import urlparse

import pytest
from faker import Faker
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
from pytest_simcore.helpers.typing_env import EnvVarsDict
from settings_library.postgres import PostgresSettings
Expand All @@ -24,7 +25,6 @@ def mock_environment(mock_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPa


def test_cached_property_dsn(mock_environment: EnvVarsDict):

settings = PostgresSettings.create_from_envs()

# all are upper-case
Expand All @@ -36,22 +36,27 @@ def test_cached_property_dsn(mock_environment: EnvVarsDict):
assert "dsn" not in settings.model_dump()


def test_dsn_with_query(mock_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch):
def test_dsn_with_query(
mock_environment: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, faker: Faker
):
settings = PostgresSettings()

assert settings.POSTGRES_CLIENT_NAME
assert settings.dsn == "postgresql://foo:secret@localhost:5432/foodb"
app_name = faker.pystr()
assert (
settings.dsn_with_query
== "postgresql://foo:secret@localhost:5432/foodb?application_name=Some+%2643+funky+name"
settings.dsn_with_query(app_name, suffix="my-suffix")
== f"postgresql://foo:secret@localhost:5432/foodb?application_name={app_name}-Some+%2643+funky+name-my-suffix"
)

with monkeypatch.context() as patch:
patch.delenv("POSTGRES_CLIENT_NAME")
settings = PostgresSettings()

assert not settings.POSTGRES_CLIENT_NAME
assert settings.dsn == settings.dsn_with_query
assert f"{settings.dsn}?application_name=blah" == settings.dsn_with_query(
"blah", suffix=None
)


def test_dsn_with_async_sqlalchemy_has_query(
Expand Down
11 changes: 8 additions & 3 deletions packages/simcore-sdk/src/simcore_sdk/node_data/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def _state_metadata_entry_exists(


async def _delete_legacy_archive(
project_id: ProjectID, node_uuid: NodeID, path: Path
project_id: ProjectID, node_uuid: NodeID, path: Path, *, application_name: str
) -> None:
"""removes the .zip state archive from storage"""
s3_object = __create_s3_object_key(
Expand All @@ -180,13 +180,15 @@ async def _delete_legacy_archive(
# NOTE: if service is opened by a person which the users shared it with,
# they will not have the permission to delete the node
# Removing it via it's owner allows to always have access to the delete operation.
owner_id = await DBManager().get_project_owner_user_id(project_id)
owner_id = await DBManager(
application_name=application_name
).get_project_owner_user_id(project_id)
await filemanager.delete_file(
user_id=owner_id, store_id=SIMCORE_LOCATION, s3_object=s3_object
)


async def push(
async def push( # pylint: disable=too-many-arguments
user_id: UserID,
project_id: ProjectID,
node_uuid: NodeID,
Expand All @@ -198,6 +200,7 @@ async def push(
progress_bar: ProgressBarData,
aws_s3_cli_settings: AwsS3CliSettings | None,
legacy_state: LegacyState | None,
application_name: str,
) -> None:
"""pushes and removes the legacy archive if present"""

Expand Down Expand Up @@ -226,6 +229,7 @@ async def push(
project_id=project_id,
node_uuid=node_uuid,
path=source_path,
application_name=application_name,
)

if legacy_state:
Expand All @@ -244,6 +248,7 @@ async def push(
project_id=project_id,
node_uuid=node_uuid,
path=legacy_state.old_state_path,
application_name=application_name,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,25 @@ async def _update_comp_run_snapshot_tasks_if_computational(


class DBContextManager:
def __init__(self, db_engine: AsyncEngine | None = None) -> None:
def __init__(
self, db_engine: AsyncEngine | None = None, *, application_name: str
) -> None:
self._db_engine: AsyncEngine | None = db_engine
self._db_engine_created: bool = False
self._application_name: str = application_name

@staticmethod
async def _create_db_engine() -> AsyncEngine:
async def _create_db_engine(application_name: str) -> AsyncEngine:
settings = NodePortsSettings.create_from_envs()
engine = await create_async_engine_and_database_ready(
settings.POSTGRES_SETTINGS
settings.POSTGRES_SETTINGS, f"{application_name}-simcore-sdk"
)
assert isinstance(engine, AsyncEngine) # nosec
return engine

async def __aenter__(self) -> AsyncEngine:
if not self._db_engine:
self._db_engine = await self._create_db_engine()
self._db_engine = await self._create_db_engine(self._application_name)
self._db_engine_created = True
return self._db_engine

Expand All @@ -107,8 +110,9 @@ async def __aexit__(self, exc_type, exc, tb) -> None:


class DBManager:
def __init__(self, db_engine: AsyncEngine | None = None):
def __init__(self, db_engine: AsyncEngine | None = None, *, application_name: str):
self._db_engine = db_engine
self._application_name = application_name

async def write_ports_configuration(
self,
Expand All @@ -124,7 +128,9 @@ async def write_ports_configuration(

node_configuration = json_loads(json_configuration)
async with (
DBContextManager(self._db_engine) as engine,
DBContextManager(
self._db_engine, application_name=self._application_name
) as engine,
engine.begin() as connection,
):
# 1. Update comp_tasks table
Expand Down Expand Up @@ -154,7 +160,9 @@ async def get_ports_configuration_from_node_uuid(
"Getting ports configuration of node %s from comp_tasks table", node_uuid
)
async with (
DBContextManager(self._db_engine) as engine,
DBContextManager(
self._db_engine, application_name=self._application_name
) as engine,
engine.connect() as connection,
):
node = await _get_node_from_db(project_id, node_uuid, connection)
Expand All @@ -171,7 +179,9 @@ async def get_ports_configuration_from_node_uuid(

async def get_project_owner_user_id(self, project_id: ProjectID) -> UserID:
async with (
DBContextManager(self._db_engine) as engine,
DBContextManager(
self._db_engine, application_name=self._application_name
) as engine,
engine.connect() as connection,
):
prj_owner = await connection.scalar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,11 @@ async def ports(
project_id: ProjectIDStr,
node_uuid: NodeIDStr,
*,
db_manager: DBManager | None = None,
db_manager: DBManager,
r_clone_settings: RCloneSettings | None = None,
io_log_redirect_cb: LogRedirectCB | None = None,
aws_s3_cli_settings: AwsS3CliSettings | None = None
) -> Nodeports:
log.debug("creating node_ports_v2 object using provided dbmanager: %s", db_manager)
# NOTE: warning every dbmanager create a new db engine!
if db_manager is None: # NOTE: keeps backwards compatibility
log.debug("no db manager provided, creating one...")
db_manager = DBManager()

return await load(
db_manager=db_manager,
user_id=user_id,
Expand Down
6 changes: 6 additions & 0 deletions packages/simcore-sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
import simcore_sdk
from faker import Faker
from helpers.utils_port_v2 import CONSTANT_UUID
from pytest_mock.plugin import MockerFixture
from pytest_simcore.helpers.postgres_tools import PostgresTestConfig
Expand Down Expand Up @@ -85,3 +86,8 @@ def constant_uuid4(mocker: MockerFixture) -> None:
"simcore_sdk.node_ports_common.data_items_utils.uuid4",
return_value=CONSTANT_UUID,
)


@pytest.fixture
def mock_app_name(faker: Faker) -> str:
return faker.pystr()
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ async def test_delete_legacy_archive(
project_id=project_id,
node_uuid=node_uuid,
path=content_path,
application_name=faker.pystr(),
)

assert (
Expand Down
Loading
Loading