Skip to content

Commit 69cebbb

Browse files
committed
Improve function types
1 parent 78b4087 commit 69cebbb

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/django_mysql/models/functions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def get(cls, using: str = DEFAULT_DB_ALIAS) -> int:
233233
# database connections in Django, and the reason was not clear
234234
with connections[using].cursor() as cursor:
235235
cursor.execute("SELECT LAST_INSERT_ID()")
236-
return cursor.fetchone()[0]
236+
id_: int = cursor.fetchone()[0]
237+
return id_
237238

238239

239240
# JSON Functions
@@ -322,7 +323,11 @@ def as_sql(
322323
if connection.vendor != "mysql": # pragma: no cover
323324
raise AssertionError("JSONValue only supports MySQL/MariaDB")
324325
json_string = json.dumps(self._data, allow_nan=False)
325-
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
326+
if (
327+
connection.vendor == "mysql"
328+
# type narrowed by vendor check
329+
and connection.mysql_is_mariadb # type: ignore [attr-defined]
330+
):
326331
# MariaDB doesn't support explicit cast to JSON.
327332
return "JSON_EXTRACT(%s, '$')", (json_string,)
328333
else:
@@ -334,7 +339,7 @@ def __init__(
334339
self,
335340
expression: ExpressionArgument,
336341
data: dict[
337-
str,
342+
ExpressionArgument,
338343
(
339344
ExpressionArgument
340345
| None
@@ -352,12 +357,12 @@ def __init__(
352357
exprs = [expression]
353358

354359
for path, value in data.items():
355-
if not hasattr(path, "resolve_expression"):
360+
if not isinstance(path, Expression):
356361
path = Value(path)
357362

358363
exprs.append(path)
359364

360-
if not hasattr(value, "resolve_expression"):
365+
if not isinstance(value, Expression):
361366
value = JSONValue(value)
362367

363368
exprs.append(value)
@@ -456,19 +461,20 @@ def __init__(
456461
self,
457462
expression: ExpressionArgument,
458463
to_add: dict[
459-
str, ExpressionArgument | float | int | dt.date | dt.time | dt.datetime
464+
ExpressionArgument,
465+
ExpressionArgument | float | int | dt.date | dt.time | dt.datetime,
460466
],
461467
) -> None:
462468
from django_mysql.models.fields import DynamicField
463469

464470
expressions = [expression]
465471
for name, value in to_add.items():
466-
if not hasattr(name, "resolve_expression"):
472+
if not isinstance(name, Expression):
467473
name = Value(name)
468474

469-
if isinstance(value, dict): # type: ignore [unreachable]
475+
if isinstance(value, dict):
470476
raise ValueError("ColumnAdd with nested values is not supported")
471-
if not hasattr(value, "resolve_expression"):
477+
if not isinstance(value, Expression):
472478
value = Value(value)
473479

474480
expressions.extend((name, value))

0 commit comments

Comments
 (0)