Skip to content

Commit 9d1356b

Browse files
Altering Ondelete and OnUpdate with MigrationManager (#1183)
* fix OnDelete and OnUpdate in migrations * fix test docstring * remove for loop * refactor `add_foreign_key_constraint` If a `ForeignKey` is passed in, we can get most of the info we need from it, rather than requiring additional params to be passed in. * add `get_fk_constraint_name` * rename pk constraint on table renaming in migrations * `RenamePkConstraint` -> `RenameConstraint` * skip renaming pk constraint in SQLite * revert changes * update param names in `rename_constraint` method * add docstring * add `get_fk_constraint_rules` function So the logic for getting constraints is centralised. * add TODO for supporting SQLite * add TODO for moving logic from `generate.py` to `constraints.py` * remove redundant lookup in `generate.py` We should refactor a bunch of this - I came across this dict lookup, which isn't required because we can just call the enum directly. --------- Co-authored-by: Daniel Townsend <[email protected]>
1 parent 7cbd871 commit 9d1356b

File tree

8 files changed

+276
-67
lines changed

8 files changed

+276
-67
lines changed

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from piccolo.engine import engine_finder
2020
from piccolo.query import Query
2121
from piccolo.query.base import DDL
22+
from piccolo.query.constraints import get_fk_constraint_name
2223
from piccolo.schema import SchemaDDLBase
2324
from piccolo.table import Table, create_table_class, sort_table_classes
2425
from piccolo.utils.warnings import colored_warning
@@ -423,8 +424,8 @@ async def _print_query(query: t.Union[DDL, Query, SchemaDDLBase]):
423424

424425
async def _run_query(self, query: t.Union[DDL, Query, SchemaDDLBase]):
425426
"""
426-
If MigrationManager is not in the preview mode,
427-
executes the queries. else, prints the query.
427+
If MigrationManager is in preview mode then it just print the query
428+
instead of executing it.
428429
"""
429430
if self.preview:
430431
await self._print_query(query)
@@ -534,6 +535,39 @@ async def _run_alter_columns(self, backwards: bool = False):
534535

535536
###############################################################
536537

538+
on_delete = params.get("on_delete")
539+
on_update = params.get("on_update")
540+
if on_delete is not None or on_update is not None:
541+
existing_table = await self.get_table_from_snapshot(
542+
table_class_name=table_class_name,
543+
app_name=self.app_name,
544+
)
545+
546+
fk_column = existing_table._meta.get_column_by_name(
547+
alter_column.column_name
548+
)
549+
550+
assert isinstance(fk_column, ForeignKey)
551+
552+
# First drop the existing foreign key constraint
553+
constraint_name = await get_fk_constraint_name(
554+
column=fk_column
555+
)
556+
await self._run_query(
557+
_Table.alter().drop_constraint(
558+
constraint_name=constraint_name
559+
)
560+
)
561+
562+
# Then add a new foreign key constraint
563+
await self._run_query(
564+
_Table.alter().add_foreign_key_constraint(
565+
column=fk_column,
566+
on_delete=on_delete,
567+
on_update=on_update,
568+
)
569+
)
570+
537571
null = params.get("null")
538572
if null is not None:
539573
await self._run_query(

piccolo/apps/schema/commands/generate.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ async def get_fk_triggers(
501501
Any Table subclass - just used to execute raw queries on the database.
502502
503503
"""
504+
# TODO - Move this query to `piccolo.query.constraints` or use:
505+
# `piccolo.query.constraints.referential_constraints`
504506
triggers = await table_class.raw(
505507
(
506508
"SELECT tc.constraint_name, "
@@ -537,23 +539,6 @@ async def get_fk_triggers(
537539
)
538540

539541

540-
ONDELETE_MAP = {
541-
"NO ACTION": OnDelete.no_action,
542-
"RESTRICT": OnDelete.restrict,
543-
"CASCADE": OnDelete.cascade,
544-
"SET NULL": OnDelete.set_null,
545-
"SET DEFAULT": OnDelete.set_default,
546-
}
547-
548-
ONUPDATE_MAP = {
549-
"NO ACTION": OnUpdate.no_action,
550-
"RESTRICT": OnUpdate.restrict,
551-
"CASCADE": OnUpdate.cascade,
552-
"SET NULL": OnUpdate.set_null,
553-
"SET DEFAULT": OnUpdate.set_default,
554-
}
555-
556-
557542
async def get_constraints(
558543
table_class: t.Type[Table], tablename: str, schema_name: str = "public"
559544
) -> TableConstraints:
@@ -765,8 +750,8 @@ async def create_table_class_from_db(
765750
column_name, constraint_table.name
766751
)
767752
if trigger:
768-
kwargs["on_update"] = ONUPDATE_MAP[trigger.on_update]
769-
kwargs["on_delete"] = ONDELETE_MAP[trigger.on_delete]
753+
kwargs["on_update"] = OnUpdate(trigger.on_update)
754+
kwargs["on_delete"] = OnDelete(trigger.on_delete)
770755
else:
771756
output_schema.trigger_warnings.append(
772757
f"{tablename}.{column_name}"

piccolo/query/constraints.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from dataclasses import dataclass
2+
3+
from piccolo.columns import ForeignKey
4+
from piccolo.columns.base import OnDelete, OnUpdate
5+
6+
7+
async def get_fk_constraint_name(column: ForeignKey) -> str:
8+
"""
9+
Checks what the foreign key constraint is called in the database.
10+
"""
11+
12+
table = column._meta.table
13+
14+
if table._meta.db.engine_type == "sqlite":
15+
# TODO - add the query for SQLite
16+
raise ValueError("SQLite isn't currently supported.")
17+
18+
schema = table._meta.schema or "public"
19+
table_name = table._meta.tablename
20+
column_name = column._meta.db_column_name
21+
22+
constraints = await table.raw(
23+
"""
24+
SELECT
25+
kcu.constraint_name AS fk_constraint_name
26+
FROM
27+
information_schema.referential_constraints AS rc
28+
INNER JOIN
29+
information_schema.key_column_usage AS kcu
30+
ON kcu.constraint_catalog = rc.constraint_catalog
31+
AND kcu.constraint_schema = rc.constraint_schema
32+
AND kcu.constraint_name = rc.constraint_name
33+
WHERE
34+
kcu.table_schema = {} AND
35+
kcu.table_name = {} AND
36+
kcu.column_name = {}
37+
""",
38+
schema,
39+
table_name,
40+
column_name,
41+
)
42+
43+
return constraints[0]["fk_constraint_name"]
44+
45+
46+
@dataclass
47+
class ConstraintRules:
48+
on_delete: OnDelete
49+
on_update: OnUpdate
50+
51+
52+
async def get_fk_constraint_rules(column: ForeignKey) -> ConstraintRules:
53+
"""
54+
Checks the constraint rules for this foreign key in the database.
55+
"""
56+
table = column._meta.table
57+
58+
if table._meta.db.engine_type == "sqlite":
59+
# TODO - add the query for SQLite
60+
raise ValueError("SQLite isn't currently supported.")
61+
62+
schema = table._meta.schema or "public"
63+
table_name = table._meta.tablename
64+
column_name = column._meta.db_column_name
65+
66+
constraints = await table.raw(
67+
"""
68+
SELECT
69+
kcu.constraint_name,
70+
kcu.table_name,
71+
kcu.column_name,
72+
rc.update_rule,
73+
rc.delete_rule
74+
FROM
75+
information_schema.key_column_usage AS kcu
76+
INNER JOIN
77+
information_schema.referential_constraints AS rc
78+
ON kcu.constraint_name = rc.constraint_name
79+
WHERE
80+
kcu.table_schema = {} AND
81+
kcu.table_name = {} AND
82+
kcu.column_name = {}
83+
""",
84+
schema,
85+
table_name,
86+
column_name,
87+
)
88+
89+
return ConstraintRules(
90+
on_delete=OnDelete(constraints[0]["delete_rule"]),
91+
on_update=OnUpdate(constraints[0]["update_rule"]),
92+
)

piccolo/query/methods/alter.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def ddl(self) -> str:
3636
return f"RENAME TO {self.new_name}"
3737

3838

39+
@dataclass
40+
class RenameConstraint(AlterStatement):
41+
__slots__ = ("old_name", "new_name")
42+
43+
old_name: str
44+
new_name: str
45+
46+
@property
47+
def ddl(self) -> str:
48+
return f"RENAME CONSTRAINT {self.old_name} TO {self.new_name}"
49+
50+
3951
@dataclass
4052
class AlterColumnStatement(AlterStatement):
4153
__slots__ = ("column",)
@@ -194,16 +206,17 @@ class AddForeignKeyConstraint(AlterStatement):
194206
"constraint_name",
195207
"foreign_key_column_name",
196208
"referenced_table_name",
209+
"referenced_column_name",
197210
"on_delete",
198211
"on_update",
199212
)
200213

201214
constraint_name: str
202215
foreign_key_column_name: str
203216
referenced_table_name: str
217+
referenced_column_name: str
204218
on_delete: t.Optional[OnDelete]
205219
on_update: t.Optional[OnUpdate]
206-
referenced_column_name: str = "id"
207220

208221
@property
209222
def ddl(self) -> str:
@@ -273,8 +286,8 @@ def ddl(self) -> str:
273286

274287
class Alter(DDL):
275288
__slots__ = (
276-
"_add_foreign_key_constraint",
277289
"_add",
290+
"_add_foreign_key_constraint",
278291
"_drop_constraint",
279292
"_drop_default",
280293
"_drop_table",
@@ -288,6 +301,7 @@ class Alter(DDL):
288301
"_set_null",
289302
"_set_schema",
290303
"_set_unique",
304+
"_rename_constraint",
291305
)
292306

293307
def __init__(self, table: t.Type[Table], **kwargs):
@@ -307,6 +321,7 @@ def __init__(self, table: t.Type[Table], **kwargs):
307321
self._set_null: t.List[SetNull] = []
308322
self._set_schema: t.List[SetSchema] = []
309323
self._set_unique: t.List[SetUnique] = []
324+
self._rename_constraint: t.List[RenameConstraint] = []
310325

311326
def add_column(self: Self, name: str, column: Column) -> Self:
312327
"""
@@ -372,6 +387,24 @@ def rename_table(self, new_name: str) -> Alter:
372387
self._rename_table = [RenameTable(new_name=new_name)]
373388
return self
374389

390+
def rename_constraint(self, old_name: str, new_name: str) -> Alter:
391+
"""
392+
Rename a constraint on the table::
393+
394+
>>> await Band.alter().rename_constraint(
395+
... 'old_constraint_name',
396+
... 'new_constraint_name',
397+
... )
398+
399+
"""
400+
self._rename_constraint = [
401+
RenameConstraint(
402+
old_name=old_name,
403+
new_name=new_name,
404+
)
405+
]
406+
return self
407+
375408
def rename_column(
376409
self, column: t.Union[str, Column], new_name: str
377410
) -> Alter:
@@ -488,7 +521,7 @@ def set_length(self, column: t.Union[str, Varchar], length: int) -> Alter:
488521
def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str:
489522
column_name = AlterColumnStatement(column=column).column_name
490523
tablename = self.table._meta.tablename
491-
return f"{tablename}_{column_name}_fk"
524+
return f"{tablename}_{column_name}_fkey"
492525

493526
def drop_constraint(self, constraint_name: str) -> Alter:
494527
self._drop_constraint.append(
@@ -500,37 +533,58 @@ def drop_foreign_key_constraint(
500533
self, column: t.Union[str, ForeignKey]
501534
) -> Alter:
502535
constraint_name = self._get_constraint_name(column=column)
503-
return self.drop_constraint(constraint_name=constraint_name)
536+
self._drop_constraint.append(
537+
DropConstraint(constraint_name=constraint_name)
538+
)
539+
return self
504540

505541
def add_foreign_key_constraint(
506542
self,
507543
column: t.Union[str, ForeignKey],
508-
referenced_table_name: str,
544+
referenced_table_name: t.Optional[str] = None,
545+
referenced_column_name: t.Optional[str] = None,
546+
constraint_name: t.Optional[str] = None,
509547
on_delete: t.Optional[OnDelete] = None,
510548
on_update: t.Optional[OnUpdate] = None,
511-
referenced_column_name: str = "id",
512549
) -> Alter:
513550
"""
514551
Add a new foreign key constraint::
515552
516553
>>> await Band.alter().add_foreign_key_constraint(
517554
... Band.manager,
518-
... referenced_table_name='manager',
519555
... on_delete=OnDelete.cascade
520556
... )
521557
522558
"""
523-
constraint_name = self._get_constraint_name(column=column)
559+
constraint_name = constraint_name or self._get_constraint_name(
560+
column=column
561+
)
524562
column_name = AlterColumnStatement(column=column).column_name
525563

564+
if referenced_column_name is None:
565+
if isinstance(column, ForeignKey):
566+
referenced_column_name = (
567+
column._foreign_key_meta.resolved_target_column._meta.db_column_name # noqa: E501
568+
)
569+
else:
570+
raise ValueError("Please pass in `referenced_column_name`.")
571+
572+
if referenced_table_name is None:
573+
if isinstance(column, ForeignKey):
574+
referenced_table_name = (
575+
column._foreign_key_meta.resolved_references._meta.tablename # noqa: E501
576+
)
577+
else:
578+
raise ValueError("Please pass in `referenced_table_name`.")
579+
526580
self._add_foreign_key_constraint.append(
527581
AddForeignKeyConstraint(
528582
constraint_name=constraint_name,
529583
foreign_key_column_name=column_name,
530584
referenced_table_name=referenced_table_name,
585+
referenced_column_name=referenced_column_name,
531586
on_delete=on_delete,
532587
on_update=on_update,
533-
referenced_column_name=referenced_column_name,
534588
)
535589
)
536590
return self
@@ -579,9 +633,12 @@ def default_ddl(self) -> t.Sequence[str]:
579633
i.ddl
580634
for i in itertools.chain(
581635
self._add,
636+
self._add_foreign_key_constraint,
582637
self._rename_columns,
583638
self._rename_table,
639+
self._rename_constraint,
584640
self._drop,
641+
self._drop_constraint,
585642
self._drop_default,
586643
self._set_column_type,
587644
self._set_unique,

piccolo/utils/sync.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import asyncio
44
import typing as t
5-
from concurrent.futures import ThreadPoolExecutor
5+
from concurrent.futures import Future, ThreadPoolExecutor
66

7+
ReturnType = t.TypeVar("ReturnType")
78

8-
def run_sync(coroutine: t.Coroutine):
9+
10+
def run_sync(
11+
coroutine: t.Coroutine[t.Any, t.Any, ReturnType],
12+
) -> ReturnType:
913
"""
1014
Run the coroutine synchronously - trying to accommodate as many edge cases
1115
as possible.
@@ -20,5 +24,5 @@ def run_sync(coroutine: t.Coroutine):
2024
except RuntimeError:
2125
# An event loop already exists.
2226
with ThreadPoolExecutor(max_workers=1) as executor:
23-
future = executor.submit(asyncio.run, coroutine)
27+
future: Future = executor.submit(asyncio.run, coroutine)
2428
return future.result()

0 commit comments

Comments
 (0)