Skip to content

Commit 62043df

Browse files
committed
Add config for Conv2D Direct for the diffusion model
1 parent e109929 commit 62043df

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ struct sd_load_model_inputs
166166
const int threads = 0;
167167
const int quant = 0;
168168
const bool flash_attention = false;
169+
const bool diffusion_conv_direct = false;
169170
const bool vae_conv_direct = false;
170171
const bool taesd = false;
171172
const int tiled_vae_threshold = 0;

koboldcpp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class sd_load_model_inputs(ctypes.Structure):
280280
("threads", ctypes.c_int),
281281
("quant", ctypes.c_int),
282282
("flash_attention", ctypes.c_bool),
283+
("diffusion_conv_direct", ctypes.c_bool),
283284
("vae_conv_direct", ctypes.c_bool),
284285
("taesd", ctypes.c_bool),
285286
("tiled_vae_threshold", ctypes.c_int),
@@ -1655,6 +1656,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
16551656
inputs.threads = thds
16561657
inputs.quant = quant
16571658
inputs.flash_attention = args.sdflashattention
1659+
inputs.diffusion_conv_direct = args.sddiffusionconvdir
16581660
inputs.vae_conv_direct = args.sdvaeconvdir
16591661
inputs.taesd = True if args.sdvaeauto else False
16601662
inputs.tiled_vae_threshold = args.sdtiledvae
@@ -4569,6 +4571,7 @@ def hide_tooltip(event):
45694571
sd_clipg_var = ctk.StringVar()
45704572
sd_photomaker_var = ctk.StringVar()
45714573
sd_flash_attention_var = ctk.IntVar(value=0)
4574+
sd_diffusion_convdir_var = ctk.IntVar(value=0)
45724575
sd_vaeauto_var = ctk.IntVar(value=0)
45734576
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
45744577
sd_vae_convdir_var = ctk.IntVar(value=0)
@@ -5334,6 +5337,7 @@ def toggletaesd(a,b,c):
53345337
makecheckbox(images_tab, "Conv2D Direct for VAE", sd_vae_convdir_var, 42, padx=220, tooltiptxt="Enable Conv2D Direct for VAE. Saves memory and improves performance.\nMight crash if not supported by the backend.")
53355338
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
53365339
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 48, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
5340+
makecheckbox(images_tab, "Conv2D Direct for Diffusion", sd_diffusion_convdir_var, 48, padx=220, tooltiptxt="Enable Conv2D Direct for diffusion. May save memory or improve performance.\nMight crash if not supported by the backend.")
53375341

53385342
# audio tab
53395343
audio_tab = tabcontent["Audio"]
@@ -5572,6 +5576,8 @@ def export_vars():
55725576

55735577
if sd_flash_attention_var.get()==1:
55745578
args.sdflashattention = True
5579+
if sd_diffusion_convdir_var.get()==1:
5580+
args.sddiffusionconvdir = True
55755581
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
55765582
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
55775583
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@@ -5807,6 +5813,7 @@ def import_vars(dict):
58075813
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
58085814
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
58095815
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
5816+
sd_diffusion_convdir_var.set(1 if ("sddiffusionconvdir" in dict and dict["sddiffusionconvdir"]) else 0)
58105817
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
58115818
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
58125819
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@@ -7610,6 +7617,7 @@ def range_checker(arg: str):
76107617
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
76117618
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
76127619
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
7620+
sdparsergroup.add_argument("--sddiffusionconvdir", help="Enables Conv2D Direct for the image diffusion model. May improve performance or reduce memory usage. Might crash if not supported by the backend.", action='store_true')
76137621
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
76147622
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
76157623
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct SDParams {
9999
bool clip_on_cpu = false;
100100
bool vae_on_cpu = false;
101101
bool diffusion_flash_attn = false;
102+
bool diffusion_conv_direct = false;
102103
bool vae_conv_direct = false;
103104
bool canny_preprocess = false;
104105
bool color = false;
@@ -212,6 +213,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
212213
{
213214
printf("Flash Attention is enabled\n");
214215
}
216+
if(inputs.diffusion_conv_direct)
217+
{
218+
printf("Conv2D Direct for diffusion model is enabled\n");
219+
}
215220
if(inputs.vae_conv_direct)
216221
{
217222
printf("Conv2D Direct for VAE model is enabled\n");
@@ -251,6 +256,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
251256
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
252257
sd_params->n_threads = inputs.threads; //if -1 use physical cores
253258
sd_params->diffusion_flash_attn = inputs.flash_attention;
259+
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
254260
sd_params->vae_conv_direct = inputs.vae_conv_direct;
255261
sd_params->input_path = ""; //unused
256262
sd_params->batch_count = 1;
@@ -322,6 +328,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
322328
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
323329
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
324330
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
331+
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
325332
params.vae_conv_direct = sd_params->vae_conv_direct;
326333
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
327334
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;

0 commit comments

Comments
 (0)