Skip to content

Commit 4a9e10f

Browse files
author
anyshu
committed
更新
1 parent 71bf10b commit 4a9e10f

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

tools/server/server-diffusion.cpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -635,36 +635,38 @@ struct server_task {
635635
params.sampling.samplers = defaults.sampling.samplers;
636636
}
637637
}
638-
639-
// Diffusion parameters
640-
params.diffusion_steps = json_value(data, "diffusion_steps", params.diffusion_steps);
641638

642-
// Parse diffusion algorithm from string or int
643-
const auto diffusion_alg = data.find("diffusion_algorithm");
644-
if (diffusion_alg != data.end()) {
645-
if (diffusion_alg->is_string()) {
646-
std::string alg_str = diffusion_alg->get<std::string>();
647-
if (alg_str == "origin") params.diffusion_algo = ORIGIN;
648-
else if (alg_str == "entropy") params.diffusion_algo = ENTROPY_BASED;
649-
else if (alg_str == "margin") params.diffusion_algo = MARGIN_BASED;
650-
else if (alg_str == "random") params.diffusion_algo = RANDOM;
651-
else if (alg_str == "confidence") params.diffusion_algo = CONFIDENCE_BASED;
652-
} else if (diffusion_alg->is_number_integer()) {
653-
int alg_int = diffusion_alg->get<int>();
654-
if (alg_int >= 0 && alg_int <= 4) {
655-
params.diffusion_algo = static_cast<diffusion_algorithm>(alg_int);
639+
{
640+
// Diffusion parameters
641+
params.diffusion_steps = json_value(data, "diffusion_steps", params.diffusion_steps);
642+
643+
// Parse diffusion algorithm from string or int
644+
const auto diffusion_alg = data.find("diffusion_algorithm");
645+
if (diffusion_alg != data.end()) {
646+
if (diffusion_alg->is_string()) {
647+
std::string alg_str = diffusion_alg->get<std::string>();
648+
if (alg_str == "origin") params.diffusion_algo = ORIGIN;
649+
else if (alg_str == "entropy") params.diffusion_algo = ENTROPY_BASED;
650+
else if (alg_str == "margin") params.diffusion_algo = MARGIN_BASED;
651+
else if (alg_str == "random") params.diffusion_algo = RANDOM;
652+
else if (alg_str == "confidence") params.diffusion_algo = CONFIDENCE_BASED;
653+
} else if (diffusion_alg->is_number_integer()) {
654+
int alg_int = diffusion_alg->get<int>();
655+
if (alg_int >= 0 && alg_int <= 4) {
656+
params.diffusion_algo = static_cast<diffusion_algorithm>(alg_int);
657+
}
656658
}
657659
}
660+
661+
params.diffusion_eps = json_value(data, "diffusion_eps", params.diffusion_eps);
662+
params.diffusion_block_len = json_value(data, "diffusion_block_length", params.diffusion_block_len);
663+
params.diffusion_cfg_scale = json_value(data, "cfg_scale", params.diffusion_cfg_scale);
664+
params.diffusion_alg_temp = json_value(data, "diffusion_temperature", params.diffusion_alg_temp);
665+
params.diffusion_visual = json_value(data, "visual_mode", params.diffusion_visual);
666+
params.diffusion_shift_logits = json_value(data, "shift_logits", params.diffusion_shift_logits);
667+
params.diffusion_add_gumbel_noise = json_value(data, "add_gumbel_noise", params.diffusion_add_gumbel_noise);
668+
params.diffusion_max_length = json_value(data, "max_length", params.diffusion_max_length);
658669
}
659-
660-
params.diffusion_eps = json_value(data, "diffusion_eps", params.diffusion_eps);
661-
params.diffusion_block_len = json_value(data, "diffusion_block_length", params.diffusion_block_len);
662-
params.diffusion_cfg_scale = json_value(data, "cfg_scale", params.diffusion_cfg_scale);
663-
params.diffusion_alg_temp = json_value(data, "diffusion_temperature", params.diffusion_alg_temp);
664-
params.diffusion_visual = json_value(data, "visual_mode", params.diffusion_visual);
665-
params.diffusion_shift_logits = json_value(data, "shift_logits", params.diffusion_shift_logits);
666-
params.diffusion_add_gumbel_noise = json_value(data, "add_gumbel_noise", params.diffusion_add_gumbel_noise);
667-
params.diffusion_max_length = json_value(data, "max_length", params.diffusion_max_length);
668670

669671
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
670672
params.oaicompat_model = json_value(data, "model", model_name);

0 commit comments

Comments
 (0)