Skip to content

Commit 32a4d61

Browse files
author
Timon Viola
committed
feat(operators): add mlflow integration
1 parent ccffe74 commit 32a4d61

File tree

8 files changed

+217
-229
lines changed

8 files changed

+217
-229
lines changed

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,7 @@ ignore = [
243243
"D214",
244244
"D215",
245245
"E731", # Do not assign a lambda expression, use a def
246-
"TCH003", # Do not move imports from stdlib to TYPE_CHECKING block
247-
"PT004", # Fixture does not return anything, add leading underscore
248-
"PT005", # Fixture returns a value, remove leading underscore
246+
"TC003", # Do not move imports from stdlib to TYPE_CHECKING block
249247
"PT006", # Wrong type of names in @pytest.mark.parametrize
250248
"PT007", # Wrong type of values in @pytest.mark.parametrize
251249
"PT011", # pytest.raises() is too broad, set the match parameter

src/dagcellent/operators/mlflow/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
"""Mlflow operator."""
1+
"""Mlflow helpers."""
22
# ruff: noqa: G004
3+
# typing: ignore
34
from __future__ import annotations
45

56
from enum import Enum
@@ -11,15 +12,15 @@
1112

1213

1314
class SlimModelVersion(TypedDict):
14-
"""Slim, json serializable type of mlflow.entities.model_registry.ModelVersion."""
15+
"""Slim, JSON serializable type of mlflow.entities.model_registry.ModelVersion."""
1516

1617
name: str
1718
version: str
1819
run_id: str
1920
tags: dict[str, str]
2021

2122

22-
def _serialize_model_version(
23+
def serialize_model_version(
2324
model_version: mlflow.entities.model_registry.ModelVersion,
2425
) -> SlimModelVersion:
2526
return {

src/dagcellent/operators/mlflow/get_latest_model_version.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/dagcellent/operators/mlflow/get_model_meta_data.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

src/dagcellent/operators/mlflow/get_model_version_by_name_and_stage.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/dagcellent/operators/mlflow/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from typing_extensions import ParamSpec
1515

1616
if TYPE_CHECKING:
17-
from collections.abc import Sequence
18-
1917
# NOTE ruff fails for this check
2018
import mlflow.entities.model_registry # noqa: TCH004
2119

20+
from dagcellent.operators.mlflow._utils import MlflowModelStage
21+
2222

2323
_LOGGER = logging.getLogger(__name__)
2424

@@ -120,7 +120,7 @@ def transition_model_version_stage(
120120
def get_latest_versions(
121121
self: MlflowHook,
122122
name: str,
123-
stages: Sequence[str],
123+
stages: list[str],
124124
) -> list[mlflow.entities.model_registry.ModelVersion]:
125125
"""Get latest model version. See MLFlow docs.
126126

0 commit comments

Comments
 (0)