Skip to content

Commit ccf1271

Browse files
committed
Fixed conflict between server_default and column attributes
Fixes #185.
1 parent 19cc0ac commit ccf1271

File tree

3 files changed

+61
-30
lines changed

3 files changed

+61
-30
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Version history
1010
- Fixed improper handling of schema prefixes in sequence names in server defaults
1111
- Fixed identically named tables from different schemas resulting in invalid generated
1212
code
13+
- Fixed imports caused by ``server_default`` conflicting with class attribute names
1314
- Worked around PostgreSQL UUID columns getting ``Any`` as the type annotation
1415

1516
**3.0.0b3**

src/sqlacodegen/generators.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,15 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
198198
):
199199
self.add_import(column.type.astext_type)
200200

201+
if column.default:
202+
self.add_import(column.default)
203+
204+
if column.server_default:
205+
if isinstance(column.server_default, (Computed, Identity)):
206+
self.add_import(column.server_default)
207+
elif isinstance(column.server_default, DefaultClause):
208+
self.add_literal_import("sqlalchemy", "text")
209+
201210
def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
202211
if isinstance(constraint, Index):
203212
if len(constraint.columns) > 1 or not uses_default_name(constraint):
@@ -382,6 +391,9 @@ def render_column(self, column: Column[Any], show_name: bool) -> str:
382391
for fk in dedicated_fks:
383392
args.append(self.render_constraint(fk))
384393

394+
if column.default:
395+
args.append(repr(column.default))
396+
385397
if column.key != column.name:
386398
kwargs["key"] = column.key
387399
if is_primary:
@@ -401,31 +413,6 @@ def render_column(self, column: Column[Any], show_name: bool) -> str:
401413
kwargs["server_default"] = render_callable(
402414
"text", repr(column.server_default.arg.text)
403415
)
404-
if isinstance(column.server_default.arg, TextClause):
405-
if self.bind.dialect.name == "postgresql":
406-
match = _re_postgresql_nextval_sequence.match(
407-
column.server_default.arg.text
408-
)
409-
if match:
410-
# Add an explicit sequence
411-
if match.group(2) != f"{column.table.name}_{column.name}_seq":
412-
callable_kwargs = {}
413-
if match.group(1):
414-
callable_kwargs["schema"] = repr(match.group(1))
415-
416-
args.append(
417-
render_callable(
418-
"Sequence",
419-
repr(match.group(2)),
420-
kwargs=callable_kwargs,
421-
)
422-
)
423-
self.add_literal_import("sqlalchemy", "Sequence")
424-
425-
del kwargs["server_default"]
426-
427-
if "server_default" in kwargs:
428-
self.add_literal_import("sqlalchemy", "text")
429416
elif isinstance(column.server_default, Computed):
430417
expression = str(column.server_default.sqltext)
431418

@@ -436,10 +423,8 @@ def render_column(self, column: Column[Any], show_name: bool) -> str:
436423
args.append(
437424
render_callable("Computed", repr(expression), kwargs=computed_kwargs)
438425
)
439-
self.add_import(Computed)
440426
elif isinstance(column.server_default, Identity):
441427
args.append(repr(column.server_default))
442-
self.add_import(Identity)
443428
elif column.server_default:
444429
kwargs["server_default"] = repr(column.server_default)
445430

@@ -603,6 +588,23 @@ def fix_column_types(self, table: Table) -> None:
603588
except CompileError:
604589
pass
605590

591+
# PostgreSQL specific fix: detect sequences from server_default
592+
if column.server_default and self.bind.dialect.name == "postgresql":
593+
if isinstance(column.server_default, DefaultClause) and isinstance(
594+
column.server_default.arg, TextClause
595+
):
596+
match = _re_postgresql_nextval_sequence.match(
597+
column.server_default.arg.text
598+
)
599+
if match:
600+
# Add an explicit sequence
601+
if match.group(2) != f"{column.table.name}_{column.name}_seq":
602+
column.default = sqlalchemy.Sequence(
603+
match.group(2), schema=match.group(1)
604+
)
605+
606+
column.server_default = None
607+
606608
def get_adapted_type(self, coltype: Any) -> Any:
607609
compiled_type = coltype.compile(self.bind.engine.dialect)
608610
for supercls in coltype.__class__.__mro__:

tests/test_generators.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,34 @@ class SimpleItems(Base):
22082208
""",
22092209
)
22102210

2211+
# @pytest.mark.xfail(strict=True)
2212+
def test_colname_import_conflict(self, generator: CodeGenerator) -> None:
2213+
Table(
2214+
"simple",
2215+
generator.metadata,
2216+
Column("id", INTEGER, primary_key=True),
2217+
Column("text", VARCHAR),
2218+
Column("textwithdefault", VARCHAR, server_default=text("'test'")),
2219+
)
2220+
2221+
validate_code(
2222+
generator.generate(),
2223+
"""\
2224+
from sqlalchemy import Column, Integer, String, text
2225+
from sqlalchemy.orm import declarative_base
2226+
2227+
Base = declarative_base()
2228+
2229+
2230+
class Simple(Base):
2231+
__tablename__ = 'simple'
2232+
2233+
id = Column(Integer, primary_key=True)
2234+
text_ = Column('text', String)
2235+
textwithdefault = Column(String, server_default=text("'test'"))
2236+
""",
2237+
)
2238+
22112239

22122240
class TestDataclassGenerator:
22132241
@pytest.fixture
@@ -2255,7 +2283,7 @@ def test_mandatory_field_last(self, generator: CodeGenerator) -> None:
22552283
"simple",
22562284
generator.metadata,
22572285
Column("id", INTEGER, primary_key=True),
2258-
Column("name", VARCHAR(20), default="foo"),
2286+
Column("name", VARCHAR(20), server_default=text("foo")),
22592287
Column("age", INTEGER, nullable=False),
22602288
)
22612289

@@ -2267,7 +2295,7 @@ def test_mandatory_field_last(self, generator: CodeGenerator) -> None:
22672295
from dataclasses import dataclass, field
22682296
from typing import Optional
22692297
2270-
from sqlalchemy import Column, Integer, String
2298+
from sqlalchemy import Column, Integer, String, text
22712299
from sqlalchemy.orm import registry
22722300
22732301
mapper_registry = registry()
@@ -2281,7 +2309,7 @@ class Simple:
22812309
22822310
id: int = field(init=False, metadata={'sa': Column(Integer, primary_key=True)})
22832311
age: int = field(metadata={'sa': Column(Integer, nullable=False)})
2284-
name: Optional[str] = field(default=None, metadata={'sa': Column(String(20))})
2312+
name: Optional[str] = field(default=None, metadata={'sa': Column(String(20), server_default=text('foo'))})
22852313
""",
22862314
)
22872315

0 commit comments

Comments
 (0)