@@ -847,6 +847,15 @@ class StableDiffusionGGML {
847847 }
848848 struct ggml_tensor * denoised = ggml_dup_tensor (work_ctx, x);
849849
850+ // TODO do not hardcode
851+ float apg_eta = .08f ;
852+ float apg_momentum = -.5f ;
853+ float apg_norm_treshold = 15 .0f ;
854+
855+ std::vector<float > apg_momentum_buffer;
856+ if (apg_momentum != 0 )
857+ apg_momentum_buffer.resize ((size_t )ggml_nelements (denoised));
858+
850859 auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
851860 if (step == 1 ) {
852861 pretty_progress (0 , (int )steps, 0 );
@@ -951,6 +960,50 @@ class StableDiffusionGGML {
951960 float * vec_input = (float *)input->data ;
952961 float * positive_data = (float *)out_cond->data ;
953962 int ne_elements = (int )ggml_nelements (denoised);
963+
964+ float * deltas = vec_denoised;
965+
966+ // https://arxiv.org/pdf/2410.02416
967+ float apg_scale_factor = 1 .;
968+ float diff_norm = 0 ;
969+ float cond_norm_sq = 0 ;
970+ float dot = 0 ;
971+ for (int i = 0 ; i < ne_elements; i++) {
972+ float delta = positive_data[i] - negative_data[i];
973+ if (apg_momentum != 0 ) {
974+ delta += apg_momentum * apg_momentum_buffer[i];
975+ apg_momentum_buffer[i] = delta;
976+ }
977+ if (apg_norm_treshold > 0 ) {
978+ diff_norm += delta * delta;
979+ }
980+ if (apg_eta != 1 .0f ) {
981+ cond_norm_sq += positive_data[i] * positive_data[i];
982+ dot += positive_data[i] * delta;
983+ }
984+ deltas[i] = delta;
985+ }
986+ if (apg_norm_treshold > 0 ) {
987+ diff_norm = std::sqrtf (diff_norm);
988+ apg_scale_factor = std::min (1 .0f , apg_norm_treshold / diff_norm);
989+ }
990+ if (apg_eta != 1 .0f ) {
991+ dot *= apg_scale_factor;
992+ // pre-normalize (avoids one square root and ne_elements extra divs)
993+ dot /= cond_norm_sq;
994+ }
995+
996+ for (int i = 0 ; i < ne_elements; i++) {
997+ deltas[i] *= apg_scale_factor;
998+ if (apg_eta != 1 .0f ) {
999+ float apg_parallel = dot * positive_data[i];
1000+ float apg_orthogonal = deltas[i] - apg_parallel;
1001+
1002+ // tweak deltas
1003+ deltas[i] = apg_orthogonal + apg_eta * apg_parallel;
1004+ }
1005+ }
1006+
9541007 for (int i = 0 ; i < ne_elements; i++) {
9551008 float latent_result = positive_data[i];
9561009 if (has_unconditioned) {
@@ -960,7 +1013,9 @@ class StableDiffusionGGML {
9601013 int64_t i3 = i / out_cond->ne [0 ] * out_cond->ne [1 ] * out_cond->ne [2 ];
9611014 float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1 .0f / ne3);
9621015 } else {
963- latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
1016+ float delta = deltas[i];
1017+
1018+ latent_result = positive_data[i] + (cfg_scale - 1 ) * delta;
9641019 }
9651020 }
9661021 if (is_skiplayer_step) {
@@ -1004,7 +1059,8 @@ class StableDiffusionGGML {
10041059 }
10051060
10061061 // ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding
1007- ggml_tensor* get_first_stage_encoding (ggml_context* work_ctx, ggml_tensor* moments) {
1062+ ggml_tensor*
1063+ get_first_stage_encoding (ggml_context* work_ctx, ggml_tensor* moments) {
10081064 // ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
10091065 ggml_tensor* latent = ggml_new_tensor_4d (work_ctx, moments->type , moments->ne [0 ], moments->ne [1 ], moments->ne [2 ] / 2 , moments->ne [3 ]);
10101066 struct ggml_tensor * noise = ggml_dup_tensor (work_ctx, latent);
0 commit comments