Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ arguments:
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
--schedule {discrete, karras, exponential, ays, gits, simple, sgm_uniform} Denoiser sigma schedule (default: discrete)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
Expand All @@ -347,6 +347,7 @@ arguments:
--chroma-disable-dit-mask disable dit mask for chroma
--chroma-enable-t5-mask enable t5 mask for chroma
--chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma
--timestep-shift N shift timestep, default: -1 off, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
-v, --verbose print extra info
```

Expand Down
84 changes: 82 additions & 2 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,25 @@ struct GITSSchedule : SigmaSchedule {
}
};

struct SGMUniformSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {

std::vector<float> result;
if (n == 0) {
result.push_back(0.0f);
return result;
}
result.reserve(n + 1);
int t_max = TIMESTEPS -1;
float step = static_cast<float>(t_max) / static_cast<float>(n > 1 ? (n -1) : 1) ;
for(uint32_t i=0; i<n; ++i) {
result.push_back(t_to_sigma_func(t_max - step * i));
}
result.push_back(0.0f);
return result;
}
};

struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
// These *COULD* be function arguments here,
Expand All @@ -251,6 +270,36 @@ struct KarrasSchedule : SigmaSchedule {
}
};

struct SimpleSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas;

if (n == 0) {
return result_sigmas;
}

result_sigmas.reserve(n + 1);

int model_sigmas_len = TIMESTEPS;

float step_factor = static_cast<float>(model_sigmas_len) / static_cast<float>(n);

for (uint32_t i = 0; i < n; ++i) {

int offset_from_start_of_py_array = static_cast<int>(static_cast<float>(i) * step_factor);
int timestep_index = model_sigmas_len - 1 - offset_from_start_of_py_array;

if (timestep_index < 0) {
timestep_index = 0;
}

result_sigmas.push_back(t_to_sigma(static_cast<float>(timestep_index)));
}
result_sigmas.push_back(0.0f);
return result_sigmas;
}
};

struct Denoiser {
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0;
Expand All @@ -262,8 +311,39 @@ struct Denoiser {
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;

virtual std::vector<float> get_sigmas(uint32_t n) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
// Check if the current schedule is SGMUniformSchedule
if (std::dynamic_pointer_cast<SGMUniformSchedule>(schedule)) {
std::vector<float> sigs;
sigs.reserve(n + 1);

if (n == 0) {
sigs.push_back(0.0f);
return sigs;
}

// Use the Denoiser's own sigma_to_t and t_to_sigma methods
float start_t_val = this->sigma_to_t(this->sigma_max());
float end_t_val = this->sigma_to_t(this->sigma_min());

float dt_per_step;
if (n > 0) {
dt_per_step = (end_t_val - start_t_val) / static_cast<float>(n);
} else {
dt_per_step = 0.0f;
}

for (uint32_t i = 0; i < n; ++i) {
float current_t = start_t_val + static_cast<float>(i) * dt_per_step;
sigs.push_back(this->t_to_sigma(current_t));
}

sigs.push_back(0.0f);
return sigs;

} else { // For all other schedules, use the existing virtual dispatch
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
}
}
};

Expand Down
21 changes: 18 additions & 3 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct SDParams {
float slg_scale = 0.f;
float skip_layer_start = 0.01f;
float skip_layer_end = 0.2f;
int shifted_timestep = -1;

bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
Expand Down Expand Up @@ -163,6 +164,7 @@ void print_params(SDParams params) {
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
printf(" timestep_shift: %d\n", params.shifted_timestep);
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
Expand Down Expand Up @@ -223,7 +225,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --schedule {discrete, karras, exponential, ays, gits, sgm_uniform, simple} Denoiser sigma schedule (default: discrete)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
Expand All @@ -235,6 +237,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color colors the logging tags according to level\n");
printf(" --timestep-shift N shift timestep for NitroFusion models, default: -1 off, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant\n");
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
Expand Down Expand Up @@ -487,7 +490,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
const char* arg = argv[index];
params.schedule = str_to_schedule(arg);
if (params.schedule == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid schedule %s\n",
fprintf(stderr, "error: invalid schedule %s, must be one of [discrete, karras, exponential, ays, gits, sgm_uniform, simple]\n",
arg);
return -1;
}
Expand Down Expand Up @@ -568,7 +571,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"-r", "--ref-image", "", on_ref_image_arg},
{"-h", "--help", "", on_help_arg},
};

auto on_timestep_shift_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
params.shifted_timestep = std::stoi(argv[index]);
if (params.shifted_timestep != -1 && (params.shifted_timestep < 1 || params.shifted_timestep > 1000)) {
fprintf(stderr, "error: timestep-shift must be between 1 and 1000, or -1 to disable\n");
return -1;
}
return 1;
};
options.manual_options.push_back({"", "--timestep-shift", "", on_timestep_shift_arg});
if (!parse_options(argc, argv, options)) {
print_usage(argc, argv);
exit(1);
Expand Down Expand Up @@ -979,6 +993,7 @@ int main(int argc, const char* argv[]) {
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str(),
params.shifted_timestep,
};

results = generate_image(sd_ctx, &img_gen_params);
Expand Down
Loading
Loading