Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 251 additions & 61 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,27 +196,91 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}

std::vector<int> convert_token_to_id(std::string text) {
size_t search_pos = 0;
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
std::string token_str;
size_t consumed_len = 0;
bool is_embed_tag = false;

// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
std::string trimmed_str = trim(str);
size_t leading_spaces = str.length() - trimmed_str.length();

if (starts_with(trimmed_str, "<embed:")) {
size_t tag_end = trimmed_str.find(">");
if (tag_end == std::string::npos) {
return false; // Incomplete tag.
}
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
token_str = lower_tag; // Fallback to lowercased version

if (text.length() >= lower_tag.length()) {
for (size_t i = search_pos; i <= text.length() - lower_tag.length(); ++i) {
bool match = true;
for (size_t j = 0; j < lower_tag.length(); ++j) {
if (std::tolower(text[i + j]) != lower_tag[j]) {
match = false;
break;
}
}
if (match) {
token_str = text.substr(i, lower_tag.length());
search_pos = i + token_str.length();
break;
}
}
}
consumed_len = leading_spaces + token_str.length();
is_embed_tag = true;
} else {
// Not a tag. Could be a plain trigger word.
size_t first_delim = trimmed_str.find_first_of(" ,");
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
consumed_len = leading_spaces + token_str.length();
}

std::string embd_name = trim(token_str);
if (is_embed_tag) {
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");

std::string embd_path;
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");

if (is_path) {
if (file_exists(embd_name)) {
embd_path = embd_name;
} else if (file_exists(embd_name + ".safetensors")) {
embd_path = embd_name + ".safetensors";
} else if (file_exists(embd_name + ".pt")) {
embd_path = embd_name + ".pt";
} else if (file_exists(embd_name + ".ckpt")) {
embd_path = embd_name + ".ckpt";
}
} else {
embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
}

if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
str = str.substr(consumed_len);
return true;
}
}

if (is_embed_tag) {
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
str = str.substr(consumed_len);
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
}

// It was not a tag and we couldn't find a file for it as a trigger word.
return false;
};
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
Expand Down Expand Up @@ -245,30 +309,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
}
return false;
};

std::vector<int> tokens;
std::vector<float> weights;
std::vector<bool> class_token_mask;
Expand All @@ -278,6 +318,93 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> clean_input_ids;
const std::string& curr_text = item.first;
float curr_weight = item.second;
size_t search_pos = 0;
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
std::string token_str;
size_t consumed_len = 0;
bool is_embed_tag = false;

// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
std::string trimmed_str = trim(str);
size_t leading_spaces = str.length() - trimmed_str.length();

if (starts_with(trimmed_str, "<embed:")) {
size_t tag_end = trimmed_str.find(">");
if (tag_end == std::string::npos) {
return false; // Incomplete tag.
}
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
token_str = lower_tag; // Fallback to lowercased version

if (curr_text.length() >= lower_tag.length()) {
for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) {
bool match = true;
for (size_t j = 0; j < lower_tag.length(); ++j) {
if (std::tolower(curr_text[i + j]) != lower_tag[j]) {
match = false;
break;
}
}
if (match) {
token_str = curr_text.substr(i, lower_tag.length());
search_pos = i + token_str.length();
break;
}
}
}
consumed_len = leading_spaces + token_str.length();
is_embed_tag = true;
} else {
// Not a tag. Could be a plain trigger word.
size_t first_delim = trimmed_str.find_first_of(" ,");
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
consumed_len = leading_spaces + token_str.length();
}

std::string embd_name = trim(token_str);
if (is_embed_tag) {
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
}

std::string embd_path;
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");

if (is_path) {
if (file_exists(embd_name)) {
embd_path = embd_name;
} else if (file_exists(embd_name + ".safetensors")) {
embd_path = embd_name + ".safetensors";
} else if (file_exists(embd_name + ".pt")) {
embd_path = embd_name + ".pt";
} else if (file_exists(embd_name + ".ckpt")) {
embd_path = embd_name + ".ckpt";
}
} else {
embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
}

if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
str = str.substr(consumed_len);
return true;
}
}

if (is_embed_tag) {
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
str = str.substr(consumed_len);
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
}

// It was not a tag and we couldn't find a file for it as a trigger word.
return false;
};
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
int32_t clean_index = 0;
Expand Down Expand Up @@ -359,35 +486,98 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
}
return false;
};

std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
size_t search_pos = 0;
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
std::string token_str;
size_t consumed_len = 0;
bool is_embed_tag = false;

// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
std::string trimmed_str = trim(str);
size_t leading_spaces = str.length() - trimmed_str.length();

if (starts_with(trimmed_str, "<embed:")) {
size_t tag_end = trimmed_str.find(">");
if (tag_end == std::string::npos) {
return false; // Incomplete tag.
}
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
token_str = lower_tag; // Fallback to lowercased version

if (curr_text.length() >= lower_tag.length()) {
for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) {
bool match = true;
for (size_t j = 0; j < lower_tag.length(); ++j) {
if (std::tolower(curr_text[i + j]) != lower_tag[j]) {
match = false;
break;
}
}
if (match) {
token_str = curr_text.substr(i, lower_tag.length());
search_pos = i + token_str.length();
break;
}
}
}
consumed_len = leading_spaces + token_str.length();
is_embed_tag = true;
} else {
// Not a tag. Could be a plain trigger word.
size_t first_delim = trimmed_str.find_first_of(" ,");
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
consumed_len = leading_spaces + token_str.length();
}

std::string embd_name = trim(token_str);
if (is_embed_tag) {
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
}

std::string embd_path;
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");

if (is_path) {
if (file_exists(embd_name)) {
embd_path = embd_name;
} else if (file_exists(embd_name + ".safetensors")) {
embd_path = embd_name + ".safetensors";
} else if (file_exists(embd_name + ".pt")) {
embd_path = embd_name + ".pt";
} else if (file_exists(embd_name + ".ckpt")) {
embd_path = embd_name + ".ckpt";
}
} else {
embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
}

if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
str = str.substr(consumed_len);
return true;
}
}

if (is_embed_tag) {
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
str = str.substr(consumed_len);
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
}

// It was not a tag and we couldn't find a file for it as a trigger word.
return false;
};
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
Expand Down
25 changes: 21 additions & 4 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,17 +810,34 @@ class StableDiffusionGGML {
is_high_noise = true;
LOG_DEBUG("high noise lora: %s", lora_name.c_str());
}
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
std::string st_file_path;
std::string ckpt_file_path;
std::string file_path;
if (file_exists(st_file_path)) {
bool is_path = contains(lora_name, "/") || contains(lora_name, "\\");

if (is_path) {
st_file_path = lora_name + ".safetensors";
ckpt_file_path = lora_name + ".ckpt";
} else {
st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
}

if (is_path && file_exists(lora_name)) {
file_path = lora_name;
} else if (file_exists(st_file_path)) {
file_path = st_file_path;
} else if (file_exists(ckpt_file_path)) {
file_path = ckpt_file_path;
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
if (is_path) {
LOG_WARN("can not find lora file %s, %s or %s", lora_name.c_str(), st_file_path.c_str(), ckpt_file_path.c_str());
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
}
return;
}

LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
if (!lora.load_from_file()) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
Expand Down
Loading