Skip to content

Commit e30c7ef

Browse files
authored
✨ Update type annotations and upgrade mypy (#173)
1 parent 02da85c commit e30c7ef

File tree

10 files changed

+90
-76
lines changed

10 files changed

+90
-76
lines changed

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
3737

3838
[tool.poetry.dev-dependencies]
3939
pytest = "^6.2.4"
40-
mypy = "^0.812"
40+
mypy = "^0.910"
4141
flake8 = "^3.9.2"
4242
black = {version = "^21.5-beta.1", python = "^3.7"}
4343
mkdocs = "^1.2.1"
@@ -98,3 +98,7 @@ warn_return_any = true
9898
implicit_reexport = false
9999
strict_equality = true
100100
# --strict end
101+
102+
[[tool.mypy.overrides]]
103+
module = "sqlmodel.sql.expression"
104+
warn_unused_ignores = false

sqlmodel/engine/create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ def create_engine(
136136
if not isinstance(query_cache_size, _DefaultPlaceholder):
137137
current_kwargs["query_cache_size"] = query_cache_size
138138
current_kwargs.update(kwargs)
139-
return _create_engine(url, **current_kwargs)
139+
return _create_engine(url, **current_kwargs) # type: ignore

sqlmodel/engine/result.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __iter__(self) -> Iterator[_T]:
2323
return super().__iter__()
2424

2525
def __next__(self) -> _T:
26-
return super().__next__()
26+
return super().__next__() # type: ignore
2727

2828
def first(self) -> Optional[_T]:
2929
return super().first()
@@ -32,7 +32,7 @@ def one_or_none(self) -> Optional[_T]:
3232
return super().one_or_none()
3333

3434
def one(self) -> _T:
35-
return super().one()
35+
return super().one() # type: ignore
3636

3737

3838
class Result(_Result, Generic[_T]):
@@ -70,10 +70,10 @@ def scalar_one(self) -> _T:
7070
return super().scalar_one() # type: ignore
7171

7272
def scalar_one_or_none(self) -> Optional[_T]:
73-
return super().scalar_one_or_none() # type: ignore
73+
return super().scalar_one_or_none()
7474

7575
def one(self) -> _T: # type: ignore
7676
return super().one() # type: ignore
7777

7878
def scalar(self) -> Optional[_T]:
79-
return super().scalar() # type: ignore
79+
return super().scalar()

sqlmodel/ext/asyncio/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
self,
2222
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
2323
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
24-
**kw,
24+
**kw: Any,
2525
):
2626
# All the same code of the original AsyncSession
2727
kw["future"] = True
@@ -52,7 +52,7 @@ async def exec(
5252
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
5353
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
5454

55-
return await greenlet_spawn( # type: ignore
55+
return await greenlet_spawn(
5656
self.sync_session.exec,
5757
statement,
5858
params=params,

sqlmodel/main.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
*,
102102
back_populates: Optional[str] = None,
103103
link_model: Optional[Any] = None,
104-
sa_relationship: Optional[RelationshipProperty] = None,
104+
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
105105
sa_relationship_args: Optional[Sequence[Any]] = None,
106106
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
107107
) -> None:
@@ -127,32 +127,32 @@ def Field(
127127
default: Any = Undefined,
128128
*,
129129
default_factory: Optional[NoArgAnyCallable] = None,
130-
alias: str = None,
131-
title: str = None,
132-
description: str = None,
130+
alias: Optional[str] = None,
131+
title: Optional[str] = None,
132+
description: Optional[str] = None,
133133
exclude: Union[
134134
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
135135
] = None,
136136
include: Union[
137137
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
138138
] = None,
139-
const: bool = None,
140-
gt: float = None,
141-
ge: float = None,
142-
lt: float = None,
143-
le: float = None,
144-
multiple_of: float = None,
145-
min_items: int = None,
146-
max_items: int = None,
147-
min_length: int = None,
148-
max_length: int = None,
139+
const: Optional[bool] = None,
140+
gt: Optional[float] = None,
141+
ge: Optional[float] = None,
142+
lt: Optional[float] = None,
143+
le: Optional[float] = None,
144+
multiple_of: Optional[float] = None,
145+
min_items: Optional[int] = None,
146+
max_items: Optional[int] = None,
147+
min_length: Optional[int] = None,
148+
max_length: Optional[int] = None,
149149
allow_mutation: bool = True,
150-
regex: str = None,
150+
regex: Optional[str] = None,
151151
primary_key: bool = False,
152152
foreign_key: Optional[Any] = None,
153153
nullable: Union[bool, UndefinedType] = Undefined,
154154
index: Union[bool, UndefinedType] = Undefined,
155-
sa_column: Union[Column, UndefinedType] = Undefined,
155+
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
156156
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
157157
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
158158
schema_extra: Optional[Dict[str, Any]] = None,
@@ -195,7 +195,7 @@ def Relationship(
195195
*,
196196
back_populates: Optional[str] = None,
197197
link_model: Optional[Any] = None,
198-
sa_relationship: Optional[RelationshipProperty] = None,
198+
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
199199
sa_relationship_args: Optional[Sequence[Any]] = None,
200200
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
201201
) -> Any:
@@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
217217

218218
# Replicate SQLAlchemy
219219
def __setattr__(cls, name: str, value: Any) -> None:
220-
if getattr(cls.__config__, "table", False): # type: ignore
220+
if getattr(cls.__config__, "table", False):
221221
DeclarativeMeta.__setattr__(cls, name, value)
222222
else:
223223
super().__setattr__(name, value)
224224

225225
def __delattr__(cls, name: str) -> None:
226-
if getattr(cls.__config__, "table", False): # type: ignore
226+
if getattr(cls.__config__, "table", False):
227227
DeclarativeMeta.__delattr__(cls, name)
228228
else:
229229
super().__delattr__(name)
230230

231231
# From Pydantic
232-
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
232+
def __new__(
233+
cls,
234+
name: str,
235+
bases: Tuple[Type[Any], ...],
236+
class_dict: Dict[str, Any],
237+
**kwargs: Any,
238+
) -> Any:
233239
relationships: Dict[str, RelationshipInfo] = {}
234240
dict_for_pydantic = {}
235241
original_annotations = resolve_annotations(
@@ -342,7 +348,7 @@ def __init__(
342348
)
343349
relationship_to = temp_field.type_
344350
if isinstance(temp_field.type_, ForwardRef):
345-
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
351+
relationship_to = temp_field.type_.__forward_arg__
346352
rel_kwargs: Dict[str, Any] = {}
347353
if rel_info.back_populates:
348354
rel_kwargs["back_populates"] = rel_info.back_populates
@@ -360,7 +366,7 @@ def __init__(
360366
rel_args.extend(rel_info.sa_relationship_args)
361367
if rel_info.sa_relationship_kwargs:
362368
rel_kwargs.update(rel_info.sa_relationship_kwargs)
363-
rel_value: RelationshipProperty = relationship(
369+
rel_value: RelationshipProperty = relationship( # type: ignore
364370
relationship_to, *rel_args, **rel_kwargs
365371
)
366372
dict_used[rel_name] = rel_value
@@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
408414
return GUID
409415

410416

411-
def get_column_from_field(field: ModelField) -> Column:
417+
def get_column_from_field(field: ModelField) -> Column: # type: ignore
412418
sa_column = getattr(field.field_info, "sa_column", Undefined)
413419
if isinstance(sa_column, Column):
414420
return sa_column
@@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
440446
kwargs["default"] = sa_default
441447
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
442448
if sa_column_args is not Undefined:
443-
args.extend(list(cast(Sequence, sa_column_args)))
449+
args.extend(list(cast(Sequence[Any], sa_column_args)))
444450
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
445451
if sa_column_kwargs is not Undefined:
446-
kwargs.update(cast(dict, sa_column_kwargs))
452+
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
447453
return Column(sa_type, *args, **kwargs)
448454

449455

@@ -452,24 +458,27 @@ def get_column_from_field(field: ModelField) -> Column:
452458
default_registry = registry()
453459

454460

455-
def _value_items_is_true(v) -> bool:
461+
def _value_items_is_true(v: Any) -> bool:
456462
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
457463
# the current latest, Pydantic 1.8.2
458464
return v is True or v is ...
459465

460466

467+
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
468+
469+
461470
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
462471
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
463472
__slots__ = ("__weakref__",)
464473
__tablename__: ClassVar[Union[str, Callable[..., str]]]
465-
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
474+
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
466475
__name__: ClassVar[str]
467476
metadata: ClassVar[MetaData]
468477

469478
class Config:
470479
orm_mode = True
471480

472-
def __new__(cls, *args, **kwargs) -> Any:
481+
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
473482
new_object = super().__new__(cls)
474483
# SQLAlchemy doesn't call __init__ on the base class
475484
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
@@ -520,7 +529,9 @@ def __setattr__(self, name: str, value: Any) -> None:
520529
super().__setattr__(name, value)
521530

522531
@classmethod
523-
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
532+
def from_orm(
533+
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
534+
) -> _TSQLModel:
524535
# Duplicated from Pydantic
525536
if not cls.__config__.orm_mode:
526537
raise ConfigError(
@@ -533,7 +544,7 @@ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
533544
# End SQLModel support dict
534545
if not getattr(cls.__config__, "table", False):
535546
# If not table, normal Pydantic code
536-
m = cls.__new__(cls)
547+
m: _TSQLModel = cls.__new__(cls)
537548
else:
538549
# If table, create the new instance normally to make SQLAlchemy create
539550
# the _sa_instance_state attribute
@@ -554,7 +565,7 @@ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
554565

555566
@classmethod
556567
def parse_obj(
557-
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
568+
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
558569
) -> "SQLModel":
559570
obj = cls._enforce_dict_if_root(obj)
560571
# SQLModel, support update dict

sqlmodel/orm/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def exec(
6060
results = super().execute(
6161
statement,
6262
params=params,
63-
execution_options=execution_options, # type: ignore
63+
execution_options=execution_options,
6464
bind_arguments=bind_arguments,
6565
_parent_execute_state=_parent_execute_state,
6666
_add_event=_add_event,
@@ -74,7 +74,7 @@ def execute(
7474
self,
7575
statement: _Executable,
7676
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
77-
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
77+
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
7878
bind_arguments: Optional[Mapping[str, Any]] = None,
7979
_parent_execute_state: Optional[Any] = None,
8080
_add_event: Optional[Any] = None,
@@ -101,7 +101,7 @@ def execute(
101101
return super().execute( # type: ignore
102102
statement,
103103
params=params,
104-
execution_options=execution_options, # type: ignore
104+
execution_options=execution_options,
105105
bind_arguments=bind_arguments,
106106
_parent_execute_state=_parent_execute_state,
107107
_add_event=_add_event,

sqlmodel/sql/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,4 @@
66

77

88
class Executable(_Executable, Generic[_T]):
9-
def __init__(self, *args, **kwargs):
10-
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
11-
super(_Executable, self).__init__(*args, **kwargs)
9+
pass

sqlmodel/sql/expression.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ class SelectOfScalar(_Select, Generic[_TSelect]):
4545
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
4646
pass
4747

48-
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
48+
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
4949
pass
5050

51-
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
51+
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
5252
pass
5353

5454
# Cast them for editors to work correctly, from several tricks tried, this works
@@ -65,9 +65,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet
6565

6666
_TScalar_0 = TypeVar(
6767
"_TScalar_0",
68-
Column,
69-
Sequence,
70-
Mapping,
68+
Column, # type: ignore
69+
Sequence, # type: ignore
70+
Mapping, # type: ignore
7171
UUID,
7272
datetime,
7373
float,
@@ -83,9 +83,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet
8383

8484
_TScalar_1 = TypeVar(
8585
"_TScalar_1",
86-
Column,
87-
Sequence,
88-
Mapping,
86+
Column, # type: ignore
87+
Sequence, # type: ignore
88+
Mapping, # type: ignore
8989
UUID,
9090
datetime,
9191
float,
@@ -101,9 +101,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet
101101

102102
_TScalar_2 = TypeVar(
103103
"_TScalar_2",
104-
Column,
105-
Sequence,
106-
Mapping,
104+
Column, # type: ignore
105+
Sequence, # type: ignore
106+
Mapping, # type: ignore
107107
UUID,
108108
datetime,
109109
float,
@@ -119,9 +119,9 @@ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMet
119119

120120
_TScalar_3 = TypeVar(
121121
"_TScalar_3",
122-
Column,
123-
Sequence,
124-
Mapping,
122+
Column, # type: ignore
123+
Sequence, # type: ignore
124+
Mapping, # type: ignore
125125
UUID,
126126
datetime,
127127
float,
@@ -446,14 +446,14 @@ def select( # type: ignore
446446
# Generated overloads end
447447

448448

449-
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
449+
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
450450
if len(entities) == 1:
451451
return SelectOfScalar._create(*entities, **kw) # type: ignore
452452
return Select._create(*entities, **kw) # type: ignore
453453

454454

455455
# TODO: add several @overload from Python types to SQLAlchemy equivalents
456-
def col(column_expression: Any) -> ColumnClause:
456+
def col(column_expression: Any) -> ColumnClause: # type: ignore
457457
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
458458
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
459459
return column_expression

0 commit comments

Comments
 (0)