Skip to content

Commit fbf6563

Browse files
committed
Refactor latent preview + support tae/vae preview
1 parent 6fdc230 commit fbf6563

File tree

4 files changed

+247
-96
lines changed

4 files changed

+247
-96
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ test/
1111
output*.png
1212
models*
1313
*.log
14-
latent-preview.png
14+
preview.png

examples/cli/main.cpp

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#include "flux.hpp"
1111
#include "stable-diffusion.h"
1212

13-
#include "latent-preview.h"
14-
1513
#define STB_IMAGE_IMPLEMENTATION
1614
#define STB_IMAGE_STATIC
1715
#include "stb_image.h"
@@ -60,6 +58,13 @@ const char* modes_str[] = {
6058
"convert",
6159
};
6260

61+
const char* previews_str[] = {
62+
"none",
63+
"proj",
64+
"tae",
65+
"vae",
66+
};
67+
6368
enum SDMode {
6469
TXT2IMG,
6570
IMG2IMG,
@@ -127,6 +132,11 @@ struct SDParams {
127132
float slg_scale = 0.;
128133
float skip_layer_start = 0.01;
129134
float skip_layer_end = 0.2;
135+
136+
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
137+
int preview_interval = 1;
138+
std::string preview_path = "preview.png";
139+
bool taesd_preview = false;
130140
};
131141

132142
void print_params(SDParams params) {
@@ -488,6 +498,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
488498
params.diffusion_flash_attn = true; // can reduce MEM significantly
489499
} else if (arg == "--canny") {
490500
params.canny_preprocess = true;
501+
} else if (arg == "--taesd-preview-only") {
502+
params.taesd_preview = true;
491503
} else if (arg == "-b" || arg == "--batch-count") {
492504
if (++i >= argc) {
493505
invalid_arg = true;
@@ -610,6 +622,35 @@ void parse_args(int argc, const char** argv, SDParams& params) {
610622
break;
611623
}
612624
params.skip_layer_end = std::stof(argv[i]);
625+
} else if (arg == "--preview") {
626+
if (++i >= argc) {
627+
invalid_arg = true;
628+
break;
629+
}
630+
const char* preview = argv[i];
631+
int preview_method = -1;
632+
for (int m = 0; m < N_PREVIEWS; m++) {
633+
if (!strcmp(preview, previews_str[m])) {
634+
preview_method = m;
635+
}
636+
}
637+
if (preview_method == -1) {
638+
invalid_arg = true;
639+
break;
640+
}
641+
params.preview_method = (sd_preview_policy_t)preview_method;
642+
} else if (arg == "--preview-interval") {
643+
if (++i >= argc) {
644+
invalid_arg = true;
645+
break;
646+
}
647+
params.preview_interval = std::stoi(argv[i]);
648+
} else if (arg == "--preview-path") {
649+
if (++i >= argc) {
650+
invalid_arg = true;
651+
break;
652+
}
653+
params.preview_path = argv[i];
613654
} else {
614655
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
615656
print_usage(argc, argv);
@@ -767,52 +808,17 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
767808
fflush(out_stream);
768809
}
769810

770-
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
771-
const int channel = 3;
772-
int width = latents->ne[0];
773-
int height = latents->ne[1];
774-
int dim = latents->ne[2];
775-
776-
const float(*latent_rgb_proj)[channel];
777-
778-
if (dim == 16) {
779-
// 16 channels VAE -> Flux or SD3
780-
781-
if (sd_version_is_sd3(version)) {
782-
latent_rgb_proj = sd3_latent_rgb_proj;
783-
} else if (sd_version_is_flux(version)) {
784-
latent_rgb_proj = flux_latent_rgb_proj;
785-
} else {
786-
// unknown model
787-
return;
788-
}
811+
const char* preview_path;
789812

790-
} else if (dim == 4) {
791-
// 4 channels VAE
792-
if (version == VERSION_SDXL) {
793-
latent_rgb_proj = sdxl_latent_rgb_proj;
794-
} else if (version == VERSION_SD1 || version == VERSION_SD2) {
795-
latent_rgb_proj = sd_latent_rgb_proj;
796-
} else {
797-
// unknown model
798-
return;
799-
}
800-
} else {
801-
// unknown latent space
802-
return;
803-
}
804-
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
805-
806-
preview_latent_image(data, latents, latent_rgb_proj, width, height, dim);
807-
808-
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
809-
free(data);
813+
void step_callback(int step, sd_image_t image) {
814+
stbi_write_png(preview_path, image.width, image.height, image.channel, image.data, 0);
810815
}
811816

812817
int main(int argc, const char* argv[]) {
813818
SDParams params;
814819

815820
parse_args(argc, argv, params);
821+
preview_path = params.preview_path.c_str();
816822

817823
sd_set_log_callback(sd_log_cb, (void*)&params);
818824

@@ -920,7 +926,8 @@ int main(int argc, const char* argv[]) {
920926
params.clip_on_cpu,
921927
params.control_net_cpu,
922928
params.vae_on_cpu,
923-
params.diffusion_flash_attn);
929+
params.diffusion_flash_attn,
930+
params.taesd_preview);
924931

925932
if (sd_ctx == NULL) {
926933
printf("new_sd_ctx_t failed\n");
@@ -975,6 +982,8 @@ int main(int argc, const char* argv[]) {
975982
params.slg_scale,
976983
params.skip_layer_start,
977984
params.skip_layer_end,
985+
params.preview_method,
986+
params.preview_interval,
978987
(step_callback_t)step_callback);
979988
} else {
980989
sd_image_t input_image = {(uint32_t)params.width,
@@ -996,8 +1005,7 @@ int main(int argc, const char* argv[]) {
9961005
params.sample_method,
9971006
params.sample_steps,
9981007
params.strength,
999-
params.seed,
1000-
(step_callback_t)step_callback);
1008+
params.seed);
10011009
if (results == NULL) {
10021010
printf("generate failed\n");
10031011
free_sd_ctx(sd_ctx);
@@ -1044,6 +1052,8 @@ int main(int argc, const char* argv[]) {
10441052
params.slg_scale,
10451053
params.skip_layer_start,
10461054
params.skip_layer_end,
1055+
params.preview_method,
1056+
params.preview_interval,
10471057
(step_callback_t)step_callback);
10481058
}
10491059
}

0 commit comments

Comments
 (0)