diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb63..883cb713 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -128,9 +128,14 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, ): super().__init__(metadata, bind, options) self.indentation: str = indentation + # TODO add check if there is a "." in the value if set? + self.dynamic_schema_import_path: str | None = dynamic_schema_import_path + self.dynamic_schema_value: str | None = dynamic_schema_value self.imports: dict[str, set[str]] = defaultdict(set) self.module_imports: set[str] = set() @@ -197,6 +202,8 @@ def collect_imports(self, models: Iterable[Model]) -> None: for model in models: self.collect_imports_for_model(model) + if self.dynamic_schema_import_path: + self.add_literal_import(*self.dynamic_schema_import_path.rsplit(".", 1)) def collect_imports_for_model(self, model: Model) -> None: if model.__class__ is Model: @@ -374,7 +381,9 @@ def render_table(self, table: Table) -> str: if len(index.columns) > 1 or not uses_default_name(index): args.append(self.render_index(index)) - if table.schema: + if self.dynamic_schema_value: + kwargs["schema"] = self.dynamic_schema_value + elif table.schema: kwargs["schema"] = repr(table.schema) table_comment = getattr(table, "comment", None) @@ -722,9 +731,18 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "Base", ): - super().__init__(metadata, bind, options, indentation=indentation) + super().__init__( + metadata, + bind, + options, + indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, + ) self.base_class_name: str = base_class_name self.inflect_engine = inflect.engine() @@ -1159,14 +1177,23 @@ def render_table_args(self, table: Table) -> str: if len(index.columns) > 1 or not uses_default_name(index): args.append(self.render_index(index)) - if table.schema: + if self.dynamic_schema_value: + kwargs["schema"] = self.dynamic_schema_value + elif table.schema: kwargs["schema"] = table.schema if table.comment: kwargs["comment"] = table.comment if kwargs: - formatted_kwargs = pformat(kwargs) + # NB: using pformat on the dict turns schema value (python code) to a string + formatted_kwargs = f",\n{self.indentation}".join( + f"'{k}': {pformat(v)}" + if v != self.dynamic_schema_value + else f"'{k}': {v}" + for k, v in kwargs.items() + ) + formatted_kwargs = f"{{{formatted_kwargs}}}" if not args: return formatted_kwargs else: @@ -1309,6 +1336,8 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "Base", quote_annotations: bool = False, metadata_key: str = "sa", @@ -1318,6 +1347,8 @@ def __init__( bind, options, indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, base_class_name=base_class_name, ) self.metadata_key: str = metadata_key @@ -1348,6 +1379,8 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "SQLModel", ): super().__init__( @@ -1355,9 +1388,89 @@ def __init__( bind, options, indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, base_class_name=base_class_name, ) + def generate_models(self) -> list[Model]: + models_by_table_name: dict[str, Model] = {} + + # Pick association tables from the metadata into their own set, don't process + # them normally + links: defaultdict[str, list[Model]] = defaultdict(lambda: []) + for table in self.metadata.sorted_tables: + qualified_name = qualified_table_name(table) + + # Link tables have exactly two foreign key constraints and all columns are + # involved in them + fk_constraints = sorted( + table.foreign_key_constraints, key=get_constraint_sort_key + ) + if len(fk_constraints) == 2 and all( + col.foreign_keys for col in table.columns + ): + model = models_by_table_name[qualified_name] = Model(table) + tablename = fk_constraints[0].elements[0].column.table.name + links[tablename].append(model) + continue + + # Only form model classes for tables that have a primary key and are not + # association tables + if not table.primary_key: + models_by_table_name[qualified_name] = Model(table) + else: + model = ModelClass(table) + models_by_table_name[qualified_name] = model + + # Fill in the columns + for column in table.c: + column_attr = ColumnAttribute(model, column) + model.columns.append(column_attr) + + # Add relationships + for model in models_by_table_name.values(): + if isinstance(model, ModelClass): + self.generate_relationships( + model, models_by_table_name, links[model.table.name] + ) + + # Nest inherited classes in their superclasses to ensure proper ordering + if "nojoined" not in self.options: + for model in list(models_by_table_name.values()): + if not isinstance(model, ModelClass): + continue + + pk_column_names = {col.name for col in model.table.primary_key.columns} + for constraint in model.table.foreign_key_constraints: + if set(get_column_names(constraint)) == pk_column_names: + target = models_by_table_name[ + qualified_table_name(constraint.elements[0].column.table) + ] + if isinstance(target, ModelClass): + model.parent_class = target + target.children.append(model) + + # Change base if we have both tables and model classes + if any( + not isinstance(model, ModelClass) for model in models_by_table_name.values() + ): + TablesGenerator.generate_base(self) + + # Collect the imports + self.collect_imports(models_by_table_name.values()) + + # Rename models and their attributes that conflict with imports or other + # attributes + global_names = { + name for namespace in self.imports.values() for name in namespace + } + for model in models_by_table_name.values(): + self.generate_model_name(model, global_names) + global_names.add(model.name) + + return list(models_by_table_name.values()) + def generate_base(self) -> None: self.base = Base( literal_imports=[], @@ -1368,7 +1481,6 @@ def generate_base(self) -> None: def collect_imports(self, models: Iterable[Model]) -> None: super(DeclarativeGenerator, self).collect_imports(models) if any(isinstance(model, ModelClass) for model in models): - self.remove_literal_import("sqlalchemy", "MetaData") self.add_literal_import("sqlmodel", "SQLModel") self.add_literal_import("sqlmodel", "Field") @@ -1400,7 +1512,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: self.add_import(python_type) def render_module_variables(self, models: list[Model]) -> str: - declarations: list[str] = [] + declarations: list[str] = self.base.declarations if any(not isinstance(model, ModelClass) for model in models): if self.base.table_metadata_declaration is not None: declarations.append(self.base.table_metadata_declaration) @@ -1446,7 +1558,7 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: kwargs["default"] = None python_type_name = f"Optional[{python_type_name}]" - rendered_column = self.render_column(column, True) + rendered_column = self.render_column(column, True, is_table=True) kwargs["sa_column"] = f"{rendered_column}" rendered_field = render_callable("Field", kwargs=kwargs) return f"{column_attr.name}: {python_type_name} = {rendered_field}" diff --git a/tests/conftest.py b/tests/conftest.py index 022e786c..ae3822d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from textwrap import dedent import pytest @@ -31,3 +32,12 @@ def validate_code(generated_code: str, expected_code: str) -> None: configure_mappers() finally: clear_mappers() + + +@dataclass +class SchemaObject: + name: str + + +# NB: not a fixture on purpose +schema_obj = SchemaObject(name="best_schema") diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index d9bf7b53..7ce25c89 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -1,5 +1,7 @@ from __future__ import annotations +from textwrap import dedent + import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy import PrimaryKeyConstraint @@ -30,6 +32,20 @@ def generator( return DeclarativeGenerator(metadata, engine, options) +@pytest.fixture +def generator_dynamic_schema( + request: FixtureRequest, metadata: MetaData, engine: Engine +) -> CodeGenerator: + schema_import_path, schema_value = getattr(request, "param", (None, None)) + return DeclarativeGenerator( + metadata, + engine, + [], + dynamic_schema_import_path=schema_import_path, + dynamic_schema_value=schema_value, + ) + + def test_indexes(generator: CodeGenerator) -> None: simple_items = Table( "simple_items", @@ -1509,3 +1525,37 @@ class Simple(Base): server_default=text("'test'")) """, ) + + +@pytest.mark.parametrize( + "generator_dynamic_schema", + [[".conftest.schema_obj", "schema_obj.name"]], + indirect=True, +) +def test_use_dynamic_schema(generator_dynamic_schema: CodeGenerator) -> None: + Table( + "simple_items", + generator_dynamic_schema.metadata, + Column("id", INTEGER, primary_key=True), + ) + + expected_code = """\ +from .conftest import schema_obj +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = {'schema': schema_obj.name} + + id: Mapped[int] = mapped_column(Integer, primary_key=True) +""" + generated_code = generator_dynamic_schema.generate() + expected_code = dedent(expected_code) + assert generated_code == expected_code + # TODO: code execution fails with KeyError: "'__name__' not in globals", any idea? + # validate_code(generated_code, expected_code)