@@ -57,13 +57,15 @@ const char* modes_str[] = {
5757 " txt2img" ,
5858 " img2img" ,
5959 " img2vid" ,
60+ " edit" ,
6061 " convert" ,
6162};
6263
6364enum SDMode {
6465 TXT2IMG,
6566 IMG2IMG,
6667 IMG2VID,
68+ EDIT,
6769 CONVERT,
6870 MODE_COUNT
6971};
@@ -89,6 +91,7 @@ struct SDParams {
8991 std::string input_path;
9092 std::string mask_path;
9193 std::string control_image_path;
94+ std::vector<std::string> ref_image_paths;
9295
9396 std::string prompt;
9497 std::string negative_prompt;
@@ -154,6 +157,10 @@ void print_params(SDParams params) {
154157 printf (" init_img: %s\n " , params.input_path .c_str ());
155158 printf (" mask_img: %s\n " , params.mask_path .c_str ());
156159 printf (" control_image: %s\n " , params.control_image_path .c_str ());
160+ printf (" ref_images_paths:\n " );
161+ for (auto & path : params.ref_image_paths ) {
162+ printf (" %s\n " , path.c_str ());
163+ };
157164 printf (" clip on cpu: %s\n " , params.clip_on_cpu ? " true" : " false" );
158165 printf (" controlnet cpu: %s\n " , params.control_net_cpu ? " true" : " false" );
159166 printf (" vae decoder on cpu:%s\n " , params.vae_on_cpu ? " true" : " false" );
@@ -208,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
208215 printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
209216 printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
210217 printf (" --control-image [IMAGE] path to image condition, control net\n " );
218+ printf (" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n " );
211219 printf (" -o, --output OUTPUT path to write result image to (default: ./output.png)\n " );
212220 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
213221 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
@@ -243,7 +251,7 @@ void print_usage(int argc, const char* argv[]) {
243251 printf (" This might crash if it is not supported by the backend.\n " );
244252 printf (" --control-net-cpu keep controlnet in cpu (for low vram)\n " );
245253 printf (" --canny apply canny preprocessor (edge detection)\n " );
246- printf (" --color Colors the logging tags according to level\n " );
254+ printf (" --color colors the logging tags according to level\n " );
247255 printf (" -v, --verbose print extra info\n " );
248256}
249257
@@ -629,6 +637,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629637 break ;
630638 }
631639 params.skip_layer_end = std::stof (argv[i]);
640+ } else if (arg == " -r" || arg == " --ref-image" ) {
641+ if (++i >= argc) {
642+ invalid_arg = true ;
643+ break ;
644+ }
645+ params.ref_image_paths .push_back (argv[i]);
632646 } else {
633647 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634648 print_usage (argc, argv);
@@ -657,7 +671,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
657671 }
658672
659673 if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path .length () == 0 ) {
660- fprintf (stderr, " error: when using the img2img mode, the following arguments are required: init-img\n " );
674+ fprintf (stderr, " error: when using the img2img/img2vid mode, the following arguments are required: init-img\n " );
675+ print_usage (argc, argv);
676+ exit (1 );
677+ }
678+
679+ if (params.mode == EDIT && params.ref_image_paths .size () == 0 ) {
680+ fprintf (stderr, " error: when using the edit mode, the following arguments are required: ref-image\n " );
661681 print_usage (argc, argv);
662682 exit (1 );
663683 }
@@ -826,6 +846,7 @@ int main(int argc, const char* argv[]) {
826846 uint8_t * input_image_buffer = NULL ;
827847 uint8_t * control_image_buffer = NULL ;
828848 uint8_t * mask_image_buffer = NULL ;
849+ std::vector<sd_image_t > ref_images;
829850
830851 if (params.mode == IMG2IMG || params.mode == IMG2VID) {
831852 vae_decode_only = false ;
@@ -877,6 +898,37 @@ int main(int argc, const char* argv[]) {
877898 free (input_image_buffer);
878899 input_image_buffer = resized_image_buffer;
879900 }
901+ } else if (params.mode == EDIT) {
902+ vae_decode_only = false ;
903+ for (auto & path : params.ref_image_paths ) {
904+ int c = 0 ;
905+ int width = 0 ;
906+ int height = 0 ;
907+ uint8_t * image_buffer = stbi_load (path.c_str (), &width, &height, &c, 3 );
908+ if (image_buffer == NULL ) {
909+ fprintf (stderr, " load image from '%s' failed\n " , path.c_str ());
910+ return 1 ;
911+ }
912+ if (c < 3 ) {
913+ fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
914+ free (image_buffer);
915+ return 1 ;
916+ }
917+ if (width <= 0 ) {
918+ fprintf (stderr, " error: the width of image must be greater than 0\n " );
919+ free (image_buffer);
920+ return 1 ;
921+ }
922+ if (height <= 0 ) {
923+ fprintf (stderr, " error: the height of image must be greater than 0\n " );
924+ free (image_buffer);
925+ return 1 ;
926+ }
927+ ref_images.push_back ({(uint32_t )width,
928+ (uint32_t )height,
929+ 3 ,
930+ image_buffer});
931+ }
880932 }
881933
882934 sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
@@ -968,7 +1020,7 @@ int main(int argc, const char* argv[]) {
9681020 params.slg_scale ,
9691021 params.skip_layer_start ,
9701022 params.skip_layer_end );
971- } else {
1023+ } else if (params. mode == IMG2IMG || params. mode == IMG2VID) {
9721024 sd_image_t input_image = {(uint32_t )params.width ,
9731025 (uint32_t )params.height ,
9741026 3 ,
@@ -1038,6 +1090,32 @@ int main(int argc, const char* argv[]) {
10381090 params.skip_layer_start ,
10391091 params.skip_layer_end );
10401092 }
1093+ } else { // EDIT
1094+ results = edit (sd_ctx,
1095+ ref_images.data (),
1096+ ref_images.size (),
1097+ params.prompt .c_str (),
1098+ params.negative_prompt .c_str (),
1099+ params.clip_skip ,
1100+ params.cfg_scale ,
1101+ params.guidance ,
1102+ params.eta ,
1103+ params.width ,
1104+ params.height ,
1105+ params.sample_method ,
1106+ params.sample_steps ,
1107+ params.strength ,
1108+ params.seed ,
1109+ params.batch_count ,
1110+ control_image,
1111+ params.control_strength ,
1112+ params.style_ratio ,
1113+ params.normalize_input ,
1114+ params.skip_layers .data (),
1115+ params.skip_layers .size (),
1116+ params.slg_scale ,
1117+ params.skip_layer_start ,
1118+ params.skip_layer_end );
10411119 }
10421120
10431121 if (results == NULL ) {
@@ -1117,4 +1195,4 @@ int main(int argc, const char* argv[]) {
11171195 free (input_image_buffer);
11181196
11191197 return 0 ;
1120- }
1198+ }
0 commit comments