Skip to content

Commit de983de

Browse files
committed
Correct types of source expression functions
1 parent 07f44d0 commit de983de

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/django_mysql/models/expressions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
from typing import Any
4-
from typing import Iterable
4+
from typing import Sequence
55

66
from django.db.backends.base.base import BaseDatabaseWrapper
77
from django.db.models import F
88
from django.db.models import Value
99
from django.db.models.expressions import BaseExpression
10+
from django.db.models.expressions import Combinable
11+
from django.db.models.expressions import Expression
1012
from django.db.models.sql.compiler import SQLCompiler
1113

1214
from django_mysql.utils import collapse_spaces
@@ -18,10 +20,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None:
1820
self.lhs = lhs
1921
self.rhs = rhs
2022

21-
def get_source_expressions(self) -> list[BaseExpression]:
23+
def get_source_expressions(self) -> list[Expression]:
2224
return [self.lhs, self.rhs]
2325

24-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
26+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
2527
self.lhs, self.rhs = exprs
2628

2729

@@ -138,10 +140,10 @@ def __init__(self, lhs: BaseExpression) -> None:
138140
super().__init__()
139141
self.lhs = lhs
140142

141-
def get_source_expressions(self) -> list[BaseExpression]:
143+
def get_source_expressions(self) -> list[Expression]:
142144
return [self.lhs]
143145

144-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
146+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
145147
(self.lhs,) = exprs
146148

147149
def as_sql(
@@ -170,10 +172,10 @@ def __init__(self, lhs: BaseExpression) -> None:
170172
super().__init__()
171173
self.lhs = lhs
172174

173-
def get_source_expressions(self) -> list[BaseExpression]:
175+
def get_source_expressions(self) -> list[Expression]:
174176
return [self.lhs]
175177

176-
def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
178+
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
177179
(self.lhs,) = exprs
178180

179181
def as_sql(

0 commit comments

Comments
 (0)