Skip to content

Commit e6ad74d

Browse files
author
John Lyu
committed
fix lint
1 parent 48f2a88 commit e6ad74d

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

sqlmodel/_compat.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic import BaseModel
2323
from pydantic.fields import FieldInfo
2424
from sqlalchemy import inspect
25+
from sqlalchemy.orm import Mapper
2526
from typing_extensions import Annotated, get_args, get_origin
2627

2728
# Reassign variable to make it reexported for mypy
@@ -293,20 +294,21 @@ def sqlmodel_table_construct(
293294
# End SQLModel override
294295
# Override polymorphic_on default value
295296
mapper = inspect(cls)
296-
polymorphic_on = mapper.polymorphic_on
297-
if polymorphic_on is not None:
298-
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
299-
field_info = cls.model_fields.get(polymorphic_property.key)
300-
if field_info:
301-
v = values.get(polymorphic_property.key)
302-
# if model is inherited or polymorphic_on is not explicitly set
303-
# set the polymorphic_on by default
304-
if mapper.inherits or v is None:
305-
setattr(
306-
self_instance,
307-
polymorphic_property.key,
308-
mapper.polymorphic_identity,
309-
)
297+
if isinstance(mapper, Mapper):
298+
polymorphic_on = mapper.polymorphic_on
299+
if polymorphic_on is not None:
300+
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
301+
field_info = cls.model_fields.get(polymorphic_property.key)
302+
if field_info:
303+
v = values.get(polymorphic_property.key)
304+
# if model is inherited or polymorphic_on is not explicitly set
305+
# set the polymorphic_on by default
306+
if mapper.inherits or v is None:
307+
setattr(
308+
self_instance,
309+
polymorphic_property.key,
310+
mapper.polymorphic_identity,
311+
)
310312
return self_instance
311313

312314
def sqlmodel_validate(

sqlmodel/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,10 @@ def __new__(
551551
for base in bases[::-1]:
552552
if issubclass(base, BaseModel):
553553
base_fields.update(base.model_fields)
554-
for k, v in new_cls.model_fields.items():
554+
fields = get_model_fields(new_cls)
555+
for k, v in fields.items():
555556
if isinstance(v.default, InstrumentedAttribute):
556-
new_cls.model_fields[k] = base_fields.get(k)
557+
fields[k] = base_fields.get(k, FieldInfo())
557558

558559
def get_config(name: str) -> Any:
559560
config_class_value = get_config_value(
@@ -572,7 +573,7 @@ def get_config(name: str) -> Any:
572573
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
573574
# a table model
574575
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
575-
new_cls.__tablename__ = new_cls.__name__.lower()
576+
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
576577
# If it was passed by kwargs, ensure it's also set in config
577578
set_config_value(model=new_cls, parameter="table", value=config_table)
578579
for k, v in get_model_fields(new_cls).items():
@@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
731732
raise ValueError(f"{type_} has no matching SQLAlchemy type")
732733

733734

734-
def get_column_from_field(field: Any) -> Column: # type: ignore
735+
def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore
735736
if IS_PYDANTIC_V2:
736737
field_info = field
737738
else:

0 commit comments

Comments
 (0)