|
11 | 11 | import secrets |
12 | 12 | import shutil |
13 | 13 | import signal |
| 14 | +import sqlite3 |
14 | 15 | import subprocess |
15 | 16 | import sys |
| 17 | +from typing import Any, Callable |
16 | 18 |
|
17 | 19 | import tornado |
18 | | -import uvloop as uvloop |
19 | 20 | from jupyterhub.log import log_request |
20 | | -from sqlalchemy import create_engine |
| 21 | +from sqlalchemy import Engine, create_engine, event |
21 | 22 | from sqlalchemy.orm import scoped_session, sessionmaker |
22 | 23 | from tornado.httpserver import HTTPServer |
23 | 24 | from traitlets import ( |
@@ -63,6 +64,47 @@ def get_session_maker(url) -> scoped_session: |
63 | 64 | return scoped_session(sessionmaker(bind=engine)) |
64 | 65 |
|
65 | 66 |
|
| 67 | +def enable_foreign_keys_for_sqlite() -> Callable[[sqlite3.Connection, Any], None] | None: |
| 68 | + """ |
| 69 | + Register a listener that enables foreign key constraints for SQLite databases. |
| 70 | +
|
| 71 | + It returns the function that sets the foreign keys pragma, and which is |
| 72 | + registered with the event listener on Engine.connect. The function can be |
| 73 | + used to later remove the listener: |
| 74 | +
|
| 75 | + e.g.:: |
| 76 | +
|
| 77 | + from sqlalchemy import Engine, event |
| 78 | +
|
| 79 | + event.remove(Engine, "connect", set_sqlite_pragma) |
| 80 | +
|
| 81 | + This function is a no-op if the DATABASE_TYPE environment variable is not set to "sqlite". |
| 82 | + In this case it returns `None`. |
| 83 | + """ |
| 84 | + database_type = os.getenv("DATABASE_TYPE") |
| 85 | + if database_type == "sqlite": |
| 86 | + # The following function is needed to enable foreign key constraints in SQLite. |
| 87 | + # The code was copied from the SQLAlchemy documentation: |
| 88 | + # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#foreign-key-support |
| 89 | + @event.listens_for(Engine, "connect") |
| 90 | + def set_sqlite_pragma(dbapi_connection, connection_record): |
| 91 | + # the sqlite3 driver will not set PRAGMA foreign_keys |
| 92 | + # if autocommit=False; set to True temporarily |
| 93 | + ac = dbapi_connection.autocommit |
| 94 | + dbapi_connection.autocommit = True |
| 95 | + |
| 96 | + cursor = dbapi_connection.cursor() |
| 97 | + # Note: this is a SQLite-specific pragma |
| 98 | + cursor.execute("PRAGMA foreign_keys=ON") |
| 99 | + cursor.close() |
| 100 | + |
| 101 | + # restore previous autocommit setting |
| 102 | + dbapi_connection.autocommit = ac |
| 103 | + |
| 104 | + return set_sqlite_pragma |
| 105 | + return None |
| 106 | + |
| 107 | + |
66 | 108 | class GraderService(config.Application): |
67 | 109 | name = "grader-service" |
68 | 110 | version = __version__ |
@@ -111,13 +153,13 @@ def _default_db_url(self): |
111 | 153 | allow_none=False, |
112 | 154 | ).tag(config=True) |
113 | 155 |
|
114 | | - max_body_size = Int(104857600, help="Sets the max buffer size in bytes, default to 100mb").tag( |
| 156 | + max_body_size = Int(104857600, help="Sets the max body size in bytes, default to 100mb").tag( |
115 | 157 | config=True |
116 | 158 | ) |
117 | 159 |
|
118 | | - max_buffer_size = Int(104857600, help="Sets the max body size in bytes, default to 100mb").tag( |
119 | | - config=True |
120 | | - ) |
| 160 | + max_buffer_size = Int( |
| 161 | + 104857600, help="Sets the max buffer size in bytes, default to 100mb" |
| 162 | + ).tag(config=True) |
121 | 163 |
|
122 | 164 | service_git_username = Unicode("grader-service", allow_none=False).tag(config=True) |
123 | 165 |
|
@@ -329,6 +371,8 @@ def initialize(self, argv, *args, **kwargs): |
329 | 371 | self.load_config_file(self.config_file) |
330 | 372 | self.setup_loggers(self.log_level) |
331 | 373 |
|
| 374 | + enable_foreign_keys_for_sqlite() |
| 375 | + |
332 | 376 | self.session_maker = get_session_maker(self.db_url) |
333 | 377 | self.init_roles() |
334 | 378 | # use uvloop instead of default asyncio loop |
@@ -398,7 +442,10 @@ def init_roles(self): |
398 | 442 |
|
399 | 443 | user = db.query(User).filter(User.name == username).one_or_none() |
400 | 444 | if user is None: |
401 | | - self.log.info(f"Adding new user with username {username} and display name {display_name}") |
| 445 | + self.log.info( |
| 446 | + f"Adding new user with username {username} " |
| 447 | + f"and display name {display_name}" |
| 448 | + ) |
402 | 449 | user = User() |
403 | 450 | user.name = username |
404 | 451 | user.display_name = display_name |
|
0 commit comments