Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

/*================================================== CLIPTokenizer ===================================================*/

std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
__STATIC_INLINE__ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
std::regex re("<lora:([^:]+):([^>]+)>");
std::smatch matches;
std::unordered_map<std::string, float> filename2multiplier;
Expand All @@ -31,7 +31,7 @@ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remov
return std::make_pair(filename2multiplier, text);
}

std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
Expand Down
44 changes: 39 additions & 5 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class ResBlock : public GGMLBlock {
}
};

class GEGLU : public GGMLBlock {
class GEGLU : public UnaryBlock {
protected:
int64_t dim_in;
int64_t dim_out;
Expand Down Expand Up @@ -216,14 +216,42 @@ class GEGLU : public GGMLBlock {
}
};

class GELU : public UnaryBlock {
public:
GELU(int64_t dim_in, int64_t dim_out, bool bias = true) {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out, bias));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);

x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
return x;
}
};

class FeedForward : public GGMLBlock {
public:
enum class Activation {
GEGLU,
GELU
};
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4) {
int64_t mult = 4,
Activation activation = Activation::GEGLU,
bool force_prec_f32 = false) {
int64_t inner_dim = dim * mult;
SD_UNUSED(force_prec_f32);
if (activation == Activation::GELU) {
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GELU(dim, inner_dim));
} else {
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
}

blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
// net_1 is nn.Dropout(), skip for inference
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
}
Expand All @@ -232,11 +260,17 @@ class FeedForward : public GGMLBlock {
// x: [ne3, ne2, ne1, dim]
// return: [ne3, ne2, ne1, dim_out]

auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]);
auto net_0 = std::dynamic_pointer_cast<UnaryBlock>(blocks["net.0"]);
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);

x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using Vulkan without enabling force_prec_f32,
// or when using CUDA but the weights are k-quants.
float scale = 1.f / 128.f;
x = ggml_scale(ctx, x, scale);
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
x = ggml_scale(ctx, x, 1.f / scale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious which part in the CUDA backend causes the issue here? I assume you are working around some FP overflow?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s likely that ggml_mul_mat has a precision issue when the weights are k-quants.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why did Jeff's ggml_mul_mat_set_prec fix work for vulkan but not cuda, could cuda be ignoring that?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cuda approach to matmul is pretty different (see #851 (comment)). Anecdotally it seems to be less prone to precision issues, but I guess it can still run into problems.

return x;
}
};
Expand Down
192 changes: 138 additions & 54 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __CONDITIONER_HPP__

#include "clip.hpp"
#include "qwenvl.hpp"
#include "t5.hpp"

struct SDCondition {
Expand All @@ -22,11 +23,11 @@ struct Conditioner {
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
Expand All @@ -35,9 +36,13 @@ struct Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) = 0;
bool zero_out_masked = false) {
GGML_ABORT("Not implemented yet!");
}
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) = 0;
const std::string& prompt) {
GGML_ABORT("Not implemented yet!");
}
};

// ldm.modules.encoders.modules.FrozenCLIPEmbedder
Expand Down Expand Up @@ -978,23 +983,6 @@ struct SD3CLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, 77, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}

std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}

std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};

struct FluxCLIPEmbedder : public Conditioner {
Expand Down Expand Up @@ -1195,23 +1183,6 @@ struct FluxCLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}

std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}

std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};

struct T5CLIPEmbedder : public Conditioner {
Expand Down Expand Up @@ -1398,22 +1369,135 @@ struct T5CLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
};

std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
int prompt_template_encode_start_idx = 34;

Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "") {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
}

std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
}

void alloc_params_buffer() {
qwenvl->alloc_params_buffer();
}

void free_params_buffer() {
qwenvl->free_params_buffer();
}

size_t get_params_buffer_size() {
size_t buffer_size = 0;
buffer_size += qwenvl->get_params_buffer_size();
return buffer_size;
}

std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);

{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}

tokenizer.pad_tokens(tokens, weights, max_length, padding);

// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// }
// std::cout << std::endl;

return {tokens, weights};
}

SDCondition get_learned_condition_common(ggml_context* work_ctx,
int n_threads,
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
int clip_skip,
bool zero_out_masked = false) {
auto& tokens = std::get<0>(token_and_weights);
auto& weights = std::get<1>(token_and_weights);

int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

qwenvl->compute(n_threads,
input_ids,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
}

GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);

ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[2]);

ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});

int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return SDCondition(new_hidden_states, nullptr, nullptr);
}

SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(prompt, 0, false);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
};

Expand Down
55 changes: 55 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "flux.hpp"
#include "mmdit.hpp"
#include "qwen_image.hpp"
#include "unet.hpp"
#include "wan.hpp"

Expand Down Expand Up @@ -263,4 +264,58 @@ struct WanModel : public DiffusionModel {
}
};

struct QwenImageModel : public DiffusionModel {
std::string prefix;
Qwen::QwenImageRunner qwen_image;

QwenImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
}

std::string get_desc() {
return qwen_image.get_desc();
}

void alloc_params_buffer() {
qwen_image.alloc_params_buffer();
}

void free_params_buffer() {
qwen_image.free_params_buffer();
}

void free_compute_buffer() {
qwen_image.free_compute_buffer();
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
qwen_image.get_param_tensors(tensors, prefix);
}

size_t get_params_buffer_size() {
return qwen_image.get_params_buffer_size();
}

int64_t get_adm_in_channels() {
return 768;
}

void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return qwen_image.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
output,
output_ctx);
}
};

#endif
Loading
Loading