Skip to content

Commit 8e2c4b0

Browse files
committed
Simplify
1 parent c3c8912 commit 8e2c4b0

File tree

4 files changed

+5
-14
lines changed

4 files changed

+5
-14
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ disallow_untyped_defs = false
109109
disallow_untyped_calls = false
110110

111111
[tool.ruff.lint]
112-
typing-modules = ["sqlmodel._compat"]
113112
select = [
114113
"E", # pycodestyle errors
115114
"W", # pycodestyle warnings

sqlmodel/_compat.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from pydantic.fields import FieldInfo
2525
from typing_extensions import Annotated, get_args, get_origin
2626

27-
from .sql.sqltypes import AutoString
28-
2927
# Reassign variable to make it reexported for mypy
3028
PYDANTIC_VERSION = P_VERSION
3129
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
@@ -469,8 +467,8 @@ def is_field_noneable(field: "FieldInfo") -> bool:
469467

470468
def get_sa_type_from_field(field: Any) -> Any:
471469
if get_origin(field.type_) is Literal:
472-
return AutoString
473-
elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
470+
return Literal
471+
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
474472
return field.type_
475473
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
476474

sqlmodel/main.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
inspect,
4141
)
4242
from sqlalchemy import Enum as sa_Enum
43-
from sqlalchemy import types as sa_types
4443
from sqlalchemy.orm import (
4544
Mapped,
4645
RelationshipProperty,
@@ -53,13 +52,12 @@
5352
from sqlalchemy.orm.instrumentation import is_instrumented
5453
from sqlalchemy.sql.schema import MetaData
5554
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
56-
from typing_extensions import TypeAlias, deprecated, get_origin
55+
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
5756

5857
from ._compat import ( # type: ignore[attr-defined]
5958
IS_PYDANTIC_V2,
6059
PYDANTIC_MINOR_VERSION,
6160
BaseConfig,
62-
Literal,
6361
ModelField,
6462
ModelMetaclass,
6563
Representation,
@@ -657,16 +655,12 @@ def get_sqlalchemy_type(field: Any) -> Any:
657655
type_ = get_sa_type_from_field(field)
658656
metadata = get_field_metadata(field)
659657

660-
# If it's already an SQLAlchemy type (eg. AutoString), use it directly
661-
if isinstance(type_, type) and issubclass(type_, sa_types.TypeEngine):
662-
return type_
663-
664658
# Checks for `Literal` type annotation
665659
if type_ is Literal:
666660
return AutoString
667661
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
668662
if issubclass(type_, Enum):
669-
return sa_Enum(cast(Type[Enum], type_))
663+
return sa_Enum(type_)
670664
if issubclass(
671665
type_,
672666
(

tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlalchemy.exc import IntegrityError
55
from sqlalchemy.orm import RelationshipProperty
66
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
7-
from sqlmodel._compat import Literal
7+
from typing_extensions import Literal
88

99

1010
def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):

0 commit comments

Comments
 (0)