@@ -637,7 +637,6 @@ struct FinalLayer : public GGMLBlock {
637637struct MMDiT : public GGMLBlock {
638638 // Diffusion model with a Transformer backbone.
639639protected:
640- SDVersion version = VERSION_SD3_2B;
641640 int64_t input_size = -1 ;
642641 int64_t patch_size = 2 ;
643642 int64_t in_channels = 16 ;
@@ -659,8 +658,7 @@ struct MMDiT : public GGMLBlock {
659658 }
660659
661660public:
662- MMDiT (SDVersion version = VERSION_SD3_2B)
663- : version(version) {
661+ MMDiT (std::map<std::string, enum ggml_type>& tensor_types) {
664662 // input_size is always None
665663 // learn_sigma is always False
666664 // register_length is alwalys 0
@@ -672,48 +670,44 @@ struct MMDiT : public GGMLBlock {
672670 // pos_embed_scaling_factor is not used
673671 // pos_embed_offset is not used
674672 // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
675- if (version == VERSION_SD3_2B) {
676- input_size = -1 ;
677- patch_size = 2 ;
678- in_channels = 16 ;
679- depth = 24 ;
680- mlp_ratio = 4 .0f ;
681- adm_in_channels = 2048 ;
682- out_channels = 16 ;
683- pos_embed_max_size = 192 ;
684- num_patchs = 36864 ; // 192 * 192
685- context_size = 4096 ;
686- context_embedder_out_dim = 1536 ;
687- } else if (version == VERSION_SD3_5_8B) {
688- input_size = -1 ;
689- patch_size = 2 ;
690- in_channels = 16 ;
691- depth = 38 ;
692- mlp_ratio = 4 .0f ;
693- adm_in_channels = 2048 ;
694- out_channels = 16 ;
695- pos_embed_max_size = 192 ;
696- num_patchs = 36864 ; // 192 * 192
697- context_size = 4096 ;
698- context_embedder_out_dim = 2432 ;
699- qk_norm = " rms" ;
700- } else if (version == VERSION_SD3_5_2B) {
701- input_size = -1 ;
702- patch_size = 2 ;
703- in_channels = 16 ;
704- depth = 24 ;
705- d_self = 12 ;
706- mlp_ratio = 4 .0f ;
707- adm_in_channels = 2048 ;
708- out_channels = 16 ;
709- pos_embed_max_size = 384 ;
710- num_patchs = 147456 ;
711- context_size = 4096 ;
712- context_embedder_out_dim = 1536 ;
713- qk_norm = " rms" ;
673+
674+ // read tensors from tensor_types
675+ for (auto pair : tensor_types) {
676+ std::string tensor_name = pair.first ;
677+ if (tensor_name.find (" model.diffusion_model." ) == std::string::npos)
678+ continue ;
679+ size_t jb = tensor_name.find (" joint_blocks." );
680+ if (jb != std::string::npos) {
681+ tensor_name = tensor_name.substr (jb); // remove prefix
682+ int block_depth = atoi (tensor_name.substr (13 , tensor_name.find (" ." , 13 )).c_str ());
683+ if (block_depth + 1 > depth) {
684+ depth = block_depth + 1 ;
685+ }
686+ if (tensor_name.find (" attn.ln" ) != std::string::npos) {
687+ if (tensor_name.find (" .bias" ) != std::string::npos) {
688+ qk_norm = " ln" ;
689+ } else {
690+ qk_norm = " rms" ;
691+ }
692+ }
693+ if (tensor_name.find (" attn2" ) != std::string::npos) {
694+ if (block_depth > d_self) {
695+ d_self = block_depth;
696+ }
697+ }
698+ }
714699 }
700+
701+ if (d_self >= 0 ) {
702+ pos_embed_max_size *= 2 ;
703+ num_patchs *= 4 ;
704+ }
705+
706+ LOG_INFO (" MMDiT layers: %d (including %d MMDiT-x layers)" , depth, d_self + 1 );
707+
715708 int64_t default_out_channels = in_channels;
716709 hidden_size = 64 * depth;
710+ context_embedder_out_dim = 64 * depth;
717711 int64_t num_heads = depth;
718712
719713 blocks[" x_embedder" ] = std::shared_ptr<GGMLBlock>(new PatchEmbed (input_size, patch_size, in_channels, hidden_size, true ));
@@ -879,9 +873,8 @@ struct MMDiTRunner : public GGMLRunner {
879873
880874 MMDiTRunner (ggml_backend_t backend,
881875 std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
882- const std::string prefix = " " ,
883- SDVersion version = VERSION_SD3_2B)
884- : GGMLRunner(backend), mmdit(version) {
876+ const std::string prefix = " " )
877+ : GGMLRunner(backend), mmdit(tensor_types) {
885878 mmdit.init (params_ctx, tensor_types, prefix);
886879 }
887880
0 commit comments