1010#include " flux.hpp"
1111#include " stable-diffusion.h"
1212
13- #include " latent-preview.h"
14-
1513#define STB_IMAGE_IMPLEMENTATION
1614#define STB_IMAGE_STATIC
1715#include " stb_image.h"
@@ -60,6 +58,13 @@ const char* modes_str[] = {
6058 " convert" ,
6159};
6260
61+ const char * previews_str[] = {
62+ " none" ,
63+ " proj" ,
64+ " tae" ,
65+ " vae" ,
66+ };
67+
6368enum SDMode {
6469 TXT2IMG,
6570 IMG2IMG,
@@ -127,6 +132,11 @@ struct SDParams {
127132 float slg_scale = 0 .;
128133 float skip_layer_start = 0.01 ;
129134 float skip_layer_end = 0.2 ;
135+
136+ sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
137+ int preview_interval = 1 ;
138+ std::string preview_path = " preview.png" ;
139+ bool taesd_preview = false ;
130140};
131141
132142void print_params (SDParams params) {
@@ -488,6 +498,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
488498 params.diffusion_flash_attn = true ; // can reduce MEM significantly
489499 } else if (arg == " --canny" ) {
490500 params.canny_preprocess = true ;
501+ } else if (arg == " --taesd-preview-only" ) {
502+ params.taesd_preview = true ;
491503 } else if (arg == " -b" || arg == " --batch-count" ) {
492504 if (++i >= argc) {
493505 invalid_arg = true ;
@@ -610,6 +622,35 @@ void parse_args(int argc, const char** argv, SDParams& params) {
610622 break ;
611623 }
612624 params.skip_layer_end = std::stof (argv[i]);
625+ } else if (arg == " --preview" ) {
626+ if (++i >= argc) {
627+ invalid_arg = true ;
628+ break ;
629+ }
630+ const char * preview = argv[i];
631+ int preview_method = -1 ;
632+ for (int m = 0 ; m < N_PREVIEWS; m++) {
633+ if (!strcmp (preview, previews_str[m])) {
634+ preview_method = m;
635+ }
636+ }
637+ if (preview_method == -1 ) {
638+ invalid_arg = true ;
639+ break ;
640+ }
641+ params.preview_method = (sd_preview_policy_t )preview_method;
642+ } else if (arg == " --preview-interval" ) {
643+ if (++i >= argc) {
644+ invalid_arg = true ;
645+ break ;
646+ }
647+ params.preview_interval = std::stoi (argv[i]);
648+ } else if (arg == " --preview-path" ) {
649+ if (++i >= argc) {
650+ invalid_arg = true ;
651+ break ;
652+ }
653+ params.preview_path = argv[i];
613654 } else {
614655 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
615656 print_usage (argc, argv);
@@ -767,52 +808,17 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
767808 fflush (out_stream);
768809}
769810
770- void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
771- const int channel = 3 ;
772- int width = latents->ne [0 ];
773- int height = latents->ne [1 ];
774- int dim = latents->ne [2 ];
775-
776- const float (*latent_rgb_proj)[channel];
777-
778- if (dim == 16 ) {
779- // 16 channels VAE -> Flux or SD3
780-
781- if (sd_version_is_sd3 (version)) {
782- latent_rgb_proj = sd3_latent_rgb_proj;
783- } else if (sd_version_is_flux (version)) {
784- latent_rgb_proj = flux_latent_rgb_proj;
785- } else {
786- // unknown model
787- return ;
788- }
811+ const char * preview_path;
789812
790- } else if (dim == 4 ) {
791- // 4 channels VAE
792- if (version == VERSION_SDXL) {
793- latent_rgb_proj = sdxl_latent_rgb_proj;
794- } else if (version == VERSION_SD1 || version == VERSION_SD2) {
795- latent_rgb_proj = sd_latent_rgb_proj;
796- } else {
797- // unknown model
798- return ;
799- }
800- } else {
801- // unknown latent space
802- return ;
803- }
804- uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
805-
806- preview_latent_image (data, latents, latent_rgb_proj, width, height, dim);
807-
808- stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
809- free (data);
813+ void step_callback (int step, sd_image_t image) {
814+ stbi_write_png (preview_path, image.width , image.height , image.channel , image.data , 0 );
810815}
811816
812817int main (int argc, const char * argv[]) {
813818 SDParams params;
814819
815820 parse_args (argc, argv, params);
821+ preview_path = params.preview_path .c_str ();
816822
817823 sd_set_log_callback (sd_log_cb, (void *)¶ms);
818824
@@ -920,7 +926,8 @@ int main(int argc, const char* argv[]) {
920926 params.clip_on_cpu ,
921927 params.control_net_cpu ,
922928 params.vae_on_cpu ,
923- params.diffusion_flash_attn );
929+ params.diffusion_flash_attn ,
930+ params.taesd_preview );
924931
925932 if (sd_ctx == NULL ) {
926933 printf (" new_sd_ctx_t failed\n " );
@@ -975,6 +982,8 @@ int main(int argc, const char* argv[]) {
975982 params.slg_scale ,
976983 params.skip_layer_start ,
977984 params.skip_layer_end ,
985+ params.preview_method ,
986+ params.preview_interval ,
978987 (step_callback_t )step_callback);
979988 } else {
980989 sd_image_t input_image = {(uint32_t )params.width ,
@@ -996,8 +1005,7 @@ int main(int argc, const char* argv[]) {
9961005 params.sample_method ,
9971006 params.sample_steps ,
9981007 params.strength ,
999- params.seed ,
1000- (step_callback_t )step_callback);
1008+ params.seed );
10011009 if (results == NULL ) {
10021010 printf (" generate failed\n " );
10031011 free_sd_ctx (sd_ctx);
@@ -1044,6 +1052,8 @@ int main(int argc, const char* argv[]) {
10441052 params.slg_scale ,
10451053 params.skip_layer_start ,
10461054 params.skip_layer_end ,
1055+ params.preview_method ,
1056+ params.preview_interval ,
10471057 (step_callback_t )step_callback);
10481058 }
10491059 }
0 commit comments