@@ -252,21 +252,27 @@ struct DismantledBlock : public GGMLBlock {
252252public:
253253 int64_t num_heads;
254254 bool pre_only;
255+ bool self_attn;
255256
256257public:
257258 DismantledBlock (int64_t hidden_size,
258259 int64_t num_heads,
259260 float mlp_ratio = 4.0 ,
260261 std::string qk_norm = " " ,
261262 bool qkv_bias = false ,
262- bool pre_only = false )
263- : num_heads(num_heads), pre_only(pre_only) {
263+ bool pre_only = false ,
264+ bool self_attn = false )
265+ : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
264266 // rmsnorm is always Flase
265267 // scale_mod_only is always Flase
266268 // swiglu is always Flase
267269 blocks[" norm1" ] = std::shared_ptr<GGMLBlock>(new LayerNorm (hidden_size, 1e-06f , false ));
268270 blocks[" attn" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
269271
272+ if (self_attn) {
273+ blocks[" attn2" ] = std::shared_ptr<GGMLBlock>(new SelfAttention (hidden_size, num_heads, qk_norm, qkv_bias, false ));
274+ }
275+
270276 if (!pre_only) {
271277 blocks[" norm2" ] = std::shared_ptr<GGMLBlock>(new LayerNorm (hidden_size, 1e-06f , false ));
272278 int64_t mlp_hidden_dim = (int64_t )(hidden_size * mlp_ratio);
@@ -277,9 +283,52 @@ struct DismantledBlock : public GGMLBlock {
277283 if (pre_only) {
278284 n_mods = 2 ;
279285 }
286+ if (self_attn) {
287+ n_mods = 9 ;
288+ }
280289 blocks[" adaLN_modulation.1" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_size, n_mods * hidden_size));
281290 }
282291
292+ std::tuple<std::vector<struct ggml_tensor *>, std::vector<struct ggml_tensor *>, std::vector<struct ggml_tensor *>> pre_attention_x (struct ggml_context * ctx,
293+ struct ggml_tensor * x,
294+ struct ggml_tensor * c) {
295+ GGML_ASSERT (self_attn);
296+ // x: [N, n_token, hidden_size]
297+ // c: [N, hidden_size]
298+ auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks[" norm1" ]);
299+ auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" attn" ]);
300+ auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks[" attn2" ]);
301+ auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks[" adaLN_modulation.1" ]);
302+
303+ int64_t n_mods = 9 ;
304+ auto m = adaLN_modulation_1->forward (ctx, ggml_silu (ctx, c)); // [N, n_mods * hidden_size]
305+ m = ggml_reshape_3d (ctx, m, c->ne [0 ], n_mods, c->ne [1 ]); // [N, n_mods, hidden_size]
306+ m = ggml_cont (ctx, ggml_permute (ctx, m, 0 , 2 , 1 , 3 )); // [n_mods, N, hidden_size]
307+
308+ int64_t offset = m->nb [1 ] * m->ne [1 ];
309+ auto shift_msa = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 0 ); // [N, hidden_size]
310+ auto scale_msa = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 1 ); // [N, hidden_size]
311+ auto gate_msa = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 2 ); // [N, hidden_size]
312+
313+ auto shift_mlp = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 3 ); // [N, hidden_size]
314+ auto scale_mlp = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 4 ); // [N, hidden_size]
315+ auto gate_mlp = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 5 ); // [N, hidden_size]
316+
317+ auto shift_msa2 = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 6 ); // [N, hidden_size]
318+ auto scale_msa2 = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 7 ); // [N, hidden_size]
319+ auto gate_msa2 = ggml_view_2d (ctx, m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 8 ); // [N, hidden_size]
320+
321+ auto x_norm = norm1->forward (ctx, x);
322+
323+ auto attn_in = modulate (ctx, x_norm, shift_msa, scale_msa);
324+ auto qkv = attn->pre_attention (ctx, attn_in);
325+
326+ auto attn2_in = modulate (ctx, x_norm, shift_msa2, scale_msa2);
327+ auto qkv2 = attn2->pre_attention (ctx, attn2_in);
328+
329+ return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
330+ }
331+
283332 std::pair<std::vector<struct ggml_tensor *>, std::vector<struct ggml_tensor *>> pre_attention (struct ggml_context * ctx,
284333 struct ggml_tensor * x,
285334 struct ggml_tensor * c) {
@@ -319,6 +368,44 @@ struct DismantledBlock : public GGMLBlock {
319368 }
320369 }
321370
371+ struct ggml_tensor * post_attention_x (struct ggml_context * ctx,
372+ struct ggml_tensor * attn_out,
373+ struct ggml_tensor * attn2_out,
374+ struct ggml_tensor * x,
375+ struct ggml_tensor * gate_msa,
376+ struct ggml_tensor * shift_mlp,
377+ struct ggml_tensor * scale_mlp,
378+ struct ggml_tensor * gate_mlp,
379+ struct ggml_tensor * gate_msa2) {
380+ // attn_out: [N, n_token, hidden_size]
381+ // x: [N, n_token, hidden_size]
382+ // gate_msa: [N, hidden_size]
383+ // shift_mlp: [N, hidden_size]
384+ // scale_mlp: [N, hidden_size]
385+ // gate_mlp: [N, hidden_size]
386+ // return: [N, n_token, hidden_size]
387+ GGML_ASSERT (!pre_only);
388+
389+ auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" attn" ]);
390+ auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks[" attn2" ]);
391+ auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks[" norm2" ]);
392+ auto mlp = std::dynamic_pointer_cast<Mlp>(blocks[" mlp" ]);
393+
394+ gate_msa = ggml_reshape_3d (ctx, gate_msa, gate_msa->ne [0 ], 1 , gate_msa->ne [1 ]); // [N, 1, hidden_size]
395+ gate_mlp = ggml_reshape_3d (ctx, gate_mlp, gate_mlp->ne [0 ], 1 , gate_mlp->ne [1 ]); // [N, 1, hidden_size]
396+ gate_msa2 = ggml_reshape_3d (ctx, gate_msa2, gate_msa2->ne [0 ], 1 , gate_msa2->ne [1 ]); // [N, 1, hidden_size]
397+
398+ attn_out = attn->post_attention (ctx, attn_out);
399+ attn2_out = attn2->post_attention (ctx, attn2_out);
400+
401+ x = ggml_add (ctx, x, ggml_mul (ctx, attn_out, gate_msa));
402+ x = ggml_add (ctx, x, ggml_mul (ctx, attn2_out, gate_msa2));
403+ auto mlp_out = mlp->forward (ctx, modulate (ctx, norm2->forward (ctx, x), shift_mlp, scale_mlp));
404+ x = ggml_add (ctx, x, ggml_mul (ctx, mlp_out, gate_mlp));
405+
406+ return x;
407+ }
408+
322409 struct ggml_tensor * post_attention (struct ggml_context * ctx,
323410 struct ggml_tensor * attn_out,
324411 struct ggml_tensor * x,
@@ -357,40 +444,71 @@ struct DismantledBlock : public GGMLBlock {
357444 // return: [N, n_token, hidden_size]
358445
359446 auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" attn" ]);
360-
361- auto qkv_intermediates = pre_attention (ctx, x, c);
362- auto qkv = qkv_intermediates.first ;
363- auto intermediates = qkv_intermediates.second ;
364-
365- auto attn_out = ggml_nn_attention_ext (ctx, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
366- x = post_attention (ctx,
367- attn_out,
368- intermediates[0 ],
369- intermediates[1 ],
370- intermediates[2 ],
371- intermediates[3 ],
372- intermediates[4 ]);
373- return x; // [N, n_token, dim]
447+ if (self_attn) {
448+ auto qkv_intermediates = pre_attention_x (ctx, x, c);
449+ // auto qkv = qkv_intermediates.first;
450+ // auto intermediates = qkv_intermediates.second;
451+ // no longer a pair, but a tuple
452+ auto qkv = std::get<0 >(qkv_intermediates);
453+ auto qkv2 = std::get<1 >(qkv_intermediates);
454+ auto intermediates = std::get<2 >(qkv_intermediates);
455+
456+ auto attn_out = ggml_nn_attention_ext (ctx, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
457+ auto attn2_out = ggml_nn_attention_ext (ctx, qkv2[0 ], qkv2[1 ], qkv2[2 ], num_heads); // [N, n_token, dim]
458+ x = post_attention_x (ctx,
459+ attn_out,
460+ attn2_out,
461+ intermediates[0 ],
462+ intermediates[1 ],
463+ intermediates[2 ],
464+ intermediates[3 ],
465+ intermediates[4 ],
466+ intermediates[5 ]);
467+ return x; // [N, n_token, dim]
468+ } else {
469+ auto qkv_intermediates = pre_attention (ctx, x, c);
470+ auto qkv = qkv_intermediates.first ;
471+ auto intermediates = qkv_intermediates.second ;
472+
473+ auto attn_out = ggml_nn_attention_ext (ctx, qkv[0 ], qkv[1 ], qkv[2 ], num_heads); // [N, n_token, dim]
474+ x = post_attention (ctx,
475+ attn_out,
476+ intermediates[0 ],
477+ intermediates[1 ],
478+ intermediates[2 ],
479+ intermediates[3 ],
480+ intermediates[4 ]);
481+ return x; // [N, n_token, dim]
482+ }
374483 }
375484};
376485
377- __STATIC_INLINE__ std::pair<struct ggml_tensor *, struct ggml_tensor *> block_mixing (struct ggml_context * ctx,
378- struct ggml_tensor * context,
379- struct ggml_tensor * x,
380- struct ggml_tensor * c,
381- std::shared_ptr<DismantledBlock> context_block,
382- std::shared_ptr<DismantledBlock> x_block) {
486+ __STATIC_INLINE__ std::pair<struct ggml_tensor *, struct ggml_tensor *>
487+ block_mixing (struct ggml_context * ctx,
488+ struct ggml_tensor * context,
489+ struct ggml_tensor * x,
490+ struct ggml_tensor * c,
491+ std::shared_ptr<DismantledBlock> context_block,
492+ std::shared_ptr<DismantledBlock> x_block) {
383493 // context: [N, n_context, hidden_size]
384494 // x: [N, n_token, hidden_size]
385495 // c: [N, hidden_size]
386496 auto context_qkv_intermediates = context_block->pre_attention (ctx, context, c);
387497 auto context_qkv = context_qkv_intermediates.first ;
388498 auto context_intermediates = context_qkv_intermediates.second ;
389499
390- auto x_qkv_intermediates = x_block->pre_attention (ctx, x, c);
391- auto x_qkv = x_qkv_intermediates.first ;
392- auto x_intermediates = x_qkv_intermediates.second ;
500+ std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
393501
502+ if (x_block->self_attn ) {
503+ auto x_qkv_intermediates = x_block->pre_attention_x (ctx, x, c);
504+ x_qkv = std::get<0 >(x_qkv_intermediates);
505+ x_qkv2 = std::get<1 >(x_qkv_intermediates);
506+ x_intermediates = std::get<2 >(x_qkv_intermediates);
507+ } else {
508+ auto x_qkv_intermediates = x_block->pre_attention (ctx, x, c);
509+ x_qkv = x_qkv_intermediates.first ;
510+ x_intermediates = x_qkv_intermediates.second ;
511+ }
394512 std::vector<struct ggml_tensor *> qkv;
395513 for (int i = 0 ; i < 3 ; i++) {
396514 qkv.push_back (ggml_concat (ctx, context_qkv[i], x_qkv[i], 1 ));
@@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
429547 context = NULL ;
430548 }
431549
432- x = x_block->post_attention (ctx,
433- x_attn,
434- x_intermediates[0 ],
435- x_intermediates[1 ],
436- x_intermediates[2 ],
437- x_intermediates[3 ],
438- x_intermediates[4 ]);
550+ if (x_block->self_attn ) {
551+ auto attn2 = ggml_nn_attention_ext (ctx, x_qkv2[0 ], x_qkv2[1 ], x_qkv2[2 ], x_block->num_heads ); // [N, n_token, hidden_size]
552+
553+ x = x_block->post_attention_x (ctx,
554+ x_attn,
555+ attn2,
556+ x_intermediates[0 ],
557+ x_intermediates[1 ],
558+ x_intermediates[2 ],
559+ x_intermediates[3 ],
560+ x_intermediates[4 ],
561+ x_intermediates[5 ]);
562+ } else {
563+ x = x_block->post_attention (ctx,
564+ x_attn,
565+ x_intermediates[0 ],
566+ x_intermediates[1 ],
567+ x_intermediates[2 ],
568+ x_intermediates[3 ],
569+ x_intermediates[4 ]);
570+ }
439571
440572 return {context, x};
441573}
@@ -447,9 +579,10 @@ struct JointBlock : public GGMLBlock {
447579 float mlp_ratio = 4.0 ,
448580 std::string qk_norm = " " ,
449581 bool qkv_bias = false ,
450- bool pre_only = false ) {
582+ bool pre_only = false ,
583+ bool self_attn_x = false ) {
451584 blocks[" context_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
452- blocks[" x_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false ));
585+ blocks[" x_block" ] = std::shared_ptr<GGMLBlock>(new DismantledBlock (hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false , self_attn_x ));
453586 }
454587
455588 std::pair<struct ggml_tensor *, struct ggml_tensor *> forward (struct ggml_context * ctx,
@@ -507,6 +640,7 @@ struct MMDiT : public GGMLBlock {
507640 int64_t input_size = -1 ;
508641 int64_t patch_size = 2 ;
509642 int64_t in_channels = 16 ;
643+ int64_t d_self = -1 ; // >=0 for MMdiT-X
510644 int64_t depth = 24 ;
511645 float mlp_ratio = 4 .0f ;
512646 int64_t adm_in_channels = 2048 ;
@@ -561,6 +695,20 @@ struct MMDiT : public GGMLBlock {
561695 context_size = 4096 ;
562696 context_embedder_out_dim = 2432 ;
563697 qk_norm = " rms" ;
698+ } else if (version == VERSION_SD3_5_2B) {
699+ input_size = -1 ;
700+ patch_size = 2 ;
701+ in_channels = 16 ;
702+ depth = 24 ;
703+ d_self = 12 ;
704+ mlp_ratio = 4 .0f ;
705+ adm_in_channels = 2048 ;
706+ out_channels = 16 ;
707+ pos_embed_max_size = 384 ;
708+ num_patchs = 147456 ;
709+ context_size = 4096 ;
710+ context_embedder_out_dim = 1536 ;
711+ qk_norm = " rms" ;
564712 }
565713 int64_t default_out_channels = in_channels;
566714 hidden_size = 64 * depth;
@@ -581,15 +729,17 @@ struct MMDiT : public GGMLBlock {
581729 mlp_ratio,
582730 qk_norm,
583731 true ,
584- i == depth - 1 ));
732+ i == depth - 1 ,
733+ i <= d_self));
585734 }
586735
587736 blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new FinalLayer (hidden_size, patch_size, out_channels));
588737 }
589738
590- struct ggml_tensor * cropped_pos_embed (struct ggml_context * ctx,
591- int64_t h,
592- int64_t w) {
739+ struct ggml_tensor *
740+ cropped_pos_embed (struct ggml_context * ctx,
741+ int64_t h,
742+ int64_t w) {
593743 auto pos_embed = params[" pos_embed" ];
594744
595745 h = (h + 1 ) / patch_size;
0 commit comments