Skip to content

Commit d485089

Browse files
committed
Use a single config option for Conv2D Direct
1 parent 62043df commit d485089

File tree

1 file changed

+34
-14
lines changed

1 file changed

+34
-14
lines changed

koboldcpp.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,19 @@ def generate(genparams, stream_flag=False):
16381638
outstr = outstr[:sindex]
16391639
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
16401640

1641+
sd_convdirect_choices = ['disabled', 'vaeonly', 'enabled']
1642+
1643+
def sd_convdirect_option(value):
1644+
if not value:
1645+
value = ''
1646+
value = value.lower()
1647+
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
1648+
return 'disabled'
1649+
elif value in ['vae', 'vaeonly']:
1650+
return 'vaeonly'
1651+
elif value in ['enabled', 'enable', 'on', 'full']:
1652+
return 'enabled'
1653+
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
16411654

16421655
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
16431656
global args
@@ -1656,8 +1669,9 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16561669
inputs.threads = thds
16571670
inputs.quant = quant
16581671
inputs.flash_attention = args.sdflashattention
1659-
inputs.diffusion_conv_direct = args.sddiffusionconvdir
1660-
inputs.vae_conv_direct = args.sdvaeconvdir
1672+
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
1673+
inputs.diffusion_conv_direct = sdconvdirect == 'enabled'
1674+
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'enabled']
16611675
inputs.taesd = True if args.sdvaeauto else False
16621676
inputs.tiled_vae_threshold = args.sdtiledvae
16631677
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4574,7 +4588,7 @@ def hide_tooltip(event):
45744588
sd_diffusion_convdir_var = ctk.IntVar(value=0)
45754589
sd_vaeauto_var = ctk.IntVar(value=0)
45764590
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4577-
sd_vae_convdir_var = ctk.IntVar(value=0)
4591+
sd_convdirect_var = ctk.StringVar(value='disabled')
45784592
sd_clamped_var = ctk.StringVar(value="0")
45794593
sd_clamped_soft_var = ctk.StringVar(value="0")
45804594
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -4639,6 +4653,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
46394653
temp.bind("<Leave>", hide_tooltip)
46404654
return temp
46414655

4656+
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
4657+
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
4658+
label=None
4659+
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
4660+
if command is not None and variable is not None:
4661+
variable.trace_add("write", command)
4662+
combo.grid(row=row,column=0, padx=padx, sticky="nw")
4663+
if tooltiptxt!="":
4664+
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
4665+
combo.bind("<Leave>", hide_tooltip)
4666+
return combo, label
4667+
46424668
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
46434669
temp = ctk.CTkLabel(parent, text=text)
46444670
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@@ -5334,10 +5360,9 @@ def toggletaesd(a,b,c):
53345360
sdvaeitem2.grid()
53355361
sdvaeitem3.grid()
53365362
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5337-
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
5363+
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
53385364
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
5339-
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
5340-
makecheckbox(images_tab, "Conv2D Direct for Diffusion", sd_diffusion_convdir_var, 48, padx=220, tooltiptxt="Enable Conv2D Direct for diffusion. May save memory or improve performance.\nMight crash if not supported by the backend.")
5365+
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53415366

53425367
# audio tab
53435368
audio_tab = tabcontent["Audio"]
@@ -5576,8 +5601,6 @@ def export_vars():
55765601

55775602
if sd_flash_attention_var.get()==1:
55785603
args.sdflashattention = True
5579-
if sd_diffusion_convdir_var.get()==1:
5580-
args.sddiffusionconvdir = True
55815604
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55825605
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55835606
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5590,8 +5613,7 @@ def export_vars():
55905613
args.sdvae = ""
55915614
if sd_vae_var.get() != "":
55925615
args.sdvae = sd_vae_var.get()
5593-
if sd_vae_convdir_var.get()==1:
5594-
args.sdvaeconvdir = True
5616+
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
55955617
if sd_t5xxl_var.get() != "":
55965618
args.sdt5xxl = sd_t5xxl_var.get()
55975619
if sd_clipl_var.get() != "":
@@ -5813,15 +5835,14 @@ def import_vars(dict):
58135835
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58145836
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
58155837
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5816-
sd_diffusion_convdir_var.set(1 if ("sddiffusionconvdir" in dict and dict["sddiffusionconvdir"]) else 0)
5838+
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
58175839
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58185840
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58195841
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
58205842
sd_clipg_var.set(dict["sdclipg"] if ("sdclipg" in dict and dict["sdclipg"]) else "")
58215843
sd_photomaker_var.set(dict["sdphotomaker"] if ("sdphotomaker" in dict and dict["sdphotomaker"]) else "")
58225844
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
58235845
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
5824-
sd_vae_convdir_var.set(1 if ("sdvaeconvdir" in dict and dict["sdvaeconvdir"]) else 0)
58255846

58265847
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
58275848
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
@@ -7617,11 +7638,10 @@ def range_checker(arg: str):
76177638
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76187639
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
76197640
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7620-
sdparsergroup.add_argument("--sddiffusionconvdir", help="Enables Conv2D Direct for the image diffusion model. May improve performance or reduce memory usage. Might crash if not supported by the backend.", action='store_true')
7641+
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'disabled' (default) to disable, 'enabled' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
76217642
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76227643
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76237644
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
7624-
sdparsergroupvae.add_argument("--sdvaeconvdir", help="Enables Conv2D Direct for the image diffusion model. Should improve performance and reduce memory usage. Might crash if not supported by the backend.", action='store_true')
76257645
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
76267646
sdparsergrouplora.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
76277647
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")

0 commit comments

Comments
 (0)