@@ -126,9 +126,13 @@ struct SDParams {
126126 int upscale_repeats = 1 ;
127127
128128 std::vector<int > skip_layers = {7 , 8 , 9 };
129- float slg_scale = 0 .f ;
129+ float slg_scale = 0 .0f ;
130130 float skip_layer_start = 0 .01f ;
131131 float skip_layer_end = 0 .2f ;
132+
133+ float apg_eta = 1 .0f ;
134+ float apg_momentum = 0 .0f ;
135+ float apg_norm_treshold = 0 .0f ;
132136};
133137
134138void print_params (SDParams params) {
@@ -213,6 +217,9 @@ void print_usage(int argc, const char* argv[]) {
213217 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
214218 printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
215219 printf (" --guidance SCALE guidance scale for img2img (default: 3.5)\n " );
220+ printf (" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n " );
221+ printf (" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n " );
222+ printf (" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n " );
216223 printf (" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n " );
217224 printf (" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n " );
218225 printf (" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n " );
@@ -629,6 +636,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629636 break ;
630637 }
631638 params.skip_layer_end = std::stof (argv[i]);
639+ } else if (arg == " --apg-eta" ) {
640+ if (++i >= argc) {
641+ invalid_arg = true ;
642+ break ;
643+ }
644+ params.apg_eta = std::stof (argv[i]);
645+ } else if (arg == " --apg-momentum" ) {
646+ if (++i >= argc) {
647+ invalid_arg = true ;
648+ break ;
649+ }
650+ params.apg_momentum = std::stof (argv[i]);
651+ } else if (arg == " --apg-nt" || arg == " --apg-rescale" ) {
652+ if (++i >= argc) {
653+ invalid_arg = true ;
654+ break ;
655+ }
656+ params.apg_norm_treshold = std::stof (argv[i]);
632657 } else {
633658 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634659 print_usage (argc, argv);
@@ -968,7 +993,9 @@ int main(int argc, const char* argv[]) {
968993 params.slg_scale ,
969994 params.skip_layer_start ,
970995 params.skip_layer_end },
971- sd_apg_params_t {1 , 0 , 0 });
996+ sd_apg_params_t {params.apg_eta ,
997+ params.apg_momentum ,
998+ params.apg_norm_treshold });
972999 } else {
9731000 sd_image_t input_image = {(uint32_t )params.width ,
9741001 (uint32_t )params.height ,
@@ -1038,7 +1065,9 @@ int main(int argc, const char* argv[]) {
10381065 params.slg_scale ,
10391066 params.skip_layer_start ,
10401067 params.skip_layer_end },
1041- sd_apg_params_t {1 , 0 , 0 });
1068+ sd_apg_params_t {params.apg_eta ,
1069+ params.apg_momentum ,
1070+ params.apg_norm_treshold });
10421071 }
10431072 }
10441073
0 commit comments