Skip to content

Commit e69878c

Browse files
committed
Add config option for Conv2D Direct
1 parent 86e79d2 commit e69878c

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

expose.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ struct sd_load_model_inputs
166166
const int threads = 0;
167167
const int quant = 0;
168168
const bool flash_attention = false;
169+
const bool diffusion_conv_direct = false;
170+
const bool vae_conv_direct = false;
169171
const bool taesd = false;
170172
const int tiled_vae_threshold = 0;
171173
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
280280
("threads", ctypes.c_int),
281281
("quant", ctypes.c_int),
282282
("flash_attention", ctypes.c_bool),
283+
("diffusion_conv_direct", ctypes.c_bool),
284+
("vae_conv_direct", ctypes.c_bool),
283285
("taesd", ctypes.c_bool),
284286
("tiled_vae_threshold", ctypes.c_int),
285287
("t5xxl_filename", ctypes.c_char_p),
@@ -1636,6 +1638,19 @@ def generate(genparams, stream_flag=False):
16361638
outstr = outstr[:sindex]
16371639
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
16381640

1641+
sd_convdirect_choices = ['disabled', 'vaeonly', 'enabled']
1642+
1643+
def sd_convdirect_option(value):
1644+
if not value:
1645+
value = ''
1646+
value = value.lower()
1647+
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
1648+
return 'disabled'
1649+
elif value in ['vae', 'vaeonly']:
1650+
return 'vaeonly'
1651+
elif value in ['enabled', 'enable', 'on', 'full']:
1652+
return 'enabled'
1653+
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
16391654

16401655
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
16411656
global args
@@ -1654,6 +1669,9 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16541669
inputs.threads = thds
16551670
inputs.quant = quant
16561671
inputs.flash_attention = args.sdflashattention
1672+
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
1673+
inputs.diffusion_conv_direct = sdconvdirect == 'enabled'
1674+
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'enabled']
16571675
inputs.taesd = True if args.sdvaeauto else False
16581676
inputs.tiled_vae_threshold = args.sdtiledvae
16591677
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4567,8 +4585,10 @@ def hide_tooltip(event):
45674585
sd_clipg_var = ctk.StringVar()
45684586
sd_photomaker_var = ctk.StringVar()
45694587
sd_flash_attention_var = ctk.IntVar(value=0)
4588+
sd_diffusion_convdir_var = ctk.IntVar(value=0)
45704589
sd_vaeauto_var = ctk.IntVar(value=0)
45714590
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4591+
sd_convdirect_var = ctk.StringVar(value='disabled')
45724592
sd_clamped_var = ctk.StringVar(value="0")
45734593
sd_clamped_soft_var = ctk.StringVar(value="0")
45744594
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -4633,6 +4653,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
46334653
temp.bind("<Leave>", hide_tooltip)
46344654
return temp
46354655

4656+
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
4657+
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
4658+
label=None
4659+
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
4660+
if command is not None and variable is not None:
4661+
variable.trace_add("write", command)
4662+
combo.grid(row=row,column=0, padx=padx, sticky="nw")
4663+
if tooltiptxt!="":
4664+
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
4665+
combo.bind("<Leave>", hide_tooltip)
4666+
return combo, label
4667+
46364668
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
46374669
temp = ctk.CTkLabel(parent, text=text)
46384670
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@@ -5327,7 +5359,8 @@ def toggletaesd(a,b,c):
53275359
sdvaeitem1.grid()
53285360
sdvaeitem2.grid()
53295361
sdvaeitem3.grid()
5330-
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5362+
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5363+
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
53315364
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53325365
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53335366

@@ -5580,6 +5613,7 @@ def export_vars():
55805613
args.sdvae = ""
55815614
if sd_vae_var.get() != "":
55825615
args.sdvae = sd_vae_var.get()
5616+
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
55835617
if sd_t5xxl_var.get() != "":
55845618
args.sdt5xxl = sd_t5xxl_var.get()
55855619
if sd_clipl_var.get() != "":
@@ -5801,6 +5835,7 @@ def import_vars(dict):
58015835
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58025836
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
58035837
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5838+
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
58045839
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58055840
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58065841
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7603,6 +7638,7 @@ def range_checker(arg: str):
76037638
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76047639
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
76057640
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7641+
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'disabled' (default) to disable, 'enabled' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
76067642
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76077643
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76087644
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool diffusion_conv_direct = false;
103+
bool vae_conv_direct = false;
102104
bool canny_preprocess = false;
103105
bool color = false;
104106
int upscale_repeats = 1;
@@ -211,6 +213,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
211213
{
212214
printf("Flash Attention is enabled\n");
213215
}
216+
if(inputs.diffusion_conv_direct)
217+
{
218+
printf("Conv2D Direct for diffusion model is enabled\n");
219+
}
220+
if(inputs.vae_conv_direct)
221+
{
222+
printf("Conv2D Direct for VAE model is enabled\n");
223+
}
214224
if(inputs.quant)
215225
{
216226
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -246,6 +256,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
246256
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
247257
sd_params->n_threads = inputs.threads; //if -1 use physical cores
248258
sd_params->diffusion_flash_attn = inputs.flash_attention;
259+
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
260+
sd_params->vae_conv_direct = inputs.vae_conv_direct;
249261
sd_params->input_path = ""; //unused
250262
sd_params->batch_count = 1;
251263
sd_params->vae_path = vaefilename;
@@ -316,6 +328,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
316328
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
317329
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
318330
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
331+
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
332+
params.vae_conv_direct = sd_params->vae_conv_direct;
319333
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
320334
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
321335
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;

0 commit comments

Comments
 (0)