Skip to content

Commit b50420e

Browse files
akleineSkutteOleg
authored andcommitted
feat: add code and doc for running SSD1B models
1 parent d45ee32 commit b50420e

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

model.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,12 @@ SDVersion ModelLoader::get_sd_version() {
18581858
if (is_ip2p) {
18591859
return VERSION_SDXL_PIX2PIX;
18601860
}
1861-
return VERSION_SDXL;
1861+
for (auto& tensor_storage : tensor_storages) {
1862+
if (tensor_storage.name.find("model.diffusion_model.middle_block.1") != std::string::npos) {
1863+
return VERSION_SDXL; // found a missing tensor in SSD1B, so it is SDXL
1864+
}
1865+
}
1866+
return VERSION_SDXL_SSD1B;
18621867
}
18631868

18641869
if (is_flux) {

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enum SDVersion {
2828
VERSION_SDXL,
2929
VERSION_SDXL_INPAINT,
3030
VERSION_SDXL_PIX2PIX,
31+
VERSION_SDXL_SSD1B,
3132
VERSION_SVD,
3233
VERSION_SD3,
3334
VERSION_FLUX,
@@ -56,7 +57,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
5657
}
5758

5859
static inline bool sd_version_is_sdxl(SDVersion version) {
59-
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) {
60+
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
6061
return true;
6162
}
6263
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ const char* model_version_to_str[] = {
5858
"SDXL",
5959
"SDXL Inpaint",
6060
"SDXL Instruct-Pix2Pix",
61+
"SDXL (SSD1B)",
6162
"SVD",
6263
"SD3.x",
6364
"Flux",

unet.hpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,14 @@ class UnetModelBlock : public GGMLBlock {
270270
n_head = ch / d_head;
271271
}
272272
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
273+
int td=transformer_depth[i];
274+
if (version == VERSION_SDXL_SSD1B) {
275+
if (i==2) td=4;
276+
}
273277
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
274278
n_head,
275279
d_head,
276-
transformer_depth[i],
280+
td,
277281
context_dim));
278282
}
279283
input_block_chans.push_back(ch);
@@ -296,13 +300,14 @@ class UnetModelBlock : public GGMLBlock {
296300
n_head = ch / d_head;
297301
}
298302
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
299-
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
300-
n_head,
301-
d_head,
302-
transformer_depth[transformer_depth.size() - 1],
303-
context_dim));
304-
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
305-
303+
if (version != VERSION_SDXL_SSD1B) {
304+
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
305+
n_head,
306+
d_head,
307+
transformer_depth[transformer_depth.size() - 1],
308+
context_dim));
309+
blocks["middle_block.2"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
310+
}
306311
// output_blocks
307312
int output_block_idx = 0;
308313
for (int i = (int)len_mults - 1; i >= 0; i--) {
@@ -324,7 +329,12 @@ class UnetModelBlock : public GGMLBlock {
324329
n_head = ch / d_head;
325330
}
326331
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
327-
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, transformer_depth[i], context_dim));
332+
int td = transformer_depth[i];
333+
if (version == VERSION_SDXL_SSD1B) {
334+
if (i==2 && (j==0 || j==1)) td=4;
335+
if (i==1 && (j==1 || j==2)) td=1;
336+
}
337+
blocks[name] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch, n_head, d_head, td, context_dim));
328338

329339
up_sample_idx++;
330340
}
@@ -478,9 +488,10 @@ class UnetModelBlock : public GGMLBlock {
478488

479489
// middle_block
480490
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
481-
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
482-
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
483-
491+
if (version != VERSION_SDXL_SSD1B) {
492+
h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
493+
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
494+
}
484495
if (controls.size() > 0) {
485496
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
486497
h = ggml_add(ctx, h, cs); // middle control

0 commit comments

Comments
 (0)