|
1 |
| -from django.db.backends.sqlite3.base import DatabaseWrapper as BaseDatabaseWrapper |
| 1 | +from __future__ import unicode_literals |
| 2 | + |
| 3 | +from django.db.backends.sqlite3.base import DatabaseWrapper as BaseDatabaseWrapper, \ |
| 4 | + _sqlite_date_extract, _sqlite_date_trunc, _sqlite_datetime_cast_date, \ |
| 5 | + _sqlite_datetime_extract, _sqlite_datetime_trunc, _sqlite_time_extract, \ |
| 6 | + _sqlite_regexp, _sqlite_format_dtdelta, _sqlite_power, FORMAT_QMARK_REGEX |
2 | 7 |
|
3 | 8 | from ..signals import setup
|
4 | 9 |
|
| 10 | +from pysqlcipher import dbapi2 as Database |
| 11 | + |
| 12 | + |
| 13 | +import datetime |
| 14 | +import decimal |
| 15 | +import warnings |
| 16 | + |
| 17 | +from django.conf import settings |
| 18 | +from django.db.backends import utils as backend_utils |
| 19 | +from django.utils import six, timezone |
| 20 | +from django.utils.dateparse import ( |
| 21 | + parse_date, parse_datetime, parse_time, |
| 22 | +) |
| 23 | +from django.utils.deprecation import RemovedInDjango20Warning |
| 24 | +from django.utils.safestring import SafeBytes |
| 25 | + |
| 26 | +try: |
| 27 | + import pytz |
| 28 | +except ImportError: |
| 29 | + pytz = None |
| 30 | + |
| 31 | +DatabaseError = Database.DatabaseError |
| 32 | +IntegrityError = Database.IntegrityError |
| 33 | + |
| 34 | + |
| 35 | +def adapt_datetime_warn_on_aware_datetime(value): |
| 36 | + # Remove this function and rely on the default adapter in Django 2.0. |
| 37 | + if settings.USE_TZ and timezone.is_aware(value): |
| 38 | + warnings.warn( |
| 39 | + "The SQLite database adapter received an aware datetime (%s), " |
| 40 | + "probably from cursor.execute(). Update your code to pass a " |
| 41 | + "naive datetime in the database connection's time zone (UTC by " |
| 42 | + "default).", RemovedInDjango20Warning) |
| 43 | + # This doesn't account for the database connection's timezone, |
| 44 | + # which isn't known. (That's why this adapter is deprecated.) |
| 45 | + value = value.astimezone(timezone.utc).replace(tzinfo=None) |
| 46 | + return value.isoformat(str(" ")) |
| 47 | + |
| 48 | + |
| 49 | +def decoder(conv_func): |
| 50 | + """ The Python sqlite3 interface returns always byte strings. |
| 51 | + This function converts the received value to a regular string before |
| 52 | + passing it to the receiver function. |
| 53 | + """ |
| 54 | + return lambda s: conv_func(s.decode('utf-8')) |
| 55 | + |
| 56 | + |
| 57 | +Database.register_converter(str("bool"), decoder(lambda s: s == '1')) |
| 58 | +Database.register_converter(str("time"), decoder(parse_time)) |
| 59 | +Database.register_converter(str("date"), decoder(parse_date)) |
| 60 | +Database.register_converter(str("datetime"), decoder(parse_datetime)) |
| 61 | +Database.register_converter(str("timestamp"), decoder(parse_datetime)) |
| 62 | +Database.register_converter(str("TIMESTAMP"), decoder(parse_datetime)) |
| 63 | +Database.register_converter(str("decimal"), decoder(backend_utils.typecast_decimal)) |
| 64 | + |
| 65 | +Database.register_adapter(datetime.datetime, adapt_datetime_warn_on_aware_datetime) |
| 66 | +Database.register_adapter(decimal.Decimal, backend_utils.rev_typecast_decimal) |
| 67 | +if six.PY2: |
| 68 | + Database.register_adapter(str, lambda s: s.decode('utf-8')) |
| 69 | + Database.register_adapter(SafeBytes, lambda s: s.decode('utf-8')) |
| 70 | + |
5 | 71 |
|
6 | 72 | class DatabaseWrapper(BaseDatabaseWrapper):
|
| 73 | + Database = Database |
| 74 | + |
7 | 75 | def _cursor(self):
|
8 | 76 | if self.connection is None:
|
9 | 77 | setup()
|
10 | 78 | return super(DatabaseWrapper, self)._cursor()
|
| 79 | + |
| 80 | + def get_new_connection(self, conn_params): |
| 81 | + conn = Database.connect(**conn_params) |
| 82 | + conn.create_function("django_date_extract", 2, _sqlite_date_extract) |
| 83 | + conn.create_function("django_date_trunc", 2, _sqlite_date_trunc) |
| 84 | + conn.create_function("django_datetime_cast_date", 2, _sqlite_datetime_cast_date) |
| 85 | + conn.create_function("django_datetime_extract", 3, _sqlite_datetime_extract) |
| 86 | + conn.create_function("django_datetime_trunc", 3, _sqlite_datetime_trunc) |
| 87 | + conn.create_function("django_time_extract", 2, _sqlite_time_extract) |
| 88 | + conn.create_function("regexp", 2, _sqlite_regexp) |
| 89 | + conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) |
| 90 | + conn.create_function("django_power", 2, _sqlite_power) |
| 91 | + return conn |
| 92 | + |
| 93 | + def create_cursor(self): |
| 94 | + return self.connection.cursor(factory=SQLiteCursorWrapper) |
| 95 | + |
| 96 | + |
| 97 | +class SQLiteCursorWrapper(Database.Cursor): |
| 98 | + """ |
| 99 | + Django uses "format" style placeholders, but pysqlite2 uses "qmark" style. |
| 100 | + This fixes it -- but note that if you want to use a literal "%s" in a query, |
| 101 | + you'll need to use "%%s". |
| 102 | + """ |
| 103 | + def execute(self, query, params=None): |
| 104 | + if params is None: |
| 105 | + return Database.Cursor.execute(self, query) |
| 106 | + query = self.convert_query(query) |
| 107 | + return Database.Cursor.execute(self, query, params) |
| 108 | + |
| 109 | + def executemany(self, query, param_list): |
| 110 | + query = self.convert_query(query) |
| 111 | + return Database.Cursor.executemany(self, query, param_list) |
| 112 | + |
| 113 | + def convert_query(self, query): |
| 114 | + return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%') |
0 commit comments