Skip to content

Commit e1c3e0f

Browse files
authored
Update pre-commit hooks (#219)
1 parent 6ac2991 commit e1c3e0f

File tree

12 files changed

+52
-52
lines changed

12 files changed

+52
-52
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v5.0.0
6+
rev: v6.0.0
77
hooks:
88
- id: check-ast
99
- id: check-toml
@@ -18,15 +18,15 @@ repos:
1818
# - id: no-commit-to-branch
1919

2020
- repo: https://github.com/pre-commit/mirrors-mypy
21-
rev: 'v1.11.2'
21+
rev: 'v1.19.1'
2222
hooks:
2323
- id: mypy
2424
additional_dependencies:
2525
- fastapi
2626
- pytest
2727

2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: 'v0.6.9'
29+
rev: 'v0.14.10'
3030
hooks:
3131
- id: ruff
3232
args: [--fix]

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,15 @@ line-length = 100
5454
# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later
5555
select = ["ALL"]
5656
ignore = [
57-
"ANN101", # style choice - no annotation for self
58-
"ANN102", # style choice - no annotation for cls
5957
"CPY", # we do not require copyright in every file
6058
"D", # todo: docstring linting
6159
"D203",
6260
"D204",
6361
"D213",
6462
"DTZ", # To add
6563
# Linter does not detect when types are used for Pydantic
66-
"TCH001",
67-
"TCH003",
64+
"TC001",
65+
"TC003",
6866
]
6967

7068
[tool.ruff.lint.per-file-ignores]

src/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable:
2222

2323
@functools.cache
2424
def _load_configuration(file: Path) -> TomlTable:
25-
return typing.cast(TomlTable, tomllib.loads(file.read_text()))
25+
return tomllib.loads(file.read_text())
2626

2727

2828
def load_routing_configuration(file: Path = CONFIG_PATH) -> TomlTable:
29-
return typing.cast(TomlTable, _load_configuration(file)["routing"])
29+
return typing.cast("TomlTable", _load_configuration(file)["routing"])
3030

3131

3232
@functools.cache

src/database/evaluations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
1111
return cast(
12-
Sequence[Row],
12+
"Sequence[Row]",
1313
connection.execute(
1414
text(
1515
"""

src/database/flows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
88
return cast(
9-
Sequence[Row],
9+
"Sequence[Row]",
1010
expdb.execute(
1111
text(
1212
"""
@@ -36,7 +36,7 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:
3636

3737
def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
3838
return cast(
39-
Sequence[Row],
39+
"Sequence[Row]",
4040
expdb.execute(
4141
text(
4242
"""

src/database/studies.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
4343
"""
4444
if study.type_ == StudyType.TASK:
4545
return cast(
46-
Sequence[Row],
46+
"Sequence[Row]",
4747
expdb.execute(
4848
text(
4949
"""
@@ -56,7 +56,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
5656
).all(),
5757
)
5858
return cast(
59-
Sequence[Row],
59+
"Sequence[Row]",
6060
expdb.execute(
6161
text(
6262
"""
@@ -103,7 +103,7 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int:
103103
},
104104
)
105105
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one()
106-
return cast(int, study_id)
106+
return cast("int", study_id)
107107

108108

109109
def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None:

src/database/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def get(id_: int, expdb: Connection) -> Row | None:
1919

2020
def get_task_types(expdb: Connection) -> Sequence[Row]:
2121
return cast(
22-
Sequence[Row],
22+
"Sequence[Row]",
2323
expdb.execute(
2424
text(
2525
"""
@@ -46,7 +46,7 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None:
4646

4747
def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]:
4848
return cast(
49-
Sequence[Row],
49+
"Sequence[Row]",
5050
expdb.execute(
5151
text(
5252
"""
@@ -62,7 +62,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro
6262

6363
def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
6464
return cast(
65-
Sequence[Row],
65+
"Sequence[Row]",
6666
expdb.execute(
6767
text(
6868
"""
@@ -78,7 +78,7 @@ def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
7878

7979
def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]:
8080
return cast(
81-
Sequence[Row],
81+
"Sequence[Row]",
8282
expdb.execute(
8383
text(
8484
"""

src/routers/openml/tasks.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import re
33
from http import HTTPStatus
4-
from typing import Annotated, Any
4+
from typing import Annotated, cast
55

66
import xmltodict
77
from fastapi import APIRouter, Depends, HTTPException
@@ -15,22 +15,24 @@
1515

1616
router = APIRouter(prefix="/tasks", tags=["tasks"])
1717

18+
type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
1819

19-
def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401
20+
21+
def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
2022
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
2123
json_str = json.dumps(json_template)
2224
# To account for the differences between PHP and Python conversions:
2325
for py, php in [("@name", "name"), ("#text", "value"), ("@type", "type")]:
2426
json_str = json_str.replace(py, php)
25-
return json.loads(json_str)
27+
return cast("dict[str, JSON]", json.loads(json_str))
2628

2729

2830
def fill_template(
2931
template: str,
3032
task: RowMapping,
31-
task_inputs: dict[str, str],
33+
task_inputs: dict[str, str | int],
3234
connection: Connection,
33-
) -> Any: # noqa: ANN401
35+
) -> dict[str, JSON]:
3436
"""Fill in the XML template as used for task descriptions and return the result,
3537
converted to JSON.
3638
@@ -79,22 +81,25 @@ def fill_template(
7981
}
8082
"""
8183
json_template = convert_template_xml_to_json(template)
82-
return _fill_json_template(
83-
json_template,
84-
task,
85-
task_inputs,
86-
fetched_data={},
87-
connection=connection,
84+
return cast(
85+
"dict[str, JSON]",
86+
_fill_json_template(
87+
json_template,
88+
task,
89+
task_inputs,
90+
fetched_data={},
91+
connection=connection,
92+
),
8893
)
8994

9095

9196
def _fill_json_template(
92-
template: dict[str, Any],
97+
template: JSON,
9398
task: RowMapping,
94-
task_inputs: dict[str, str],
95-
fetched_data: dict[str, Any],
99+
task_inputs: dict[str, str | int],
100+
fetched_data: dict[str, str],
96101
connection: Connection,
97-
) -> dict[str, Any] | list[dict[str, Any]] | str:
102+
) -> JSON:
98103
if isinstance(template, dict):
99104
return {
100105
k: _fill_json_template(v, task, task_inputs, fetched_data, connection)
@@ -115,7 +120,7 @@ def _fill_json_template(
115120
if match.string == template:
116121
# How do we know the default value? probably ttype_io table?
117122
return task_inputs.get(field, [])
118-
template = template.replace(match.group(), task_inputs[field])
123+
template = template.replace(match.group(), str(task_inputs[field]))
119124
if match := re.search(r"\[LOOKUP:(.*)]", template):
120125
(field,) = match.groups()
121126
if field not in fetched_data:
@@ -176,7 +181,7 @@ def get_task(
176181
tags = database.tasks.get_tags(task_id, expdb)
177182
name = f"Task {task_id} ({task_type.name})"
178183
dataset_id = task_inputs.get("source_data")
179-
if dataset_id and (dataset := database.datasets.get(dataset_id, expdb)):
184+
if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)):
180185
name = f"Task {task_id}: {dataset.name} ({task_type.name})"
181186

182187
return Task(

src/routers/openml/tasktype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def get_task_type(
5353
task_type = _normalize_task_type(task_type_record)
5454
# Some names are quoted, or have typos in their comma-separation (e.g. 'A ,B')
5555
task_type["creator"] = [
56-
creator.strip(' "') for creator in cast(str, task_type["creator"]).split(",")
56+
creator.strip(' "') for creator in cast("str", task_type["creator"]).split(",")
5757
]
5858
if contributors := task_type.pop("contributors"):
5959
task_type["contributor"] = [
60-
creator.strip(' "') for creator in cast(str, contributors).split(",")
60+
creator.strip(' "') for creator in cast("str", contributors).split(",")
6161
]
6262
task_type["creation_date"] = task_type.pop("creationDate")
6363
task_type_inputs = get_input_for_task_type(task_type_id, expdb)

src/schemas/datasets/mldcat_ap.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from abc import ABC
1212
from enum import StrEnum
13-
from typing import Generic, Literal, TypeVar
13+
from typing import Literal
1414

1515
from pydantic import BaseModel, Field, HttpUrl, field_serializer, model_serializer
1616

@@ -41,10 +41,7 @@ class JsonLDObject(BaseModel, ABC):
4141
}
4242

4343

44-
T = TypeVar("T", bound=JsonLDObject)
45-
46-
47-
class JsonLDObjectReference(BaseModel, Generic[T]):
44+
class JsonLDObjectReference[T: JsonLDObject](BaseModel):
4845
id_: str = Field(serialization_alias="@id")
4946

5047
model_config = {"populate_by_name": True, "extra": "forbid"}
@@ -275,7 +272,7 @@ class DataService(JsonLDObject):
275272

276273

277274
class JsonLDGraph(BaseModel):
278-
context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context") # type: ignore[arg-type]
275+
context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context")
279276
graph: list[Distribution | DataService | Dataset | Quality | Feature | Agent | MD5Checksum] = (
280277
Field(default_factory=list, serialization_alias="@graph")
281278
)

0 commit comments

Comments
 (0)