Skip to content

Commit e861feb

Browse files
committed
Add separate flash attention config for image generation
1 parent d876898 commit e861feb

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

koboldcpp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1653,7 +1653,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16531653

16541654
inputs.threads = thds
16551655
inputs.quant = quant
1656-
inputs.flash_attention = args.flashattention
1656+
inputs.flash_attention = args.sdflashattention
16571657
inputs.taesd = True if args.sdvaeauto else False
16581658
inputs.tiled_vae_threshold = args.sdtiledvae
16591659
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4566,6 +4566,7 @@ def hide_tooltip(event):
45664566
sd_clipl_var = ctk.StringVar()
45674567
sd_clipg_var = ctk.StringVar()
45684568
sd_photomaker_var = ctk.StringVar()
4569+
sd_flash_attention_var = ctk.IntVar(value=0)
45694570
sd_vaeauto_var = ctk.IntVar(value=0)
45704571
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
45714572
sd_clamped_var = ctk.StringVar(value="0")
@@ -5328,6 +5329,7 @@ def toggletaesd(a,b,c):
53285329
sdvaeitem3.grid()
53295330
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
53305331
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
5332+
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53315333

53325334
# audio tab
53335335
audio_tab = tabcontent["Audio"]
@@ -5564,6 +5566,8 @@ def export_vars():
55645566
if sd_model_var.get() != "":
55655567
args.sdmodel = sd_model_var.get()
55665568

5569+
if sd_flash_attention_var.get()==1:
5570+
args.sdflashattention = True
55675571
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55685572
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55695573
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5796,6 +5800,7 @@ def import_vars(dict):
57965800
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
57975801
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
57985802
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
5803+
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
57995804
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58005805
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58015806
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7597,6 +7602,7 @@ def range_checker(arg: str):
75977602
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
75987603
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
75997604
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
7605+
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
76007606
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76017607
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76027608
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

0 commit comments

Comments
 (0)