Skip to content

Commit 941e502

Browse files
committed
Add config for Conv2D Direct for the diffusion model
1 parent c70b0df commit 941e502

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct sd_load_model_inputs
165165
const int threads = 0;
166166
const int quant = 0;
167167
const bool flash_attention = false;
168+
const bool diffusion_conv_direct = false;
168169
const bool vae_conv_direct = false;
169170
const bool taesd = false;
170171
const int tiled_vae_threshold = 0;

koboldcpp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class sd_load_model_inputs(ctypes.Structure):
278278
("threads", ctypes.c_int),
279279
("quant", ctypes.c_int),
280280
("flash_attention", ctypes.c_bool),
281+
("diffusion_conv_direct", ctypes.c_bool),
281282
("vae_conv_direct", ctypes.c_bool),
282283
("taesd", ctypes.c_bool),
283284
("tiled_vae_threshold", ctypes.c_int),
@@ -1647,6 +1648,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16471648
inputs.threads = thds
16481649
inputs.quant = quant
16491650
inputs.flash_attention = args.sdflashattention
1651+
inputs.diffusion_conv_direct = args.sddiffusionconvdir
16501652
inputs.vae_conv_direct = args.sdvaeconvdir
16511653
inputs.taesd = True if args.sdvaeauto else False
16521654
inputs.tiled_vae_threshold = args.sdtiledvae
@@ -4559,6 +4561,7 @@ def hide_tooltip(event):
45594561
sd_clipg_var = ctk.StringVar()
45604562
sd_photomaker_var = ctk.StringVar()
45614563
sd_flash_attention_var = ctk.IntVar(value=0)
4564+
sd_diffusion_convdir_var = ctk.IntVar(value=0)
45624565
sd_vaeauto_var = ctk.IntVar(value=0)
45634566
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
45644567
sd_vae_convdir_var = ctk.IntVar(value=0)
@@ -5309,6 +5312,7 @@ def toggletaesd(a,b,c):
53095312
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
53105313
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53115314
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
5315+
makecheckbox(images_tab, "Conv2D Direct for Diffusion", sd_diffusion_convdir_var, 48, padx=220, tooltiptxt="Enable Conv2D Direct for diffusion. May save memory or improve performance.\nMight crash if not supported by the backend.")
53125316

53135317
# audio tab
53145318
audio_tab = tabcontent["Audio"]
@@ -5541,6 +5545,8 @@ def export_vars():
55415545

55425546
if sd_flash_attention_var.get()==1:
55435547
args.sdflashattention = True
5548+
if sd_diffusion_convdir_var.get()==1:
5549+
args.sddiffusionconvdir = True
55445550
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55455551
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55465552
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5765,6 +5771,7 @@ def import_vars(dict):
57655771
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
57665772
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
57675773
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5774+
sd_diffusion_convdir_var.set(1 if ("sddiffusionconvdir" in dict and dict["sddiffusionconvdir"]) else 0)
57685775
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
57695776
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
57705777
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7568,6 +7575,7 @@ def range_checker(arg: str):
75687575
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
75697576
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
75707577
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7578+
sdparsergroup.add_argument("--sddiffusionconvdir", help="Enables Conv2D Direct for the image diffusion model. May improve performance or reduce memory usage. Might crash if not supported by the backend.", action='store_true')
75717579
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
75727580
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
75737581
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool diffusion_conv_direct = false;
102103
bool vae_conv_direct = false;
103104
bool canny_preprocess = false;
104105
bool color = false;
@@ -212,6 +213,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
212213
{
213214
printf("Flash Attention is enabled\n");
214215
}
216+
if(inputs.diffusion_conv_direct)
217+
{
218+
printf("Conv2D Direct for diffusion model is enabled\n");
219+
}
215220
if(inputs.vae_conv_direct)
216221
{
217222
printf("Conv2D Direct for VAE model is enabled\n");
@@ -251,6 +256,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
251256
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
252257
sd_params->n_threads = inputs.threads; //if -1 use physical cores
253258
sd_params->diffusion_flash_attn = inputs.flash_attention;
259+
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
254260
sd_params->vae_conv_direct = inputs.vae_conv_direct;
255261
sd_params->input_path = ""; //unused
256262
sd_params->batch_count = 1;
@@ -322,6 +328,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
322328
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
323329
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
324330
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
331+
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
325332
params.vae_conv_direct = sd_params->vae_conv_direct;
326333
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
327334
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;

0 commit comments

Comments
 (0)