Skip to content

Commit c48999f

Browse files
authored
additional options for image generation (#1765)
* sd: add backend support for choosing the default sampler * use the default sampler on the API * sd: add backend support for the scheduler * sd: add backend support for distilled guidance * sd: add backend support for timestep-shift * sd: add a config field to set default image gen options
1 parent 75272f6 commit c48999f

File tree

3 files changed

+131
-15
lines changed

3 files changed

+131
-15
lines changed

expose.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,14 @@ struct sd_generation_inputs
197197
const bool flip_mask = false;
198198
const float denoising_strength = 0.0f;
199199
const float cfg_scale = 0.0f;
200+
const float distilled_guidance = -1.0f;
201+
const int shifted_timestep = 0;
200202
const int sample_steps = 0;
201203
const int width = 0;
202204
const int height = 0;
203205
const int seed = 0;
204206
const char * sample_method = nullptr;
207+
const char * scheduler = nullptr;
205208
const int clip_skip = -1;
206209
const int vid_req_frames = 1;
207210
const int vid_req_avi = 0;

koboldcpp.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,14 @@ class sd_generation_inputs(ctypes.Structure):
311311
("flip_mask", ctypes.c_bool),
312312
("denoising_strength", ctypes.c_float),
313313
("cfg_scale", ctypes.c_float),
314+
("distilled_guidance", ctypes.c_float),
315+
("shifted_timestep", ctypes.c_int),
314316
("sample_steps", ctypes.c_int),
315317
("width", ctypes.c_int),
316318
("height", ctypes.c_int),
317319
("seed", ctypes.c_int),
318320
("sample_method", ctypes.c_char_p),
321+
("scheduler", ctypes.c_char_p),
319322
("clip_skip", ctypes.c_int),
320323
("vid_req_frames", ctypes.c_int),
321324
("vid_req_avi", ctypes.c_int)]
@@ -393,6 +396,8 @@ class embeddings_generation_outputs(ctypes.Structure):
393396
("count", ctypes.c_int),
394397
("data", ctypes.c_char_p)]
395398

399+
400+
396401
def getdirpath():
397402
return os.path.dirname(os.path.realpath(__file__))
398403
def getabspath():
@@ -1788,9 +1793,58 @@ def sd_comfyui_tranform_params(genparams):
17881793
print("Warning: ComfyUI Payload Missing!")
17891794
return genparams
17901795

1796+
def sd_process_meta_fields(fields, config):
1797+
# aliases to match sd.cpp command-line options
1798+
aliases = {
1799+
'cfg-scale': 'cfg_scale',
1800+
'guidance': 'distilled_guidance',
1801+
'sampler': 'sampler_name',
1802+
'sampling-method': 'sampler_name',
1803+
'timestep-shift': 'shifted_timestep',
1804+
}
1805+
fields_dict = {aliases.get(k, k): v for k, v in fields}
1806+
# whitelist accepted parameters
1807+
whitelist = ['scheduler', 'shifted_timestep', 'distilled_guidance']
1808+
if config:
1809+
# note the current UI always set these
1810+
whitelist += ['sampler_name', 'cfg_scale']
1811+
fields_dict = {k: v for k, v in fields_dict.items() if k in whitelist}
1812+
return fields_dict
1813+
1814+
# json with top-level dict
1815+
def sd_parse_meta_field(prompt, config=False):
1816+
jfields = {}
1817+
try:
1818+
jfields = json.loads(prompt)
1819+
except json.JSONDecodeError:
1820+
# accept "field":"value",... without {} (also empty strings)
1821+
try:
1822+
jfields = json.loads('{ ' + prompt + ' }')
1823+
except json.JSONDecodeError:
1824+
print("Warning: couldn't parse meta prompt; it should be valid JSON.")
1825+
if not isinstance(jfields, dict):
1826+
jfields = {}
1827+
kv_dict = sd_process_meta_fields(jfields.items(), config)
1828+
return kv_dict
1829+
1830+
17911831
def sd_generate(genparams):
17921832
global maxctx, args, currentusergenkey, totalgens, pendingabortkey, chatcompl_adapter
17931833

1834+
sdgendefaults = sd_parse_meta_field(args.sdgendefaults or '', config=True)
1835+
params = dict()
1836+
defparams = dict()
1837+
for k, v in sdgendefaults.items():
1838+
if k in ['sampler_name', 'scheduler']:
1839+
# these can be explicitely set to 'default'; process later
1840+
# TODO should we consider values like 'clip_skip=-1' as 'default' too?
1841+
defparams[k] = v
1842+
else:
1843+
params[k] = v
1844+
# apply most of the defaults
1845+
params.update(genparams)
1846+
genparams = params
1847+
17941848
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
17951849
adapter_obj = genparams.get('adapter', default_adapter)
17961850
forced_negprompt = adapter_obj.get("add_sd_negative_prompt", "")
@@ -1816,13 +1870,20 @@ def sd_generate(genparams):
18161870
flip_mask = genparams.get("inpainting_mask_invert", 0)
18171871
denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6),0.6)
18181872
cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5),5)
1873+
distilled_guidance = tryparsefloat(genparams.get("distilled_guidance", None), None)
1874+
shifted_timestep = tryparseint(genparams.get("shifted_timestep", None), None)
18191875
sample_steps = tryparseint(genparams.get("steps", 20),20)
18201876
width = tryparseint(genparams.get("width", 512),512)
18211877
height = tryparseint(genparams.get("height", 512),512)
18221878
seed = tryparseint(genparams.get("seed", -1),-1)
18231879
if seed < 0:
18241880
seed = random.randint(100000, 999999)
1825-
sample_method = genparams.get("sampler_name", "k_euler_a")
1881+
sample_method = (genparams.get("sampler_name") or "default").lower()
1882+
if sample_method == 'default' and 'sampler_name' in defparams:
1883+
sample_method = (defparams.get("sampler_name") or "default").lower()
1884+
scheduler = (genparams.get("scheduler") or "default").lower()
1885+
if scheduler == 'default' and 'scheduler' in defparams:
1886+
scheduler = (defparams.get("scheduler") or "default").lower()
18261887
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
18271888
vid_req_frames = tryparseint(genparams.get("frames", 1),1)
18281889
vid_req_frames = 1 if (not vid_req_frames or vid_req_frames < 1) else vid_req_frames
@@ -1834,6 +1895,10 @@ def sd_generate(genparams):
18341895

18351896
#clean vars
18361897
cfg_scale = (1 if cfg_scale < 1 else (25 if cfg_scale > 25 else cfg_scale))
1898+
if distilled_guidance is not None and (distilled_guidance < 0 or distilled_guidance > 100):
1899+
distilled_guidance = None # fall back to the default
1900+
if shifted_timestep is not None and (shifted_timestep < 0 or shifted_timestep > 1000):
1901+
shifted_timestep = None # fall back to the default
18371902
sample_steps = (1 if sample_steps < 1 else (forced_steplimit if sample_steps > forced_steplimit else sample_steps))
18381903
vid_req_frames = (1 if vid_req_frames < 1 else (100 if vid_req_frames > 100 else vid_req_frames))
18391904

@@ -1852,12 +1917,17 @@ def sd_generate(genparams):
18521917
inputs.extra_images[n] = extra_image.encode("UTF-8")
18531918
inputs.flip_mask = flip_mask
18541919
inputs.cfg_scale = cfg_scale
1920+
if distilled_guidance is not None:
1921+
inputs.distilled_guidance = distilled_guidance
18551922
inputs.denoising_strength = denoising_strength
1923+
if shifted_timestep is not None:
1924+
inputs.shifted_timestep = shifted_timestep
18561925
inputs.sample_steps = sample_steps
18571926
inputs.width = width
18581927
inputs.height = height
18591928
inputs.seed = seed
1860-
inputs.sample_method = sample_method.lower().encode("UTF-8")
1929+
inputs.sample_method = sample_method.encode("UTF-8")
1930+
inputs.scheduler = scheduler.encode("UTF-8")
18611931
inputs.clip_skip = clip_skip
18621932
inputs.vid_req_frames = vid_req_frames
18631933
inputs.vid_req_avi = vid_req_avi
@@ -4675,6 +4745,7 @@ def hide_tooltip(event):
46754745
sd_clamped_soft_var = ctk.StringVar(value="0")
46764746
sd_threads_var = ctk.StringVar(value=str(default_threads))
46774747
sd_quant_var = ctk.StringVar(value=sd_quant_choices[0])
4748+
sd_gen_defaults_var = ctk.StringVar()
46784749

46794750
whisper_model_var = ctk.StringVar()
46804751
tts_model_var = ctk.StringVar()
@@ -5451,6 +5522,7 @@ def toggletaesd(a,b,c):
54515522
makecheckbox(images_tab, "Model CPU Offload", sd_offload_cpu_var, 50,padx=8, tooltiptxt="Offload image weights in RAM to save VRAM, swap into VRAM when needed.")
54525523
makecheckbox(images_tab, "VAE on CPU", sd_vae_cpu_var, 50,padx=160, tooltiptxt="Force VAE to CPU only for image generation.")
54535524
makecheckbox(images_tab, "CLIP on GPU", sd_clip_gpu_var, 50,padx=280, tooltiptxt="Put CLIP and T5 to GPU for image generation. Otherwise, CLIP will use CPU.")
5525+
makelabelentry(images_tab, "Default Params:", sd_gen_defaults_var, 52, 280, padx=110, singleline=True, tooltip='Default image generation parameters when not specified by the UI or API.\nSpecified as JSON fields: {"KEY1":"VALUE1", "KEY2":"VALUE2"...}')
54545526

54555527
# audio tab
54565528
audio_tab = tabcontent["Audio"]
@@ -5725,6 +5797,7 @@ def export_vars():
57255797
args.sdloramult = float(sd_loramult_var.get())
57265798
else:
57275799
args.sdlora = ""
5800+
args.sdgendefaults = sd_gen_defaults_var.get()
57285801

57295802
if whisper_model_var.get() != "":
57305803
args.whispermodel = whisper_model_var.get()
@@ -5951,6 +6024,7 @@ def import_vars(dict):
59516024

59526025
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
59536026
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
6027+
sd_gen_defaults_var.set(dict.get("sdgendefaults", ""))
59546028

59556029
whisper_model_var.set(dict["whispermodel"] if ("whispermodel" in dict and dict["whispermodel"]) else "")
59566030

@@ -7787,6 +7861,7 @@ def range_checker(arg: str):
77877861
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")
77887862
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LORA model to be applied.", type=float, default=1.0)
77897863
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
7864+
sdparsergroup.add_argument("--sdgendefaults", metavar=('{"parameter":"value",...}'), help="Sets default parameters for image generation, as a JSON string.", default="")
77907865
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
77917866
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")
77927867

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ struct SDParams {
6767
int width = 512;
6868
int height = 512;
6969

70-
sample_method_t sample_method = EULER_A;
70+
sample_method_t sample_method = SAMPLE_METHOD_DEFAULT;
71+
scheduler_t scheduler = scheduler_t::DEFAULT;
7172
int sample_steps = 20;
73+
float distilled_guidance = -1.0f;
74+
float shifted_timestep = 0;
7275
float strength = 0.75f;
7376
int64_t seed = 42;
7477
bool clip_on_cpu = false;
@@ -404,20 +407,24 @@ std::string clean_input_prompt(const std::string& input) {
404407
}
405408

406409
static std::string get_image_params(const sd_img_gen_params_t & params) {
407-
std::stringstream parameter_string;
408-
parameter_string << std::setprecision(3)
410+
std::stringstream ss;
411+
ss << std::setprecision(3)
409412
<< "Prompt: " << params.prompt
410413
<< " | NegativePrompt: " << params.negative_prompt
411414
<< " | Steps: " << params.sample_params.sample_steps
412415
<< " | CFGScale: " << params.sample_params.guidance.txt_cfg
413416
<< " | Guidance: " << params.sample_params.guidance.distilled_guidance
414417
<< " | Seed: " << params.seed
415418
<< " | Size: " << params.width << "x" << params.height
416-
<< " | Sampler: " << sd_sample_method_name(params.sample_params.sample_method)
417-
<< " | Clip skip: " << params.clip_skip
419+
<< " | Sampler: " << sd_sample_method_name(params.sample_params.sample_method);
420+
if (params.sample_params.scheduler != scheduler_t::DEFAULT)
421+
ss << " " << sd_schedule_name(params.sample_params.scheduler);
422+
if (params.sample_params.shifted_timestep != 0)
423+
ss << "| Timestep Shift: " << params.sample_params.shifted_timestep;
424+
ss << " | Clip skip: " << params.clip_skip
418425
<< " | Model: " << sdmodelfilename
419426
<< " | Version: KoboldCpp";
420-
return parameter_string.str();
427+
return ss.str();
421428
}
422429

423430
static inline int rounddown_64(int n) {
@@ -519,23 +526,29 @@ static void sd_fix_resolution(int &width, int &height, int img_hard_limit, int i
519526

520527
static enum sample_method_t sampler_from_name(const std::string& sampler)
521528
{
522-
if(sampler=="euler a"||sampler=="k_euler_a"||sampler=="euler_a") //all lowercase
529+
// all lowercase
530+
enum sample_method_t result = str_to_sample_method(sampler.c_str());
531+
if (result != sample_method_t::SAMPLE_METHOD_COUNT)
532+
{
533+
return result;
534+
}
535+
else if(sampler=="euler a"||sampler=="k_euler_a")
523536
{
524537
return sample_method_t::EULER_A;
525538
}
526-
else if(sampler=="euler"||sampler=="k_euler")
539+
else if(sampler=="k_euler")
527540
{
528541
return sample_method_t::EULER;
529542
}
530-
else if(sampler=="heun"||sampler=="k_heun")
543+
else if(sampler=="k_heun")
531544
{
532545
return sample_method_t::HEUN;
533546
}
534-
else if(sampler=="dpm2"||sampler=="k_dpm_2")
547+
else if(sampler=="k_dpm_2")
535548
{
536549
return sample_method_t::DPM2;
537550
}
538-
else if(sampler=="lcm"||sampler=="k_lcm")
551+
else if(sampler=="k_lcm")
539552
{
540553
return sample_method_t::LCM;
541554
}
@@ -549,11 +562,10 @@ static enum sample_method_t sampler_from_name(const std::string& sampler)
549562
}
550563
else
551564
{
552-
return sample_method_t::EULER_A;
565+
return sample_method_t::SAMPLE_METHOD_DEFAULT;
553566
}
554567
}
555568

556-
557569
uint8_t* load_image_from_b64(const std::string & b64str, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3)
558570
{
559571
std::vector<uint8_t> decoded_buf = kcpp_base64_decode(b64str);
@@ -644,6 +656,19 @@ uint8_t* load_image_from_b64(const std::string & b64str, int& width, int& height
644656
image_buffer = resized_image_buffer;
645657
}
646658
return image_buffer;
659+
660+
}
661+
662+
static enum scheduler_t scheduler_from_name(const char * scheduler)
663+
{
664+
if (scheduler) {
665+
enum scheduler_t result = str_to_schedule(scheduler);
666+
if (result != scheduler_t::SCHEDULE_COUNT)
667+
{
668+
return result;
669+
}
670+
}
671+
return scheduler_t::DEFAULT;
647672
}
648673

649674
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
@@ -674,13 +699,20 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
674699
sd_params->prompt = cleanprompt;
675700
sd_params->negative_prompt = cleannegprompt;
676701
sd_params->cfg_scale = inputs.cfg_scale;
702+
sd_params->distilled_guidance = inputs.distilled_guidance;
677703
sd_params->sample_steps = inputs.sample_steps;
704+
sd_params->shifted_timestep = inputs.shifted_timestep;
678705
sd_params->seed = inputs.seed;
679706
sd_params->width = inputs.width;
680707
sd_params->height = inputs.height;
681708
sd_params->strength = inputs.denoising_strength;
682709
sd_params->clip_skip = inputs.clip_skip;
683710
sd_params->sample_method = sampler_from_name(inputs.sample_method);
711+
sd_params->scheduler = scheduler_from_name(inputs.scheduler);
712+
713+
if (sd_params->sample_method == SAMPLE_METHOD_DEFAULT) {
714+
sd_params->sample_method = sd_get_default_sample_method(sd_ctx);
715+
}
684716

685717
auto loadedsdver = get_loaded_sd_version(sd_ctx);
686718
bool is_img2img = img2img_data != "";
@@ -841,10 +873,15 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
841873
params.clip_skip = sd_params->clip_skip;
842874
params.sample_params.guidance.txt_cfg = sd_params->cfg_scale;
843875
params.sample_params.guidance.img_cfg = sd_params->cfg_scale;
876+
if (sd_params->distilled_guidance >= 0.f) {
877+
params.sample_params.guidance.distilled_guidance = sd_params->distilled_guidance;
878+
}
844879
params.width = sd_params->width;
845880
params.height = sd_params->height;
846881
params.sample_params.sample_method = sd_params->sample_method;
882+
params.sample_params.scheduler = sd_params->scheduler;
847883
params.sample_params.sample_steps = sd_params->sample_steps;
884+
params.sample_params.shifted_timestep = sd_params->shifted_timestep;
848885
params.seed = sd_params->seed;
849886
params.strength = sd_params->strength;
850887
params.vae_tiling_params.enabled = dotile;
@@ -922,6 +959,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
922959
<< "\nCFGSCLE:" << params.sample_params.guidance.txt_cfg
923960
<< "\nSIZE:" << params.width << "x" << params.height
924961
<< "\nSM:" << sd_sample_method_name(params.sample_params.sample_method)
962+
<< "\nSCHED:" << sd_schedule_name(params.sample_params.scheduler)
925963
<< "\nSTEP:" << params.sample_params.sample_steps
926964
<< "\nSEED:" << params.seed
927965
<< "\nBATCH:" << params.batch_count

0 commit comments

Comments
 (0)