diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md new file mode 100644 index 000000000..7e38cb35f --- /dev/null +++ b/docs/distilled_sd.md @@ -0,0 +1,18 @@ +# Running distilled SDXL models: SSD1B + +### Preface + +This kind of models has a reduced U-Net part. Unlike other SDXL models the U-Net has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 . + +### How to Use + +Unfortunately not all of this models follow the standard model parameter naming mapping. +Anyway there are some useful SSD1B models available online, such as: + + * https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors + * https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors + +Also there are useful LORAs available: + + * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors + * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors diff --git a/model.cpp b/model.cpp index b45493cc4..1751cb66b 100644 --- a/model.cpp +++ b/model.cpp @@ -1859,7 +1859,12 @@ SDVersion ModelLoader::get_sd_version() { if (is_ip2p) { return VERSION_SDXL_PIX2PIX; } - return VERSION_SDXL; + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) { + return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL + } + } + return VERSION_SDXL_SSD1B; } if (is_flux) { diff --git a/model.h b/model.h index 069bb0c21..c6637dd0b 100644 --- a/model.h +++ b/model.h @@ -27,6 +27,7 @@ enum SDVersion { VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, + VERSION_SDXL_SSD1B, VERSION_SVD, VERSION_SD3, VERSION_FLUX, @@ -55,7 +56,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 87b6a3779..0429624a3 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -33,6 +33,7 @@ const char* model_version_to_str[] = { "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", + "SDXL (SSD1B)", "SVD", "SD3.x", "Flux", diff --git a/unet.hpp b/unet.hpp index 19bedb32b..6dde9bcc8 100644 --- a/unet.hpp +++ b/unet.hpp @@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; + int td=transformer_depth[i]; + if (version == VERSION_SDXL_SSD1B) { + if (i==2) td=4; + } blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, - transformer_depth[i], + td, context_dim)); } input_block_chans.push_back(ch); @@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, - n_head, - d_head, - transformer_depth[transformer_depth.size() - 1], - context_dim)); - blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); - + if (version != VERSION_SDXL_SSD1B) { + blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, + n_head, + d_head, + transformer_depth[transformer_depth.size() - 1], + context_dim)); + blocks["middle_block.2"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); + } // output_blocks int output_block_idx = 0; for (int i = (int)len_mults - 1; i >= 0; i--) { @@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock { n_head = ch / d_head; } std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim)); + int td = transformer_depth[i]; + if (version == VERSION_SDXL_SSD1B) { + if (i==2 && (j==0 || j==1)) td=4; + if (i==1 && (j==1 || j==2)) td=1; + } + blocks[name] = std::shared_ptr(get_attention_layer(ch, n_head, d_head, td, context_dim)); up_sample_idx++; } @@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock { // middle_block h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - + if (version != VERSION_SDXL_SSD1B) { + h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + } if (controls.size() > 0) { auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); h = ggml_add(ctx, h, cs); // middle control