Skip to content

Commit 6003e90

Browse files
authored
Add flash attention and conv2d direct controls for image generation (#1678)
* Add separate flash attention config for image generation * Add config option for Conv2D Direct
1 parent 35707f4 commit 6003e90

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

expose.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ struct sd_load_model_inputs
166166
const int threads = 0;
167167
const int quant = 0;
168168
const bool flash_attention = false;
169+
const bool diffusion_conv_direct = false;
170+
const bool vae_conv_direct = false;
169171
const bool taesd = false;
170172
const int tiled_vae_threshold = 0;
171173
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
280280
("threads", ctypes.c_int),
281281
("quant", ctypes.c_int),
282282
("flash_attention", ctypes.c_bool),
283+
("diffusion_conv_direct", ctypes.c_bool),
284+
("vae_conv_direct", ctypes.c_bool),
283285
("taesd", ctypes.c_bool),
284286
("tiled_vae_threshold", ctypes.c_int),
285287
("t5xxl_filename", ctypes.c_char_p),
@@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False):
16371639
outstr = outstr[:sindex]
16381640
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
16391641

1642+
sd_convdirect_choices = ['off', 'vaeonly', 'full']
1643+
1644+
def sd_convdirect_option(value):
1645+
if not value:
1646+
value = ''
1647+
value = value.lower()
1648+
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
1649+
return 'off'
1650+
elif value in ['vae', 'vaeonly']:
1651+
return 'vaeonly'
1652+
elif value in ['enabled', 'enable', 'on', 'full']:
1653+
return 'full'
1654+
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
16401655

16411656
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
16421657
global args
@@ -1654,7 +1669,10 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16541669

16551670
inputs.threads = thds
16561671
inputs.quant = quant
1657-
inputs.flash_attention = args.flashattention
1672+
inputs.flash_attention = args.sdflashattention
1673+
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
1674+
inputs.diffusion_conv_direct = sdconvdirect == 'full'
1675+
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full']
16581676
inputs.taesd = True if args.sdvaeauto else False
16591677
inputs.tiled_vae_threshold = args.sdtiledvae
16601678
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4568,8 +4586,10 @@ def hide_tooltip(event):
45684586
sd_clipl_var = ctk.StringVar()
45694587
sd_clipg_var = ctk.StringVar()
45704588
sd_photomaker_var = ctk.StringVar()
4589+
sd_flash_attention_var = ctk.IntVar(value=0)
45714590
sd_vaeauto_var = ctk.IntVar(value=0)
45724591
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4592+
sd_convdirect_var = ctk.StringVar(value='disabled')
45734593
sd_clamped_var = ctk.StringVar(value="0")
45744594
sd_clamped_soft_var = ctk.StringVar(value="0")
45754595
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -4634,6 +4654,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
46344654
temp.bind("<Leave>", hide_tooltip)
46354655
return temp
46364656

4657+
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
4658+
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
4659+
label=None
4660+
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
4661+
if command is not None and variable is not None:
4662+
variable.trace_add("write", command)
4663+
combo.grid(row=row,column=0, padx=padx, sticky="nw")
4664+
if tooltiptxt!="":
4665+
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
4666+
combo.bind("<Leave>", hide_tooltip)
4667+
return combo, label
4668+
46374669
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
46384670
temp = ctk.CTkLabel(parent, text=text)
46394671
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@@ -5328,8 +5360,10 @@ def toggletaesd(a,b,c):
53285360
sdvaeitem1.grid()
53295361
sdvaeitem2.grid()
53305362
sdvaeitem3.grid()
5331-
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5363+
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5364+
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
53325365
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
5366+
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53335367

53345368
# audio tab
53355369
audio_tab = tabcontent["Audio"]
@@ -5566,6 +5600,8 @@ def export_vars():
55665600
if sd_model_var.get() != "":
55675601
args.sdmodel = sd_model_var.get()
55685602

5603+
if sd_flash_attention_var.get()==1:
5604+
args.sdflashattention = True
55695605
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55705606
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55715607
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5578,6 +5614,7 @@ def export_vars():
55785614
args.sdvae = ""
55795615
if sd_vae_var.get() != "":
55805616
args.sdvae = sd_vae_var.get()
5617+
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
55815618
if sd_t5xxl_var.get() != "":
55825619
args.sdt5xxl = sd_t5xxl_var.get()
55835620
if sd_clipl_var.get() != "":
@@ -5798,6 +5835,8 @@ def import_vars(dict):
57985835
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
57995836
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58005837
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
5838+
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5839+
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
58015840
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58025841
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58035842
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7600,6 +7639,8 @@ def range_checker(arg: str):
76007639
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
76017640
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76027641
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
7642+
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7643+
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
76037644
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76047645
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76057646
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool diffusion_conv_direct = false;
103+
bool vae_conv_direct = false;
102104
bool canny_preprocess = false;
103105
bool color = false;
104106
int upscale_repeats = 1;
@@ -211,6 +213,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
211213
{
212214
printf("Flash Attention is enabled\n");
213215
}
216+
if(inputs.diffusion_conv_direct)
217+
{
218+
printf("Conv2D Direct for diffusion model is enabled\n");
219+
}
220+
if(inputs.vae_conv_direct)
221+
{
222+
printf("Conv2D Direct for VAE model is enabled\n");
223+
}
214224
if(inputs.quant)
215225
{
216226
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -246,6 +256,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
246256
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
247257
sd_params->n_threads = inputs.threads; //if -1 use physical cores
248258
sd_params->diffusion_flash_attn = inputs.flash_attention;
259+
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
260+
sd_params->vae_conv_direct = inputs.vae_conv_direct;
249261
sd_params->input_path = ""; //unused
250262
sd_params->batch_count = 1;
251263
sd_params->vae_path = vaefilename;
@@ -316,6 +328,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
316328
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
317329
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
318330
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
331+
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
332+
params.vae_conv_direct = sd_params->vae_conv_direct;
319333
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
320334
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
321335
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;

0 commit comments

Comments
 (0)