Skip to content

Commit d74c16e

Browse files
authored
enable flash attention for image generation (LostRuins#1633)
1 parent bc3e4c1 commit d74c16e

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ struct sd_load_model_inputs
161161
const char * vulkan_info = nullptr;
162162
const int threads = 0;
163163
const int quant = 0;
164+
const bool flash_attention = false;
164165
const bool taesd = false;
165166
const int tiled_vae_threshold = 0;
166167
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ class sd_load_model_inputs(ctypes.Structure):
273273
("vulkan_info", ctypes.c_char_p),
274274
("threads", ctypes.c_int),
275275
("quant", ctypes.c_int),
276+
("flash_attention", ctypes.c_bool),
276277
("taesd", ctypes.c_bool),
277278
("tiled_vae_threshold", ctypes.c_int),
278279
("t5xxl_filename", ctypes.c_char_p),
@@ -1624,6 +1625,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16241625

16251626
inputs.threads = thds
16261627
inputs.quant = quant
1628+
inputs.flash_attention = args.flashattention
16271629
inputs.taesd = True if args.sdvaeauto else False
16281630
inputs.tiled_vae_threshold = args.sdtiledvae
16291631
inputs.vae_filename = vae_filename.encode("UTF-8")

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
179179
printf("With PhotoMaker Model: %s\n",photomaker_filename.c_str());
180180
photomaker_enabled = true;
181181
}
182+
if(inputs.flash_attention)
183+
{
184+
printf("Flash Attention is enabled\n");
185+
}
182186
if(inputs.quant)
183187
{
184188
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -213,6 +217,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
213217
sd_params->model_path = inputs.model_filename;
214218
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
215219
sd_params->n_threads = inputs.threads; //if -1 use physical cores
220+
sd_params->diffusion_flash_attn = inputs.flash_attention;
216221
sd_params->input_path = ""; //unused
217222
sd_params->batch_count = 1;
218223
sd_params->vae_path = vaefilename;

0 commit comments

Comments
 (0)