Skip to content

Commit adea272

Browse files
authored
feat(server): use image and command-line dimensions by default on server (#1262)
1 parent 45ce78a commit adea272

1 file changed

Lines changed: 107 additions & 54 deletions

File tree

examples/server/main.cpp

Lines changed: 107 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
404404
std::string size = j.value("size", "");
405405
std::string output_format = j.value("output_format", "png");
406406
int output_compression = j.value("output_compression", 100);
407-
int width = 512;
408-
int height = 512;
407+
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
408+
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
409409
if (!size.empty()) {
410410
auto pos = size.find('x');
411411
if (pos != std::string::npos) {
@@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
593593
n = std::clamp(n, 1, 8);
594594

595595
std::string size = req.form.get_field("size");
596-
int width = 512, height = 512;
596+
int width = -1, height = -1;
597597
if (!size.empty()) {
598598
auto pos = size.find('x');
599599
if (pos != std::string::npos) {
@@ -650,15 +650,31 @@ int main(int argc, const char** argv) {
650650

651651
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
652652

653-
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
654-
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
653+
sd_image_t init_image = {0, 0, 3, nullptr};
654+
sd_image_t control_image = {0, 0, 3, nullptr};
655655
std::vector<sd_image_t> pmid_images;
656656

657+
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
658+
if (gen_params.width > 0)
659+
return gen_params.width;
660+
if (default_gen_params.width > 0)
661+
return default_gen_params.width;
662+
return 512;
663+
};
664+
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
665+
if (gen_params.height > 0)
666+
return gen_params.height;
667+
if (default_gen_params.height > 0)
668+
return default_gen_params.height;
669+
return 512;
670+
};
671+
657672
std::vector<sd_image_t> ref_images;
658673
ref_images.reserve(images_bytes.size());
659674
for (auto& bytes : images_bytes) {
660-
int img_w = width;
661-
int img_h = height;
675+
int img_w;
676+
int img_h;
677+
662678
uint8_t* raw_pixels = load_image_from_memory(
663679
reinterpret_cast<const char*>(bytes.data()),
664680
static_cast<int>(bytes.size()),
@@ -670,22 +686,31 @@ int main(int argc, const char** argv) {
670686
}
671687

672688
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
689+
gen_params.set_width_and_height_if_unset(img.width, img.height);
673690
ref_images.push_back(img);
674691
}
675692

676693
sd_image_t mask_image = {0};
677694
if (!mask_bytes.empty()) {
678-
int mask_w = width;
679-
int mask_h = height;
695+
int expected_width = 0;
696+
int expected_height = 0;
697+
if (gen_params.width_and_height_are_set()) {
698+
expected_width = gen_params.width;
699+
expected_height = gen_params.height;
700+
}
701+
int mask_w;
702+
int mask_h;
703+
680704
uint8_t* mask_raw = load_image_from_memory(
681705
reinterpret_cast<const char*>(mask_bytes.data()),
682706
static_cast<int>(mask_bytes.size()),
683707
mask_w, mask_h,
684-
width, height, 1);
708+
expected_width, expected_height, 1);
685709
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
710+
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
686711
} else {
687-
mask_image.width = width;
688-
mask_image.height = height;
712+
mask_image.width = get_resolved_width();
713+
mask_image.height = get_resolved_height();
689714
mask_image.channel = 1;
690715
mask_image.data = nullptr;
691716
}
@@ -702,8 +727,8 @@ int main(int argc, const char** argv) {
702727
gen_params.auto_resize_ref_image,
703728
gen_params.increase_ref_index,
704729
mask_image,
705-
gen_params.width,
706-
gen_params.height,
730+
get_resolved_width(),
731+
get_resolved_height(),
707732
gen_params.sample_params,
708733
gen_params.strength,
709734
gen_params.seed,
@@ -886,8 +911,6 @@ int main(int argc, const char** argv) {
886911
SDGenerationParams gen_params = default_gen_params;
887912
gen_params.prompt = prompt;
888913
gen_params.negative_prompt = negative_prompt;
889-
gen_params.width = width;
890-
gen_params.height = height;
891914
gen_params.seed = seed;
892915
gen_params.sample_params.sample_steps = steps;
893916
gen_params.batch_count = batch_size;
@@ -905,38 +928,66 @@ int main(int argc, const char** argv) {
905928
gen_params.sample_params.scheduler = scheduler;
906929
}
907930

931+
// re-read to avoid applying 512 as default before the provided
932+
// images and/or server command-line
933+
gen_params.width = j.value("width", -1);
934+
gen_params.height = j.value("height", -1);
935+
908936
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
909937

910-
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
911-
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
912-
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
938+
sd_image_t init_image = {0, 0, 3, nullptr};
939+
sd_image_t control_image = {0, 0, 3, nullptr};
940+
sd_image_t mask_image = {0, 0, 1, nullptr};
913941
std::vector<uint8_t> mask_data;
914942
std::vector<sd_image_t> pmid_images;
915943
std::vector<sd_image_t> ref_images;
916944

917-
if (img2img) {
918-
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
919-
// remove data URI prefix if present ("data:image/png;base64,")
920-
auto comma_pos = encoded.find(',');
921-
if (comma_pos != std::string::npos) {
922-
encoded = encoded.substr(comma_pos + 1);
945+
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
946+
if (gen_params.width > 0)
947+
return gen_params.width;
948+
if (default_gen_params.width > 0)
949+
return default_gen_params.width;
950+
return 512;
951+
};
952+
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
953+
if (gen_params.height > 0)
954+
return gen_params.height;
955+
if (default_gen_params.height > 0)
956+
return default_gen_params.height;
957+
return 512;
958+
};
959+
960+
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
961+
// remove data URI prefix if present ("data:image/png;base64,")
962+
auto comma_pos = encoded.find(',');
963+
if (comma_pos != std::string::npos) {
964+
encoded = encoded.substr(comma_pos + 1);
965+
}
966+
std::vector<uint8_t> img_data = base64_decode(encoded);
967+
if (!img_data.empty()) {
968+
int expected_width = 0;
969+
int expected_height = 0;
970+
if (gen_params.width_and_height_are_set()) {
971+
expected_width = gen_params.width;
972+
expected_height = gen_params.height;
923973
}
924-
std::vector<uint8_t> img_data = base64_decode(encoded);
925-
if (!img_data.empty()) {
926-
int img_w = image.width;
927-
int img_h = image.height;
928-
uint8_t* raw_data = load_image_from_memory(
929-
(const char*)img_data.data(), (int)img_data.size(),
930-
img_w, img_h,
931-
image.width, image.height, image.channel);
932-
if (raw_data) {
933-
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
934-
return true;
935-
}
974+
int img_w;
975+
int img_h;
976+
977+
uint8_t* raw_data = load_image_from_memory(
978+
(const char*)img_data.data(), (int)img_data.size(),
979+
img_w, img_h,
980+
expected_width, expected_height, image.channel);
981+
if (raw_data) {
982+
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
983+
gen_params.set_width_and_height_if_unset(image.width, image.height);
984+
return true;
936985
}
937-
return false;
938-
};
986+
}
987+
return false;
988+
};
939989

990+
if (img2img) {
940991
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
941992
std::string encoded = j["init_images"][0].get<std::string>();
942993
decode_image(init_image, encoded);
@@ -952,30 +1003,32 @@ int main(int argc, const char** argv) {
9521003
}
9531004
}
9541005
} else {
955-
mask_data = std::vector<uint8_t>(width * height, 255);
956-
mask_image.width = width;
957-
mask_image.height = height;
1006+
int m_width = get_resolved_width();
1007+
int m_height = get_resolved_height();
1008+
mask_data = std::vector<uint8_t>(m_width * m_height, 255);
1009+
mask_image.width = m_width;
1010+
mask_image.height = m_height;
9581011
mask_image.channel = 1;
9591012
mask_image.data = mask_data.data();
9601013
}
9611014

962-
if (j.contains("extra_images") && j["extra_images"].is_array()) {
963-
for (auto extra_image : j["extra_images"]) {
964-
std::string encoded = extra_image.get<std::string>();
965-
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
966-
if (decode_image(tmp_image, encoded)) {
967-
ref_images.push_back(tmp_image);
968-
}
969-
}
970-
}
971-
9721015
float denoising_strength = j.value("denoising_strength", -1.f);
9731016
if (denoising_strength >= 0.f) {
9741017
denoising_strength = std::min(denoising_strength, 1.0f);
9751018
gen_params.strength = denoising_strength;
9761019
}
9771020
}
9781021

1022+
if (j.contains("extra_images") && j["extra_images"].is_array()) {
1023+
for (auto extra_image : j["extra_images"]) {
1024+
std::string encoded = extra_image.get<std::string>();
1025+
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
1026+
if (decode_image(tmp_image, encoded)) {
1027+
ref_images.push_back(tmp_image);
1028+
}
1029+
}
1030+
}
1031+
9791032
sd_img_gen_params_t img_gen_params = {
9801033
sd_loras.data(),
9811034
static_cast<uint32_t>(sd_loras.size()),
@@ -988,8 +1041,8 @@ int main(int argc, const char** argv) {
9881041
gen_params.auto_resize_ref_image,
9891042
gen_params.increase_ref_index,
9901043
mask_image,
991-
gen_params.width,
992-
gen_params.height,
1044+
get_resolved_width(),
1045+
get_resolved_height(),
9931046
gen_params.sample_params,
9941047
gen_params.strength,
9951048
gen_params.seed,

0 commit comments

Comments
 (0)