diff --git a/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py b/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py index 90d80995..033399f9 100644 --- a/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py +++ b/python/sqlcommenter-python/google/cloud/sqlcommenter/django/middleware.py @@ -19,7 +19,6 @@ import django from django.db import connections -from django.db.backends.utils import CursorDebugWrapper from google.cloud.sqlcommenter import add_sql_comment from google.cloud.sqlcommenter.opencensus import get_opencensus_values from google.cloud.sqlcommenter.opentelemetry import get_opentelemetry_values @@ -90,8 +89,4 @@ def __call__(self, execute, sql, params, many, context): # * https://github.com/basecamp/marginalia/issues/61 # * https://github.com/basecamp/marginalia/pull/80 - # Add the query to the query log if debugging. - if isinstance(context['cursor'], CursorDebugWrapper): - context['connection'].queries_log.append(sql) - return execute(sql, params, many, context) diff --git a/python/sqlcommenter-python/tests/django/tests.py b/python/sqlcommenter-python/tests/django/tests.py index e61cc205..984a304c 100644 --- a/python/sqlcommenter-python/tests/django/tests.py +++ b/python/sqlcommenter-python/tests/django/tests.py @@ -17,11 +17,9 @@ import django from django.db import connection, connections from django.http import HttpRequest -from django.test import TestCase, modify_settings, override_settings +from django.test import TestCase, modify_settings from django.urls import resolve, reverse -from google.cloud.sqlcommenter.django.middleware import ( - QueryWrapper, SqlCommenter, -) +from google.cloud.sqlcommenter.django.middleware import QueryWrapper from ..compat import mock from ..opencensus_mock import mock_opencensus_tracer @@ -43,8 +41,16 @@ def __call__(self, request): return self.get_response(request) -# Query log only active if DEBUG=True. -@override_settings(DEBUG=True) +class SqlCaptureWrapper: + """Wrapper to capture the SQL after comments are added by QueryWrapper.""" + def __init__(self): + self.captured_sql = None + + def __call__(self, execute, sql, params, many, context): + self.captured_sql = sql + return execute(sql, params, many, context) + + class Tests(TestCase): databases = '__all__' @@ -55,18 +61,22 @@ def get_request(path): return request def get_query(self, path='/'): - SqlCommenter(views.home)(self.get_request(path)) - # Query with comment added by QueryWrapper and unaltered query added - # by Django's CursorDebugWrapper. - self.assertEqual(len(connection.queries), 2) - return connection.queries[0] + # Use a capture wrapper to intercept the SQL after QueryWrapper adds comments. + # QueryWrapper must be added first (outer), then capture (inner) to see modified SQL. + capture = SqlCaptureWrapper() + request = self.get_request(path) + with connection.execute_wrapper(QueryWrapper(request)): + with connection.execute_wrapper(capture): + views.home(request) + return capture.captured_sql def get_query_other_db(self, path='/', connection_name='default'): - SqlCommenter(views.home_other_db)(self.get_request(path)) - # Query with comment added by QueryWrapper and unaltered query added - # by Django's CursorDebugWrapper. - self.assertEqual(len(connections[connection_name].queries), 2) - return connections[connection_name].queries[0] + capture = SqlCaptureWrapper() + request = self.get_request(path) + with connections[connection_name].execute_wrapper(QueryWrapper(request)): + with connections[connection_name].execute_wrapper(capture): + views.home_other_db(request) + return capture.captured_sql def assertRoute(self, route, query): # route available in Django 2.2 and later.