Skip to content

Commit 7179e49

Browse files
committed
1 parent 60d3cc7 commit 7179e49

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

otherarch/sdcpp/stable-diffusion.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,14 +2689,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
26892689
sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img);
26902690
sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img);
26912691

2692-
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
2693-
26942692
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
26952693
int64_t mask_channels = 1;
26962694
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
2697-
mask_channels = 8 * 8; // flatten the whole mask
2695+
mask_channels = vae_scale_factor * vae_scale_factor; // flatten the whole mask
26982696
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2699-
mask_channels = 1 + init_latent->ne[2];
2697+
mask_channels = 1 + sd_ctx->sd->get_latent_channel();
27002698
}
27012699
ggml_tensor* masked_latent = nullptr;
27022700

@@ -2705,8 +2703,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
27052703
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
27062704
ggml_ext_tensor_apply_mask(init_img, mask_img, masked_img);
27072705
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2706+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
27082707
} else {
27092708
// mask after vae
2709+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
27102710
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
27112711
ggml_ext_tensor_apply_mask(init_latent, mask_img, masked_latent, 0.);
27122712
}
@@ -2747,9 +2747,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
27472747
for (int k = 0; k < masked_latent->ne[2]; k++) {
27482748
ggml_ext_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
27492749
}
2750+
} else {
2751+
float m = ggml_ext_tensor_get_f32(mask_img, mx, my);
2752+
ggml_ext_tensor_set_f32(concat_latent, m, ix, iy, 0);
2753+
for (int k = 0; k < masked_latent->ne[2];k++) {
2754+
float v = ggml_ext_tensor_get_f32(masked_latent, ix, iy, k);
2755+
ggml_ext_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
2756+
}
27502757
}
27512758
}
27522759
}
2760+
} else {
2761+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
27532762
}
27542763

27552764
{

0 commit comments

Comments
 (0)