Skip to content

Commit 5233e8e

Browse files
committed
sd 3.5 medium
1 parent f32a874 commit 5233e8e

File tree

5 files changed

+206
-51
lines changed

5 files changed

+206
-51
lines changed

otherarch/sdcpp/mmdit.hpp

Lines changed: 188 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,27 @@ struct DismantledBlock : public GGMLBlock {
252252
public:
253253
int64_t num_heads;
254254
bool pre_only;
255+
bool self_attn;
255256

256257
public:
257258
DismantledBlock(int64_t hidden_size,
258259
int64_t num_heads,
259260
float mlp_ratio = 4.0,
260261
std::string qk_norm = "",
261262
bool qkv_bias = false,
262-
bool pre_only = false)
263-
: num_heads(num_heads), pre_only(pre_only) {
263+
bool pre_only = false,
264+
bool self_attn = false)
265+
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
264266
// rmsnorm is always Flase
265267
// scale_mod_only is always Flase
266268
// swiglu is always Flase
267269
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
268270
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
269271

272+
if (self_attn) {
273+
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
274+
}
275+
270276
if (!pre_only) {
271277
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
272278
int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
@@ -277,9 +283,52 @@ struct DismantledBlock : public GGMLBlock {
277283
if (pre_only) {
278284
n_mods = 2;
279285
}
286+
if (self_attn) {
287+
n_mods = 9;
288+
}
280289
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
281290
}
282291

292+
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
293+
struct ggml_tensor* x,
294+
struct ggml_tensor* c) {
295+
GGML_ASSERT(self_attn);
296+
// x: [N, n_token, hidden_size]
297+
// c: [N, hidden_size]
298+
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
299+
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
300+
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
301+
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
302+
303+
int64_t n_mods = 9;
304+
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
305+
m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
306+
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
307+
308+
int64_t offset = m->nb[1] * m->ne[1];
309+
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
310+
auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
311+
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
312+
313+
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
314+
auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
315+
auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
316+
317+
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size]
318+
auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size]
319+
auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size]
320+
321+
auto x_norm = norm1->forward(ctx, x);
322+
323+
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
324+
auto qkv = attn->pre_attention(ctx, attn_in);
325+
326+
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
327+
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
328+
329+
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
330+
}
331+
283332
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
284333
struct ggml_tensor* x,
285334
struct ggml_tensor* c) {
@@ -319,6 +368,44 @@ struct DismantledBlock : public GGMLBlock {
319368
}
320369
}
321370

371+
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
372+
struct ggml_tensor* attn_out,
373+
struct ggml_tensor* attn2_out,
374+
struct ggml_tensor* x,
375+
struct ggml_tensor* gate_msa,
376+
struct ggml_tensor* shift_mlp,
377+
struct ggml_tensor* scale_mlp,
378+
struct ggml_tensor* gate_mlp,
379+
struct ggml_tensor* gate_msa2) {
380+
// attn_out: [N, n_token, hidden_size]
381+
// x: [N, n_token, hidden_size]
382+
// gate_msa: [N, hidden_size]
383+
// shift_mlp: [N, hidden_size]
384+
// scale_mlp: [N, hidden_size]
385+
// gate_mlp: [N, hidden_size]
386+
// return: [N, n_token, hidden_size]
387+
GGML_ASSERT(!pre_only);
388+
389+
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
390+
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
391+
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
392+
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
393+
394+
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
395+
gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
396+
gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size]
397+
398+
attn_out = attn->post_attention(ctx, attn_out);
399+
attn2_out = attn2->post_attention(ctx, attn2_out);
400+
401+
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
402+
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
403+
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
404+
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
405+
406+
return x;
407+
}
408+
322409
struct ggml_tensor* post_attention(struct ggml_context* ctx,
323410
struct ggml_tensor* attn_out,
324411
struct ggml_tensor* x,
@@ -357,40 +444,71 @@ struct DismantledBlock : public GGMLBlock {
357444
// return: [N, n_token, hidden_size]
358445

359446
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
360-
361-
auto qkv_intermediates = pre_attention(ctx, x, c);
362-
auto qkv = qkv_intermediates.first;
363-
auto intermediates = qkv_intermediates.second;
364-
365-
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
366-
x = post_attention(ctx,
367-
attn_out,
368-
intermediates[0],
369-
intermediates[1],
370-
intermediates[2],
371-
intermediates[3],
372-
intermediates[4]);
373-
return x; // [N, n_token, dim]
447+
if (self_attn) {
448+
auto qkv_intermediates = pre_attention_x(ctx, x, c);
449+
// auto qkv = qkv_intermediates.first;
450+
// auto intermediates = qkv_intermediates.second;
451+
// no longer a pair, but a tuple
452+
auto qkv = std::get<0>(qkv_intermediates);
453+
auto qkv2 = std::get<1>(qkv_intermediates);
454+
auto intermediates = std::get<2>(qkv_intermediates);
455+
456+
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
457+
auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
458+
x = post_attention_x(ctx,
459+
attn_out,
460+
attn2_out,
461+
intermediates[0],
462+
intermediates[1],
463+
intermediates[2],
464+
intermediates[3],
465+
intermediates[4],
466+
intermediates[5]);
467+
return x; // [N, n_token, dim]
468+
} else {
469+
auto qkv_intermediates = pre_attention(ctx, x, c);
470+
auto qkv = qkv_intermediates.first;
471+
auto intermediates = qkv_intermediates.second;
472+
473+
auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
474+
x = post_attention(ctx,
475+
attn_out,
476+
intermediates[0],
477+
intermediates[1],
478+
intermediates[2],
479+
intermediates[3],
480+
intermediates[4]);
481+
return x; // [N, n_token, dim]
482+
}
374483
}
375484
};
376485

377-
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixing(struct ggml_context* ctx,
378-
struct ggml_tensor* context,
379-
struct ggml_tensor* x,
380-
struct ggml_tensor* c,
381-
std::shared_ptr<DismantledBlock> context_block,
382-
std::shared_ptr<DismantledBlock> x_block) {
486+
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
487+
block_mixing(struct ggml_context* ctx,
488+
struct ggml_tensor* context,
489+
struct ggml_tensor* x,
490+
struct ggml_tensor* c,
491+
std::shared_ptr<DismantledBlock> context_block,
492+
std::shared_ptr<DismantledBlock> x_block) {
383493
// context: [N, n_context, hidden_size]
384494
// x: [N, n_token, hidden_size]
385495
// c: [N, hidden_size]
386496
auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c);
387497
auto context_qkv = context_qkv_intermediates.first;
388498
auto context_intermediates = context_qkv_intermediates.second;
389499

390-
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
391-
auto x_qkv = x_qkv_intermediates.first;
392-
auto x_intermediates = x_qkv_intermediates.second;
500+
std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
393501

502+
if (x_block->self_attn) {
503+
auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c);
504+
x_qkv = std::get<0>(x_qkv_intermediates);
505+
x_qkv2 = std::get<1>(x_qkv_intermediates);
506+
x_intermediates = std::get<2>(x_qkv_intermediates);
507+
} else {
508+
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
509+
x_qkv = x_qkv_intermediates.first;
510+
x_intermediates = x_qkv_intermediates.second;
511+
}
394512
std::vector<struct ggml_tensor*> qkv;
395513
for (int i = 0; i < 3; i++) {
396514
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
@@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> block_mixi
429547
context = NULL;
430548
}
431549

432-
x = x_block->post_attention(ctx,
433-
x_attn,
434-
x_intermediates[0],
435-
x_intermediates[1],
436-
x_intermediates[2],
437-
x_intermediates[3],
438-
x_intermediates[4]);
550+
if (x_block->self_attn) {
551+
auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size]
552+
553+
x = x_block->post_attention_x(ctx,
554+
x_attn,
555+
attn2,
556+
x_intermediates[0],
557+
x_intermediates[1],
558+
x_intermediates[2],
559+
x_intermediates[3],
560+
x_intermediates[4],
561+
x_intermediates[5]);
562+
} else {
563+
x = x_block->post_attention(ctx,
564+
x_attn,
565+
x_intermediates[0],
566+
x_intermediates[1],
567+
x_intermediates[2],
568+
x_intermediates[3],
569+
x_intermediates[4]);
570+
}
439571

440572
return {context, x};
441573
}
@@ -447,9 +579,10 @@ struct JointBlock : public GGMLBlock {
447579
float mlp_ratio = 4.0,
448580
std::string qk_norm = "",
449581
bool qkv_bias = false,
450-
bool pre_only = false) {
582+
bool pre_only = false,
583+
bool self_attn_x = false) {
451584
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
452-
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false));
585+
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
453586
}
454587

455588
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@@ -507,6 +640,7 @@ struct MMDiT : public GGMLBlock {
507640
int64_t input_size = -1;
508641
int64_t patch_size = 2;
509642
int64_t in_channels = 16;
643+
int64_t d_self = -1; // >=0 for MMdiT-X
510644
int64_t depth = 24;
511645
float mlp_ratio = 4.0f;
512646
int64_t adm_in_channels = 2048;
@@ -561,6 +695,20 @@ struct MMDiT : public GGMLBlock {
561695
context_size = 4096;
562696
context_embedder_out_dim = 2432;
563697
qk_norm = "rms";
698+
} else if (version == VERSION_SD3_5_2B) {
699+
input_size = -1;
700+
patch_size = 2;
701+
in_channels = 16;
702+
depth = 24;
703+
d_self = 12;
704+
mlp_ratio = 4.0f;
705+
adm_in_channels = 2048;
706+
out_channels = 16;
707+
pos_embed_max_size = 384;
708+
num_patchs = 147456;
709+
context_size = 4096;
710+
context_embedder_out_dim = 1536;
711+
qk_norm = "rms";
564712
}
565713
int64_t default_out_channels = in_channels;
566714
hidden_size = 64 * depth;
@@ -581,15 +729,17 @@ struct MMDiT : public GGMLBlock {
581729
mlp_ratio,
582730
qk_norm,
583731
true,
584-
i == depth - 1));
732+
i == depth - 1,
733+
i <= d_self));
585734
}
586735

587736
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
588737
}
589738

590-
struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx,
591-
int64_t h,
592-
int64_t w) {
739+
struct ggml_tensor*
740+
cropped_pos_embed(struct ggml_context* ctx,
741+
int64_t h,
742+
int64_t w) {
593743
auto pos_embed = params["pos_embed"];
594744

595745
h = (h + 1) / patch_size;

otherarch/sdcpp/model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,9 @@ SDVersion ModelLoader::get_sd_version() {
13761376
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
13771377
is_flux = true;
13781378
}
1379+
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
1380+
return VERSION_SD3_5_2B;
1381+
}
13791382
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
13801383
return VERSION_SD3_5_8B;
13811384
}

otherarch/sdcpp/model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ enum SDVersion {
2626
VERSION_FLUX_DEV,
2727
VERSION_FLUX_SCHNELL,
2828
VERSION_SD3_5_8B,
29+
VERSION_SD3_5_2B,
2930
VERSION_COUNT,
3031
};
3132

0 commit comments

Comments
 (0)