Skip to content

Commit a897add

Browse files
committed
Connection options
1 parent f3a18b7 commit a897add

File tree

5 files changed

+63
-17
lines changed

5 files changed

+63
-17
lines changed

databases/backends/mysql.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,31 @@
1717

1818

1919
class MySQLBackend(DatabaseBackend):
20-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
20+
def __init__(
21+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
22+
) -> None:
2123
self._database_url = DatabaseURL(database_url)
24+
self._options = options
2225
self._dialect = pymysql.dialect(paramstyle="pyformat")
2326
self._pool = None
2427

2528
def _get_connection_kwargs(self) -> dict:
26-
options = self._database_url.options
29+
url_options = self._database_url.options
2730

2831
kwargs = {}
29-
min_size = options.get("min_size")
30-
max_size = options.get("max_size")
31-
ssl = options.get("ssl")
32+
min_size = url_options.get("min_size")
33+
max_size = url_options.get("max_size")
34+
ssl = url_options.get("ssl")
3235

3336
if min_size is not None:
3437
kwargs["minsize"] = int(min_size)
3538
if max_size is not None:
3639
kwargs["maxsize"] = int(max_size)
3740
if ssl is not None:
3841
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
42+
43+
kwargs.update(self._options)
44+
3945
return kwargs
4046

4147
async def connect(self) -> None:

databases/backends/postgres.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020

2121
class PostgresBackend(DatabaseBackend):
22-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
22+
def __init__(
23+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
24+
) -> None:
2325
self._database_url = DatabaseURL(database_url)
26+
self._options = options
2427
self._dialect = self._get_dialect()
2528
self._pool = None
2629

@@ -37,19 +40,22 @@ def _get_dialect(self) -> Dialect:
3740
return dialect
3841

3942
def _get_connection_kwargs(self) -> dict:
40-
options = self._database_url.options
43+
url_options = self._database_url.options
4144

4245
kwargs = {}
43-
min_size = options.get("min_size")
44-
max_size = options.get("max_size")
45-
ssl = options.get("ssl")
46+
min_size = url_options.get("min_size")
47+
max_size = url_options.get("max_size")
48+
ssl = url_options.get("ssl")
4649

4750
if min_size is not None:
4851
kwargs["min_size"] = int(min_size)
4952
if max_size is not None:
5053
kwargs["max_size"] = int(max_size)
5154
if ssl is not None:
5255
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
56+
57+
kwargs.update(self._options)
58+
5359
return kwargs
5460

5561
async def connect(self) -> None:

databases/backends/sqlite.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717

1818
class SQLiteBackend(DatabaseBackend):
19-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
19+
def __init__(
20+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
21+
) -> None:
2022
self._database_url = DatabaseURL(database_url)
23+
self._options = options
2124
self._dialect = pysqlite.dialect(paramstyle="qmark")
22-
self._pool = SQLitePool(self._database_url)
25+
self._pool = SQLitePool(self._database_url, **self._options)
2326

2427
async def connect(self) -> None:
2528
pass
@@ -45,12 +48,13 @@ def connection(self) -> "SQLiteConnection":
4548

4649

4750
class SQLitePool:
48-
def __init__(self, url: DatabaseURL) -> None:
51+
def __init__(self, url: DatabaseURL, **options: typing.Any) -> None:
4952
self._url = url
53+
self._options = options
5054

5155
async def acquire(self) -> aiosqlite.Connection:
5256
connection = aiosqlite.connect(
53-
database=self._url.database, isolation_level=None
57+
database=self._url.database, isolation_level=None, **self._options
5458
)
5559
await connection.__aenter__()
5660
return connection

databases/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@ class Database:
2525
}
2626

2727
def __init__(
28-
self, url: typing.Union[str, "DatabaseURL"], *, force_rollback: bool = False
28+
self,
29+
url: typing.Union[str, "DatabaseURL"],
30+
*,
31+
force_rollback: bool = False,
32+
**options: typing.Any,
2933
):
3034
self.url = DatabaseURL(url)
31-
self._force_rollback = force_rollback
35+
self.options = options
3236
self.is_connected = False
3337

38+
self._force_rollback = force_rollback
39+
3440
backend_str = self.SUPPORTED_BACKENDS[self.url.dialect]
3541
backend_cls = import_from_string(backend_str)
3642
assert issubclass(backend_cls, DatabaseBackend)
37-
self._backend = backend_cls(self.url)
43+
self._backend = backend_cls(self.url, **self.options)
3844

3945
# Connections are stored as task-local state.
4046
self._connection_context = ContextVar("connection_context") # type: ContextVar

tests/test_connection_options.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,43 @@ def test_postgres_pool_size():
1212
assert kwargs == {"min_size": 1, "max_size": 20}
1313

1414

15+
def test_postgres_explicit_pool_size():
16+
backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20)
17+
kwargs = backend._get_connection_kwargs()
18+
assert kwargs == {"min_size": 1, "max_size": 20}
19+
20+
1521
def test_postgres_ssl():
1622
backend = PostgresBackend("postgres://localhost/database?ssl=true")
1723
kwargs = backend._get_connection_kwargs()
1824
assert kwargs == {"ssl": True}
1925

2026

27+
def test_postgres_explicit_ssl():
28+
backend = PostgresBackend("postgres://localhost/database", ssl=True)
29+
kwargs = backend._get_connection_kwargs()
30+
assert kwargs == {"ssl": True}
31+
32+
2133
def test_mysql_pool_size():
2234
backend = MySQLBackend("mysql://localhost/database?min_size=1&max_size=20")
2335
kwargs = backend._get_connection_kwargs()
2436
assert kwargs == {"minsize": 1, "maxsize": 20}
2537

2638

39+
def test_mysql_explicit_pool_size():
40+
backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20)
41+
kwargs = backend._get_connection_kwargs()
42+
assert kwargs == {"minsize": 1, "maxsize": 20}
43+
44+
2745
def test_mysql_ssl():
2846
backend = MySQLBackend("postgres://localhost/database?ssl=true")
2947
kwargs = backend._get_connection_kwargs()
3048
assert kwargs == {"ssl": True}
49+
50+
51+
def test_mysql_explicit_ssl():
52+
backend = MySQLBackend("postgres://localhost/database", ssl=True)
53+
kwargs = backend._get_connection_kwargs()
54+
assert kwargs == {"ssl": True}

0 commit comments

Comments
 (0)