Skip to content

Commit 5570cba

Browse files
authored
Merge pull request #85 from encode/connection-options
Connection options
2 parents f3a18b7 + 83d1bb2 commit 5570cba

File tree

6 files changed

+77
-18
lines changed

6 files changed

+77
-18
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ It allows you to make queries using the powerful [SQLAlchemy Core][sqlalchemy-co
1818
expression language, and provides support for PostgreSQL, MySQL, and SQLite.
1919

2020
Databases is suitable for integrating against any async Web framework, such as [Starlette][starlette],
21-
[Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], [FastAPI][fastapi],
21+
[Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], [FastAPI][fastapi],
2222
or [Bocadillo][bocadillo].
2323

2424
**Requirements**: Python 3.6+
@@ -220,6 +220,13 @@ database = Database('postgresql://localhost/example?ssl=true')
220220
database = Database('mysql://localhost/example?min_size=5&max_size=20')
221221
```
222222

223+
You can also use keyword arguments to pass in any connection options.
224+
Available keyword arguments may differ between database backends.
225+
226+
```python
227+
database = Database('postgresql://localhost/example', ssl=True, min_size=5, max_size=20)
228+
```
229+
223230
## Test isolation
224231

225232
For strict test isolation you will always want to rollback the test database

databases/backends/mysql.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,37 @@
1717

1818

1919
class MySQLBackend(DatabaseBackend):
20-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
20+
def __init__(
21+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
22+
) -> None:
2123
self._database_url = DatabaseURL(database_url)
24+
self._options = options
2225
self._dialect = pymysql.dialect(paramstyle="pyformat")
2326
self._pool = None
2427

2528
def _get_connection_kwargs(self) -> dict:
26-
options = self._database_url.options
29+
url_options = self._database_url.options
2730

2831
kwargs = {}
29-
min_size = options.get("min_size")
30-
max_size = options.get("max_size")
31-
ssl = options.get("ssl")
32+
min_size = url_options.get("min_size")
33+
max_size = url_options.get("max_size")
34+
ssl = url_options.get("ssl")
3235

3336
if min_size is not None:
3437
kwargs["minsize"] = int(min_size)
3538
if max_size is not None:
3639
kwargs["maxsize"] = int(max_size)
3740
if ssl is not None:
3841
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
42+
43+
for key, value in self._options.items():
44+
# Coerce 'min_size' and 'max_size' for consistency.
45+
if key == "min_size":
46+
key = "minsize"
47+
elif key == "max_size":
48+
key = "maxsize"
49+
kwargs[key] = value
50+
3951
return kwargs
4052

4153
async def connect(self) -> None:

databases/backends/postgres.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020

2121
class PostgresBackend(DatabaseBackend):
22-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
22+
def __init__(
23+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
24+
) -> None:
2325
self._database_url = DatabaseURL(database_url)
26+
self._options = options
2427
self._dialect = self._get_dialect()
2528
self._pool = None
2629

@@ -37,19 +40,22 @@ def _get_dialect(self) -> Dialect:
3740
return dialect
3841

3942
def _get_connection_kwargs(self) -> dict:
40-
options = self._database_url.options
43+
url_options = self._database_url.options
4144

4245
kwargs = {}
43-
min_size = options.get("min_size")
44-
max_size = options.get("max_size")
45-
ssl = options.get("ssl")
46+
min_size = url_options.get("min_size")
47+
max_size = url_options.get("max_size")
48+
ssl = url_options.get("ssl")
4649

4750
if min_size is not None:
4851
kwargs["min_size"] = int(min_size)
4952
if max_size is not None:
5053
kwargs["max_size"] = int(max_size)
5154
if ssl is not None:
5255
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
56+
57+
kwargs.update(self._options)
58+
5359
return kwargs
5460

5561
async def connect(self) -> None:

databases/backends/sqlite.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717

1818
class SQLiteBackend(DatabaseBackend):
19-
def __init__(self, database_url: typing.Union[DatabaseURL, str]) -> None:
19+
def __init__(
20+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
21+
) -> None:
2022
self._database_url = DatabaseURL(database_url)
23+
self._options = options
2124
self._dialect = pysqlite.dialect(paramstyle="qmark")
22-
self._pool = SQLitePool(self._database_url)
25+
self._pool = SQLitePool(self._database_url, **self._options)
2326

2427
async def connect(self) -> None:
2528
pass
@@ -45,12 +48,13 @@ def connection(self) -> "SQLiteConnection":
4548

4649

4750
class SQLitePool:
48-
def __init__(self, url: DatabaseURL) -> None:
51+
def __init__(self, url: DatabaseURL, **options: typing.Any) -> None:
4952
self._url = url
53+
self._options = options
5054

5155
async def acquire(self) -> aiosqlite.Connection:
5256
connection = aiosqlite.connect(
53-
database=self._url.database, isolation_level=None
57+
database=self._url.database, isolation_level=None, **self._options
5458
)
5559
await connection.__aenter__()
5660
return connection

databases/core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@ class Database:
2525
}
2626

2727
def __init__(
28-
self, url: typing.Union[str, "DatabaseURL"], *, force_rollback: bool = False
28+
self,
29+
url: typing.Union[str, "DatabaseURL"],
30+
*,
31+
force_rollback: bool = False,
32+
**options: typing.Any,
2933
):
3034
self.url = DatabaseURL(url)
31-
self._force_rollback = force_rollback
35+
self.options = options
3236
self.is_connected = False
3337

38+
self._force_rollback = force_rollback
39+
3440
backend_str = self.SUPPORTED_BACKENDS[self.url.dialect]
3541
backend_cls = import_from_string(backend_str)
3642
assert issubclass(backend_cls, DatabaseBackend)
37-
self._backend = backend_cls(self.url)
43+
self._backend = backend_cls(self.url, **self.options)
3844

3945
# Connections are stored as task-local state.
4046
self._connection_context = ContextVar("connection_context") # type: ContextVar

tests/test_connection_options.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,43 @@ def test_postgres_pool_size():
1212
assert kwargs == {"min_size": 1, "max_size": 20}
1313

1414

15+
def test_postgres_explicit_pool_size():
16+
backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20)
17+
kwargs = backend._get_connection_kwargs()
18+
assert kwargs == {"min_size": 1, "max_size": 20}
19+
20+
1521
def test_postgres_ssl():
1622
backend = PostgresBackend("postgres://localhost/database?ssl=true")
1723
kwargs = backend._get_connection_kwargs()
1824
assert kwargs == {"ssl": True}
1925

2026

27+
def test_postgres_explicit_ssl():
28+
backend = PostgresBackend("postgres://localhost/database", ssl=True)
29+
kwargs = backend._get_connection_kwargs()
30+
assert kwargs == {"ssl": True}
31+
32+
2133
def test_mysql_pool_size():
2234
backend = MySQLBackend("mysql://localhost/database?min_size=1&max_size=20")
2335
kwargs = backend._get_connection_kwargs()
2436
assert kwargs == {"minsize": 1, "maxsize": 20}
2537

2638

39+
def test_mysql_explicit_pool_size():
40+
backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20)
41+
kwargs = backend._get_connection_kwargs()
42+
assert kwargs == {"minsize": 1, "maxsize": 20}
43+
44+
2745
def test_mysql_ssl():
2846
backend = MySQLBackend("postgres://localhost/database?ssl=true")
2947
kwargs = backend._get_connection_kwargs()
3048
assert kwargs == {"ssl": True}
49+
50+
51+
def test_mysql_explicit_ssl():
52+
backend = MySQLBackend("postgres://localhost/database", ssl=True)
53+
kwargs = backend._get_connection_kwargs()
54+
assert kwargs == {"ssl": True}

0 commit comments

Comments
 (0)