Skip to content

Commit 02f9efa

Browse files
author
Timon Viola
committed
test: add mlflow integration tests
1 parent 8784b5c commit 02f9efa

File tree

10 files changed

+201
-104
lines changed

10 files changed

+201
-104
lines changed

.containerignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
!share
2+
*

.pre-commit-config.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ repos:
1515
- id: typos
1616
stages: [commit]
1717
exclude: "^typings/.*"
18-
- repo: https://github.com/pre-commit/mirrors-mypy
19-
rev: v1.10.1 # Use the sha / tag you want to point at
18+
- repo: local
2019
hooks:
21-
- id: mypy
22-
language: system
23-
pass_filenames: false
24-
args: ['.']
20+
- id: mypy-hatch
21+
name: mypy-hatch
22+
language: system
23+
pass_filenames: false
24+
args: ['.']
25+
entry: hatch
26+
args: ["run", "types:check"]
2527
- repo: https://github.com/compilerla/conventional-pre-commit
2628
rev: v3.3.0
2729
hooks:

docker-compose.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ version: '3.8'
99
services:
1010
airflow:
1111
build: .
12+
#image: pytest74406-airflow:latest
1213
ports:
1314
- "8080:8080"
1415
container_name: airflow
1516
restart: always
16-
depends_on:
17-
- db
1817
volumes:
1918
- ./tests/dags:/opt/airflow/dags
2019
- ./:/opt/dagcellent/

pyproject.toml

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[build-system]
2-
requires = ["hatchling", "versioningit"]
2+
requires = ["hatchling"]
33
build-backend = "hatchling.build"
44

55
[project]
66
name = "dagcellent"
7-
dynamic = ["version"]
7+
version = "0.2.0"
88
description = ''
99
readme = "README.md"
1010
requires-python = ">=3.11"
@@ -28,7 +28,7 @@ classifiers = [
2828
dependencies = [
2929
"apache-airflow>=2.9.1",
3030
"apache-airflow-providers-amazon>=7.2.0",
31-
"apache-airflow-providers-microsoft-mssql >=3.2.0",
31+
"apache-airflow-providers-microsoft-mssql >=4.0.0",
3232
"apache-airflow-providers-snowflake >= 5.1",
3333
"apache-airflow-providers-common-sql[pandas,openlineage]",
3434
"tomli >= 2.0.1",
@@ -64,18 +64,17 @@ extra-dependencies = [
6464
"pre-commit == 3.7.*",
6565
"ruff == 0.4.4",
6666
"mypy == 1.10.*",
67-
"versioningit == 3.1.*",
6867
"towncrier == 23.11.*",
6968
]
7069

7170
[tool.hatch.envs.dev.scripts]
7271
install = "pre-commit install --hook-type commit-msg"
73-
version = "versioningit"
7472
changelog = "git-cliff -o CHANGELOG.md"
7573

7674
[tool.hatch.envs.types]
7775
extra-dependencies = [
7876
"mypy>=1.0.0",
77+
"types-requests",
7978
]
8079
[tool.hatch.envs.types.scripts]
8180
check = "mypy --install-types --non-interactive {args:src/dagcellent tests}"
@@ -96,16 +95,15 @@ extra-dependencies = [
9695
"pytest >= 8.0.0",
9796
"pytest-cov",
9897
"pytest-mock >= 3.14.0",
98+
"pytest-docker",
9999
]
100100

101101
[[tool.hatch.envs.test.matrix]]
102102
python = ["3.11", "3.12"]
103103

104104
[tool.hatch.envs.test.scripts]
105-
test = "pytest --cov=dagcellent --cov-report=term-missing --cov-report=xml --cov-report=html tests"
105+
test = "pytest --cov=dagcellent --cov-report=term-missing --cov-report=xml --cov-report=html --container-scope module tests"
106106

107-
[tool.hatch.version]
108-
source = "versioningit"
109107

110108
[tool.coverage.run]
111109
source_pkgs = ["dagcellent", "tests"]
@@ -311,24 +309,4 @@ exclude = [
311309
"^typings/*"
312310
]
313311

314-
[tool.versioningit]
315-
default-version = "0.0.0-unknown"
316-
317-
[tool.versioningit.vcs]
318-
# The method key:
319-
method = "git" # <- The method name
320-
321-
# Parameters to pass to the method:
322-
match = ["v*"]
323-
default-tag = "0.0.1"
324-
325-
[tool.hatch.build.hooks.versioningit-onbuild]
326-
source-file = "src/dagcellent/__init__.py"
327-
build-file = "dagcellent/__init__.py"
328-
require-match = true
329-
330-
[tool.versioningit.format]
331-
distance = "{next_version}.dev{distance}+{vcs}{rev}"
332-
dirty = "{version}+dirty"
333-
distance-dirty = "{next_version}.dev{distance}+{vcs}{rev}.dirty"
334312

src/dagcellent/operators/mlflow/hooks.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
def _mlflow_request_wrapper(query: Callable[P, T]) -> T:
3131
"""Wrap requests around MlflowException."""
3232
try:
33-
res = query()
33+
res = query() # type: ignore[call-arg]
3434
except mlflow.MlflowException as exc:
3535
_msg = "Error during mlflow query."
3636
_LOGGER.error(_msg, exc_info=exc)
@@ -63,11 +63,17 @@ def get_latest_model_version(
6363
6464
Returns:
6565
dict: hashmap with version and run_id of latest model
66+
67+
Raises:
68+
ValueError: No model returned with given name.
6669
"""
6770
query = functools.partial(
6871
self.client.search_model_versions, f"name = '{model_name}'"
6972
)
7073
model_reigstry_info = _mlflow_request_wrapper(query)
74+
if len(model_reigstry_info) == 0:
75+
_msg = f"No models found with name {model_name}."
76+
raise ValueError(_msg)
7177
latest_version = functools.reduce(
7278
lambda x, y: x if int(x.version) > int(y.version) else y,
7379
model_reigstry_info,
@@ -110,6 +116,11 @@ def transition_model_version_stage(
110116
Returns:
111117
mlflow.entities.model_registry.ModelVersion: model version
112118
"""
119+
warn(
120+
"This function is deprecated in mlflow 2.9.0",
121+
DeprecationWarning,
122+
stacklevel=2,
123+
)
113124
query = functools.partial(
114125
self.client.transition_model_version_stage,
115126
name,
@@ -135,7 +146,11 @@ def get_latest_versions(
135146
Returns:
136147
list[mlflow.entities.model_registry.ModelVersion]: list of model versions
137148
"""
138-
warn("This is deprecated in mlflow 2.9.0", DeprecationWarning, stacklevel=2)
149+
warn(
150+
"This function is deprecated in mlflow 2.9.0",
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
139154
query = functools.partial(self.client.get_latest_versions, name, stages)
140155
return _mlflow_request_wrapper(query)
141156

tests/integration/docker-compose.override.mssql.yaml

Lines changed: 0 additions & 39 deletions
This file was deleted.

tests/integration/docker-compose.override.psql.yaml

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from warnings import warn
5+
6+
import pytest
7+
import requests
8+
from requests.exceptions import ConnectionError
9+
10+
11+
def _is_responsive(url: str):
12+
try:
13+
response = requests.get(url)
14+
if response.status_code == 200:
15+
return True
16+
except ConnectionError:
17+
return False
18+
19+
20+
def _remove_path(path: Path):
21+
if path.is_file() or path.is_symlink():
22+
path.unlink()
23+
return
24+
for p in path.iterdir():
25+
_remove_path(p)
26+
path.rmdir()
27+
28+
29+
@pytest.fixture(scope="session")
30+
def docker_compose_file(pytestconfig: pytest.Config):
31+
return [
32+
str(pytestconfig.rootpath / "docker-compose.yaml"),
33+
str(Path(__file__).parent / "docker-compose.override.mlflow.yaml"),
34+
]
35+
36+
37+
@pytest.fixture(scope="module", autouse=True)
38+
def mlflow_service(docker_ip, docker_services, pytestconfig: pytest.Config):
39+
"""Ensure that HTTP service is up and responsive."""
40+
41+
# `port_for` takes a container port and returns the corresponding host port
42+
port = docker_services.port_for("mlflow", 5000)
43+
url = f"http://{docker_ip}:{port}"
44+
docker_services.wait_until_responsive(
45+
timeout=30.0, pause=0.1, check=lambda: _is_responsive(url)
46+
)
47+
yield url
48+
# after run, remove the root 'mlruns' folder
49+
_p = pytestconfig.rootpath / "mlruns"
50+
if not _p.exists():
51+
return
52+
53+
if _p.is_dir():
54+
_remove_path(_p)
55+
return
56+
warn(
57+
"Could not clean up `mlruns` artifact. This might cause unexpected"
58+
"behaviour in the next test suite execution.",
59+
stacklevel=2,
60+
)
61+
62+
63+
@pytest.fixture(scope="module")
64+
def mlflow_hook(mlflow_service: str):
65+
"""Get MLFlow client wrapper."""
66+
from dagcellent.operators.mlflow.hooks import MlflowHook
67+
68+
return MlflowHook(mlflow_service)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
version: '3.8'
2+
3+
services:
4+
# minio:
5+
# image: minio/minio:latest
6+
# ports:
7+
# - "9000:9000"
8+
# environment:
9+
# - "MINIO_ACCESS_KEY=mlflow_access_key"
10+
# - "MINIO_SECRET_KEY=mlflow_secret_key"
11+
# volumes:
12+
# - minio_data:/data
13+
# command: server /data
14+
#
15+
mlflow:
16+
image: ghcr.io/mlflow/mlflow:v2.20.3
17+
# build:
18+
# context: ..
19+
# dockerfile: ghcr.io/mlflow/mlflow:v2.20.3
20+
# args:
21+
# - MLFLOW_VERSION=">=2.2"
22+
container_name: mlflow
23+
restart: on-failure
24+
ports:
25+
- "5005:5000"
26+
#depends_on:
27+
# - postgres
28+
environment:
29+
- GUNICORN_CMD_ARGS="--bind=0.0.0.0"
30+
# - MLFLOW_S3_ENDPOINT_URL=http://minio:9000
31+
command: mlflow server --backend-store-uri=sqlite:///mlruns.db --default-artifact-root=file:mlruns --host 0.0.0.0 --port 5000
32+
#command: mlflow ui --host 0.0.0.0 --port 8888
33+
# command: mlflow server --backend-store-uri postgresql://mlflow:mlflow@postgres:5432/mlflow --default-artifact-root s3://mlflow/ --host 0.0.0.0
34+
35+
#volumes:
36+
# postgres-db-volume:
37+
# logs:
38+
# minio_data:
39+

0 commit comments

Comments
 (0)