Skip to content

Commit 6e42ffa

Browse files
committed
add skip layer guidance support (mmdit only)
1 parent 4d70e16 commit 6e42ffa

File tree

5 files changed

+177
-23
lines changed

5 files changed

+177
-23
lines changed

diffusion_model.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ struct DiffusionModel {
1717
std::vector<struct ggml_tensor*> controls = {},
1818
float control_strength = 0.f,
1919
struct ggml_tensor** output = NULL,
20-
struct ggml_context* output_ctx = NULL) = 0;
20+
struct ggml_context* output_ctx = NULL,
21+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2122
virtual void alloc_params_buffer() = 0;
2223
virtual void free_params_buffer() = 0;
2324
virtual void free_compute_buffer() = 0;
@@ -70,7 +71,8 @@ struct UNetModel : public DiffusionModel {
7071
std::vector<struct ggml_tensor*> controls = {},
7172
float control_strength = 0.f,
7273
struct ggml_tensor** output = NULL,
73-
struct ggml_context* output_ctx = NULL) {
74+
struct ggml_context* output_ctx = NULL,
75+
std::vector<int> skip_layers = std::vector<int>()) {
7476
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7577
}
7678
};
@@ -119,8 +121,9 @@ struct MMDiTModel : public DiffusionModel {
119121
std::vector<struct ggml_tensor*> controls = {},
120122
float control_strength = 0.f,
121123
struct ggml_tensor** output = NULL,
122-
struct ggml_context* output_ctx = NULL) {
123-
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
124+
struct ggml_context* output_ctx = NULL,
125+
std::vector<int> skip_layers = std::vector<int>()) {
126+
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
124127
}
125128
};
126129

@@ -168,7 +171,8 @@ struct FluxModel : public DiffusionModel {
168171
std::vector<struct ggml_tensor*> controls = {},
169172
float control_strength = 0.f,
170173
struct ggml_tensor** output = NULL,
171-
struct ggml_context* output_ctx = NULL) {
174+
struct ggml_context* output_ctx = NULL,
175+
std::vector<int> skip_layers = std::vector<int>()) {
172176
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
173177
}
174178
};

examples/cli/main.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ struct SDParams {
119119
bool canny_preprocess = false;
120120
bool color = false;
121121
int upscale_repeats = 1;
122+
123+
std::vector<int> skip_layers = {7, 8, 9};
124+
float slg_scale = 2.5;
125+
float skip_layer_start = 0.01;
126+
float skip_layer_end = 0.2;
122127
};
123128

124129
void print_params(SDParams params) {
@@ -197,6 +202,11 @@ void print_usage(int argc, const char* argv[]) {
197202
printf(" -p, --prompt [PROMPT] the prompt to render\n");
198203
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
199204
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
205+
printf(" --slg enable skip layer guidance (CFG variant)\n");
206+
printf(" --skip_layers LAYERS Layers to skip for skip layer CFG (requires --slg): (default: [7,8,9])\n");
207+
printf(" --slg-scale SCALE skip layer guidance scale (requires --slg): (default: 2.5)\n");
208+
printf(" --skip_layer_start START skip layer enabling point (* steps) (requires --slg): (default: 0.01)\n");
209+
printf(" --skip_layer_end END skip layer enabling point (* steps) (requires --slg): (default: 0.2)\n");
200210
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
201211
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
202212
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
@@ -223,6 +233,7 @@ void print_usage(int argc, const char* argv[]) {
223233

224234
void parse_args(int argc, const char** argv, SDParams& params) {
225235
bool invalid_arg = false;
236+
bool cfg_skip = false;
226237
std::string arg;
227238
for (int i = 1; i < argc; i++) {
228239
arg = argv[i];
@@ -534,6 +545,63 @@ void parse_args(int argc, const char** argv, SDParams& params) {
534545
params.verbose = true;
535546
} else if (arg == "--color") {
536547
params.color = true;
548+
} else if (arg == "--slg") {
549+
cfg_skip = true;
550+
} else if (arg == "--skip-layers") {
551+
if (++i >= argc) {
552+
invalid_arg = true;
553+
break;
554+
}
555+
if (argv[i][0] != '[') {
556+
invalid_arg = true;
557+
break;
558+
}
559+
std::string layers_str = argv[i];
560+
while (layers_str.back() != ']') {
561+
if (++i >= argc) {
562+
invalid_arg = true;
563+
break;
564+
}
565+
layers_str += " " + std::string(argv[i]);
566+
}
567+
layers_str = layers_str.substr(1, layers_str.size() - 2);
568+
569+
std::regex regex("[, ]+");
570+
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
571+
std::sregex_token_iterator end;
572+
std::vector<std::string> tokens(iter, end);
573+
std::vector<int> layers;
574+
for (const auto& token : tokens) {
575+
try {
576+
layers.push_back(std::stoi(token));
577+
} catch (const std::invalid_argument& e) {
578+
invalid_arg = true;
579+
break;
580+
}
581+
}
582+
params.skip_layers = layers;
583+
584+
if (invalid_arg) {
585+
break;
586+
}
587+
} else if (arg == "--slg-scale") {
588+
if (++i >= argc) {
589+
invalid_arg = true;
590+
break;
591+
}
592+
params.slg_scale = std::stof(argv[i]);
593+
} else if (arg == "--skip-layer-start") {
594+
if (++i >= argc) {
595+
invalid_arg = true;
596+
break;
597+
}
598+
params.skip_layer_start = std::stof(argv[i]);
599+
} else if (arg == "--skip-layer-end") {
600+
if (++i >= argc) {
601+
invalid_arg = true;
602+
break;
603+
}
604+
params.skip_layer_end = std::stof(argv[i]);
537605
} else {
538606
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
539607
print_usage(argc, argv);
@@ -549,6 +617,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
549617
params.n_threads = get_num_physical_cores();
550618
}
551619

620+
if (!cfg_skip) {
621+
// set skip_layers to empty
622+
params.skip_layers.clear();
623+
}
624+
552625
if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) {
553626
fprintf(stderr, "error: the following arguments are required: prompt\n");
554627
print_usage(argc, argv);
@@ -840,7 +913,11 @@ int main(int argc, const char* argv[]) {
840913
params.control_strength,
841914
params.style_ratio,
842915
params.normalize_input,
843-
params.input_id_images_path.c_str());
916+
params.input_id_images_path.c_str(),
917+
params.skip_layers,
918+
params.slg_scale,
919+
params.skip_layer_start,
920+
params.skip_layer_end);
844921
} else {
845922
sd_image_t input_image = {(uint32_t)params.width,
846923
(uint32_t)params.height,

mmdit.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,20 @@ struct MMDiT : public GGMLBlock {
801801
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
802802
struct ggml_tensor* x,
803803
struct ggml_tensor* c_mod,
804-
struct ggml_tensor* context) {
804+
struct ggml_tensor* context,
805+
std::vector<int> skip_layers = std::vector<int>()) {
805806
// x: [N, H*W, hidden_size]
806807
// context: [N, n_context, d_context]
807808
// c: [N, hidden_size]
808809
// return: [N, N*W, patch_size * patch_size * out_channels]
809810
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
810811

811812
for (int i = 0; i < depth; i++) {
813+
// skip iteration if i is in skip_layers
814+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
815+
continue;
816+
}
817+
812818
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
813819

814820
auto context_x = block->forward(ctx, context, x, c_mod);
@@ -824,8 +830,9 @@ struct MMDiT : public GGMLBlock {
824830
struct ggml_tensor* forward(struct ggml_context* ctx,
825831
struct ggml_tensor* x,
826832
struct ggml_tensor* t,
827-
struct ggml_tensor* y = NULL,
828-
struct ggml_tensor* context = NULL) {
833+
struct ggml_tensor* y = NULL,
834+
struct ggml_tensor* context = NULL,
835+
std::vector<int> skip_layers = std::vector<int>()) {
829836
// Forward pass of DiT.
830837
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
831838
// t: (N,) tensor of diffusion timesteps
@@ -856,7 +863,7 @@ struct MMDiT : public GGMLBlock {
856863
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
857864
}
858865

859-
x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels)
866+
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
860867

861868
x = unpatchify(ctx, x, h, w); // [N, C, H, W]
862869

@@ -885,7 +892,8 @@ struct MMDiTRunner : public GGMLRunner {
885892
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
886893
struct ggml_tensor* timesteps,
887894
struct ggml_tensor* context,
888-
struct ggml_tensor* y) {
895+
struct ggml_tensor* y,
896+
std::vector<int> skip_layers = std::vector<int>()) {
889897
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false);
890898

891899
x = to_backend(x);
@@ -897,7 +905,8 @@ struct MMDiTRunner : public GGMLRunner {
897905
x,
898906
timesteps,
899907
y,
900-
context);
908+
context,
909+
skip_layers);
901910

902911
ggml_build_forward_expand(gf, out);
903912

@@ -910,13 +919,14 @@ struct MMDiTRunner : public GGMLRunner {
910919
struct ggml_tensor* context,
911920
struct ggml_tensor* y,
912921
struct ggml_tensor** output = NULL,
913-
struct ggml_context* output_ctx = NULL) {
922+
struct ggml_context* output_ctx = NULL,
923+
std::vector<int> skip_layers = std::vector<int>()) {
914924
// x: [N, in_channels, h, w]
915925
// timesteps: [N, ]
916926
// context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size]
917927
// y: [N, adm_in_channels] or [1, adm_in_channels]
918928
auto get_graph = [&]() -> struct ggml_cgraph* {
919-
return build_graph(x, timesteps, context, y);
929+
return build_graph(x, timesteps, context, y, skip_layers);
920930
};
921931

922932
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);

stable-diffusion.cpp

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,11 @@ class StableDiffusionGGML {
772772
sample_method_t method,
773773
const std::vector<float>& sigmas,
774774
int start_merge_step,
775-
SDCondition id_cond) {
775+
SDCondition id_cond,
776+
std::vector<int> skip_layers = {},
777+
float slg_scale = 2.5,
778+
float skip_layer_start = 0.01,
779+
float skip_layer_end = 0.2) {
776780
size_t steps = sigmas.size() - 1;
777781
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
778782
// print_ggml_tensor(noise);
@@ -870,6 +874,30 @@ class StableDiffusionGGML {
870874
&out_uncond);
871875
negative_data = (float*)out_uncond->data;
872876
}
877+
878+
bool has_skiplayer = skip_layers.size() > 0;
879+
int stepCount = sigmas.size();
880+
has_skiplayer = has_skiplayer && step > (int)(skip_layer_start * stepCount) && step < (int)(skip_layer_end * stepCount);
881+
float* skip_layer_data = NULL;
882+
if (has_skiplayer) {
883+
LOG_DEBUG("Skipping layers at step %d\n", step);
884+
ggml_tensor* out_skip = ggml_dup_tensor(work_ctx, x);
885+
// skip layer (same as conditionned)
886+
diffusion_model->compute(n_threads,
887+
noised_input,
888+
timesteps,
889+
cond.c_crossattn,
890+
cond.c_concat,
891+
cond.c_vector,
892+
guidance_tensor,
893+
-1,
894+
controls,
895+
control_strength,
896+
&out_skip,
897+
NULL,
898+
skip_layers);
899+
skip_layer_data = (float*)out_skip->data;
900+
}
873901
float* vec_denoised = (float*)denoised->data;
874902
float* vec_input = (float*)input->data;
875903
float* positive_data = (float*)out_cond->data;
@@ -886,6 +914,9 @@ class StableDiffusionGGML {
886914
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
887915
}
888916
}
917+
if (has_skiplayer) {
918+
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale;
919+
}
889920
// v = latent_result, eps = latent_result
890921
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
891922
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
@@ -1112,7 +1143,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
11121143
float control_strength,
11131144
float style_ratio,
11141145
bool normalize_input,
1115-
std::string input_id_images_path) {
1146+
std::string input_id_images_path,
1147+
std::vector<int> skip_layers = {},
1148+
float slg_scale = 2.5,
1149+
float skip_layer_start = 0.01,
1150+
float skip_layer_end = 0.2) {
11161151
if (seed < 0) {
11171152
// Generally, when using the provided command line, the seed is always >0.
11181153
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1321,7 +1356,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13211356
sample_method,
13221357
sigmas,
13231358
start_merge_step,
1324-
id_cond);
1359+
id_cond,
1360+
skip_layers,
1361+
slg_scale,
1362+
skip_layer_start,
1363+
skip_layer_end);
13251364
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
13261365
// print_ggml_tensor(x_0);
13271366
int64_t sampling_end = ggml_time_ms();
@@ -1387,7 +1426,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
13871426
float control_strength,
13881427
float style_ratio,
13891428
bool normalize_input,
1390-
const char* input_id_images_path_c_str) {
1429+
const char* input_id_images_path_c_str,
1430+
std::vector<int> skip_layers,
1431+
float slg_scale,
1432+
float skip_layer_start,
1433+
float skip_layer_end) {
13911434
LOG_DEBUG("txt2img %dx%d", width, height);
13921435
if (sd_ctx == NULL) {
13931436
return NULL;
@@ -1455,7 +1498,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14551498
control_strength,
14561499
style_ratio,
14571500
normalize_input,
1458-
input_id_images_path_c_str);
1501+
input_id_images_path_c_str,
1502+
skip_layers,
1503+
slg_scale,
1504+
skip_layer_start,
1505+
skip_layer_end);
14591506

14601507
size_t t1 = ggml_time_ms();
14611508

@@ -1482,7 +1529,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
14821529
float control_strength,
14831530
float style_ratio,
14841531
bool normalize_input,
1485-
const char* input_id_images_path_c_str) {
1532+
const char* input_id_images_path_c_str,
1533+
std::vector<int> skip_layers,
1534+
float slg_scale,
1535+
float skip_layer_start,
1536+
float skip_layer_end) {
14861537
LOG_DEBUG("img2img %dx%d", width, height);
14871538
if (sd_ctx == NULL) {
14881539
return NULL;
@@ -1556,7 +1607,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
15561607
control_strength,
15571608
style_ratio,
15581609
normalize_input,
1559-
input_id_images_path_c_str);
1610+
input_id_images_path_c_str,
1611+
skip_layers,
1612+
slg_scale,
1613+
skip_layer_start,
1614+
skip_layer_end);
15601615

15611616
size_t t2 = ggml_time_ms();
15621617

0 commit comments

Comments
 (0)