Skip to content

Commit f0c6d71

Browse files
committed
Add gradient estimation sampler
1 parent 1c32fa0 commit f0c6d71

File tree

5 files changed

+68
-6
lines changed

5 files changed

+68
-6
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,12 @@ arguments:
326326
--slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)
327327
0 means disabled, a value of 2.5 is nice for sd3.5 medium
328328
--eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)
329+
--ge-gamma SCALE gamma for gradient estimation sampler: (default: 2.0)
329330
--skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])
330331
--skip-layer-start START SLG enabling point: (default: 0.01)
331332
--skip-layer-end END SLG disabling point: (default: 0.2)
332333
--scheduler {discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple} Denoiser sigma scheduler (default: discrete)
333-
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}
334+
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, gradient_estimation}
334335
sampling method (default: "euler" for Flux/SD3/Wan, "euler_a" otherwise)
335336
--timestep-shift N shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant
336337
--steps STEPS number of sample steps (default: 20)
@@ -344,7 +345,7 @@ arguments:
344345
--high-noise-skip-layer-start (high noise) SLG enabling point: (default: 0.01)
345346
--high-noise-skip-layer-end END (high noise) SLG disabling point: (default: 0.2)
346347
--high-noise-scheduler {discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple} Denoiser sigma scheduler (default: discrete)
347-
--high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}
348+
--high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, gradient_estimation}
348349
(high noise) sampling method (default: "euler_a")
349350
--high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)
350351
SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])

denoiser.hpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,8 @@ static void sample_k_diffusion(sample_method_t method,
576576
ggml_tensor* x,
577577
std::vector<float> sigmas,
578578
std::shared_ptr<RNG> rng,
579-
float eta) {
579+
float eta,
580+
float ge_gamma) {
580581
size_t steps = sigmas.size() - 1;
581582
// sample_euler_ancestral
582583
switch (method) {
@@ -1462,7 +1463,52 @@ static void sample_k_diffusion(sample_method_t method,
14621463
}
14631464
}
14641465
} break;
1466+
case GRADIENT_ESTIMATION: {
1467+
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
1468+
struct ggml_tensor* old_d = ggml_dup_tensor(work_ctx, x);
1469+
bool has_old_d = false;
14651470

1471+
for (int i = 0; i < steps; i++) {
1472+
float sigma = sigmas[i];
1473+
1474+
ggml_tensor* denoised = model(x, sigma, i + 1);
1475+
1476+
// d = (x - denoised) / sigma
1477+
float* vec_d = (float*)d->data;
1478+
float* vec_x = (float*)x->data;
1479+
float* vec_denoised = (float*)denoised->data;
1480+
1481+
for (int j = 0; j < ggml_nelements(d); j++) {
1482+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigma;
1483+
}
1484+
1485+
float dt = sigmas[i + 1] - sigma;
1486+
1487+
if (sigmas[i + 1] == 0) {
1488+
// Denoising step
1489+
for (int j = 0; j < ggml_nelements(x); j++) {
1490+
vec_x[j] = vec_denoised[j];
1491+
}
1492+
} else {
1493+
// Euler method
1494+
for (int j = 0; j < ggml_nelements(x); j++) {
1495+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
1496+
}
1497+
}
1498+
1499+
if (has_old_d) {
1500+
// Gradient estimation
1501+
float* vec_old_d = (float*)old_d->data;
1502+
for (int j = 0; j < ggml_nelements(x); j++) {
1503+
float d_bar = (ge_gamma - 1.f) * (vec_d[j] - vec_old_d[j]);
1504+
vec_x[j] = vec_x[j] + d_bar * dt;
1505+
}
1506+
}
1507+
// old_d = d
1508+
copy_ggml_tensor(old_d, d);
1509+
has_old_d = true;
1510+
}
1511+
} break;
14661512
default:
14671513
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
14681514
abort();

examples/cli/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,12 @@ void print_usage(int argc, const char* argv[]) {
248248
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
249249
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
250250
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
251+
printf(" --ge-gamma SCALE gamma for gradient estimation sampler: (default: 2.0)\n");
251252
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
252253
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
253254
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
254255
printf(" --scheduler {discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple} Denoiser sigma scheduler (default: discrete)\n");
255-
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
256+
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, gradient_estimation}\n");
256257
printf(" sampling method (default: \"euler\" for Flux/SD3/Wan, \"euler_a\" otherwise)\n");
257258
printf(" --timestep-shift N shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant\n");
258259
printf(" --steps STEPS number of sample steps (default: 20)\n");
@@ -266,7 +267,7 @@ void print_usage(int argc, const char* argv[]) {
266267
printf(" --high-noise-skip-layer-start (high noise) SLG enabling point: (default: 0.01)\n");
267268
printf(" --high-noise-skip-layer-end END (high noise) SLG disabling point: (default: 0.2)\n");
268269
printf(" --high-noise-scheduler {discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple} Denoiser sigma scheduler (default: discrete)\n");
269-
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
270+
printf(" --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, gradient_estimation}\n");
270271
printf(" (high noise) sampling method (default: \"euler_a\")\n");
271272
printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n");
272273
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
@@ -535,6 +536,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
535536
{"", "--skip-layer-start", "", &params.sample_params.guidance.slg.layer_start},
536537
{"", "--skip-layer-end", "", &params.sample_params.guidance.slg.layer_end},
537538
{"", "--eta", "", &params.sample_params.eta},
539+
{"", "--ge-gamma", "", &params.sample_params.ge_gamma},
538540
{"", "--high-noise-cfg-scale", "", &params.high_noise_sample_params.guidance.txt_cfg},
539541
{"", "--high-noise-img-cfg-scale", "", &params.high_noise_sample_params.guidance.img_cfg},
540542
{"", "--high-noise-guidance", "", &params.high_noise_sample_params.guidance.distilled_guidance},

stable-diffusion.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ const char* sampling_methods_str[] = {
5959
"DDIM \"trailing\"",
6060
"TCD",
6161
"Euler A",
62+
"Gradient Estimation",
6263
};
6364

6465
/*================================================== Helper Functions ================================================*/
@@ -1071,6 +1072,7 @@ class StableDiffusionGGML {
10711072
float eta,
10721073
int shifted_timestep,
10731074
sample_method_t method,
1075+
float ge_gamma,
10741076
const std::vector<float>& sigmas,
10751077
int start_merge_step,
10761078
SDCondition id_cond,
@@ -1299,7 +1301,7 @@ class StableDiffusionGGML {
12991301
return denoised;
13001302
};
13011303

1302-
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
1304+
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta, ge_gamma);
13031305

13041306
if (inverse_noise_scaling) {
13051307
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
@@ -1670,6 +1672,7 @@ const char* sample_method_to_str[] = {
16701672
"ddim_trailing",
16711673
"tcd",
16721674
"euler_a",
1675+
"gradient_estimation",
16731676
};
16741677

16751678
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -1812,6 +1815,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
18121815
sample_params->scheduler = DEFAULT;
18131816
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
18141817
sample_params->sample_steps = 20;
1818+
sample_params->ge_gamma = 2.0f;
18151819
}
18161820

18171821
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -1832,6 +1836,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
18321836
"sample_method: %s, "
18331837
"sample_steps: %d, "
18341838
"eta: %.2f, "
1839+
"ge_gamma: %.2f, "
18351840
"shifted_timestep: %d)",
18361841
sample_params->guidance.txt_cfg,
18371842
isfinite(sample_params->guidance.img_cfg)
@@ -1846,6 +1851,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
18461851
sd_sample_method_name(sample_params->sample_method),
18471852
sample_params->sample_steps,
18481853
sample_params->eta,
1854+
sample_params->ge_gamma,
18491855
sample_params->shifted_timestep);
18501856

18511857
return buf;
@@ -1979,6 +1985,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
19791985
int width,
19801986
int height,
19811987
enum sample_method_t sample_method,
1988+
float ge_gamma,
19821989
const std::vector<float>& sigmas,
19831990
int64_t seed,
19841991
int batch_count,
@@ -2266,6 +2273,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
22662273
eta,
22672274
shifted_timestep,
22682275
sample_method,
2276+
ge_gamma,
22692277
sigmas,
22702278
start_merge_step,
22712279
id_cond,
@@ -2570,6 +2578,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
25702578
width,
25712579
height,
25722580
sample_method,
2581+
sd_img_gen_params->sample_params.ge_gamma,
25732582
sigmas,
25742583
seed,
25752584
sd_img_gen_params->batch_count,
@@ -2902,6 +2911,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
29022911
sd_vid_gen_params->high_noise_sample_params.eta,
29032912
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
29042913
sd_vid_gen_params->high_noise_sample_params.sample_method,
2914+
sd_vid_gen_params->high_noise_sample_params.ge_gamma,
29052915
high_noise_sigmas,
29062916
-1,
29072917
{},
@@ -2938,6 +2948,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
29382948
sd_vid_gen_params->sample_params.eta,
29392949
sd_vid_gen_params->sample_params.shifted_timestep,
29402950
sd_vid_gen_params->sample_params.sample_method,
2951+
sd_vid_gen_params->sample_params.ge_gamma,
29412952
sigmas,
29422953
-1,
29432954
{},

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum sample_method_t {
4848
DDIM_TRAILING,
4949
TCD,
5050
EULER_A,
51+
GRADIENT_ESTIMATION,
5152
SAMPLE_METHOD_COUNT
5253
};
5354

@@ -186,6 +187,7 @@ typedef struct {
186187
enum sample_method_t sample_method;
187188
int sample_steps;
188189
float eta;
190+
float ge_gamma;
189191
int shifted_timestep;
190192
} sd_sample_params_t;
191193

0 commit comments

Comments
 (0)