Skip to content

Commit c385a6c

Browse files
committed
override get_multi_foreign_keys()
1 parent a4cf3cb commit c385a6c

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ jobs:
4242
crdb-version: [
4343
# "cockroach:latest-v24.3",
4444
# "cockroach:latest-v25.2",
45-
# "cockroach:latest-v25.4",
46-
"cockroach:latest-v26.1"
45+
"cockroach:latest-v25.4"
46+
# "cockroach:latest-v26.1"
4747
]
4848
db-alias: [
4949
"psycopg2",

sqlalchemy_cockroachdb/base.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import re
33
import threading
44
from sqlalchemy import text
5+
from sqlalchemy import util
56
from sqlalchemy.dialects.postgresql.base import PGDialect
67
from sqlalchemy.dialects.postgresql import ARRAY
78
from sqlalchemy.dialects.postgresql import INET
89
from sqlalchemy.dialects.postgresql import UUID
910
from sqlalchemy.dialects.postgresql import JSONB
11+
from sqlalchemy.engine.reflection import ReflectionDefaults
1012
from sqlalchemy.ext.compiler import compiles
1113
from sqlalchemy.util import warn
1214

@@ -363,6 +365,121 @@ def get_multi_indexes(
363365
result.pop(k, None)
364366
return result
365367

368+
@util.memoized_property
369+
def _fk_regex_pattern(self):
370+
# optionally quoted token
371+
qtoken = r'(?:"[^"]+"|[\w]+?)'
372+
373+
# https://www.postgresql.org/docs/current/static/sql-createtable.html
374+
return re.compile(
375+
r"FOREIGN KEY \((.*?)\) "
376+
rf"REFERENCES (?:({qtoken})\.)?({qtoken})\(((?:{qtoken}(?: *, *)?)+)\)" # noqa: E501
377+
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
378+
r"[\s]?(ON DELETE "
379+
r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
380+
r"[\s]?(ON UPDATE "
381+
r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
382+
r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
383+
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
384+
)
385+
386+
def get_multi_foreign_keys(
387+
self,
388+
connection,
389+
schema,
390+
filter_names,
391+
scope,
392+
kind,
393+
postgresql_ignore_search_path=False,
394+
**kw,
395+
):
396+
preparer = self.identifier_preparer
397+
398+
has_filter_names, params = self._prepare_filter_names(filter_names)
399+
query = self._foreing_key_query(schema, has_filter_names, scope, kind)
400+
result = connection.execute(query, params)
401+
402+
FK_REGEX = self._fk_regex_pattern
403+
404+
fkeys = collections.defaultdict(list)
405+
default = ReflectionDefaults.foreign_keys
406+
for table_name, conname, condef, conschema, comment in result:
407+
# ensure that each table has an entry, even if it has
408+
# no foreign keys
409+
if conname is None:
410+
fkeys[(schema, table_name)] = default()
411+
continue
412+
table_fks = fkeys[(schema, table_name)]
413+
m = re.search(FK_REGEX, condef).groups()
414+
415+
(
416+
constrained_columns,
417+
referred_schema,
418+
referred_table,
419+
referred_columns,
420+
_,
421+
match,
422+
_,
423+
ondelete,
424+
_,
425+
onupdate,
426+
deferrable,
427+
_,
428+
initially,
429+
) = m
430+
431+
if deferrable is not None:
432+
deferrable = True if deferrable == "DEFERRABLE" else False
433+
constrained_columns = [
434+
preparer._unquote_identifier(x)
435+
for x in re.split(r"\s*,\s*", constrained_columns)
436+
]
437+
438+
if postgresql_ignore_search_path:
439+
# when ignoring search path, we use the actual schema
440+
# provided it isn't the "default" schema
441+
if conschema != self.default_schema_name:
442+
referred_schema = conschema
443+
else:
444+
referred_schema = schema
445+
elif referred_schema:
446+
# referred_schema is the schema that we regexp'ed from
447+
# pg_get_constraintdef(). If the schema is in the search
448+
# path, pg_get_constraintdef() will give us None.
449+
referred_schema = preparer._unquote_identifier(referred_schema)
450+
elif schema is not None and schema == conschema:
451+
# If the actual schema matches the schema of the table
452+
# we're reflecting, then we will use that.
453+
referred_schema = schema
454+
455+
referred_table = preparer._unquote_identifier(referred_table)
456+
referred_columns = [
457+
preparer._unquote_identifier(x)
458+
for x in re.split(r"\s*,\s", referred_columns)
459+
]
460+
options = {
461+
k: v
462+
for k, v in [
463+
("onupdate", onupdate),
464+
("ondelete", ondelete),
465+
("initially", initially),
466+
("deferrable", deferrable),
467+
("match", match),
468+
]
469+
if v is not None and v != "NO ACTION"
470+
}
471+
fkey_d = {
472+
"name": conname,
473+
"constrained_columns": constrained_columns,
474+
"referred_schema": referred_schema,
475+
"referred_table": referred_table,
476+
"referred_columns": referred_columns,
477+
"options": options,
478+
"comment": comment,
479+
}
480+
table_fks.append(fkey_d)
481+
return fkeys.items()
482+
366483
def get_pk_constraint(self, conn, table_name, schema=None, **kw):
367484
if self._is_v21plus:
368485
return super().get_pk_constraint(conn, table_name, schema, **kw)

0 commit comments

Comments
 (0)