Skip to content

Commit 544b97e

Browse files
committed
Add config for Conv2D Direct for the VAE
1 parent 7fedd50 commit 544b97e

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct sd_load_model_inputs
165165
const int threads = 0;
166166
const int quant = 0;
167167
const bool flash_attention = false;
168+
const bool vae_conv_direct = false;
168169
const bool taesd = false;
169170
const int tiled_vae_threshold = 0;
170171
const char * t5xxl_filename = nullptr;

koboldcpp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class sd_load_model_inputs(ctypes.Structure):
278278
("threads", ctypes.c_int),
279279
("quant", ctypes.c_int),
280280
("flash_attention", ctypes.c_bool),
281+
("vae_conv_direct", ctypes.c_bool),
281282
("taesd", ctypes.c_bool),
282283
("tiled_vae_threshold", ctypes.c_int),
283284
("t5xxl_filename", ctypes.c_char_p),
@@ -1646,6 +1647,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16461647
inputs.threads = thds
16471648
inputs.quant = quant
16481649
inputs.flash_attention = args.sdflashattention
1650+
inputs.vae_conv_direct = args.sdvaeconvdir
16491651
inputs.taesd = True if args.sdvaeauto else False
16501652
inputs.tiled_vae_threshold = args.sdtiledvae
16511653
inputs.vae_filename = vae_filename.encode("UTF-8")
@@ -4559,6 +4561,7 @@ def hide_tooltip(event):
45594561
sd_flash_attention_var = ctk.IntVar(value=0)
45604562
sd_vaeauto_var = ctk.IntVar(value=0)
45614563
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
4564+
sd_vae_convdir_var = ctk.IntVar(value=0)
45624565
sd_clamped_var = ctk.StringVar(value="0")
45634566
sd_clamped_soft_var = ctk.StringVar(value="0")
45644567
sd_threads_var = ctk.StringVar(value=str(default_threads))
@@ -5302,7 +5305,8 @@ def toggletaesd(a,b,c):
53025305
sdvaeitem1.grid()
53035306
sdvaeitem2.grid()
53045307
sdvaeitem3.grid()
5305-
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5308+
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
5309+
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
53065310
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53075311
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
53085312

@@ -5549,6 +5553,8 @@ def export_vars():
55495553
args.sdvae = ""
55505554
if sd_vae_var.get() != "":
55515555
args.sdvae = sd_vae_var.get()
5556+
if sd_vae_convdir_var.get()==1:
5557+
args.sdvaeconvdir = True
55525558
if sd_t5xxl_var.get() != "":
55535559
args.sdt5xxl = sd_t5xxl_var.get()
55545560
if sd_clipl_var.get() != "":
@@ -5766,6 +5772,7 @@ def import_vars(dict):
57665772
sd_photomaker_var.set(dict["sdphotomaker"] if ("sdphotomaker" in dict and dict["sdphotomaker"]) else "")
57675773
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
57685774
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
5775+
sd_vae_convdir_var.set(1 if ("sdvaeconvdir" in dict and dict["sdvaeconvdir"]) else 0)
57695776

57705777
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
57715778
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
@@ -7564,6 +7571,7 @@ def range_checker(arg: str):
75647571
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
75657572
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
75667573
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
7574+
sdparsergroupvae.add_argument("--sdvaeconvdir", help="Enables Conv2D Direct for the image diffusion model. Should improve performance and reduce memory usage. Might crash if not supported by the backend.", action='store_true')
75677575
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
75687576
sdparsergrouplora.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
75697577
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool vae_conv_direct = false;
102103
bool canny_preprocess = false;
103104
bool color = false;
104105
int upscale_repeats = 1;
@@ -211,6 +212,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
211212
{
212213
printf("Flash Attention is enabled\n");
213214
}
215+
if(inputs.vae_conv_direct)
216+
{
217+
printf("Conv2D Direct for VAE model is enabled\n");
218+
}
214219
if(inputs.quant)
215220
{
216221
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
@@ -246,6 +251,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
246251
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
247252
sd_params->n_threads = inputs.threads; //if -1 use physical cores
248253
sd_params->diffusion_flash_attn = inputs.flash_attention;
254+
sd_params->vae_conv_direct = inputs.vae_conv_direct;
249255
sd_params->input_path = ""; //unused
250256
sd_params->batch_count = 1;
251257
sd_params->vae_path = vaefilename;
@@ -312,6 +318,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
312318
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
313319
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
314320
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
321+
params.vae_conv_direct = sd_params->vae_conv_direct;
315322
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
316323
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
317324
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;

0 commit comments

Comments
 (0)