Skip to content

Commit ac86029

Browse files
committed
Correct types of source expression functions
1 parent 22e57fa commit ac86029

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/django_mysql/models/expressions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from collections.abc import Iterable
44
from typing import Any
5+
from typing import Sequence
56

67
from django.db.backends.base.base import BaseDatabaseWrapper
78
from django.db.models import F
89
from django.db.models import Value
910
from django.db.models.expressions import BaseExpression
11+
from django.db.models.expressions import Combinable
12+
from django.db.models.expressions import Expression
1013
from django.db.models.sql.compiler import SQLCompiler
1114

1215
from django_mysql.utils import collapse_spaces
@@ -18,10 +21,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None:
1821
self.lhs = lhs
1922
self.rhs = rhs
2023

21-
def get_source_expressions(self) -> list[BaseExpression]:
24+
def get_source_expressions(self) -> list[Expression]:
2225
return [self.lhs, self.rhs]
2326

24-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
27+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
2528
self.lhs, self.rhs = exprs
2629

2730

@@ -138,10 +141,10 @@ def __init__(self, lhs: BaseExpression) -> None:
138141
super().__init__()
139142
self.lhs = lhs
140143

141-
def get_source_expressions(self) -> list[BaseExpression]:
144+
def get_source_expressions(self) -> list[Expression]:
142145
return [self.lhs]
143146

144-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
147+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
145148
(self.lhs,) = exprs
146149

147150
def as_sql(
@@ -170,10 +173,10 @@ def __init__(self, lhs: BaseExpression) -> None:
170173
super().__init__()
171174
self.lhs = lhs
172175

173-
def get_source_expressions(self) -> list[BaseExpression]:
176+
def get_source_expressions(self) -> list[Expression]:
174177
return [self.lhs]
175178

176-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
179+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
177180
(self.lhs,) = exprs
178181

179182
def as_sql(

0 commit comments

Comments
 (0)