Skip to content

Commit fab2ff0

Browse files
committed
sync sd.cpp to e370258
1 parent 49535fd commit fab2ff0

File tree

3 files changed

+276
-146
lines changed

3 files changed

+276
-146
lines changed

otherarch/sdcpp/main.cpp

Lines changed: 121 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ const char* modes_str[] = {
4141
"img_gen",
4242
"vid_gen",
4343
"convert",
44+
"upscale",
4445
};
45-
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert"
46+
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
4647

4748
enum SDMode {
4849
IMG_GEN,
4950
VID_GEN,
5051
CONVERT,
52+
UPSCALE,
5153
MODE_COUNT
5254
};
5355

@@ -82,6 +84,7 @@ struct SDParams {
8284

8385
std::string prompt;
8486
std::string negative_prompt;
87+
8588
int clip_skip = -1; // <= 0 represents unspecified
8689
int width = 512;
8790
int height = 512;
@@ -125,6 +128,8 @@ struct SDParams {
125128
int chroma_t5_mask_pad = 1;
126129
float flow_shift = INFINITY;
127130

131+
prediction_t prediction = DEFAULT_PRED;
132+
128133
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
129134

130135
SDParams() {
@@ -186,6 +191,7 @@ void print_params(SDParams params) {
186191
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
187192
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
188193
printf(" moe_boundary: %.3f\n", params.moe_boundary);
194+
printf(" prediction: %s\n", sd_prediction_name(params.prediction));
189195
printf(" flow_shift: %.2f\n", params.flow_shift);
190196
printf(" strength(img2img): %.2f\n", params.strength);
191197
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@@ -208,7 +214,7 @@ void print_usage(int argc, const char* argv[]) {
208214
printf("\n");
209215
printf("arguments:\n");
210216
printf(" -h, --help show this help message and exit\n");
211-
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n");
217+
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n");
212218
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
213219
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
214220
printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n");
@@ -225,7 +231,7 @@ void print_usage(int argc, const char* argv[]) {
225231
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
226232
printf(" --control-net [CONTROL_PATH] path to control net model\n");
227233
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
228-
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
234+
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
229235
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
230236
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
231237
printf(" If not specified, the default is the type of the weight file\n");
@@ -279,6 +285,7 @@ void print_usage(int argc, const char* argv[]) {
279285
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
280286
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
281287
printf(" -b, --batch-count COUNT number of images to generate\n");
288+
printf(" --prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override.\n");
282289
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
283290
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
284291
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -649,6 +656,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
649656
return 1;
650657
};
651658

659+
auto on_prediction_arg = [&](int argc, const char** argv, int index) {
660+
if (++index >= argc) {
661+
return -1;
662+
}
663+
const char* arg = argv[index];
664+
params.prediction = str_to_prediction(arg);
665+
if (params.prediction == PREDICTION_COUNT) {
666+
fprintf(stderr, "error: invalid prediction type %s\n",
667+
arg);
668+
return -1;
669+
}
670+
return 1;
671+
};
672+
652673
auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
653674
if (++index >= argc) {
654675
return -1;
@@ -805,6 +826,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
805826
{"", "--rng", "", on_rng_arg},
806827
{"-s", "--seed", "", on_seed_arg},
807828
{"", "--sampling-method", "", on_sample_method_arg},
829+
{"", "--prediction", "", on_prediction_arg},
808830
{"", "--scheduler", "", on_schedule_arg},
809831
{"", "--skip-layers", "", on_skip_layers_arg},
810832
{"", "--high-noise-sampling-method", "", on_high_noise_sample_method_arg},
@@ -825,13 +847,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
825847
params.n_threads = sd_get_num_physical_cores();
826848
}
827849

828-
if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) {
850+
if ((params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt.length() == 0) {
829851
fprintf(stderr, "error: the following arguments are required: prompt\n");
830852
print_usage(argc, argv);
831853
exit(1);
832854
}
833855

834-
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
856+
if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
835857
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
836858
print_usage(argc, argv);
837859
exit(1);
@@ -891,6 +913,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
891913
exit(1);
892914
}
893915

916+
if (params.mode == UPSCALE) {
917+
if (params.esrgan_path.length() == 0) {
918+
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
919+
exit(1);
920+
}
921+
if (params.init_image_path.length() == 0) {
922+
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
923+
exit(1);
924+
}
925+
}
926+
894927
if (params.seed < 0) {
895928
srand((int)time(NULL));
896929
params.seed = rand();
@@ -901,14 +934,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
901934
params.output_path = "output.gguf";
902935
}
903936
}
904-
905-
if (!isfinite(params.sample_params.guidance.img_cfg)) {
906-
params.sample_params.guidance.img_cfg = params.sample_params.guidance.txt_cfg;
907-
}
908-
909-
if (!isfinite(params.high_noise_sample_params.guidance.img_cfg)) {
910-
params.high_noise_sample_params.guidance.img_cfg = params.high_noise_sample_params.guidance.txt_cfg;
911-
}
912937
}
913938

914939
static std::string sd_basename(const std::string& path) {
@@ -1349,6 +1374,7 @@ int main(int argc, const char* argv[]) {
13491374
params.n_threads,
13501375
params.wtype,
13511376
params.rng_type,
1377+
params.prediction,
13521378
params.offload_params_to_cpu,
13531379
params.clip_on_cpu,
13541380
params.control_net_cpu,
@@ -1362,76 +1388,92 @@ int main(int argc, const char* argv[]) {
13621388
params.flow_shift,
13631389
};
13641390

1365-
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
1391+
sd_image_t* results = nullptr;
1392+
int num_results = 0;
13661393

1367-
if (sd_ctx == NULL) {
1368-
printf("new_sd_ctx_t failed\n");
1369-
release_all_resources();
1370-
return 1;
1371-
}
1394+
if (params.mode == UPSCALE) {
1395+
num_results = 1;
1396+
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
1397+
if (results == NULL) {
1398+
printf("failed to allocate results array\n");
1399+
release_all_resources();
1400+
return 1;
1401+
}
13721402

1373-
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1374-
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1375-
}
1403+
results[0] = init_image;
1404+
init_image.data = NULL;
1405+
} else {
1406+
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
13761407

1377-
sd_image_t* results;
1378-
int num_results = 1;
1379-
if (params.mode == IMG_GEN) {
1380-
sd_img_gen_params_t img_gen_params = {
1381-
params.prompt.c_str(),
1382-
params.negative_prompt.c_str(),
1383-
params.clip_skip,
1384-
init_image,
1385-
ref_images.data(),
1386-
(int)ref_images.size(),
1387-
params.increase_ref_index,
1388-
mask_image,
1389-
params.width,
1390-
params.height,
1391-
params.sample_params,
1392-
params.strength,
1393-
params.seed,
1394-
params.batch_count,
1395-
control_image,
1396-
params.control_strength,
1397-
{
1398-
pmid_images.data(),
1399-
(int)pmid_images.size(),
1400-
params.pm_id_embed_path.c_str(),
1401-
params.pm_style_strength,
1402-
}, // pm_params
1403-
params.vae_tiling_params,
1404-
};
1405-
1406-
results = generate_image(sd_ctx, &img_gen_params);
1407-
num_results = params.batch_count;
1408-
} else if (params.mode == VID_GEN) {
1409-
sd_vid_gen_params_t vid_gen_params = {
1410-
params.prompt.c_str(),
1411-
params.negative_prompt.c_str(),
1412-
params.clip_skip,
1413-
init_image,
1414-
end_image,
1415-
control_frames.data(),
1416-
(int)control_frames.size(),
1417-
params.width,
1418-
params.height,
1419-
params.sample_params,
1420-
params.high_noise_sample_params,
1421-
params.moe_boundary,
1422-
params.strength,
1423-
params.seed,
1424-
params.video_frames,
1425-
params.vace_strength,
1426-
};
1427-
1428-
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
1429-
}
1408+
if (sd_ctx == NULL) {
1409+
printf("new_sd_ctx_t failed\n");
1410+
release_all_resources();
1411+
return 1;
1412+
}
1413+
1414+
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
1415+
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
1416+
}
1417+
1418+
if (params.mode == IMG_GEN) {
1419+
sd_img_gen_params_t img_gen_params = {
1420+
params.prompt.c_str(),
1421+
params.negative_prompt.c_str(),
1422+
params.clip_skip,
1423+
init_image,
1424+
ref_images.data(),
1425+
(int)ref_images.size(),
1426+
params.increase_ref_index,
1427+
mask_image,
1428+
params.width,
1429+
params.height,
1430+
params.sample_params,
1431+
params.strength,
1432+
params.seed,
1433+
params.batch_count,
1434+
control_image,
1435+
params.control_strength,
1436+
{
1437+
pmid_images.data(),
1438+
(int)pmid_images.size(),
1439+
params.pm_id_embed_path.c_str(),
1440+
params.pm_style_strength,
1441+
}, // pm_params
1442+
params.vae_tiling_params,
1443+
};
1444+
1445+
results = generate_image(sd_ctx, &img_gen_params);
1446+
num_results = params.batch_count;
1447+
} else if (params.mode == VID_GEN) {
1448+
sd_vid_gen_params_t vid_gen_params = {
1449+
params.prompt.c_str(),
1450+
params.negative_prompt.c_str(),
1451+
params.clip_skip,
1452+
init_image,
1453+
end_image,
1454+
control_frames.data(),
1455+
(int)control_frames.size(),
1456+
params.width,
1457+
params.height,
1458+
params.sample_params,
1459+
params.high_noise_sample_params,
1460+
params.moe_boundary,
1461+
params.strength,
1462+
params.seed,
1463+
params.video_frames,
1464+
params.vace_strength,
1465+
};
1466+
1467+
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
1468+
}
1469+
1470+
if (results == NULL) {
1471+
printf("generate failed\n");
1472+
free_sd_ctx(sd_ctx);
1473+
return 1;
1474+
}
14301475

1431-
if (results == NULL) {
1432-
printf("generate failed\n");
14331476
free_sd_ctx(sd_ctx);
1434-
return 1;
14351477
}
14361478

14371479
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
@@ -1444,7 +1486,7 @@ int main(int argc, const char* argv[]) {
14441486
if (upscaler_ctx == NULL) {
14451487
printf("new_upscaler_ctx failed\n");
14461488
} else {
1447-
for (int i = 0; i < params.batch_count; i++) {
1489+
for (int i = 0; i < num_results; i++) {
14481490
if (results[i].data == NULL) {
14491491
continue;
14501492
}
@@ -1530,7 +1572,6 @@ int main(int argc, const char* argv[]) {
15301572
results[i].data = NULL;
15311573
}
15321574
free(results);
1533-
free_sd_ctx(sd_ctx);
15341575

15351576
release_all_resources();
15361577

0 commit comments

Comments
 (0)