Skip to content

Commit 6cc2956

Browse files
committed
server : add server_task_type field to server_task_result
This commit adds a server_task_type field to the server_task_result struct. This field is used to identify the type of the server task. The motivation for adding this is that it might allow us to avoid using dynamic_cast's when checking the type of the server_task_result. For example, this could then be replaced with checks like this: ```c++ GGML_ASSERT(result.get() != nullptr); GGML_ASSERT(result.get()->get_server_task_type() == type); ```
1 parent 2d2d076 commit 6cc2956

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

examples/server/server.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ enum server_task_type {
6868
SERVER_TASK_TYPE_SLOT_RESTORE,
6969
SERVER_TASK_TYPE_SLOT_ERASE,
7070
SERVER_TASK_TYPE_SET_LORA,
71+
SERVER_TASK_TYPE_NONE,
7172
};
7273

7374
enum oaicompat_type {
@@ -480,6 +481,7 @@ struct result_timings {
480481
struct server_task_result {
481482
int id = -1;
482483
int id_slot = -1;
484+
server_task_type type;
483485
virtual bool is_error() {
484486
// only used by server_task_result_error
485487
return false;
@@ -491,6 +493,7 @@ struct server_task_result {
491493
virtual int get_index() {
492494
return -1;
493495
}
496+
virtual server_task_type get_server_task_type() = 0;
494497
virtual json to_json() = 0;
495498
virtual ~server_task_result() = default;
496499
};
@@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {
794797

795798
return ret;
796799
}
800+
801+
server_task_type get_server_task_type() {
802+
return SERVER_TASK_TYPE_COMPLETION;
803+
}
797804
};
798805

799806
struct server_task_result_cmpl_partial : server_task_result {
@@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {
962969

963970
return std::vector<json>({ret});
964971
}
972+
973+
server_task_type get_server_task_type() {
974+
return SERVER_TASK_TYPE_NONE;
975+
}
965976
};
966977

967978
struct server_task_result_embd : server_task_result {
@@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
9971008
{"tokens_evaluated", n_tokens},
9981009
};
9991010
}
1011+
1012+
server_task_type get_server_task_type() {
1013+
return SERVER_TASK_TYPE_EMBEDDING;
1014+
}
10001015
};
10011016

10021017
struct server_task_result_rerank : server_task_result {
@@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
10161031
{"tokens_evaluated", n_tokens},
10171032
};
10181033
}
1034+
1035+
server_task_type get_server_task_type() {
1036+
return SERVER_TASK_TYPE_RERANK;
1037+
}
10191038
};
10201039

10211040
// this function maybe used outside of server_task_result_error
@@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
10711090
virtual json to_json() override {
10721091
return format_error_response(err_msg, err_type);
10731092
}
1093+
1094+
server_task_type get_server_task_type() {
1095+
return SERVER_TASK_TYPE_NONE;
1096+
}
10741097
};
10751098

10761099
struct server_task_result_metrics : server_task_result {
@@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
11271150
{ "slots", slots_data },
11281151
};
11291152
}
1153+
1154+
server_task_type get_server_task_type() {
1155+
return SERVER_TASK_TYPE_METRICS;
1156+
}
11301157
};
11311158

11321159
struct server_task_result_slot_save_load : server_task_result {
@@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
11601187
};
11611188
}
11621189
}
1190+
1191+
server_task_type get_server_task_type() {
1192+
return SERVER_TASK_TYPE_SLOT_SAVE;
1193+
}
11631194
};
11641195

11651196
struct server_task_result_slot_erase : server_task_result {
@@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
11711202
{ "n_erased", n_erased },
11721203
};
11731204
}
1205+
1206+
server_task_type get_server_task_type() {
1207+
return SERVER_TASK_TYPE_SLOT_ERASE;
1208+
}
11741209
};
11751210

11761211
struct server_task_result_apply_lora : server_task_result {
11771212
virtual json to_json() override {
11781213
return json {{ "success", true }};
11791214
}
1215+
1216+
server_task_type get_server_task_type() {
1217+
return SERVER_TASK_TYPE_NONE;
1218+
}
11801219
};
11811220

11821221
struct server_slot {
@@ -2751,6 +2790,11 @@ struct server_context {
27512790
res->id = task.id;
27522791
queue_results.send(std::move(res));
27532792
} break;
2793+
case SERVER_TASK_TYPE_NONE:
2794+
{
2795+
// do nothing
2796+
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
2797+
} break;
27542798
}
27552799
}
27562800

@@ -3693,12 +3737,8 @@ int main(int argc, char ** argv) {
36933737
res_error(res, result->to_json());
36943738
return;
36953739
}
3696-
3697-
if (type == SERVER_TASK_TYPE_SLOT_SAVE) {
3698-
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
3699-
} else if (type == SERVER_TASK_TYPE_SLOT_ERASE) {
3700-
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
3701-
}
3740+
GGML_ASSERT(result.get() != nullptr);
3741+
GGML_ASSERT(result.get()->get_server_task_type() == type);
37023742
res_ok(res, result->to_json());
37033743
};
37043744

0 commit comments

Comments
 (0)