Skip to content

Commit fc3b771

Browse files
committed
reorganize sampler selection in sdtype_adapter
1 parent bbc8d8f commit fc3b771

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,42 @@ static void sd_fix_resolution(int &width, int &height, int img_hard_limit, int i
500500
}
501501
}
502502

503+
static enum sample_method_t sampler_from_name(const std::string& sampler)
504+
{
505+
if(sampler=="euler a"||sampler=="k_euler_a"||sampler=="euler_a") //all lowercase
506+
{
507+
return sample_method_t::EULER_A;
508+
}
509+
else if(sampler=="euler"||sampler=="k_euler")
510+
{
511+
return sample_method_t::EULER;
512+
}
513+
else if(sampler=="heun"||sampler=="k_heun")
514+
{
515+
return sample_method_t::HEUN;
516+
}
517+
else if(sampler=="dpm2"||sampler=="k_dpm_2")
518+
{
519+
return sample_method_t::DPM2;
520+
}
521+
else if(sampler=="lcm"||sampler=="k_lcm")
522+
{
523+
return sample_method_t::LCM;
524+
}
525+
else if(sampler=="ddim")
526+
{
527+
return sample_method_t::DDIM_TRAILING;
528+
}
529+
else if(sampler=="dpm++ 2m karras" || sampler=="dpm++ 2m" || sampler=="k_dpmpp_2m")
530+
{
531+
return sample_method_t::DPMPP2M;
532+
}
533+
else
534+
{
535+
return sample_method_t::EULER_A;
536+
}
537+
}
538+
503539
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
504540
{
505541
sd_generation_outputs output;
@@ -525,8 +561,6 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
525561
extra_image_data.push_back(std::string(inputs.extra_images[i]));
526562
}
527563

528-
std::string sampler = inputs.sample_method;
529-
530564
sd_params->prompt = cleanprompt;
531565
sd_params->negative_prompt = cleannegprompt;
532566
sd_params->cfg_scale = inputs.cfg_scale;
@@ -536,6 +570,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
536570
sd_params->height = inputs.height;
537571
sd_params->strength = inputs.denoising_strength;
538572
sd_params->clip_skip = inputs.clip_skip;
573+
sd_params->sample_method = sampler_from_name(inputs.sample_method);
539574
sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG);
540575

541576
auto loadedsdver = get_loaded_sd_version(sd_ctx);
@@ -548,12 +583,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
548583
}
549584
sd_params->cfg_scale = 1.0f;
550585
}
551-
if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") {
586+
if (sd_params->sample_method == sample_method_t::EULER_A) {
552587
//euler a broken on flux
553588
if (!sd_is_quiet && sddebugmode) {
554589
printf("Flux: switching Euler A to Euler\n");
555590
}
556-
sampler = "euler";
591+
sd_params->sample_method = sample_method_t::EULER;
557592
}
558593
}
559594

@@ -617,39 +652,6 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
617652

618653
fflush(stdout);
619654

620-
if(sampler=="euler a"||sampler=="k_euler_a"||sampler=="euler_a") //all lowercase
621-
{
622-
sd_params->sample_method = sample_method_t::EULER_A;
623-
}
624-
else if(sampler=="euler"||sampler=="k_euler")
625-
{
626-
sd_params->sample_method = sample_method_t::EULER;
627-
}
628-
else if(sampler=="heun"||sampler=="k_heun")
629-
{
630-
sd_params->sample_method = sample_method_t::HEUN;
631-
}
632-
else if(sampler=="dpm2"||sampler=="k_dpm_2")
633-
{
634-
sd_params->sample_method = sample_method_t::DPM2;
635-
}
636-
else if(sampler=="lcm"||sampler=="k_lcm")
637-
{
638-
sd_params->sample_method = sample_method_t::LCM;
639-
}
640-
else if(sampler=="ddim")
641-
{
642-
sd_params->sample_method = sample_method_t::DDIM_TRAILING;
643-
}
644-
else if(sampler=="dpm++ 2m karras" || sampler=="dpm++ 2m" || sampler=="k_dpmpp_2m")
645-
{
646-
sd_params->sample_method = sample_method_t::DPMPP2M;
647-
}
648-
else
649-
{
650-
sd_params->sample_method = sample_method_t::EULER_A;
651-
}
652-
653655
if(extra_image_data.size()>0)
654656
{
655657
if(input_extraimage_buffers.size()>0) //just in time free old buffer

0 commit comments

Comments
 (0)