@@ -41,13 +41,15 @@ const char* modes_str[] = {
4141 " img_gen" ,
4242 " vid_gen" ,
4343 " convert" ,
44+ " upscale" ,
4445};
45- #define SD_ALL_MODES_STR " img_gen, vid_gen, convert"
46+ #define SD_ALL_MODES_STR " img_gen, vid_gen, convert, upscale "
4647
4748enum SDMode {
4849 IMG_GEN,
4950 VID_GEN,
5051 CONVERT,
52+ UPSCALE,
5153 MODE_COUNT
5254};
5355
@@ -82,6 +84,7 @@ struct SDParams {
8284
8385 std::string prompt;
8486 std::string negative_prompt;
87+
8588 int clip_skip = -1 ; // <= 0 represents unspecified
8689 int width = 512 ;
8790 int height = 512 ;
@@ -125,6 +128,8 @@ struct SDParams {
125128 int chroma_t5_mask_pad = 1 ;
126129 float flow_shift = INFINITY;
127130
131+ prediction_t prediction = DEFAULT_PRED;
132+
128133 sd_tiling_params_t vae_tiling_params = {false , 0 , 0 , 0 .5f , 0 .0f , 0 .0f };
129134
130135 SDParams () {
@@ -186,6 +191,7 @@ void print_params(SDParams params) {
186191 printf (" sample_params: %s\n " , SAFE_STR (sample_params_str));
187192 printf (" high_noise_sample_params: %s\n " , SAFE_STR (high_noise_sample_params_str));
188193 printf (" moe_boundary: %.3f\n " , params.moe_boundary );
194+ printf (" prediction: %s\n " , sd_prediction_name (params.prediction ));
189195 printf (" flow_shift: %.2f\n " , params.flow_shift );
190196 printf (" strength(img2img): %.2f\n " , params.strength );
191197 printf (" rng: %s\n " , sd_rng_type_name (params.rng_type ));
@@ -208,7 +214,7 @@ void print_usage(int argc, const char* argv[]) {
208214 printf (" \n " );
209215 printf (" arguments:\n " );
210216 printf (" -h, --help show this help message and exit\n " );
211- printf (" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n " );
217+ printf (" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n " );
212218 printf (" -t, --threads N number of threads to use during computation (default: -1)\n " );
213219 printf (" If threads <= 0, then threads will be set to the number of CPU physical cores\n " );
214220 printf (" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n " );
@@ -225,7 +231,7 @@ void print_usage(int argc, const char* argv[]) {
225231 printf (" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n " );
226232 printf (" --control-net [CONTROL_PATH] path to control net model\n " );
227233 printf (" --embd-dir [EMBEDDING_PATH] path to embeddings\n " );
228- printf (" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n " );
234+ printf (" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n " );
229235 printf (" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n " );
230236 printf (" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n " );
231237 printf (" If not specified, the default is the type of the weight file\n " );
@@ -279,6 +285,7 @@ void print_usage(int argc, const char* argv[]) {
279285 printf (" --rng {std_default, cuda} RNG (default: cuda)\n " );
280286 printf (" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n " );
281287 printf (" -b, --batch-count COUNT number of images to generate\n " );
288+ printf (" --prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override.\n " );
282289 printf (" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n " );
283290 printf (" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n " );
284291 printf (" --vae-tiling process vae in tiles to reduce memory usage\n " );
@@ -649,6 +656,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
649656 return 1 ;
650657 };
651658
659+ auto on_prediction_arg = [&](int argc, const char ** argv, int index) {
660+ if (++index >= argc) {
661+ return -1 ;
662+ }
663+ const char * arg = argv[index];
664+ params.prediction = str_to_prediction (arg);
665+ if (params.prediction == PREDICTION_COUNT) {
666+ fprintf (stderr, " error: invalid prediction type %s\n " ,
667+ arg);
668+ return -1 ;
669+ }
670+ return 1 ;
671+ };
672+
652673 auto on_sample_method_arg = [&](int argc, const char ** argv, int index) {
653674 if (++index >= argc) {
654675 return -1 ;
@@ -805,6 +826,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
805826 {" " , " --rng" , " " , on_rng_arg},
806827 {" -s" , " --seed" , " " , on_seed_arg},
807828 {" " , " --sampling-method" , " " , on_sample_method_arg},
829+ {" " , " --prediction" , " " , on_prediction_arg},
808830 {" " , " --scheduler" , " " , on_schedule_arg},
809831 {" " , " --skip-layers" , " " , on_skip_layers_arg},
810832 {" " , " --high-noise-sampling-method" , " " , on_high_noise_sample_method_arg},
@@ -825,13 +847,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
825847 params.n_threads = sd_get_num_physical_cores ();
826848 }
827849
828- if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt .length () == 0 ) {
850+ if (( params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt .length () == 0 ) {
829851 fprintf (stderr, " error: the following arguments are required: prompt\n " );
830852 print_usage (argc, argv);
831853 exit (1 );
832854 }
833855
834- if (params.model_path .length () == 0 && params.diffusion_model_path .length () == 0 ) {
856+ if (params.mode != UPSCALE && params. model_path .length () == 0 && params.diffusion_model_path .length () == 0 ) {
835857 fprintf (stderr, " error: the following arguments are required: model_path/diffusion_model\n " );
836858 print_usage (argc, argv);
837859 exit (1 );
@@ -891,6 +913,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
891913 exit (1 );
892914 }
893915
916+ if (params.mode == UPSCALE) {
917+ if (params.esrgan_path .length () == 0 ) {
918+ fprintf (stderr, " error: upscale mode needs an upscaler model (--upscale-model)\n " );
919+ exit (1 );
920+ }
921+ if (params.init_image_path .length () == 0 ) {
922+ fprintf (stderr, " error: upscale mode needs an init image (--init-img)\n " );
923+ exit (1 );
924+ }
925+ }
926+
894927 if (params.seed < 0 ) {
895928 srand ((int )time (NULL ));
896929 params.seed = rand ();
@@ -901,14 +934,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
901934 params.output_path = " output.gguf" ;
902935 }
903936 }
904-
905- if (!isfinite (params.sample_params .guidance .img_cfg )) {
906- params.sample_params .guidance .img_cfg = params.sample_params .guidance .txt_cfg ;
907- }
908-
909- if (!isfinite (params.high_noise_sample_params .guidance .img_cfg )) {
910- params.high_noise_sample_params .guidance .img_cfg = params.high_noise_sample_params .guidance .txt_cfg ;
911- }
912937}
913938
914939static std::string sd_basename (const std::string& path) {
@@ -1349,6 +1374,7 @@ int main(int argc, const char* argv[]) {
13491374 params.n_threads ,
13501375 params.wtype ,
13511376 params.rng_type ,
1377+ params.prediction ,
13521378 params.offload_params_to_cpu ,
13531379 params.clip_on_cpu ,
13541380 params.control_net_cpu ,
@@ -1362,76 +1388,92 @@ int main(int argc, const char* argv[]) {
13621388 params.flow_shift ,
13631389 };
13641390
1365- sd_ctx_t * sd_ctx = new_sd_ctx (&sd_ctx_params);
1391+ sd_image_t * results = nullptr ;
1392+ int num_results = 0 ;
13661393
1367- if (sd_ctx == NULL ) {
1368- printf (" new_sd_ctx_t failed\n " );
1369- release_all_resources ();
1370- return 1 ;
1371- }
1394+ if (params.mode == UPSCALE) {
1395+ num_results = 1 ;
1396+ results = (sd_image_t *)calloc (num_results, sizeof (sd_image_t ));
1397+ if (results == NULL ) {
1398+ printf (" failed to allocate results array\n " );
1399+ release_all_resources ();
1400+ return 1 ;
1401+ }
13721402
1373- if (params.sample_params .sample_method == SAMPLE_METHOD_DEFAULT) {
1374- params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
1375- }
1403+ results[0 ] = init_image;
1404+ init_image.data = NULL ;
1405+ } else {
1406+ sd_ctx_t * sd_ctx = new_sd_ctx (&sd_ctx_params);
13761407
1377- sd_image_t * results;
1378- int num_results = 1 ;
1379- if (params.mode == IMG_GEN) {
1380- sd_img_gen_params_t img_gen_params = {
1381- params.prompt .c_str (),
1382- params.negative_prompt .c_str (),
1383- params.clip_skip ,
1384- init_image,
1385- ref_images.data (),
1386- (int )ref_images.size (),
1387- params.increase_ref_index ,
1388- mask_image,
1389- params.width ,
1390- params.height ,
1391- params.sample_params ,
1392- params.strength ,
1393- params.seed ,
1394- params.batch_count ,
1395- control_image,
1396- params.control_strength ,
1397- {
1398- pmid_images.data (),
1399- (int )pmid_images.size (),
1400- params.pm_id_embed_path .c_str (),
1401- params.pm_style_strength ,
1402- }, // pm_params
1403- params.vae_tiling_params ,
1404- };
1405-
1406- results = generate_image (sd_ctx, &img_gen_params);
1407- num_results = params.batch_count ;
1408- } else if (params.mode == VID_GEN) {
1409- sd_vid_gen_params_t vid_gen_params = {
1410- params.prompt .c_str (),
1411- params.negative_prompt .c_str (),
1412- params.clip_skip ,
1413- init_image,
1414- end_image,
1415- control_frames.data (),
1416- (int )control_frames.size (),
1417- params.width ,
1418- params.height ,
1419- params.sample_params ,
1420- params.high_noise_sample_params ,
1421- params.moe_boundary ,
1422- params.strength ,
1423- params.seed ,
1424- params.video_frames ,
1425- params.vace_strength ,
1426- };
1427-
1428- results = generate_video (sd_ctx, &vid_gen_params, &num_results);
1429- }
1408+ if (sd_ctx == NULL ) {
1409+ printf (" new_sd_ctx_t failed\n " );
1410+ release_all_resources ();
1411+ return 1 ;
1412+ }
1413+
1414+ if (params.sample_params .sample_method == SAMPLE_METHOD_DEFAULT) {
1415+ params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
1416+ }
1417+
1418+ if (params.mode == IMG_GEN) {
1419+ sd_img_gen_params_t img_gen_params = {
1420+ params.prompt .c_str (),
1421+ params.negative_prompt .c_str (),
1422+ params.clip_skip ,
1423+ init_image,
1424+ ref_images.data (),
1425+ (int )ref_images.size (),
1426+ params.increase_ref_index ,
1427+ mask_image,
1428+ params.width ,
1429+ params.height ,
1430+ params.sample_params ,
1431+ params.strength ,
1432+ params.seed ,
1433+ params.batch_count ,
1434+ control_image,
1435+ params.control_strength ,
1436+ {
1437+ pmid_images.data (),
1438+ (int )pmid_images.size (),
1439+ params.pm_id_embed_path .c_str (),
1440+ params.pm_style_strength ,
1441+ }, // pm_params
1442+ params.vae_tiling_params ,
1443+ };
1444+
1445+ results = generate_image (sd_ctx, &img_gen_params);
1446+ num_results = params.batch_count ;
1447+ } else if (params.mode == VID_GEN) {
1448+ sd_vid_gen_params_t vid_gen_params = {
1449+ params.prompt .c_str (),
1450+ params.negative_prompt .c_str (),
1451+ params.clip_skip ,
1452+ init_image,
1453+ end_image,
1454+ control_frames.data (),
1455+ (int )control_frames.size (),
1456+ params.width ,
1457+ params.height ,
1458+ params.sample_params ,
1459+ params.high_noise_sample_params ,
1460+ params.moe_boundary ,
1461+ params.strength ,
1462+ params.seed ,
1463+ params.video_frames ,
1464+ params.vace_strength ,
1465+ };
1466+
1467+ results = generate_video (sd_ctx, &vid_gen_params, &num_results);
1468+ }
1469+
1470+ if (results == NULL ) {
1471+ printf (" generate failed\n " );
1472+ free_sd_ctx (sd_ctx);
1473+ return 1 ;
1474+ }
14301475
1431- if (results == NULL ) {
1432- printf (" generate failed\n " );
14331476 free_sd_ctx (sd_ctx);
1434- return 1 ;
14351477 }
14361478
14371479 int upscale_factor = 4 ; // unused for RealESRGAN_x4plus_anime_6B.pth
@@ -1444,7 +1486,7 @@ int main(int argc, const char* argv[]) {
14441486 if (upscaler_ctx == NULL ) {
14451487 printf (" new_upscaler_ctx failed\n " );
14461488 } else {
1447- for (int i = 0 ; i < params. batch_count ; i++) {
1489+ for (int i = 0 ; i < num_results ; i++) {
14481490 if (results[i].data == NULL ) {
14491491 continue ;
14501492 }
@@ -1530,7 +1572,6 @@ int main(int argc, const char* argv[]) {
15301572 results[i].data = NULL ;
15311573 }
15321574 free (results);
1533- free_sd_ctx (sd_ctx);
15341575
15351576 release_all_resources ();
15361577
0 commit comments