@@ -2689,14 +2689,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
26892689 sd_image_to_ggml_tensor (sd_img_gen_params->mask_image , mask_img);
26902690 sd_image_to_ggml_tensor (sd_img_gen_params->init_image , init_img);
26912691
2692- init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
2693-
26942692 if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
26952693 int64_t mask_channels = 1 ;
26962694 if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
2697- mask_channels = 8 * 8 ; // flatten the whole mask
2695+ mask_channels = vae_scale_factor * vae_scale_factor ; // flatten the whole mask
26982696 } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2699- mask_channels = 1 + init_latent-> ne [ 2 ] ;
2697+ mask_channels = 1 + sd_ctx-> sd -> get_latent_channel () ;
27002698 }
27012699 ggml_tensor* masked_latent = nullptr ;
27022700
@@ -2705,8 +2703,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
27052703 ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
27062704 ggml_ext_tensor_apply_mask (init_img, mask_img, masked_img);
27072705 masked_latent = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
2706+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
27082707 } else {
27092708 // mask after vae
2709+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
27102710 masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], 1 );
27112711 ggml_ext_tensor_apply_mask (init_latent, mask_img, masked_latent, 0 .);
27122712 }
@@ -2747,9 +2747,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
27472747 for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
27482748 ggml_ext_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
27492749 }
2750+ } else {
2751+ float m = ggml_ext_tensor_get_f32 (mask_img, mx, my);
2752+ ggml_ext_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
2753+ for (int k = 0 ; k < masked_latent->ne [2 ];k++) {
2754+ float v = ggml_ext_tensor_get_f32 (masked_latent, ix, iy, k);
2755+ ggml_ext_tensor_set_f32 (concat_latent, v, ix, iy, k + mask_channels);
2756+ }
27502757 }
27512758 }
27522759 }
2760+ } else {
2761+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
27532762 }
27542763
27552764 {
0 commit comments