@@ -92,15 +92,16 @@ struct SDParams {
9292
9393 std::string prompt;
9494 std::string negative_prompt;
95- float min_cfg = 1 .0f ;
96- float cfg_scale = 7 .0f ;
97- float guidance = 3 .5f ;
98- float eta = 0 .f;
99- float style_ratio = 20 .f;
100- int clip_skip = -1 ; // <= 0 represents unspecified
101- int width = 512 ;
102- int height = 512 ;
103- int batch_count = 1 ;
95+ float min_cfg = 1 .0f ;
96+ float cfg_scale = 7 .0f ;
97+ float img_cfg_scale = INFINITY;
98+ float guidance = 3 .5f ;
99+ float eta = 0 .f;
100+ float style_ratio = 20 .f;
101+ int clip_skip = -1 ; // <= 0 represents unspecified
102+ int width = 512 ;
103+ int height = 512 ;
104+ int batch_count = 1 ;
104105
105106 int video_frames = 6 ;
106107 int motion_bucket_id = 127 ;
@@ -163,6 +164,7 @@ void print_params(SDParams params) {
163164 printf (" negative_prompt: %s\n " , params.negative_prompt .c_str ());
164165 printf (" min_cfg: %.2f\n " , params.min_cfg );
165166 printf (" cfg_scale: %.2f\n " , params.cfg_scale );
167+ printf (" img_cfg_scale: %.2f\n " , params.img_cfg_scale );
166168 printf (" slg_scale: %.2f\n " , params.slg_scale );
167169 printf (" guidance: %.2f\n " , params.guidance );
168170 printf (" eta: %.2f\n " , params.eta );
@@ -212,7 +214,8 @@ void print_usage(int argc, const char* argv[]) {
212214 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
213215 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
214216 printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
215- printf (" --guidance SCALE guidance scale for img2img (default: 3.5)\n " );
217+ printf (" --img_cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n " );
218+ printf (" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n " );
216219 printf (" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n " );
217220 printf (" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n " );
218221 printf (" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n " );
@@ -439,6 +442,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
439442 break ;
440443 }
441444 params.cfg_scale = std::stof (argv[i]);
445+ } else if (arg == " --img-cfg-scale" ) {
446+ if (++i >= argc) {
447+ invalid_arg = true ;
448+ break ;
449+ }
450+ params.img_cfg_scale = std::stof (argv[i]);
442451 } else if (arg == " --guidance" ) {
443452 if (++i >= argc) {
444453 invalid_arg = true ;
@@ -698,6 +707,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
698707 params.output_path = " output.gguf" ;
699708 }
700709 }
710+
711+ if (!isfinite (params.img_cfg_scale )) {
712+ params.img_cfg_scale = params.cfg_scale ;
713+ }
701714}
702715
703716static std::string sd_basename (const std::string& path) {
@@ -792,6 +805,18 @@ int main(int argc, const char* argv[]) {
792805
793806 parse_args (argc, argv, params);
794807
808+ sd_guidance_params_t guidance_params = {params.cfg_scale ,
809+ params.img_cfg_scale ,
810+ params.min_cfg ,
811+ params.guidance ,
812+ {
813+ params.skip_layers .data (),
814+ params.skip_layers .size (),
815+ params.skip_layer_start ,
816+ params.skip_layer_end ,
817+ params.slg_scale ,
818+ }};
819+
795820 sd_set_log_callback (sd_log_cb, (void *)¶ms);
796821
797822 if (params.verbose ) {
@@ -949,8 +974,7 @@ int main(int argc, const char* argv[]) {
949974 params.prompt .c_str (),
950975 params.negative_prompt .c_str (),
951976 params.clip_skip ,
952- params.cfg_scale ,
953- params.guidance ,
977+ guidance_params,
954978 params.eta ,
955979 params.width ,
956980 params.height ,
@@ -962,12 +986,7 @@ int main(int argc, const char* argv[]) {
962986 params.control_strength ,
963987 params.style_ratio ,
964988 params.normalize_input ,
965- params.input_id_images_path .c_str (),
966- params.skip_layers .data (),
967- params.skip_layers .size (),
968- params.slg_scale ,
969- params.skip_layer_start ,
970- params.skip_layer_end );
989+ params.input_id_images_path .c_str ());
971990 } else {
972991 sd_image_t input_image = {(uint32_t )params.width ,
973992 (uint32_t )params.height ,
@@ -983,8 +1002,7 @@ int main(int argc, const char* argv[]) {
9831002 params.motion_bucket_id ,
9841003 params.fps ,
9851004 params.augmentation_level ,
986- params.min_cfg ,
987- params.cfg_scale ,
1005+ guidance_params,
9881006 params.sample_method ,
9891007 params.sample_steps ,
9901008 params.strength ,
@@ -1017,8 +1035,7 @@ int main(int argc, const char* argv[]) {
10171035 params.prompt .c_str (),
10181036 params.negative_prompt .c_str (),
10191037 params.clip_skip ,
1020- params.cfg_scale ,
1021- params.guidance ,
1038+ guidance_params,
10221039 params.eta ,
10231040 params.width ,
10241041 params.height ,
@@ -1031,12 +1048,7 @@ int main(int argc, const char* argv[]) {
10311048 params.control_strength ,
10321049 params.style_ratio ,
10331050 params.normalize_input ,
1034- params.input_id_images_path .c_str (),
1035- params.skip_layers .data (),
1036- params.skip_layers .size (),
1037- params.slg_scale ,
1038- params.skip_layer_start ,
1039- params.skip_layer_end );
1051+ params.input_id_images_path .c_str ());
10401052 }
10411053 }
10421054
@@ -1075,19 +1087,19 @@ int main(int argc, const char* argv[]) {
10751087
10761088 std::string dummy_name, ext, lc_ext;
10771089 bool is_jpg;
1078- size_t last = params.output_path .find_last_of (" ." );
1090+ size_t last = params.output_path .find_last_of (" ." );
10791091 size_t last_path = std::min (params.output_path .find_last_of (" /" ),
10801092 params.output_path .find_last_of (" \\ " ));
1081- if (last != std::string::npos // filename has extension
1082- && (last_path == std::string::npos || last > last_path)) {
1093+ if (last != std::string::npos // filename has extension
1094+ && (last_path == std::string::npos || last > last_path)) {
10831095 dummy_name = params.output_path .substr (0 , last);
10841096 ext = lc_ext = params.output_path .substr (last);
10851097 std::transform (ext.begin (), ext.end (), lc_ext.begin (), ::tolower);
10861098 is_jpg = lc_ext == " .jpg" || lc_ext == " .jpeg" || lc_ext == " .jpe" ;
10871099 } else {
10881100 dummy_name = params.output_path ;
10891101 ext = lc_ext = " " ;
1090- is_jpg = false ;
1102+ is_jpg = false ;
10911103 }
10921104 // appending ".png" to absent or unknown extension
10931105 if (!is_jpg && lc_ext != " .png" ) {
@@ -1099,7 +1111,7 @@ int main(int argc, const char* argv[]) {
10991111 continue ;
11001112 }
11011113 std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + ext : dummy_name + ext;
1102- if (is_jpg) {
1114+ if (is_jpg) {
11031115 stbi_write_jpg (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
11041116 results[i].data , 90 , get_image_params (params, params.seed + i).c_str ());
11051117 printf (" save result JPEG image to '%s'\n " , final_image_path.c_str ());
0 commit comments