Skip to content

Commit 0743a1b

Browse files
authored
fix: fix vae tiling for flux2 (#1025)
1 parent 34a6fd4 commit 0743a1b

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

stable-diffusion.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,8 +2096,9 @@ class StableDiffusionGGML {
20962096
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
20972097
int64_t t0 = ggml_time_ms();
20982098
ggml_tensor* result = nullptr;
2099-
int W = x->ne[0] / get_vae_scale_factor();
2100-
int H = x->ne[1] / get_vae_scale_factor();
2099+
const int vae_scale_factor = get_vae_scale_factor();
2100+
int W = x->ne[0] / vae_scale_factor;
2101+
int H = x->ne[1] / vae_scale_factor;
21012102
int C = get_latent_channel();
21022103
if (vae_tiling_params.enabled && !encode_video) {
21032104
// TODO wan2.2 vae support?
@@ -2133,7 +2134,7 @@ class StableDiffusionGGML {
21332134
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
21342135
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
21352136
};
2136-
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
2137+
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
21372138
} else {
21382139
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
21392140
}
@@ -2144,7 +2145,7 @@ class StableDiffusionGGML {
21442145
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
21452146
tae_first_stage->compute(n_threads, in, false, &out, nullptr);
21462147
};
2147-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
2148+
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
21482149
} else {
21492150
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
21502151
}
@@ -2220,8 +2221,9 @@ class StableDiffusionGGML {
22202221
}
22212222

22222223
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
2223-
int64_t W = x->ne[0] * get_vae_scale_factor();
2224-
int64_t H = x->ne[1] * get_vae_scale_factor();
2224+
const int vae_scale_factor = get_vae_scale_factor();
2225+
int64_t W = x->ne[0] * vae_scale_factor;
2226+
int64_t H = x->ne[1] * vae_scale_factor;
22252227
int64_t C = 3;
22262228
ggml_tensor* result = nullptr;
22272229
if (decode_video) {
@@ -2261,7 +2263,7 @@ class StableDiffusionGGML {
22612263
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
22622264
first_stage_model->compute(n_threads, in, true, &out, nullptr);
22632265
};
2264-
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
2266+
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
22652267
} else {
22662268
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
22672269
}
@@ -2273,7 +2275,7 @@ class StableDiffusionGGML {
22732275
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
22742276
tae_first_stage->compute(n_threads, in, true, &out);
22752277
};
2276-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
2278+
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
22772279
} else {
22782280
tae_first_stage->compute(n_threads, x, true, &result);
22792281
}

0 commit comments

Comments
 (0)