diff --git a/src/gino/dialects/asyncpg.py b/src/gino/dialects/asyncpg.py index f7e2fefb..795f8d9b 100644 --- a/src/gino/dialects/asyncpg.py +++ b/src/gino/dialects/asyncpg.py @@ -245,14 +245,9 @@ def __init__(self, *pargs, **kwargs): self.baked_queries = {} args.update( - loop=self._loop, - host=self._url.host, - port=self._url.port, - user=self._url.username, - database=self._url.database, - password=self._url.password, - connection_class=Connection, + connection_class=Connection, dsn=str(self._url), loop=self._loop, ) + if self._prebake and self._bakery: self._init_hook = args.pop("init", None) args["init"] = self._bake diff --git a/src/gino/strategies.py b/src/gino/strategies.py index eeb1cd85..63bee32b 100644 --- a/src/gino/strategies.py +++ b/src/gino/strategies.py @@ -1,11 +1,12 @@ import asyncio from copy import copy -from sqlalchemy.engine import url from sqlalchemy import util +from sqlalchemy.engine.url import make_url from sqlalchemy.engine.strategies import EngineStrategy from .engine import GinoEngine +from .dialects.asyncpg import AsyncpgDialect class GinoStrategy(EngineStrategy): @@ -14,8 +15,7 @@ class GinoStrategy(EngineStrategy): This strategy is initialized automatically as :mod:`gino` is imported. If :func:`sqlalchemy.create_engine` uses ``strategy="gino"``, it will return a - :class:`~collections.abc.Coroutine`, and treat URL prefix ``postgresql://`` or - ``postgres://`` as ``postgresql+asyncpg://``. + :class:`~collections.abc.Coroutine`. """ name = "gino" @@ -23,14 +23,16 @@ class GinoStrategy(EngineStrategy): async def create(self, name_or_url, loop=None, **kwargs): engine_cls = self.engine_cls - u = url.make_url(name_or_url) + url = make_url(name_or_url) if loop is None: loop = asyncio.get_event_loop() - if u.drivername in {"postgresql", "postgres"}: - u = copy(u) - u.drivername = "postgresql+asyncpg" - dialect_cls = u.get_dialect() + # The postgresql dialect is already taken by the PGDialect_psycopg2 + # we need to force ourone. + if url.drivername in ("postgresql", "postgres"): + dialect_cls = AsyncpgDialect + else: + dialect_cls = url.get_dialect() pop_kwarg = kwargs.pop @@ -52,7 +54,7 @@ async def create(self, name_or_url, loop=None, **kwargs): dialect = dialect_cls(**dialect_args) pool_class = kwargs.pop("pool_class", None) - pool = await dialect.init_pool(u, loop, pool_class=pool_class) + pool = await dialect.init_pool(url, loop, pool_class=pool_class) engine_args = dict(loop=loop) for k in util.get_cls_kwargs(engine_cls): diff --git a/tests/test_bind.py b/tests/test_bind.py index 53ae8a80..7685fad9 100644 --- a/tests/test_bind.py +++ b/tests/test_bind.py @@ -4,7 +4,7 @@ from gino.exceptions import UninitializedError from sqlalchemy.engine.url import make_url -from .models import db, PG_URL, User +from .models import db, DB_ARGS, PG_URL, User pytestmark = pytest.mark.asyncio @@ -55,9 +55,28 @@ async def test_db_api(bind, random_name): assert params[0] == 3 -async def test_bind_url(): - url = make_url(PG_URL) - assert url.drivername == "postgresql" - await db.set_bind(PG_URL) - assert url.drivername == "postgresql" +@pytest.mark.parametrize( + "dsn, driver_name", + ( + ( + "postgresql://{user}:{password}@{host}:{port}/{database}".format(**DB_ARGS), + "postgresql", + ), + ( + "postgres://{user}:{password}@{host}:{port}/{database}".format(**DB_ARGS), + "postgres", + ), + ( + "postgres://{user}:{password}@/{database}?host={host}&port={port}".format( + **DB_ARGS + ), + "postgres", + ), + ), +) +async def test_bind_url(dsn, driver_name): + url = make_url(dsn) + assert url.drivername == driver_name + await db.set_bind(dsn) + assert url.drivername == driver_name await db.pop_bind().close()