Skip to content

Commit c2a2ef0

Browse files
EthanKim8683sheinbergonagronholm
authored
Fixed same-name imports from wrong package (#411)
* Fixed same-name imports from wrong package * Revert "Fixed same-name imports from wrong package" This reverts commit b181bd7. * Use render_python_column_type in SQLModelGenerator * Use stdlib-list instead of sys.stdlib_module_names * Update dependencies * Use sys.stdlib_module_names for Python 3.10 or newer * Use ClassVar instead of getter * Use utility instead of ClassVar * Fix indentation in pyproject.toml * Update src/sqlacodegen/generators.py Co-authored-by: Alex Grönholm <[email protected]> * PR Fixes * Update CHANGES.rst --------- Co-authored-by: Idan Sheinberg <[email protected]> Co-authored-by: Alex Grönholm <[email protected]>
1 parent 240e5b7 commit c2a2ef0

File tree

7 files changed

+74
-53
lines changed

7 files changed

+74
-53
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Version history
1212
- Prevent double pluralization (PR by @dkratzert)
1313
- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon)
1414
- Temporarily restrict SQLAlchemy version to 2.0.41 (PR by @sheinbergon)
15+
- Fixes ``add_import`` behavior when adding imports from sqlalchemy and overall better
16+
alignment of import behavior(s) across generators
1517

1618
**3.0.0**
1719

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"SQLAlchemy >= 2.0.29,<2.0.42",
3333
"inflect >= 4.0.0",
3434
"importlib_metadata; python_version < '3.10'",
35+
"stdlib-list; python_version < '3.10'"
3536
]
3637
dynamic = ["version"]
3738

src/sqlacodegen/generators.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
get_common_fk_constraints,
6060
get_compiled_expression,
6161
get_constraint_sort_key,
62+
get_stdlib_module_names,
6263
qualified_table_name,
6364
render_callable,
6465
uses_default_name,
@@ -119,9 +120,7 @@ def generate(self) -> str:
119120
@dataclass(eq=False)
120121
class TablesGenerator(CodeGenerator):
121122
valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
122-
builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
123-
"dataclasses"
124-
}
123+
stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
125124

126125
def __init__(
127126
self,
@@ -276,7 +275,7 @@ def add_import(self, obj: Any) -> None:
276275

277276
if type_.__name__ in dialect_pkg.__all__:
278277
pkgname = dialect_pkgname
279-
elif type_.__name__ in dir(sqlalchemy):
278+
elif type_ is getattr(sqlalchemy, type_.__name__, None):
280279
pkgname = "sqlalchemy"
281280
else:
282281
pkgname = type_.__module__
@@ -300,21 +299,26 @@ def group_imports(self) -> list[list[str]]:
300299
stdlib_imports: list[str] = []
301300
thirdparty_imports: list[str] = []
302301

303-
for package in sorted(self.imports):
304-
imports = ", ".join(sorted(self.imports[package]))
302+
def get_collection(package: str) -> list[str]:
305303
collection = thirdparty_imports
306304
if package == "__future__":
307305
collection = future_imports
308-
elif package in self.builtin_module_names:
306+
elif package in self.stdlib_module_names:
309307
collection = stdlib_imports
310308
elif package in sys.modules:
311309
if "site-packages" not in (sys.modules[package].__file__ or ""):
312310
collection = stdlib_imports
311+
return collection
313312

313+
for package in sorted(self.imports):
314+
imports = ", ".join(sorted(self.imports[package]))
315+
316+
collection = get_collection(package)
314317
collection.append(f"from {package} import {imports}")
315318

316319
for module in sorted(self.module_imports):
317-
thirdparty_imports.append(f"import {module}")
320+
collection = get_collection(module)
321+
collection.append(f"import {module}")
318322

319323
return [
320324
group
@@ -1212,10 +1216,7 @@ def render_table_args(self, table: Table) -> str:
12121216
else:
12131217
return ""
12141218

1215-
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1216-
column = column_attr.column
1217-
rendered_column = self.render_column(column, column_attr.name != column.name)
1218-
1219+
def render_column_python_type(self, column: Column[Any]) -> str:
12191220
def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
12201221
column_type = column.type
12211222
pre: list[str] = []
@@ -1254,7 +1255,14 @@ def render_python_type(column_type: TypeEngine[Any]) -> str:
12541255

12551256
pre, col_type, post = get_type_qualifiers()
12561257
column_python_type = f"{pre}{render_python_type(col_type)}{post}"
1257-
return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
1258+
return column_python_type
1259+
1260+
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1261+
column = column_attr.column
1262+
rendered_column = self.render_column(column, column_attr.name != column.name)
1263+
rendered_column_python_type = self.render_column_python_type(column)
1264+
1265+
return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
12581266

12591267
def render_relationship(self, relationship: RelationshipAttribute) -> str:
12601268
def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
@@ -1444,15 +1452,6 @@ def collect_imports_for_model(self, model: Model) -> None:
14441452
if model.relationships:
14451453
self.add_literal_import("sqlmodel", "Relationship")
14461454

1447-
def collect_imports_for_column(self, column: Column[Any]) -> None:
1448-
super().collect_imports_for_column(column)
1449-
try:
1450-
python_type = column.type.python_type
1451-
except NotImplementedError:
1452-
self.add_literal_import("typing", "Any")
1453-
else:
1454-
self.add_import(python_type)
1455-
14561455
def render_module_variables(self, models: list[Model]) -> str:
14571456
declarations: list[str] = []
14581457
if any(not isinstance(model, ModelClass) for model in models):
@@ -1485,25 +1484,17 @@ def render_class_variables(self, model: ModelClass) -> str:
14851484

14861485
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
14871486
column = column_attr.column
1488-
try:
1489-
python_type = column.type.python_type
1490-
except NotImplementedError:
1491-
python_type_name = "Any"
1492-
else:
1493-
python_type_name = python_type.__name__
1487+
rendered_column = self.render_column(column, True)
1488+
rendered_column_python_type = self.render_column_python_type(column)
14941489

14951490
kwargs: dict[str, Any] = {}
1496-
if (
1497-
column.autoincrement and column.name in column.table.primary_key
1498-
) or column.nullable:
1499-
self.add_literal_import("typing", "Optional")
1491+
if column.nullable:
15001492
kwargs["default"] = None
1501-
python_type_name = f"Optional[{python_type_name}]"
1502-
1503-
rendered_column = self.render_column(column, True)
15041493
kwargs["sa_column"] = f"{rendered_column}"
1494+
15051495
rendered_field = render_callable("Field", kwargs=kwargs)
1506-
return f"{column_attr.name}: {python_type_name} = {rendered_field}"
1496+
1497+
return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
15071498

15081499
def render_relationship(self, relationship: RelationshipAttribute) -> str:
15091500
rendered = super().render_relationship(relationship).partition(" = ")[2]

src/sqlacodegen/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import re
4+
import sys
45
from collections.abc import Mapping
56
from typing import Any, Literal, cast
67

@@ -206,3 +207,13 @@ def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | No
206207
schema, sequence = sequence, ""
207208

208209
return schema, sequence
210+
211+
212+
def get_stdlib_module_names() -> set[str]:
213+
major, minor = sys.version_info.major, sys.version_info.minor
214+
if (major, minor) > (3, 9):
215+
return set(sys.builtin_module_names) | set(sys.stdlib_module_names)
216+
else:
217+
from stdlib_list import stdlib_list
218+
219+
return set(sys.builtin_module_names) | set(stdlib_list(f"{major}.{minor}"))

tests/test_cli.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,11 @@ def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
137137
assert (
138138
output_path.read_text()
139139
== """\
140-
from typing import Optional
141-
142140
from sqlalchemy import Column, Integer, Text
143141
from sqlmodel import Field, SQLModel
144142
145143
class Foo(SQLModel, table=True):
146-
id: Optional[int] = Field(default=None, sa_column=Column('id', Integer, \
147-
primary_key=True))
144+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
148145
name: str = Field(sa_column=Column('name', Text))
149146
"""
150147
)

tests/test_generator_dataclass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,11 @@ def test_uuid_type_annotation(generator: CodeGenerator) -> None:
251251
validate_code(
252252
generator.generate(),
253253
"""\
254+
import uuid
255+
254256
from sqlalchemy import UUID
255257
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
256258
mapped_column
257-
import uuid
258259
259260
class Base(MappedAsDataclass, DeclarativeBase):
260261
pass

tests/test_generator_sqlmodel.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from _pytest.fixtures import FixtureRequest
5+
from sqlalchemy import Uuid
56
from sqlalchemy.engine import Engine
67
from sqlalchemy.schema import (
78
CheckConstraint,
@@ -56,8 +57,7 @@ class Item(SQLModel, table=True):
5657
Index('idx_text_number', 'text', 'number')
5758
)
5859
59-
id: Optional[int] = Field(default=None, sa_column=Column(\
60-
'id', Integer, primary_key=True))
60+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
6161
number: Optional[int] = Field(default=None, sa_column=Column(\
6262
'number', Integer))
6363
text: Optional[str] = Field(default=None, sa_column=Column(\
@@ -91,8 +91,7 @@ class SimpleConstraints(SQLModel, table=True):
9191
UniqueConstraint('id', 'number')
9292
)
9393
94-
id: Optional[int] = Field(default=None, sa_column=Column(\
95-
'id', Integer, primary_key=True))
94+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
9695
number: Optional[int] = Field(default=None, sa_column=Column(\
9796
'number', Integer))
9897
""",
@@ -124,8 +123,7 @@ def test_onetomany(generator: CodeGenerator) -> None:
124123
class SimpleContainers(SQLModel, table=True):
125124
__tablename__ = 'simple_containers'
126125
127-
id: Optional[int] = Field(default=None, sa_column=Column(\
128-
'id', Integer, primary_key=True))
126+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
129127
130128
simple_goods: list['SimpleGoods'] = Relationship(\
131129
back_populates='container')
@@ -134,8 +132,7 @@ class SimpleContainers(SQLModel, table=True):
134132
class SimpleGoods(SQLModel, table=True):
135133
__tablename__ = 'simple_goods'
136134
137-
id: Optional[int] = Field(default=None, sa_column=Column(\
138-
'id', Integer, primary_key=True))
135+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
139136
container_id: Optional[int] = Field(default=None, sa_column=Column(\
140137
'container_id', ForeignKey('simple_containers.id')))
141138
@@ -167,8 +164,7 @@ def test_onetoone(generator: CodeGenerator) -> None:
167164
class OtherItems(SQLModel, table=True):
168165
__tablename__ = 'other_items'
169166
170-
id: Optional[int] = Field(default=None, sa_column=Column(\
171-
'id', Integer, primary_key=True))
167+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
172168
173169
simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\
174170
sa_relationship_kwargs={'uselist': False}, back_populates='other_item')
@@ -177,12 +173,34 @@ class OtherItems(SQLModel, table=True):
177173
class SimpleOnetoone(SQLModel, table=True):
178174
__tablename__ = 'simple_onetoone'
179175
180-
id: Optional[int] = Field(default=None, sa_column=Column(\
181-
'id', Integer, primary_key=True))
176+
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
182177
other_item_id: Optional[int] = Field(default=None, sa_column=Column(\
183178
'other_item_id', ForeignKey('other_items.id'), unique=True))
184179
185180
other_item: Optional['OtherItems'] = Relationship(\
186181
back_populates='simple_onetoone')
187182
""",
188183
)
184+
185+
186+
def test_uuid(generator: CodeGenerator) -> None:
187+
Table(
188+
"simple_uuid",
189+
generator.metadata,
190+
Column("id", Uuid, primary_key=True),
191+
)
192+
193+
validate_code(
194+
generator.generate(),
195+
"""\
196+
import uuid
197+
198+
from sqlalchemy import Column, Uuid
199+
from sqlmodel import Field, SQLModel
200+
201+
class SimpleUuid(SQLModel, table=True):
202+
__tablename__ = 'simple_uuid'
203+
204+
id: uuid.UUID = Field(sa_column=Column('id', Uuid, primary_key=True))
205+
""",
206+
)

0 commit comments

Comments
 (0)