Skip to content

Commit 76ce0a6

Browse files
committed
Add gradient estimation sampler
1 parent aff9d2d commit 76ce0a6

File tree

4 files changed

+63
-3
lines changed

4 files changed

+63
-3
lines changed

denoiser.hpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,8 @@ static void sample_k_diffusion(sample_method_t method,
672672
ggml_tensor* x,
673673
std::vector<float> sigmas,
674674
std::shared_ptr<RNG> rng,
675-
float eta) {
675+
float eta,
676+
float ge_gamma) {
676677
size_t steps = sigmas.size() - 1;
677678
// sample_euler_ancestral
678679
switch (method) {
@@ -1558,7 +1559,52 @@ static void sample_k_diffusion(sample_method_t method,
15581559
}
15591560
}
15601561
} break;
1562+
case GRADIENT_ESTIMATION: {
1563+
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
1564+
struct ggml_tensor* old_d = ggml_dup_tensor(work_ctx, x);
1565+
bool has_old_d = false;
15611566

1567+
for (int i = 0; i < steps; i++) {
1568+
float sigma = sigmas[i];
1569+
1570+
ggml_tensor* denoised = model(x, sigma, i + 1);
1571+
1572+
// d = (x - denoised) / sigma
1573+
float* vec_d = (float*)d->data;
1574+
float* vec_x = (float*)x->data;
1575+
float* vec_denoised = (float*)denoised->data;
1576+
1577+
for (int j = 0; j < ggml_nelements(d); j++) {
1578+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigma;
1579+
}
1580+
1581+
float dt = sigmas[i + 1] - sigma;
1582+
1583+
if (sigmas[i + 1] == 0) {
1584+
// Denoising step
1585+
for (int j = 0; j < ggml_nelements(x); j++) {
1586+
vec_x[j] = vec_denoised[j];
1587+
}
1588+
} else {
1589+
// Euler method
1590+
for (int j = 0; j < ggml_nelements(x); j++) {
1591+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
1592+
}
1593+
}
1594+
1595+
if (has_old_d) {
1596+
// Gradient estimation
1597+
float* vec_old_d = (float*)old_d->data;
1598+
for (int j = 0; j < ggml_nelements(x); j++) {
1599+
float d_bar = (ge_gamma - 1.f) * (vec_d[j] - vec_old_d[j]);
1600+
vec_x[j] = vec_x[j] + d_bar * dt;
1601+
}
1602+
}
1603+
// old_d = d
1604+
copy_ggml_tensor(old_d, d);
1605+
has_old_d = true;
1606+
}
1607+
} break;
15621608
default:
15631609
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
15641610
abort();

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
725725
"eta in DDIM, only for DDIM and TCD (default: 0)",
726726
&params.sample_params.eta},
727727
{"",
728-
"--high-noise-cfg-scale",
728+
"--ge-gamma", "", &params.sample_params.ge_gamma},
729+
{"", "--high-noise-cfg-scale",
729730
"(high noise) unconditional guidance scale: (default: 7.0)",
730731
&params.high_noise_sample_params.guidance.txt_cfg},
731732
{"",

stable-diffusion.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ const char* sampling_methods_str[] = {
6161
"LCM",
6262
"DDIM \"trailing\"",
6363
"TCD",
64+
"Gradient Estimation",
6465
};
6566

6667
/*================================================== Helper Functions ================================================*/
@@ -1479,6 +1480,7 @@ class StableDiffusionGGML {
14791480
float eta,
14801481
int shifted_timestep,
14811482
sample_method_t method,
1483+
float ge_gamma,
14821484
const std::vector<float>& sigmas,
14831485
int start_merge_step,
14841486
SDCondition id_cond,
@@ -1837,7 +1839,7 @@ class StableDiffusionGGML {
18371839
return denoised;
18381840
};
18391841

1840-
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
1842+
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta, ge_gamma);
18411843

18421844
if (easycache_enabled) {
18431845
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
@@ -2343,6 +2345,7 @@ const char* sample_method_to_str[] = {
23432345
"lcm",
23442346
"ddim_trailing",
23452347
"tcd",
2348+
"gradient_estimation",
23462349
};
23472350

23482351
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -2574,6 +2577,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
25742577
sample_params->scheduler = SCHEDULER_COUNT;
25752578
sample_params->sample_method = SAMPLE_METHOD_COUNT;
25762579
sample_params->sample_steps = 20;
2580+
sample_params->ge_gamma = 2.0f;
25772581
}
25782582

25792583
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -2594,6 +2598,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
25942598
"sample_method: %s, "
25952599
"sample_steps: %d, "
25962600
"eta: %.2f, "
2601+
"ge_gamma: %.2f, "
25972602
"shifted_timestep: %d)",
25982603
sample_params->guidance.txt_cfg,
25992604
std::isfinite(sample_params->guidance.img_cfg)
@@ -2608,6 +2613,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
26082613
sd_sample_method_name(sample_params->sample_method),
26092614
sample_params->sample_steps,
26102615
sample_params->eta,
2616+
sample_params->ge_gamma,
26112617
sample_params->shifted_timestep);
26122618

26132619
return buf;
@@ -2759,6 +2765,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
27592765
int width,
27602766
int height,
27612767
enum sample_method_t sample_method,
2768+
float ge_gamma,
27622769
const std::vector<float>& sigmas,
27632770
int64_t seed,
27642771
int batch_count,
@@ -3053,6 +3060,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
30533060
eta,
30543061
shifted_timestep,
30553062
sample_method,
3063+
ge_gamma,
30563064
sigmas,
30573065
start_merge_step,
30583066
id_cond,
@@ -3370,6 +3378,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
33703378
width,
33713379
height,
33723380
sample_method,
3381+
sd_img_gen_params->sample_params.ge_gamma,
33733382
sigmas,
33743383
seed,
33753384
sd_img_gen_params->batch_count,
@@ -3706,6 +3715,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
37063715
sd_vid_gen_params->high_noise_sample_params.eta,
37073716
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
37083717
high_noise_sample_method,
3718+
sd_vid_gen_params->high_noise_sample_params.ge_gamma,
37093719
high_noise_sigmas,
37103720
-1,
37113721
{},
@@ -3743,6 +3753,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
37433753
sd_vid_gen_params->sample_params.eta,
37443754
sd_vid_gen_params->sample_params.shifted_timestep,
37453755
sample_method,
3756+
sd_vid_gen_params->sample_params.ge_gamma,
37463757
sigmas,
37473758
-1,
37483759
{},

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum sample_method_t {
4848
LCM_SAMPLE_METHOD,
4949
DDIM_TRAILING_SAMPLE_METHOD,
5050
TCD_SAMPLE_METHOD,
51+
GRADIENT_ESTIMATION,
5152
SAMPLE_METHOD_COUNT
5253
};
5354

@@ -219,6 +220,7 @@ typedef struct {
219220
enum sample_method_t sample_method;
220221
int sample_steps;
221222
float eta;
223+
float ge_gamma;
222224
int shifted_timestep;
223225
} sd_sample_params_t;
224226

0 commit comments

Comments
 (0)