Skip to content

Commit 269534c

Browse files
committed
Improve function types
1 parent 1bbc110 commit 269534c

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/django_mysql/models/functions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def get(cls, using: str = DEFAULT_DB_ALIAS) -> int:
172172
# database connections in Django, and the reason was not clear
173173
with connections[using].cursor() as cursor:
174174
cursor.execute("SELECT LAST_INSERT_ID()")
175-
return cursor.fetchone()[0]
175+
id_: int = cursor.fetchone()[0]
176+
return id_
176177

177178

178179
# JSON Functions
@@ -261,7 +262,11 @@ def as_sql(
261262
if connection.vendor != "mysql": # pragma: no cover
262263
raise AssertionError("JSONValue only supports MySQL/MariaDB")
263264
json_string = json.dumps(self._data, allow_nan=False)
264-
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
265+
if (
266+
connection.vendor == "mysql"
267+
# type narrowed by vendor check
268+
and connection.mysql_is_mariadb # type: ignore [attr-defined]
269+
):
265270
# MariaDB doesn't support explicit cast to JSON.
266271
return "JSON_EXTRACT(%s, '$')", (json_string,)
267272
else:
@@ -273,7 +278,7 @@ def __init__(
273278
self,
274279
expression: ExpressionArgument,
275280
data: dict[
276-
str,
281+
ExpressionArgument,
277282
(
278283
ExpressionArgument
279284
| None
@@ -291,12 +296,12 @@ def __init__(
291296
exprs = [expression]
292297

293298
for path, value in data.items():
294-
if not hasattr(path, "resolve_expression"):
299+
if not isinstance(path, Expression):
295300
path = Value(path)
296301

297302
exprs.append(path)
298303

299-
if not hasattr(value, "resolve_expression"):
304+
if not isinstance(value, Expression):
300305
value = JSONValue(value)
301306

302307
exprs.append(value)
@@ -395,19 +400,20 @@ def __init__(
395400
self,
396401
expression: ExpressionArgument,
397402
to_add: dict[
398-
str, ExpressionArgument | float | int | dt.date | dt.time | dt.datetime
403+
ExpressionArgument,
404+
ExpressionArgument | float | int | dt.date | dt.time | dt.datetime,
399405
],
400406
) -> None:
401407
from django_mysql.models.fields import DynamicField
402408

403409
expressions = [expression]
404410
for name, value in to_add.items():
405-
if not hasattr(name, "resolve_expression"):
411+
if not isinstance(name, Expression):
406412
name = Value(name)
407413

408-
if isinstance(value, dict): # type: ignore [unreachable]
414+
if isinstance(value, dict):
409415
raise ValueError("ColumnAdd with nested values is not supported")
410-
if not hasattr(value, "resolve_expression"):
416+
if not isinstance(value, Expression):
411417
value = Value(value)
412418

413419
expressions.extend((name, value))

0 commit comments

Comments
 (0)