Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ struct DiscreteFlowDenoiser : public Denoiser {

float sigma_data = 1.0f;

DiscreteFlowDenoiser() {
DiscreteFlowDenoiser(float shift = 3.0f)
: shift(shift) {
set_parameters();
}

Expand Down
9 changes: 7 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct SDParams {
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
float flow_shift = INFINITY;

SDParams() {
sd_sample_params_init(&sample_params);
Expand Down Expand Up @@ -171,6 +172,7 @@ void print_params(SDParams params) {
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
printf(" moe_boundary: %.3f\n", params.moe_boundary);
printf(" flow_shift: %.2f\n", params.flow_shift);
printf(" strength(img2img): %.2f\n", params.strength);
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
printf(" seed: %ld\n", params.seed);
Expand Down Expand Up @@ -278,8 +280,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
printf(" --video-frames video frames (default: 1)\n");
printf(" --fps fps (default: 24)\n");
printf(" --moe-boundary BOUNDARY Timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
printf(" Only enabled if `--high-noise-steps` is set to -1\n");
printf(" --moe-boundary BOUNDARY timestep boundary for Wan2.2 MoE model. (default: 0.875)\n");
printf(" only enabled if `--high-noise-steps` is set to -1\n");
printf(" --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto)\n");
printf(" -v, --verbose print extra info\n");
}

Expand Down Expand Up @@ -514,6 +517,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--style-ratio", "", &params.style_ratio},
{"", "--control-strength", "", &params.control_strength},
{"", "--moe-boundary", "", &params.moe_boundary},
{"", "--flow-shift", "", &params.flow_shift},
};

options.bool_options = {
Expand Down Expand Up @@ -1181,6 +1185,7 @@ int main(int argc, const char* argv[]) {
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
params.chroma_t5_mask_pad,
params.flow_shift,
};

sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
Expand Down
13 changes: 11 additions & 2 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,11 @@ class StableDiffusionGGML {

if (sd_version_is_sd3(version)) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.0f; // TODO: validate
Expand All @@ -694,7 +698,11 @@ class StableDiffusionGGML {
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 5.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
Expand Down Expand Up @@ -1553,6 +1561,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1;
sd_ctx_params->flow_shift = INFINITY;
}

char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ typedef struct {
bool chroma_use_dit_mask;
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
float flow_shift;
} sd_ctx_params_t;

typedef struct {
Expand Down
Loading