|
2 | 2 | #include <chrono> |
3 | 3 | #include <filesystem> |
4 | 4 | #include <fstream> |
| 5 | +#include <future> |
5 | 6 | #include <iomanip> |
6 | 7 | #include <iostream> |
7 | 8 | #include <mutex> |
@@ -384,6 +385,18 @@ int main(int argc, const char** argv) { |
384 | 385 | return httplib::Server::HandlerResponse::Unhandled; |
385 | 386 | }); |
386 | 387 |
|
| 388 | + auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) { |
| 389 | + std::future_status ft_status; |
| 390 | + do { |
| 391 | + if (!ft.valid()) |
| 392 | + break; |
| 393 | + ft_status = ft.wait_for(std::chrono::milliseconds(1000)); |
| 394 | + if (req.is_connection_closed()) { |
| 395 | + sd_cancel_generation(sd_ctx, SD_CANCEL_ALL); |
| 396 | + } |
| 397 | + } while (ft_status != std::future_status::ready); |
| 398 | + }; |
| 399 | + |
387 | 400 | // index html |
388 | 401 | std::string index_html; |
389 | 402 | #ifdef HAVE_INDEX_HTML |
@@ -532,11 +545,13 @@ int main(int argc, const char** argv) { |
532 | 545 | sd_image_t* results = nullptr; |
533 | 546 | int num_results = 0; |
534 | 547 |
|
535 | | - { |
| 548 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
536 | 549 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
537 | 550 | results = generate_image(sd_ctx, &img_gen_params); |
538 | 551 | num_results = gen_params.batch_count; |
539 | | - } |
| 552 | + }); |
| 553 | + |
| 554 | + wait_for_generation(ft, sd_ctx, req); |
540 | 555 |
|
541 | 556 | for (int i = 0; i < num_results; i++) { |
542 | 557 | if (results[i].data == nullptr) { |
@@ -779,11 +794,13 @@ int main(int argc, const char** argv) { |
779 | 794 | sd_image_t* results = nullptr; |
780 | 795 | int num_results = 0; |
781 | 796 |
|
782 | | - { |
| 797 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
783 | 798 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
784 | 799 | results = generate_image(sd_ctx, &img_gen_params); |
785 | 800 | num_results = gen_params.batch_count; |
786 | | - } |
| 801 | + }); |
| 802 | + |
| 803 | + wait_for_generation(ft, sd_ctx, req); |
787 | 804 |
|
788 | 805 | json out; |
789 | 806 | out["created"] = static_cast<long long>(std::time(nullptr)); |
@@ -1095,11 +1112,13 @@ int main(int argc, const char** argv) { |
1095 | 1112 | sd_image_t* results = nullptr; |
1096 | 1113 | int num_results = 0; |
1097 | 1114 |
|
1098 | | - { |
| 1115 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
1099 | 1116 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
1100 | 1117 | results = generate_image(sd_ctx, &img_gen_params); |
1101 | 1118 | num_results = gen_params.batch_count; |
1102 | | - } |
| 1119 | + }); |
| 1120 | + |
| 1121 | + wait_for_generation(ft, sd_ctx, req); |
1103 | 1122 |
|
1104 | 1123 | json out; |
1105 | 1124 | out["images"] = json::array(); |
|
0 commit comments