diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c62b7a8e..e4f6eca3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,6 +23,8 @@ jobs: cache-dependency-path: pyproject.toml - name: Install dependencies run: pip install -e .[test] + - name: Install sqlmodel dependency + run: pip install -e .[sqlmodel] - name: Test with pytest run: coverage run -m pytest - name: Upload Coverage diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb63..6edcb530 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -292,8 +292,21 @@ def group_imports(self) -> list[list[str]]: stdlib_imports: list[str] = [] thirdparty_imports: list[str] = [] + has_imports: set[str] = set() for package in sorted(self.imports): - imports = ", ".join(sorted(self.imports[package])) + if duplicate_import := self.imports[package] & has_imports: + print( + f"WARN: Duplicate imports `{duplicate_import}` are detected " + f"from the package `{package}` and will be filtered, " + f"which may cause abnormal behavior." + ) + + current_import = sorted( + name for name in self.imports[package] if name not in duplicate_import + ) + has_imports = has_imports | set(current_import) + + imports = ", ".join(current_import) collection = thirdparty_imports if package == "__future__": collection = future_imports @@ -448,7 +461,11 @@ def render_column( kwargs["key"] = column.key if is_primary: kwargs["primary_key"] = True - if not column.nullable and not is_sole_pk and is_table: + if ( + not column.nullable + and not is_sole_pk + and (is_table or isinstance(self, SQLModelGenerator)) + ): kwargs["nullable"] = False if is_unique: @@ -482,10 +499,11 @@ def render_column( if comment: kwargs["comment"] = repr(comment) - if is_table: + if is_table or isinstance(self, SQLModelGenerator): self.add_import(Column) return render_callable("Column", *args, kwargs=kwargs) else: + self.add_literal_import("sqlalchemy.orm", "mapped_column") return render_callable("mapped_column", *args, kwargs=kwargs) def render_column_type(self, coltype: object) -> str: