Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,30 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
const std::string& curr_text = item.first;
float curr_weight = item.second;
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
int32_t clean_index = 0;
if(curr_text == "BREAK" && curr_weight == -1.0f) {
// Pad token array up to chunk size at this point.
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
int padding_size = 75 - (tokens_acc % 75);
for (int j = 0; j < padding_size; j++) {
clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID);
clean_index++;
}

// After padding, continue to the next iteration to process the following text as a new segment
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
weights.insert(weights.end(), padding_size, curr_weight);
continue;
}

// Regular token, process normally
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
int token_id = curr_tokens[i];
if (token_id == image_token)
if (token_id == image_token) {
class_token_index.push_back(clean_index - 1);
else {
} else {
clean_input_ids.push_back(token_id);
clean_index++;
}
Expand Down Expand Up @@ -379,6 +396,22 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;

if(curr_text == "BREAK" && curr_weight == -1.0f) {
// Pad token array up to chunk size at this point.
// TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future?
// Also, this is 75 instead of 77 to leave room for BOS and EOS tokens.
size_t current_size = tokens.size();
size_t padding_size = (75 - (current_size % 75)) % 75; // Ensure no negative padding

if (padding_size > 0) {
LOG_DEBUG("BREAK token encountered, padding current chunk by %zu tokens.", padding_size);
tokens.insert(tokens.end(), padding_size, tokenizer.EOS_TOKEN_ID);
weights.insert(weights.end(), padding_size, 1.0f);
}
continue; // Skip to the next item after handling BREAK
}

std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
Expand Down
11 changes: 8 additions & 3 deletions util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <codecvt>
#include <fstream>
#include <locale>
#include <regex>
#include <sstream>
#include <string>
#include <thread>
Expand Down Expand Up @@ -513,6 +514,8 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
// (abc) - increases attention to abc by a multiplier of 1.1
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
// [abc] - decreases attention to abc by a multiplier of 1.1
// BREAK - separates the prompt into conceptually distinct parts for sequential processing
// B - internal helper pattern; prevents 'B' in 'BREAK' from being consumed as normal text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is enough to make the regex a little less magical. Thanks!

// \( - literal character '('
// \[ - literal character '['
// \) - literal character ')'
Expand Down Expand Up @@ -548,7 +551,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
float round_bracket_multiplier = 1.1f;
float square_bracket_multiplier = 1 / 1.1f;

std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|\bBREAK\b|[^\\()\[\]:B]+|:|\bB)");
std::regex re_break(R"(\s*\bBREAK\b\s*)");

auto multiply_range = [&](int start_position, float multiplier) {
Expand All @@ -557,7 +560,7 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
}
};

std::smatch m;
std::smatch m,m2;
std::string remaining_text = text;

while (std::regex_search(remaining_text, m, re_attention)) {
Expand All @@ -581,6 +584,8 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
square_brackets.pop_back();
} else if (text == "\\(") {
res.push_back({text.substr(1), 1.0f});
} else if (std::regex_search(text, m2, re_break)) {
res.push_back({"BREAK", -1.0f});
} else {
res.push_back({text, 1.0f});
}
Expand Down Expand Up @@ -611,4 +616,4 @@ std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::str
}

return res;
}
}
Loading