Skip to content

Commit bc1551b

Browse files
committed
allow the use of absolute paths for lora and embeddings
1 parent fce6afc commit bc1551b

File tree

2 files changed

+283
-91
lines changed

2 files changed

+283
-91
lines changed

conditioner.hpp

Lines changed: 251 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,91 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
196196
}
197197

198198
std::vector<int> convert_token_to_id(std::string text) {
199+
size_t search_pos = 0;
199200
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
200-
size_t word_end = str.find(",");
201-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
202-
embd_name = trim(embd_name);
203-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
204-
if (embd_path.size() == 0) {
205-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
201+
std::string token_str;
202+
size_t consumed_len = 0;
203+
bool is_embed_tag = false;
204+
205+
// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
206+
std::string trimmed_str = trim(str);
207+
size_t leading_spaces = str.length() - trimmed_str.length();
208+
209+
if (starts_with(trimmed_str, "<embed:")) {
210+
size_t tag_end = trimmed_str.find(">");
211+
if (tag_end == std::string::npos) {
212+
return false; // Incomplete tag.
213+
}
214+
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
215+
token_str = lower_tag; // Fallback to lowercased version
216+
217+
if (text.length() >= lower_tag.length()) {
218+
for (size_t i = search_pos; i <= text.length() - lower_tag.length(); ++i) {
219+
bool match = true;
220+
for (size_t j = 0; j < lower_tag.length(); ++j) {
221+
if (std::tolower(text[i + j]) != lower_tag[j]) {
222+
match = false;
223+
break;
224+
}
225+
}
226+
if (match) {
227+
token_str = text.substr(i, lower_tag.length());
228+
search_pos = i + token_str.length();
229+
break;
230+
}
231+
}
232+
}
233+
consumed_len = leading_spaces + token_str.length();
234+
is_embed_tag = true;
235+
} else {
236+
// Not a tag. Could be a plain trigger word.
237+
size_t first_delim = trimmed_str.find_first_of(" ,");
238+
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
239+
consumed_len = leading_spaces + token_str.length();
240+
}
241+
242+
std::string embd_name = trim(token_str);
243+
if (is_embed_tag) {
244+
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
206245
}
207-
if (embd_path.size() == 0) {
208-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
246+
247+
std::string embd_path;
248+
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");
249+
250+
if (is_path) {
251+
if (file_exists(embd_name)) {
252+
embd_path = embd_name;
253+
} else if (file_exists(embd_name + ".safetensors")) {
254+
embd_path = embd_name + ".safetensors";
255+
} else if (file_exists(embd_name + ".pt")) {
256+
embd_path = embd_name + ".pt";
257+
} else if (file_exists(embd_name + ".ckpt")) {
258+
embd_path = embd_name + ".ckpt";
259+
}
260+
} else {
261+
embd_path = get_full_path(embd_dir, embd_name + ".pt");
262+
if (embd_path.size() == 0) {
263+
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
264+
}
265+
if (embd_path.size() == 0) {
266+
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
267+
}
209268
}
269+
210270
if (embd_path.size() > 0) {
211271
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
212-
if (word_end != std::string::npos) {
213-
str = str.substr(word_end);
214-
} else {
215-
str = "";
216-
}
272+
str = str.substr(consumed_len);
217273
return true;
218274
}
219275
}
276+
277+
if (is_embed_tag) {
278+
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
279+
str = str.substr(consumed_len);
280+
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
281+
}
282+
283+
// It was not a tag and we couldn't find a file for it as a trigger word.
220284
return false;
221285
};
222286
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
@@ -245,30 +309,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
245309
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
246310
}
247311

248-
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
249-
size_t word_end = str.find(",");
250-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
251-
embd_name = trim(embd_name);
252-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
253-
if (embd_path.size() == 0) {
254-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
255-
}
256-
if (embd_path.size() == 0) {
257-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
258-
}
259-
if (embd_path.size() > 0) {
260-
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
261-
if (word_end != std::string::npos) {
262-
str = str.substr(word_end);
263-
} else {
264-
str = "";
265-
}
266-
return true;
267-
}
268-
}
269-
return false;
270-
};
271-
272312
std::vector<int> tokens;
273313
std::vector<float> weights;
274314
std::vector<bool> class_token_mask;
@@ -278,6 +318,93 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
278318
std::vector<int> clean_input_ids;
279319
const std::string& curr_text = item.first;
280320
float curr_weight = item.second;
321+
size_t search_pos = 0;
322+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
323+
std::string token_str;
324+
size_t consumed_len = 0;
325+
bool is_embed_tag = false;
326+
327+
// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
328+
std::string trimmed_str = trim(str);
329+
size_t leading_spaces = str.length() - trimmed_str.length();
330+
331+
if (starts_with(trimmed_str, "<embed:")) {
332+
size_t tag_end = trimmed_str.find(">");
333+
if (tag_end == std::string::npos) {
334+
return false; // Incomplete tag.
335+
}
336+
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
337+
token_str = lower_tag; // Fallback to lowercased version
338+
339+
if (curr_text.length() >= lower_tag.length()) {
340+
for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) {
341+
bool match = true;
342+
for (size_t j = 0; j < lower_tag.length(); ++j) {
343+
if (std::tolower(curr_text[i + j]) != lower_tag[j]) {
344+
match = false;
345+
break;
346+
}
347+
}
348+
if (match) {
349+
token_str = curr_text.substr(i, lower_tag.length());
350+
search_pos = i + token_str.length();
351+
break;
352+
}
353+
}
354+
}
355+
consumed_len = leading_spaces + token_str.length();
356+
is_embed_tag = true;
357+
} else {
358+
// Not a tag. Could be a plain trigger word.
359+
size_t first_delim = trimmed_str.find_first_of(" ,");
360+
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
361+
consumed_len = leading_spaces + token_str.length();
362+
}
363+
364+
std::string embd_name = trim(token_str);
365+
if (is_embed_tag) {
366+
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
367+
}
368+
369+
std::string embd_path;
370+
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");
371+
372+
if (is_path) {
373+
if (file_exists(embd_name)) {
374+
embd_path = embd_name;
375+
} else if (file_exists(embd_name + ".safetensors")) {
376+
embd_path = embd_name + ".safetensors";
377+
} else if (file_exists(embd_name + ".pt")) {
378+
embd_path = embd_name + ".pt";
379+
} else if (file_exists(embd_name + ".ckpt")) {
380+
embd_path = embd_name + ".ckpt";
381+
}
382+
} else {
383+
embd_path = get_full_path(embd_dir, embd_name + ".pt");
384+
if (embd_path.size() == 0) {
385+
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
386+
}
387+
if (embd_path.size() == 0) {
388+
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
389+
}
390+
}
391+
392+
if (embd_path.size() > 0) {
393+
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
394+
str = str.substr(consumed_len);
395+
return true;
396+
}
397+
}
398+
399+
if (is_embed_tag) {
400+
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
401+
str = str.substr(consumed_len);
402+
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
403+
}
404+
405+
// It was not a tag and we couldn't find a file for it as a trigger word.
406+
return false;
407+
};
281408
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
282409
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
283410
int32_t clean_index = 0;
@@ -359,35 +486,98 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
359486
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
360487
}
361488

362-
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
363-
size_t word_end = str.find(",");
364-
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
365-
embd_name = trim(embd_name);
366-
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
367-
if (embd_path.size() == 0) {
368-
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
369-
}
370-
if (embd_path.size() == 0) {
371-
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
372-
}
373-
if (embd_path.size() > 0) {
374-
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
375-
if (word_end != std::string::npos) {
376-
str = str.substr(word_end);
377-
} else {
378-
str = "";
379-
}
380-
return true;
381-
}
382-
}
383-
return false;
384-
};
385-
386489
std::vector<int> tokens;
387490
std::vector<float> weights;
388491
for (const auto& item : parsed_attention) {
389492
const std::string& curr_text = item.first;
390493
float curr_weight = item.second;
494+
size_t search_pos = 0;
495+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
496+
std::string token_str;
497+
size_t consumed_len = 0;
498+
bool is_embed_tag = false;
499+
500+
// The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
501+
std::string trimmed_str = trim(str);
502+
size_t leading_spaces = str.length() - trimmed_str.length();
503+
504+
if (starts_with(trimmed_str, "<embed:")) {
505+
size_t tag_end = trimmed_str.find(">");
506+
if (tag_end == std::string::npos) {
507+
return false; // Incomplete tag.
508+
}
509+
std::string lower_tag = trimmed_str.substr(0, tag_end + 1);
510+
token_str = lower_tag; // Fallback to lowercased version
511+
512+
if (curr_text.length() >= lower_tag.length()) {
513+
for (size_t i = search_pos; i <= curr_text.length() - lower_tag.length(); ++i) {
514+
bool match = true;
515+
for (size_t j = 0; j < lower_tag.length(); ++j) {
516+
if (std::tolower(curr_text[i + j]) != lower_tag[j]) {
517+
match = false;
518+
break;
519+
}
520+
}
521+
if (match) {
522+
token_str = curr_text.substr(i, lower_tag.length());
523+
search_pos = i + token_str.length();
524+
break;
525+
}
526+
}
527+
}
528+
consumed_len = leading_spaces + token_str.length();
529+
is_embed_tag = true;
530+
} else {
531+
// Not a tag. Could be a plain trigger word.
532+
size_t first_delim = trimmed_str.find_first_of(" ,");
533+
token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr(0, first_delim);
534+
consumed_len = leading_spaces + token_str.length();
535+
}
536+
537+
std::string embd_name = trim(token_str);
538+
if (is_embed_tag) {
539+
embd_name = embd_name.substr(strlen("<embed:"), embd_name.length() - strlen("<embed:") - 1);
540+
}
541+
542+
std::string embd_path;
543+
bool is_path = contains(embd_name, "/") || contains(embd_name, "\\");
544+
545+
if (is_path) {
546+
if (file_exists(embd_name)) {
547+
embd_path = embd_name;
548+
} else if (file_exists(embd_name + ".safetensors")) {
549+
embd_path = embd_name + ".safetensors";
550+
} else if (file_exists(embd_name + ".pt")) {
551+
embd_path = embd_name + ".pt";
552+
} else if (file_exists(embd_name + ".ckpt")) {
553+
embd_path = embd_name + ".ckpt";
554+
}
555+
} else {
556+
embd_path = get_full_path(embd_dir, embd_name + ".pt");
557+
if (embd_path.size() == 0) {
558+
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
559+
}
560+
if (embd_path.size() == 0) {
561+
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
562+
}
563+
}
564+
565+
if (embd_path.size() > 0) {
566+
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
567+
str = str.substr(consumed_len);
568+
return true;
569+
}
570+
}
571+
572+
if (is_embed_tag) {
573+
LOG_WARN("could not load embedding '%s'", embd_name.c_str());
574+
str = str.substr(consumed_len);
575+
return true; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
576+
}
577+
578+
// It was not a tag and we couldn't find a file for it as a trigger word.
579+
return false;
580+
};
391581
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
392582
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
393583
weights.insert(weights.end(), curr_tokens.size(), curr_weight);

0 commit comments

Comments
 (0)