Skip to content

Commit 985aedd

Browse files
authored
refactor: optimize the handling of pred type (leejet#1048)
1 parent 3f3610b commit 985aedd

File tree

3 files changed

+68
-99
lines changed

3 files changed

+68
-99
lines changed

examples/cli/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ struct SDCliParams {
409409
return -1;
410410
}
411411
const char* preview = argv[index];
412-
int preview_found = -1;
412+
int preview_found = -1;
413413
for (int m = 0; m < PREVIEW_COUNT; m++) {
414414
if (!strcmp(preview, previews_str[m])) {
415415
preview_found = m;
@@ -515,7 +515,7 @@ struct SDContextParams {
515515
bool chroma_use_t5_mask = false;
516516
int chroma_t5_mask_pad = 1;
517517

518-
prediction_t prediction = DEFAULT_PRED;
518+
prediction_t prediction = PREDICTION_COUNT;
519519
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
520520

521521
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};

stable-diffusion.cpp

Lines changed: 65 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ class StableDiffusionGGML {
707707
return false;
708708
}
709709

710-
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
710+
LOG_DEBUG("finished loaded file");
711711

712712
{
713713
size_t clip_params_mem_size = cond_stage_model->get_params_buffer_size();
@@ -782,8 +782,59 @@ class StableDiffusionGGML {
782782
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
783783
}
784784

785-
if (sd_ctx_params->prediction != DEFAULT_PRED) {
786-
switch (sd_ctx_params->prediction) {
785+
// init denoiser
786+
{
787+
prediction_t pred_type = sd_ctx_params->prediction;
788+
float flow_shift = sd_ctx_params->flow_shift;
789+
790+
if (pred_type == PREDICTION_COUNT) {
791+
if (sd_version_is_sd2(version)) {
792+
// check is_using_v_parameterization_for_sd2
793+
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
794+
pred_type = V_PRED;
795+
} else {
796+
pred_type = EPS_PRED;
797+
}
798+
} else if (sd_version_is_sdxl(version)) {
799+
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
800+
// CosXL models
801+
// TODO: get sigma_min and sigma_max values from file
802+
pred_type = EDM_V_PRED;
803+
} else if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
804+
pred_type = V_PRED;
805+
} else {
806+
pred_type = EPS_PRED;
807+
}
808+
} else if (sd_version_is_sd3(version) ||
809+
sd_version_is_wan(version) ||
810+
sd_version_is_qwen_image(version) ||
811+
sd_version_is_z_image(version)) {
812+
pred_type = FLOW_PRED;
813+
if (flow_shift == INFINITY) {
814+
if (sd_version_is_wan(version)) {
815+
flow_shift = 5.f;
816+
} else {
817+
flow_shift = 3.f;
818+
}
819+
}
820+
} else if (sd_version_is_flux(version)) {
821+
pred_type = FLUX_FLOW_PRED;
822+
if (flow_shift == INFINITY) {
823+
flow_shift = 1.0f; // TODO: validate
824+
for (const auto& [name, tensor_storage] : tensor_storage_map) {
825+
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
826+
flow_shift = 1.15f;
827+
}
828+
}
829+
}
830+
} else if (sd_version_is_flux2(version)) {
831+
pred_type = FLUX2_FLOW_PRED;
832+
} else {
833+
pred_type = EPS_PRED;
834+
}
835+
}
836+
837+
switch (pred_type) {
787838
case EPS_PRED:
788839
LOG_INFO("running in eps-prediction mode");
789840
break;
@@ -795,22 +846,14 @@ class StableDiffusionGGML {
795846
LOG_INFO("running in v-prediction EDM mode");
796847
denoiser = std::make_shared<EDMVDenoiser>();
797848
break;
798-
case SD3_FLOW_PRED: {
849+
case FLOW_PRED: {
799850
LOG_INFO("running in FLOW mode");
800-
float shift = sd_ctx_params->flow_shift;
801-
if (shift == INFINITY) {
802-
shift = 3.0;
803-
}
804-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
851+
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
805852
break;
806853
}
807854
case FLUX_FLOW_PRED: {
808855
LOG_INFO("running in Flux FLOW mode");
809-
float shift = sd_ctx_params->flow_shift;
810-
if (shift == INFINITY) {
811-
shift = 3.0;
812-
}
813-
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
856+
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
814857
break;
815858
}
816859
case FLUX2_FLOW_PRED: {
@@ -819,93 +862,21 @@ class StableDiffusionGGML {
819862
break;
820863
}
821864
default: {
822-
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
865+
LOG_ERROR("Unknown predition type %i", pred_type);
866+
ggml_free(ctx);
823867
return false;
824868
}
825869
}
826-
} else {
827-
if (sd_version_is_sd2(version)) {
828-
// check is_using_v_parameterization_for_sd2
829-
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
830-
is_using_v_parameterization = true;
831-
}
832-
} else if (sd_version_is_sdxl(version)) {
833-
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
834-
// CosXL models
835-
// TODO: get sigma_min and sigma_max values from file
836-
is_using_edm_v_parameterization = true;
837-
}
838-
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
839-
is_using_v_parameterization = true;
840-
}
841-
} else if (version == VERSION_SVD) {
842-
// TODO: V_PREDICTION_EDM
843-
is_using_v_parameterization = true;
844-
}
845870

846-
if (sd_version_is_sd3(version)) {
847-
LOG_INFO("running in FLOW mode");
848-
float shift = sd_ctx_params->flow_shift;
849-
if (shift == INFINITY) {
850-
shift = 3.0;
851-
}
852-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
853-
} else if (sd_version_is_flux(version)) {
854-
LOG_INFO("running in Flux FLOW mode");
855-
float shift = sd_ctx_params->flow_shift;
856-
if (shift == INFINITY) {
857-
shift = 1.0f; // TODO: validate
858-
for (const auto& [name, tensor_storage] : tensor_storage_map) {
859-
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
860-
shift = 1.15f;
861-
}
862-
}
871+
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
872+
if (comp_vis_denoiser) {
873+
for (int i = 0; i < TIMESTEPS; i++) {
874+
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
875+
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
863876
}
864-
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
865-
} else if (sd_version_is_flux2(version)) {
866-
LOG_INFO("running in Flux2 FLOW mode");
867-
denoiser = std::make_shared<Flux2FlowDenoiser>();
868-
} else if (sd_version_is_wan(version)) {
869-
LOG_INFO("running in FLOW mode");
870-
float shift = sd_ctx_params->flow_shift;
871-
if (shift == INFINITY) {
872-
shift = 5.0;
873-
}
874-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
875-
} else if (sd_version_is_qwen_image(version)) {
876-
LOG_INFO("running in FLOW mode");
877-
float shift = sd_ctx_params->flow_shift;
878-
if (shift == INFINITY) {
879-
shift = 3.0;
880-
}
881-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
882-
} else if (sd_version_is_z_image(version)) {
883-
LOG_INFO("running in FLOW mode");
884-
float shift = sd_ctx_params->flow_shift;
885-
if (shift == INFINITY) {
886-
shift = 3.0f;
887-
}
888-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
889-
} else if (is_using_v_parameterization) {
890-
LOG_INFO("running in v-prediction mode");
891-
denoiser = std::make_shared<CompVisVDenoiser>();
892-
} else if (is_using_edm_v_parameterization) {
893-
LOG_INFO("running in v-prediction EDM mode");
894-
denoiser = std::make_shared<EDMVDenoiser>();
895-
} else {
896-
LOG_INFO("running in eps-prediction mode");
897877
}
898878
}
899879

900-
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
901-
if (comp_vis_denoiser) {
902-
for (int i = 0; i < TIMESTEPS; i++) {
903-
comp_vis_denoiser->sigmas[i] = std::sqrt((1 - ((float*)alphas_cumprod_tensor->data)[i]) / ((float*)alphas_cumprod_tensor->data)[i]);
904-
comp_vis_denoiser->log_sigmas[i] = std::log(comp_vis_denoiser->sigmas[i]);
905-
}
906-
}
907-
908-
LOG_DEBUG("finished loaded file");
909880
ggml_free(ctx);
910881
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
911882
return true;
@@ -2426,7 +2397,6 @@ enum scheduler_t str_to_scheduler(const char* str) {
24262397
}
24272398

24282399
const char* prediction_to_str[] = {
2429-
"default",
24302400
"eps",
24312401
"v",
24322402
"edm_v",
@@ -2512,7 +2482,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
25122482
sd_ctx_params->wtype = SD_TYPE_COUNT;
25132483
sd_ctx_params->rng_type = CUDA_RNG;
25142484
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
2515-
sd_ctx_params->prediction = DEFAULT_PRED;
2485+
sd_ctx_params->prediction = PREDICTION_COUNT;
25162486
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
25172487
sd_ctx_params->offload_params_to_cpu = false;
25182488
sd_ctx_params->keep_clip_on_cpu = false;

stable-diffusion.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ enum scheduler_t {
6565
};
6666

6767
enum prediction_t {
68-
DEFAULT_PRED,
6968
EPS_PRED,
7069
V_PRED,
7170
EDM_V_PRED,
72-
SD3_FLOW_PRED,
71+
FLOW_PRED,
7372
FLUX_FLOW_PRED,
7473
FLUX2_FLOW_PRED,
7574
PREDICTION_COUNT

0 commit comments

Comments
 (0)