|
20 | 20 | import argparse |
21 | 21 | import sys |
22 | 22 |
|
23 | | -from sqlalchemy import create_engine, inspect |
| 23 | +from sqlalchemy import String, create_engine, inspect, text |
24 | 24 | from sqlalchemy.orm import Session |
25 | 25 |
|
26 | 26 | from microSALT.store.orm_models import ( |
|
50 | 50 | ] |
51 | 51 |
|
52 | 52 |
|
| 53 | +def _widen_varchar_columns(dst_engine) -> None: |
| 54 | + """ALTER any MySQL VARCHAR column that is narrower than the ORM definition. |
| 55 | +
|
| 56 | + Called after create_all so that tables are guaranteed to exist. |
| 57 | + Handles the case where the schema was already created with an older, |
| 58 | + narrower column definition. |
| 59 | + """ |
| 60 | + dst_inspector = inspect(dst_engine) |
| 61 | + existing_tables = dst_inspector.get_table_names() |
| 62 | + |
| 63 | + with dst_engine.connect() as conn: |
| 64 | + for model in TABLES: |
| 65 | + table_name = model.__tablename__ |
| 66 | + if table_name not in existing_tables: |
| 67 | + continue |
| 68 | + |
| 69 | + actual_cols = {c["name"]: c for c in dst_inspector.get_columns(table_name)} |
| 70 | + |
| 71 | + for col in inspect(model).mapper.columns: |
| 72 | + if not isinstance(col.type, String): |
| 73 | + continue |
| 74 | + orm_len = col.type.length |
| 75 | + if orm_len is None: |
| 76 | + continue |
| 77 | + actual = actual_cols.get(col.name) |
| 78 | + if actual is None: |
| 79 | + continue |
| 80 | + actual_len = getattr(actual["type"], "length", None) |
| 81 | + if actual_len is not None and actual_len < orm_len: |
| 82 | + print( |
| 83 | + f" Widening {table_name}.{col.name}: " |
| 84 | + f"VARCHAR({actual_len}) → VARCHAR({orm_len})" |
| 85 | + ) |
| 86 | + conn.execute( |
| 87 | + text( |
| 88 | + f"ALTER TABLE `{table_name}` MODIFY COLUMN" |
| 89 | + f" `{col.name}` VARCHAR({orm_len})" |
| 90 | + ) |
| 91 | + ) |
| 92 | + conn.commit() |
| 93 | + |
| 94 | + |
53 | 95 | def _columns(model: type) -> list[str]: |
54 | 96 | """Return the list of column attribute names for an ORM model.""" |
55 | 97 | return [c.key for c in inspect(model).mapper.column_attrs] |
@@ -132,9 +174,11 @@ def main() -> int: |
132 | 174 | src_engine = create_engine(sqlite_uri, pool_pre_ping=True) |
133 | 175 | dst_engine = create_engine(args.mysql, pool_pre_ping=True) |
134 | 176 |
|
135 | | - # Ensure all ORM tables exist in the destination |
| 177 | + # Ensure all ORM tables exist in the destination, then widen any columns |
| 178 | + # that are narrower in MySQL than the current ORM definition. |
136 | 179 | if not args.dry_run: |
137 | 180 | Base.metadata.create_all(dst_engine) |
| 181 | + _widen_varchar_columns(dst_engine) |
138 | 182 |
|
139 | 183 | total_inserted = 0 |
140 | 184 | total_skipped = 0 |
|
0 commit comments