Skip to content

Commit d9de1da

Browse files
committed
Add gradient estimation sampler
1 parent 2034588 commit d9de1da

File tree

4 files changed

+63
-3
lines changed

4 files changed

+63
-3
lines changed

denoiser.hpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ static void sample_k_diffusion(sample_method_t method,
636636
ggml_tensor* x,
637637
std::vector<float> sigmas,
638638
std::shared_ptr<RNG> rng,
639-
float eta) {
639+
float eta,
640+
float ge_gamma) {
640641
size_t steps = sigmas.size() - 1;
641642
// sample_euler_ancestral
642643
switch (method) {
@@ -1522,7 +1523,52 @@ static void sample_k_diffusion(sample_method_t method,
15221523
}
15231524
}
15241525
} break;
1526+
case GRADIENT_ESTIMATION: {
1527+
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
1528+
struct ggml_tensor* old_d = ggml_dup_tensor(work_ctx, x);
1529+
bool has_old_d = false;
15251530

1531+
for (int i = 0; i < steps; i++) {
1532+
float sigma = sigmas[i];
1533+
1534+
ggml_tensor* denoised = model(x, sigma, i + 1);
1535+
1536+
// d = (x - denoised) / sigma
1537+
float* vec_d = (float*)d->data;
1538+
float* vec_x = (float*)x->data;
1539+
float* vec_denoised = (float*)denoised->data;
1540+
1541+
for (int j = 0; j < ggml_nelements(d); j++) {
1542+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigma;
1543+
}
1544+
1545+
float dt = sigmas[i + 1] - sigma;
1546+
1547+
if (sigmas[i + 1] == 0) {
1548+
// Denoising step
1549+
for (int j = 0; j < ggml_nelements(x); j++) {
1550+
vec_x[j] = vec_denoised[j];
1551+
}
1552+
} else {
1553+
// Euler method
1554+
for (int j = 0; j < ggml_nelements(x); j++) {
1555+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
1556+
}
1557+
}
1558+
1559+
if (has_old_d) {
1560+
// Gradient estimation
1561+
float* vec_old_d = (float*)old_d->data;
1562+
for (int j = 0; j < ggml_nelements(x); j++) {
1563+
float d_bar = (ge_gamma - 1.f) * (vec_d[j] - vec_old_d[j]);
1564+
vec_x[j] = vec_x[j] + d_bar * dt;
1565+
}
1566+
}
1567+
// old_d = d
1568+
copy_ggml_tensor(old_d, d);
1569+
has_old_d = true;
1570+
}
1571+
} break;
15261572
default:
15271573
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
15281574
abort();

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
716716
"eta in DDIM, only for DDIM and TCD (default: 0)",
717717
&params.sample_params.eta},
718718
{"",
719-
"--high-noise-cfg-scale",
719+
"--ge-gamma", "", &params.sample_params.ge_gamma},
720+
{"", "--high-noise-cfg-scale",
720721
"(high noise) unconditional guidance scale: (default: 7.0)",
721722
&params.high_noise_sample_params.guidance.txt_cfg},
722723
{"",

stable-diffusion.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ const char* sampling_methods_str[] = {
5959
"LCM",
6060
"DDIM \"trailing\"",
6161
"TCD",
62+
"Gradient Estimation",
6263
};
6364

6465
/*================================================== Helper Functions ================================================*/
@@ -1433,6 +1434,7 @@ class StableDiffusionGGML {
14331434
float eta,
14341435
int shifted_timestep,
14351436
sample_method_t method,
1437+
float ge_gamma,
14361438
const std::vector<float>& sigmas,
14371439
int start_merge_step,
14381440
SDCondition id_cond,
@@ -1788,7 +1790,7 @@ class StableDiffusionGGML {
17881790
return denoised;
17891791
};
17901792

1791-
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
1793+
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta, ge_gamma);
17921794

17931795
if (easycache_enabled) {
17941796
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
@@ -2239,6 +2241,7 @@ const char* sample_method_to_str[] = {
22392241
"lcm",
22402242
"ddim_trailing",
22412243
"tcd",
2244+
"gradient_estimation",
22422245
};
22432246

22442247
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -2469,6 +2472,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
24692472
sample_params->scheduler = SCHEDULER_COUNT;
24702473
sample_params->sample_method = SAMPLE_METHOD_COUNT;
24712474
sample_params->sample_steps = 20;
2475+
sample_params->ge_gamma = 2.0f;
24722476
}
24732477

24742478
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -2489,6 +2493,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
24892493
"sample_method: %s, "
24902494
"sample_steps: %d, "
24912495
"eta: %.2f, "
2496+
"ge_gamma: %.2f, "
24922497
"shifted_timestep: %d)",
24932498
sample_params->guidance.txt_cfg,
24942499
std::isfinite(sample_params->guidance.img_cfg)
@@ -2503,6 +2508,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
25032508
sd_sample_method_name(sample_params->sample_method),
25042509
sample_params->sample_steps,
25052510
sample_params->eta,
2511+
sample_params->ge_gamma,
25062512
sample_params->shifted_timestep);
25072513

25082514
return buf;
@@ -2654,6 +2660,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
26542660
int width,
26552661
int height,
26562662
enum sample_method_t sample_method,
2663+
float ge_gamma,
26572664
const std::vector<float>& sigmas,
26582665
int64_t seed,
26592666
int batch_count,
@@ -2948,6 +2955,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
29482955
eta,
29492956
shifted_timestep,
29502957
sample_method,
2958+
ge_gamma,
29512959
sigmas,
29522960
start_merge_step,
29532961
id_cond,
@@ -3262,6 +3270,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
32623270
width,
32633271
height,
32643272
sample_method,
3273+
sd_img_gen_params->sample_params.ge_gamma,
32653274
sigmas,
32663275
seed,
32673276
sd_img_gen_params->batch_count,
@@ -3598,6 +3607,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35983607
sd_vid_gen_params->high_noise_sample_params.eta,
35993608
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
36003609
high_noise_sample_method,
3610+
sd_vid_gen_params->high_noise_sample_params.ge_gamma,
36013611
high_noise_sigmas,
36023612
-1,
36033613
{},
@@ -3635,6 +3645,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36353645
sd_vid_gen_params->sample_params.eta,
36363646
sd_vid_gen_params->sample_params.shifted_timestep,
36373647
sample_method,
3648+
sd_vid_gen_params->sample_params.ge_gamma,
36383649
sigmas,
36393650
-1,
36403651
{},

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum sample_method_t {
4848
LCM_SAMPLE_METHOD,
4949
DDIM_TRAILING_SAMPLE_METHOD,
5050
TCD_SAMPLE_METHOD,
51+
GRADIENT_ESTIMATION,
5152
SAMPLE_METHOD_COUNT
5253
};
5354

@@ -218,6 +219,7 @@ typedef struct {
218219
enum sample_method_t sample_method;
219220
int sample_steps;
220221
float eta;
222+
float ge_gamma;
221223
int shifted_timestep;
222224
} sd_sample_params_t;
223225

0 commit comments

Comments
 (0)