Skip to content

Commit 2e78c7f

Browse files
committed
Merge remote-tracking branch 'upstream/main' into files_folder_pattern
2 parents 62efeec + 003041f commit 2e78c7f

File tree

13 files changed

+481
-37
lines changed

13 files changed

+481
-37
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ venv/
3232
# Aim
3333
.aim
3434

35+
# Mlflow
36+
mlruns/
37+
3538
# Backup files and folders
3639
*.bkp
3740
*.bkp.*

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,12 +823,13 @@ For details about how you can use set a custom stopping criteria and perform cus
823823

824824
## Experiment Tracking
825825

826-
Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/) or custom trackers built into the code like
826+
Experiment tracking in fms-hf-tuning allows users to track their experiments with known trackers like [Aimstack](https://aimstack.io/), [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) or custom trackers built into the code like
827827
[FileLoggingTracker](./tuning/trackers/filelogging_tracker.py)
828828

829829
The code supports currently two trackers out of the box,
830830
* `FileLoggingTracker` : A built in tracker which supports logging training loss to a file.
831831
* `Aimstack` : A popular opensource tracker which can be used to track any metrics or metadata from the experiments.
832+
* `MLflow Tracking` : Another popular opensource tracker which stores metrics, metadata or even artifacts from experiments.
832833

833834
Further details on enabling and using the trackers mentioned above can be found [here](docs/experiment-tracking.md).
834835

build/Dockerfile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ ARG USER=tuning
1919
ARG USER_UID=1000
2020
ARG PYTHON_VERSION=3.11
2121
ARG WHEEL_VERSION=""
22-
## Enable Aimstack if requested via ENABLE_AIM set to "true"
22+
## Enable Aimstack or MLflow if requested via ENABLE_AIM/MLFLOW set to "true"
2323
ARG ENABLE_AIM=false
24+
ARG ENABLE_MLFLOW=false
2425
ARG ENABLE_FMS_ACCELERATION=true
2526

2627
## Base Layer ##################################################################
@@ -151,6 +152,10 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
151152
python -m pip install --user "$(head bdist_name)[aim]"; \
152153
fi
153154

155+
RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
156+
python -m pip install --user "$(head bdist_name)[mlflow]"; \
157+
fi
158+
154159
# Clean up the wheel module. It's only needed by flash-attn install
155160
RUN python -m pip uninstall wheel build -y && \
156161
# Cleanup the bdist whl file

docs/experiment-tracking.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,34 @@ sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....
115115
The code expects either the `local` or `remote` repo to be specified and will result in a `ValueError` otherwise.
116116
See [AimConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/tracker_configs.py#L25) for more details.
117117

118+
## MLflow Tracker
119+
120+
To enable [MLflow Tracking](https://mlflow.org/docs/latest/tracking.html) users need to pass `"mlflow"` as the requested tracker as part of the [training argument](https://github.com/foundation-model-stack/fms-hf-tuning/blob/a9b8ec8d1d50211873e63fa4641054f704be8712/tuning/config/configs.py#L131).
121+
122+
123+
When using MLflow, users need to specify additional arguments which specify [mlflow tracking uri](https://mlflow.org/docs/latest/tracking.html#common-setups) location where either a [mlflow supported database](https://mlflow.org/docs/latest/tracking/backend-stores.html#supported-store-types) or [mlflow remote tracking server](https://mlflow.org/docs/latest/tracking/server.html) is running.
124+
125+
Example
126+
```
127+
from tuning import sft_trainer
128+
from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory
129+
130+
training_args = TrainingArguments(
131+
...,
132+
trackers = ["mlflow"],
133+
)
134+
135+
tracker_configs = TrackerConfigFactory(
136+
mlflow_config=MLflowConfig(
137+
mlflow_experiment="experiment-name",
138+
mlflow_tracking_uri=<tracking uri>
139+
)
140+
)
141+
142+
sft_trainer.train(train_args=training_args, tracker_configs=tracker_configs,....)
143+
```
144+
145+
The code expects a valid uri to be specified and will result in a `ValueError` otherwise.
118146

119147
## Running the code via command line `tuning/sft_trainer::main` function
120148

@@ -123,10 +151,10 @@ If running the code via main function of [sft_trainer.py](../tuning/sft_trainer.
123151
To enable tracking please pass
124152

125153
```
126-
--tracker <aim/file_logger>
154+
--tracker <aim/file_logger/mlflow>
127155
```
128156

129-
To further customise tracking you can specify additional arguments needed by the tracker like
157+
To further customise tracking you can specify additional arguments needed by the tracker like (example shows aim follow similarly for mlflow)
130158

131159
```
132160
--tracker aim --aim_repo <path-to-aimrepo> --experiment <experiment-name>

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies = [
4444
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<25", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"]
4545
flash-attn = ["flash-attn>=2.5.3,<3.0"]
4646
aim = ["aim>=3.19.0,<4.0"]
47+
mlflow = ["mlflow"]
4748
fms-accel = ["fms-acceleration>=0.1"]
4849
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
4950

tests/test_sft_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def test_parse_arguments(job_config):
361361
_,
362362
_,
363363
_,
364+
_,
364365
) = sft_trainer.parse_arguments(parser, job_config_copy)
365366
assert str(model_args.torch_dtype) == "torch.bfloat16"
366367
assert data_args.dataset_text_field == "output"
@@ -386,6 +387,7 @@ def test_parse_arguments_defaults(job_config):
386387
_,
387388
_,
388389
_,
390+
_,
389391
) = sft_trainer.parse_arguments(parser, job_config_defaults)
390392
assert str(model_args.torch_dtype) == "torch.bfloat16"
391393
assert model_args.use_flash_attn is False
@@ -396,14 +398,14 @@ def test_parse_arguments_peft_method(job_config):
396398
parser = sft_trainer.get_parser()
397399
job_config_pt = copy.deepcopy(job_config)
398400
job_config_pt["peft_method"] = "pt"
399-
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
401+
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
400402
parser, job_config_pt
401403
)
402404
assert isinstance(tune_config, peft_config.PromptTuningConfig)
403405

404406
job_config_lora = copy.deepcopy(job_config)
405407
job_config_lora["peft_method"] = "lora"
406-
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
408+
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
407409
parser, job_config_lora
408410
)
409411
assert isinstance(tune_config, peft_config.LoraConfig)

tests/trackers/test_aim_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def fixture_aimrepo():
5858

5959

6060
@pytest.mark.skipif(aim_not_available, reason="Requires aimstack to be installed")
61-
def test_run_with_good_tracker_name_but_no_args():
61+
def test_run_with_aim_tracker_name_but_no_args():
6262
"""Ensure that train() raises error with aim tracker name but no args"""
6363

6464
with tempfile.TemporaryDirectory() as tempdir:
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# SPDX-License-Identifier: Apache-2.0
16+
# https://spdx.dev/learn/handling-license-info/
17+
18+
# Standard
19+
import copy
20+
import json
21+
import os
22+
import tempfile
23+
24+
# Third Party
25+
from transformers.utils.import_utils import _is_package_available
26+
import pytest
27+
28+
# First Party
29+
from tests.test_sft_trainer import (
30+
DATA_ARGS,
31+
MODEL_ARGS,
32+
TRAIN_ARGS,
33+
_get_checkpoint_path,
34+
_test_run_inference,
35+
_validate_training,
36+
)
37+
38+
# Local
39+
from tuning import sft_trainer
40+
from tuning.config.tracker_configs import MLflowConfig, TrackerConfigFactory
41+
42+
mlflow_not_available = not _is_package_available("mlflow")
43+
44+
45+
@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
46+
def test_run_with_mlflow_tracker_name_but_no_args():
47+
"""Ensure that train() raises error with mlflow tracker name but no args"""
48+
49+
with tempfile.TemporaryDirectory() as tempdir:
50+
train_args = copy.deepcopy(TRAIN_ARGS)
51+
train_args.output_dir = tempdir
52+
53+
train_args.trackers = ["mlflow"]
54+
55+
with pytest.raises(
56+
ValueError,
57+
match="mlflow tracker requested but mlflow_uri is not specified.",
58+
):
59+
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)
60+
61+
62+
@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
63+
def test_e2e_run_with_mlflow_tracker():
64+
"""Ensure that training succeeds with mlflow tracker"""
65+
66+
# mlflow performs a cleanup at callback close time which happens post the
67+
# delete of this directory so we run into two issues
68+
# 1. the temp directory cannot be cleared as it has open pointer by mlflow
69+
# 2. mlflow complaints that it cannot find a run which it just created.
70+
# this is a race condition which is fixed with mkdtemp() which doesn't delete
71+
tempdir = tempfile.mkdtemp()
72+
73+
train_args = copy.deepcopy(TRAIN_ARGS)
74+
train_args.output_dir = tempdir
75+
76+
# This should not mean file logger is not present.
77+
# code will add it by default
78+
# The below validate_training check will test for that too.
79+
train_args.trackers = ["mlflow"]
80+
81+
mlflow_path = os.path.join(tempdir, "mlflow")
82+
83+
tracker_configs = TrackerConfigFactory(
84+
mlflow_config=MLflowConfig(
85+
mlflow_experiment="unit_test",
86+
mlflow_tracking_uri=f"file://{mlflow_path}",
87+
)
88+
)
89+
90+
sft_trainer.train(
91+
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
92+
)
93+
94+
# validate ft tuning configs
95+
_validate_training(tempdir)
96+
97+
assert os.path.exists(mlflow_path) and os.path.isdir(mlflow_path)
98+
99+
# validate inference
100+
_test_run_inference(checkpoint_path=_get_checkpoint_path(tempdir))
101+
102+
103+
@pytest.mark.skipif(mlflow_not_available, reason="Requires mlflow to be installed")
104+
def test_e2e_run_with_mlflow_runuri_export_default_path():
105+
"""Ensure that mlflow outputs run uri in the output dir by default"""
106+
107+
tempdir = tempfile.mkdtemp()
108+
train_args = copy.deepcopy(TRAIN_ARGS)
109+
train_args.output_dir = tempdir
110+
111+
train_args.trackers = ["mlflow"]
112+
113+
mlflow_path = os.path.join(tempdir, "mlflow")
114+
115+
tracker_configs = TrackerConfigFactory(
116+
mlflow_config=MLflowConfig(
117+
mlflow_experiment="unit_test",
118+
mlflow_tracking_uri=f"file://{mlflow_path}",
119+
)
120+
)
121+
122+
sft_trainer.train(
123+
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
124+
)
125+
126+
# validate ft tuning configs
127+
_validate_training(tempdir)
128+
129+
assert os.path.exists(mlflow_path) and os.path.isdir(mlflow_path)
130+
131+
run_uri_file = os.path.join(tempdir, "mlflow_tracker.json")
132+
133+
assert os.path.exists(run_uri_file) is True
134+
assert os.path.getsize(run_uri_file) > 0
135+
136+
with open(run_uri_file, "r", encoding="utf-8") as f:
137+
content = json.loads(f.read())
138+
assert "run_uri" in content

tuning/config/tracker_configs.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class FileLoggingTrackerConfig:
2424
@dataclass
2525
class AimConfig:
2626
# Name of the experiment
27-
experiment: str = None
27+
experiment: str = "fms-hf-tuning"
2828
# aim_repo can point to a locally accessible directory
2929
# or a remote repository hosted on a server.
3030
# When 'aim_remote_server_ip' or 'aim_remote_server_port' is set,
@@ -47,9 +47,6 @@ class AimConfig:
4747
aim_run_id_export_path: str = None
4848

4949
def __post_init__(self):
50-
if self.experiment is None:
51-
self.experiment = "fms-hf-tuning"
52-
5350
if (
5451
self.aim_remote_server_ip is not None
5552
and self.aim_remote_server_port is not None
@@ -63,7 +60,23 @@ def __post_init__(self):
6360
)
6461

6562

63+
@dataclass
64+
class MLflowConfig:
65+
# Name of the experiment
66+
mlflow_experiment: str = "fms-hf-tuning"
67+
mlflow_tracking_uri: str = None
68+
# Location of where mlflow's run uri is to be exported.
69+
# If mlflow_run_uri_export_path is set the run uri will be output in a json format
70+
# to the location pointed to by `mlflow_run_uri_export_path/mlflow_tracker.json`
71+
# If this is not set then the default location where run uri will be exported
72+
# is training_args.output_dir/mlflow_tracker.json
73+
# Run uri is not exported if mlflow_run_uri_export_path variable is not set
74+
# and output_dir is not specified.
75+
mlflow_run_uri_export_path: str = None
76+
77+
6678
@dataclass
6779
class TrackerConfigFactory:
6880
file_logger_config: FileLoggingTrackerConfig = None
6981
aim_config: AimConfig = None
82+
mlflow_config: MLflowConfig = None

0 commit comments

Comments
 (0)