Skip to content

Commit 4b37d5d

Browse files
authored
Better nullability definitions for all generators (#419)
* Better nullability definitios for all generators * Better nullability definitions for all generators * Improved SQLModels testing * PR Fixes
1 parent 2965556 commit 4b37d5d

File tree

6 files changed

+46
-12
lines changed

6 files changed

+46
-12
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Version history
1414
- Temporarily restrict SQLAlchemy version to 2.0.41 (PR by @sheinbergon)
1515
- Fixes ``add_import`` behavior when adding imports from sqlalchemy and overall better
1616
alignment of import behavior(s) across generators
17+
- Fixes ``nullable`` column behavior for non-null columns for both
18+
``sqlmodels`` and ``declarative`` generators (PR by @sheinbergon)
1719

1820
**3.0.0**
1921

src/sqlacodegen/generators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,9 @@ def render_column(
410410
args = []
411411
kwargs: dict[str, Any] = {}
412412
kwarg = []
413-
is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
413+
is_part_of_composite_pk = (
414+
column.primary_key and len(column.table.primary_key) > 1
415+
)
414416
dedicated_fks = [
415417
c
416418
for c in column.foreign_keys
@@ -460,8 +462,10 @@ def render_column(
460462
kwargs["key"] = column.key
461463
if is_primary:
462464
kwargs["primary_key"] = True
463-
if not column.nullable and not is_sole_pk and is_table:
465+
if not column.nullable and not column.primary_key:
464466
kwargs["nullable"] = False
467+
if column.nullable and is_part_of_composite_pk:
468+
kwargs["nullable"] = True
465469

466470
if is_unique:
467471
column.unique = True

tests/test_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class Foo(Base):
8282
__tablename__ = 'foo'
8383
8484
id: Mapped[int] = mapped_column(Integer, primary_key=True)
85-
name: Mapped[str] = mapped_column(Text)
85+
name: Mapped[str] = mapped_column(Text, nullable=False)
8686
"""
8787
)
8888

@@ -115,7 +115,7 @@ class Foo(Base):
115115
__tablename__ = 'foo'
116116
117117
id: Mapped[int] = mapped_column(Integer, primary_key=True)
118-
name: Mapped[str] = mapped_column(Text)
118+
name: Mapped[str] = mapped_column(Text, nullable=False)
119119
"""
120120
)
121121

@@ -142,7 +142,7 @@ def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None:
142142
143143
class Foo(SQLModel, table=True):
144144
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
145-
name: str = Field(sa_column=Column('name', Text))
145+
name: str = Field(sa_column=Column('name', Text, nullable=False))
146146
"""
147147
)
148148

tests/test_generator_dataclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Simple(Base):
7777
__tablename__ = 'simple'
7878
7979
id: Mapped[int] = mapped_column(Integer, primary_key=True)
80-
age: Mapped[int] = mapped_column(Integer)
80+
age: Mapped[int] = mapped_column(Integer, nullable=False)
8181
name: Mapped[Optional[str]] = mapped_column(String(20), \
8282
server_default=text('foo'))
8383
""",

tests/test_generator_declarative.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ class SimpleItems(Base):
341341
342342
id: Mapped[int] = mapped_column(Integer, primary_key=True)
343343
top_container_id: Mapped[int] = \
344-
mapped_column(ForeignKey('simple_containers.id'))
344+
mapped_column(ForeignKey('simple_containers.id'), nullable=False)
345345
parent_container_id: Mapped[Optional[int]] = \
346346
mapped_column(ForeignKey('simple_containers.id'))
347347
@@ -812,6 +812,34 @@ class SimpleItems(Base):
812812
)
813813

814814

815+
def test_composite_nullable_pk(generator: CodeGenerator) -> None:
816+
Table(
817+
"simple_items",
818+
generator.metadata,
819+
Column("id1", INTEGER, primary_key=True),
820+
Column("id2", INTEGER, primary_key=True, nullable=True),
821+
)
822+
validate_code(
823+
generator.generate(),
824+
"""\
825+
from typing import Optional
826+
827+
from sqlalchemy import Integer
828+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
829+
830+
class Base(DeclarativeBase):
831+
pass
832+
833+
834+
class SimpleItems(Base):
835+
__tablename__ = 'simple_items'
836+
837+
id1: Mapped[int] = mapped_column(Integer, primary_key=True)
838+
id2: Mapped[Optional[int]] = mapped_column(Integer, primary_key=True, nullable=True)
839+
""",
840+
)
841+
842+
815843
def test_joined_inheritance(generator: CodeGenerator) -> None:
816844
Table(
817845
"simple_sub_items",
@@ -1045,7 +1073,7 @@ class Group(Base):
10451073
)
10461074
10471075
groups_id: Mapped[int] = mapped_column(Integer, primary_key=True)
1048-
group_name: Mapped[str] = mapped_column(Text(50))
1076+
group_name: Mapped[str] = mapped_column(Text(50), nullable=False)
10491077
10501078
users: Mapped[list['User']] = relationship('User', back_populates='group')
10511079
@@ -1590,7 +1618,7 @@ class WithItems(Base):
15901618
__tablename__ = 'with_items'
15911619
15921620
id: Mapped[int] = mapped_column(Integer, primary_key=True)
1593-
int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER()))
1621+
int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER()), nullable=False)
15941622
str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2))
15951623
""",
15961624
)

tests/test_generator_sqlmodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_indexes(generator: CodeGenerator) -> None:
3333
"item",
3434
generator.metadata,
3535
Column("id", INTEGER, primary_key=True),
36-
Column("number", INTEGER),
36+
Column("number", INTEGER, nullable=False),
3737
Column("text", VARCHAR),
3838
)
3939
simple_items.indexes.add(Index("idx_number", simple_items.c.number))
@@ -58,8 +58,8 @@ class Item(SQLModel, table=True):
5858
)
5959
6060
id: int = Field(sa_column=Column('id', Integer, primary_key=True))
61-
number: Optional[int] = Field(default=None, sa_column=Column(\
62-
'number', Integer))
61+
number: int = Field(sa_column=Column(\
62+
'number', Integer, nullable=False))
6363
text: Optional[str] = Field(default=None, sa_column=Column(\
6464
'text', String))
6565
""",

0 commit comments

Comments
 (0)