@@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
404404 std::string size = j.value (" size" , " " );
405405 std::string output_format = j.value (" output_format" , " png" );
406406 int output_compression = j.value (" output_compression" , 100 );
407- int width = 512 ;
408- int height = 512 ;
407+ int width = default_gen_params. width > 0 ? default_gen_params. width : 512 ;
408+ int height = default_gen_params. width > 0 ? default_gen_params. height : 512 ;
409409 if (!size.empty ()) {
410410 auto pos = size.find (' x' );
411411 if (pos != std::string::npos) {
@@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
593593 n = std::clamp (n, 1 , 8 );
594594
595595 std::string size = req.form .get_field (" size" );
596- int width = 512 , height = 512 ;
596+ int width = - 1 , height = - 1 ;
597597 if (!size.empty ()) {
598598 auto pos = size.find (' x' );
599599 if (pos != std::string::npos) {
@@ -650,15 +650,31 @@ int main(int argc, const char** argv) {
650650
651651 LOG_DEBUG (" %s\n " , gen_params.to_string ().c_str ());
652652
653- sd_image_t init_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
654- sd_image_t control_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
653+ sd_image_t init_image = {0 , 0 , 3 , nullptr };
654+ sd_image_t control_image = {0 , 0 , 3 , nullptr };
655655 std::vector<sd_image_t > pmid_images;
656656
657+ auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
658+ if (gen_params.width > 0 )
659+ return gen_params.width ;
660+ if (default_gen_params.width > 0 )
661+ return default_gen_params.width ;
662+ return 512 ;
663+ };
664+ auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
665+ if (gen_params.height > 0 )
666+ return gen_params.height ;
667+ if (default_gen_params.height > 0 )
668+ return default_gen_params.height ;
669+ return 512 ;
670+ };
671+
657672 std::vector<sd_image_t > ref_images;
658673 ref_images.reserve (images_bytes.size ());
659674 for (auto & bytes : images_bytes) {
660- int img_w = width;
661- int img_h = height;
675+ int img_w;
676+ int img_h;
677+
662678 uint8_t * raw_pixels = load_image_from_memory (
663679 reinterpret_cast <const char *>(bytes.data ()),
664680 static_cast <int >(bytes.size ()),
@@ -670,22 +686,31 @@ int main(int argc, const char** argv) {
670686 }
671687
672688 sd_image_t img{(uint32_t )img_w, (uint32_t )img_h, 3 , raw_pixels};
689+ gen_params.set_width_and_height_if_unset (img.width , img.height );
673690 ref_images.push_back (img);
674691 }
675692
676693 sd_image_t mask_image = {0 };
677694 if (!mask_bytes.empty ()) {
678- int mask_w = width;
679- int mask_h = height;
695+ int expected_width = 0 ;
696+ int expected_height = 0 ;
697+ if (gen_params.width_and_height_are_set ()) {
698+ expected_width = gen_params.width ;
699+ expected_height = gen_params.height ;
700+ }
701+ int mask_w;
702+ int mask_h;
703+
680704 uint8_t * mask_raw = load_image_from_memory (
681705 reinterpret_cast <const char *>(mask_bytes.data ()),
682706 static_cast <int >(mask_bytes.size ()),
683707 mask_w, mask_h,
684- width, height , 1 );
708+ expected_width, expected_height , 1 );
685709 mask_image = {(uint32_t )mask_w, (uint32_t )mask_h, 1 , mask_raw};
710+ gen_params.set_width_and_height_if_unset (mask_image.width , mask_image.height );
686711 } else {
687- mask_image.width = width ;
688- mask_image.height = height ;
712+ mask_image.width = get_resolved_width () ;
713+ mask_image.height = get_resolved_height () ;
689714 mask_image.channel = 1 ;
690715 mask_image.data = nullptr ;
691716 }
@@ -702,8 +727,8 @@ int main(int argc, const char** argv) {
702727 gen_params.auto_resize_ref_image ,
703728 gen_params.increase_ref_index ,
704729 mask_image,
705- gen_params. width ,
706- gen_params. height ,
730+ get_resolved_width () ,
731+ get_resolved_height () ,
707732 gen_params.sample_params ,
708733 gen_params.strength ,
709734 gen_params.seed ,
@@ -886,8 +911,6 @@ int main(int argc, const char** argv) {
886911 SDGenerationParams gen_params = default_gen_params;
887912 gen_params.prompt = prompt;
888913 gen_params.negative_prompt = negative_prompt;
889- gen_params.width = width;
890- gen_params.height = height;
891914 gen_params.seed = seed;
892915 gen_params.sample_params .sample_steps = steps;
893916 gen_params.batch_count = batch_size;
@@ -905,38 +928,66 @@ int main(int argc, const char** argv) {
905928 gen_params.sample_params .scheduler = scheduler;
906929 }
907930
931+ // re-read to avoid applying 512 as default before the provided
932+ // images and/or server command-line
933+ gen_params.width = j.value (" width" , -1 );
934+ gen_params.height = j.value (" height" , -1 );
935+
908936 LOG_DEBUG (" %s\n " , gen_params.to_string ().c_str ());
909937
910- sd_image_t init_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
911- sd_image_t control_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
912- sd_image_t mask_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 1 , nullptr };
938+ sd_image_t init_image = {0 , 0 , 3 , nullptr };
939+ sd_image_t control_image = {0 , 0 , 3 , nullptr };
940+ sd_image_t mask_image = {0 , 0 , 1 , nullptr };
913941 std::vector<uint8_t > mask_data;
914942 std::vector<sd_image_t > pmid_images;
915943 std::vector<sd_image_t > ref_images;
916944
917- if (img2img) {
918- auto decode_image = [](sd_image_t & image, std::string encoded) -> bool {
919- // remove data URI prefix if present ("data:image/png;base64,")
920- auto comma_pos = encoded.find (' ,' );
921- if (comma_pos != std::string::npos) {
922- encoded = encoded.substr (comma_pos + 1 );
945+ auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
946+ if (gen_params.width > 0 )
947+ return gen_params.width ;
948+ if (default_gen_params.width > 0 )
949+ return default_gen_params.width ;
950+ return 512 ;
951+ };
952+ auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
953+ if (gen_params.height > 0 )
954+ return gen_params.height ;
955+ if (default_gen_params.height > 0 )
956+ return default_gen_params.height ;
957+ return 512 ;
958+ };
959+
960+ auto decode_image = [&gen_params](sd_image_t & image, std::string encoded) -> bool {
961+ // remove data URI prefix if present ("data:image/png;base64,")
962+ auto comma_pos = encoded.find (' ,' );
963+ if (comma_pos != std::string::npos) {
964+ encoded = encoded.substr (comma_pos + 1 );
965+ }
966+ std::vector<uint8_t > img_data = base64_decode (encoded);
967+ if (!img_data.empty ()) {
968+ int expected_width = 0 ;
969+ int expected_height = 0 ;
970+ if (gen_params.width_and_height_are_set ()) {
971+ expected_width = gen_params.width ;
972+ expected_height = gen_params.height ;
923973 }
924- std::vector<uint8_t > img_data = base64_decode (encoded);
925- if (!img_data.empty ()) {
926- int img_w = image.width ;
927- int img_h = image.height ;
928- uint8_t * raw_data = load_image_from_memory (
929- (const char *)img_data.data (), (int )img_data.size (),
930- img_w, img_h,
931- image.width , image.height , image.channel );
932- if (raw_data) {
933- image = {(uint32_t )img_w, (uint32_t )img_h, image.channel , raw_data};
934- return true ;
935- }
974+ int img_w;
975+ int img_h;
976+
977+ uint8_t * raw_data = load_image_from_memory (
978+ (const char *)img_data.data (), (int )img_data.size (),
979+ img_w, img_h,
980+ expected_width, expected_height, image.channel );
981+ if (raw_data) {
982+ image = {(uint32_t )img_w, (uint32_t )img_h, image.channel , raw_data};
983+ gen_params.set_width_and_height_if_unset (image.width , image.height );
984+ return true ;
936985 }
937- return false ;
938- };
986+ }
987+ return false ;
988+ };
939989
990+ if (img2img) {
940991 if (j.contains (" init_images" ) && j[" init_images" ].is_array () && !j[" init_images" ].empty ()) {
941992 std::string encoded = j[" init_images" ][0 ].get <std::string>();
942993 decode_image (init_image, encoded);
@@ -952,30 +1003,32 @@ int main(int argc, const char** argv) {
9521003 }
9531004 }
9541005 } else {
955- mask_data = std::vector<uint8_t >(width * height, 255 );
956- mask_image.width = width;
957- mask_image.height = height;
1006+ int m_width = get_resolved_width ();
1007+ int m_height = get_resolved_height ();
1008+ mask_data = std::vector<uint8_t >(m_width * m_height, 255 );
1009+ mask_image.width = m_width;
1010+ mask_image.height = m_height;
9581011 mask_image.channel = 1 ;
9591012 mask_image.data = mask_data.data ();
9601013 }
9611014
962- if (j.contains (" extra_images" ) && j[" extra_images" ].is_array ()) {
963- for (auto extra_image : j[" extra_images" ]) {
964- std::string encoded = extra_image.get <std::string>();
965- sd_image_t tmp_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
966- if (decode_image (tmp_image, encoded)) {
967- ref_images.push_back (tmp_image);
968- }
969- }
970- }
971-
9721015 float denoising_strength = j.value (" denoising_strength" , -1 .f );
9731016 if (denoising_strength >= 0 .f ) {
9741017 denoising_strength = std::min (denoising_strength, 1 .0f );
9751018 gen_params.strength = denoising_strength;
9761019 }
9771020 }
9781021
1022+ if (j.contains (" extra_images" ) && j[" extra_images" ].is_array ()) {
1023+ for (auto extra_image : j[" extra_images" ]) {
1024+ std::string encoded = extra_image.get <std::string>();
1025+ sd_image_t tmp_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
1026+ if (decode_image (tmp_image, encoded)) {
1027+ ref_images.push_back (tmp_image);
1028+ }
1029+ }
1030+ }
1031+
9791032 sd_img_gen_params_t img_gen_params = {
9801033 sd_loras.data (),
9811034 static_cast <uint32_t >(sd_loras.size ()),
@@ -988,8 +1041,8 @@ int main(int argc, const char** argv) {
9881041 gen_params.auto_resize_ref_image ,
9891042 gen_params.increase_ref_index ,
9901043 mask_image,
991- gen_params. width ,
992- gen_params. height ,
1044+ get_resolved_width () ,
1045+ get_resolved_height () ,
9931046 gen_params.sample_params ,
9941047 gen_params.strength ,
9951048 gen_params.seed ,
0 commit comments