Skip to content

Commit f5c3fcf

Browse files
committed
Add gradient estimation sampler
1 parent 5498cc0 commit f5c3fcf

File tree

4 files changed

+63
-3
lines changed

4 files changed

+63
-3
lines changed

denoiser.hpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,8 @@ static void sample_k_diffusion(sample_method_t method,
576576
ggml_tensor* x,
577577
std::vector<float> sigmas,
578578
std::shared_ptr<RNG> rng,
579-
float eta) {
579+
float eta,
580+
float ge_gamma) {
580581
size_t steps = sigmas.size() - 1;
581582
// sample_euler_ancestral
582583
switch (method) {
@@ -1462,7 +1463,52 @@ static void sample_k_diffusion(sample_method_t method,
14621463
}
14631464
}
14641465
} break;
1466+
case GRADIENT_ESTIMATION: {
1467+
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
1468+
struct ggml_tensor* old_d = ggml_dup_tensor(work_ctx, x);
1469+
bool has_old_d = false;
14651470

1471+
for (int i = 0; i < steps; i++) {
1472+
float sigma = sigmas[i];
1473+
1474+
ggml_tensor* denoised = model(x, sigma, i + 1);
1475+
1476+
// d = (x - denoised) / sigma
1477+
float* vec_d = (float*)d->data;
1478+
float* vec_x = (float*)x->data;
1479+
float* vec_denoised = (float*)denoised->data;
1480+
1481+
for (int j = 0; j < ggml_nelements(d); j++) {
1482+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigma;
1483+
}
1484+
1485+
float dt = sigmas[i + 1] - sigma;
1486+
1487+
if (sigmas[i + 1] == 0) {
1488+
// Denoising step
1489+
for (int j = 0; j < ggml_nelements(x); j++) {
1490+
vec_x[j] = vec_denoised[j];
1491+
}
1492+
} else {
1493+
// Euler method
1494+
for (int j = 0; j < ggml_nelements(x); j++) {
1495+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
1496+
}
1497+
}
1498+
1499+
if (has_old_d) {
1500+
// Gradient estimation
1501+
float* vec_old_d = (float*)old_d->data;
1502+
for (int j = 0; j < ggml_nelements(x); j++) {
1503+
float d_bar = (ge_gamma - 1.f) * (vec_d[j] - vec_old_d[j]);
1504+
vec_x[j] = vec_x[j] + d_bar * dt;
1505+
}
1506+
}
1507+
// old_d = d
1508+
copy_ggml_tensor(old_d, d);
1509+
has_old_d = true;
1510+
}
1511+
} break;
14661512
default:
14671513
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
14681514
abort();

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
716716
"eta in DDIM, only for DDIM and TCD (default: 0)",
717717
&params.sample_params.eta},
718718
{"",
719-
"--high-noise-cfg-scale",
719+
"--ge-gamma", "", &params.sample_params.ge_gamma},
720+
{"", "--high-noise-cfg-scale",
720721
"(high noise) unconditional guidance scale: (default: 7.0)",
721722
&params.high_noise_sample_params.guidance.txt_cfg},
722723
{"",

stable-diffusion.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ const char* sampling_methods_str[] = {
6060
"DDIM \"trailing\"",
6161
"TCD",
6262
"Euler A",
63+
"Gradient Estimation",
6364
};
6465

6566
/*================================================== Helper Functions ================================================*/
@@ -1481,6 +1482,7 @@ class StableDiffusionGGML {
14811482
float eta,
14821483
int shifted_timestep,
14831484
sample_method_t method,
1485+
float ge_gamma,
14841486
const std::vector<float>& sigmas,
14851487
int start_merge_step,
14861488
SDCondition id_cond,
@@ -1836,7 +1838,7 @@ class StableDiffusionGGML {
18361838
return denoised;
18371839
};
18381840

1839-
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
1841+
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta, ge_gamma);
18401842

18411843
if (easycache_enabled) {
18421844
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
@@ -2288,6 +2290,7 @@ const char* sample_method_to_str[] = {
22882290
"ddim_trailing",
22892291
"tcd",
22902292
"euler_a",
2293+
"gradient_estimation",
22912294
};
22922295

22932296
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -2518,6 +2521,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
25182521
sample_params->scheduler = DEFAULT;
25192522
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
25202523
sample_params->sample_steps = 20;
2524+
sample_params->ge_gamma = 2.0f;
25212525
}
25222526

25232527
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -2538,6 +2542,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
25382542
"sample_method: %s, "
25392543
"sample_steps: %d, "
25402544
"eta: %.2f, "
2545+
"ge_gamma: %.2f, "
25412546
"shifted_timestep: %d)",
25422547
sample_params->guidance.txt_cfg,
25432548
std::isfinite(sample_params->guidance.img_cfg)
@@ -2552,6 +2557,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
25522557
sd_sample_method_name(sample_params->sample_method),
25532558
sample_params->sample_steps,
25542559
sample_params->eta,
2560+
sample_params->ge_gamma,
25552561
sample_params->shifted_timestep);
25562562

25572563
return buf;
@@ -2695,6 +2701,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
26952701
int width,
26962702
int height,
26972703
enum sample_method_t sample_method,
2704+
float ge_gamma,
26982705
const std::vector<float>& sigmas,
26992706
int64_t seed,
27002707
int batch_count,
@@ -2990,6 +2997,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
29902997
eta,
29912998
shifted_timestep,
29922999
sample_method,
3000+
ge_gamma,
29933001
sigmas,
29943002
start_merge_step,
29953003
id_cond,
@@ -3305,6 +3313,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
33053313
width,
33063314
height,
33073315
sample_method,
3316+
sd_img_gen_params->sample_params.ge_gamma,
33083317
sigmas,
33093318
seed,
33103319
sd_img_gen_params->batch_count,
@@ -3632,6 +3641,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36323641
sd_vid_gen_params->high_noise_sample_params.eta,
36333642
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
36343643
sd_vid_gen_params->high_noise_sample_params.sample_method,
3644+
sd_vid_gen_params->high_noise_sample_params.ge_gamma,
36353645
high_noise_sigmas,
36363646
-1,
36373647
{},
@@ -3669,6 +3679,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
36693679
sd_vid_gen_params->sample_params.eta,
36703680
sd_vid_gen_params->sample_params.shifted_timestep,
36713681
sd_vid_gen_params->sample_params.sample_method,
3682+
sd_vid_gen_params->sample_params.ge_gamma,
36723683
sigmas,
36733684
-1,
36743685
{},

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum sample_method_t {
4949
DDIM_TRAILING,
5050
TCD,
5151
EULER_A,
52+
GRADIENT_ESTIMATION,
5253
SAMPLE_METHOD_COUNT
5354
};
5455

@@ -219,6 +220,7 @@ typedef struct {
219220
enum sample_method_t sample_method;
220221
int sample_steps;
221222
float eta;
223+
float ge_gamma;
222224
int shifted_timestep;
223225
} sd_sample_params_t;
224226

0 commit comments

Comments
 (0)