Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel {

MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn = false,
const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
}

std::string get_desc() {
Expand Down
53 changes: 33 additions & 20 deletions mmdit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,16 @@ class SelfAttention : public GGMLBlock {
int64_t num_heads;
bool pre_only;
std::string qk_norm;
bool flash_attn;

public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
bool pre_only = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) {
Expand Down Expand Up @@ -206,8 +208,8 @@ class SelfAttention : public GGMLBlock {
ggml_backend_t backend,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
Expand All @@ -232,6 +234,7 @@ struct DismantledBlock : public GGMLBlock {
int64_t num_heads;
bool pre_only;
bool self_attn;
bool flash_attn;

public:
DismantledBlock(int64_t hidden_size,
Expand All @@ -240,16 +243,17 @@ struct DismantledBlock : public GGMLBlock {
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn = false)
bool self_attn = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase
// scale_mod_only is always Flase
// swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));

if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
}

if (!pre_only) {
Expand Down Expand Up @@ -435,8 +439,8 @@ struct DismantledBlock : public GGMLBlock {
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);

auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
Expand All @@ -452,7 +456,7 @@ struct DismantledBlock : public GGMLBlock {
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;

auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
Expand All @@ -468,6 +472,7 @@ struct DismantledBlock : public GGMLBlock {
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx,
ggml_backend_t backend,
bool flash_attn,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
Expand Down Expand Up @@ -497,8 +502,8 @@ block_mixing(struct ggml_context* ctx,
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
}

auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx,
attn,
attn->ne[0],
Expand Down Expand Up @@ -556,16 +561,20 @@ block_mixing(struct ggml_context* ctx,
}

struct JointBlock : public GGMLBlock {
bool flash_attn;

public:
JointBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
bool self_attn_x = false,
bool flash_attn = false)
: flash_attn(flash_attn) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
}

std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
Expand All @@ -576,7 +585,7 @@ struct JointBlock : public GGMLBlock {
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);

return block_mixing(ctx, backend, context, x, c, context_block, x_block);
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
}
};

Expand Down Expand Up @@ -634,14 +643,16 @@ struct MMDiT : public GGMLBlock {
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
bool flash_attn = false;

void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
}

public:
MMDiT(const String2GGMLType& tensor_types = {}) {
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
: flash_attn(flash_attn) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
Expand Down Expand Up @@ -709,7 +720,8 @@ struct MMDiT : public GGMLBlock {
qk_norm,
true,
i == depth - 1,
i <= d_self));
i <= d_self,
flash_attn));
}

blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
Expand Down Expand Up @@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {

MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
}

Expand Down Expand Up @@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false, false));
{
LOG_INFO("loading from '%s'", file_path.c_str());

Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class StableDiffusionGGML {
model_loader.tensor_storages_types);
diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu,
sd_ctx_params->diffusion_flash_attn,
model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;
Expand Down
Loading