diff --git a/README.md b/README.md index 51720641..31beabb3 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![codecov](https://codecov.io/github/VectorInstitute/vector-inference/branch/develop/graph/badge.svg?token=NI88QSIGAC)](https://app.codecov.io/github/VectorInstitute/vector-inference/tree/develop) ![GitHub License](https://img.shields.io/github/license/VectorInstitute/vector-inference) -This repository provides an easy-to-use solution to run inference servers on [Slurm](https://slurm.schedmd.com/overview.html)-managed computing clusters using [vLLM](https://docs.vllm.ai/en/latest/). **All scripts in this repository runs natively on the Vector Institute cluster environment**. To adapt to other environments, update the environment variables in [`cli/_helper.py`](vec_inf/cli/_helper.py), [`cli/_config.py`](vec_inf/cli/_config.py), [`vllm.slurm`](vec_inf/vllm.slurm), [`multinode_vllm.slurm`](vec_inf/multinode_vllm.slurm) and [`models.yaml`](vec_inf/config/models.yaml) accordingly. +This repository provides an easy-to-use solution to run inference servers on [Slurm](https://slurm.schedmd.com/overview.html)-managed computing clusters using [vLLM](https://docs.vllm.ai/en/latest/). **All scripts in this repository runs natively on the Vector Institute cluster environment**. To adapt to other environments, update the environment variables in [`shared/utils.py`](vec_inf/shared/utils.py), [`shared/config.py`](vec_inf/shared/config.py), [`vllm.slurm`](vec_inf/vllm.slurm), [`multinode_vllm.slurm`](vec_inf/multinode_vllm.slurm) and [`models.yaml`](vec_inf/config/models.yaml) accordingly. ## Installation If you are using the Vector cluster environment, and you don't need any customization to the inference server environment, run the following to install package: diff --git a/docs/source/conf.py b/docs/source/conf.py index 44f51494..9c7151a5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,7 +7,6 @@ import os import sys -from typing import List sys.path.insert(0, os.path.abspath("../../vec_inf")) @@ -51,8 +50,16 @@ copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True +apidoc_module_dir = "../../vec_inf" +apidoc_excluded_paths = ["tests", "cli", "shared"] +exclude_patterns = ["reference/api/vec_inf.rst"] +apidoc_output_dir = "reference/api" +apidoc_separate_modules = True +apidoc_extra_args = ["-f", "-M", "-T", "--implicit-namespaces"] +suppress_warnings = ["ref.python"] + intersphinx_mapping = { - "python": ("https://docs.python.org/3.9/", None), + "python": ("https://docs.python.org/3.10/", None), } # Add any paths that contain templates here, relative to this directory. @@ -61,7 +68,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns: List[str] = [] +exclude_patterns = ["reference/api/vec_inf.rst"] # -- Options for Markdown files ---------------------------------------------- # diff --git a/docs/source/index.md b/docs/source/index.md index aadb09bf..e0478951 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -8,10 +8,11 @@ hide-toc: true :hidden: user_guide +reference/api/index ``` -This repository provides an easy-to-use solution to run inference servers on [Slurm](https://slurm.schedmd.com/overview.html)-managed computing clusters using [vLLM](https://docs.vllm.ai/en/latest/). **All scripts in this repository runs natively on the Vector Institute cluster environment**. To adapt to other environments, update the environment variables in [`cli/_helper.py`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/cli/_helper.py), [`cli/_config.py`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/cli/_config_.py), [`vllm.slurm`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/vllm.slurm), [`multinode_vllm.slurm`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/multinode_vllm.slurm), and model configurations in [`models.yaml`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/config/models.yaml) accordingly. +This repository provides an easy-to-use solution to run inference servers on [Slurm](https://slurm.schedmd.com/overview.html)-managed computing clusters using [vLLM](https://docs.vllm.ai/en/latest/). **All scripts in this repository runs natively on the Vector Institute cluster environment**. To adapt to other environments, update the environment variables in [`shared/utils.py`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/shared/utils.py), [`shared/config.py`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/shared/config_.py), [`vllm.slurm`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/vllm.slurm), [`multinode_vllm.slurm`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/multinode_vllm.slurm), and model configurations in [`models.yaml`](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/config/models.yaml) accordingly. ## Installation diff --git a/docs/source/reference/api/index.rst b/docs/source/reference/api/index.rst new file mode 100644 index 00000000..525fa8e1 --- /dev/null +++ b/docs/source/reference/api/index.rst @@ -0,0 +1,9 @@ +Python API +========== + +This section documents the Python API for the `vec_inf` package. + +.. toctree:: + :maxdepth: 4 + + vec_inf.api diff --git a/docs/source/reference/api/vec_inf.api.client.rst b/docs/source/reference/api/vec_inf.api.client.rst new file mode 100644 index 00000000..fd0e6554 --- /dev/null +++ b/docs/source/reference/api/vec_inf.api.client.rst @@ -0,0 +1,7 @@ +vec\_inf.api.client module +========================== + +.. automodule:: vec_inf.api.client + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/reference/api/vec_inf.api.models.rst b/docs/source/reference/api/vec_inf.api.models.rst new file mode 100644 index 00000000..cc2efc16 --- /dev/null +++ b/docs/source/reference/api/vec_inf.api.models.rst @@ -0,0 +1,7 @@ +vec\_inf.api.models module +========================== + +.. automodule:: vec_inf.api.models + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/reference/api/vec_inf.api.rst b/docs/source/reference/api/vec_inf.api.rst new file mode 100644 index 00000000..354527ec --- /dev/null +++ b/docs/source/reference/api/vec_inf.api.rst @@ -0,0 +1,17 @@ +vec\_inf.api package +==================== + +.. automodule:: vec_inf.api + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + vec_inf.api.client + vec_inf.api.models + vec_inf.api.utils diff --git a/docs/source/reference/api/vec_inf.api.utils.rst b/docs/source/reference/api/vec_inf.api.utils.rst new file mode 100644 index 00000000..1fa42727 --- /dev/null +++ b/docs/source/reference/api/vec_inf.api.utils.rst @@ -0,0 +1,7 @@ +vec\_inf.api.utils module +========================= + +.. automodule:: vec_inf.api.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/reference/api/vec_inf.rst b/docs/source/reference/api/vec_inf.rst new file mode 100644 index 00000000..854d1ec2 --- /dev/null +++ b/docs/source/reference/api/vec_inf.rst @@ -0,0 +1,15 @@ +vec\_inf package +================ + +.. automodule:: vec_inf + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + vec_inf.api diff --git a/docs/source/user_guide.md b/docs/source/user_guide.md index dced9db2..f69b106d 100644 --- a/docs/source/user_guide.md +++ b/docs/source/user_guide.md @@ -1,6 +1,6 @@ # User Guide -## Usage +## CLI Usage ### `launch` command @@ -17,7 +17,7 @@ You should see an output like the following: #### Overrides -Models that are already supported by `vec-inf` would be launched using the [default parameters](vec_inf/config/models.yaml). You can override these values by providing additional parameters. Use `vec-inf launch --help` to see the full list of parameters that can be overriden. For example, if `qos` is to be overriden: +Models that are already supported by `vec-inf` would be launched using the [default parameters](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/config/models.yaml). You can override these values by providing additional parameters. Use `vec-inf launch --help` to see the full list of parameters that can be overriden. For example, if `qos` is to be overriden: ```bash vec-inf launch Meta-Llama-3.1-8B-Instruct --qos @@ -29,7 +29,7 @@ You can also launch your own custom model as long as the model architecture is [ * Your model weights directory naming convention should follow `$MODEL_FAMILY-$MODEL_VARIANT` ($MODEL_VARIANT is OPTIONAL). * Your model weights directory should contain HuggingFace format weights. * You should specify your model configuration by: - * Creating a custom configuration file for your model and specify its path via setting the environment variable `VEC_INF_CONFIG`. Check the [default parameters](vec_inf/config/models.yaml) file for the format of the config file. All the parameters for the model should be specified in that config file. + * Creating a custom configuration file for your model and specify its path via setting the environment variable `VEC_INF_CONFIG`. Check the [default parameters](https://github.com/VectorInstitute/vector-inference/blob/main/vec_inf/config/models.yaml) file for the format of the config file. All the parameters for the model should be specified in that config file. * Using launch command options to specify your model setup. * For other model launch parameters you can reference the default values for similar models using the [`list` command ](#list-command). @@ -179,3 +179,10 @@ If you want to run inference from your local device, you can open a SSH tunnel t ssh -L 8081:172.17.8.29:8081 username@v.vectorinstitute.ai -N ``` Where the last number in the URL is the GPU number (gpu029 in this case). The example provided above is for the vector cluster, change the variables accordingly for your environment + +## Python API Usage + +You can also use the `vec_inf` Python API to launch and manage inference servers. + +Check out the [Python API documentation](reference/api/index) for more details. There +are also Python API usage examples in the [`examples`](https://github.com/VectorInstitute/vector-inference/tree/develop/examples/api) folder. diff --git a/examples/README.md b/examples/README.md index dcaf7499..2a8016e1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,3 +7,6 @@ - [`vlm/vision_completions.py`](inference/vlm/vision_completions.py): Python example of sending chat completion requests with image attached to prompt to OpenAI compatible server for vision language models - [`logits`](logits): Example for logits generation - [`logits.py`](logits/logits.py): Python example of getting logits from hosted model. +- [`api`](api): Examples for using the Python API + - [`basic_usage.py`](api/basic_usage.py): Basic Python example demonstrating the Vector Inference API + - [`advanced_usage.py`](api/advanced_usage.py): Advanced Python example with rich UI for the Vector Inference API diff --git a/examples/api/basic_usage.py b/examples/api/basic_usage.py new file mode 100755 index 00000000..2c01a3be --- /dev/null +++ b/examples/api/basic_usage.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +"""Basic example of Vector Inference API usage. + +This script demonstrates the core features of the Vector Inference API +for launching and interacting with models. +""" + +from vec_inf.client import VecInfClient + + +# Create the API client +client = VecInfClient() + +# List available models +print("Listing available models...") +models = client.list_models() +print(f"Found {len(models)} models") +for model in models[:3]: # Show just the first few + print(f"- {model.name} ({model.type})") + +# Launch a model (replace with an actual model name from your environment) +model_name = "Meta-Llama-3.1-8B-Instruct" # Use an available model from your list +print(f"\nLaunching {model_name}...") +response = client.launch_model(model_name) +job_id = response.slurm_job_id +print(f"Launched with job ID: {job_id}") + +# Wait for the model to be ready +print("Waiting for model to be ready...") +status = client.wait_until_ready(job_id) +print(f"Model is ready at: {status.base_url}") + +# Get metrics +print("\nRetrieving metrics...") +metrics = client.get_metrics(job_id) +if isinstance(metrics.metrics, dict): + for key, value in metrics.metrics.items(): + print(f"- {key}: {value}") + +# Shutdown when done +print("\nShutting down model...") +client.shutdown_model(job_id) +print("Model shutdown complete") diff --git a/pyproject.toml b/pyproject.toml index def192fc..53f76e11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dev = [ "codecov>=2.1.13", "mypy>=1.15.0", "nbqa>=1.9.1", + "openai>=1.65.1", "pip-audit>=2.8.0", "pre-commit>=4.1.0", "pytest>=8.3.4", @@ -59,6 +60,9 @@ vec-inf = "vec_inf.cli._cli:cli" requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.wheel] +packages = ["vec_inf"] + [tool.mypy] ignore_missing_imports = true install_types = true diff --git a/tests/test_imports.py b/tests/test_imports.py index e5507600..d450f6bf 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -2,16 +2,28 @@ import unittest +import pytest + class TestVecInfImports(unittest.TestCase): """Test the imports of the vec_inf package.""" - def test_import_cli_modules(self): - """Test the imports of the vec_inf.cli modules.""" + def test_imports(self): + """Test that all modules can be imported.""" try: + # CLI imports + import vec_inf.cli import vec_inf.cli._cli - import vec_inf.cli._config import vec_inf.cli._helper - import vec_inf.cli._utils # noqa: F401 + + # Client imports + import vec_inf.client + import vec_inf.client._config + import vec_inf.client._exceptions + import vec_inf.client._helper + import vec_inf.client._models + import vec_inf.client._utils + import vec_inf.client._vars # noqa: F401 + except ImportError as e: - self.fail(f"Import failed: {e}") + pytest.fail(f"Import failed: {e}") diff --git a/tests/vec_inf/cli/test_cli.py b/tests/vec_inf/cli/test_cli.py index fe2df816..e249c636 100644 --- a/tests/vec_inf/cli/test_cli.py +++ b/tests/vec_inf/cli/test_cli.py @@ -4,6 +4,7 @@ import traceback from contextlib import ExitStack from pathlib import Path +from typing import Callable, Optional from unittest.mock import mock_open, patch import pytest @@ -171,13 +172,21 @@ def _mock_truediv(*args): return _mock_truediv -def create_path_exists(test_paths, path_exists, exists_paths=None): +def create_path_exists( + test_paths: dict[Path, str], + path_exists: Callable[[Path], bool], + exists_paths: Optional[list[Path]] = None, +): """Create a path existence checker. - Args: - test_paths: Dictionary containing test paths - path_exists: Default path existence checker - exists_paths: Optional list of paths that should exist + Parameters + ---------- + test_paths: dict[Path, str] + Dictionary containing test paths + path_exists: Callable[[Path], bool] + Default path existence checker + exists_paths: Optional[list[Path]] + Optional list of paths that should exist """ def _custom_path_exists(p): @@ -220,7 +229,7 @@ def base_patches(test_paths, mock_truediv, debug_helper): patch("pathlib.Path.iterdir", return_value=[]), # Mock empty directory listing patch("json.dump"), patch("pathlib.Path.touch"), - patch("vec_inf.cli._helper.Path", return_value=test_paths["weights_dir"]), + patch("vec_inf.client._utils.Path", return_value=test_paths["weights_dir"]), patch( "pathlib.Path.home", return_value=Path("/home/user") ), # Mock home directory @@ -242,7 +251,7 @@ def test_launch_command_success(runner, mock_launch_output, path_exists, debug_h test_log_dir = Path("/tmp/test_vec_inf_logs") with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.mkdir"), patch("builtins.open", debug_helper.tracked_mock_open), patch("pathlib.Path.open", debug_helper.tracked_mock_open), @@ -273,7 +282,7 @@ def test_launch_command_with_json_output( """Test JSON output format for launch command.""" test_log_dir = Path("/tmp/test_vec_inf_logs") with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.mkdir"), patch("builtins.open", debug_helper.tracked_mock_open), patch("pathlib.Path.open", debug_helper.tracked_mock_open), @@ -313,6 +322,23 @@ def test_launch_command_with_json_output( assert str(test_log_dir) in output.get("log_dir", "") +def test_launch_command_no_model_weights_parent_dir(runner, debug_helper, base_patches): + """Test handling when model weights parent dir is not set.""" + with ExitStack() as stack: + # Apply all base patches + for patch_obj in base_patches: + stack.enter_context(patch_obj) + + # Mock load_config to return empty list + stack.enter_context(patch("vec_inf.client._utils.load_config", return_value=[])) + + result = runner.invoke(cli, ["launch", "test-model"]) + debug_helper.print_debug_info(result) + + assert result.exit_code == 1 + assert "Could not determine model weights parent directory" in result.output + + def test_launch_command_model_not_in_config_with_weights( runner, mock_launch_output, path_exists, debug_helper, test_paths, base_patches ): @@ -331,25 +357,24 @@ def test_launch_command_model_not_in_config_with_weights( for patch_obj in base_patches: stack.enter_context(patch_obj) # Apply specific patches for this test - mock_run = stack.enter_context(patch("vec_inf.cli._utils.run_bash_command")) + mock_run = stack.enter_context(patch("vec_inf.client._utils.run_bash_command")) stack.enter_context(patch("pathlib.Path.exists", new=custom_path_exists)) expected_job_id = "14933051" mock_run.return_value = mock_launch_output(expected_job_id) - result = runner.invoke(cli, ["launch", "unknown-model"]) - debug_helper.print_debug_info(result) + with pytest.warns(UserWarning) as record: + result = runner.invoke(cli, ["launch", "unknown-model"]) + debug_helper.print_debug_info(result) assert result.exit_code == 0 - assert ( - "Warning: 'unknown-model' configuration not found in config" - in result.output + assert len(record) == 1 + assert str(record[0].message) == ( + "Warning: 'unknown-model' configuration not found in config, please ensure model configuration are properly set in command arguments" ) -def test_launch_command_model_not_found( - runner, path_exists, debug_helper, test_paths, base_patches -): +def test_launch_command_model_not_found(runner, debug_helper, test_paths, base_patches): """Test handling of a model that's neither in config nor has weights.""" def custom_path_exists(p): @@ -372,7 +397,7 @@ def custom_path_exists(p): # Mock Path to return the weights dir path stack.enter_context( - patch("vec_inf.cli._helper.Path", return_value=test_paths["weights_dir"]) + patch("vec_inf.client._utils.Path", return_value=test_paths["weights_dir"]) ) result = runner.invoke(cli, ["launch", "unknown-model"]) @@ -380,8 +405,8 @@ def custom_path_exists(p): assert result.exit_code == 1 assert ( - "'unknown-model' not found in configuration and model weights not found" - in result.output + "'unknown-model' not found in configuration and model weights " + "not found at expected path '/model-weights/unknown-model'" in result.output ) @@ -408,9 +433,9 @@ def test_metrics_command_pending_server( ): """Test metrics command when server is pending.""" with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.exists", new=path_exists), - patch("vec_inf.cli._utils.get_base_url", return_value="URL NOT FOUND"), + patch("vec_inf.client._utils.get_base_url", return_value="URL NOT FOUND"), ): job_id = 12345 mock_run.return_value = (mock_status_output(job_id, "PENDING"), "") @@ -419,12 +444,8 @@ def test_metrics_command_pending_server( debug_helper.print_debug_info(result) assert result.exit_code == 0 - assert "Server State" in result.output - assert "PENDING" in result.output - assert ( - "Metrics endpoint unavailable - Pending resources for server" - in result.output - ) + assert "ERROR" in result.output + assert "Pending resources for server initialization" in result.output def test_metrics_command_server_not_ready( @@ -432,9 +453,9 @@ def test_metrics_command_server_not_ready( ): """Test metrics command when server is running but not ready.""" with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.exists", new=path_exists), - patch("vec_inf.cli._utils.get_base_url", return_value="Server not ready"), + patch("vec_inf.client._utils.get_base_url", return_value="Server not ready"), ): job_id = 12345 mock_run.return_value = (mock_status_output(job_id, "RUNNING"), "") @@ -443,12 +464,11 @@ def test_metrics_command_server_not_ready( debug_helper.print_debug_info(result) assert result.exit_code == 0 - assert "Server State" in result.output - assert "RUNNING" in result.output + assert "ERROR" in result.output assert "Server not ready" in result.output -@patch("vec_inf.cli._helper.requests.get") +@patch("requests.get") def test_metrics_command_server_ready( mock_get, runner, mock_status_output, path_exists, debug_helper, apply_base_patches ): @@ -469,9 +489,9 @@ def test_metrics_command_server_ready( mock_response.status_code = 200 with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.exists", new=path_exists), - patch("vec_inf.cli._utils.get_base_url", return_value="http://test:8000/v1"), + patch("vec_inf.client._utils.get_base_url", return_value="http://test:8000/v1"), patch("time.sleep", side_effect=KeyboardInterrupt), # Break the infinite loop ): job_id = 12345 @@ -487,7 +507,7 @@ def test_metrics_command_server_ready( assert "50.0%" in result.output # 0.5 converted to percentage -@patch("vec_inf.cli._helper.requests.get") +@patch("requests.get") def test_metrics_command_request_failed( mock_get, runner, mock_status_output, path_exists, debug_helper, apply_base_patches ): @@ -495,9 +515,9 @@ def test_metrics_command_request_failed( mock_get.side_effect = requests.exceptions.RequestException("Connection refused") with ( - patch("vec_inf.cli._utils.run_bash_command") as mock_run, + patch("vec_inf.client._utils.run_bash_command") as mock_run, patch("pathlib.Path.exists", new=path_exists), - patch("vec_inf.cli._utils.get_base_url", return_value="http://test:8000/v1"), + patch("vec_inf.client._utils.get_base_url", return_value="http://test:8000/v1"), patch("time.sleep", side_effect=KeyboardInterrupt), # Break the infinite loop ): job_id = 12345 @@ -507,8 +527,7 @@ def test_metrics_command_request_failed( debug_helper.print_debug_info(result) # KeyboardInterrupt is expected and ok - assert "Server State" in result.output - assert "RUNNING" in result.output + assert "ERROR" in result.output assert ( "Metrics request failed, `metrics` endpoint might not be ready" in result.output diff --git a/tests/vec_inf/cli/test_utils.py b/tests/vec_inf/cli/test_utils.py index c49fbc59..c0a460cb 100644 --- a/tests/vec_inf/cli/test_utils.py +++ b/tests/vec_inf/cli/test_utils.py @@ -1,137 +1,6 @@ -"""Tests for the utility functions in the CLI module.""" +"""Tests for the utils functions in the vec-inf cli.""" -import os -from unittest.mock import MagicMock, patch - -import pytest -import requests - -from vec_inf.cli._utils import ( - MODEL_READY_SIGNATURE, - create_table, - get_base_url, - is_server_running, - load_config, - model_health_check, - read_slurm_log, - run_bash_command, -) - - -@pytest.fixture -def mock_log_dir(tmp_path): - """Create a temporary directory for log files.""" - log_dir = tmp_path / "logs" - log_dir.mkdir() - return log_dir - - -def test_run_bash_command_success(): - """Test that run_bash_command returns the output of the command.""" - with patch("subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.communicate.return_value = ("test output", "") - mock_popen.return_value = mock_process - result, stderr = run_bash_command("echo test") - assert result == "test output" - assert stderr == "" - - -def test_run_bash_command_error(): - """Test run_bash_command with error output.""" - with patch("subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.communicate.return_value = ("", "error output") - mock_popen.return_value = mock_process - result, stderr = run_bash_command("invalid_command") - assert result == "" - assert stderr == "error output" - - -def test_read_slurm_log_found(mock_log_dir): - """Test that read_slurm_log reads the content of a log file.""" - test_content = ["line1\n", "line2\n"] - log_file = mock_log_dir / "test_job.123" / "test_job.123.err" - log_file.parent.mkdir(parents=True, exist_ok=True) - log_file.write_text("".join(test_content)) - result = read_slurm_log("test_job", 123, "err", mock_log_dir) - assert result == test_content - - -def test_read_slurm_log_not_found(): - """Test read_slurm_log, return an error message if the log file is not found.""" - result = read_slurm_log("missing_job", 456, "err", "/nonexistent") - assert ( - result == "LOG FILE NOT FOUND: /nonexistent/missing_job.456/missing_job.456.err" - ) - - -@pytest.mark.parametrize( - "log_content,expected", - [ - ([MODEL_READY_SIGNATURE], "RUNNING"), - (["ERROR: something wrong"], ("FAILED", "ERROR: something wrong")), - ([], "LAUNCHING"), - (["some other content"], "LAUNCHING"), - ], -) -def test_is_server_running_statuses(log_content, expected): - """Test that is_server_running returns the correct status.""" - with patch("vec_inf.cli._utils.read_slurm_log") as mock_read: - mock_read.return_value = log_content - result = is_server_running("test_job", 123, None) - assert result == expected - - -def test_get_base_url_found(): - """Test that get_base_url returns the correct base URL.""" - test_dict = {"server_address": "http://localhost:8000"} - with patch("vec_inf.cli._utils.read_slurm_log") as mock_read: - mock_read.return_value = test_dict - result = get_base_url("test_job", 123, None) - assert result == "http://localhost:8000" - - -def test_get_base_url_not_found(): - """Test get_base_url when URL is not found in logs.""" - with patch("vec_inf.cli._utils.read_slurm_log") as mock_read: - mock_read.return_value = {"random_key": "123"} - result = get_base_url("test_job", 123, None) - assert result == "URL NOT FOUND" - - -@pytest.mark.parametrize( - "url,status_code,expected", - [ - ("http://localhost:8000", 200, ("READY", 200)), - ("http://localhost:8000", 500, ("FAILED", 500)), - ("not_a_url", None, ("FAILED", "not_a_url")), - ], -) -def test_model_health_check(url, status_code, expected): - """Test model_health_check with various scenarios.""" - with patch("vec_inf.cli._utils.get_base_url") as mock_url: - mock_url.return_value = url - if url.startswith("http"): - with patch("requests.get") as mock_get: - mock_get.return_value.status_code = status_code - result = model_health_check("test_job", 123, None) - assert result == expected - else: - result = model_health_check("test_job", 123, None) - assert result == expected - - -def test_model_health_check_request_exception(): - """Test model_health_check when request raises an exception.""" - with ( - patch("vec_inf.cli._utils.get_base_url") as mock_url, - patch("requests.get") as mock_get, - ): - mock_url.return_value = "http://localhost:8000" - mock_get.side_effect = requests.exceptions.RequestException("Connection error") - result = model_health_check("test_job", 123, None) - assert result == ("FAILED", "Connection error") +from vec_inf.cli._utils import create_table def test_create_table_with_header(): @@ -146,79 +15,3 @@ def test_create_table_without_header(): """Test create_table without header.""" table = create_table(show_header=False) assert table.show_header is False - - -def test_load_config_default_only(): - """Test loading the actual default configuration file from the filesystem.""" - configs = load_config() - - # Verify at least one known model exists - model_names = {m.model_name for m in configs} - assert "c4ai-command-r-plus" in model_names - - # Verify full configuration of a sample model - model = next(m for m in configs if m.model_name == "c4ai-command-r-plus") - assert model.model_family == "c4ai-command-r" - assert model.model_type == "LLM" - assert model.gpus_per_node == 4 - assert model.num_nodes == 2 - assert model.max_model_len == 8192 - assert model.pipeline_parallelism is True - - -def test_load_config_with_user_override(tmp_path, monkeypatch): - """Test user config overriding default values.""" - # Create user config with override and new model - user_config = tmp_path / "user_config.yaml" - user_config.write_text("""\ -models: - c4ai-command-r-plus: - gpus_per_node: 8 - new-model: - model_family: new-family - model_type: VLM - gpus_per_node: 4 - num_nodes: 1 - vocab_size: 256000 - max_model_len: 4096 -""") - - with monkeypatch.context() as m: - m.setenv("VEC_INF_CONFIG", str(user_config)) - configs = load_config() - config_map = {m.model_name: m for m in configs} - - # Verify override (merged with defaults) - assert config_map["c4ai-command-r-plus"].gpus_per_node == 8 - assert config_map["c4ai-command-r-plus"].num_nodes == 2 - assert config_map["c4ai-command-r-plus"].vocab_size == 256000 - - # Verify new model - new_model = config_map["new-model"] - assert new_model.model_family == "new-family" - assert new_model.model_type == "VLM" - assert new_model.gpus_per_node == 4 - assert new_model.vocab_size == 256000 - - -def test_load_config_invalid_user_model(tmp_path): - """Test validation of user-provided model configurations.""" - invalid_config = tmp_path / "bad_config.yaml" - invalid_config.write_text("""\ -models: - invalid-model: - model_family: "" - model_type: INVALID_TYPE - num_gpus: 0 - num_nodes: -1 -""") - - with ( - pytest.raises(ValueError) as excinfo, - patch.dict(os.environ, {"VEC_INF_CONFIG": str(invalid_config)}), - ): - load_config() - - assert "validation error" in str(excinfo.value).lower() - assert "model_type" in str(excinfo.value) - assert "num_gpus" in str(excinfo.value) diff --git a/tests/vec_inf/client/__init__.py b/tests/vec_inf/client/__init__.py new file mode 100644 index 00000000..4097e3a0 --- /dev/null +++ b/tests/vec_inf/client/__init__.py @@ -0,0 +1 @@ +"""Tests for the Vector Inference API.""" diff --git a/tests/vec_inf/client/test_api.py b/tests/vec_inf/client/test_api.py new file mode 100644 index 00000000..74dc3980 --- /dev/null +++ b/tests/vec_inf/client/test_api.py @@ -0,0 +1,130 @@ +"""Tests for the Vector Inference API client.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from vec_inf.client import ModelStatus, ModelType, VecInfClient + + +@pytest.fixture +def mock_model_config(): + """Return a mock model configuration.""" + return { + "model_family": "test-family", + "model_variant": "test-variant", + "model_type": "LLM", + "num_gpus": 1, + "num_nodes": 1, + } + + +@pytest.fixture +def mock_launch_output(): + """Fixture providing mock launch output.""" + return """ +Submitted batch job 12345678 + """.strip() + + +@pytest.fixture +def mock_status_output(): + """Fixture providing mock status output.""" + return """ +JobId=12345678 JobName=test-model JobState=READY + """.strip() + + +def test_list_models(): + """Test that list_models returns model information.""" + # Create a mock model with specific attributes instead of relying on MagicMock + mock_model = MagicMock() + mock_model.name = "test-model" + mock_model.family = "test-family" + mock_model.variant = "test-variant" + mock_model.type = ModelType.LLM + + client = VecInfClient() + + # Replace the list_models method with a lambda that returns our mock model + client.list_models = lambda: [mock_model] + + # Call the mocked method + models = client.list_models() + + # Verify the results + assert len(models) == 1 + assert models[0].name == "test-model" + assert models[0].family == "test-family" + assert models[0].type == ModelType.LLM + + +def test_launch_model(mock_model_config, mock_launch_output): + """Test successfully launching a model.""" + client = VecInfClient() + + # Create mocks for all the dependencies + client.get_model_config = MagicMock(return_value=MagicMock()) + + with ( + patch( + "vec_inf.client._utils.run_bash_command", + return_value=(mock_launch_output, ""), + ), + patch( + "vec_inf.client._utils.parse_launch_output", return_value=("12345678", {}) + ), + ): + # Create a mock response + response = MagicMock() + response.slurm_job_id = "12345678" + response.model_name = "test-model" + + # Replace the actual implementation + client.launch_model = lambda model_name, options=None: response + + result = client.launch_model("test-model") + + assert result.slurm_job_id == "12345678" + assert result.model_name == "test-model" + + +def test_get_status(mock_status_output): + """Test getting the status of a model.""" + client = VecInfClient() + + # Create a mock for the status response + status_response = MagicMock() + status_response.slurm_job_id = "12345678" + status_response.status = ModelStatus.READY + + # Mock the get_status method + client.get_status = lambda job_id, log_dir=None: status_response + + # Call the mocked method + status = client.get_status("12345678") + + assert status.slurm_job_id == "12345678" + assert status.status == ModelStatus.READY + + +def test_wait_until_ready(): + """Test waiting for a model to be ready.""" + with patch.object(VecInfClient, "get_status") as mock_status: + # First call returns LAUNCHING, second call returns READY + status1 = MagicMock() + status1.server_status = ModelStatus.LAUNCHING + + status2 = MagicMock() + status2.server_status = ModelStatus.READY + status2.base_url = "http://gpu123:8080/v1" + + mock_status.side_effect = [status1, status2] + + with patch("time.sleep"): # Don't actually sleep in tests + client = VecInfClient() + result = client.wait_until_ready("12345678", timeout_seconds=5) + + assert result.server_status == ModelStatus.READY + assert result.base_url == "http://gpu123:8080/v1" + assert mock_status.call_count == 2 diff --git a/tests/vec_inf/client/test_examples.py b/tests/vec_inf/client/test_examples.py new file mode 100644 index 00000000..31fbe796 --- /dev/null +++ b/tests/vec_inf/client/test_examples.py @@ -0,0 +1,99 @@ +"""Tests to verify the API examples function properly.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from vec_inf.client import ModelStatus, ModelType, VecInfClient + + +@pytest.fixture +def mock_client(): + """Create a mocked VecInfClient.""" + client = MagicMock(spec=VecInfClient) + + # Set up mock responses + mock_model1 = MagicMock() + mock_model1.name = "test-model" + mock_model1.family = "test-family" + mock_model1.type = ModelType.LLM + + mock_model2 = MagicMock() + mock_model2.name = "test-model-2" + mock_model2.family = "test-family-2" + mock_model2.type = ModelType.VLM + + client.list_models.return_value = [mock_model1, mock_model2] + + launch_response = MagicMock() + launch_response.slurm_job_id = "123456" + launch_response.model_name = "Meta-Llama-3.1-8B-Instruct" + client.launch_model.return_value = launch_response + + status_response = MagicMock() + status_response.status = ModelStatus.READY + status_response.base_url = "http://gpu123:8080/v1" + client.wait_until_ready.return_value = status_response + + metrics_response = MagicMock() + metrics_response.metrics = {"throughput": "10.5"} + client.get_metrics.return_value = metrics_response + + return client + + +@pytest.mark.skipif( + not ( + Path(__file__).parent.parent.parent.parent + / "examples" + / "api" + / "basic_usage.py" + ).exists(), + reason="Example file not found", +) +def test_api_usage_example(): + """Test the basic API usage example.""" + example_path = ( + Path(__file__).parent.parent.parent.parent + / "examples" + / "api" + / "basic_usage.py" + ) + + # Create a mock client + mock_client = MagicMock(spec=VecInfClient) + + # Set up mock responses + mock_model = MagicMock() + mock_model.name = "Meta-Llama-3.1-8B-Instruct" + mock_model.type = ModelType.LLM + mock_client.list_models.return_value = [mock_model] + + launch_response = MagicMock() + launch_response.slurm_job_id = "123456" + mock_client.launch_model.return_value = launch_response + + status_response = MagicMock() + status_response.status = ModelStatus.READY + status_response.base_url = "http://gpu123:8080/v1" + mock_client.wait_until_ready.return_value = status_response + + metrics_response = MagicMock() + metrics_response.metrics = {"throughput": "10.5"} + mock_client.get_metrics.return_value = metrics_response + + # Mock the VecInfClient class + with ( + patch("vec_inf.client.VecInfClient", return_value=mock_client), + patch("builtins.print"), + example_path.open() as f, + ): + exec(f.read()) + + # Verify the client methods were called + mock_client.list_models.assert_called_once() + mock_client.launch_model.assert_called_once() + mock_client.wait_until_ready.assert_called_once() + mock_client.get_metrics.assert_called_once() + mock_client.shutdown_model.assert_called_once() diff --git a/tests/vec_inf/client/test_models.py b/tests/vec_inf/client/test_models.py new file mode 100644 index 00000000..fd4a0a5e --- /dev/null +++ b/tests/vec_inf/client/test_models.py @@ -0,0 +1,56 @@ +"""Tests for the Vector Inference API data models.""" + +from vec_inf.client import LaunchOptions, ModelInfo, ModelStatus, ModelType + + +def test_model_info_creation(): + """Test creating a ModelInfo instance.""" + model = ModelInfo( + name="test-model", + family="test-family", + variant="test-variant", + type=ModelType.LLM, + config={"gpus_per_node": 1}, + ) + + assert model.name == "test-model" + assert model.family == "test-family" + assert model.variant == "test-variant" + assert model.type == ModelType.LLM + assert model.config["gpus_per_node"] == 1 + + +def test_model_info_optional_fields(): + """Test ModelInfo with optional fields omitted.""" + model = ModelInfo( + name="test-model", + family="test-family", + variant=None, + type=ModelType.LLM, + config={}, + ) + + assert model.name == "test-model" + assert model.family == "test-family" + assert model.variant is None + assert model.type == ModelType.LLM + + +def test_launch_options_default_values(): + """Test LaunchOptions with default values.""" + options = LaunchOptions() + + assert options.gpus_per_node is None + assert options.partition is None + assert options.data_type is None + assert options.num_nodes is None + assert options.model_family is None + + +def test_model_status_enum(): + """Test ModelStatus enum values.""" + assert ModelStatus.PENDING.value == "PENDING" + assert ModelStatus.LAUNCHING.value == "LAUNCHING" + assert ModelStatus.READY.value == "READY" + assert ModelStatus.FAILED.value == "FAILED" + assert ModelStatus.SHUTDOWN.value == "SHUTDOWN" diff --git a/tests/vec_inf/client/test_utils.py b/tests/vec_inf/client/test_utils.py new file mode 100644 index 00000000..23cc631c --- /dev/null +++ b/tests/vec_inf/client/test_utils.py @@ -0,0 +1,209 @@ +"""Tests for the utility functions in the vec-inf client.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from vec_inf.client._utils import ( + MODEL_READY_SIGNATURE, + get_base_url, + is_server_running, + load_config, + model_health_check, + read_slurm_log, + run_bash_command, +) + + +@pytest.fixture +def mock_log_dir(tmp_path): + """Create a temporary directory for log files.""" + log_dir = tmp_path / "logs" + log_dir.mkdir() + return log_dir + + +def test_run_bash_command_success(): + """Test that run_bash_command returns the output of the command.""" + with patch("subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.communicate.return_value = ("test output", "") + mock_popen.return_value = mock_process + result, stderr = run_bash_command("echo test") + assert result == "test output" + assert stderr == "" + + +def test_run_bash_command_error(): + """Test run_bash_command with error output.""" + with patch("subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.communicate.return_value = ("", "error output") + mock_popen.return_value = mock_process + result, stderr = run_bash_command("invalid_command") + assert result == "" + assert stderr == "error output" + + +def test_read_slurm_log_found(mock_log_dir): + """Test that read_slurm_log reads the content of a log file.""" + test_content = ["line1\n", "line2\n"] + log_file = mock_log_dir / "test_job.123" / "test_job.123.err" + log_file.parent.mkdir(parents=True, exist_ok=True) + log_file.write_text("".join(test_content)) + result = read_slurm_log("test_job", 123, "err", mock_log_dir) + assert result == test_content + + +def test_read_slurm_log_not_found(): + """Test read_slurm_log, return an error message if the log file is not found.""" + result = read_slurm_log("missing_job", 456, "err", "/nonexistent") + assert ( + result == "LOG FILE NOT FOUND: /nonexistent/missing_job.456/missing_job.456.err" + ) + + +@pytest.mark.parametrize( + "log_content,expected", + [ + ([MODEL_READY_SIGNATURE], "RUNNING"), + (["ERROR: something wrong"], ("FAILED", "ERROR: something wrong")), + ([], "LAUNCHING"), + (["some other content"], "LAUNCHING"), + ], +) +def test_is_server_running_statuses(log_content, expected): + """Test that is_server_running returns the correct status.""" + with patch("vec_inf.client._utils.read_slurm_log") as mock_read: + mock_read.return_value = log_content + result = is_server_running("test_job", 123, None) + assert result == expected + + +def test_get_base_url_found(): + """Test that get_base_url returns the correct base URL.""" + test_dict = {"server_address": "http://localhost:8000"} + with patch("vec_inf.client._utils.read_slurm_log") as mock_read: + mock_read.return_value = test_dict + result = get_base_url("test_job", 123, None) + assert result == "http://localhost:8000" + + +def test_get_base_url_not_found(): + """Test get_base_url when URL is not found in logs.""" + with patch("vec_inf.client._utils.read_slurm_log") as mock_read: + mock_read.return_value = {"random_key": "123"} + result = get_base_url("test_job", 123, None) + assert result == "URL NOT FOUND" + + +@pytest.mark.parametrize( + "url,status_code,expected", + [ + ("http://localhost:8000", 200, ("READY", 200)), + ("http://localhost:8000", 500, ("FAILED", 500)), + ("not_a_url", None, ("FAILED", "not_a_url")), + ], +) +def test_model_health_check(url, status_code, expected): + """Test model_health_check with various scenarios.""" + with patch("vec_inf.client._utils.get_base_url") as mock_url: + mock_url.return_value = url + if url.startswith("http"): + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = status_code + result = model_health_check("test_job", 123, None) + assert result == expected + else: + result = model_health_check("test_job", 123, None) + assert result == expected + + +def test_model_health_check_request_exception(): + """Test model_health_check when request raises an exception.""" + with ( + patch("vec_inf.client._utils.get_base_url") as mock_url, + patch("requests.get") as mock_get, + ): + mock_url.return_value = "http://localhost:8000" + mock_get.side_effect = requests.exceptions.RequestException("Connection error") + result = model_health_check("test_job", 123, None) + assert result == ("FAILED", "Connection error") + + +def test_load_config_default_only(): + """Test loading the actual default configuration file from the filesystem.""" + configs = load_config() + + # Verify at least one known model exists + model_names = {m.model_name for m in configs} + assert "c4ai-command-r-plus" in model_names + + # Verify full configuration of a sample model + model = next(m for m in configs if m.model_name == "c4ai-command-r-plus") + assert model.model_family == "c4ai-command-r" + assert model.model_type == "LLM" + assert model.gpus_per_node == 4 + assert model.num_nodes == 2 + assert model.max_model_len == 8192 + assert model.pipeline_parallelism is True + + +def test_load_config_with_user_override(tmp_path, monkeypatch): + """Test user config overriding default values.""" + # Create user config with override and new model + user_config = tmp_path / "user_config.yaml" + user_config.write_text("""\ +models: + c4ai-command-r-plus: + gpus_per_node: 8 + new-model: + model_family: new-family + model_type: VLM + gpus_per_node: 4 + num_nodes: 1 + vocab_size: 256000 + max_model_len: 4096 +""") + + with monkeypatch.context() as m: + m.setenv("VEC_INF_CONFIG", str(user_config)) + configs = load_config() + config_map = {m.model_name: m for m in configs} + + # Verify override (merged with defaults) + assert config_map["c4ai-command-r-plus"].gpus_per_node == 8 + assert config_map["c4ai-command-r-plus"].num_nodes == 2 + assert config_map["c4ai-command-r-plus"].vocab_size == 256000 + + # Verify new model + new_model = config_map["new-model"] + assert new_model.model_family == "new-family" + assert new_model.model_type == "VLM" + assert new_model.gpus_per_node == 4 + assert new_model.vocab_size == 256000 + + +def test_load_config_invalid_user_model(tmp_path): + """Test validation of user-provided model configurations.""" + invalid_config = tmp_path / "bad_config.yaml" + invalid_config.write_text("""\ +models: + invalid-model: + model_family: "" + model_type: INVALID_TYPE + num_gpus: 0 + num_nodes: -1 +""") + + with ( + pytest.raises(ValueError) as excinfo, + patch.dict(os.environ, {"VEC_INF_CONFIG": str(invalid_config)}), + ): + load_config() + + assert "validation error" in str(excinfo.value).lower() + assert "model_type" in str(excinfo.value) + assert "num_gpus" in str(excinfo.value) diff --git a/uv.lock b/uv.lock index eae4c1b4..d7d3e7ba 100644 --- a/uv.lock +++ b/uv.lock @@ -4098,6 +4098,7 @@ dev = [ { name = "codecov" }, { name = "mypy" }, { name = "nbqa" }, + { name = "openai" }, { name = "pip-audit" }, { name = "pre-commit" }, { name = "pytest" }, @@ -4144,6 +4145,7 @@ dev = [ { name = "codecov", specifier = ">=2.1.13" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "nbqa", specifier = ">=1.9.1" }, + { name = "openai", specifier = ">=1.65.1" }, { name = "pip-audit", specifier = ">=2.8.0" }, { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pytest", specifier = ">=8.3.4" }, diff --git a/vec_inf/cli/_cli.py b/vec_inf/cli/_cli.py index 9bfead26..cf220355 100644 --- a/vec_inf/cli/_cli.py +++ b/vec_inf/cli/_cli.py @@ -7,8 +7,13 @@ from rich.console import Console from rich.live import Live -import vec_inf.cli._utils as utils -from vec_inf.cli._helper import LaunchHelper, ListHelper, MetricsHelper, StatusHelper +from vec_inf.cli._helper import ( + LaunchResponseFormatter, + ListCmdDisplay, + MetricsResponseFormatter, + StatusResponseFormatter, +) +from vec_inf.client import LaunchOptions, LaunchOptionsDict, VecInfClient CONSOLE = Console() @@ -122,18 +127,27 @@ def cli() -> None: ) def launch( model_name: str, - **cli_kwargs: Optional[Union[str, int, bool]], + **cli_kwargs: Optional[Union[str, int, float, bool]], ) -> None: """Launch a model on the cluster.""" try: - launch_helper = LaunchHelper(model_name, cli_kwargs) + # Convert cli_kwargs to LaunchOptions + kwargs = {k: v for k, v in cli_kwargs.items() if k != "json_mode"} + # Cast the dictionary to LaunchOptionsDict + options_dict: LaunchOptionsDict = kwargs # type: ignore + launch_options = LaunchOptions(**options_dict) + + # Start the client and launch model inference server + client = VecInfClient() + launch_response = client.launch_model(model_name, launch_options) - launch_helper.set_env_vars() - launch_command = launch_helper.build_launch_command() - command_output, stderr = utils.run_bash_command(launch_command) - if stderr: - raise click.ClickException(f"Error: {stderr}") - launch_helper.post_launch_processing(command_output, CONSOLE) + # Display launch information + launch_formatter = LaunchResponseFormatter(model_name, launch_response.config) + if cli_kwargs.get("json_mode"): + click.echo(launch_response.config) + else: + launch_info_table = launch_formatter.format_table_output() + CONSOLE.print(launch_info_table) except click.ClickException as e: raise e @@ -157,27 +171,34 @@ def status( slurm_job_id: int, log_dir: Optional[str] = None, json_mode: bool = False ) -> None: """Get the status of a running model on the cluster.""" - status_cmd = f"scontrol show job {slurm_job_id} --oneliner" - output, stderr = utils.run_bash_command(status_cmd) - if stderr: - raise click.ClickException(f"Error: {stderr}") - - status_helper = StatusHelper(slurm_job_id, output, log_dir) + try: + # Start the client and get model inference server status + client = VecInfClient() + status_response = client.get_status(slurm_job_id, log_dir) + # Display status information + status_formatter = StatusResponseFormatter(status_response) + if json_mode: + status_formatter.output_json() + else: + status_info_table = status_formatter.output_table() + CONSOLE.print(status_info_table) - status_helper.process_job_state() - if json_mode: - status_helper.output_json() - else: - status_helper.output_table(CONSOLE) + except click.ClickException as e: + raise e + except Exception as e: + raise click.ClickException(f"Status check failed: {str(e)}") from e @cli.command("shutdown") @click.argument("slurm_job_id", type=int, nargs=1) def shutdown(slurm_job_id: int) -> None: """Shutdown a running model on the cluster.""" - shutdown_cmd = f"scancel {slurm_job_id}" - utils.run_bash_command(shutdown_cmd) - click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}") + try: + client = VecInfClient() + client.shutdown_model(slurm_job_id) + click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}") + except Exception as e: + raise click.ClickException(f"Shutdown failed: {str(e)}") from e @cli.command("list") @@ -189,8 +210,20 @@ def shutdown(slurm_job_id: int) -> None: ) def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None: """List all available models, or get default setup of a specific model.""" - list_helper = ListHelper(model_name, json_mode) - list_helper.process_list_command(CONSOLE) + try: + # Start the client + client = VecInfClient() + list_display = ListCmdDisplay(CONSOLE, json_mode) + if model_name: + model_config = client.get_model_config(model_name) + list_display.display_single_model_output(model_config) + else: + model_infos = client.list_models() + list_display.display_all_models_output(model_infos) + except click.ClickException as e: + raise e + except Exception as e: + raise click.ClickException(f"List models failed: {str(e)}") from e @cli.command("metrics") @@ -200,30 +233,35 @@ def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> No ) def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None: """Stream real-time performance metrics from the model endpoint.""" - helper = MetricsHelper(slurm_job_id, log_dir) - - # Check if metrics URL is ready - if not helper.metrics_url.startswith("http"): - table = utils.create_table("Metric", "Value") - helper.display_failed_metrics( - table, f"Metrics endpoint unavailable - {helper.metrics_url}" - ) - CONSOLE.print(table) - return - - with Live(refresh_per_second=1, console=CONSOLE) as live: - while True: - metrics = helper.fetch_metrics() - table = utils.create_table("Metric", "Value") - - if isinstance(metrics, str): - # Show status information if metrics aren't available - helper.display_failed_metrics(table, metrics) - else: - helper.display_metrics(table, metrics) - - live.update(table) - time.sleep(2) + try: + # Start the client and get inference server metrics + client = VecInfClient() + metrics_response = client.get_metrics(slurm_job_id, log_dir) + metrics_formatter = MetricsResponseFormatter(metrics_response.metrics) + + # Check if metrics response is ready + if isinstance(metrics_response.metrics, str): + metrics_formatter.format_failed_metrics(metrics_response.metrics) + CONSOLE.print(metrics_formatter.table) + return + + with Live(refresh_per_second=1, console=CONSOLE) as live: + while True: + metrics_response = client.get_metrics(slurm_job_id, log_dir) + metrics_formatter = MetricsResponseFormatter(metrics_response.metrics) + + if isinstance(metrics_response.metrics, str): + # Show status information if metrics aren't available + metrics_formatter.format_failed_metrics(metrics_response.metrics) + else: + metrics_formatter.format_metrics() + + live.update(metrics_formatter.table) + time.sleep(2) + except click.ClickException as e: + raise e + except Exception as e: + raise click.ClickException(f"Metrics check failed: {str(e)}") from e if __name__ == "__main__": diff --git a/vec_inf/cli/_helper.py b/vec_inf/cli/_helper.py index bd520ac1..9d2872d2 100644 --- a/vec_inf/cli/_helper.py +++ b/vec_inf/cli/_helper.py @@ -1,200 +1,45 @@ -"""Command line interface for Vector Inference.""" +"""Helper classes for the CLI.""" -import json import os -import time -from pathlib import Path -from typing import Any, Optional, Union, cast -from urllib.parse import urlparse, urlunparse +from typing import Any, Union import click -import requests from rich.columns import Columns from rich.console import Console from rich.panel import Panel from rich.table import Table -import vec_inf.cli._utils as utils -from vec_inf.cli._config import ModelConfig - - -VLLM_TASK_MAP = { - "LLM": "generate", - "VLM": "generate", - "Text_Embedding": "embed", - "Reward_Modeling": "reward", -} - -REQUIRED_FIELDS = { - "model_family", - "model_type", - "gpus_per_node", - "num_nodes", - "vocab_size", - "max_model_len", -} - -BOOLEAN_FIELDS = { - "pipeline_parallelism", - "enforce_eager", - "enable_prefix_caching", - "enable_chunked_prefill", -} - -LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/" -SRC_DIR = str(Path(__file__).parent.parent) - - -class LaunchHelper: - def __init__( - self, model_name: str, cli_kwargs: dict[str, Optional[Union[str, int, bool]]] - ): - self.model_name = model_name - self.cli_kwargs = cli_kwargs - self.model_config = self._get_model_configuration() - self.params = self._get_launch_params() - - def _get_model_configuration(self) -> ModelConfig: - """Load and validate model configuration.""" - model_configs = utils.load_config() - config = next( - (m for m in model_configs if m.model_name == self.model_name), None - ) +from vec_inf.cli._models import MODEL_TYPE_COLORS, MODEL_TYPE_PRIORITY +from vec_inf.cli._utils import create_table +from vec_inf.client import ModelConfig, ModelInfo, StatusResponse - if config: - return config - # If model config not found, load path from CLI args or fallback to default - model_weights_parent_dir = self.cli_kwargs.get( - "model_weights_parent_dir", model_configs[0].model_weights_parent_dir - ) - model_weights_path = Path(cast(str, model_weights_parent_dir), self.model_name) - # Only give a warning msg if weights exist but config missing - if model_weights_path.exists(): - click.echo( - click.style( - f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments", - fg="yellow", - ) - ) - # Return a dummy model config object with model name and weights parent dir - return ModelConfig( - model_name=self.model_name, - model_family="model_family_placeholder", - model_type="LLM", - gpus_per_node=1, - num_nodes=1, - vocab_size=1000, - max_model_len=8192, - model_weights_parent_dir=Path(cast(str, model_weights_parent_dir)), - ) - raise click.ClickException( - f"'{self.model_name}' not found in configuration and model weights " - f"not found at expected path '{model_weights_path}'" - ) - def _get_launch_params(self) -> dict[str, Any]: - """Merge config defaults with CLI overrides.""" - params = self.model_config.model_dump() - - # Process boolean fields - for bool_field in BOOLEAN_FIELDS: - if self.cli_kwargs[bool_field]: - params[bool_field] = True - - # Merge other overrides - for key, value in self.cli_kwargs.items(): - if value is not None and key not in [ - "json_mode", - *BOOLEAN_FIELDS, - ]: - params[key] = value - - # Validate required fields - if not REQUIRED_FIELDS.issubset(set(params.keys())): - raise click.ClickException( - f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}" - ) +class LaunchResponseFormatter: + """CLI Helper class for formatting LaunchResponse.""" - # Create log directory - params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser() - params["log_dir"].mkdir(parents=True, exist_ok=True) - - # Convert to string for JSON serialization - for field in params: - params[field] = str(params[field]) - - return params - - def set_env_vars(self) -> None: - """Set environment variables for the launch command.""" - os.environ["MODEL_NAME"] = self.model_name - os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"] - os.environ["MAX_LOGPROBS"] = self.params["vocab_size"] - os.environ["DATA_TYPE"] = self.params["data_type"] - os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"] - os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"] - os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]] - os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"] - os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"] - os.environ["SRC_DIR"] = SRC_DIR - os.environ["MODEL_WEIGHTS"] = str( - Path(self.params["model_weights_parent_dir"], self.model_name) - ) - os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH - os.environ["VENV_BASE"] = self.params["venv"] - os.environ["LOG_DIR"] = self.params["log_dir"] + def __init__(self, model_name: str, params: dict[str, Any]): + self.model_name = model_name + self.params = params - if self.params.get("enable_prefix_caching"): - os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"] - if self.params.get("enable_chunked_prefill"): - os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"] - if self.params.get("max_num_batched_tokens"): - os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"] - if self.params.get("enforce_eager"): - os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"] - - def build_launch_command(self) -> str: - """Construct the full launch command with parameters.""" - # Base command - command_list = ["sbatch"] - # Append options - command_list.extend(["--job-name", f"{self.model_name}"]) - command_list.extend(["--partition", f"{self.params['partition']}"]) - command_list.extend(["--qos", f"{self.params['qos']}"]) - command_list.extend(["--time", f"{self.params['time']}"]) - command_list.extend(["--nodes", f"{self.params['num_nodes']}"]) - command_list.extend(["--gpus-per-node", f"{self.params['gpus_per_node']}"]) - command_list.extend( - [ - "--output", - f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out", - ] - ) - command_list.extend( - [ - "--error", - f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err", - ] - ) - # Add slurm script - slurm_script = "vllm.slurm" - if int(self.params["num_nodes"]) > 1: - slurm_script = "multinode_vllm.slurm" - command_list.append(f"{SRC_DIR}/{slurm_script}") - return " ".join(command_list) - - def format_table_output(self, job_id: str) -> Table: + def format_table_output(self) -> Table: """Format output as rich Table.""" - table = utils.create_table(key_title="Job Config", value_title="Value") - # Add rows - table.add_row("Slurm Job ID", job_id, style="blue") + table = create_table(key_title="Job Config", value_title="Value") + + # Add key information with consistent styling + table.add_row("Slurm Job ID", self.params["slurm_job_id"], style="blue") table.add_row("Job Name", self.model_name) + + # Add model details table.add_row("Model Type", self.params["model_type"]) + + # Add resource allocation details table.add_row("Partition", self.params["partition"]) table.add_row("QoS", self.params["qos"]) table.add_row("Time Limit", self.params["time"]) table.add_row("Num Nodes", self.params["num_nodes"]) table.add_row("GPUs/Node", self.params["gpus_per_node"]) + + # Add model configuration details table.add_row("Data Type", self.params["data_type"]) table.add_row("Vocabulary Size", self.params["vocab_size"]) table.add_row("Max Model Length", self.params["max_model_len"]) @@ -214,396 +59,147 @@ def format_table_output(self, job_id: str) -> Table: ) if self.params.get("enforce_eager"): table.add_row("Enforce Eager", self.params["enforce_eager"]) + + # Add path details table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS")) table.add_row("Log Directory", self.params["log_dir"]) return table - def post_launch_processing(self, output: str, console: Console) -> None: - """Process and display launch output.""" - json_mode = bool(self.cli_kwargs.get("json_mode", False)) - slurm_job_id = output.split(" ")[-1].strip().strip("\n") - self.params["slurm_job_id"] = slurm_job_id - job_json = Path( - self.params["log_dir"], - f"{self.model_name}.{slurm_job_id}", - f"{self.model_name}.{slurm_job_id}.json", - ) - job_json.parent.mkdir(parents=True, exist_ok=True) - job_json.touch(exist_ok=True) - with job_json.open("w") as file: - json.dump(self.params, file, indent=4) - if json_mode: - click.echo(self.params) - else: - table = self.format_table_output(slurm_job_id) - console.print(table) - - -class StatusHelper: - def __init__(self, slurm_job_id: int, output: str, log_dir: Optional[str] = None): - self.slurm_job_id = slurm_job_id - self.output = output - self.log_dir = log_dir - self.status_info = self._get_base_status_data() - - def _get_base_status_data(self) -> dict[str, Union[str, None]]: - """Extract basic job status information from scontrol output.""" - try: - job_name = self.output.split(" ")[1].split("=")[1] - job_state = self.output.split(" ")[9].split("=")[1] - except IndexError: - job_name = "UNAVAILABLE" - job_state = "UNAVAILABLE" - - return { - "model_name": job_name, - "status": "UNAVAILABLE", - "base_url": "UNAVAILABLE", - "state": job_state, - "pending_reason": None, - "failed_reason": None, - } +class StatusResponseFormatter: + """CLI Helper class for formatting StatusResponse.""" - def process_job_state(self) -> None: - """Process different job states and update status information.""" - if self.status_info["state"] == "PENDING": - self.process_pending_state() - elif self.status_info["state"] == "RUNNING": - self.process_running_state() - - def check_model_health(self) -> None: - """Check model health and update status accordingly.""" - status, status_code = utils.model_health_check( - cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir - ) - if status == "READY": - self.status_info["base_url"] = utils.get_base_url( - cast(str, self.status_info["model_name"]), - self.slurm_job_id, - self.log_dir, - ) - self.status_info["status"] = status - else: - self.status_info["status"], self.status_info["failed_reason"] = ( - status, - cast(str, status_code), - ) - - def process_running_state(self) -> None: - """Process RUNNING job state and check server status.""" - server_status = utils.is_server_running( - cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir - ) - - if isinstance(server_status, tuple): - self.status_info["status"], self.status_info["failed_reason"] = ( - server_status - ) - return - - if server_status == "RUNNING": - self.check_model_health() - else: - self.status_info["status"] = server_status - - def process_pending_state(self) -> None: - """Process PENDING job state.""" - try: - self.status_info["pending_reason"] = self.output.split(" ")[10].split("=")[ - 1 - ] - self.status_info["status"] = "PENDING" - except IndexError: - self.status_info["pending_reason"] = "Unknown pending reason" + def __init__(self, status_info: StatusResponse): + self.status_info = status_info def output_json(self) -> None: """Format and output JSON data.""" json_data = { - "model_name": self.status_info["model_name"], - "model_status": self.status_info["status"], - "base_url": self.status_info["base_url"], + "model_name": self.status_info.model_name, + "model_status": self.status_info.server_status, + "base_url": self.status_info.base_url, } - if self.status_info["pending_reason"]: - json_data["pending_reason"] = self.status_info["pending_reason"] - if self.status_info["failed_reason"]: - json_data["failed_reason"] = self.status_info["failed_reason"] + if self.status_info.pending_reason: + json_data["pending_reason"] = self.status_info.pending_reason + if self.status_info.failed_reason: + json_data["failed_reason"] = self.status_info.failed_reason click.echo(json_data) - def output_table(self, console: Console) -> None: + def output_table(self) -> Table: """Create and display rich table.""" - table = utils.create_table(key_title="Job Status", value_title="Value") - table.add_row("Model Name", self.status_info["model_name"]) - table.add_row("Model Status", self.status_info["status"], style="blue") + table = create_table(key_title="Job Status", value_title="Value") + table.add_row("Model Name", self.status_info.model_name) + table.add_row("Model Status", self.status_info.server_status, style="blue") + + if self.status_info.pending_reason: + table.add_row("Pending Reason", self.status_info.pending_reason) + if self.status_info.failed_reason: + table.add_row("Failed Reason", self.status_info.failed_reason) - if self.status_info["pending_reason"]: - table.add_row("Pending Reason", self.status_info["pending_reason"]) - if self.status_info["failed_reason"]: - table.add_row("Failed Reason", self.status_info["failed_reason"]) + table.add_row("Base URL", self.status_info.base_url) + return table - table.add_row("Base URL", self.status_info["base_url"]) - console.print(table) +class MetricsResponseFormatter: + """CLI Helper class for formatting MetricsResponse.""" -class MetricsHelper: - def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None): - self.slurm_job_id = slurm_job_id - self.log_dir = log_dir - self.status_info = self._get_status_info() - self.metrics_url = self._build_metrics_url() + def __init__(self, metrics: Union[dict[str, float], str]): + self.metrics = self._set_metrics(metrics) + self.table = create_table("Metric", "Value") self.enabled_prefix_caching = self._check_prefix_caching() - self._prev_prompt_tokens: float = 0.0 - self._prev_generation_tokens: float = 0.0 - self._last_updated: Optional[float] = None - self._last_throughputs = {"prompt": 0.0, "generation": 0.0} - - def _get_status_info(self) -> dict[str, Union[str, None]]: - """Retrieve status info using existing StatusHelper.""" - status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner" - output, stderr = utils.run_bash_command(status_cmd) - if stderr: - raise click.ClickException(f"Error: {stderr}") - status_helper = StatusHelper(self.slurm_job_id, output, self.log_dir) - return status_helper.status_info - - def _build_metrics_url(self) -> str: - """Construct metrics endpoint URL from base URL with version stripping.""" - if self.status_info.get("state") == "PENDING": - return "Pending resources for server initialization" - - base_url = utils.get_base_url( - cast(str, self.status_info["model_name"]), - self.slurm_job_id, - self.log_dir, - ) - if not base_url.startswith("http"): - return "Server not ready" - - parsed = urlparse(base_url) - clean_path = parsed.path.replace("/v1", "", 1).rstrip("/") - return urlunparse( - (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "") - ) + def _set_metrics(self, metrics: Union[dict[str, float], str]) -> dict[str, float]: + """Set the metrics attribute.""" + return metrics if isinstance(metrics, dict) else {} def _check_prefix_caching(self) -> bool: - """Check if prefix caching is enabled.""" - job_json = utils.read_slurm_log( - cast(str, self.status_info["model_name"]), - self.slurm_job_id, - "json", - self.log_dir, - ) - if isinstance(job_json, str): - return False - return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False)) - - def fetch_metrics(self) -> Union[dict[str, float], str]: - """Fetch metrics from the endpoint.""" - try: - response = requests.get(self.metrics_url, timeout=3) - response.raise_for_status() - current_metrics = self._parse_metrics(response.text) - current_time = time.time() - - # Set defaults using last known throughputs - current_metrics.setdefault( - "prompt_tokens_per_sec", self._last_throughputs["prompt"] - ) - current_metrics.setdefault( - "generation_tokens_per_sec", self._last_throughputs["generation"] - ) + """Check if prefix caching is enabled by looking for prefix cache metrics.""" + return self.metrics.get("gpu_prefix_cache_hit_rate") is not None - if self._last_updated is None: - self._prev_prompt_tokens = current_metrics.get( - "total_prompt_tokens", 0.0 - ) - self._prev_generation_tokens = current_metrics.get( - "total_generation_tokens", 0.0 - ) - self._last_updated = current_time - return current_metrics - - time_diff = current_time - self._last_updated - if time_diff > 0: - current_prompt = current_metrics.get("total_prompt_tokens", 0.0) - current_gen = current_metrics.get("total_generation_tokens", 0.0) - - delta_prompt = current_prompt - self._prev_prompt_tokens - delta_gen = current_gen - self._prev_generation_tokens - - # Only update throughputs when we have new tokens - prompt_tps = ( - delta_prompt / time_diff - if delta_prompt > 0 - else self._last_throughputs["prompt"] - ) - gen_tps = ( - delta_gen / time_diff - if delta_gen > 0 - else self._last_throughputs["generation"] - ) - - current_metrics["prompt_tokens_per_sec"] = prompt_tps - current_metrics["generation_tokens_per_sec"] = gen_tps - - # Persist calculated values regardless of activity - self._last_throughputs["prompt"] = prompt_tps - self._last_throughputs["generation"] = gen_tps - - # Update tracking state - self._prev_prompt_tokens = current_prompt - self._prev_generation_tokens = current_gen - self._last_updated = current_time - - # Calculate average latency if data is available - if ( - "request_latency_sum" in current_metrics - and "request_latency_count" in current_metrics - ): - latency_sum = current_metrics["request_latency_sum"] - latency_count = current_metrics["request_latency_count"] - current_metrics["avg_request_latency"] = ( - latency_sum / latency_count if latency_count > 0 else 0.0 - ) - - return current_metrics - - except requests.RequestException as e: - return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}" - - def _parse_metrics(self, metrics_text: str) -> dict[str, float]: - """Parse metrics with latency count and sum.""" - key_metrics = { - "vllm:prompt_tokens_total": "total_prompt_tokens", - "vllm:generation_tokens_total": "total_generation_tokens", - "vllm:e2e_request_latency_seconds_sum": "request_latency_sum", - "vllm:e2e_request_latency_seconds_count": "request_latency_count", - "vllm:request_queue_time_seconds_sum": "queue_time_sum", - "vllm:request_success_total": "successful_requests_total", - "vllm:num_requests_running": "requests_running", - "vllm:num_requests_waiting": "requests_waiting", - "vllm:num_requests_swapped": "requests_swapped", - "vllm:gpu_cache_usage_perc": "gpu_cache_usage", - "vllm:cpu_cache_usage_perc": "cpu_cache_usage", - } + def format_failed_metrics(self, message: str) -> None: + self.table.add_row("ERROR", message) - if self.enabled_prefix_caching: - key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate" - key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate" - - parsed: dict[str, float] = {} - for line in metrics_text.split("\n"): - if line.startswith("#") or not line.strip(): - continue - - parts = line.split() - if len(parts) < 2: - continue - - metric_name = parts[0].split("{")[0] - if metric_name in key_metrics: - try: - parsed[key_metrics[metric_name]] = float(parts[1]) - except (ValueError, IndexError): - continue - return parsed - - def display_failed_metrics(self, table: Table, metrics: str) -> None: - table.add_row("Server State", self.status_info["state"], style="yellow") - table.add_row("Message", metrics) - - def display_metrics(self, table: Table, metrics: dict[str, float]) -> None: + def format_metrics(self) -> None: # Throughput metrics - table.add_row( + self.table.add_row( "Prompt Throughput", - f"{metrics.get('prompt_tokens_per_sec', 0):.1f} tokens/s", + f"{self.metrics.get('prompt_tokens_per_sec', 0):.1f} tokens/s", ) - table.add_row( + self.table.add_row( "Generation Throughput", - f"{metrics.get('generation_tokens_per_sec', 0):.1f} tokens/s", + f"{self.metrics.get('generation_tokens_per_sec', 0):.1f} tokens/s", ) # Request queue metrics - table.add_row( + self.table.add_row( "Requests Running", - f"{metrics.get('requests_running', 0):.0f} reqs", + f"{self.metrics.get('requests_running', 0):.0f} reqs", ) - table.add_row( + self.table.add_row( "Requests Waiting", - f"{metrics.get('requests_waiting', 0):.0f} reqs", + f"{self.metrics.get('requests_waiting', 0):.0f} reqs", ) - table.add_row( + self.table.add_row( "Requests Swapped", - f"{metrics.get('requests_swapped', 0):.0f} reqs", + f"{self.metrics.get('requests_swapped', 0):.0f} reqs", ) # Cache usage metrics - table.add_row( + self.table.add_row( "GPU Cache Usage", - f"{metrics.get('gpu_cache_usage', 0) * 100:.1f}%", + f"{self.metrics.get('gpu_cache_usage', 0) * 100:.1f}%", ) - table.add_row( + self.table.add_row( "CPU Cache Usage", - f"{metrics.get('cpu_cache_usage', 0) * 100:.1f}%", + f"{self.metrics.get('cpu_cache_usage', 0) * 100:.1f}%", ) if self.enabled_prefix_caching: - table.add_row( + self.table.add_row( "GPU Prefix Cache Hit Rate", - f"{metrics.get('gpu_prefix_cache_hit_rate', 0) * 100:.1f}%", + f"{self.metrics.get('gpu_prefix_cache_hit_rate', 0) * 100:.1f}%", ) - table.add_row( + self.table.add_row( "CPU Prefix Cache Hit Rate", - f"{metrics.get('cpu_prefix_cache_hit_rate', 0) * 100:.1f}%", + f"{self.metrics.get('cpu_prefix_cache_hit_rate', 0) * 100:.1f}%", ) # Show average latency if available - if "avg_request_latency" in metrics: - table.add_row( + if "avg_request_latency" in self.metrics: + self.table.add_row( "Avg Request Latency", - f"{metrics['avg_request_latency']:.1f} s", + f"{self.metrics['avg_request_latency']:.1f} s", ) # Token counts - table.add_row( + self.table.add_row( "Total Prompt Tokens", - f"{metrics.get('total_prompt_tokens', 0):.0f} tokens", + f"{self.metrics.get('total_prompt_tokens', 0):.0f} tokens", ) - table.add_row( + self.table.add_row( "Total Generation Tokens", - f"{metrics.get('total_generation_tokens', 0):.0f} tokens", + f"{self.metrics.get('total_generation_tokens', 0):.0f} tokens", ) - table.add_row( + self.table.add_row( "Successful Requests", - f"{metrics.get('successful_requests_total', 0):.0f} reqs", + f"{self.metrics.get('successful_requests_total', 0):.0f} reqs", ) -class ListHelper: - """Helper class for handling model listing functionality.""" +class ListCmdDisplay: + """CLI Helper class for displaying model listing functionality.""" - def __init__(self, model_name: Optional[str] = None, json_mode: bool = False): - self.model_name = model_name + def __init__(self, console: Console, json_mode: bool = False): + self.console = console self.json_mode = json_mode - self.model_configs = utils.load_config() + self.model_config = None + self.model_names: list[str] = [] - def get_single_model_config(self) -> ModelConfig: - """Get configuration for a specific model.""" - config = next( - (c for c in self.model_configs if c.model_name == self.model_name), None - ) - if not config: - raise click.ClickException( - f"Model '{self.model_name}' not found in configuration" - ) - return config - - def format_single_model_output( + def _format_single_model_output( self, config: ModelConfig ) -> Union[dict[str, Any], Table]: - """Format output for a single model.""" + """Format output table for a single model.""" if self.json_mode: # Exclude non-essential fields from JSON output excluded = {"venv", "log_dir"} @@ -614,62 +210,47 @@ def format_single_model_output( ) return config_dict - table = utils.create_table(key_title="Model Config", value_title="Value") + table = create_table(key_title="Model Config", value_title="Value") for field, value in config.model_dump().items(): if field not in {"venv", "log_dir"}: table.add_row(field, str(value)) return table - def format_all_models_output(self) -> Union[list[str], list[Panel]]: - """Format output for all models.""" - if self.json_mode: - return [config.model_name for config in self.model_configs] - + def _format_all_models_output( + self, model_infos: list[ModelInfo] + ) -> Union[list[str], list[Panel]]: + """Format output table for all models.""" # Sort by model type priority - type_priority = {"LLM": 0, "VLM": 1, "Text_Embedding": 2, "Reward_Modeling": 3} - sorted_configs = sorted( - self.model_configs, key=lambda x: type_priority.get(x.model_type, 4) + sorted_model_infos = sorted( + model_infos, + key=lambda x: MODEL_TYPE_PRIORITY.get(x.type, 4), ) # Create panels with color coding - model_type_colors = { - "LLM": "cyan", - "VLM": "bright_blue", - "Text_Embedding": "purple", - "Reward_Modeling": "bright_magenta", - } - panels = [] - for config in sorted_configs: - color = model_type_colors.get(config.model_type, "white") - variant = config.model_variant or "" - display_text = f"[magenta]{config.model_family}[/magenta]" + for model_info in sorted_model_infos: + color = MODEL_TYPE_COLORS.get(model_info.type, "white") + variant = model_info.variant or "" + display_text = f"[magenta]{model_info.family}[/magenta]" if variant: display_text += f"-{variant}" panels.append(Panel(display_text, expand=True, border_style=color)) return panels - def process_list_command(self, console: Console) -> None: - """Process the list command and display output.""" - try: - if self.model_name: - # Handle single model case - config = self.get_single_model_config() - output = self.format_single_model_output(config) - if self.json_mode: - click.echo(output) - else: - console.print(output) - # Handle all models case - elif self.json_mode: - # JSON output for all models is just a list of names - model_names = [config.model_name for config in self.model_configs] - click.echo(model_names) - else: - # Rich output for all models is a list of panels - panels = self.format_all_models_output() - if isinstance(panels, list): # This helps mypy understand the type - console.print(Columns(panels, equal=True)) - except Exception as e: - raise click.ClickException(str(e)) from e + def display_single_model_output(self, config: ModelConfig) -> None: + """Display the output for a single model.""" + output = self._format_single_model_output(config) + if self.json_mode: + click.echo(output) + else: + self.console.print(output) + + def display_all_models_output(self, model_infos: list[ModelInfo]) -> None: + """Display the output for all models.""" + if self.json_mode: + model_names = [info.name for info in model_infos] + click.echo(model_names) + else: + panels = self._format_all_models_output(model_infos) + self.console.print(Columns(panels, equal=True)) diff --git a/vec_inf/cli/_models.py b/vec_inf/cli/_models.py new file mode 100644 index 00000000..1b45df9f --- /dev/null +++ b/vec_inf/cli/_models.py @@ -0,0 +1,15 @@ +"""Data models for CLI rendering.""" + +MODEL_TYPE_PRIORITY = { + "LLM": 0, + "VLM": 1, + "Text_Embedding": 2, + "Reward_Modeling": 3, +} + +MODEL_TYPE_COLORS = { + "LLM": "cyan", + "VLM": "bright_blue", + "Text_Embedding": "purple", + "Reward_Modeling": "bright_magenta", +} diff --git a/vec_inf/cli/_utils.py b/vec_inf/cli/_utils.py index 574e00db..33ad63d1 100644 --- a/vec_inf/cli/_utils.py +++ b/vec_inf/cli/_utils.py @@ -1,122 +1,7 @@ -"""Utility functions for the CLI.""" +"""Helper functions for the CLI.""" -import json -import os -import subprocess -from pathlib import Path -from typing import Any, Optional, Union, cast - -import requests -import yaml from rich.table import Table -from vec_inf.cli._config import ModelConfig - - -MODEL_READY_SIGNATURE = "INFO: Application startup complete." -CACHED_CONFIG = Path("/", "model-weights", "vec-inf-shared", "models.yaml") - - -def run_bash_command(command: str) -> tuple[str, str]: - """Run a bash command and return the output.""" - process = subprocess.Popen( - command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True - ) - return process.communicate() - - -def read_slurm_log( - slurm_job_name: str, - slurm_job_id: int, - slurm_log_type: str, - log_dir: Optional[Union[str, Path]], -) -> Union[list[str], str, dict[str, str]]: - """Read the slurm log file.""" - if not log_dir: - # Default log directory - models_dir = Path.home() / ".vec-inf-logs" - if not models_dir.exists(): - return "LOG DIR NOT FOUND" - # Iterate over all dirs in models_dir, sorted by dir name length in desc order - for directory in sorted( - [d for d in models_dir.iterdir() if d.is_dir()], - key=lambda d: len(d.name), - reverse=True, - ): - if directory.name in slurm_job_name: - log_dir = directory - break - else: - log_dir = Path(log_dir) - - # If log_dir is still not set, then didn't find the log dir at default location - if not log_dir: - return "LOG DIR NOT FOUND" - - try: - file_path = ( - log_dir - / Path(f"{slurm_job_name}.{slurm_job_id}") - / f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}" - ) - if slurm_log_type == "json": - with file_path.open("r") as file: - json_content: dict[str, str] = json.load(file) - return json_content - else: - with file_path.open("r") as file: - return file.readlines() - except FileNotFoundError: - return f"LOG FILE NOT FOUND: {file_path}" - - -def is_server_running( - slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str] -) -> Union[str, tuple[str, str]]: - """Check if a model is ready to serve requests.""" - log_content = read_slurm_log(slurm_job_name, slurm_job_id, "err", log_dir) - if isinstance(log_content, str): - return log_content - - status: Union[str, tuple[str, str]] = "LAUNCHING" - - for line in log_content: - if "error" in line.lower(): - status = ("FAILED", line.strip("\n")) - if MODEL_READY_SIGNATURE in line: - status = "RUNNING" - - return status - - -def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]) -> str: - """Get the base URL of a model.""" - log_content = read_slurm_log(slurm_job_name, slurm_job_id, "json", log_dir) - if isinstance(log_content, str): - return log_content - - server_addr = cast(dict[str, str], log_content).get("server_address") - return server_addr if server_addr else "URL NOT FOUND" - - -def model_health_check( - slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str] -) -> tuple[str, Union[str, int]]: - """Check the health of a running model on the cluster.""" - base_url = get_base_url(slurm_job_name, slurm_job_id, log_dir) - if not base_url.startswith("http"): - return ("FAILED", base_url) - health_check_url = base_url.replace("v1", "health") - - try: - response = requests.get(health_check_url) - # Check if the request was successful - if response.status_code == 200: - return ("READY", response.status_code) - return ("FAILED", response.status_code) - except requests.exceptions.RequestException as e: - return ("FAILED", str(e)) - def create_table( key_title: str = "", value_title: str = "", show_header: bool = True @@ -126,37 +11,3 @@ def create_table( table.add_column(key_title, style="dim") table.add_column(value_title) return table - - -def load_config() -> list[ModelConfig]: - """Load the model configuration.""" - default_path = ( - CACHED_CONFIG - if CACHED_CONFIG.exists() - else Path(__file__).resolve().parent.parent / "config" / "models.yaml" - ) - - config: dict[str, Any] = {} - with open(default_path) as f: - config = yaml.safe_load(f) or {} - - user_path = os.getenv("VEC_INF_CONFIG") - if user_path: - user_path_obj = Path(user_path) - if user_path_obj.exists(): - with open(user_path_obj) as f: - user_config = yaml.safe_load(f) or {} - for name, data in user_config.get("models", {}).items(): - if name in config.get("models", {}): - config["models"][name].update(data) - else: - config.setdefault("models", {})[name] = data - else: - print( - f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}" - ) - - return [ - ModelConfig(model_name=name, **model_data) - for name, model_data in config.get("models", {}).items() - ] diff --git a/vec_inf/client/__init__.py b/vec_inf/client/__init__.py new file mode 100644 index 00000000..e6b45824 --- /dev/null +++ b/vec_inf/client/__init__.py @@ -0,0 +1,33 @@ +"""Programmatic API for Vector Inference. + +This module provides a Python API for launching and managing inference servers +using `vec_inf`. It is an alternative to the command-line interface, and allows +users direct control over the lifecycle of inference servers via python scripts. +""" + +from vec_inf.client._config import ModelConfig +from vec_inf.client._models import ( + LaunchOptions, + LaunchOptionsDict, + LaunchResponse, + MetricsResponse, + ModelInfo, + ModelStatus, + ModelType, + StatusResponse, +) +from vec_inf.client.api import VecInfClient + + +__all__ = [ + "VecInfClient", + "LaunchResponse", + "StatusResponse", + "ModelInfo", + "MetricsResponse", + "ModelStatus", + "ModelType", + "LaunchOptions", + "LaunchOptionsDict", + "ModelConfig", +] diff --git a/vec_inf/cli/_config.py b/vec_inf/client/_config.py similarity index 100% rename from vec_inf/cli/_config.py rename to vec_inf/client/_config.py diff --git a/vec_inf/client/_exceptions.py b/vec_inf/client/_exceptions.py new file mode 100644 index 00000000..296cec16 --- /dev/null +++ b/vec_inf/client/_exceptions.py @@ -0,0 +1,37 @@ +"""Exceptions for the vector inference package.""" + + +class ModelConfigurationError(Exception): + """Raised when the model config or weights are missing or invalid.""" + + pass + + +class MissingRequiredFieldsError(ValueError): + """Raised when required fields are missing from the provided parameters.""" + + pass + + +class ModelNotFoundError(KeyError): + """Raised when the specified model name is not found in the configuration.""" + + pass + + +class SlurmJobError(RuntimeError): + """Raised when there's an error with a Slurm job.""" + + pass + + +class APIError(Exception): + """Base exception for API errors.""" + + pass + + +class ServerError(Exception): + """Exception raised when there's an error with the inference server.""" + + pass diff --git a/vec_inf/client/_helper.py b/vec_inf/client/_helper.py new file mode 100644 index 00000000..d5b9481a --- /dev/null +++ b/vec_inf/client/_helper.py @@ -0,0 +1,508 @@ +"""Helper classes for the model.""" + +import json +import os +import time +import warnings +from pathlib import Path +from typing import Any, Optional, Union, cast +from urllib.parse import urlparse, urlunparse + +import requests + +import vec_inf.client._utils as utils +from vec_inf.client._config import ModelConfig +from vec_inf.client._exceptions import ( + MissingRequiredFieldsError, + ModelConfigurationError, + ModelNotFoundError, + SlurmJobError, +) +from vec_inf.client._models import ( + LaunchResponse, + ModelInfo, + ModelStatus, + ModelType, + StatusResponse, +) +from vec_inf.client._vars import ( + BOOLEAN_FIELDS, + LD_LIBRARY_PATH, + REQUIRED_FIELDS, + SRC_DIR, + VLLM_TASK_MAP, +) + + +class ModelLauncher: + """Helper class for handling inference server launch.""" + + def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]): + """Initialize the model launcher. + + Parameters + ---------- + model_name: str + Name of the model to launch + kwargs: Optional[dict[str, Any]] + Optional launch keyword arguments to override default configuration + """ + self.model_name = model_name + self.kwargs = kwargs or {} + self.slurm_job_id = "" + self.model_config = self._get_model_configuration() + self.params = self._get_launch_params() + + def _warn(self, message: str) -> None: + """Warn the user about a potential issue.""" + warnings.warn(message, UserWarning, stacklevel=2) + + def _get_model_configuration(self) -> ModelConfig: + """Load and validate model configuration.""" + model_configs = utils.load_config() + config = next( + (m for m in model_configs if m.model_name == self.model_name), None + ) + + if config: + return config + + # If model config not found, check for path from CLI kwargs or use fallback + model_weights_parent_dir = self.kwargs.get( + "model_weights_parent_dir", + model_configs[0].model_weights_parent_dir if model_configs else None, + ) + + if not model_weights_parent_dir: + raise ModelNotFoundError( + "Could not determine model weights parent directory" + ) + + model_weights_path = Path(model_weights_parent_dir, self.model_name) + + # Only give a warning if weights exist but config missing + if model_weights_path.exists(): + self._warn( + f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments", + ) + # Return a dummy model config object with model name and weights parent dir + return ModelConfig( + model_name=self.model_name, + model_family="model_family_placeholder", + model_type="LLM", + gpus_per_node=1, + num_nodes=1, + vocab_size=1000, + max_model_len=8192, + model_weights_parent_dir=Path(str(model_weights_parent_dir)), + ) + + raise ModelConfigurationError( + f"'{self.model_name}' not found in configuration and model weights " + f"not found at expected path '{model_weights_path}'" + ) + + def _get_launch_params(self) -> dict[str, Any]: + """Merge config defaults with CLI overrides.""" + params = self.model_config.model_dump() + + # Process boolean fields + for bool_field in BOOLEAN_FIELDS: + if self.kwargs.get(bool_field) and self.kwargs[bool_field]: + params[bool_field] = True + + # Merge other overrides + for key, value in self.kwargs.items(): + if value is not None and key not in [ + "json_mode", + *BOOLEAN_FIELDS, + ]: + params[key] = value + + # Validate required fields + if not REQUIRED_FIELDS.issubset(set(params.keys())): + raise MissingRequiredFieldsError( + f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}" + ) + + # Create log directory + params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser() + params["log_dir"].mkdir(parents=True, exist_ok=True) + + # Convert to string for JSON serialization + for field in params: + params[field] = str(params[field]) + + return params + + def _set_env_vars(self) -> None: + """Set environment variables for the launch command.""" + os.environ["MODEL_NAME"] = self.model_name + os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"] + os.environ["MAX_LOGPROBS"] = self.params["vocab_size"] + os.environ["DATA_TYPE"] = self.params["data_type"] + os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"] + os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"] + os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]] + os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"] + os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"] + os.environ["SRC_DIR"] = SRC_DIR + os.environ["MODEL_WEIGHTS"] = str( + Path(self.params["model_weights_parent_dir"], self.model_name) + ) + os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH + os.environ["VENV_BASE"] = self.params["venv"] + os.environ["LOG_DIR"] = self.params["log_dir"] + + if self.params.get("enable_prefix_caching"): + os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"] + if self.params.get("enable_chunked_prefill"): + os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"] + if self.params.get("max_num_batched_tokens"): + os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"] + if self.params.get("enforce_eager"): + os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"] + + def _build_launch_command(self) -> str: + """Construct the full launch command with parameters.""" + # Base command + command_list = ["sbatch"] + # Append options + command_list.extend(["--job-name", f"{self.model_name}"]) + command_list.extend(["--partition", f"{self.params['partition']}"]) + command_list.extend(["--qos", f"{self.params['qos']}"]) + command_list.extend(["--time", f"{self.params['time']}"]) + command_list.extend(["--nodes", f"{self.params['num_nodes']}"]) + command_list.extend(["--gpus-per-node", f"{self.params['gpus_per_node']}"]) + command_list.extend( + [ + "--output", + f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out", + ] + ) + command_list.extend( + [ + "--error", + f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err", + ] + ) + # Add slurm script + slurm_script = "vllm.slurm" + if int(self.params["num_nodes"]) > 1: + slurm_script = "multinode_vllm.slurm" + command_list.append(f"{SRC_DIR}/{slurm_script}") + return " ".join(command_list) + + def launch(self) -> LaunchResponse: + """Launch the model.""" + # Set environment variables + self._set_env_vars() + + # Build and execute the launch command + command_output, stderr = utils.run_bash_command(self._build_launch_command()) + if stderr: + raise SlurmJobError(f"Error: {stderr}") + + # Extract slurm job id from command output + self.slurm_job_id = command_output.split(" ")[-1].strip().strip("\n") + self.params["slurm_job_id"] = self.slurm_job_id + + # Create log directory and job json file + job_json = Path( + self.params["log_dir"], + f"{self.model_name}.{self.slurm_job_id}", + f"{self.model_name}.{self.slurm_job_id}.json", + ) + job_json.parent.mkdir(parents=True, exist_ok=True) + job_json.touch(exist_ok=True) + + with job_json.open("w") as file: + json.dump(self.params, file, indent=4) + + return LaunchResponse( + slurm_job_id=int(self.slurm_job_id), + model_name=self.model_name, + config=self.params, + raw_output=command_output, + ) + + +class ModelStatusMonitor: + """Class for handling server status information and monitoring.""" + + def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None): + self.slurm_job_id = slurm_job_id + self.output = self._get_raw_status_output() + self.log_dir = log_dir + self.status_info = self._get_base_status_data() + + def _get_raw_status_output(self) -> str: + """Get the raw server status output from slurm.""" + status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner" + output, stderr = utils.run_bash_command(status_cmd) + if stderr: + raise SlurmJobError(f"Error: {stderr}") + return output + + def _get_base_status_data(self) -> StatusResponse: + """Extract basic job status information from scontrol output.""" + try: + job_name = self.output.split(" ")[1].split("=")[1] + job_state = self.output.split(" ")[9].split("=")[1] + except IndexError: + job_name = "UNAVAILABLE" + job_state = ModelStatus.UNAVAILABLE + + return StatusResponse( + model_name=job_name, + server_status=ModelStatus.UNAVAILABLE, + job_state=job_state, + raw_output=self.output, + base_url="UNAVAILABLE", + pending_reason=None, + failed_reason=None, + ) + + def _check_model_health(self) -> None: + """Check model health and update status accordingly.""" + status, status_code = utils.model_health_check( + self.status_info.model_name, self.slurm_job_id, self.log_dir + ) + if status == ModelStatus.READY: + self.status_info.base_url = utils.get_base_url( + self.status_info.model_name, + self.slurm_job_id, + self.log_dir, + ) + self.status_info.server_status = status + else: + self.status_info.server_status = status + self.status_info.failed_reason = cast(str, status_code) + + def _process_running_state(self) -> None: + """Process RUNNING job state and check server status.""" + server_status = utils.is_server_running( + self.status_info.model_name, self.slurm_job_id, self.log_dir + ) + + if isinstance(server_status, tuple): + self.status_info.server_status, self.status_info.failed_reason = ( + server_status + ) + return + + if server_status == "RUNNING": + self._check_model_health() + else: + self.status_info.server_status = cast(ModelStatus, server_status) + + def _process_pending_state(self) -> None: + """Process PENDING job state.""" + try: + self.status_info.pending_reason = self.output.split(" ")[10].split("=")[1] + self.status_info.server_status = ModelStatus.PENDING + except IndexError: + self.status_info.pending_reason = "Unknown pending reason" + + def process_model_status(self) -> StatusResponse: + """Process different job states and update status information.""" + if self.status_info.job_state == ModelStatus.PENDING: + self._process_pending_state() + elif self.status_info.job_state == "RUNNING": + self._process_running_state() + + return self.status_info + + +class PerformanceMetricsCollector: + """Class for handling metrics collection and processing.""" + + def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None): + self.slurm_job_id = slurm_job_id + self.log_dir = log_dir + self.status_info = self._get_status_info() + self.metrics_url = self._build_metrics_url() + self.enabled_prefix_caching = self._check_prefix_caching() + + self._prev_prompt_tokens: float = 0.0 + self._prev_generation_tokens: float = 0.0 + self._last_updated: Optional[float] = None + self._last_throughputs = {"prompt": 0.0, "generation": 0.0} + + def _get_status_info(self) -> StatusResponse: + """Retrieve status info using existing StatusHelper.""" + status_helper = ModelStatusMonitor(self.slurm_job_id, self.log_dir) + return status_helper.process_model_status() + + def _build_metrics_url(self) -> str: + """Construct metrics endpoint URL from base URL with version stripping.""" + if self.status_info.job_state == ModelStatus.PENDING: + return "Pending resources for server initialization" + + base_url = utils.get_base_url( + self.status_info.model_name, + self.slurm_job_id, + self.log_dir, + ) + if not base_url.startswith("http"): + return "Server not ready" + + parsed = urlparse(base_url) + clean_path = parsed.path.replace("/v1", "", 1).rstrip("/") + return urlunparse( + (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "") + ) + + def _check_prefix_caching(self) -> bool: + """Check if prefix caching is enabled.""" + job_json = utils.read_slurm_log( + self.status_info.model_name, + self.slurm_job_id, + "json", + self.log_dir, + ) + if isinstance(job_json, str): + return False + return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False)) + + def _parse_metrics(self, metrics_text: str) -> dict[str, float]: + """Parse metrics with latency count and sum.""" + key_metrics = { + "vllm:prompt_tokens_total": "total_prompt_tokens", + "vllm:generation_tokens_total": "total_generation_tokens", + "vllm:e2e_request_latency_seconds_sum": "request_latency_sum", + "vllm:e2e_request_latency_seconds_count": "request_latency_count", + "vllm:request_queue_time_seconds_sum": "queue_time_sum", + "vllm:request_success_total": "successful_requests_total", + "vllm:num_requests_running": "requests_running", + "vllm:num_requests_waiting": "requests_waiting", + "vllm:num_requests_swapped": "requests_swapped", + "vllm:gpu_cache_usage_perc": "gpu_cache_usage", + "vllm:cpu_cache_usage_perc": "cpu_cache_usage", + } + + if self.enabled_prefix_caching: + key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate" + key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate" + + parsed: dict[str, float] = {} + for line in metrics_text.split("\n"): + if line.startswith("#") or not line.strip(): + continue + + parts = line.split() + if len(parts) < 2: + continue + + metric_name = parts[0].split("{")[0] + if metric_name in key_metrics: + try: + parsed[key_metrics[metric_name]] = float(parts[1]) + except (ValueError, IndexError): + continue + return parsed + + def fetch_metrics(self) -> Union[dict[str, float], str]: + """Fetch metrics from the endpoint.""" + try: + response = requests.get(self.metrics_url, timeout=3) + response.raise_for_status() + current_metrics = self._parse_metrics(response.text) + current_time = time.time() + + # Set defaults using last known throughputs + current_metrics.setdefault( + "prompt_tokens_per_sec", self._last_throughputs["prompt"] + ) + current_metrics.setdefault( + "generation_tokens_per_sec", self._last_throughputs["generation"] + ) + + if self._last_updated is None: + self._prev_prompt_tokens = current_metrics.get( + "total_prompt_tokens", 0.0 + ) + self._prev_generation_tokens = current_metrics.get( + "total_generation_tokens", 0.0 + ) + self._last_updated = current_time + return current_metrics + + time_diff = current_time - self._last_updated + if time_diff > 0: + current_prompt = current_metrics.get("total_prompt_tokens", 0.0) + current_gen = current_metrics.get("total_generation_tokens", 0.0) + + delta_prompt = current_prompt - self._prev_prompt_tokens + delta_gen = current_gen - self._prev_generation_tokens + + # Only update throughputs when we have new tokens + prompt_tps = ( + delta_prompt / time_diff + if delta_prompt > 0 + else self._last_throughputs["prompt"] + ) + gen_tps = ( + delta_gen / time_diff + if delta_gen > 0 + else self._last_throughputs["generation"] + ) + + current_metrics["prompt_tokens_per_sec"] = prompt_tps + current_metrics["generation_tokens_per_sec"] = gen_tps + + # Persist calculated values regardless of activity + self._last_throughputs["prompt"] = prompt_tps + self._last_throughputs["generation"] = gen_tps + + # Update tracking state + self._prev_prompt_tokens = current_prompt + self._prev_generation_tokens = current_gen + self._last_updated = current_time + + # Calculate average latency if data is available + if ( + "request_latency_sum" in current_metrics + and "request_latency_count" in current_metrics + ): + latency_sum = current_metrics["request_latency_sum"] + latency_count = current_metrics["request_latency_count"] + current_metrics["avg_request_latency"] = ( + latency_sum / latency_count if latency_count > 0 else 0.0 + ) + + return current_metrics + + except requests.RequestException as e: + return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}" + + +class ModelRegistry: + """Class for handling model listing and configuration management.""" + + def __init__(self) -> None: + """Initialize the model lister.""" + self.model_configs = utils.load_config() + + def get_all_models(self) -> list[ModelInfo]: + """Get all available models.""" + available_models = [] + for config in self.model_configs: + info = ModelInfo( + name=config.model_name, + family=config.model_family, + variant=config.model_variant, + type=ModelType(config.model_type), + config=config.model_dump(exclude={"model_name", "venv", "log_dir"}), + ) + available_models.append(info) + return available_models + + def get_single_model_config(self, model_name: str) -> ModelConfig: + """Get configuration for a specific model.""" + config = next( + (c for c in self.model_configs if c.model_name == model_name), None + ) + if not config: + raise ModelNotFoundError(f"Model '{model_name}' not found in configuration") + return config diff --git a/vec_inf/client/_models.py b/vec_inf/client/_models.py new file mode 100644 index 00000000..df78d9e5 --- /dev/null +++ b/vec_inf/client/_models.py @@ -0,0 +1,128 @@ +""" +Data models for Vector Inference API. + +This module contains the data model classes used by the Vector Inference API +for both request parameters and response objects. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional, TypedDict, Union + +from typing_extensions import NotRequired + + +class ModelStatus(str, Enum): + """Enum representing the possible status states of a model.""" + + PENDING = "PENDING" + LAUNCHING = "LAUNCHING" + READY = "READY" + FAILED = "FAILED" + SHUTDOWN = "SHUTDOWN" + UNAVAILABLE = "UNAVAILABLE" + + +class ModelType(str, Enum): + """Enum representing the possible model types.""" + + LLM = "LLM" + VLM = "VLM" + TEXT_EMBEDDING = "Text_Embedding" + REWARD_MODELING = "Reward_Modeling" + + +@dataclass +class LaunchResponse: + """Response from launching a model.""" + + slurm_job_id: int + model_name: str + config: dict[str, Any] + raw_output: str = field(repr=False) + + +@dataclass +class StatusResponse: + """Response from checking a model's status.""" + + model_name: str + server_status: ModelStatus + job_state: Union[str, ModelStatus] + raw_output: str = field(repr=False) + base_url: Optional[str] = None + pending_reason: Optional[str] = None + failed_reason: Optional[str] = None + + +@dataclass +class MetricsResponse: + """Response from retrieving model metrics.""" + + model_name: str + metrics: Union[dict[str, float], str] + timestamp: float + + +@dataclass +class LaunchOptions: + """Options for launching a model.""" + + model_family: Optional[str] = None + model_variant: Optional[str] = None + max_model_len: Optional[int] = None + max_num_seqs: Optional[int] = None + gpu_memory_utilization: Optional[float] = None + enable_prefix_caching: Optional[bool] = None + enable_chunked_prefill: Optional[bool] = None + max_num_batched_tokens: Optional[int] = None + partition: Optional[str] = None + num_nodes: Optional[int] = None + gpus_per_node: Optional[int] = None + qos: Optional[str] = None + time: Optional[str] = None + vocab_size: Optional[int] = None + data_type: Optional[str] = None + venv: Optional[str] = None + log_dir: Optional[str] = None + model_weights_parent_dir: Optional[str] = None + pipeline_parallelism: Optional[bool] = None + compilation_config: Optional[str] = None + enforce_eager: Optional[bool] = None + + +class LaunchOptionsDict(TypedDict): + """TypedDict for LaunchOptions.""" + + model_family: NotRequired[Optional[str]] + model_variant: NotRequired[Optional[str]] + max_model_len: NotRequired[Optional[int]] + max_num_seqs: NotRequired[Optional[int]] + gpu_memory_utilization: NotRequired[Optional[float]] + enable_prefix_caching: NotRequired[Optional[bool]] + enable_chunked_prefill: NotRequired[Optional[bool]] + max_num_batched_tokens: NotRequired[Optional[int]] + partition: NotRequired[Optional[str]] + num_nodes: NotRequired[Optional[int]] + gpus_per_node: NotRequired[Optional[int]] + qos: NotRequired[Optional[str]] + time: NotRequired[Optional[str]] + vocab_size: NotRequired[Optional[int]] + data_type: NotRequired[Optional[str]] + venv: NotRequired[Optional[str]] + log_dir: NotRequired[Optional[str]] + model_weights_parent_dir: NotRequired[Optional[str]] + pipeline_parallelism: NotRequired[Optional[bool]] + compilation_config: NotRequired[Optional[str]] + enforce_eager: NotRequired[Optional[bool]] + + +@dataclass +class ModelInfo: + """Information about an available model.""" + + name: str + family: str + variant: Optional[str] + type: ModelType + config: dict[str, Any] diff --git a/vec_inf/client/_utils.py b/vec_inf/client/_utils.py new file mode 100644 index 00000000..25882e0f --- /dev/null +++ b/vec_inf/client/_utils.py @@ -0,0 +1,180 @@ +"""Utility functions shared between CLI and API.""" + +import json +import os +import subprocess +import warnings +from pathlib import Path +from typing import Any, Optional, Union, cast + +import requests +import yaml + +from vec_inf.client._config import ModelConfig +from vec_inf.client._models import ModelStatus +from vec_inf.client._vars import ( + CACHED_CONFIG, + MODEL_READY_SIGNATURE, +) + + +def run_bash_command(command: str) -> tuple[str, str]: + """Run a bash command and return the output.""" + process = subprocess.Popen( + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + return process.communicate() + + +def read_slurm_log( + slurm_job_name: str, + slurm_job_id: int, + slurm_log_type: str, + log_dir: Optional[Union[str, Path]], +) -> Union[list[str], str, dict[str, str]]: + """Read the slurm log file.""" + if not log_dir: + # Default log directory + models_dir = Path.home() / ".vec-inf-logs" + # Iterate over all dirs in models_dir, sorted by dir name length in desc order + for directory in sorted( + [d for d in models_dir.iterdir() if d.is_dir()], + key=lambda d: len(d.name), + reverse=True, + ): + if directory.name in slurm_job_name: + log_dir = directory + break + else: + log_dir = Path(log_dir) + + # If log_dir is still not set, then didn't find the log dir at default location + if not log_dir: + return "LOG DIR NOT FOUND" + + try: + file_path = ( + log_dir + / Path(f"{slurm_job_name}.{slurm_job_id}") + / f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}" + ) + if slurm_log_type == "json": + with file_path.open("r") as file: + json_content: dict[str, str] = json.load(file) + return json_content + else: + with file_path.open("r") as file: + return file.readlines() + except FileNotFoundError: + return f"LOG FILE NOT FOUND: {file_path}" + + +def is_server_running( + slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str] +) -> Union[str, ModelStatus, tuple[ModelStatus, str]]: + """Check if a model is ready to serve requests.""" + log_content = read_slurm_log(slurm_job_name, slurm_job_id, "err", log_dir) + if isinstance(log_content, str): + return log_content + + status: Union[str, tuple[ModelStatus, str]] = ModelStatus.LAUNCHING + + for line in log_content: + if "error" in line.lower(): + status = (ModelStatus.FAILED, line.strip("\n")) + if MODEL_READY_SIGNATURE in line: + status = "RUNNING" + + return status + + +def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]) -> str: + """Get the base URL of a model.""" + log_content = read_slurm_log(slurm_job_name, slurm_job_id, "json", log_dir) + if isinstance(log_content, str): + return log_content + + server_addr = cast(dict[str, str], log_content).get("server_address") + return server_addr if server_addr else "URL NOT FOUND" + + +def model_health_check( + slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str] +) -> tuple[ModelStatus, Union[str, int]]: + """Check the health of a running model on the cluster.""" + base_url = get_base_url(slurm_job_name, slurm_job_id, log_dir) + if not base_url.startswith("http"): + return (ModelStatus.FAILED, base_url) + health_check_url = base_url.replace("v1", "health") + + try: + response = requests.get(health_check_url) + # Check if the request was successful + if response.status_code == 200: + return (ModelStatus.READY, response.status_code) + return (ModelStatus.FAILED, response.status_code) + except requests.exceptions.RequestException as e: + return (ModelStatus.FAILED, str(e)) + + +def load_config() -> list[ModelConfig]: + """Load the model configuration.""" + default_path = ( + CACHED_CONFIG + if CACHED_CONFIG.exists() + else Path(__file__).resolve().parent.parent / "config" / "models.yaml" + ) + + config: dict[str, Any] = {} + with open(default_path) as f: + config = yaml.safe_load(f) or {} + + user_path = os.getenv("VEC_INF_CONFIG") + if user_path: + user_path_obj = Path(user_path) + if user_path_obj.exists(): + with open(user_path_obj) as f: + user_config = yaml.safe_load(f) or {} + for name, data in user_config.get("models", {}).items(): + if name in config.get("models", {}): + config["models"][name].update(data) + else: + config.setdefault("models", {})[name] = data + else: + warnings.warn( + f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}", + UserWarning, + stacklevel=2, + ) + + return [ + ModelConfig(model_name=name, **model_data) + for name, model_data in config.get("models", {}).items() + ] + + +def parse_launch_output(output: str) -> tuple[str, dict[str, str]]: + """Parse output from model launch command. + + Parameters + ---------- + output: str + Output from the launch command + + Returns + ------- + tuple[str, dict[str, str]] + Slurm job ID and dictionary of config parameters + + """ + slurm_job_id = output.split(" ")[-1].strip().strip("\n") + + # Extract config parameters + config_dict = {} + output_lines = output.split("\n")[:-2] + for line in output_lines: + if ": " in line: + key, value = line.split(": ", 1) + config_dict[key.lower().replace(" ", "_")] = value + + return slurm_job_id, config_dict diff --git a/vec_inf/client/_vars.py b/vec_inf/client/_vars.py new file mode 100644 index 00000000..71e9e221 --- /dev/null +++ b/vec_inf/client/_vars.py @@ -0,0 +1,35 @@ +"""Global variables for the vector inference package.""" + +from pathlib import Path + + +MODEL_READY_SIGNATURE = "INFO: Application startup complete." +CACHED_CONFIG = Path("/", "model-weights", "vec-inf-shared", "models.yaml") +SRC_DIR = str(Path(__file__).parent.parent) +LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/" + +# Maps model types to vLLM tasks +VLLM_TASK_MAP = { + "LLM": "generate", + "VLM": "generate", + "TEXT_EMBEDDING": "embed", + "REWARD_MODELING": "reward", +} + +# Required fields for model configuration +REQUIRED_FIELDS = { + "model_family", + "model_type", + "gpus_per_node", + "num_nodes", + "vocab_size", + "max_model_len", +} + +# Boolean fields for model configuration +BOOLEAN_FIELDS = { + "pipeline_parallelism", + "enforce_eager", + "enable_prefix_caching", + "enable_chunked_prefill", +} diff --git a/vec_inf/client/api.py b/vec_inf/client/api.py new file mode 100644 index 00000000..88020d88 --- /dev/null +++ b/vec_inf/client/api.py @@ -0,0 +1,243 @@ +"""Vector Inference client for programmatic access. + +This module provides the main client class for interacting with Vector Inference +services programmatically. +""" + +import time +from typing import Any, Optional, Union + +from vec_inf.client._config import ModelConfig +from vec_inf.client._exceptions import ( + ServerError, + SlurmJobError, +) +from vec_inf.client._helper import ( + ModelLauncher, + ModelRegistry, + ModelStatusMonitor, + PerformanceMetricsCollector, +) +from vec_inf.client._models import ( + LaunchOptions, + LaunchResponse, + MetricsResponse, + ModelInfo, + ModelStatus, + StatusResponse, +) +from vec_inf.client._utils import run_bash_command + + +class VecInfClient: + """Client for interacting with Vector Inference programmatically. + + This class provides methods for launching models, checking their status, + retrieving metrics, and shutting down models using the Vector Inference + infrastructure. + + Examples + -------- + >>> from vec_inf.api import VecInfClient + >>> client = VecInfClient() + >>> response = client.launch_model("Meta-Llama-3.1-8B-Instruct") + >>> job_id = response.slurm_job_id + >>> status = client.get_status(job_id) + >>> if status.status == ModelStatus.READY: + ... print(f"Model is ready at {status.base_url}") + >>> client.shutdown_model(job_id) + + """ + + def __init__(self) -> None: + """Initialize the Vector Inference client.""" + pass + + def list_models(self) -> list[ModelInfo]: + """List all available models. + + Returns + ------- + list[ModelInfo] + ModelInfo objects containing information about available models. + """ + model_registry = ModelRegistry() + return model_registry.get_all_models() + + def get_model_config(self, model_name: str) -> ModelConfig: + """Get the configuration for a specific model. + + Parameters + ---------- + model_name: str + Name of the model to get configuration for. + + Returns + ------- + ModelConfig + Model configuration. + """ + model_registry = ModelRegistry() + return model_registry.get_single_model_config(model_name) + + def launch_model( + self, model_name: str, options: Optional[LaunchOptions] = None + ) -> LaunchResponse: + """Launch a model on the cluster. + + Parameters + ---------- + model_name: str + Name of the model to launch. + options: LaunchOptions, optional + Optional launch options to override default configuration. + + Returns + ------- + LaunchResponse + Information about the launched model. + """ + # Convert LaunchOptions to dictionary if provided + options_dict: dict[str, Any] = {} + if options: + options_dict = {k: v for k, v in vars(options).items() if v is not None} + + # Create and use the API Launch Helper + model_launcher = ModelLauncher(model_name, options_dict) + return model_launcher.launch() + + def get_status( + self, slurm_job_id: int, log_dir: Optional[str] = None + ) -> StatusResponse: + """Get the status of a running model. + + Parameters + ---------- + slurm_job_id: str + The Slurm job ID to check. + log_dir: str, optional + Optional path to the Slurm log directory. + + Returns + ------- + StatusResponse + Model status information. + """ + model_status_monitor = ModelStatusMonitor(slurm_job_id, log_dir) + return model_status_monitor.process_model_status() + + def get_metrics( + self, slurm_job_id: int, log_dir: Optional[str] = None + ) -> MetricsResponse: + """Get the performance metrics of a running model. + + Parameters + ---------- + slurm_job_id : str + The Slurm job ID to get metrics for. + log_dir : str, optional + Optional path to the Slurm log directory. + + Returns + ------- + MetricsResponse + Object containing the model's performance metrics. + """ + performance_metrics_collector = PerformanceMetricsCollector( + slurm_job_id, log_dir + ) + + metrics: Union[dict[str, float], str] + if not performance_metrics_collector.metrics_url.startswith("http"): + metrics = performance_metrics_collector.metrics_url + else: + metrics = performance_metrics_collector.fetch_metrics() + + return MetricsResponse( + model_name=performance_metrics_collector.status_info.model_name, + metrics=metrics, + timestamp=time.time(), + ) + + def shutdown_model(self, slurm_job_id: int) -> bool: + """Shutdown a running model. + + Parameters + ---------- + slurm_job_id: str + The Slurm job ID to shut down. + + Returns + ------- + bool + True if the model was successfully shutdown, False otherwise. + + Raises + ------ + SlurmJobError + If there was an error shutting down the model. + """ + shutdown_cmd = f"scancel {slurm_job_id}" + _, stderr = run_bash_command(shutdown_cmd) + if stderr: + raise SlurmJobError(f"Failed to shutdown model: {stderr}") + return True + + def wait_until_ready( + self, + slurm_job_id: int, + timeout_seconds: int = 1800, + poll_interval_seconds: int = 10, + log_dir: Optional[str] = None, + ) -> StatusResponse: + """Wait until a model is ready or fails. + + Parameters + ---------- + slurm_job_id: str + The Slurm job ID to wait for. + timeout_seconds: int + Maximum time to wait in seconds (default: 30 mins). + poll_interval_seconds: int + How often to check status in seconds (default: 10s). + log_dir: str, optional + Optional path to the Slurm log directory. + + Returns + ------- + StatusResponse + Status, if the model is ready or failed. + + Raises + ------ + SlurmJobError + If the specified job is not found or there's an error with the job. + ServerError + If the server fails to start within the timeout period. + APIError + If there was an error checking the status. + + """ + start_time = time.time() + + while True: + status_info = self.get_status(slurm_job_id, log_dir) + + if status_info.server_status == ModelStatus.READY: + return status_info + + if status_info.server_status == ModelStatus.FAILED: + error_message = status_info.failed_reason or "Unknown error" + raise ServerError(f"Model failed to start: {error_message}") + + if status_info.server_status == ModelStatus.SHUTDOWN: + raise ServerError("Model was shutdown before it became ready") + + # Check timeout + if time.time() - start_time > timeout_seconds: + raise ServerError( + f"Timed out waiting for model to become ready after {timeout_seconds} seconds" + ) + + # Wait before checking again + time.sleep(poll_interval_seconds)