1
1
from django .apps .registry import Apps
2
+ from django .db import DatabaseError
2
3
from django .db import models as django_models
3
4
from django .db .migrations import Migration
4
- from django .db .migrations .exceptions import IrreversibleError
5
+ from django .db .migrations .exceptions import IrreversibleError , MigrationSchemaMissing
5
6
from django .db .migrations .operations .fields import FieldOperation
6
7
from django .db .migrations .operations .models import (
7
8
DeleteModel ,
15
16
__all__ = ["patch_migrations" , "patch_migration_recorder" , "patch_migration" ]
16
17
17
18
19
+ def _should_distribute_migrations (connection ):
20
+ """
21
+ Check if the connection is configured for distributed migrations.
22
+ """
23
+ return getattr (connection , "distributed_migrations" , False ) and getattr (
24
+ connection , "migration_cluster" , None
25
+ )
26
+
27
+
28
+ def _get_model_table_name (connection ):
29
+ """
30
+ Return the name of the table that will be used by the MigrationRecorder.
31
+ If distributed migrations are enabled, return the distributed table name.
32
+ Otherwise, return the regular django_migrations table name.
33
+ """
34
+ if _should_distribute_migrations (connection ):
35
+ return "distributed_django_migrations"
36
+ return "django_migrations"
37
+
38
+
39
+ def _check_replicas (connection ):
40
+ """
41
+ Check if the connection has replicas configured for the migration cluster.
42
+ """
43
+ if hasattr (connection , "has_replicas" ):
44
+ return connection .has_replicas
45
+
46
+ with connection .cursor () as cursor :
47
+ cursor .execute (
48
+ f"select replica_num from system.clusters where cluster={ connection .migration_cluster } "
49
+ )
50
+ (replica_count ,) = cursor .fetchone ()
51
+ return replica_count >= 1
52
+
53
+
18
54
def patch_migrations ():
19
55
patch_migration_recorder ()
20
56
patch_migration ()
@@ -29,22 +65,75 @@ def Migration(self):
29
65
if self ._migration_class is None :
30
66
if self .connection .vendor == "clickhouse" :
31
67
from clickhouse_backend import models
68
+ from clickhouse_backend .models import currentDatabase
32
69
33
- class Migration ( models . ClickhouseModel ):
34
- app = models . StringField ( max_length = 255 )
35
- name = models . StringField ( max_length = 255 )
36
- applied = models . DateTime64Field ( default = now )
37
- deleted = models . BoolField ( default = False )
70
+ # Only create a distributed migration model if the connection
71
+ # has distributed migrations enabled and a migration cluster is set.
72
+ # otherwise, create a regular merge tree.
73
+ if _should_distribute_migrations ( self . connection ):
74
+ has_replicas = _check_replicas ( self . connection )
38
75
39
- class Meta :
40
- apps = Apps ()
41
- app_label = "migrations"
42
- db_table = "django_migrations"
43
- engine = models .MergeTree (order_by = ("app" , "name" ))
44
- cluster = getattr (self .connection , "migration_cluster" , None )
76
+ Engine = models .MergeTree
77
+ if has_replicas :
78
+ Engine = models .ReplicatedMergeTree
45
79
46
- def __str__ (self ):
47
- return "Migration %s for %s" % (self .name , self .app )
80
+ self .connection .has_replicas = has_replicas
81
+
82
+ class _Migration (models .ClickhouseModel ):
83
+ app = models .StringField (max_length = 255 )
84
+ name = models .StringField (max_length = 255 )
85
+ applied = models .DateTime64Field (default = now )
86
+ deleted = models .BoolField (default = False )
87
+
88
+ class Meta :
89
+ apps = Apps ()
90
+ app_label = "migrations"
91
+ db_table = "django_migrations"
92
+ engine = Engine (order_by = ("app" , "name" ))
93
+ cluster = self .connection .migration_cluster
94
+
95
+ def __str__ (self ):
96
+ return "Migration %s for %s" % (self .name , self .app )
97
+
98
+ class Migration (models .ClickhouseModel ):
99
+ app = models .StringField (max_length = 255 )
100
+ name = models .StringField (max_length = 255 )
101
+ applied = models .DateTime64Field (default = now )
102
+ deleted = models .BoolField (default = False )
103
+
104
+ class Meta :
105
+ apps = Apps ()
106
+ app_label = "migrations"
107
+ db_table = _get_model_table_name (self .connection )
108
+ engine = models .Distributed (
109
+ self .connection .migration_cluster ,
110
+ currentDatabase (),
111
+ _Migration ._meta .db_table ,
112
+ models .Rand (),
113
+ )
114
+ cluster = self .connection .migration_cluster
115
+
116
+ Migration ._meta .local_model_class = _Migration
117
+
118
+ else :
119
+
120
+ class Migration (models .ClickhouseModel ):
121
+ app = models .StringField (max_length = 255 )
122
+ name = models .StringField (max_length = 255 )
123
+ applied = models .DateTime64Field (default = now )
124
+ deleted = models .BoolField (default = False )
125
+
126
+ class Meta :
127
+ apps = Apps ()
128
+ app_label = "migrations"
129
+ db_table = _get_model_table_name (self .connection )
130
+ engine = models .MergeTree (order_by = ("app" , "name" ))
131
+ cluster = getattr (
132
+ self .connection , "migration_cluster" , None
133
+ )
134
+
135
+ def __str__ (self ):
136
+ return "Migration %s for %s" % (self .name , self .app )
48
137
49
138
else :
50
139
@@ -69,15 +158,45 @@ def has_table(self):
69
158
# Assert migration table won't be deleted once created.
70
159
if not getattr (self , "_has_table" , False ):
71
160
with self .connection .cursor () as cursor :
161
+ table = self .Migration ._meta .db_table
72
162
tables = self .connection .introspection .table_names (cursor )
73
- self ._has_table = self . Migration . _meta . db_table in tables
163
+ self ._has_table = table in tables
74
164
if self ._has_table and self .connection .vendor == "clickhouse" :
75
165
# fix https://github.com/jayvynl/django-clickhouse-backend/issues/51
76
166
cursor .execute (
77
- "ALTER table django_migrations ADD COLUMN IF NOT EXISTS deleted Bool"
167
+ f "ALTER table { table } ADD COLUMN IF NOT EXISTS deleted Bool"
78
168
)
79
169
return self ._has_table
80
170
171
+ def ensure_schema (self ):
172
+ """Ensure the table exists and has the correct schema."""
173
+ # If the table's there, that's fine - we've never changed its schema
174
+ # in the codebase.
175
+ if self .has_table ():
176
+ return
177
+
178
+ # In case of distributed migrations, we need to ensure the local model exists first and
179
+ # then create the distributed model.
180
+ try :
181
+ with self .connection .schema_editor () as editor :
182
+ if (
183
+ editor .connection .vendor == "clickhouse"
184
+ and _should_distribute_migrations (editor .connection )
185
+ ):
186
+ with editor .connection .cursor () as cursor :
187
+ tables = editor .connection .introspection .table_names (cursor )
188
+ local_model_class = self .Migration ._meta .local_model_class
189
+ local_table = local_model_class ._meta .db_table
190
+ if local_table not in tables :
191
+ # Create the local model first
192
+ editor .create_model (self .Migration ._meta .local_model_class )
193
+
194
+ editor .create_model (self .Migration )
195
+ except DatabaseError as exc :
196
+ raise MigrationSchemaMissing (
197
+ "Unable to create the django_migrations table (%s)" % exc
198
+ )
199
+
81
200
def migration_qs (self ):
82
201
if self .connection .vendor == "clickhouse" :
83
202
return self .Migration .objects .using (self .connection .alias ).filter (
@@ -118,6 +237,7 @@ def flush(self):
118
237
119
238
MigrationRecorder .Migration = property (Migration )
120
239
MigrationRecorder .has_table = has_table
240
+ MigrationRecorder .ensure_schema = ensure_schema
121
241
MigrationRecorder .migration_qs = property (migration_qs )
122
242
MigrationRecorder .record_applied = record_applied
123
243
MigrationRecorder .record_unapplied = record_unapplied
@@ -136,13 +256,15 @@ def apply(self, project_state, schema_editor, collect_sql=False):
136
256
"""
137
257
applied_on_remote = False
138
258
if getattr (schema_editor .connection , "migration_cluster" , None ):
259
+ _table = _get_model_table_name (schema_editor .connection )
260
+
139
261
with schema_editor .connection .cursor () as cursor :
140
262
cursor .execute (
141
263
"select EXISTS(select 1 from clusterAllReplicas(%s, currentDatabase(), %s)"
142
264
" where app=%s and name=%s and deleted=false)" ,
143
265
[
144
266
schema_editor .connection .migration_cluster ,
145
- "django_migrations" ,
267
+ _table ,
146
268
self .app_label ,
147
269
self .name ,
148
270
],
@@ -203,13 +325,15 @@ def unapply(self, project_state, schema_editor, collect_sql=False):
203
325
"""
204
326
unapplied_on_remote = False
205
327
if getattr (schema_editor .connection , "migration_cluster" , None ):
328
+ _table = _get_model_table_name (schema_editor .connection )
329
+
206
330
with schema_editor .connection .cursor () as cursor :
207
331
cursor .execute (
208
332
"select EXISTS(select 1 from clusterAllReplicas(%s, currentDatabase(), %s)"
209
333
" where app=%s and name=%s and deleted=true)" ,
210
334
[
211
335
schema_editor .connection .migration_cluster ,
212
- "django_migrations" ,
336
+ _table ,
213
337
self .app_label ,
214
338
self .name ,
215
339
],
0 commit comments