Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 28ea805

Browse files
committed
➕ Add support for SQL Server
1 parent deedd13 commit 28ea805

File tree

10 files changed

+433
-10
lines changed

10 files changed

+433
-10
lines changed

.github/workflows/test-suite.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,30 @@ jobs:
3838
- 5432:5432
3939
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
4040

41+
mssql:
42+
image: mcr.microsoft.com/mssql/server:2019-GA-ubuntu-16.04
43+
env:
44+
MSSQL_SA_PASSWORD: "mssql123mssql"
45+
ACCEPT_EULA: "Y"
46+
MSSQL_PID: "Developer"
47+
ports:
48+
- "1433:1433"
49+
4150
steps:
4251
- uses: "actions/checkout@v3"
4352
- uses: "actions/setup-python@v4"
4453
with:
4554
python-version: "${{ matrix.python-version }}"
55+
- name: "Install drivers"
56+
run: |
57+
curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add -
58+
curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list > /etc/apt/sources.list.d/mssql-release.list
59+
sudo apt-get update -y
60+
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql17
61+
sudo ACCEPT_EULA=Y apt-get install -y mssql-tools
62+
echo 'export PATH="$PATH:/opt/mssql-tools/bin"' >> ~/.bashrc
63+
source ~/.bashrc
64+
sudo apt-get install -y unixodbc-dev
4665
- name: "Install dependencies"
4766
run: "scripts/install"
4867
- name: "Run linting checks"
@@ -60,4 +79,7 @@ jobs:
6079
postgresql://username:password@localhost:5432/testsuite,
6180
postgresql+aiopg://username:[email protected]:5432/testsuite,
6281
postgresql+asyncpg://username:password@localhost:5432/testsuite
82+
mssql://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server,
83+
mssql+pyodbc://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server,
84+
mssql+aioodbc://sa:mssql123mssql@localhost:1433/master?driver=ODBC+Driver+17+for+SQL+Server
6385
run: "scripts/test"

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Database drivers supported are:
3636
* [aiomysql][aiomysql]
3737
* [asyncmy][asyncmy]
3838
* [aiosqlite][aiosqlite]
39+
* [aioodbc][aioodbc]
3940

4041
You can install the required database drivers with:
4142

@@ -45,9 +46,10 @@ $ pip install databases[aiopg]
4546
$ pip install databases[aiomysql]
4647
$ pip install databases[asyncmy]
4748
$ pip install databases[aiosqlite]
49+
$ pip install databases[aioodbc]
4850
```
4951

50-
Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL.
52+
Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL, [pymysql][pymysql] for MySQL and [pyodbc][pyodbc] for SQL Server.
5153

5254
---
5355

@@ -103,11 +105,13 @@ for examples of how to start using databases together with SQLAlchemy core expre
103105
[alembic]: https://alembic.sqlalchemy.org/en/latest/
104106
[psycopg2]: https://www.psycopg.org/
105107
[pymysql]: https://github.com/PyMySQL/PyMySQL
108+
[pyodbc]: https://github.com/mkleehammer/pyodbc
106109
[asyncpg]: https://github.com/MagicStack/asyncpg
107110
[aiopg]: https://github.com/aio-libs/aiopg
108111
[aiomysql]: https://github.com/aio-libs/aiomysql
109112
[asyncmy]: https://github.com/long2ice/asyncmy
110113
[aiosqlite]: https://github.com/omnilib/aiosqlite
114+
[aioodbc]: https://aioodbc.readthedocs.io/en/latest/
111115

112116
[starlette]: https://github.com/encode/starlette
113117
[sanic]: https://github.com/huge-success/sanic

databases/backends/mssql.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import getpass
2+
import logging
3+
import typing
4+
import uuid
5+
6+
import aioodbc
7+
from sqlalchemy.dialects.mssql import pyodbc
8+
from sqlalchemy.engine.cursor import CursorResultMetaData
9+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
10+
from sqlalchemy.sql import ClauseElement
11+
from sqlalchemy.sql.ddl import DDLElement
12+
13+
from databases.backends.common.records import Record, Row, create_column_maps
14+
from databases.core import LOG_EXTRA, DatabaseURL
15+
from databases.interfaces import (
16+
ConnectionBackend,
17+
DatabaseBackend,
18+
Record as RecordInterface,
19+
TransactionBackend,
20+
)
21+
22+
logger = logging.getLogger("databases")
23+
24+
25+
class MSSQLBackend(DatabaseBackend):
26+
def __init__(
27+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
28+
) -> None:
29+
self._database_url = DatabaseURL(database_url)
30+
self._options = options
31+
self._dialect = pyodbc.dialect(paramstyle="pyformat")
32+
self._dialect.supports_native_decimal = True
33+
self._pool: aioodbc.Pool = None
34+
35+
def _get_connection_kwargs(self) -> dict:
36+
url_options = self._database_url.options
37+
38+
kwargs = {}
39+
min_size = url_options.get("min_size")
40+
max_size = url_options.get("max_size")
41+
pool_recycle = url_options.get("pool_recycle")
42+
ssl = url_options.get("ssl")
43+
driver = url_options.get("driver")
44+
trusted_connection = url_options.get("trusted_connection", "no")
45+
46+
assert driver is not None, "The driver must be specified"
47+
48+
if min_size is not None:
49+
kwargs["minsize"] = int(min_size)
50+
if max_size is not None:
51+
kwargs["maxsize"] = int(max_size)
52+
if pool_recycle is not None:
53+
kwargs["pool_recycle"] = int(pool_recycle)
54+
if ssl is not None:
55+
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
56+
57+
kwargs["trusted_connection"] = trusted_connection.lower()
58+
kwargs["driver"] = driver
59+
60+
for key, value in self._options.items():
61+
# Coerce 'min_size' and 'max_size' for consistency.
62+
if key == "min_size":
63+
key = "minsize"
64+
elif key == "max_size":
65+
key = "maxsize"
66+
kwargs[key] = value
67+
68+
return kwargs
69+
70+
async def connect(self) -> None:
71+
assert self._pool is None, "DatabaseBackend is already running"
72+
kwargs = self._get_connection_kwargs()
73+
74+
driver = kwargs["driver"]
75+
database = self._database_url.database
76+
hostname = self._database_url.hostname
77+
port = self._database_url.port or 1433
78+
user = self._database_url.username or getpass.getuser()
79+
password = self._database_url.password
80+
81+
dsn = f"Driver={driver};Database={database};Server={hostname};UID={user};PWD={password};Port={port}"
82+
83+
self._pool = await aioodbc.create_pool(
84+
dsn=dsn,
85+
autocommit=True,
86+
**kwargs,
87+
)
88+
89+
async def disconnect(self) -> None:
90+
assert self._pool is not None, "DatabaseBackend is not running"
91+
self._pool.close()
92+
await self._pool.wait_closed()
93+
self._pool = None
94+
95+
def connection(self) -> "MSSQLConnection":
96+
return MSSQLConnection(self, self._dialect)
97+
98+
99+
class CompilationContext:
100+
def __init__(self, context: ExecutionContext):
101+
self.context = context
102+
103+
104+
class MSSQLConnection(ConnectionBackend):
105+
def __init__(self, database: MSSQLBackend, dialect: Dialect) -> None:
106+
self._database = database
107+
self._dialect = dialect
108+
self._connection: typing.Optional[aioodbc.Connection] = None
109+
110+
async def acquire(self) -> None:
111+
assert self._connection is None, "Connection is already acquired"
112+
assert self._database._pool is not None, "DatabaseBackend is not running"
113+
self._connection = await self._database._pool.acquire()
114+
115+
async def release(self) -> None:
116+
assert self._connection is not None, "Connection is not acquired"
117+
assert self._database._pool is not None, "DatabaseBackend is not running"
118+
await self._database._pool.release(self._connection)
119+
self._connection = None
120+
121+
async def fetch_all(self, query: ClauseElement) -> typing.List["RecordInterface"]:
122+
assert self._connection is not None, "Connection is not acquired"
123+
query_str, args, result_columns, context = self._compile(query)
124+
column_maps = create_column_maps(result_columns)
125+
dialect = self._dialect
126+
cursor = await self._connection.cursor()
127+
try:
128+
await cursor.execute(query_str, args)
129+
rows = await cursor.fetchall()
130+
metadata = CursorResultMetaData(context, cursor.description)
131+
rows = [
132+
Row(
133+
metadata,
134+
metadata._processors,
135+
metadata._keymap,
136+
Row._default_key_style,
137+
row,
138+
)
139+
for row in rows
140+
]
141+
return [Record(row, result_columns, dialect, column_maps) for row in rows]
142+
finally:
143+
await cursor.close()
144+
145+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
146+
assert self._connection is not None, "Connection is not acquired"
147+
query_str, args, result_columns, context = self._compile(query)
148+
column_maps = create_column_maps(result_columns)
149+
dialect = self._dialect
150+
cursor = await self._connection.cursor()
151+
try:
152+
await cursor.execute(query_str, args)
153+
row = await cursor.fetchone()
154+
if row is None:
155+
return None
156+
metadata = CursorResultMetaData(context, cursor.description)
157+
row = Row(
158+
metadata,
159+
metadata._processors,
160+
metadata._keymap,
161+
Row._default_key_style,
162+
row,
163+
)
164+
return Record(row, result_columns, dialect, column_maps)
165+
finally:
166+
await cursor.close()
167+
168+
async def execute(self, query: ClauseElement) -> typing.Any:
169+
assert self._connection is not None, "Connection is not acquired"
170+
query_str, args, _, _ = self._compile(query)
171+
cursor = await self._connection.cursor()
172+
try:
173+
values = await cursor.execute(query_str, args)
174+
try:
175+
values = await values.fetchone()
176+
return values[0]
177+
except Exception:
178+
...
179+
finally:
180+
await cursor.close()
181+
182+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
183+
assert self._connection is not None, "Connection is not acquired"
184+
cursor = await self._connection.cursor()
185+
try:
186+
for single_query in queries:
187+
single_query, args, _, _ = self._compile(single_query)
188+
await cursor.execute(single_query, args)
189+
finally:
190+
await cursor.close()
191+
192+
async def iterate(
193+
self, query: ClauseElement
194+
) -> typing.AsyncGenerator[typing.Any, None]:
195+
assert self._connection is not None, "Connection is not acquired"
196+
query_str, args, result_columns, context = self._compile(query)
197+
column_maps = create_column_maps(result_columns)
198+
dialect = self._dialect
199+
cursor = await self._connection.cursor()
200+
try:
201+
await cursor.execute(query_str, args)
202+
metadata = CursorResultMetaData(context, cursor.description)
203+
async for row in cursor:
204+
record = Row(
205+
metadata,
206+
metadata._processors,
207+
metadata._keymap,
208+
Row._default_key_style,
209+
row,
210+
)
211+
yield Record(record, result_columns, dialect, column_maps)
212+
finally:
213+
await cursor.close()
214+
215+
def transaction(self) -> TransactionBackend:
216+
return MSSQLTransaction(self)
217+
218+
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
219+
compiled = query.compile(
220+
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
221+
)
222+
223+
execution_context = self._dialect.execution_ctx_cls()
224+
execution_context.dialect = self._dialect
225+
226+
if not isinstance(query, DDLElement):
227+
compiled_params = compiled.params.items()
228+
229+
mapping = {key: "?" for _, (key, _) in enumerate(compiled_params, start=1)}
230+
compiled_query = compiled.string % mapping
231+
232+
processors = compiled._bind_processors
233+
args = [
234+
processors[key](val) if key in processors else val
235+
for key, val in compiled_params
236+
]
237+
238+
execution_context.result_column_struct = (
239+
compiled._result_columns,
240+
compiled._ordered_columns,
241+
compiled._textual_ordered_columns,
242+
compiled._ad_hoc_textual,
243+
compiled._loose_column_name_matching,
244+
)
245+
246+
result_map = compiled._result_columns
247+
else:
248+
compiled_query = compiled.string
249+
args = []
250+
result_map = None
251+
252+
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
253+
logger.debug(
254+
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
255+
)
256+
return compiled_query, args, result_map, CompilationContext(execution_context)
257+
258+
@property
259+
def raw_connection(self) -> aioodbc.connection.Connection:
260+
assert self._connection is not None, "Connection is not acquired"
261+
return self._connection
262+
263+
264+
class MSSQLTransaction(TransactionBackend):
265+
def __init__(self, connection: MSSQLConnection):
266+
self._connection = connection
267+
self._is_root = False
268+
self._savepoint_name = ""
269+
270+
async def start(
271+
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
272+
) -> None:
273+
assert self._connection._connection is not None, "Connection is not acquired"
274+
self._is_root = is_root
275+
cursor = await self._connection._connection.cursor()
276+
if self._is_root:
277+
await cursor.execute("BEGIN TRANSACTION")
278+
else:
279+
id = str(uuid.uuid4()).replace("-", "_")[:12]
280+
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
281+
try:
282+
await cursor.execute(f"SAVE TRANSACTION {self._savepoint_name}")
283+
finally:
284+
cursor.close()
285+
286+
async def commit(self) -> None:
287+
assert self._connection._connection is not None, "Connection is not acquired"
288+
cursor = await self._connection._connection.cursor()
289+
if self._is_root:
290+
await cursor.execute("COMMIT TRANSACTION")
291+
else:
292+
try:
293+
await cursor.execute(f"COMMIT TRANSACTION {self._savepoint_name}")
294+
finally:
295+
cursor.close()
296+
297+
async def rollback(self) -> None:
298+
assert self._connection._connection is not None, "Connection is not acquired"
299+
cursor = await self._connection._connection.cursor()
300+
if self._is_root:
301+
await cursor.execute("BEGIN TRANSACTION")
302+
await cursor.execute("ROLLBACK TRANSACTION")
303+
else:
304+
try:
305+
await cursor.execute(f"ROLLBACK TRANSACTION {self._savepoint_name}")
306+
finally:
307+
cursor.close()

databases/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class Database:
4242
"postgres": "databases.backends.postgres:PostgresBackend",
4343
"mysql": "databases.backends.mysql:MySQLBackend",
4444
"mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend",
45+
"mssql": "databases.backends.mssql:MSSQLBackend",
46+
"mssql+pyodbc": "databases.backends.mssql:MSSQLBackend",
47+
"mssql+aioodbc": "databases.backends.mssql:MSSQLBackend",
4548
"sqlite": "databases.backends.sqlite:SQLiteBackend",
4649
}
4750

0 commit comments

Comments
 (0)