Skip to content

Commit 9e4c4f8

Browse files
wbrunadonington
andcommitted
feat(server): cancel current generation on client disconnect
Co-authored-by: donington <jandastroy@gmail.com>
1 parent 2612f62 commit 9e4c4f8

1 file changed

Lines changed: 25 additions & 6 deletions

File tree

examples/server/main.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <chrono>
33
#include <filesystem>
44
#include <fstream>
5+
#include <future>
56
#include <iomanip>
67
#include <iostream>
78
#include <mutex>
@@ -384,6 +385,18 @@ int main(int argc, const char** argv) {
384385
return httplib::Server::HandlerResponse::Unhandled;
385386
});
386387

388+
auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) {
389+
std::future_status ft_status;
390+
do {
391+
if (!ft.valid())
392+
break;
393+
ft_status = ft.wait_for(std::chrono::milliseconds(1000));
394+
if (req.is_connection_closed()) {
395+
sd_cancel_generation(sd_ctx, SD_CANCEL_ALL);
396+
}
397+
} while (ft_status != std::future_status::ready);
398+
};
399+
387400
// index html
388401
std::string index_html;
389402
#ifdef HAVE_INDEX_HTML
@@ -532,11 +545,13 @@ int main(int argc, const char** argv) {
532545
sd_image_t* results = nullptr;
533546
int num_results = 0;
534547

535-
{
548+
std::future<void> ft = std::async(std::launch::async, [&]() {
536549
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
537550
results = generate_image(sd_ctx, &img_gen_params);
538551
num_results = gen_params.batch_count;
539-
}
552+
});
553+
554+
wait_for_generation(ft, sd_ctx, req);
540555

541556
for (int i = 0; i < num_results; i++) {
542557
if (results[i].data == nullptr) {
@@ -779,11 +794,13 @@ int main(int argc, const char** argv) {
779794
sd_image_t* results = nullptr;
780795
int num_results = 0;
781796

782-
{
797+
std::future<void> ft = std::async(std::launch::async, [&]() {
783798
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
784799
results = generate_image(sd_ctx, &img_gen_params);
785800
num_results = gen_params.batch_count;
786-
}
801+
});
802+
803+
wait_for_generation(ft, sd_ctx, req);
787804

788805
json out;
789806
out["created"] = static_cast<long long>(std::time(nullptr));
@@ -1095,11 +1112,13 @@ int main(int argc, const char** argv) {
10951112
sd_image_t* results = nullptr;
10961113
int num_results = 0;
10971114

1098-
{
1115+
std::future<void> ft = std::async(std::launch::async, [&]() {
10991116
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
11001117
results = generate_image(sd_ctx, &img_gen_params);
11011118
num_results = gen_params.batch_count;
1102-
}
1119+
});
1120+
1121+
wait_for_generation(ft, sd_ctx, req);
11031122

11041123
json out;
11051124
out["images"] = json::array();

0 commit comments

Comments
 (0)