Skip to content

Commit 37d5e27

Browse files
committed
Support 2-cond cfg properly in cli
1 parent 94dea6c commit 37d5e27

File tree

3 files changed

+100
-119
lines changed

3 files changed

+100
-119
lines changed

examples/cli/main.cpp

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,16 @@ struct SDParams {
9292

9393
std::string prompt;
9494
std::string negative_prompt;
95-
float min_cfg = 1.0f;
96-
float cfg_scale = 7.0f;
97-
float guidance = 3.5f;
98-
float eta = 0.f;
99-
float style_ratio = 20.f;
100-
int clip_skip = -1; // <= 0 represents unspecified
101-
int width = 512;
102-
int height = 512;
103-
int batch_count = 1;
95+
float min_cfg = 1.0f;
96+
float cfg_scale = 7.0f;
97+
float img_cfg_scale = INFINITY;
98+
float guidance = 3.5f;
99+
float eta = 0.f;
100+
float style_ratio = 20.f;
101+
int clip_skip = -1; // <= 0 represents unspecified
102+
int width = 512;
103+
int height = 512;
104+
int batch_count = 1;
104105

105106
int video_frames = 6;
106107
int motion_bucket_id = 127;
@@ -163,6 +164,7 @@ void print_params(SDParams params) {
163164
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
164165
printf(" min_cfg: %.2f\n", params.min_cfg);
165166
printf(" cfg_scale: %.2f\n", params.cfg_scale);
167+
printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale);
166168
printf(" slg_scale: %.2f\n", params.slg_scale);
167169
printf(" guidance: %.2f\n", params.guidance);
168170
printf(" eta: %.2f\n", params.eta);
@@ -212,7 +214,8 @@ void print_usage(int argc, const char* argv[]) {
212214
printf(" -p, --prompt [PROMPT] the prompt to render\n");
213215
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
214216
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
215-
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
217+
printf(" --img_cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n");
218+
printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n");
216219
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
217220
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
218221
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -439,6 +442,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
439442
break;
440443
}
441444
params.cfg_scale = std::stof(argv[i]);
445+
} else if (arg == "--img-cfg-scale") {
446+
if (++i >= argc) {
447+
invalid_arg = true;
448+
break;
449+
}
450+
params.img_cfg_scale = std::stof(argv[i]);
442451
} else if (arg == "--guidance") {
443452
if (++i >= argc) {
444453
invalid_arg = true;
@@ -698,6 +707,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
698707
params.output_path = "output.gguf";
699708
}
700709
}
710+
711+
if (!isfinite(params.img_cfg_scale)) {
712+
params.img_cfg_scale = params.cfg_scale;
713+
}
701714
}
702715

703716
static std::string sd_basename(const std::string& path) {
@@ -792,6 +805,18 @@ int main(int argc, const char* argv[]) {
792805

793806
parse_args(argc, argv, params);
794807

808+
sd_guidance_params_t guidance_params = {params.cfg_scale,
809+
params.img_cfg_scale,
810+
params.min_cfg,
811+
params.guidance,
812+
{
813+
params.skip_layers.data(),
814+
params.skip_layers.size(),
815+
params.skip_layer_start,
816+
params.skip_layer_end,
817+
params.slg_scale,
818+
}};
819+
795820
sd_set_log_callback(sd_log_cb, (void*)&params);
796821

797822
if (params.verbose) {
@@ -949,8 +974,7 @@ int main(int argc, const char* argv[]) {
949974
params.prompt.c_str(),
950975
params.negative_prompt.c_str(),
951976
params.clip_skip,
952-
params.cfg_scale,
953-
params.guidance,
977+
guidance_params,
954978
params.eta,
955979
params.width,
956980
params.height,
@@ -962,12 +986,7 @@ int main(int argc, const char* argv[]) {
962986
params.control_strength,
963987
params.style_ratio,
964988
params.normalize_input,
965-
params.input_id_images_path.c_str(),
966-
params.skip_layers.data(),
967-
params.skip_layers.size(),
968-
params.slg_scale,
969-
params.skip_layer_start,
970-
params.skip_layer_end);
989+
params.input_id_images_path.c_str());
971990
} else {
972991
sd_image_t input_image = {(uint32_t)params.width,
973992
(uint32_t)params.height,
@@ -983,8 +1002,7 @@ int main(int argc, const char* argv[]) {
9831002
params.motion_bucket_id,
9841003
params.fps,
9851004
params.augmentation_level,
986-
params.min_cfg,
987-
params.cfg_scale,
1005+
guidance_params,
9881006
params.sample_method,
9891007
params.sample_steps,
9901008
params.strength,
@@ -1017,8 +1035,7 @@ int main(int argc, const char* argv[]) {
10171035
params.prompt.c_str(),
10181036
params.negative_prompt.c_str(),
10191037
params.clip_skip,
1020-
params.cfg_scale,
1021-
params.guidance,
1038+
guidance_params,
10221039
params.eta,
10231040
params.width,
10241041
params.height,
@@ -1031,12 +1048,7 @@ int main(int argc, const char* argv[]) {
10311048
params.control_strength,
10321049
params.style_ratio,
10331050
params.normalize_input,
1034-
params.input_id_images_path.c_str(),
1035-
params.skip_layers.data(),
1036-
params.skip_layers.size(),
1037-
params.slg_scale,
1038-
params.skip_layer_start,
1039-
params.skip_layer_end);
1051+
params.input_id_images_path.c_str());
10401052
}
10411053
}
10421054

@@ -1075,19 +1087,19 @@ int main(int argc, const char* argv[]) {
10751087

10761088
std::string dummy_name, ext, lc_ext;
10771089
bool is_jpg;
1078-
size_t last = params.output_path.find_last_of(".");
1090+
size_t last = params.output_path.find_last_of(".");
10791091
size_t last_path = std::min(params.output_path.find_last_of("/"),
10801092
params.output_path.find_last_of("\\"));
1081-
if (last != std::string::npos // filename has extension
1082-
&& (last_path == std::string::npos || last > last_path)) {
1093+
if (last != std::string::npos // filename has extension
1094+
&& (last_path == std::string::npos || last > last_path)) {
10831095
dummy_name = params.output_path.substr(0, last);
10841096
ext = lc_ext = params.output_path.substr(last);
10851097
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
10861098
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
10871099
} else {
10881100
dummy_name = params.output_path;
10891101
ext = lc_ext = "";
1090-
is_jpg = false;
1102+
is_jpg = false;
10911103
}
10921104
// appending ".png" to absent or unknown extension
10931105
if (!is_jpg && lc_ext != ".png") {
@@ -1099,7 +1111,7 @@ int main(int argc, const char* argv[]) {
10991111
continue;
11001112
}
11011113
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1102-
if(is_jpg) {
1114+
if (is_jpg) {
11031115
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11041116
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11051117
printf("save result JPEG image to '%s'\n", final_image_path.c_str());

0 commit comments

Comments
 (0)