Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/distilled_sd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Running distilled SDXL models: SSD1B

### Preface

This kind of models has a reduced U-Net part. Unlike other SDXL models the U-Net has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 .

### How to Use

Unfortunately not all of this models follow the standard model parameter naming mapping.
Anyway there are some useful SSD1B models available online, such as:

* https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors
* https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors

Also there are useful LORAs available:

* https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors
* https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors
7 changes: 6 additions & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1859,7 +1859,12 @@ SDVersion ModelLoader::get_sd_version() {
if (is_ip2p) {
return VERSION_SDXL_PIX2PIX;
}
return VERSION_SDXL;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) {
return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL
}
}
return VERSION_SDXL_SSD1B;
}

if (is_flux) {
Expand Down
3 changes: 2 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum SDVersion {
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
VERSION_SDXL_SSD1B,
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
Expand Down Expand Up @@ -55,7 +56,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
}

static inline bool sd_version_is_sdxl(SDVersion version) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
return true;
}
return false;
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const char* model_version_to_str[] = {
"SDXL",
"SDXL Inpaint",
"SDXL Instruct-Pix2Pix",
"SDXL (SSD1B)",
"SVD",
"SD3.x",
"Flux",
Expand Down
35 changes: 23 additions & 12 deletions unet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock {
n_head = ch / d_head;
}
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
int td=transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i==2) td=4;
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[i],
td,
context_dim));
}
input_block_chans.push_back(ch);
Expand All @@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock {
n_head = ch / d_head;
}
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[transformer_depth.size() - 1],
context_dim));
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));

if (version != VERSION_SDXL_SSD1B) {
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
n_head,
d_head,
transformer_depth[transformer_depth.size() - 1],
context_dim));
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
}
// output_blocks
int output_block_idx = 0;
for (int i = (int)len_mults - 1; i >= 0; i--) {
Expand All @@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock {
n_head = ch / d_head;
}
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim));
int td = transformer_depth[i];
if (version == VERSION_SDXL_SSD1B) {
if (i==2 && (j==0 || j==1)) td=4;
if (i==1 && (j==1 || j==2)) td=1;
}
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));

up_sample_idx++;
}
Expand Down Expand Up @@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock {

// middle_block
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]

if (version != VERSION_SDXL_SSD1B) {
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
}
if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
h = ggml_add(ctx, h, cs); // middle control
Expand Down