Skip to content

Commit 93b6f82

Browse files
authored
Refactor how we provide cursor factory (#376)
1 parent 6225027 commit 93b6f82

File tree

3 files changed

+23
-20
lines changed

3 files changed

+23
-20
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from django import VERSION
2+
3+
4+
def get_postgres_cursor_class():
5+
if VERSION < (4, 2):
6+
from psycopg2.extensions import cursor as cursor_cls
7+
else:
8+
from django.db.backends.postgresql.base import Cursor as cursor_cls
9+
return cursor_cls

django_prometheus/db/backends/postgis/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
from django import VERSION
21
from django.contrib.gis.db.backends.postgis import base
32

3+
from django_prometheus.db.backends.common import get_postgres_cursor_class
44
from django_prometheus.db.common import DatabaseWrapperMixin, ExportingCursorWrapper
55

6-
if VERSION < (4, 2):
7-
from psycopg2.extensions import cursor as cursor_cls
8-
else:
9-
from django.db.backends.postgresql.base import Cursor as cursor_cls
10-
116

127
class DatabaseWrapper(DatabaseWrapperMixin, base.DatabaseWrapper):
13-
def get_connection_params(self):
14-
conn_params = super().get_connection_params()
15-
conn_params["cursor_factory"] = ExportingCursorWrapper(cursor_cls, "postgis", self.vendor)
16-
return conn_params
8+
def get_new_connection(self, *args, **kwargs):
9+
conn = super().get_new_connection(*args, **kwargs)
10+
conn.cursor_factory = ExportingCursorWrapper(
11+
conn.cursor_factory or get_postgres_cursor_class(), "postgis", self.vendor
12+
)
13+
return conn
1714

1815
def create_cursor(self, name=None):
1916
# cursor_factory is a kwarg to connect() so restore create_cursor()'s

django_prometheus/db/backends/postgresql/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
from django import VERSION
21
from django.contrib.gis.db.backends.postgis import base
32

3+
from django_prometheus.db.backends.common import get_postgres_cursor_class
44
from django_prometheus.db.common import DatabaseWrapperMixin, ExportingCursorWrapper
55

6-
if VERSION < (4, 2):
7-
from psycopg2.extensions import cursor as cursor_cls
8-
else:
9-
from django.db.backends.postgresql.base import Cursor as cursor_cls
10-
116

127
class DatabaseWrapper(DatabaseWrapperMixin, base.DatabaseWrapper):
13-
def get_connection_params(self):
14-
conn_params = super().get_connection_params()
15-
conn_params["cursor_factory"] = ExportingCursorWrapper(cursor_cls, self.alias, self.vendor)
16-
return conn_params
8+
def get_new_connection(self, *args, **kwargs):
9+
conn = super().get_new_connection(*args, **kwargs)
10+
conn.cursor_factory = ExportingCursorWrapper(
11+
conn.cursor_factory or get_postgres_cursor_class(), self.alias, self.vendor
12+
)
13+
return conn
1714

1815
def create_cursor(self, name=None):
1916
# cursor_factory is a kwarg to connect() so restore create_cursor()'s

0 commit comments

Comments
 (0)