Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 555e4bf

Browse files
committed
Some refectoring and polishing for raw_connection #9
1 parent b75004b commit 555e4bf

File tree

7 files changed

+61
-80
lines changed

7 files changed

+61
-80
lines changed

databases/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from databases.core import Database, DatabaseURL
22

3-
43
__version__ = "0.1.9"
54
__all__ = ["Database", "DatabaseURL"]

databases/backends/mysql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
135135
finally:
136136
await cursor.close()
137137

138-
async def expose_backend_connection(self) -> aiomysql.connection.Connection:
139-
assert self._connection is not None, "Connection is not acquired"
140-
return self._connection
141-
142138
async def iterate(
143139
self, query: ClauseElement
144140
) -> typing.AsyncGenerator[typing.Any, None]:
@@ -153,6 +149,10 @@ async def iterate(
153149
finally:
154150
await cursor.close()
155151

152+
async def raw_connection(self) -> aiomysql.connection.Connection:
153+
assert self._connection is not None, "Connection is not acquired"
154+
return self._connection
155+
156156
def transaction(self) -> TransactionBackend:
157157
return MySQLTransaction(self)
158158

databases/backends/postgres.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,21 +74,18 @@ def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
7474
for idx, (column_name, _, _, datatype) in enumerate(self._result_columns)
7575
}
7676

77-
def __getitem__(self, key: typing.Any) -> typing.Any:
77+
def __getitem__(self, key: str) -> typing.Any:
78+
idx, datatype = self._column_map[key]
79+
raw = self._row[idx]
7880
try:
79-
idx, datatype = self._column_map[key]
80-
raw = self._row[idx]
81-
try:
82-
processor = _result_processors[datatype]
83-
except KeyError:
84-
processor = datatype.result_processor(self._dialect, None)
85-
_result_processors[datatype] = processor
86-
87-
if processor is not None:
88-
return processor(raw)
89-
return raw
81+
processor = _result_processors[datatype]
9082
except KeyError:
91-
return self._row[key]
83+
processor = datatype.result_processor(self._dialect, None)
84+
_result_processors[datatype] = processor
85+
86+
if processor is not None:
87+
return processor(raw)
88+
return raw
9289

9390
def __iter__(self) -> typing.Iterator:
9491
return iter(self._column_map)
@@ -145,10 +142,6 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
145142
single_query, args, result_columns = self._compile(single_query)
146143
await self._connection.execute(single_query, *args)
147144

148-
async def expose_backend_connection(self) -> asyncpg.connection.Connection:
149-
assert self._connection is not None, "Connection is not acquired"
150-
return self._connection
151-
152145
async def iterate(
153146
self, query: ClauseElement
154147
) -> typing.AsyncGenerator[typing.Any, None]:
@@ -157,6 +150,10 @@ async def iterate(
157150
async for row in self._connection.cursor(query, *args):
158151
yield Record(row, result_columns, self._dialect)
159152

153+
async def raw_connection(self) -> asyncpg.connection.Connection:
154+
assert self._connection is not None, "Connection is not acquired"
155+
return self._connection
156+
160157
def transaction(self) -> TransactionBackend:
161158
return PostgresTransaction(connection=self)
162159

databases/backends/sqlite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
116116
for value in values:
117117
await self.execute(query, value)
118118

119-
async def expose_backend_connection(self) -> aiosqlite.core.Connection:
119+
async def raw_connection(self) -> aiosqlite.core.Connection:
120120
assert self._connection is not None, "Connection is not acquired"
121121
return self._connection
122122

databases/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,17 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
104104
async with self.connection() as connection:
105105
return await connection.execute_many(query=query, values=values)
106106

107-
async def expose_backend_connection(self) -> typing.Any:
108-
async with self.connection() as connection:
109-
return await connection.expose_backend_connection()
110-
111107
async def iterate(
112108
self, query: ClauseElement
113109
) -> typing.AsyncGenerator[RowProxy, None]:
114110
async with self.connection() as connection:
115111
async for record in connection.iterate(query):
116112
yield record
117113

114+
async def raw_connection(self) -> typing.Any:
115+
async with self.connection() as connection:
116+
return await connection.raw_connection()
117+
118118
def connection(self) -> "Connection":
119119
if self._global_connection is not None:
120120
return self._global_connection
@@ -172,8 +172,8 @@ async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any
172172
async def execute_many(self, query: ClauseElement, values: list) -> None:
173173
await self._connection.execute_many(query, values)
174174

175-
async def expose_backend_connection(self) -> typing.Any:
176-
return await self._connection.expose_backend_connection()
175+
async def raw_connection(self) -> typing.Any:
176+
return await self._connection.raw_connection()
177177

178178
async def iterate(
179179
self, query: ClauseElement

databases/interfaces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any
3333
async def execute_many(self, query: ClauseElement, values: list) -> None:
3434
raise NotImplementedError() # pragma: no cover
3535

36-
async def expose_backend_connection(self) -> typing.Any:
37-
raise NotImplementedError() # pragma: no cover
38-
3936
async def iterate(
4037
self, query: ClauseElement
4138
) -> typing.AsyncGenerator[typing.Mapping, None]:
@@ -44,6 +41,9 @@ async def iterate(
4441
# https://github.com/python/mypy/issues/5385#issuecomment-407281656
4542
yield True # pragma: no cover
4643

44+
async def raw_connection(self) -> typing.Any:
45+
raise NotImplementedError() # pragma: no cover
46+
4747
def transaction(self) -> "TransactionBackend":
4848
raise NotImplementedError() # pragma: no cover
4949

tests/test_databases.py

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,25 @@ async def get_connection_2():
516516
await task_2
517517

518518

519+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
520+
@async_adapter
521+
async def test_connection_context_with_raw_connection(database_url):
522+
"""
523+
Test connection contexts with respect to the raw connection.
524+
"""
525+
async with Database(database_url) as database:
526+
async with database.connection() as connection_1:
527+
async with database.connection() as connection_2:
528+
assert connection_1 is connection_2
529+
530+
raw_connection_0 = await database.raw_connection()
531+
raw_connection_1 = await connection_1.raw_connection()
532+
raw_connection_2 = await connection_2.raw_connection()
533+
534+
assert raw_connection_0 is raw_connection_1 is raw_connection_2
535+
assert raw_connection_0 is connection_1._connection._connection
536+
537+
519538
@pytest.mark.parametrize("database_url", DATABASE_URLS)
520539
@async_adapter
521540
async def test_queries_with_expose_backend_connection(database_url):
@@ -525,29 +544,30 @@ async def test_queries_with_expose_backend_connection(database_url):
525544
"""
526545
async with Database(database_url) as database:
527546
async with database.transaction(force_rollback=True):
547+
# Get the raw connection
548+
con = await database.raw_connection()
549+
528550
# Insert query
529-
if str(database_url).startswith('mysql'):
551+
if str(database_url).startswith("mysql"):
530552
insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
531553
else:
532554
insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)"
533555

534556
# execute()
535557
values = ("example1", True)
536558

537-
con = await database.expose_backend_connection()
538-
539-
if str(database_url).startswith('postgresql'):
559+
if str(database_url).startswith("postgresql"):
540560
await con.execute(insert_query, *values)
541-
elif str(database_url).startswith('mysql'):
561+
elif str(database_url).startswith("mysql"):
542562
cursor = await con.cursor()
543563
await cursor.execute(insert_query, values)
544-
elif str(database_url).startswith('sqlite'):
564+
elif str(database_url).startswith("sqlite"):
545565
await con.execute(insert_query, values)
546566

547567
# execute_many()
548568
values = [("example2", False), ("example3", True)]
549-
550-
if str(database_url).startswith('mysql'):
569+
570+
if str(database_url).startswith("mysql"):
551571
cursor = await con.cursor()
552572
await cursor.executemany(insert_query, values)
553573
else:
@@ -557,13 +577,13 @@ async def test_queries_with_expose_backend_connection(database_url):
557577
select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"
558578

559579
# fetch_all()
560-
if str(database_url).startswith('postgresql'):
580+
if str(database_url).startswith("postgresql"):
561581
results = await con.fetch(select_query)
562-
elif str(database_url).startswith('mysql'):
582+
elif str(database_url).startswith("mysql"):
563583
cursor = await con.cursor()
564584
await cursor.execute(select_query)
565585
results = await cursor.fetchall()
566-
elif str(database_url).startswith('sqlite'):
586+
elif str(database_url).startswith("sqlite"):
567587
results = await con.execute_fetchall(select_query)
568588

569589
assert len(results) == 3
@@ -576,48 +596,13 @@ async def test_queries_with_expose_backend_connection(database_url):
576596
assert results[2][2] == True
577597

578598
# fetch_one()
579-
if str(database_url).startswith('postgresql'):
599+
if str(database_url).startswith("postgresql"):
580600
result = await con.fetchrow(select_query)
581601
else:
582602
cursor = await con.cursor()
583603
await cursor.execute(select_query)
584604
result = await cursor.fetchone()
585-
605+
586606
# Raw output for the raw request
587607
assert result[1] == "example1"
588608
assert result[2] == True
589-
590-
591-
@pytest.mark.parametrize("database_url", DATABASE_URLS)
592-
@async_adapter
593-
async def test_queries_with_sqlalchemy_test(database_url):
594-
"""
595-
Test for inserting and retriving the data using the `sqlalchemy.text` query object.
596-
"""
597-
async with Database(database_url) as database:
598-
async with database.transaction(force_rollback=True):
599-
# Insert query
600-
insert_query = "INSERT INTO notes (text, completed) VALUES (:text, :completed)"
601-
602-
# execute_many()
603-
query = sqlalchemy.text(insert_query)
604-
values = [("example1", True), ("example2", False), ("example3", True)]
605-
for text, completed in values:
606-
current_query = query.bindparams(text=text, completed=completed)
607-
await database.execute(current_query)
608-
609-
# Select query
610-
select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"
611-
612-
# fetch_all()
613-
query = sqlalchemy.text(select_query)
614-
results = await database.fetch_all(query)
615-
616-
assert len(results) == 3
617-
# Raw output for the raw request
618-
assert results[0][1] == "example1"
619-
assert results[0][2] == True
620-
assert results[1][1] == "example2"
621-
assert results[1][2] == False
622-
assert results[2][1] == "example3"
623-
assert results[2][2] == True

0 commit comments

Comments
 (0)