Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ struct SDParams {
rng_type_t rng_type = CUDA_RNG;
int64_t seed = 42;
bool verbose = false;
bool vae_tiling = false;
bool offload_params_to_cpu = false;
bool control_net_cpu = false;
bool normalize_input = false;
Expand All @@ -119,6 +118,8 @@ struct SDParams {
int chroma_t5_mask_pad = 1;
float flow_shift = INFINITY;

sd_tiling_params_t vae_tiling_params = {false, 32, 32, 0.5f, false, 0.0f, 0.0f};

SDParams() {
sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params);
Expand Down Expand Up @@ -180,7 +181,7 @@ void print_params(SDParams params) {
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
printf(" seed: %ld\n", params.seed);
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
Expand Down Expand Up @@ -268,6 +269,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)\n");
printf(" --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)\n");
printf(" --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
Expand Down Expand Up @@ -485,7 +489,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"-o", "--output", "", &params.output_path},
{"-p", "--prompt", "", &params.prompt},
{"-n", "--negative-prompt", "", &params.negative_prompt},

{"", "--upscale-model", "", &params.esrgan_path},
};

Expand Down Expand Up @@ -526,7 +529,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
};

options.bool_options = {
{"", "--vae-tiling", "", true, &params.vae_tiling},
{"", "--vae-tiling", "", true, &params.vae_tiling_params.enabled},
{"", "--offload-to-cpu", "", true, &params.offload_params_to_cpu},
{"", "--control-net-cpu", "", true, &params.control_net_cpu},
{"", "--normalize-input", "", true, &params.normalize_input},
Expand Down Expand Up @@ -726,6 +729,62 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1;
};

auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string tile_size_str = argv[index];
size_t x_pos = tile_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string tile_x_str = tile_size_str.substr(0, x_pos);
std::string tile_y_str = tile_size_str.substr(x_pos + 1);
params.vae_tiling_params.tile_size_x = std::stoi(tile_x_str);
params.vae_tiling_params.tile_size_y = std::stoi(tile_y_str);
} else {
params.vae_tiling_params.tile_size_x = params.vae_tiling_params.tile_size_y = std::stoi(tile_size_str);
}
} catch (const std::invalid_argument& e) {
return -1;
} catch (const std::out_of_range& e) {
return -1;
}
params.vae_tiling_params.relative = false;
return 1;
};

auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string rel_size_str = argv[index];
size_t x_pos = rel_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string rel_x_str = rel_size_str.substr(0, x_pos);
std::string rel_y_str = rel_size_str.substr(x_pos + 1);
params.vae_tiling_params.rel_size_x = std::stof(rel_x_str);
params.vae_tiling_params.rel_size_y = std::stof(rel_y_str);
} else {
params.vae_tiling_params.rel_size_x = params.vae_tiling_params.rel_size_y = std::stof(rel_size_str);
}
} catch (const std::invalid_argument& e) {
return -1;
} catch (const std::out_of_range& e) {
return -1;
}
params.vae_tiling_params.relative = true;
return 1;
};

auto on_tile_overlap_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
params.vae_tiling_params.target_overlap = std::stof(argv[index]);
return 1;
};

options.manual_options = {
{"-M", "--mode", "", on_mode_arg},
{"", "--type", "", on_type_arg},
Expand All @@ -739,6 +798,9 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--high-noise-skip-layers", "", on_high_noise_skip_layers_arg},
{"-r", "--ref-image", "", on_ref_image_arg},
{"-h", "--help", "", on_help_arg},
{"", "--vae-tile-size", "", on_tile_size_arg},
{"", "--vae-relative-tile-size", "", on_relative_tile_size_arg},
{"", "--vae-tile-overlap", "", on_tile_overlap_arg},
};

if (!parse_options(argc, argv, options)) {
Expand Down Expand Up @@ -1176,7 +1238,6 @@ int main(int argc, const char* argv[]) {
params.embedding_dir.c_str(),
params.stacked_id_embed_dir.c_str(),
vae_decode_only,
params.vae_tiling,
true,
params.n_threads,
params.wtype,
Expand Down Expand Up @@ -1225,6 +1286,7 @@ int main(int argc, const char* argv[]) {
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.vae_tiling_params,
};

results = generate_image(sd_ctx, &img_gen_params);
Expand Down
155 changes: 129 additions & 26 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,10 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
struct ggml_tensor* output,
int x,
int y,
int overlap) {
int overlap_x,
int overlap_y,
int x_skip = 0,
int y_skip = 0) {
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
Expand All @@ -503,17 +506,17 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
int64_t img_height = output->ne[1];

GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int iy = y_skip; iy < height; iy++) {
for (int ix = x_skip; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
if (overlap > 0) { // blend colors in overlapped area
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);

const float x_f_0 = (x > 0) ? ix / float(overlap) : 1;
const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1;
const float y_f_0 = (y > 0) ? iy / float(overlap) : 1;
const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1;
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;

const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
Expand Down Expand Up @@ -745,22 +748,96 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex

typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;

__STATIC_INLINE__ void
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {

int tile_overlap = (tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;

num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;

if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
// if tiles don't fit perfectly using the desired overlap
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
num_tiles_dim++;
}

tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
if (num_tiles_dim <= 2) {
if (small_dim <= tile_size) {
num_tiles_dim = 1;
tile_overlap_factor_dim = 0;
} else {
num_tiles_dim = 2;
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
}
}
}

// Tiling
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
const int p_tile_size_x, const int p_tile_size_y,
const float tile_overlap_factor, on_tile_process on_processing) {

output = ggml_set_f32(output, 0);

int input_width = (int)input->ne[0];
int input_height = (int)input->ne[1];
int output_width = (int)output->ne[0];
int output_height = (int)output->ne[1];

GGML_ASSERT(input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
GGML_ASSERT(input_width / output_width == scale || output_width / input_width == scale);

int small_width = output_width;
int small_height = output_height;

bool big_out = output_width > input_width;
if (big_out) {
// Ex: decode
small_width = input_width;
small_height = input_height;
}

int num_tiles_x;
float tile_overlap_factor_x;
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);

int num_tiles_y;
float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);

LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);

GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2

int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
int non_tile_overlap = tile_size - tile_overlap;
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;

int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;

int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;

int input_tile_size_x = tile_size_x;
int input_tile_size_y = tile_size_y;
int output_tile_size_x = tile_size_x;
int output_tile_size_y = tile_size_y;

if (big_out) {
output_tile_size_x *= scale;
output_tile_size_y *= scale;
} else {
input_tile_size_x *= scale;
input_tile_size_y *= scale;
}

struct ggml_init_params params = {};
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
params.mem_size += 3 * ggml_tensor_overhead();
params.mem_buffer = NULL;
params.no_alloc = false;
Expand All @@ -775,29 +852,50 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
}

// tiling
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
on_processing(input_tile, NULL, true);
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_INFO("processing %i tiles", num_tiles);
pretty_progress(1, num_tiles, 0.0f);
pretty_progress(0, num_tiles, 0.0f);
int tile_count = 1;
bool last_y = false, last_x = false;
float last_time = 0.0f;
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
if (y + tile_size >= input_height) {
y = input_height - tile_size;
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
int dy = 0;
if (y + tile_size_y >= small_height) {
int _y = y;
y = small_height - tile_size_y;
dy = _y - y;
if (big_out) {
dy *= scale;
}
last_y = true;
}
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
if (x + tile_size >= input_width) {
x = input_width - tile_size;
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
int dx = 0;
if (x + tile_size_x >= small_width) {
int _x = x;
x = small_width - tile_size_x;
dx = _x - x;
if (big_out) {
dx *= scale;
}
last_x = true;
}

int x_in = big_out ? x : scale * x;
int y_in = big_out ? y : scale * y;
int x_out = big_out ? x * scale : x;
int y_out = big_out ? y * scale : y;

int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;

int64_t t1 = ggml_time_ms();
ggml_split_tensor_2d(input, input_tile, x, y);
ggml_split_tensor_2d(input, input_tile, x_in, y_in);
on_processing(input_tile, output_tile, false);
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);

int64_t t2 = ggml_time_ms();
last_time = (t2 - t1) / 1000.0f;
pretty_progress(tile_count, num_tiles, last_time);
Expand All @@ -811,6 +909,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
ggml_free(tiles_ctx);
}

__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
}

__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
struct ggml_tensor* a) {
const float eps = 1e-6f; // default eps parameter
Expand Down
Loading
Loading