13
13
14
14
#include " ../json_ffi/openai_api_protocol.h"
15
15
#include " ../support/json_parser.h"
16
+ #include " ../support/utils.h"
16
17
#include " data.h"
17
18
18
19
namespace mlc {
@@ -62,6 +63,105 @@ picojson::object ResponseFormat::AsJSON() const {
62
63
return config;
63
64
}
64
65
66
+ /* ***************** DisaggConfig ******************/
67
+
68
+ Result<DisaggConfig> DisaggConfig::FromJSON (const picojson::object& config) {
69
+ using TResult = Result<DisaggConfig>;
70
+ DisaggConfig res;
71
+ std::optional<std::string> kind = json::LookupOptional<std::string>(config, " kind" );
72
+ if (kind.has_value ()) {
73
+ if (kind.value () == " prepare_prefill" ) {
74
+ res.kind = DisaggRequestKind::kPreparePrefill ;
75
+ } else if (kind.value () == " remote_prefill" ) {
76
+ res.kind = DisaggRequestKind::kRemotePrefill ;
77
+ } else if (kind.value () == " start_decode" ) {
78
+ res.kind = DisaggRequestKind::kStartDecode ;
79
+ } else {
80
+ return TResult::Error (" Unknown disaggregation request kind " + kind.value ());
81
+ }
82
+ }
83
+ std::optional<std::string> kv_append_metadata_encoded =
84
+ json::LookupOptional<std::string>(config, " kv_append_metadata" );
85
+ if (kv_append_metadata_encoded.has_value ()) {
86
+ picojson::value parse_result;
87
+ std::string err =
88
+ picojson::parse (parse_result, Base64Decode (kv_append_metadata_encoded.value ()));
89
+ if (!err.empty ()) {
90
+ return TResult::Error (" kv_append_metadata parse error: " + err);
91
+ }
92
+ if (!parse_result.is <picojson::array>()) {
93
+ return TResult::Error (" kv_append_metadata is not array of integer." );
94
+ }
95
+ picojson::array kv_append_metadata_arr = parse_result.get <picojson::array>();
96
+ std::vector<IntTuple> kv_append_metadata;
97
+ int ptr = 0 ;
98
+ while (ptr < static_cast <int >(kv_append_metadata_arr.size ())) {
99
+ if (!kv_append_metadata_arr[ptr].is <int64_t >()) {
100
+ return TResult::Error (" Invalid kv append metadata value in kv_append_metadata array" );
101
+ }
102
+ int num_segments = kv_append_metadata_arr[ptr].get <int64_t >();
103
+ if (ptr + num_segments * 2 + 1 > static_cast <int >(kv_append_metadata_arr.size ())) {
104
+ return TResult::Error (" Invalid kv append metadata compression in kv_append_metadata" );
105
+ }
106
+ std::vector<int64_t > compressed_kv_append_metadata{num_segments};
107
+ compressed_kv_append_metadata.reserve (num_segments * 2 + 1 );
108
+ for (int i = 1 ; i <= num_segments * 2 ; ++i) {
109
+ if (!kv_append_metadata_arr[ptr + i].is <int64_t >()) {
110
+ return TResult::Error (" Invalid kv append metadata value in kv_append_metadata array" );
111
+ }
112
+ compressed_kv_append_metadata.push_back (kv_append_metadata_arr[ptr + i].get <int64_t >());
113
+ }
114
+ kv_append_metadata.push_back (IntTuple (std::move (compressed_kv_append_metadata)));
115
+ ptr += num_segments * 2 + 1 ;
116
+ }
117
+ res.kv_append_metadata = std::move (kv_append_metadata);
118
+ }
119
+ res.kv_window_begin = json::LookupOptional<int64_t >(config, " kv_window_begin" );
120
+ res.kv_window_end = json::LookupOptional<int64_t >(config, " kv_window_end" );
121
+ res.dst_group_offset = json::LookupOptional<int64_t >(config, " dst_group_offset" );
122
+ return TResult::Ok (res);
123
+ }
124
+
125
+ picojson::object DisaggConfig::AsJSON () const {
126
+ picojson::object config;
127
+ switch (kind) {
128
+ case DisaggRequestKind::kPreparePrefill : {
129
+ config[" kind" ] = picojson::value (" prepare_prefill" );
130
+ break ;
131
+ }
132
+ case DisaggRequestKind::kRemotePrefill : {
133
+ config[" kind" ] = picojson::value (" remote_prefill" );
134
+ break ;
135
+ }
136
+ case DisaggRequestKind::kStartDecode : {
137
+ config[" kind" ] = picojson::value (" start_decode" );
138
+ break ;
139
+ }
140
+ default :
141
+ break ;
142
+ }
143
+ if (!kv_append_metadata.empty ()) {
144
+ picojson::array kv_append_metadata_arr;
145
+ for (const IntTuple& compressed_kv_append_metadata : kv_append_metadata) {
146
+ for (int64_t value : compressed_kv_append_metadata) {
147
+ kv_append_metadata_arr.push_back (picojson::value (value));
148
+ }
149
+ }
150
+ config[" kv_append_metadata" ] =
151
+ picojson::value (Base64Encode (picojson::value (kv_append_metadata_arr).serialize ()));
152
+ }
153
+ if (kv_window_begin.has_value ()) {
154
+ config[" kv_window_begin" ] = picojson::value (static_cast <int64_t >(kv_window_begin.value ()));
155
+ }
156
+ if (kv_window_end.has_value ()) {
157
+ config[" kv_window_end" ] = picojson::value (static_cast <int64_t >(kv_window_end.value ()));
158
+ }
159
+ if (dst_group_offset.has_value ()) {
160
+ config[" dst_group_offset" ] = picojson::value (static_cast <int64_t >(dst_group_offset.value ()));
161
+ }
162
+ return config;
163
+ }
164
+
65
165
/* ***************** DebugConfig ******************/
66
166
67
167
Result<DebugConfig> DebugConfig::FromJSON (const picojson::object& config) {
@@ -74,7 +174,7 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
74
174
if (special_request == " query_engine_metrics" ) {
75
175
res.special_request = SpecialRequestKind::kQueryEngineMetrics ;
76
176
} else {
77
- return TResult::Error (" Uknown special request " + special_request);
177
+ return TResult::Error (" Unknown special request " + special_request);
78
178
}
79
179
}
80
180
std::string grammar_execution_mode =
@@ -84,8 +184,14 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
84
184
} else if (grammar_execution_mode == " constraint" ) {
85
185
res.grammar_execution_mode = GrammarExecutionMode::kConstraint ;
86
186
} else {
87
- return TResult::Error (" Uknown grammar execution mode " + grammar_execution_mode);
187
+ return TResult::Error (" Unknown grammar execution mode " + grammar_execution_mode);
88
188
}
189
+ Result<DisaggConfig> disagg_config =
190
+ DisaggConfig::FromJSON (json::Lookup<picojson::object>(config, " disagg_config" ));
191
+ if (disagg_config.IsErr ()) {
192
+ return TResult::Error (disagg_config.UnwrapErr ());
193
+ }
194
+ res.disagg_config = disagg_config.Unwrap ();
89
195
return TResult::Ok (res);
90
196
}
91
197
@@ -114,6 +220,9 @@ picojson::object DebugConfig::AsJSON() const {
114
220
break ;
115
221
}
116
222
}
223
+ if (disagg_config.kind != DisaggRequestKind::kNone ) {
224
+ config[" disagg_config" ] = picojson::value (disagg_config.AsJSON ());
225
+ }
117
226
return config;
118
227
}
119
228
0 commit comments