Skip to content

Commit df2a43f

Browse files
committed
Add config option for Conv2D Direct
1 parent 7444c3c commit df2a43f

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

expose.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ struct sd_load_model_inputs
166166
const int threads = 0;
167167
const int quant = 0;
168168
const bool flash_attention = false;
169+
const bool diffusion_conv_direct = false;
170+
const bool vae_conv_direct = false;
169171
const bool taesd = false;
170172
const int tiled_vae_threshold = 0;
171173
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
280280
("threads", ctypes.c_int),
281281
("quant", ctypes.c_int),
282282
("flash_attention", ctypes.c_bool),
283+
("diffusion_conv_direct", ctypes.c_bool),
284+
("vae_conv_direct", ctypes.c_bool),
283285
("taesd", ctypes.c_bool),
284286
("tiled_vae_threshold", ctypes.c_int),
285287
("t5xxl_filename", ctypes.c_char_p),
@@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False):
16371639
outstr = outstr[:sindex]
16381640
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
16391641

1642+
sd_convdirect_choices = ['disabled', 'vaeonly', 'enabled']
1643+
1644+
def sd_convdirect_option(value):
1645+
if not value:
1646+
value = ''
1647+
value = value.lower()
1648+
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
1649+
return 'disabled'
1650+
elif value in ['vae', 'vaeonly']:
1651+
return 'vaeonly'
1652+
elif value in ['enabled', 'enable', 'on', 'full']:
1653+
return 'enabled'
1654+
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
16401655

16411656
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
16421657
global args
@@ -1655,6 +1670,9 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16551670
inputs.threads = thds
16561671
inputs.quant = quant
16571672
inputs.flash_attention = args.sdflashattention
1673+
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
1674+
inputs.diffusion_conv_direct = sdconvdirect == 'enabled'
1675+
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'enabled']
16581676
inputs.taesd = True if args.sdvaeauto else False
16591677
inputs.tiled_vae_threshold = args.sdtiledvae
16601678
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4571,6 +4589,7 @@ def hide_tooltip(event):
45714589
sd_flash_attention_var = ctk.IntVar(value=0)
45724590
sd_vaeauto_var = ctk.IntVar(value=0)
45734591
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4592+
sd_convdirect_var = ctk.StringVar(value='disabled')
45744593
sd_clamped_var = ctk.StringVar(value="0")
45754594
sd_clamped_soft_var = ctk.StringVar(value="0")
45764595
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -4635,6 +4654,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
46354654
temp.bind("<Leave>", hide_tooltip)
46364655
return temp
46374656

4657+
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
4658+
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
4659+
label=None
4660+
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
4661+
if command is not None and variable is not None:
4662+
variable.trace_add("write", command)
4663+
combo.grid(row=row,column=0, padx=padx, sticky="nw")
4664+
if tooltiptxt!="":
4665+
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
4666+
combo.bind("<Leave>", hide_tooltip)
4667+
return combo, label
4668+
46384669
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
46394670
temp = ctk.CTkLabel(parent, text=text)
46404671
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@@ -5329,7 +5360,8 @@ def toggletaesd(a,b,c):
53295360
sdvaeitem1.grid()
53305361
sdvaeitem2.grid()
53315362
sdvaeitem3.grid()
5332-
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5363+
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5364+
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
53335365
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53345366
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53355367

@@ -5582,6 +5614,7 @@ def export_vars():
55825614
args.sdvae = ""
55835615
if sd_vae_var.get() != "":
55845616
args.sdvae = sd_vae_var.get()
5617+
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
55855618
if sd_t5xxl_var.get() != "":
55865619
args.sdt5xxl = sd_t5xxl_var.get()
55875620
if sd_clipl_var.get() != "":
@@ -5803,6 +5836,7 @@ def import_vars(dict):
58035836
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58045837
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
58055838
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5839+
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
58065840
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58075841
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58085842
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7606,6 +7640,7 @@ def range_checker(arg: str):
76067640
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76077641
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
76087642
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7643+
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'disabled' (default) to disable, 'enabled' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
76097644
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76107645
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76117646
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool diffusion_conv_direct = false;
103+
bool vae_conv_direct = false;
102104
bool canny_preprocess = false;
103105
bool color = false;
104106
int upscale_repeats = 1;
@@ -211,6 +213,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
211213
{
212214
printf("Flash Attention is enabled\n");
213215
}
216+
if(inputs.diffusion_conv_direct)
217+
{
218+
printf("Conv2D Direct for diffusion model is enabled\n");
219+
}
220+
if(inputs.vae_conv_direct)
221+
{
222+
printf("Conv2D Direct for VAE model is enabled\n");
223+
}
214224
if(inputs.quant)
215225
{
216226
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -246,6 +256,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
246256
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
247257
sd_params->n_threads = inputs.threads; //if -1 use physical cores
248258
sd_params->diffusion_flash_attn = inputs.flash_attention;
259+
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
260+
sd_params->vae_conv_direct = inputs.vae_conv_direct;
249261
sd_params->input_path = ""; //unused
250262
sd_params->batch_count = 1;
251263
sd_params->vae_path = vaefilename;
@@ -316,6 +328,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
316328
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
317329
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
318330
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
331+
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
332+
params.vae_conv_direct = sd_params->vae_conv_direct;
319333
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
320334
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
321335
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;

0 commit comments

Comments
 (0)