Skip to content

Commit 5011832

Browse files
committed
vae tiling: refactor again, base on smaller buffer for alignment
1 parent dc990a7 commit 5011832

File tree

2 files changed

+77
-44
lines changed

2 files changed

+77
-44
lines changed

ggml_extend.hpp

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -737,62 +737,67 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
737737
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
738738

739739
// Tiling
740-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
740+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
741741
output = ggml_set_f32(output, 0);
742742

743743
int input_width = (int)input->ne[0];
744744
int input_height = (int)input->ne[1];
745745
int output_width = (int)output->ne[0];
746746
int output_height = (int)output->ne[1];
747747

748-
int input_tile_size, output_tile_size;
749-
if (scaled_out) {
750-
input_tile_size = tile_size;
751-
output_tile_size = tile_size * scale;
752-
} else {
753-
input_tile_size = tile_size * scale;
754-
output_tile_size = tile_size;
748+
GGML_ASSERT(input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
749+
GGML_ASSERT(input_width / output_width == scale || output_width / input_width == scale);
750+
751+
int small_width = output_width;
752+
int small_height = output_height;
753+
754+
bool big_out = output_width > input_width;
755+
if (big_out) {
756+
// Ex: decode
757+
small_width = input_width;
758+
small_height = input_height;
755759
}
756-
int tile_overlap = (input_tile_size * tile_overlap_factor);
757-
int non_tile_overlap = input_tile_size - tile_overlap;
758760

759-
int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
760-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % input_width;
761+
int tile_overlap = (tile_size * tile_overlap_factor);
762+
int non_tile_overlap = tile_size - tile_overlap;
763+
764+
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
765+
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
761766

762-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
767+
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
763768
// if tiles don't fit perfectly using the desired overlap
764769
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
765770
num_tiles_x++;
766771
}
767772

768-
float tile_overlap_factor_x = (float)(input_tile_size * num_tiles_x - input_width) / (float)(input_tile_size * (num_tiles_x - 1));
773+
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
769774
if (num_tiles_x <= 2) {
770-
if (input_width <= input_tile_size) {
775+
if (small_width <= tile_size) {
771776
num_tiles_x = 1;
772777
tile_overlap_factor_x = 0;
773778
} else {
774779
num_tiles_x = 2;
775-
tile_overlap_factor_x = (2 * input_tile_size - input_width) / (float)input_tile_size;
780+
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
776781
}
777782
}
778783

779-
int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
780-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % input_height;
784+
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
785+
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
781786

782-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
787+
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
783788
// if tiles don't fit perfectly using the desired overlap
784789
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
785790
num_tiles_y++;
786791
}
787792

788-
float tile_overlap_factor_y = (float)(input_tile_size * num_tiles_y - input_height) / (float)(input_tile_size * (num_tiles_y - 1));
793+
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
789794
if (num_tiles_y <= 2) {
790-
if (input_height <= input_tile_size) {
795+
if (small_height <= tile_size) {
791796
num_tiles_y = 1;
792797
tile_overlap_factor_y = 0;
793798
} else {
794799
num_tiles_y = 2;
795-
tile_overlap_factor_y = (2 * input_tile_size - input_height) / (float)input_tile_size;
800+
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
796801
}
797802
}
798803

@@ -801,11 +806,20 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
801806

802807
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
803808

804-
int tile_overlap_x = (int32_t)(input_tile_size * tile_overlap_factor_x);
805-
int non_tile_overlap_x = input_tile_size - tile_overlap_x;
809+
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
810+
int non_tile_overlap_x = tile_size - tile_overlap_x;
806811

807-
int tile_overlap_y = (int32_t)(input_tile_size * tile_overlap_factor_y);
808-
int non_tile_overlap_y = input_tile_size - tile_overlap_y;
812+
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
813+
int non_tile_overlap_y = tile_size - tile_overlap_y;
814+
815+
int input_tile_size = tile_size;
816+
int output_tile_size = tile_size;
817+
818+
if (big_out) {
819+
output_tile_size *= scale;
820+
} else {
821+
input_tile_size *= scale;
822+
}
809823

810824
struct ggml_init_params params = {};
811825
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
@@ -826,37 +840,48 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
826840
// tiling
827841
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
828842
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
829-
on_processing(input_tile, NULL, true);
830843
int num_tiles = num_tiles_x * num_tiles_y;
831844
LOG_INFO("processing %i tiles", num_tiles);
832-
pretty_progress(1, num_tiles, 0.0f);
845+
pretty_progress(0, num_tiles, 0.0f);
833846
int tile_count = 1;
834847
bool last_y = false, last_x = false;
835848
float last_time = 0.0f;
836-
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap_y) {
849+
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
837850
int dy = 0;
838-
if (y + input_tile_size >= input_height) {
851+
if (y + tile_size >= small_height) {
839852
int _y = y;
840-
y = input_height - input_tile_size;
853+
y = small_height - tile_size;
841854
dy = _y - y;
855+
if (big_out) {
856+
dy *= scale;
857+
}
842858
last_y = true;
843859
}
844-
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap_x) {
860+
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
845861
int dx = 0;
846-
if (x + input_tile_size >= input_width) {
862+
if (x + tile_size >= small_width) {
847863
int _x = x;
848-
x = input_width - input_tile_size;
864+
x = small_width - tile_size;
849865
dx = _x - x;
866+
if (big_out) {
867+
dx *= scale;
868+
}
850869
last_x = true;
851870
}
871+
872+
int x_in = big_out ? x : scale * x;
873+
int y_in = big_out ? y : scale * y;
874+
int x_out = big_out ? x * scale : x;
875+
int y_out = big_out ? y * scale : y;
876+
877+
int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
878+
int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;
879+
852880
int64_t t1 = ggml_time_ms();
853-
ggml_split_tensor_2d(input, input_tile, x, y);
881+
ggml_split_tensor_2d(input, input_tile, x_in, y_in);
854882
on_processing(input_tile, output_tile, false);
855-
if (scaled_out) {
856-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
857-
} else {
858-
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
859-
}
883+
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
884+
860885
int64_t t2 = ggml_time_ms();
861886
last_time = (t2 - t1) / 1000.0f;
862887
pretty_progress(tile_count, num_tiles, last_time);

stable-diffusion.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,13 +1318,21 @@ class StableDiffusionGGML {
13181318
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13191319
first_stage_model->compute(n_threads, in, true, &out, NULL);
13201320
};
1321-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling, false);
1321+
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
13221322
} else {
13231323
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13241324
}
13251325
first_stage_model->free_compute_buffer();
13261326
} else {
1327-
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
1327+
if (vae_tiling && !decode_video) {
1328+
// split latent in 32x32 tiles and compute in several steps
1329+
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1330+
tae_first_stage->compute(n_threads, in, true, &out, NULL);
1331+
};
1332+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
1333+
} else {
1334+
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
1335+
}
13281336
tae_first_stage->free_compute_buffer();
13291337
}
13301338

@@ -1463,7 +1471,7 @@ class StableDiffusionGGML {
14631471
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14641472
first_stage_model->compute(n_threads, in, true, &out, NULL);
14651473
};
1466-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling, true);
1474+
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
14671475
} else {
14681476
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14691477
}
@@ -1475,7 +1483,7 @@ class StableDiffusionGGML {
14751483
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14761484
tae_first_stage->compute(n_threads, in, true, &out);
14771485
};
1478-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, true);
1486+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
14791487
} else {
14801488
tae_first_stage->compute(n_threads, x, true, &result);
14811489
}

0 commit comments

Comments
 (0)