Skip to content

Commit 8967889

Browse files
committed
add edit mode
1 parent c3d50c7 commit 8967889

File tree

5 files changed

+400
-213
lines changed

5 files changed

+400
-213
lines changed

diffusion_model.hpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16-
int num_video_frames = -1,
17-
std::vector<struct ggml_tensor*> controls = {},
18-
float control_strength = 0.f,
19-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
20-
struct ggml_tensor** output = NULL,
21-
struct ggml_context* output_ctx = NULL,
22-
std::vector<int> skip_layers = std::vector<int>()) = 0;
16+
std::vector<ggml_tensor*> ref_latents = {},
17+
int num_video_frames = -1,
18+
std::vector<struct ggml_tensor*> controls = {},
19+
float control_strength = 0.f,
20+
struct ggml_tensor** output = NULL,
21+
struct ggml_context* output_ctx = NULL,
22+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2323
virtual void alloc_params_buffer() = 0;
2424
virtual void free_params_buffer() = 0;
2525
virtual void free_compute_buffer() = 0;
@@ -69,13 +69,13 @@ struct UNetModel : public DiffusionModel {
6969
struct ggml_tensor* c_concat,
7070
struct ggml_tensor* y,
7171
struct ggml_tensor* guidance,
72-
int num_video_frames = -1,
73-
std::vector<struct ggml_tensor*> controls = {},
74-
float control_strength = 0.f,
75-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
76-
struct ggml_tensor** output = NULL,
77-
struct ggml_context* output_ctx = NULL,
78-
std::vector<int> skip_layers = std::vector<int>()) {
72+
std::vector<ggml_tensor*> ref_latents = {},
73+
int num_video_frames = -1,
74+
std::vector<struct ggml_tensor*> controls = {},
75+
float control_strength = 0.f,
76+
struct ggml_tensor** output = NULL,
77+
struct ggml_context* output_ctx = NULL,
78+
std::vector<int> skip_layers = std::vector<int>()) {
7979
(void)skip_layers; // SLG doesn't work with UNet models
8080
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
8181
}
@@ -120,13 +120,13 @@ struct MMDiTModel : public DiffusionModel {
120120
struct ggml_tensor* c_concat,
121121
struct ggml_tensor* y,
122122
struct ggml_tensor* guidance,
123-
int num_video_frames = -1,
124-
std::vector<struct ggml_tensor*> controls = {},
125-
float control_strength = 0.f,
126-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
127-
struct ggml_tensor** output = NULL,
128-
struct ggml_context* output_ctx = NULL,
129-
std::vector<int> skip_layers = std::vector<int>()) {
123+
std::vector<ggml_tensor*> ref_latents = {},
124+
int num_video_frames = -1,
125+
std::vector<struct ggml_tensor*> controls = {},
126+
float control_strength = 0.f,
127+
struct ggml_tensor** output = NULL,
128+
struct ggml_context* output_ctx = NULL,
129+
std::vector<int> skip_layers = std::vector<int>()) {
130130
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
131131
}
132132
};
@@ -172,14 +172,14 @@ struct FluxModel : public DiffusionModel {
172172
struct ggml_tensor* c_concat,
173173
struct ggml_tensor* y,
174174
struct ggml_tensor* guidance,
175-
int num_video_frames = -1,
176-
std::vector<struct ggml_tensor*> controls = {},
177-
float control_strength = 0.f,
178-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
179-
struct ggml_tensor** output = NULL,
180-
struct ggml_context* output_ctx = NULL,
181-
std::vector<int> skip_layers = std::vector<int>()) {
182-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers);
175+
std::vector<ggml_tensor*> ref_latents = {},
176+
int num_video_frames = -1,
177+
std::vector<struct ggml_tensor*> controls = {},
178+
float control_strength = 0.f,
179+
struct ggml_tensor** output = NULL,
180+
struct ggml_context* output_ctx = NULL,
181+
std::vector<int> skip_layers = std::vector<int>()) {
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers);
183183
}
184184
};
185185

examples/cli/main.cpp

Lines changed: 84 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ const char* modes_str[] = {
5757
"txt2img",
5858
"img2img",
5959
"img2vid",
60+
"edit",
6061
"convert",
6162
};
6263

6364
enum SDMode {
6465
TXT2IMG,
6566
IMG2IMG,
6667
IMG2VID,
68+
EDIT,
6769
CONVERT,
6870
MODE_COUNT
6971
};
@@ -89,8 +91,7 @@ struct SDParams {
8991
std::string input_path;
9092
std::string mask_path;
9193
std::string control_image_path;
92-
93-
std::vector<std::string> kontext_image_paths;
94+
std::vector<std::string> ref_image_paths;
9495

9596
std::string prompt;
9697
std::string negative_prompt;
@@ -156,6 +157,10 @@ void print_params(SDParams params) {
156157
printf(" init_img: %s\n", params.input_path.c_str());
157158
printf(" mask_img: %s\n", params.mask_path.c_str());
158159
printf(" control_image: %s\n", params.control_image_path.c_str());
160+
printf(" ref_images_paths:\n");
161+
for (auto& path : params.ref_image_paths) {
162+
printf(" %s\n", path.c_str());
163+
};
159164
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
160165
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
161166
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
@@ -210,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
210215
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
211216
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
212217
printf(" --control-image [IMAGE] path to image condition, control net\n");
218+
printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
213219
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
214220
printf(" -p, --prompt [PROMPT] the prompt to render\n");
215221
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -245,9 +251,8 @@ void print_usage(int argc, const char* argv[]) {
245251
printf(" This might crash if it is not supported by the backend.\n");
246252
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
247253
printf(" --canny apply canny preprocessor (edge detection)\n");
248-
printf(" --color Colors the logging tags according to level\n");
254+
printf(" --color colors the logging tags according to level\n");
249255
printf(" -v, --verbose print extra info\n");
250-
printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n");
251256
}
252257

253258
void parse_args(int argc, const char** argv, SDParams& params) {
@@ -632,12 +637,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
632637
break;
633638
}
634639
params.skip_layer_end = std::stof(argv[i]);
635-
} else if (arg == "-ki" || arg == "--kontext-img") {
640+
} else if (arg == "-r" || arg == "--ref-image") {
636641
if (++i >= argc) {
637642
invalid_arg = true;
638643
break;
639644
}
640-
params.kontext_image_paths.push_back(argv[i]);
645+
params.ref_image_paths.push_back(argv[i]);
641646
} else {
642647
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
643648
print_usage(argc, argv);
@@ -666,7 +671,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
666671
}
667672

668673
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
669-
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
674+
fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n");
675+
print_usage(argc, argv);
676+
exit(1);
677+
}
678+
679+
if (params.mode == EDIT && params.ref_image_paths.size() == 0) {
680+
fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n");
670681
print_usage(argc, argv);
671682
exit(1);
672683
}
@@ -830,43 +841,12 @@ int main(int argc, const char* argv[]) {
830841
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
831842
return 1;
832843
}
833-
bool vae_decode_only = true;
834-
835-
std::vector<sd_image_t> kontext_imgs;
836-
for (auto& path : params.kontext_image_paths) {
837-
vae_decode_only = false;
838-
int c = 0;
839-
int width = 0;
840-
int height = 0;
841-
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
842-
if (image_buffer == NULL) {
843-
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
844-
return 1;
845-
}
846-
if (c < 3) {
847-
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
848-
free(image_buffer);
849-
return 1;
850-
}
851-
if (width <= 0) {
852-
fprintf(stderr, "error: the width of image must be greater than 0\n");
853-
free(image_buffer);
854-
return 1;
855-
}
856-
if (height <= 0) {
857-
fprintf(stderr, "error: the height of image must be greater than 0\n");
858-
free(image_buffer);
859-
return 1;
860-
}
861-
kontext_imgs.push_back({(uint32_t)width,
862-
(uint32_t)height,
863-
3,
864-
image_buffer});
865-
}
866844

845+
bool vae_decode_only = true;
867846
uint8_t* input_image_buffer = NULL;
868847
uint8_t* control_image_buffer = NULL;
869848
uint8_t* mask_image_buffer = NULL;
849+
std::vector<sd_image_t> ref_images;
870850

871851
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
872852
vae_decode_only = false;
@@ -918,6 +898,37 @@ int main(int argc, const char* argv[]) {
918898
free(input_image_buffer);
919899
input_image_buffer = resized_image_buffer;
920900
}
901+
} else if (params.mode == EDIT) {
902+
vae_decode_only = false;
903+
for (auto& path : params.ref_image_paths) {
904+
int c = 0;
905+
int width = 0;
906+
int height = 0;
907+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
908+
if (image_buffer == NULL) {
909+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
910+
return 1;
911+
}
912+
if (c < 3) {
913+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
914+
free(image_buffer);
915+
return 1;
916+
}
917+
if (width <= 0) {
918+
fprintf(stderr, "error: the width of image must be greater than 0\n");
919+
free(image_buffer);
920+
return 1;
921+
}
922+
if (height <= 0) {
923+
fprintf(stderr, "error: the height of image must be greater than 0\n");
924+
free(image_buffer);
925+
return 1;
926+
}
927+
ref_images.push_back({(uint32_t)width,
928+
(uint32_t)height,
929+
3,
930+
image_buffer});
931+
}
921932
}
922933

923934
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
@@ -1004,13 +1015,12 @@ int main(int argc, const char* argv[]) {
10041015
params.style_ratio,
10051016
params.normalize_input,
10061017
params.input_id_images_path.c_str(),
1007-
kontext_imgs.data(), kontext_imgs.size(),
10081018
params.skip_layers.data(),
10091019
params.skip_layers.size(),
10101020
params.slg_scale,
10111021
params.skip_layer_start,
10121022
params.skip_layer_end);
1013-
} else {
1023+
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
10141024
sd_image_t input_image = {(uint32_t)params.width,
10151025
(uint32_t)params.height,
10161026
3,
@@ -1074,13 +1084,38 @@ int main(int argc, const char* argv[]) {
10741084
params.style_ratio,
10751085
params.normalize_input,
10761086
params.input_id_images_path.c_str(),
1077-
kontext_imgs.data(), kontext_imgs.size(),
10781087
params.skip_layers.data(),
10791088
params.skip_layers.size(),
10801089
params.slg_scale,
10811090
params.skip_layer_start,
10821091
params.skip_layer_end);
10831092
}
1093+
} else { // EDIT
1094+
results = edit(sd_ctx,
1095+
ref_images.data(),
1096+
ref_images.size(),
1097+
params.prompt.c_str(),
1098+
params.negative_prompt.c_str(),
1099+
params.clip_skip,
1100+
params.cfg_scale,
1101+
params.guidance,
1102+
params.eta,
1103+
params.width,
1104+
params.height,
1105+
params.sample_method,
1106+
params.sample_steps,
1107+
params.strength,
1108+
params.seed,
1109+
params.batch_count,
1110+
control_image,
1111+
params.control_strength,
1112+
params.style_ratio,
1113+
params.normalize_input,
1114+
params.skip_layers.data(),
1115+
params.skip_layers.size(),
1116+
params.slg_scale,
1117+
params.skip_layer_start,
1118+
params.skip_layer_end);
10841119
}
10851120

10861121
if (results == NULL) {
@@ -1118,19 +1153,19 @@ int main(int argc, const char* argv[]) {
11181153

11191154
std::string dummy_name, ext, lc_ext;
11201155
bool is_jpg;
1121-
size_t last = params.output_path.find_last_of(".");
1156+
size_t last = params.output_path.find_last_of(".");
11221157
size_t last_path = std::min(params.output_path.find_last_of("/"),
11231158
params.output_path.find_last_of("\\"));
1124-
if (last != std::string::npos // filename has extension
1125-
&& (last_path == std::string::npos || last > last_path)) {
1159+
if (last != std::string::npos // filename has extension
1160+
&& (last_path == std::string::npos || last > last_path)) {
11261161
dummy_name = params.output_path.substr(0, last);
11271162
ext = lc_ext = params.output_path.substr(last);
11281163
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower);
11291164
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
11301165
} else {
11311166
dummy_name = params.output_path;
11321167
ext = lc_ext = "";
1133-
is_jpg = false;
1168+
is_jpg = false;
11341169
}
11351170
// appending ".png" to absent or unknown extension
11361171
if (!is_jpg && lc_ext != ".png") {
@@ -1142,7 +1177,7 @@ int main(int argc, const char* argv[]) {
11421177
continue;
11431178
}
11441179
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
1145-
if (is_jpg) {
1180+
if(is_jpg) {
11461181
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
11471182
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
11481183
printf("save result JPEG image to '%s'\n", final_image_path.c_str());
@@ -1160,4 +1195,4 @@ int main(int argc, const char* argv[]) {
11601195
free(input_image_buffer);
11611196

11621197
return 0;
1163-
}
1198+
}

0 commit comments

Comments
 (0)