Skip to content

Commit 4ec0e0f

Browse files
committed
now accept multiple images for reference images
1 parent 2e14338 commit 4ec0e0f

File tree

6 files changed

+128
-76
lines changed

6 files changed

+128
-76
lines changed

expose.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ struct sd_generation_inputs
181181
const char * negative_prompt = nullptr;
182182
const char * init_images = "";
183183
const char * mask = "";
184-
const char * extra_image = "";
184+
const int extra_images_len = 0;
185+
const char ** extra_images = nullptr;
185186
const bool flip_mask = false;
186187
const float denoising_strength = 0.0f;
187188
const float cfg_scale = 0.0f;

kcpp_sdui.embd

Lines changed: 16 additions & 16 deletions
Large diffs are not rendered by default.

koboldcpp.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ban_token_max = 768
6060
logit_bias_max = 512
6161
dry_seq_break_max = 128
62+
extra_images_max = 4
6263

6364
# global vars
6465
KcppVersion = "1.94.2"
@@ -291,7 +292,8 @@ class sd_generation_inputs(ctypes.Structure):
291292
("negative_prompt", ctypes.c_char_p),
292293
("init_images", ctypes.c_char_p),
293294
("mask", ctypes.c_char_p),
294-
("extra_image", ctypes.c_char_p),
295+
("extra_images_len", ctypes.c_int),
296+
("extra_images", ctypes.POINTER(ctypes.c_char_p)),
295297
("flip_mask", ctypes.c_bool),
296298
("denoising_strength", ctypes.c_float),
297299
("cfg_scale", ctypes.c_float),
@@ -1714,7 +1716,9 @@ def sd_generate(genparams):
17141716
seed = random.randint(100000, 999999)
17151717
sample_method = genparams.get("sampler_name", "k_euler_a")
17161718
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
1717-
extra_image = strip_base64_prefix(genparams.get("extra_image", ""))
1719+
extra_images_arr = genparams.get("extra_images", [])
1720+
extra_images_arr = ([] if not extra_images_arr else extra_images_arr)
1721+
extra_images_arr = extra_images_arr[:extra_images_max]
17181722

17191723
#clean vars
17201724
cfg_scale = (1 if cfg_scale < 1 else (25 if cfg_scale > 25 else cfg_scale))
@@ -1728,7 +1732,11 @@ def sd_generate(genparams):
17281732
inputs.negative_prompt = negative_prompt.encode("UTF-8")
17291733
inputs.init_images = init_images.encode("UTF-8")
17301734
inputs.mask = "".encode("UTF-8") if not mask else mask.encode("UTF-8")
1731-
inputs.extra_image = "".encode("UTF-8") if not extra_image else extra_image.encode("UTF-8")
1735+
inputs.extra_images_len = len(extra_images_arr)
1736+
inputs.extra_images = (ctypes.c_char_p * inputs.extra_images_len)()
1737+
for n, estr in enumerate(extra_images_arr):
1738+
extra_image = strip_base64_prefix(estr)
1739+
inputs.extra_images[n] = extra_image.encode("UTF-8")
17321740
inputs.flip_mask = flip_mask
17331741
inputs.cfg_scale = cfg_scale
17341742
inputs.denoising_strength = denoising_strength

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ static int sddebugmode = 0;
116116
static std::string recent_data = "";
117117
static uint8_t * input_image_buffer = NULL;
118118
static uint8_t * input_mask_buffer = NULL;
119-
static uint8_t * input_extraimage_buffer = NULL;
119+
static std::vector<uint8_t *> input_extraimage_buffers;
120+
const int max_extra_images = 4;
120121

121122
static std::string sdplatformenv, sddeviceenv, sdvulkandeviceenv;
122123
static int cfg_tiled_vae_threshold = 0;
@@ -288,8 +289,9 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
288289
sd_ctx->sd->apply_lora_from_file(lorafilename,inputs.lora_multiplier);
289290
}
290291

291-
return true;
292+
input_extraimage_buffers.reserve(max_extra_images);
292293

294+
return true;
293295
}
294296

295297
std::string clean_input_prompt(const std::string& input) {
@@ -434,7 +436,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
434436
std::string cleannegprompt = clean_input_prompt(inputs.negative_prompt);
435437
std::string img2img_data = std::string(inputs.init_images);
436438
std::string img2img_mask = std::string(inputs.mask);
437-
std::string extra_image_data = std::string(inputs.extra_image);
439+
std::vector<std::string> extra_image_data;
440+
for(int i=0;i<inputs.extra_images_len;++i)
441+
{
442+
extra_image_data.push_back(std::string(inputs.extra_images[i]));
443+
}
444+
438445
std::string sampler = inputs.sample_method;
439446

440447
sd_params->prompt = cleanprompt;
@@ -503,17 +510,20 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
503510

504511
//for img2img
505512
sd_image_t input_image = {0,0,0,nullptr};
506-
sd_image_t extraimage_reference = {0,0,0,nullptr};
513+
std::vector<sd_image_t> extraimage_references;
514+
extraimage_references.reserve(max_extra_images);
507515
std::vector<uint8_t> image_buffer;
508516
std::vector<uint8_t> image_mask_buffer;
509-
std::vector<uint8_t> extraimage_buffer;
517+
std::vector<std::vector<uint8_t>> extraimage_buffers;
518+
extraimage_buffers.reserve(max_extra_images);
519+
510520
int nx, ny, nc;
511521
int img2imgW = sd_params->width; //for img2img input
512522
int img2imgH = sd_params->height;
513523
int img2imgC = 3; // Assuming RGB image
514524
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
515525
std::vector<uint8_t> resized_mask_buf(img2imgW * img2imgH * img2imgC);
516-
std::vector<uint8_t> resized_extraimage_buf(img2imgW * img2imgH * img2imgC);
526+
std::vector<std::vector<uint8_t>> resized_extraimage_bufs(max_extra_images, std::vector<uint8_t>(img2imgW * img2imgH * img2imgC));
517527

518528
std::string ts = get_timestamp_str();
519529
if(!sd_is_quiet)
@@ -558,29 +568,39 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
558568
sd_params->sample_method = sample_method_t::EULER_A;
559569
}
560570

561-
if(extra_image_data!="")
571+
if(extra_image_data.size()>0)
562572
{
563-
if(input_extraimage_buffer!=nullptr) //just in time free old buffer
573+
if(input_extraimage_buffers.size()>0) //just in time free old buffer
564574
{
565-
stbi_image_free(input_extraimage_buffer);
566-
input_extraimage_buffer = nullptr;
575+
for(int i=0;i<input_extraimage_buffers.size();++i)
576+
{
577+
stbi_image_free(input_extraimage_buffers[i]);
578+
}
579+
input_extraimage_buffers.clear();
567580
}
568-
int nx2, ny2, nc2;
569-
int desiredchannels = 3;
570-
extraimage_buffer = kcpp_base64_decode(extra_image_data);
571-
input_extraimage_buffer = stbi_load_from_memory(extraimage_buffer.data(), extraimage_buffer.size(), &nx2, &ny2, &nc2, desiredchannels);
572-
// Resize the image
573-
int resok = stbir_resize_uint8(input_extraimage_buffer, nx2, ny2, 0, resized_extraimage_buf.data(), img2imgW, img2imgH, 0, desiredchannels);
574-
if (!resok) {
575-
printf("\nKCPP SD: resize extra image failed!\n");
576-
output.data = "";
577-
output.status = 0;
578-
return output;
581+
extraimage_buffers.clear();
582+
extraimage_references.clear();
583+
for(int i=0;i<extra_image_data.size() && i<max_extra_images;++i)
584+
{
585+
int nx2, ny2, nc2;
586+
int desiredchannels = 3;
587+
extraimage_buffers.push_back(kcpp_base64_decode(extra_image_data[i]));
588+
input_extraimage_buffers.push_back(stbi_load_from_memory(extraimage_buffers[i].data(), extraimage_buffers[i].size(), &nx2, &ny2, &nc2, desiredchannels));
589+
// Resize the image
590+
int resok = stbir_resize_uint8(input_extraimage_buffers[i], nx2, ny2, 0, resized_extraimage_bufs[i].data(), img2imgW, img2imgH, 0, desiredchannels);
591+
if (!resok) {
592+
printf("\nKCPP SD: resize extra image failed!\n");
593+
output.data = "";
594+
output.status = 0;
595+
return output;
596+
}
597+
sd_image_t extraimage_reference;
598+
extraimage_reference.width = img2imgW;
599+
extraimage_reference.height = img2imgH;
600+
extraimage_reference.channel = desiredchannels;
601+
extraimage_reference.data = resized_extraimage_bufs[i].data();
602+
extraimage_references.push_back(extraimage_reference);
579603
}
580-
extraimage_reference.width = img2imgW;
581-
extraimage_reference.height = img2imgH;
582-
extraimage_reference.channel = desiredchannels;
583-
extraimage_reference.data = resized_extraimage_buf.data();
584604

585605
//ensure prompt has img keyword, otherwise append it
586606
if(photomaker_enabled)
@@ -595,9 +615,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
595615
}
596616

597617
std::vector<sd_image_t> kontext_imgs;
598-
if(extra_image_data!="" && loadedsdver==SDVersion::VERSION_FLUX && !sd_loaded_chroma())
618+
if(extra_image_data.size()>0 && loadedsdver==SDVersion::VERSION_FLUX && !sd_loaded_chroma())
599619
{
600-
kontext_imgs.push_back(extraimage_reference);
620+
for(int i=0;i<extra_image_data.size();++i)
621+
{
622+
kontext_imgs.push_back(extraimage_references[i]);
623+
}
624+
if(!sd_is_quiet && sddebugmode==1)
625+
{
626+
printf("\nFlux Kontext: Using %d reference images\n",kontext_imgs.size());
627+
}
628+
}
629+
630+
std::vector<sd_image_t*> photomaker_imgs;
631+
if(photomaker_enabled && extra_image_data.size()>0)
632+
{
633+
for(int i=0;i<extra_image_data.size();++i)
634+
{
635+
photomaker_imgs.push_back(&extraimage_references[i]);
636+
}
637+
if(!sd_is_quiet && sddebugmode==1)
638+
{
639+
printf("\nPhotomaker: Using %d reference images\n",photomaker_imgs.size());
640+
}
601641
}
602642

603643
if (sd_params->mode == TXT2IMG) {
@@ -644,7 +684,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
644684
sd_params->slg_scale,
645685
sd_params->skip_layer_start,
646686
sd_params->skip_layer_end,
647-
(photomaker_enabled && extra_image_data!=""?(&extraimage_reference):nullptr));
687+
photomaker_imgs);
648688
} else {
649689

650690
if (sd_params->width <= 0 || sd_params->width % 64 != 0 || sd_params->height <= 0 || sd_params->height % 64 != 0) {
@@ -769,7 +809,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
769809
sd_params->slg_scale,
770810
sd_params->skip_layer_start,
771811
sd_params->skip_layer_end,
772-
(photomaker_enabled && extra_image_data!=""?(&extraimage_reference):nullptr));
812+
photomaker_imgs);
773813
}
774814

775815
if (results == NULL) {

otherarch/sdcpp/stable-diffusion.cpp

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14221422
float skip_layer_start = 0.01,
14231423
float skip_layer_end = 0.2,
14241424
ggml_tensor* masked_image = NULL,
1425-
const sd_image_t* photomaker_reference = nullptr) {
1425+
const std::vector<sd_image_t*> photomaker_references = std::vector<sd_image_t*>()) {
14261426
if (seed < 0) {
14271427
// Generally, when using the provided command line, the seed is always >0.
14281428
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1465,7 +1465,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14651465
ggml_tensor* init_img = NULL;
14661466
SDCondition id_cond;
14671467
std::vector<bool> class_tokens_mask;
1468-
if (sd_ctx->sd->pmid_model && photomaker_reference!=nullptr)
1468+
if (sd_ctx->sd->pmid_model && photomaker_references.size()>0)
14691469
{
14701470
sd_ctx->sd->stacked_id = true; //turn on photomaker if needed
14711471
}
@@ -1512,26 +1512,29 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15121512
}
15131513
}
15141514

1515-
// handle single photomaker image passed in by kcpp
1516-
if (sd_ctx->sd->pmid_model && photomaker_reference!=nullptr)
1515+
// handle multiple photomaker image passed in by kcpp
1516+
if (sd_ctx->sd->pmid_model && photomaker_references.size()>0)
15171517
{
1518-
int c = 0;
1519-
int width, height;
1520-
width = photomaker_reference->width;
1521-
height = photomaker_reference->height;
1522-
c = photomaker_reference->channel;
1523-
uint8_t* input_image_buffer = photomaker_reference->data;
1524-
sd_image_t* input_image = NULL;
1525-
input_image = new sd_image_t{(uint32_t)width,
1526-
(uint32_t)height,
1527-
3,
1528-
input_image_buffer};
1529-
input_image = preprocess_id_image(input_image);
1530-
if (input_image == NULL) {
1531-
LOG_ERROR("\npreprocess input id image from kcpp photomaker failed\n");
1532-
} else {
1533-
LOG_INFO("\nPhotoMaker loaded image from kcpp\n");
1534-
input_id_images.push_back(input_image);
1518+
for(int i=0;i<photomaker_references.size();++i)
1519+
{
1520+
int c = 0;
1521+
int width, height;
1522+
width = photomaker_references[i]->width;
1523+
height = photomaker_references[i]->height;
1524+
c = photomaker_references[i]->channel;
1525+
uint8_t* input_image_buffer = photomaker_references[i]->data;
1526+
sd_image_t* input_image = NULL;
1527+
input_image = new sd_image_t{(uint32_t)width,
1528+
(uint32_t)height,
1529+
3,
1530+
input_image_buffer};
1531+
input_image = preprocess_id_image(input_image);
1532+
if (input_image == NULL) {
1533+
LOG_ERROR("\npreprocess input id image from kcpp photomaker failed\n");
1534+
} else {
1535+
LOG_INFO("\nPhotoMaker loaded image from kcpp\n");
1536+
input_id_images.push_back(input_image);
1537+
}
15351538
}
15361539
}
15371540

@@ -1790,7 +1793,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17901793
float slg_scale = 0,
17911794
float skip_layer_start = 0.01,
17921795
float skip_layer_end = 0.2,
1793-
const sd_image_t* photomaker_reference = nullptr) {
1796+
const std::vector<sd_image_t*> photomaker_references = std::vector<sd_image_t*>()) {
17941797
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
17951798
LOG_DEBUG("txt2img %dx%d", width, height);
17961799
if (sd_ctx == NULL) {
@@ -1887,7 +1890,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
18871890
skip_layer_start,
18881891
skip_layer_end,
18891892
nullptr,
1890-
photomaker_reference);
1893+
photomaker_references);
18911894

18921895
size_t t1 = ggml_time_ms();
18931896

@@ -1924,7 +1927,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19241927
float slg_scale = 0,
19251928
float skip_layer_start = 0.01,
19261929
float skip_layer_end = 0.2,
1927-
const sd_image_t* photomaker_reference = nullptr) {
1930+
const std::vector<sd_image_t*> photomaker_references = std::vector<sd_image_t*>()) {
19281931
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
19291932
LOG_DEBUG("img2img %dx%d", width, height);
19301933
if (sd_ctx == NULL) {
@@ -2089,7 +2092,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
20892092
skip_layer_start,
20902093
skip_layer_end,
20912094
masked_image,
2092-
photomaker_reference);
2095+
photomaker_references);
20932096

20942097
size_t t2 = ggml_time_ms();
20952098

otherarch/sdcpp/stable-diffusion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
183183
float slg_scale,
184184
float skip_layer_start,
185185
float skip_layer_end,
186-
const sd_image_t* photomaker_reference);
186+
const std::vector<sd_image_t*> photomaker_references);
187187

188188
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
189189
sd_image_t init_image,
@@ -213,7 +213,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
213213
float slg_scale,
214214
float skip_layer_start,
215215
float skip_layer_end,
216-
const sd_image_t* photomaker_reference);
216+
const std::vector<sd_image_t*> photomaker_references);
217217

218218
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
219219
sd_image_t init_image,

0 commit comments

Comments
 (0)