Skip to content

Commit 044c86f

Browse files
authored
ITEP-69976: Remove MLFlow from OTX trainer (#490)
1 parent ca00dfa commit 044c86f

File tree

23 files changed

+2421
-2407
lines changed

23 files changed

+2421
-2407
lines changed

deploy/charts/geti-tools/chart/charts/seaweed-fs/values.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ services:
289289
- "Read:temporaryfiles"
290290
- "Write:mlflowexperiments"
291291
- "List:mlflowexperiments"
292+
- "Read:mlflowexperiments"
292293
- "List:vpsreferencefeatures"
293294
- "Read:vpsreferencefeatures"
294295
- "Read:pretrainedweights"

interactive_ai/workflows/geti_domain/common/jobs_common/k8s_helpers/trainer_pod_definition.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,10 @@
4040

4141

4242
def _create_sidecar_env(
43-
organization_id: str,
44-
workspace_id: str,
45-
project_id: str,
46-
job_id: str,
43+
identifier_json: str,
4744
namespace: str,
4845
role: str = "training_operator",
4946
) -> list[V1EnvVar]:
50-
identifier_json = json.dumps(
51-
{
52-
"organization_id": organization_id,
53-
"workspace_id": workspace_id,
54-
"project_id": project_id,
55-
"job_id": job_id,
56-
}
57-
)
58-
5947
# NOTE: vars below is inherited by the Flyte task who renders this sidecar
6048
var_s3_host = V1EnvVar(
6149
name="S3_HOST",
@@ -165,15 +153,21 @@ def create_flyte_container_task( # noqa: PLR0913
165153
sidecar_container_image = trainer_image_info.to_sidecar_image_full_name()
166154
logger.info(f"Create sidecar_container_image={sidecar_container_image}")
167155

156+
identifier_json = json.dumps(
157+
{
158+
"organization_id": str(session.organization_id),
159+
"workspace_id": str(session.workspace_id),
160+
"project_id": project_id,
161+
"job_id": job_id,
162+
}
163+
)
164+
168165
env_from = [
169166
V1EnvFromSource(config_map_ref=V1ConfigMapEnvSource(name=f"{namespace}-feature-flags")),
170167
V1EnvFromSource(config_map_ref=V1ConfigMapEnvSource(name=f"{namespace}-s3-bucket-names")),
171168
]
172169
sidecar_env = _create_sidecar_env(
173-
organization_id=str(session.organization_id),
174-
workspace_id=str(session.workspace_id),
175-
project_id=project_id,
176-
job_id=job_id,
170+
identifier_json=identifier_json,
177171
namespace=namespace,
178172
)
179173

@@ -218,6 +212,8 @@ def create_flyte_container_task( # noqa: PLR0913
218212
name=PRIMARY_CONTAINER_NAME,
219213
image=primary_container_image,
220214
env=[
215+
# Identifier JSON
216+
V1EnvVar(name="IDENTIFIER_JSON", value=identifier_json),
221217
V1EnvVar(name="SHARD_FILES_DIR", value="/shard_files"),
222218
V1EnvVar(name="MLFLOW_TRACKING_URI", value="http://localhost:5000"),
223219
V1EnvVar(name="MLFLOW_EXPERIMENT_ID", value=project_id),

interactive_ai/workflows/geti_domain/common/jobs_common_extras/mlflow/adapters/geti_otx_interface.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""This module defines a command to prepare MLFlow Experiment directory in the S3 bucket."""
55

6-
import io
76
import json
87
import logging
98
import os
@@ -12,7 +11,6 @@
1211
from typing import Any
1312

1413
import numpy as np
15-
import pyarrow as pa
1614
from geti_telemetry_tools import unified_tracing
1715
from geti_types import ProjectIdentifier
1816
from iai_core.adapters.binary_interpreters import RAWBinaryInterpreter
@@ -24,7 +22,6 @@
2422
from iai_core.entities.model import Model, ModelFormat, ModelOptimizationType, ModelStatus
2523
from iai_core.repos.model_repo import ModelRepo
2624
from iai_core.repos.project_repo import ProjectRepo
27-
from pandas import DataFrame
2825

2926
# NOTE: workaround for CVS-156400 -> the following imports are needed for the workaround
3027
from jobs_common.tasks.utils.progress import report_progress
@@ -40,7 +37,6 @@
4037
MLFlowLifecycleStage,
4138
MLFlowRunStatus,
4239
)
43-
from jobs_common_extras.mlflow.adapters.metrics_mapper import PerformanceDeserializer
4440
from jobs_common_extras.mlflow.repos.binary_repo import MLFlowExperimentBinaryRepo
4541

4642
logger = logging.getLogger(__name__)
@@ -445,42 +441,25 @@ def pull_metrics(self) -> Performance | None:
445441
:return: Performance object, or None if it cannot be loaded.
446442
"""
447443

448-
# Metrics can be found either in outputs/models/performance.pickle or live_metrics/metrics.arrow
449-
model_prefix = os.path.join(self.dst_path_prefix, "outputs", "models")
450-
performance_filepath = os.path.join(model_prefix, "performance-json.bin")
444+
# Metrics can be found in live_metrics/metrics.json
451445
live_metrics_prefix = os.path.join(self.dst_path_prefix, "live_metrics")
452-
metrics_filepath = os.path.join(live_metrics_prefix, "metrics.arrow")
446+
metrics_filepath = os.path.join(live_metrics_prefix, "metrics.json")
453447

454448
performance: Performance | None = None
455-
if self.binary_repo.exists(performance_filepath):
456-
logger.info("Reading performance metrics from %s", performance_filepath)
449+
if self.binary_repo.exists(metrics_filepath):
450+
logger.info("Reading performance metrics from %s", metrics_filepath)
457451
try:
458452
data = self.binary_repo.get_by_filename(
459-
filename=performance_filepath,
453+
filename=metrics_filepath,
460454
binary_interpreter=RAWBinaryInterpreter(),
461455
)
462-
performance = PerformanceDeserializer.backward(json.loads(data.decode()))
463-
except Exception:
464-
logger.exception(f"Failed to extract performance metrics from {performance_filepath}")
465-
elif self.binary_repo.exists(metrics_filepath):
466-
logger.info("Reading performance metrics from %s", metrics_filepath)
467-
try:
468-
obj = self.binary_repo.storage_client.client.get_object( # type: ignore
469-
bucket_name=self.binary_repo.storage_client.bucket_name, # type: ignore
470-
object_name=os.path.join(
471-
self.binary_repo.storage_client.object_name_base, # type: ignore[attr-defined]
472-
metrics_filepath,
473-
), # type: ignore
474-
)
475-
table = pa.ipc.RecordBatchFileReader(io.BytesIO(obj.data)).read_all()
476-
data_frame = table.to_pandas()
477-
performance = self._create_performance_from_arrow(data_frame)
456+
metrics_json = json.loads(data)
457+
performance = self._create_performance_from_json(metrics_json)
478458
except Exception:
479459
logger.exception(f"Failed to extract performance metrics from {metrics_filepath}")
480460
else:
481461
logger.error(
482-
"Cannot find any file to extract performance metrics; both `%s` and `%s` are missing.",
483-
performance_filepath,
462+
"Cannot find file to extract performance metrics; `%s` is missing.",
484463
metrics_filepath,
485464
)
486465

@@ -550,12 +529,9 @@ def _create_progress_json(self) -> dict[str, str | float]:
550529
"progress": 0.0,
551530
}
552531

553-
def _create_performance_from_arrow(self, data_frame: DataFrame) -> Performance:
554-
grouped = data_frame.groupby("key")
555-
532+
def _create_performance_from_json(self, metrics_json: dict[str, list[float]]) -> Performance:
556533
dashboard_metrics = []
557-
for name, group in grouped:
558-
ys = group["value"].tolist()
534+
for name, ys in metrics_json.items():
559535
xs = [float(x) for x in range(1, len(ys) + 1)]
560536
metric = CurveMetric(name=name, ys=ys, xs=xs)
561537

interactive_ai/workflows/geti_domain/common/tests/unit/extras/mlflow/test_adapters.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestGetiOTXInterfaceAdapter:
2323
@pytest.fixture()
2424
def fxt_performance(self):
2525
return Performance(
26-
score=ScoreMetric("dummy", 1.0),
26+
score=ScoreMetric(name="Model accuracy", value=0.5),
2727
dashboard_metrics=[
2828
LineMetricsGroup(
2929
metrics=[CurveMetric(name="dummy", ys=[1, 2, 3], xs=[1, 2, 3])],
@@ -444,29 +444,7 @@ def test_pull_metrics(
444444
# Arrange
445445
mock_project_repo.return_value.get_by_id.return_value = fxt_project
446446
mock_repo.return_value.organization_id = fxt_organization_id
447-
performance_dict = {
448-
"dashboard_metrics": [
449-
{
450-
"metrics": [
451-
{
452-
"name": "dummy",
453-
"type": "curve",
454-
"xs": [1.0, 2.0, 3.0],
455-
"ys": [1.0, 2.0, 3.0],
456-
}
457-
],
458-
"visualization_info": {
459-
"name": "dummy",
460-
"palette": "DEFAULT",
461-
"type": "LINE",
462-
"x_axis_label": "x",
463-
"y_axis_label": "y",
464-
},
465-
}
466-
],
467-
"score": {"label_id": None, "name": "dummy", "type": "score", "value": 1.0},
468-
"type": "Performance",
469-
}
447+
performance_dict = {"dummy": [1.0, 2.0, 3.0]}
470448
mock_repo.return_value.get_by_filename.return_value = json.dumps(performance_dict).encode()
471449

472450
# Act

interactive_ai/workflows/otx_domain/trainer/otx_v2/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ description = "OTX trainer"
55
requires-python = ">=3.10, <3.11"
66

77
dependencies = [
8-
"mlflow==2.19.0",
98
"minio~=7.1.0",
109
"numpy==1.26.4",
1110
"requests==2.32.3",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (C) 2022-2025 Intel Corporation
2+
# LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE
3+
import json
4+
import logging
5+
from argparse import Namespace
6+
from pathlib import Path
7+
from typing import Any
8+
9+
from lightning.pytorch.loggers.logger import Logger
10+
from otx_io import upload_model_artifact
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class OTXMetricsLogger(Logger):
16+
def __init__(self, file_path: Path):
17+
self.file_path = file_path
18+
self.metrics: dict[str, list[float]] = {}
19+
logger.info(f"Writing live metrics to {file_path}")
20+
21+
@property
22+
def name(self) -> str | None:
23+
return None
24+
25+
@property
26+
def version(self) -> int | str | None:
27+
return None
28+
29+
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None: # noqa: ARG002
30+
for key, value in metrics.items():
31+
self.metrics.setdefault(key, []).append(value)
32+
33+
with open(self.file_path, "w") as f:
34+
json.dump(self.metrics, f)
35+
36+
def log_hyperparams(self, params: dict[str, Any] | Namespace, *args: Any, **kwargs: Any) -> None:
37+
pass
38+
39+
def save(self) -> None:
40+
print(self.metrics)
41+
42+
def finalize(self, status: str) -> None: # noqa: ARG002
43+
upload_model_artifact(src_filepath=self.file_path, dst_filepath=Path("live_metrics/metrics.json"))

interactive_ai/workflows/otx_domain/trainer/otx_v2/scripts/minio_util.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

0 commit comments

Comments
 (0)