diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 564e216492c4e..a090f8e6aabd8 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,5 +18,6 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.11, <2.21.0 # for `TensorBoardLogger` +mlflow >=3.0.0, <4.0.0 # for `MLFlowLogger` torch-tensorrt; platform_system == "Linux" and python_version >= "3.12" diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..7f5c95b58c48c 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -233,10 +233,9 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _flatten_dict(params) from mlflow.entities import Param + from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH - # Truncate parameter values to 250 characters. - # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] + params_list = [Param(key=k, value=str(v)[:MAX_PARAM_VAL_LENGTH]) for k, v in params.items()] # Log in chunks of 100 parameters (the maximum allowed by MLflow). for idx in range(0, len(params_list), 100): diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..2908ebf6a09b8 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -317,16 +317,17 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path): - """Test that long parameter values are truncated to 250 characters.""" + """Test that long parameter values are truncated using MLflow's MAX_PARAM_VAL_LENGTH.""" + from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH def _check_value_length(value, *args, **kwargs): - assert len(value) <= 250 + assert len(value) <= MAX_PARAM_VAL_LENGTH mlflow_mock.entities.Param.side_effect = _check_value_length logger = MLFlowLogger("test", save_dir=str(tmp_path)) - params = {"test": "test_param" * 50} + params = {"test": "test_param" * 1000} logger.log_hyperparams(params) # assert_called_once_with() won't properly check the parameter value.