Skip to content

Commit e109929

Browse files
committed
Add config for Conv2D Direct for the VAE
1 parent e861feb commit e109929

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ struct sd_load_model_inputs
166166
const int threads = 0;
167167
const int quant = 0;
168168
const bool flash_attention = false;
169+
const bool vae_conv_direct = false;
169170
const bool taesd = false;
170171
const int tiled_vae_threshold = 0;
171172
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class sd_load_model_inputs(ctypes.Structure):
280280
("threads", ctypes.c_int),
281281
("quant", ctypes.c_int),
282282
("flash_attention", ctypes.c_bool),
283+
("vae_conv_direct", ctypes.c_bool),
283284
("taesd", ctypes.c_bool),
284285
("tiled_vae_threshold", ctypes.c_int),
285286
("t5xxl_filename", ctypes.c_char_p),
@@ -1654,6 +1655,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16541655
inputs.threads = thds
16551656
inputs.quant = quant
16561657
inputs.flash_attention = args.sdflashattention
1658+
inputs.vae_conv_direct = args.sdvaeconvdir
16571659
inputs.taesd = True if args.sdvaeauto else False
16581660
inputs.tiled_vae_threshold = args.sdtiledvae
16591661
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4569,6 +4571,7 @@ def hide_tooltip(event):
45694571
sd_flash_attention_var = ctk.IntVar(value=0)
45704572
sd_vaeauto_var = ctk.IntVar(value=0)
45714573
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4574+
sd_vae_convdir_var = ctk.IntVar(value=0)
45724575
sd_clamped_var = ctk.StringVar(value="0")
45734576
sd_clamped_soft_var = ctk.StringVar(value="0")
45744577
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -5327,7 +5330,8 @@ def toggletaesd(a,b,c):
53275330
sdvaeitem1.grid()
53285331
sdvaeitem2.grid()
53295332
sdvaeitem3.grid()
5330-
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5333+
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5334+
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
53315335
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53325336
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53335337

@@ -5580,6 +5584,8 @@ def export_vars():
55805584
args.sdvae = ""
55815585
if sd_vae_var.get() != "":
55825586
args.sdvae = sd_vae_var.get()
5587+
if sd_vae_convdir_var.get()==1:
5588+
args.sdvaeconvdir = True
55835589
if sd_t5xxl_var.get() != "":
55845590
args.sdt5xxl = sd_t5xxl_var.get()
55855591
if sd_clipl_var.get() != "":
@@ -5808,6 +5814,7 @@ def import_vars(dict):
58085814
sd_photomaker_var.set(dict["sdphotomaker"] if ("sdphotomaker" in dict and dict["sdphotomaker"]) else "")
58095815
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
58105816
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
5817+
sd_vae_convdir_var.set(1 if ("sdvaeconvdir" in dict and dict["sdvaeconvdir"]) else 0)
58115818

58125819
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
58135820
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
@@ -7606,6 +7613,7 @@ def range_checker(arg: str):
76067613
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76077614
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76087615
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
7616+
sdparsergroupvae.add_argument("--sdvaeconvdir", help="Enables Conv2D Direct for the image diffusion model. Should improve performance and reduce memory usage. Might crash if not supported by the backend.", action='store_true')
76097617
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
76107618
sdparsergrouplora.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
76117619
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool vae_conv_direct = false;
102103
bool canny_preprocess = false;
103104
bool color = false;
104105
int upscale_repeats = 1;
@@ -211,6 +212,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
211212
{
212213
printf("Flash Attention is enabled\n");
213214
}
215+
if(inputs.vae_conv_direct)
216+
{
217+
printf("Conv2D Direct for VAE model is enabled\n");
218+
}
214219
if(inputs.quant)
215220
{
216221
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -246,6 +251,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
246251
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
247252
sd_params->n_threads = inputs.threads; //if -1 use physical cores
248253
sd_params->diffusion_flash_attn = inputs.flash_attention;
254+
sd_params->vae_conv_direct = inputs.vae_conv_direct;
249255
sd_params->input_path = ""; //unused
250256
sd_params->batch_count = 1;
251257
sd_params->vae_path = vaefilename;
@@ -316,6 +322,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
316322
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
317323
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
318324
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
325+
params.vae_conv_direct = sd_params->vae_conv_direct;
319326
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
320327
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
321328
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;

0 commit comments

Comments
 (0)