Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ arguments:
--diffusion-fa use flash attention in the diffusion model (for low vram)
Might lower quality, since it implies converting k and v to f16.
This might crash if it is not supported by the backend.
--diffusion-conv-direct use Conv2d direct in the diffusion model
This might crash if it is not supported by the backend.
--vae-conv-direct use Conv2d direct in the vae model (should improve the performance)
This might crash if it is not supported by the backend.
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color colors the logging tags according to level
Expand Down
11 changes: 11 additions & 0 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ struct ControlNet : public GGMLRunner {
control_net.init(params_ctx, tensor_types, "");
}

void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
control_net.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}

~ControlNet() {
free_control_ctx();
}
Expand Down
11 changes: 11 additions & 0 deletions esrgan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ struct ESRGAN : public GGMLRunner {
rrdb_net.init(params_ctx, tensor_types, "");
}

void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
rrdb_net.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}

std::string get_desc() {
return "esrgan";
}
Expand Down
15 changes: 14 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ struct SDParams {
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
Expand Down Expand Up @@ -142,6 +144,8 @@ void print_params(SDParams params) {
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
Expand Down Expand Up @@ -232,6 +236,10 @@ void print_usage(int argc, const char* argv[]) {
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color colors the logging tags according to level\n");
Expand Down Expand Up @@ -422,6 +430,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip-on-cpu", "", true, &params.clip_on_cpu},
{"", "--vae-on-cpu", "", true, &params.vae_on_cpu},
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
{"", "--canny", "", true, &params.canny_preprocess},
{"-v", "--verbos", "", true, &params.verbose},
{"", "--color", "", true, &params.color},
Expand Down Expand Up @@ -901,6 +911,8 @@ int main(int argc, const char* argv[]) {
params.control_net_cpu,
params.vae_on_cpu,
params.diffusion_flash_attn,
params.diffusion_conv_direct,
params.vae_conv_direct,
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
params.chroma_t5_mask_pad,
Expand Down Expand Up @@ -1012,7 +1024,8 @@ int main(int argc, const char* argv[]) {
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
params.n_threads);
params.n_threads,
params.diffusion_conv_direct);

if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");
Expand Down
47 changes: 46 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
return x;
}

__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s0 = 1,
int s1 = 1,
int p0 = 0,
int p1 = 0,
int d0 = 1,
int d1 = 1) {
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
}
return x;
}

// w: [OC,IC, KD, 1 * 1]
// x: [N, IC, IH, IW]
// b: [OC,]
Expand Down Expand Up @@ -1375,6 +1394,19 @@ class GGMLBlock {
tensors[prefix + pair.first] = pair.second;
}
}

virtual std::string get_desc() {
return "GGMLBlock";
}

void get_all_blocks(std::vector<GGMLBlock*>& result) {
result.push_back(this);
for (auto& block_iter : blocks) {
if (block_iter.second) {
block_iter.second->get_all_blocks(result);
}
}
}
};

class UnaryBlock : public GGMLBlock {
Expand Down Expand Up @@ -1464,6 +1496,7 @@ class Conv2d : public UnaryBlock {
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
bool direct = false;

void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16;
Expand All @@ -1490,13 +1523,25 @@ class Conv2d : public UnaryBlock {
dilation(dilation),
bias(bias) {}

void enable_direct() {
direct = true;
}

std::string get_desc() {
return "Conv2d";
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
if (direct) {
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
} else {
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
}
}
};

Expand Down
16 changes: 16 additions & 0 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ class StableDiffusionGGML {
model_loader.tensor_storages_types,
version,
sd_ctx_params->diffusion_flash_attn);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the diffusion model");
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
}
}

cond_stage_model->alloc_params_buffer();
Expand All @@ -395,6 +399,10 @@ class StableDiffusionGGML {
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->enable_conv2d_direct();
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
Expand All @@ -403,6 +411,10 @@ class StableDiffusionGGML {
"decoder.layers",
vae_decode_only,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->enable_conv2d_direct();
}
}
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");

Expand All @@ -415,6 +427,10 @@ class StableDiffusionGGML {
controlnet_backend = backend;
}
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the control net");
control_net->enable_conv2d_direct();
}
}

if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
Expand Down
5 changes: 4 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ typedef struct {
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool diffusion_flash_attn;
bool diffusion_conv_direct;
bool vae_conv_direct;
bool chroma_use_dit_mask;
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
Expand Down Expand Up @@ -236,7 +238,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
typedef struct upscaler_ctx_t upscaler_ctx_t;

SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
int n_threads);
int n_threads,
bool direct);
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);

SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
Expand Down
11 changes: 11 additions & 0 deletions tae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,17 @@ struct TinyAutoEncoder : public GGMLRunner {
taesd.init(params_ctx, tensor_types, prefix);
}

void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
taesd.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}

std::string get_desc() {
return "taesd";
}
Expand Down
12 changes: 12 additions & 0 deletions unet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,18 @@ struct UNetModelRunner : public GGMLRunner {
unet.init(params_ctx, tensor_types, prefix);
}

void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
unet.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
LOG_DEBUG("block %s", block->get_desc().c_str());
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}

std::string get_desc() {
return "unet";
}
Expand Down
15 changes: 11 additions & 4 deletions upscaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ struct UpscalerGGML {
std::shared_ptr<ESRGAN> esrgan_upscaler;
std::string esrgan_path;
int n_threads;
bool direct = false;

UpscalerGGML(int n_threads)
: n_threads(n_threads) {
UpscalerGGML(int n_threads,
bool direct = false)
: n_threads(n_threads),
direct(direct) {
}

bool load_from_file(const std::string& esrgan_path) {
Expand Down Expand Up @@ -47,6 +50,9 @@ struct UpscalerGGML {
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
if (direct) {
esrgan_upscaler->enable_conv2d_direct();
}
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
return false;
}
Expand Down Expand Up @@ -104,14 +110,15 @@ struct upscaler_ctx_t {
};

upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
int n_threads) {
int n_threads,
bool direct = false) {
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
if (upscaler_ctx == NULL) {
return NULL;
}
std::string esrgan_path(esrgan_path_c_str);

upscaler_ctx->upscaler = new UpscalerGGML(n_threads);
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
if (upscaler_ctx->upscaler == NULL) {
return NULL;
}
Expand Down
11 changes: 11 additions & 0 deletions vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,17 @@ struct AutoEncoderKL : public GGMLRunner {
ae.init(params_ctx, tensor_types, prefix);
}

void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}

std::string get_desc() {
return "vae";
}
Expand Down
Loading