Skip to content

Commit e7eabd3

Browse files
committed
Refactor: Flexible sd3 arch
1 parent 8e7fbf8 commit e7eabd3

File tree

5 files changed

+46
-68
lines changed

5 files changed

+46
-68
lines changed

diffusion_model.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,8 @@ struct MMDiTModel : public DiffusionModel {
8383
MMDiTRunner mmdit;
8484

8585
MMDiTModel(ggml_backend_t backend,
86-
std::map<std::string, enum ggml_type>& tensor_types,
87-
SDVersion version = VERSION_SD3_2B)
88-
: mmdit(backend, tensor_types, "model.diffusion_model", version) {
86+
std::map<std::string, enum ggml_type>& tensor_types)
87+
: mmdit(backend, tensor_types, "model.diffusion_model") {
8988
}
9089

9190
void alloc_params_buffer() {

mmdit.hpp

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,6 @@ struct FinalLayer : public GGMLBlock {
637637
struct MMDiT : public GGMLBlock {
638638
// Diffusion model with a Transformer backbone.
639639
protected:
640-
SDVersion version = VERSION_SD3_2B;
641640
int64_t input_size = -1;
642641
int64_t patch_size = 2;
643642
int64_t in_channels = 16;
@@ -659,8 +658,7 @@ struct MMDiT : public GGMLBlock {
659658
}
660659

661660
public:
662-
MMDiT(SDVersion version = VERSION_SD3_2B)
663-
: version(version) {
661+
MMDiT(std::map<std::string, enum ggml_type>& tensor_types) {
664662
// input_size is always None
665663
// learn_sigma is always False
666664
// register_length is alwalys 0
@@ -672,48 +670,44 @@ struct MMDiT : public GGMLBlock {
672670
// pos_embed_scaling_factor is not used
673671
// pos_embed_offset is not used
674672
// context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
675-
if (version == VERSION_SD3_2B) {
676-
input_size = -1;
677-
patch_size = 2;
678-
in_channels = 16;
679-
depth = 24;
680-
mlp_ratio = 4.0f;
681-
adm_in_channels = 2048;
682-
out_channels = 16;
683-
pos_embed_max_size = 192;
684-
num_patchs = 36864; // 192 * 192
685-
context_size = 4096;
686-
context_embedder_out_dim = 1536;
687-
} else if (version == VERSION_SD3_5_8B) {
688-
input_size = -1;
689-
patch_size = 2;
690-
in_channels = 16;
691-
depth = 38;
692-
mlp_ratio = 4.0f;
693-
adm_in_channels = 2048;
694-
out_channels = 16;
695-
pos_embed_max_size = 192;
696-
num_patchs = 36864; // 192 * 192
697-
context_size = 4096;
698-
context_embedder_out_dim = 2432;
699-
qk_norm = "rms";
700-
} else if (version == VERSION_SD3_5_2B) {
701-
input_size = -1;
702-
patch_size = 2;
703-
in_channels = 16;
704-
depth = 24;
705-
d_self = 12;
706-
mlp_ratio = 4.0f;
707-
adm_in_channels = 2048;
708-
out_channels = 16;
709-
pos_embed_max_size = 384;
710-
num_patchs = 147456;
711-
context_size = 4096;
712-
context_embedder_out_dim = 1536;
713-
qk_norm = "rms";
673+
674+
// read tensors from tensor_types
675+
for (auto pair : tensor_types) {
676+
std::string tensor_name = pair.first;
677+
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
678+
continue;
679+
size_t jb = tensor_name.find("joint_blocks.");
680+
if (jb != std::string::npos) {
681+
tensor_name = tensor_name.substr(jb); // remove prefix
682+
int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str());
683+
if (block_depth + 1 > depth) {
684+
depth = block_depth + 1;
685+
}
686+
if (tensor_name.find("attn.ln") != std::string::npos) {
687+
if (tensor_name.find(".bias") != std::string::npos) {
688+
qk_norm = "ln";
689+
} else {
690+
qk_norm = "rms";
691+
}
692+
}
693+
if (tensor_name.find("attn2") != std::string::npos) {
694+
if (block_depth > d_self) {
695+
d_self = block_depth;
696+
}
697+
}
698+
}
714699
}
700+
701+
if (d_self >= 0) {
702+
pos_embed_max_size *= 2;
703+
num_patchs *= 4;
704+
}
705+
706+
LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1);
707+
715708
int64_t default_out_channels = in_channels;
716709
hidden_size = 64 * depth;
710+
context_embedder_out_dim = 64 * depth;
717711
int64_t num_heads = depth;
718712

719713
blocks["x_embedder"] = std::shared_ptr<GGMLBlock>(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true));
@@ -879,9 +873,8 @@ struct MMDiTRunner : public GGMLRunner {
879873

880874
MMDiTRunner(ggml_backend_t backend,
881875
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
882-
const std::string prefix = "",
883-
SDVersion version = VERSION_SD3_2B)
884-
: GGMLRunner(backend), mmdit(version) {
876+
const std::string prefix = "")
877+
: GGMLRunner(backend), mmdit(tensor_types) {
885878
mmdit.init(params_ctx, tensor_types, prefix);
886879
}
887880

model.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,6 @@ SDVersion ModelLoader::get_sd_version() {
14621462
bool is_flux = false;
14631463
bool is_schnell = true;
14641464
bool is_lite = true;
1465-
bool is_sd3 = false;
14661465
for (auto& tensor_storage : tensor_storages) {
14671466
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
14681467
is_schnell = false;
@@ -1473,14 +1472,8 @@ SDVersion ModelLoader::get_sd_version() {
14731472
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
14741473
is_lite = false;
14751474
}
1476-
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
1477-
return VERSION_SD3_5_2B;
1478-
}
1479-
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
1480-
return VERSION_SD3_5_8B;
1481-
}
1482-
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
1483-
is_sd3 = true;
1475+
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
1476+
return VERSION_SD3;
14841477
}
14851478
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
14861479
return VERSION_SDXL;
@@ -1512,9 +1505,6 @@ SDVersion ModelLoader::get_sd_version() {
15121505
return VERSION_FLUX_DEV;
15131506
}
15141507
}
1515-
if (is_sd3) {
1516-
return VERSION_SD3_2B;
1517-
}
15181508
if (token_embedding_weight.ne[0] == 768) {
15191509
return VERSION_SD1;
15201510
} else if (token_embedding_weight.ne[0] == 1024) {

model.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ enum SDVersion {
2222
VERSION_SD2,
2323
VERSION_SDXL,
2424
VERSION_SVD,
25-
VERSION_SD3_2B,
25+
VERSION_SD3,
2626
VERSION_FLUX_DEV,
2727
VERSION_FLUX_SCHNELL,
28-
VERSION_SD3_5_8B,
29-
VERSION_SD3_5_2B,
3028
VERSION_FLUX_LITE,
3129
VERSION_COUNT,
3230
};
@@ -39,7 +37,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
3937
}
4038

4139
static inline bool sd_version_is_sd3(SDVersion version) {
42-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
40+
if (version == VERSION_SD3) {
4341
return true;
4442
}
4543
return false;

stable-diffusion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ const char* model_version_to_str[] = {
2929
"SD 2.x",
3030
"SDXL",
3131
"SVD",
32-
"SD3 2B",
32+
"SD3.x",
3333
"Flux Dev",
3434
"Flux Schnell",
35-
"SD3.5 8B",
36-
"SD3.5 2B",
3735
"Flux Lite 8B"};
3836

3937
const char* sampling_methods_str[] = {
@@ -330,7 +328,7 @@ class StableDiffusionGGML {
330328
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
331329
}
332330
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
333-
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types, version);
331+
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
334332
} else if (sd_version_is_flux(version)) {
335333
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
336334
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);

0 commit comments

Comments
 (0)