diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 07b59bb8a..c1978f34f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -167,7 +167,7 @@ class StableDiffusionGGML { for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { backend = ggml_backend_vk_init(device); } - if(!backend) { + if (!backend) { LOG_WARN("Failed to initialize Vulkan backend"); } #endif @@ -181,7 +181,7 @@ class StableDiffusionGGML { backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN) +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); @@ -762,7 +762,8 @@ class StableDiffusionGGML { sample_method_t method, const std::vector& sigmas, int start_merge_step, - SDCondition id_cond) { + SDCondition id_cond, + size_t batch_num = 0) { size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -885,6 +886,9 @@ class StableDiffusionGGML { pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } + + send_result_step_callback(denoised, batch_num, step); + return denoised; }; @@ -998,6 +1002,47 @@ class StableDiffusionGGML { ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x) { return compute_first_stage(work_ctx, x, true); } + + sd_result_cb_t result_cb = nullptr; + void* result_cb_data = nullptr; + + void send_result_callback(ggml_context* work_ctx, ggml_tensor* x, size_t number) { + if (result_cb == nullptr) { + return; + } + + struct ggml_tensor* img = decode_first_stage(work_ctx, x); + auto image_data = sd_tensor_to_image(img); + + result_cb(number, image_data, result_cb_data); + } + + sd_result_step_cb_t result_step_cb = nullptr; + void* result_step_cb_data = nullptr; + + void send_result_step_callback(ggml_tensor* x, size_t number, size_t step) { + if (result_step_cb == nullptr) { + return; + } + + struct ggml_init_params params {}; + params.mem_size = static_cast(10 * 1024) * 1024; + params.mem_buffer = nullptr; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + if (!work_ctx) { + return; + } + + struct ggml_tensor* result = ggml_dup_tensor(work_ctx, x); + copy_ggml_tensor(result, x); + + struct ggml_tensor* img = decode_first_stage(work_ctx, result); + result_step_cb(number, step, sd_tensor_to_image(img), result_step_cb_data); + + ggml_free(work_ctx); + } }; /*================================================= SD API ==================================================*/ @@ -1081,6 +1126,16 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data) { + sd_ctx->sd->result_cb = cb; + sd_ctx->sd->result_cb_data = data; +} + +void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data) { + sd_ctx->sd->result_step_cb = cb; + sd_ctx->sd->result_step_cb_data = data; +} + sd_image_t* generate_image(sd_ctx_t* sd_ctx, struct ggml_context* work_ctx, ggml_tensor* init_latent, @@ -1308,11 +1363,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sample_method, sigmas, start_merge_step, - id_cond); + id_cond, + b); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); + + if (sd_ctx->sd->result_cb != nullptr) { + sd_ctx->sd->send_result_callback(work_ctx, x_0, b); + continue; + } + final_latents.push_back(x_0); } @@ -1322,6 +1384,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int64_t t3 = ggml_time_ms(); LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000); + if (sd_ctx->sd->result_cb != nullptr) { + return nullptr; + } + // Decode to image LOG_INFO("decoding %zu latents", final_latents.size()); std::vector decoded_images; // collect decoded images diff --git a/stable-diffusion.h b/stable-diffusion.h index 0d4cc1fda..d7bd9dc4d 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -107,6 +107,8 @@ enum sd_log_level_t { typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); +typedef void (*sd_result_cb_t)(size_t number, uint8_t* image_data, void* data); +typedef void (*sd_result_step_cb_t)(size_t number, size_t step, uint8_t* image_data, void* data); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); @@ -144,6 +146,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, bool keep_vae_on_cpu); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); +SD_API void sd_ctx_set_result_callback(sd_ctx_t* sd_ctx, sd_result_cb_t cb, void* data); +SD_API void sd_ctx_set_result_step_callback(sd_ctx_t* sd_ctx, sd_result_step_cb_t cb, void* data); SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* prompt,