|
23 | 23 | #include "latent-preview.h" |
24 | 24 | #include "name_conversion.h" |
25 | 25 |
|
| 26 | +#include <atomic> |
| 27 | + |
26 | 28 | const char* model_version_to_str[] = { |
27 | 29 | "SD 1.x", |
28 | 30 | "SD 1.x Inpaint", |
@@ -106,6 +108,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { |
106 | 108 |
|
107 | 109 | /*=============================================== StableDiffusionGGML ================================================*/ |
108 | 110 |
|
| 111 | +static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free, |
| 112 | + "sd_cancel_mode_t must be lock-free"); |
| 113 | + |
109 | 114 | class StableDiffusionGGML { |
110 | 115 | public: |
111 | 116 | ggml_backend_t backend = nullptr; // general backend |
@@ -171,6 +176,20 @@ class StableDiffusionGGML { |
171 | 176 | ggml_backend_free(backend); |
172 | 177 | } |
173 | 178 |
|
| 179 | + std::atomic<sd_cancel_mode_t> cancellation_flag; |
| 180 | + |
| 181 | + void set_cancel_flag(enum sd_cancel_mode_t flag) { |
| 182 | + cancellation_flag.store(flag, std::memory_order_release); |
| 183 | + } |
| 184 | + |
| 185 | + void reset_cancel_flag() { |
| 186 | + set_cancel_flag(SD_CANCEL_RESET); |
| 187 | + } |
| 188 | + |
| 189 | + enum sd_cancel_mode_t get_cancel_flag() { |
| 190 | + return cancellation_flag.load(std::memory_order_acquire); |
| 191 | + } |
| 192 | + |
174 | 193 | void init_backend() { |
175 | 194 | #ifdef SD_USE_CUDA |
176 | 195 | LOG_DEBUG("Using CUDA backend"); |
@@ -1646,6 +1665,12 @@ class StableDiffusionGGML { |
1646 | 1665 | SamplePreviewContext preview = prepare_sample_preview_context(); |
1647 | 1666 |
|
1648 | 1667 | auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::Tensor<float> { |
| 1668 | + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); |
| 1669 | + if (cancel_flag != SD_CANCEL_RESET) { |
| 1670 | + LOG_DEBUG("cancelling generation"); |
| 1671 | + return {}; |
| 1672 | + } |
| 1673 | + |
1649 | 1674 | if (step == 1 || step == -1) { |
1650 | 1675 | pretty_progress(0, (int)steps, 0); |
1651 | 1676 | } |
@@ -2480,6 +2505,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { |
2480 | 2505 | free(sd_ctx); |
2481 | 2506 | } |
2482 | 2507 |
|
| 2508 | +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) { |
| 2509 | + if (sd_ctx && sd_ctx->sd) { |
| 2510 | + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { |
| 2511 | + mode = SD_CANCEL_ALL; |
| 2512 | + } |
| 2513 | + sd_ctx->sd->set_cancel_flag(mode); |
| 2514 | + } |
| 2515 | +} |
| 2516 | + |
2483 | 2517 | SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { |
2484 | 2518 | if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { |
2485 | 2519 | return false; |
@@ -3222,6 +3256,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, |
3222 | 3256 | int64_t t0 = ggml_time_ms(); |
3223 | 3257 |
|
3224 | 3258 | for (size_t i = 0; i < final_latents.size(); i++) { |
| 3259 | + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { |
| 3260 | + LOG_ERROR("cancelling latent decodings"); |
| 3261 | + break; |
| 3262 | + } |
3225 | 3263 | int64_t t1 = ggml_time_ms(); |
3226 | 3264 | sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]); |
3227 | 3265 | if (image.empty()) { |
@@ -3389,6 +3427,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s |
3389 | 3427 | return nullptr; |
3390 | 3428 | } |
3391 | 3429 |
|
| 3430 | + sd_ctx->sd->reset_cancel_flag(); |
| 3431 | + |
3392 | 3432 | int64_t t0 = ggml_time_ms(); |
3393 | 3433 | sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; |
3394 | 3434 | GenerationRequest request(sd_ctx, sd_img_gen_params); |
@@ -3424,6 +3464,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s |
3424 | 3464 | std::vector<sd::Tensor<float>> final_latents; |
3425 | 3465 | int64_t denoise_start = ggml_time_ms(); |
3426 | 3466 | for (int b = 0; b < request.batch_count; b++) { |
| 3467 | + sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag(); |
| 3468 | + if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) { |
| 3469 | + LOG_ERROR("cancelling generation"); |
| 3470 | + break; |
| 3471 | + } |
| 3472 | + |
3427 | 3473 | int64_t sampling_start = ggml_time_ms(); |
3428 | 3474 | int64_t cur_seed = request.seed + b; |
3429 | 3475 | LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); |
@@ -3876,6 +3922,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s |
3876 | 3922 | if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { |
3877 | 3923 | return nullptr; |
3878 | 3924 | } |
| 3925 | + |
| 3926 | + sd_ctx->sd->reset_cancel_flag(); |
| 3927 | + |
3879 | 3928 | if (num_frames_out != nullptr) { |
3880 | 3929 | *num_frames_out = 0; |
3881 | 3930 | } |
|
0 commit comments