Skip to content

Commit 7fedd50

Browse files
committed
Add separate flash attention config for image generation
1 parent 419dd92 commit 7fedd50

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

koboldcpp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16451645

16461646
inputs.threads = thds
16471647
inputs.quant = quant
1648-
inputs.flash_attention = args.flashattention
1648+
inputs.flash_attention = args.sdflashattention
16491649
inputs.taesd = True if args.sdvaeauto else False
16501650
inputs.tiled_vae_threshold = args.sdtiledvae
16511651
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4556,6 +4556,7 @@ def hide_tooltip(event):
45564556
sd_clipl_var = ctk.StringVar()
45574557
sd_clipg_var = ctk.StringVar()
45584558
sd_photomaker_var = ctk.StringVar()
4559+
sd_flash_attention_var = ctk.IntVar(value=0)
45594560
sd_vaeauto_var = ctk.IntVar(value=0)
45604561
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
45614562
sd_clamped_var = ctk.StringVar(value="0")
@@ -5303,6 +5304,7 @@ def toggletaesd(a,b,c):
53035304
sdvaeitem3.grid()
53045305
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
53055306
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
5307+
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53065308

53075309
# audio tab
53085310
audio_tab = tabcontent["Audio"]
@@ -5533,6 +5535,8 @@ def export_vars():
55335535
if sd_model_var.get() != "":
55345536
args.sdmodel = sd_model_var.get()
55355537

5538+
if sd_flash_attention_var.get()==1:
5539+
args.sdflashattention = True
55365540
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55375541
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55385542
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5754,6 +5758,7 @@ def import_vars(dict):
57545758
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
57555759
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
57565760
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
5761+
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
57575762
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
57585763
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
57595764
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7555,6 +7560,7 @@ def range_checker(arg: str):
75557560
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
75567561
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
75577562
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
7563+
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
75587564
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
75597565
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
75607566
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

0 commit comments

Comments
 (0)