Skip to content

Commit 5c9ebcb

Browse files
jinhongyiiMasterJH5574CharlieFRuanyingchen21
authored
[Serve] MicroServing Implementation (#3064)
This PR introduces MicroServing API. MicroServing introduces simple yet effective REST APIs to support fine-grained sub-request level actions. A programmable router transforms user requests into sub-request calls, lifting fine-grained scheduling to the API level, thus enabling the dynamic reconfiguration of different orchestration patterns. Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Charlie Ruan <[email protected]> Co-authored-by: Yingcheng Wang <[email protected]>
1 parent 49dcd4a commit 5c9ebcb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2497
-198
lines changed

cpp/metadata/model.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
8989
result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, "tensor_parallel_shards");
9090
result.pipeline_parallel_stages =
9191
json::LookupOrDefault<int64_t>(metadata, "pipeline_parallel_stages", 1);
92+
result.disaggregation = json::LookupOrDefault<bool>(metadata, "disaggregation", false);
9293
result.kv_state_kind = KVStateKindFromString(
9394
json::LookupOrDefault<std::string>(metadata, "kv_state_kind", "kv_cache"));
9495
if (result.kv_state_kind != KVStateKind::kNone &&

cpp/metadata/model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ struct ModelMetadata {
8080
int64_t sliding_window_size;
8181
int64_t tensor_parallel_shards;
8282
int64_t pipeline_parallel_stages;
83+
bool disaggregation;
8384
int64_t attention_sink_size;
8485
std::vector<Param> params;
8586
std::unordered_map<std::string, int64_t> memory_usage;

cpp/serve/config.cc

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "../json_ffi/openai_api_protocol.h"
1515
#include "../support/json_parser.h"
16+
#include "../support/utils.h"
1617
#include "data.h"
1718

1819
namespace mlc {
@@ -62,6 +63,105 @@ picojson::object ResponseFormat::AsJSON() const {
6263
return config;
6364
}
6465

66+
/****************** DisaggConfig ******************/
67+
68+
Result<DisaggConfig> DisaggConfig::FromJSON(const picojson::object& config) {
69+
using TResult = Result<DisaggConfig>;
70+
DisaggConfig res;
71+
std::optional<std::string> kind = json::LookupOptional<std::string>(config, "kind");
72+
if (kind.has_value()) {
73+
if (kind.value() == "prepare_prefill") {
74+
res.kind = DisaggRequestKind::kPreparePrefill;
75+
} else if (kind.value() == "remote_prefill") {
76+
res.kind = DisaggRequestKind::kRemotePrefill;
77+
} else if (kind.value() == "start_decode") {
78+
res.kind = DisaggRequestKind::kStartDecode;
79+
} else {
80+
return TResult::Error("Unknown disaggregation request kind " + kind.value());
81+
}
82+
}
83+
std::optional<std::string> kv_append_metadata_encoded =
84+
json::LookupOptional<std::string>(config, "kv_append_metadata");
85+
if (kv_append_metadata_encoded.has_value()) {
86+
picojson::value parse_result;
87+
std::string err =
88+
picojson::parse(parse_result, Base64Decode(kv_append_metadata_encoded.value()));
89+
if (!err.empty()) {
90+
return TResult::Error("kv_append_metadata parse error: " + err);
91+
}
92+
if (!parse_result.is<picojson::array>()) {
93+
return TResult::Error("kv_append_metadata is not array of integer.");
94+
}
95+
picojson::array kv_append_metadata_arr = parse_result.get<picojson::array>();
96+
std::vector<IntTuple> kv_append_metadata;
97+
int ptr = 0;
98+
while (ptr < static_cast<int>(kv_append_metadata_arr.size())) {
99+
if (!kv_append_metadata_arr[ptr].is<int64_t>()) {
100+
return TResult::Error("Invalid kv append metadata value in kv_append_metadata array");
101+
}
102+
int num_segments = kv_append_metadata_arr[ptr].get<int64_t>();
103+
if (ptr + num_segments * 2 + 1 > static_cast<int>(kv_append_metadata_arr.size())) {
104+
return TResult::Error("Invalid kv append metadata compression in kv_append_metadata");
105+
}
106+
std::vector<int64_t> compressed_kv_append_metadata{num_segments};
107+
compressed_kv_append_metadata.reserve(num_segments * 2 + 1);
108+
for (int i = 1; i <= num_segments * 2; ++i) {
109+
if (!kv_append_metadata_arr[ptr + i].is<int64_t>()) {
110+
return TResult::Error("Invalid kv append metadata value in kv_append_metadata array");
111+
}
112+
compressed_kv_append_metadata.push_back(kv_append_metadata_arr[ptr + i].get<int64_t>());
113+
}
114+
kv_append_metadata.push_back(IntTuple(std::move(compressed_kv_append_metadata)));
115+
ptr += num_segments * 2 + 1;
116+
}
117+
res.kv_append_metadata = std::move(kv_append_metadata);
118+
}
119+
res.kv_window_begin = json::LookupOptional<int64_t>(config, "kv_window_begin");
120+
res.kv_window_end = json::LookupOptional<int64_t>(config, "kv_window_end");
121+
res.dst_group_offset = json::LookupOptional<int64_t>(config, "dst_group_offset");
122+
return TResult::Ok(res);
123+
}
124+
125+
picojson::object DisaggConfig::AsJSON() const {
126+
picojson::object config;
127+
switch (kind) {
128+
case DisaggRequestKind::kPreparePrefill: {
129+
config["kind"] = picojson::value("prepare_prefill");
130+
break;
131+
}
132+
case DisaggRequestKind::kRemotePrefill: {
133+
config["kind"] = picojson::value("remote_prefill");
134+
break;
135+
}
136+
case DisaggRequestKind::kStartDecode: {
137+
config["kind"] = picojson::value("start_decode");
138+
break;
139+
}
140+
default:
141+
break;
142+
}
143+
if (!kv_append_metadata.empty()) {
144+
picojson::array kv_append_metadata_arr;
145+
for (const IntTuple& compressed_kv_append_metadata : kv_append_metadata) {
146+
for (int64_t value : compressed_kv_append_metadata) {
147+
kv_append_metadata_arr.push_back(picojson::value(value));
148+
}
149+
}
150+
config["kv_append_metadata"] =
151+
picojson::value(Base64Encode(picojson::value(kv_append_metadata_arr).serialize()));
152+
}
153+
if (kv_window_begin.has_value()) {
154+
config["kv_window_begin"] = picojson::value(static_cast<int64_t>(kv_window_begin.value()));
155+
}
156+
if (kv_window_end.has_value()) {
157+
config["kv_window_end"] = picojson::value(static_cast<int64_t>(kv_window_end.value()));
158+
}
159+
if (dst_group_offset.has_value()) {
160+
config["dst_group_offset"] = picojson::value(static_cast<int64_t>(dst_group_offset.value()));
161+
}
162+
return config;
163+
}
164+
65165
/****************** DebugConfig ******************/
66166

67167
Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
@@ -74,7 +174,7 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
74174
if (special_request == "query_engine_metrics") {
75175
res.special_request = SpecialRequestKind::kQueryEngineMetrics;
76176
} else {
77-
return TResult::Error("Uknown special request " + special_request);
177+
return TResult::Error("Unknown special request " + special_request);
78178
}
79179
}
80180
std::string grammar_execution_mode =
@@ -84,8 +184,14 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
84184
} else if (grammar_execution_mode == "constraint") {
85185
res.grammar_execution_mode = GrammarExecutionMode::kConstraint;
86186
} else {
87-
return TResult::Error("Uknown grammar execution mode " + grammar_execution_mode);
187+
return TResult::Error("Unknown grammar execution mode " + grammar_execution_mode);
88188
}
189+
Result<DisaggConfig> disagg_config =
190+
DisaggConfig::FromJSON(json::Lookup<picojson::object>(config, "disagg_config"));
191+
if (disagg_config.IsErr()) {
192+
return TResult::Error(disagg_config.UnwrapErr());
193+
}
194+
res.disagg_config = disagg_config.Unwrap();
89195
return TResult::Ok(res);
90196
}
91197

@@ -114,6 +220,9 @@ picojson::object DebugConfig::AsJSON() const {
114220
break;
115221
}
116222
}
223+
if (disagg_config.kind != DisaggRequestKind::kNone) {
224+
config["disagg_config"] = picojson::value(disagg_config.AsJSON());
225+
}
117226
return config;
118227
}
119228

cpp/serve/config.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ enum class SpecialRequestKind : int {
4646
kQueryEngineMetrics = 1,
4747
};
4848

49+
enum class DisaggRequestKind : int {
50+
kNone = 0,
51+
kPreparePrefill = 1,
52+
kRemotePrefill = 2,
53+
kStartDecode = 3,
54+
};
55+
4956
/*! \brief Controls the behavior of inference with grammar constraint. */
5057
enum class GrammarExecutionMode : int {
5158
/*! \brief If grammar is provided for a request, use the grammar to constrain the output token. */
@@ -55,6 +62,28 @@ enum class GrammarExecutionMode : int {
5562
kJumpForward = 1,
5663
};
5764

65+
/*! \brief The config for disaggregation requests. */
66+
class DisaggConfig {
67+
public:
68+
DisaggRequestKind kind = DisaggRequestKind::kNone;
69+
std::vector<IntTuple> kv_append_metadata;
70+
// "kv_window_begin" and "kv_window_end" denote the KV interval of interests.
71+
// "kv_window_end" supports Python style negative indexing.
72+
// The concrete meaning varies for different special request kind:
73+
// - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes
74+
// the KV range to prefill on a prefill instance.
75+
// - For "remote_prefill", "[begin:end]" means the KV range to compute prefill
76+
// and send to the decode instance.
77+
// - For "start_decode", the end is always nullopt, and "[begin:]" denotes
78+
// the KV range to prefill locally on the decode instance.
79+
std::optional<int> kv_window_begin = std::nullopt;
80+
std::optional<int> kv_window_end = std::nullopt;
81+
std::optional<int> dst_group_offset = std::nullopt;
82+
83+
static Result<DisaggConfig> FromJSON(const picojson::object& config_json);
84+
picojson::object AsJSON() const;
85+
};
86+
5887
/*! \brief The debug configuration of a request. */
5988
class DebugConfig {
6089
public:
@@ -63,6 +92,7 @@ class DebugConfig {
6392
SpecialRequestKind special_request = SpecialRequestKind::kNone;
6493
/*! \brief The grammar execution mode. */
6594
GrammarExecutionMode grammar_execution_mode = GrammarExecutionMode::kJumpForward;
95+
DisaggConfig disagg_config;
6696

6797
/*!
6898
* \brief Create debug config from JSON.

cpp/serve/data.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,45 @@ namespace serve {
1616

1717
TVM_REGISTER_OBJECT_TYPE(DataNode);
1818

19+
std::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,
20+
int split_pos) {
21+
CHECK_GE(split_pos, 0);
22+
CHECK_GE(total_length, split_pos)
23+
<< "Cannot truncate when the current length is already less than the target length";
24+
std::vector<Data> lhs(original_data.begin(), original_data.end());
25+
std::vector<Data> rhs;
26+
while (total_length > split_pos) {
27+
ICHECK(!lhs.empty());
28+
Data last_data = lhs.back();
29+
int last_data_length = last_data->GetLength();
30+
ICHECK_GE(total_length - last_data_length, 0);
31+
if (total_length - last_data_length >= split_pos) {
32+
// Pop the entire last data.
33+
rhs.push_back(lhs.back());
34+
lhs.pop_back();
35+
total_length -= last_data_length;
36+
continue;
37+
}
38+
// Partially truncate the last data.
39+
const auto* token_data = last_data.as<TokenDataNode>();
40+
CHECK(token_data != nullptr) << "Only TokenData supports partial truncation.";
41+
int length_to_truncate = total_length - split_pos;
42+
CHECK_GT(length_to_truncate, 0);
43+
CHECK_LT(length_to_truncate, last_data_length);
44+
TokenData lhs_token_data(
45+
IntTuple{token_data->token_ids.begin(), token_data->token_ids.end() - length_to_truncate});
46+
TokenData rhs_token_data(
47+
IntTuple{token_data->token_ids.end() - length_to_truncate, token_data->token_ids.end()});
48+
CHECK_EQ(total_length - last_data_length + lhs_token_data->GetLength(), split_pos);
49+
lhs.pop_back();
50+
lhs.push_back(lhs_token_data);
51+
rhs.push_back(rhs_token_data);
52+
std::reverse(rhs.begin(), rhs.end());
53+
total_length = split_pos;
54+
}
55+
return {lhs, rhs};
56+
}
57+
1958
/****************** TextData ******************/
2059

2160
TVM_REGISTER_OBJECT_TYPE(TextDataNode);

cpp/serve/data.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class Data : public ObjectRef {
5757
TVM_DEFINE_OBJECT_REF_METHODS(Data, ObjectRef, DataNode);
5858
};
5959

60+
/*! \brief Split the given data array into two arrays at the "split_pos" position. */
61+
std::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,
62+
int split_pos);
63+
6064
/****************** TextDataNode ******************/
6165

6266
/*! \brief The class of text data, containing a text string. */

0 commit comments

Comments
 (0)