Skip to content

Commit 15a4f9f

Browse files
committed
Use a single config option for Conv2D Direct
1 parent 62043df commit 15a4f9f

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

koboldcpp.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,19 @@ def generate(genparams, stream_flag=False):
16381638
outstr = outstr[:sindex]
16391639
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
16401640

1641+
sd_convdirect_choices = ['disabled', 'vaeonly', 'enabled']
1642+
1643+
def sd_convdirect_option(value):
1644+
if not value:
1645+
value = ''
1646+
value = value.lower()
1647+
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
1648+
return 'disabled'
1649+
elif value in ['vae', 'vaeonly']:
1650+
return 'vaeonly'
1651+
elif value in ['enabled', 'enable', 'on', 'full']:
1652+
return 'enabled'
1653+
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
16411654

16421655
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
16431656
global args
@@ -1656,8 +1669,8 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16561669
inputs.threads = thds
16571670
inputs.quant = quant
16581671
inputs.flash_attention = args.sdflashattention
1659-
inputs.diffusion_conv_direct = args.sddiffusionconvdir
1660-
inputs.vae_conv_direct = args.sdvaeconvdir
1672+
inputs.diffusion_conv_direct = sdconvdirect == 'enabled'
1673+
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'enabled']
16611674
inputs.taesd = True if args.sdvaeauto else False
16621675
inputs.tiled_vae_threshold = args.sdtiledvae
16631676
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4574,7 +4587,7 @@ def hide_tooltip(event):
45744587
sd_diffusion_convdir_var = ctk.IntVar(value=0)
45754588
sd_vaeauto_var = ctk.IntVar(value=0)
45764589
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4577-
sd_vae_convdir_var = ctk.IntVar(value=0)
4590+
sd_convdirect_var = ctk.StringVar(value='disabled')
45784591
sd_clamped_var = ctk.StringVar(value="0")
45794592
sd_clamped_soft_var = ctk.StringVar(value="0")
45804593
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -4639,6 +4652,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
46394652
temp.bind("<Leave>", hide_tooltip)
46404653
return temp
46414654

4655+
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
4656+
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
4657+
label=None
4658+
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
4659+
if command is not None and variable is not None:
4660+
variable.trace_add("write", command)
4661+
combo.grid(row=row,column=0, padx=padx, sticky="nw")
4662+
if tooltiptxt!="":
4663+
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
4664+
combo.bind("<Leave>", hide_tooltip)
4665+
return combo, label
4666+
46424667
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
46434668
temp = ctk.CTkLabel(parent, text=text)
46444669
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@@ -5334,10 +5359,9 @@ def toggletaesd(a,b,c):
53345359
sdvaeitem2.grid()
53355360
sdvaeitem3.grid()
53365361
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5337-
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
5362+
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
53385363
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
5339-
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
5340-
makecheckbox(images_tab, "Conv2D Direct for Diffusion", sd_diffusion_convdir_var, 48, padx=220, tooltiptxt="Enable Conv2D Direct for diffusion. May save memory or improve performance.\nMight crash if not supported by the backend.")
5364+
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53415365

53425366
# audio tab
53435367
audio_tab = tabcontent["Audio"]
@@ -5576,8 +5600,6 @@ def export_vars():
55765600

55775601
if sd_flash_attention_var.get()==1:
55785602
args.sdflashattention = True
5579-
if sd_diffusion_convdir_var.get()==1:
5580-
args.sddiffusionconvdir = True
55815603
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55825604
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55835605
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5590,8 +5612,7 @@ def export_vars():
55905612
args.sdvae = ""
55915613
if sd_vae_var.get() != "":
55925614
args.sdvae = sd_vae_var.get()
5593-
if sd_vae_convdir_var.get()==1:
5594-
args.sdvaeconvdir = True
5615+
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
55955616
if sd_t5xxl_var.get() != "":
55965617
args.sdt5xxl = sd_t5xxl_var.get()
55975618
if sd_clipl_var.get() != "":
@@ -5813,15 +5834,14 @@ def import_vars(dict):
58135834
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58145835
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
58155836
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5816-
sd_diffusion_convdir_var.set(1 if ("sddiffusionconvdir" in dict and dict["sddiffusionconvdir"]) else 0)
5837+
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
58175838
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58185839
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58195840
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
58205841
sd_clipg_var.set(dict["sdclipg"] if ("sdclipg" in dict and dict["sdclipg"]) else "")
58215842
sd_photomaker_var.set(dict["sdphotomaker"] if ("sdphotomaker" in dict and dict["sdphotomaker"]) else "")
58225843
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
58235844
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
5824-
sd_vae_convdir_var.set(1 if ("sdvaeconvdir" in dict and dict["sdvaeconvdir"]) else 0)
58255845

58265846
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
58275847
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
@@ -7617,11 +7637,10 @@ def range_checker(arg: str):
76177637
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76187638
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
76197639
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7620-
sdparsergroup.add_argument("--sddiffusionconvdir", help="Enables Conv2D Direct for the image diffusion model. May improve performance or reduce memory usage. Might crash if not supported by the backend.", action='store_true')
7640+
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'disabled' (default) to disable, 'enabled' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
76217641
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76227642
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76237643
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
7624-
sdparsergroupvae.add_argument("--sdvaeconvdir", help="Enables Conv2D Direct for the image diffusion model. Should improve performance and reduce memory usage. Might crash if not supported by the backend.", action='store_true')
76257644
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
76267645
sdparsergrouplora.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
76277646
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")

0 commit comments

Comments
 (0)