|
| 1 | +import re |
| 2 | + |
| 3 | +from django.core.management.commands.inspectdb import Command as DCommand |
| 4 | +from django.db import connections |
| 5 | + |
| 6 | +from clickhouse_backend import compat, models |
| 7 | +from clickhouse_backend.utils.encoding import ensure_str |
| 8 | + |
| 9 | + |
| 10 | +class Command(DCommand): |
| 11 | + db_module = "clickhouse_backend" |
| 12 | + |
| 13 | + def handle_inspection(self, options): |
| 14 | + connection = connections[options["database"]] |
| 15 | + # 'table_name_filter' is a stealth option |
| 16 | + table_name_filter = options.get("table_name_filter") |
| 17 | + |
| 18 | + def table2model(table_name): |
| 19 | + return re.sub(r"[^a-zA-Z0-9]", "", table_name.title()) |
| 20 | + |
| 21 | + with connection.cursor() as cursor: |
| 22 | + yield "# This is an auto-generated Django model module." |
| 23 | + yield "# You'll have to do the following manually to clean this up:" |
| 24 | + yield "# * Rearrange models' order" |
| 25 | + yield "# * Make sure each model has one field with primary_key=True" |
| 26 | + yield ( |
| 27 | + "# * Remove `managed = False` lines if you wish to allow " |
| 28 | + "Django to create, modify, and delete the table" |
| 29 | + ) |
| 30 | + yield ( |
| 31 | + "# Feel free to rename the models, but don't rename db_table values or " |
| 32 | + "field names." |
| 33 | + ) |
| 34 | + yield "from %s import models" % self.db_module |
| 35 | + known_models = [] |
| 36 | + # Determine types of tables and/or views to be introspected. |
| 37 | + types = {"t"} |
| 38 | + if options["include_views"]: |
| 39 | + types.add("v") |
| 40 | + table_info = connection.introspection.get_table_list(cursor) |
| 41 | + table_info = {info.name: info for info in table_info if info.type in types} |
| 42 | + |
| 43 | + for table_name in options["table"] or sorted(name for name in table_info): |
| 44 | + if table_name_filter is not None and callable(table_name_filter): |
| 45 | + if not table_name_filter(table_name): |
| 46 | + continue |
| 47 | + try: |
| 48 | + try: |
| 49 | + constraints = connection.introspection.get_constraints( |
| 50 | + cursor, table_name |
| 51 | + ) |
| 52 | + except NotImplementedError: |
| 53 | + constraints = {} |
| 54 | + table_description = connection.introspection.get_table_description( |
| 55 | + cursor, table_name |
| 56 | + ) |
| 57 | + except Exception as e: |
| 58 | + yield "# Unable to inspect table '%s'" % table_name |
| 59 | + yield "# The error was: %s" % e |
| 60 | + continue |
| 61 | + |
| 62 | + model_name = table2model(table_name) |
| 63 | + yield "" |
| 64 | + yield "" |
| 65 | + yield "class %s(models.ClickhouseModel):" % model_name |
| 66 | + known_models.append(model_name) |
| 67 | + used_column_names = [] # Holds column names used in the table so far |
| 68 | + column_to_field_name = {} # Maps column names to names of model fields |
| 69 | + |
| 70 | + for row in table_description: |
| 71 | + column_name = row.name |
| 72 | + |
| 73 | + ( |
| 74 | + att_name, |
| 75 | + extra_params, # Holds Field parameters such as 'db_column'. |
| 76 | + comment_notes, # Holds Field notes, to be displayed in a Python comment. |
| 77 | + ) = self.normalize_col_name(column_name, used_column_names, False) |
| 78 | + |
| 79 | + used_column_names.append(att_name) |
| 80 | + column_to_field_name[column_name] = att_name |
| 81 | + |
| 82 | + # Add comment. |
| 83 | + if ( |
| 84 | + compat.dj_ge42 |
| 85 | + and connection.features.supports_comments |
| 86 | + and row.comment |
| 87 | + ): |
| 88 | + extra_params["db_comment"] = row.comment |
| 89 | + |
| 90 | + if extra_params: |
| 91 | + param = ", ".join( |
| 92 | + "%s=%r" % (k, v) for k, v in extra_params.items() |
| 93 | + ) |
| 94 | + else: |
| 95 | + param = "" |
| 96 | + |
| 97 | + field_define = "".join( |
| 98 | + self.inspect_field_type(row.type_code, param) |
| 99 | + ) |
| 100 | + field_desc = f"{att_name} = {field_define}" |
| 101 | + if comment_notes: |
| 102 | + field_desc += " # " + " ".join(comment_notes) |
| 103 | + yield " %s" % field_desc |
| 104 | + comment = None |
| 105 | + if info := table_info.get(table_name): |
| 106 | + is_view = info.type == "v" |
| 107 | + if connection.features.supports_comments: |
| 108 | + comment = info.comment |
| 109 | + else: |
| 110 | + is_view = False |
| 111 | + yield from self.get_meta( |
| 112 | + table_name, |
| 113 | + constraints, |
| 114 | + column_to_field_name, |
| 115 | + is_view, |
| 116 | + False, |
| 117 | + comment, |
| 118 | + ) |
| 119 | + |
| 120 | + def inspect_field_type(self, column_type, param=""): |
| 121 | + column_type = ensure_str(column_type) |
| 122 | + |
| 123 | + if column_type.startswith("LowCardinality"): # LowCardinality(Int16) |
| 124 | + param = self.merge_params(param, "low_cardinality=True") |
| 125 | + remain = yield from self.inspect_field_type(column_type[15:], param) |
| 126 | + return remain[1:] |
| 127 | + elif column_type.startswith("Nullable"): # Nullable(Int16) |
| 128 | + param = self.merge_params(param, "null=True", "blank=True") |
| 129 | + remain = yield from self.inspect_field_type( |
| 130 | + column_type[9:], param |
| 131 | + ) # Nullable(Int16) |
| 132 | + return remain[1:] |
| 133 | + elif column_type.startswith("FixedString"): # FixedString(20) |
| 134 | + i = 12 |
| 135 | + while column_type[i].isdigit(): |
| 136 | + i += 1 |
| 137 | + param = self.merge_params(param, f"max_bytes={column_type[12:i]}") |
| 138 | + yield f"models.FixedStringField({param})" |
| 139 | + return column_type[i + 1 :] |
| 140 | + elif column_type.startswith( |
| 141 | + "DateTime64" |
| 142 | + ): # DateTime64(6, 'UTC') or DateTime64(9) |
| 143 | + if int(column_type[11]) != models.DateTime64Field.DEFAULT_PRECISION: |
| 144 | + param = self.merge_params(param, f"precision={column_type[11]}") |
| 145 | + yield f"models.DateTime64Field({param})" |
| 146 | + |
| 147 | + if column_type[12] == ",": |
| 148 | + i = 15 |
| 149 | + while column_type[i] != "'": |
| 150 | + i += 1 |
| 151 | + return column_type[i + 2 :] |
| 152 | + return column_type[13:] |
| 153 | + elif column_type.startswith("DateTime"): # DateTime('UTC') or DateTime |
| 154 | + yield f"models.DateTimeField({param})" |
| 155 | + if len(column_type) > 8 and column_type[8] == "(": |
| 156 | + i = 10 |
| 157 | + while column_type[i] != "'": |
| 158 | + i += 1 |
| 159 | + return column_type[i + 2 :] |
| 160 | + return column_type[8:] |
| 161 | + elif column_type.startswith("Decimal"): # Decimal(9, 3) |
| 162 | + i = 8 |
| 163 | + while column_type[i].isdigit(): |
| 164 | + i += 1 |
| 165 | + max_digits = f"max_digits={column_type[8:i]}" |
| 166 | + i += 2 |
| 167 | + j = i |
| 168 | + while column_type[i].isdigit(): |
| 169 | + i += 1 |
| 170 | + decimal_places = f"decimal_places={column_type[j:i]}" |
| 171 | + param = self.merge_params(param, max_digits, decimal_places) |
| 172 | + yield f"models.DecimalField({param})" |
| 173 | + return column_type[i + 1 :] |
| 174 | + elif column_type.startswith("Enum"): # Enum8('a' = 1, 'b' = 2) |
| 175 | + i = 4 |
| 176 | + while column_type[i].isdigit(): |
| 177 | + i += 1 |
| 178 | + typ = column_type[:i] |
| 179 | + choices = [] |
| 180 | + name, value, remain = self.consume_enum_choice(column_type[i + 1 :]) |
| 181 | + choices.append(f"({value}, {name})") |
| 182 | + while remain[0] != ")": |
| 183 | + name, value, remain = self.consume_enum_choice(remain[2:]) |
| 184 | + choices.append(f"({value}, {name})") |
| 185 | + param = self.merge_params(param, f"choices=[{', '.join(choices)}]") |
| 186 | + yield f"models.{typ}Field({param})" |
| 187 | + return remain[1:] |
| 188 | + elif column_type.startswith( |
| 189 | + "Array" |
| 190 | + ): # Array(Tuple(String, Enum8('a' = 1, 'b' = 2))) |
| 191 | + yield "models.ArrayField(" |
| 192 | + remain = yield from self.inspect_field_type(column_type[6:]) |
| 193 | + if param: |
| 194 | + yield f", {param}" |
| 195 | + yield ")" |
| 196 | + return remain[1:] |
| 197 | + elif column_type.startswith("Tuple"): # Tuple(String, Enum8('a' = 1, 'b' = 2)) |
| 198 | + yield "models.TupleField([" |
| 199 | + remain = yield from self.inspect_field_type(column_type[6:]) |
| 200 | + while remain[0] == ",": |
| 201 | + yield ", " |
| 202 | + remain = yield from self.inspect_field_type(remain[2:]) |
| 203 | + yield "]" |
| 204 | + if param: |
| 205 | + yield f", {param}" |
| 206 | + yield ")" |
| 207 | + return remain[1:] |
| 208 | + elif column_type.startswith("Map"): # Map(String, Int8) |
| 209 | + yield "models.MapField(" |
| 210 | + remain = yield from self.inspect_field_type(column_type[4:]) |
| 211 | + yield ", " |
| 212 | + remain = yield from self.inspect_field_type(remain[2:]) |
| 213 | + if param: |
| 214 | + yield f", {param}" |
| 215 | + yield ")" |
| 216 | + return remain[1:] |
| 217 | + elif column_type.startswith("Object('json')"): |
| 218 | + yield f"models.JSONField({param})" |
| 219 | + return column_type[14:] |
| 220 | + |
| 221 | + i = 0 |
| 222 | + length = len(column_type) |
| 223 | + while i < length and column_type[i].isalnum(): |
| 224 | + i += 1 |
| 225 | + yield f"models.{column_type[:i]}Field({param})" |
| 226 | + return column_type[i:] |
| 227 | + |
| 228 | + def consume_enum_choice(self, s): # 'a' = 1 |
| 229 | + has_bytes = False |
| 230 | + i = 1 |
| 231 | + while True: |
| 232 | + if s[i] == "\\": # escape char |
| 233 | + if s[i + 1] == "x": |
| 234 | + has_bytes = True |
| 235 | + i += 2 |
| 236 | + continue |
| 237 | + if s[i] == "'": |
| 238 | + break |
| 239 | + i += 1 |
| 240 | + i += 1 |
| 241 | + name = s[:i] |
| 242 | + # try decoding bytes to utf8 string. |
| 243 | + if has_bytes: |
| 244 | + try: |
| 245 | + decoded = eval(f"b{name}.decode('utf-8')") |
| 246 | + except UnicodeDecodeError: |
| 247 | + name = f"b{name}" |
| 248 | + else: |
| 249 | + name = repr(decoded) |
| 250 | + |
| 251 | + i += 3 |
| 252 | + j = i |
| 253 | + while s[i].isdigit(): |
| 254 | + i += 1 |
| 255 | + value = s[j:i] |
| 256 | + return name, value, s[i:] |
| 257 | + |
| 258 | + def merge_params(self, *params): |
| 259 | + return ", ".join(filter(None, params)) |
0 commit comments