Skip to content

Commit c41c91c

Browse files
committed
Improve function types
1 parent 68a4fcb commit c41c91c

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/django_mysql/models/functions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def get(cls, using: str = DEFAULT_DB_ALIAS) -> int:
171171
# database connections in Django, and the reason was not clear
172172
with connections[using].cursor() as cursor:
173173
cursor.execute("SELECT LAST_INSERT_ID()")
174-
return cursor.fetchone()[0]
174+
id_: int = cursor.fetchone()[0]
175+
return id_
175176

176177

177178
# JSON Functions
@@ -260,7 +261,11 @@ def as_sql(
260261
if connection.vendor != "mysql": # pragma: no cover
261262
raise AssertionError("JSONValue only supports MySQL/MariaDB")
262263
json_string = json.dumps(self._data, allow_nan=False)
263-
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
264+
if (
265+
connection.vendor == "mysql"
266+
# type narrowed by vendor check
267+
and connection.mysql_is_mariadb # type: ignore [attr-defined]
268+
):
264269
# MariaDB doesn't support explicit cast to JSON.
265270
return "JSON_EXTRACT(%s, '$')", (json_string,)
266271
else:
@@ -272,7 +277,7 @@ def __init__(
272277
self,
273278
expression: ExpressionArgument,
274279
data: dict[
275-
str,
280+
ExpressionArgument,
276281
(
277282
ExpressionArgument
278283
| None
@@ -290,12 +295,12 @@ def __init__(
290295
exprs = [expression]
291296

292297
for path, value in data.items():
293-
if not hasattr(path, "resolve_expression"):
298+
if not isinstance(path, Expression):
294299
path = Value(path)
295300

296301
exprs.append(path)
297302

298-
if not hasattr(value, "resolve_expression"):
303+
if not isinstance(value, Expression):
299304
value = JSONValue(value)
300305

301306
exprs.append(value)
@@ -394,19 +399,20 @@ def __init__(
394399
self,
395400
expression: ExpressionArgument,
396401
to_add: dict[
397-
str, ExpressionArgument | float | int | dt.date | dt.time | dt.datetime
402+
ExpressionArgument,
403+
ExpressionArgument | float | int | dt.date | dt.time | dt.datetime,
398404
],
399405
) -> None:
400406
from django_mysql.models.fields import DynamicField
401407

402408
expressions = [expression]
403409
for name, value in to_add.items():
404-
if not hasattr(name, "resolve_expression"):
410+
if not isinstance(name, Expression):
405411
name = Value(name)
406412

407-
if isinstance(value, dict): # type: ignore [unreachable]
413+
if isinstance(value, dict):
408414
raise ValueError("ColumnAdd with nested values is not supported")
409-
if not hasattr(value, "resolve_expression"):
415+
if not isinstance(value, Expression):
410416
value = Value(value)
411417

412418
expressions.extend((name, value))

0 commit comments

Comments
 (0)