Skip to content

Commit a096c91

Browse files
authored
[Serving] Support tensor parallel shards override in command line (#2533)
This PR supports the command line overrides for model JIT compilation. This is especially helpful for enabling tensor parallelism out of box, so people don't need to manually tweak `mlc-chat-config.json` to use tensor parallelism.
1 parent 5f71aa9 commit a096c91

File tree

22 files changed

+347
-135
lines changed

22 files changed

+347
-135
lines changed

cpp/serve/engine.cc

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <dlpack/dlpack.h>
99
#include <tvm/runtime/logging.h>
10+
#include <tvm/runtime/memory/memory_manager.h>
1011
#include <tvm/runtime/module.h>
1112
#include <tvm/runtime/packed_func.h>
1213
#include <tvm/runtime/registry.h>
@@ -21,6 +22,7 @@
2122
#include "../grammar/grammar_state_matcher.h"
2223
#include "../support/json_parser.h"
2324
#include "../support/result.h"
25+
#include "../support/utils.h"
2426
#include "../tokenizers/tokenizers.h"
2527
#include "engine_actions/action.h"
2628
#include "engine_actions/action_commons.h"
@@ -278,22 +280,27 @@ class EngineImpl : public Engine {
278280
std::vector<std::pair<std::string, std::string>> models_and_model_libs =
279281
models_and_model_libs_res.Unwrap();
280282

281-
ICHECK_GE(models_and_model_libs.size(), 1);
283+
int num_model = models_and_model_libs.size();
284+
ICHECK_GE(num_model, 1);
282285
// - Initialize singleton states inside the engine.
283286
n->estate_->Reset();
284287
n->request_stream_callback_ = std::move(request_stream_callback);
285288
n->trace_recorder_ = trace_recorder;
286289
n->device_ = device;
287290
// - Load model config, create a shared disco session when tensor
288291
// parallelism is enabled.
292+
std::vector<std::string> model_libs;
289293
std::vector<picojson::object> model_configs;
290-
for (int i = 0; i < static_cast<int>(models_and_model_libs.size()); ++i) {
294+
model_libs.reserve(num_model);
295+
model_configs.reserve(num_model);
296+
for (int i = 0; i < num_model; ++i) {
291297
const auto& [model_str, model_lib] = models_and_model_libs[i];
292298
Result<picojson::object> model_config_res = Model::LoadModelConfig(model_str);
293299
if (model_config_res.IsErr()) {
294300
return TResult::Error("Model " + std::to_string(i) +
295301
" has invalid mlc-chat-config.json: " + model_config_res.UnwrapErr());
296302
}
303+
model_libs.push_back(model_lib);
297304
model_configs.push_back(model_config_res.Unwrap());
298305
}
299306

@@ -303,13 +310,14 @@ class EngineImpl : public Engine {
303310
model_configs[0]);
304311
}
305312

306-
Optional<Session> session = n->CreateDiscoSession(model_configs, device);
313+
auto [session, num_shards] = n->CreateDiscoSession(model_libs, model_configs, device);
307314
// - Initialize each model independently.
308315
n->models_.clear();
309-
for (int i = 0; i < static_cast<int>(models_and_model_libs.size()); ++i) {
316+
for (int i = 0; i < num_model; ++i) {
310317
const auto& [model_str, model_lib] = models_and_model_libs[i];
311-
Model model = Model::Create(model_lib, model_str, model_configs[i], device, session,
312-
/*trace_enabled=*/trace_recorder.defined());
318+
Model model =
319+
Model::Create(model_lib, model_str, model_configs[i], device, session, num_shards,
320+
/*trace_enabled=*/trace_recorder.defined());
313321
n->models_.push_back(model);
314322
}
315323
// - Automatically infer the missing fields in EngineConfig JSON strings
@@ -622,25 +630,44 @@ class EngineImpl : public Engine {
622630
}
623631

624632
/************** Utility Functions **************/
625-
Optional<Session> CreateDiscoSession(const std::vector<picojson::object>& model_configs,
626-
Device device) {
633+
std::pair<Optional<Session>, int> CreateDiscoSession(
634+
const std::vector<std::string>& model_libs,
635+
const std::vector<picojson::object>& model_configs, Device device) {
627636
const auto& base_model_config = model_configs[0];
628637

629-
auto f_get_num_shards = [](const picojson::object& model_config) -> int {
630-
constexpr auto kNumShardsKey = "tensor_parallel_shards";
631-
if (model_config.count(kNumShardsKey)) {
632-
const auto& val = model_config.at(kNumShardsKey);
633-
CHECK(val.is<int64_t>());
634-
return static_cast<int>(val.get<int64_t>());
638+
auto f_get_num_shards = [&device](const std::string& model_lib,
639+
const picojson::object& model_config) -> int {
640+
if (!StartsWith(model_lib, "system://")) {
641+
Module executable = tvm::runtime::Module::LoadFromFile(model_lib);
642+
PackedFunc fload_exec = executable->GetFunction("vm_load_executable");
643+
ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable";
644+
Module local_vm = fload_exec();
645+
local_vm->GetFunction("vm_initialization")(
646+
static_cast<int>(device.device_type), device.device_id,
647+
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled),
648+
static_cast<int>(kDLCPU), 0,
649+
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled));
650+
return ModelMetadata::FromModule(local_vm, std::move(model_config)).tensor_parallel_shards;
635651
} else {
636-
LOG(FATAL) << "Key \"tensor_parallel_shards\" not found.";
652+
return 1;
637653
}
638-
throw;
639654
};
640655

641-
int num_shards = std::transform_reduce(
642-
model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); },
643-
f_get_num_shards);
656+
int num_shards = -1;
657+
ICHECK_EQ(model_libs.size(), model_configs.size());
658+
for (int i = 0; i < static_cast<int>(model_libs.size()); ++i) {
659+
int model_num_shards = f_get_num_shards(model_libs[i], model_configs[i]);
660+
if (i == 0) {
661+
num_shards = model_num_shards;
662+
} else {
663+
CHECK_EQ(model_num_shards, num_shards)
664+
<< "Inconsistent tensor_parallel_shards values across models. Some model is compiled "
665+
"with tensor_parallel_shards "
666+
<< num_shards << " and some other model is compiled with tensor_parallel_shards "
667+
<< model_num_shards;
668+
}
669+
}
670+
644671
Optional<Session> session = NullOpt;
645672
if (num_shards > 1) {
646673
constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool";
@@ -664,7 +691,7 @@ class EngineImpl : public Engine {
664691
session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker");
665692
session.value()->InitCCL(ccl, ShapeTuple(device_ids));
666693
}
667-
return session;
694+
return {session, num_shards};
668695
}
669696

670697
/************** Debug/Profile **************/

cpp/serve/function_table.cc

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,14 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func,
7070
}
7171

7272
void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config,
73-
Optional<Session> session) {
73+
Optional<Session> session, int num_shards) {
7474
local_gpu_device = device;
7575
Device null_device{DLDeviceType(0), 0};
76-
int num_shards;
77-
{
78-
if (model_config.count("tensor_parallel_shards")) {
79-
CHECK(model_config["tensor_parallel_shards"].is<int64_t>());
80-
num_shards = model_config["tensor_parallel_shards"].get<int64_t>();
81-
} else {
82-
num_shards = 1;
83-
}
84-
}
8576
this->model_config = model_config;
8677
this->cached_buffers = Map<String, ObjectRef>();
8778

8879
if (num_shards > 1) {
80+
ICHECK(session.defined());
8981
this->sess = session.value();
9082
this->use_disco = true;
9183
this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"),
@@ -111,6 +103,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object
111103
ModelMetadata::FromModule(this->disco_mod->DebugGetFromRemote(0), std::move(model_config));
112104
this->_InitFunctions();
113105
} else {
106+
ICHECK(!session.defined());
114107
Module executable{nullptr};
115108
PackedFunc fload_exec{nullptr};
116109
if (StartsWith(reload_lib_path, "system://")) {
@@ -145,6 +138,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object
145138
this->model_metadata_ = ModelMetadata::FromModule(this->local_vm, std::move(model_config));
146139
this->_InitFunctions();
147140
}
141+
ICHECK_EQ(this->model_metadata_.tensor_parallel_shards, num_shards);
148142
}
149143

150144
ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device) {

cpp/serve/function_table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct FunctionTable {
4242
static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name);
4343

4444
void Init(String reload_lib_path, Device device, picojson::object model_config,
45-
Optional<Session> session);
45+
Optional<Session> session, int num_shards);
4646

4747
ObjectRef LoadParams(const std::string& model_path, Device device);
4848

cpp/serve/model.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ class ModelImpl;
2727
TVM_REGISTER_OBJECT_TYPE(ModelObj);
2828

2929
Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config,
30-
DLDevice device, const Optional<Session>& session, bool trace_enabled) {
30+
DLDevice device, const Optional<Session>& session, int num_shards,
31+
bool trace_enabled) {
3132
return Model(make_object<ModelImpl>(reload_lib_path, model_path, model_config, device, session,
32-
trace_enabled));
33+
num_shards, trace_enabled));
3334
}
3435

3536
Result<picojson::object> Model::LoadModelConfig(const String& model_path) {
@@ -56,14 +57,15 @@ class ModelImpl : public ModelObj {
5657
* \sa Model::Create
5758
*/
5859
explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config,
59-
DLDevice device, const Optional<Session>& session, bool trace_enabled)
60+
DLDevice device, const Optional<Session>& session, int num_shards,
61+
bool trace_enabled)
6062
: model_(model_path), device_(device) {
6163
// Step 1. Process model config json string.
6264
LoadModelConfigJSON(model_config);
6365
// Step 2. Initialize vm, we use the packed function mechanism
6466
// so there is no explicit abi dependency on these extra
6567
// classes other than basic tvm runtime.
66-
this->ft_.Init(reload_lib_path, device_, model_config, session);
68+
this->ft_.Init(reload_lib_path, device_, model_config, session, num_shards);
6769
// Step 3. Reset
6870
this->Reset();
6971
// Step 4. Set model type

cpp/serve/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,13 @@ class Model : public ObjectRef {
368368
* \param model_config The model config json object.
369369
* \param device The device to run the model on.
370370
* \param session The session to run the model on.
371+
* \param num_shards The number of tensor parallel shards of the model.
371372
* \param trace_enabled A boolean indicating whether tracing is enabled.
372373
* \return The created runtime module.
373374
*/
374375
static Model Create(String reload_lib_path, String model_path,
375376
const picojson::object& model_config, DLDevice device,
376-
const Optional<Session>& session, bool trace_enabled);
377+
const Optional<Session>& session, int num_shards, bool trace_enabled);
377378

378379
/*!
379380
* Load the model config from the given model path.

docs/deploy/cli.rst

Lines changed: 51 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,102 +3,89 @@
33
CLI
44
===============
55

6-
MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.
6+
MLC Chat CLI is the command line tool to run MLC-compiled LLMs out of the box interactively.
77

88
.. contents:: Table of Contents
99
:local:
1010
:depth: 2
1111

12-
Option 1. Conda Prebuilt
13-
~~~~~~~~~~~~~~~~~~~~~~~~
12+
Install MLC-LLM Package
13+
------------------------
1414

15-
The prebuilt package supports Metal on macOS and Vulkan on Linux and Windows, and can be installed via Conda one-liner.
15+
Chat CLI is a part of the MLC-LLM package.
16+
To use the chat CLI, first install MLC LLM by following the instructions :ref:`here <install-mlc-packages>`.
17+
Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful:
1618

17-
To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source <mlcchat_build_from_source>`.
19+
.. code:: bash
1820
19-
.. code:: shell
21+
mlc_llm chat --help
2022
21-
conda activate your-environment
22-
python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly
23-
mlc_llm chat -h
23+
You should see serve help message if the installation was successful.
2424

25-
.. note::
26-
The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source.
27-
28-
29-
Option 2. Build MLC Runtime from Source
30-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31-
32-
We also provide options to build mlc runtime libraries and ``mlc_llm`` from source.
33-
This step is useful if the prebuilt is unavailable on your platform, or if you would like to build a runtime
34-
that supports other GPU runtime than the prebuilt version. We can build a customized version
35-
of mlc chat runtime. You only need to do this if you choose not to use the prebuilt.
36-
37-
First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`).
38-
Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries.
39-
40-
.. `|` adds a blank line
41-
42-
|
25+
Quick Start
26+
------------
4327

44-
Run Models through MLCChat CLI
45-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28+
This section provides a quick start guide to work with MLC-LLM chat CLI.
29+
To launch the CLI session, run the following command:
4630

47-
Once ``mlc_llm`` is installed, you are able to run any MLC-compiled model on the command line.
31+
.. code:: bash
4832
49-
To run a model with MLC LLM in any platform, you can either:
33+
mlc_llm chat MODEL [--model-lib PATH-TO-MODEL-LIB]
5034
51-
- Use off-the-shelf model prebuilts from the MLC Huggingface repo (see :ref:`Model Prebuilts` for details).
52-
- Use locally compiled model weights and libraries following :doc:`the model compilation page </compilation/compile_models>`.
35+
where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process <compile-model-libraries>`. Information about other arguments can be found in the next section.
5336

54-
**Option 1: Use model prebuilts**
55-
56-
To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``.
57-
For example, to run the MLC Llama 3 8B Q4F16_1 model (`Repo link <https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC>`_),
58-
simply use ``HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC``. The model weights and library will be downloaded
59-
automatically from Huggingface.
60-
61-
.. code:: shell
62-
63-
mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024
37+
Once the chat CLI is ready, you can enter the prompt to interact with the model.
6438

6539
.. code::
6640
6741
You can use the following special commands:
6842
/help print the special commands
6943
/exit quit the cli
70-
/stats print out the latest stats (token/sec)
44+
/stats print out stats of last request (token/sec)
45+
/metrics print out full engine metrics
7146
/reset restart a fresh chat
7247
/set [overrides] override settings in the generation config. For example,
73-
`/set temperature=0.5;max_gen_len=100;stop=end,stop`
48+
`/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
7449
Note: Separate stop words in the `stop` option with commas (,).
7550
Multi-line input: Use escape+enter to start a new line.
7651
77-
user: What's the meaning of life
78-
assistant:
79-
What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life.
52+
>>> What's the meaning of life?
53+
The meaning of life is a philosophical and metaphysical question related to the purpose or significance of life or existence in general...
54+
55+
.. note::
56+
57+
If you want to enable tensor parallelism to run LLMs on multiple GPUs,
58+
please specify argument ``--overrides "tensor_parallel_shards=$NGPU"``.
59+
For example,
60+
61+
.. code:: shell
8062
81-
The concept of the meaning of life has been debated and...
63+
mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --overrides "tensor_parallel_shards=2"
8264
8365
84-
**Option 2: Use locally compiled model weights and libraries**
66+
The ``mlc_llm chat`` Command
67+
----------------------------
8568

86-
For models other than the prebuilt ones we provided:
69+
We provide the list of chat CLI interface for reference.
8770

88-
1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral``),
89-
follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries.
90-
2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights.
71+
.. code:: bash
9172
92-
Once you have the model locally compiled with a model library and model weights, to run ``mlc_llm``, simply
73+
mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] [--device DEVICE] [--overrides OVERRIDES]
9374
94-
- Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model``
95-
- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib``
9675
97-
.. code:: shell
76+
MODEL The model folder after compiling with MLC-LLM build process. The parameter
77+
can either be the model name with its quantization scheme
78+
(e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model
79+
folder. In the former case, we will use the provided name to search
80+
for the model folder over possible paths.
9881

99-
mlc_llm chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \
100-
--device "cuda:0" --overrides context_window_size=1024 \
101-
--model-lib dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so
102-
# CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so
103-
# Metal on macOS: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-metal.so
104-
# Same rule applies for other platforms
82+
--model-lib A field to specify the full path to the model library file to use (e.g. a ``.so`` file).
83+
--device The description of the device to run on. User should provide a string in the
84+
form of ``device_name:device_id`` or ``device_name``, where ``device_name`` is one of
85+
``cuda``, ``metal``, ``vulkan``, ``rocm``, ``opencl``, ``auto`` (automatically detect the
86+
local device), and ``device_id`` is the device id to run on. The default value is ``auto``,
87+
with the device id set to 0 for default.
88+
--overrides Model configuration override. Supports overriding
89+
``context_window_size``, ``prefill_chunk_size``, ``sliding_window_size``, ``attention_sink_size``,
90+
``max_batch_size`` and ``tensor_parallel_shards``. The overrides could be explicitly
91+
specified via details knobs, e.g. --overrides ``context_window_size=1024;prefill_chunk_size=128``.

0 commit comments

Comments
 (0)