Skip to content

Commit 2d09c83

Browse files
committed
feat: support for canceling the ongoing generation
1 parent b8bdffc commit 2d09c83

2 files changed

Lines changed: 57 additions & 0 deletions

File tree

include/stable-diffusion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,14 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
414414
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
415415
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
416416

417+
enum sd_cancel_mode_t {
418+
SD_CANCEL_ALL,
419+
SD_CANCEL_NEW_LATENTS,
420+
SD_CANCEL_RESET
421+
};
422+
423+
SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);
424+
417425
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
418426
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
419427

src/stable-diffusion.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "latent-preview.h"
2424
#include "name_conversion.h"
2525

26+
#include <atomic>
27+
2628
const char* model_version_to_str[] = {
2729
"SD 1.x",
2830
"SD 1.x Inpaint",
@@ -106,6 +108,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
106108

107109
/*=============================================== StableDiffusionGGML ================================================*/
108110

111+
static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
112+
"sd_cancel_mode_t must be lock-free");
113+
109114
class StableDiffusionGGML {
110115
public:
111116
ggml_backend_t backend = nullptr; // general backend
@@ -171,6 +176,20 @@ class StableDiffusionGGML {
171176
ggml_backend_free(backend);
172177
}
173178

179+
std::atomic<sd_cancel_mode_t> cancellation_flag;
180+
181+
void set_cancel_flag(enum sd_cancel_mode_t flag) {
182+
cancellation_flag.store(flag, std::memory_order_release);
183+
}
184+
185+
void reset_cancel_flag() {
186+
set_cancel_flag(SD_CANCEL_RESET);
187+
}
188+
189+
enum sd_cancel_mode_t get_cancel_flag() {
190+
return cancellation_flag.load(std::memory_order_acquire);
191+
}
192+
174193
void init_backend() {
175194
#ifdef SD_USE_CUDA
176195
LOG_DEBUG("Using CUDA backend");
@@ -1646,6 +1665,12 @@ class StableDiffusionGGML {
16461665
SamplePreviewContext preview = prepare_sample_preview_context();
16471666

16481667
auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::Tensor<float> {
1668+
enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
1669+
if (cancel_flag != SD_CANCEL_RESET) {
1670+
LOG_DEBUG("cancelling generation");
1671+
return {};
1672+
}
1673+
16491674
if (step == 1 || step == -1) {
16501675
pretty_progress(0, (int)steps, 0);
16511676
}
@@ -2480,6 +2505,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
24802505
free(sd_ctx);
24812506
}
24822507

2508+
SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) {
2509+
if (sd_ctx && sd_ctx->sd) {
2510+
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
2511+
mode = SD_CANCEL_ALL;
2512+
}
2513+
sd_ctx->sd->set_cancel_flag(mode);
2514+
}
2515+
}
2516+
24832517
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
24842518
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
24852519
return false;
@@ -3222,6 +3256,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
32223256
int64_t t0 = ggml_time_ms();
32233257

32243258
for (size_t i = 0; i < final_latents.size(); i++) {
3259+
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
3260+
LOG_ERROR("cancelling latent decodings");
3261+
break;
3262+
}
32253263
int64_t t1 = ggml_time_ms();
32263264
sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]);
32273265
if (image.empty()) {
@@ -3389,6 +3427,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
33893427
return nullptr;
33903428
}
33913429

3430+
sd_ctx->sd->reset_cancel_flag();
3431+
33923432
int64_t t0 = ggml_time_ms();
33933433
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
33943434
GenerationRequest request(sd_ctx, sd_img_gen_params);
@@ -3424,6 +3464,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
34243464
std::vector<sd::Tensor<float>> final_latents;
34253465
int64_t denoise_start = ggml_time_ms();
34263466
for (int b = 0; b < request.batch_count; b++) {
3467+
sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag();
3468+
if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) {
3469+
LOG_ERROR("cancelling generation");
3470+
break;
3471+
}
3472+
34273473
int64_t sampling_start = ggml_time_ms();
34283474
int64_t cur_seed = request.seed + b;
34293475
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed);
@@ -3876,6 +3922,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38763922
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
38773923
return nullptr;
38783924
}
3925+
3926+
sd_ctx->sd->reset_cancel_flag();
3927+
38793928
if (num_frames_out != nullptr) {
38803929
*num_frames_out = 0;
38813930
}

0 commit comments

Comments
 (0)