@@ -802,7 +802,7 @@ class StableDiffusionGGML {
802802 SDCondition id_cond,
803803 sd_slg_params_t slg_params = {NULL , 0 , 0 , 0 , 0 },
804804 sd_apg_params_t apg_params = {1 , 0 , 0 },
805- ggml_tensor* noise_mask = nullptr ) {
805+ ggml_tensor* noise_mask = nullptr ) {
806806 std::vector<int > skip_layers (slg_params.skip_layers , slg_params.skip_layers + slg_params.skip_layers_count );
807807
808808 LOG_DEBUG (" Sample" );
@@ -963,39 +963,41 @@ class StableDiffusionGGML {
963963 float diff_norm = 0 ;
964964 float cond_norm_sq = 0 ;
965965 float dot = 0 ;
966- for (int i = 0 ; i < ne_elements; i++) {
967- float delta = positive_data[i] - negative_data[i];
968- if (apg_params.momentum != 0 ) {
969- delta += apg_params.momentum * apg_momentum_buffer[i];
970- apg_momentum_buffer[i] = delta;
966+ if (has_unconditioned) {
967+ for (int i = 0 ; i < ne_elements; i++) {
968+ float delta = positive_data[i] - negative_data[i];
969+ if (apg_params.momentum != 0 ) {
970+ delta += apg_params.momentum * apg_momentum_buffer[i];
971+ apg_momentum_buffer[i] = delta;
972+ }
973+ if (apg_params.norm_treshold > 0 ) {
974+ diff_norm += delta * delta;
975+ }
976+ if (apg_params.eta != 1 .0f ) {
977+ cond_norm_sq += positive_data[i] * positive_data[i];
978+ dot += positive_data[i] * delta;
979+ }
980+ deltas[i] = delta;
971981 }
972982 if (apg_params.norm_treshold > 0 ) {
973- diff_norm += delta * delta;
983+ diff_norm = std::sqrtf (diff_norm);
984+ apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
974985 }
975986 if (apg_params.eta != 1 .0f ) {
976- cond_norm_sq += positive_data[i] * positive_data[i];
977- dot += positive_data[i] * delta;
987+ dot *= apg_scale_factor;
988+ // pre-normalize (avoids one square root and ne_elements extra divs)
989+ dot /= cond_norm_sq;
978990 }
979- deltas[i] = delta;
980- }
981- if (apg_params.norm_treshold > 0 ) {
982- diff_norm = std::sqrtf (diff_norm);
983- apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
984- }
985- if (apg_params.eta != 1 .0f ) {
986- dot *= apg_scale_factor;
987- // pre-normalize (avoids one square root and ne_elements extra divs)
988- dot /= cond_norm_sq;
989- }
990991
991- for (int i = 0 ; i < ne_elements; i++) {
992- deltas[i] *= apg_scale_factor;
993- if (apg_params.eta != 1 .0f ) {
994- float apg_parallel = dot * positive_data[i];
995- float apg_orthogonal = deltas[i] - apg_parallel;
992+ for (int i = 0 ; i < ne_elements; i++) {
993+ deltas[i] *= apg_scale_factor;
994+ if (apg_params.eta != 1 .0f ) {
995+ float apg_parallel = dot * positive_data[i];
996+ float apg_orthogonal = deltas[i] - apg_parallel;
996997
997- // tweak deltas
998- deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
998+ // tweak deltas
999+ deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1000+ }
9991001 }
10001002 }
10011003
0 commit comments