Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions expose.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ struct sd_load_model_inputs
const int threads = 0;
const int quant = 0;
const bool flash_attention = false;
const bool diffusion_conv_direct = false;
const bool vae_conv_direct = false;
const bool taesd = false;
const int tiled_vae_threshold = 0;
const char * t5xxl_filename = nullptr;
Expand Down
45 changes: 43 additions & 2 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
("threads", ctypes.c_int),
("quant", ctypes.c_int),
("flash_attention", ctypes.c_bool),
("diffusion_conv_direct", ctypes.c_bool),
("vae_conv_direct", ctypes.c_bool),
("taesd", ctypes.c_bool),
("tiled_vae_threshold", ctypes.c_int),
("t5xxl_filename", ctypes.c_char_p),
Expand Down Expand Up @@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False):
outstr = outstr[:sindex]
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}

sd_convdirect_choices = ['off', 'vaeonly', 'full']

def sd_convdirect_option(value):
if not value:
value = ''
value = value.lower()
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
return 'off'
elif value in ['vae', 'vaeonly']:
return 'vaeonly'
elif value in ['enabled', 'enable', 'on', 'full']:
return 'full'
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")

def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
global args
Expand All @@ -1654,7 +1669,10 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl

inputs.threads = thds
inputs.quant = quant
inputs.flash_attention = args.flashattention
inputs.flash_attention = args.sdflashattention
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
inputs.diffusion_conv_direct = sdconvdirect == 'full'
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full']
inputs.taesd = True if args.sdvaeauto else False
inputs.tiled_vae_threshold = args.sdtiledvae
inputs.vae_filename = vae_filename.encode("UTF-8")
Expand Down Expand Up @@ -4568,8 +4586,10 @@ def hide_tooltip(event):
sd_clipl_var = ctk.StringVar()
sd_clipg_var = ctk.StringVar()
sd_photomaker_var = ctk.StringVar()
sd_flash_attention_var = ctk.IntVar(value=0)
sd_vaeauto_var = ctk.IntVar(value=0)
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
sd_convdirect_var = ctk.StringVar(value='disabled')
sd_clamped_var = ctk.StringVar(value="0")
sd_clamped_soft_var = ctk.StringVar(value="0")
sd_threads_var = ctk.StringVar(value=str(default_threads))
Expand Down Expand Up @@ -4634,6 +4654,18 @@ def makecheckbox(parent, text, variable=None, row=0, column=0, command=None, pad
temp.bind("<Leave>", hide_tooltip)
return temp

def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
label=None
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
if command is not None and variable is not None:
variable.trace_add("write", command)
combo.grid(row=row,column=0, padx=padx, sticky="nw")
if tooltiptxt!="":
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
combo.bind("<Leave>", hide_tooltip)
return combo, label

def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
temp = ctk.CTkLabel(parent, text=text)
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
Expand Down Expand Up @@ -5328,8 +5360,10 @@ def toggletaesd(a,b,c):
sdvaeitem1.grid()
sdvaeitem2.grid()
sdvaeitem3.grid()
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")

# audio tab
audio_tab = tabcontent["Audio"]
Expand Down Expand Up @@ -5566,6 +5600,8 @@ def export_vars():
if sd_model_var.get() != "":
args.sdmodel = sd_model_var.get()

if sd_flash_attention_var.get()==1:
args.sdflashattention = True
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
Expand All @@ -5578,6 +5614,7 @@ def export_vars():
args.sdvae = ""
if sd_vae_var.get() != "":
args.sdvae = sd_vae_var.get()
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
if sd_t5xxl_var.get() != "":
args.sdt5xxl = sd_t5xxl_var.get()
if sd_clipl_var.get() != "":
Expand Down Expand Up @@ -5798,6 +5835,8 @@ def import_vars(dict):
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
Expand Down Expand Up @@ -7600,6 +7639,8 @@ def range_checker(arg: str):
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
Expand Down
14 changes: 14 additions & 0 deletions otherarch/sdcpp/sdtype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ struct SDParams {
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
Expand Down Expand Up @@ -211,6 +213,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
{
printf("Flash Attention is enabled\n");
}
if(inputs.diffusion_conv_direct)
{
printf("Conv2D Direct for diffusion model is enabled\n");
}
if(inputs.vae_conv_direct)
{
printf("Conv2D Direct for VAE model is enabled\n");
}
if(inputs.quant)
{
printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n");
Expand Down Expand Up @@ -246,6 +256,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0);
sd_params->n_threads = inputs.threads; //if -1 use physical cores
sd_params->diffusion_flash_attn = inputs.flash_attention;
sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct;
sd_params->vae_conv_direct = inputs.vae_conv_direct;
sd_params->input_path = ""; //unused
sd_params->batch_count = 1;
sd_params->vae_path = vaefilename;
Expand Down Expand Up @@ -316,6 +328,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
params.keep_control_net_on_cpu = sd_params->control_net_cpu;
params.keep_vae_on_cpu = sd_params->vae_on_cpu;
params.diffusion_flash_attn = sd_params->diffusion_flash_attn;
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
params.vae_conv_direct = sd_params->vae_conv_direct;
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;
Expand Down