@@ -196,27 +196,91 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
196196 }
197197
198198 std::vector<int > convert_token_to_id (std::string text) {
199+ size_t search_pos = 0 ;
199200 auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
200- size_t word_end = str.find (" ," );
201- std::string embd_name = word_end == std::string::npos ? str : str.substr (0 , word_end);
202- embd_name = trim (embd_name);
203- std::string embd_path = get_full_path (embd_dir, embd_name + " .pt" );
204- if (embd_path.size () == 0 ) {
205- embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
201+ std::string token_str;
202+ size_t consumed_len = 0 ;
203+ bool is_embed_tag = false ;
204+
205+ // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
206+ std::string trimmed_str = trim (str);
207+ size_t leading_spaces = str.length () - trimmed_str.length ();
208+
209+ if (starts_with (trimmed_str, " <embed:" )) {
210+ size_t tag_end = trimmed_str.find (" >" );
211+ if (tag_end == std::string::npos) {
212+ return false ; // Incomplete tag.
213+ }
214+ std::string lower_tag = trimmed_str.substr (0 , tag_end + 1 );
215+ token_str = lower_tag; // Fallback to lowercased version
216+
217+ if (text.length () >= lower_tag.length ()) {
218+ for (size_t i = search_pos; i <= text.length () - lower_tag.length (); ++i) {
219+ bool match = true ;
220+ for (size_t j = 0 ; j < lower_tag.length (); ++j) {
221+ if (std::tolower (text[i + j]) != lower_tag[j]) {
222+ match = false ;
223+ break ;
224+ }
225+ }
226+ if (match) {
227+ token_str = text.substr (i, lower_tag.length ());
228+ search_pos = i + token_str.length ();
229+ break ;
230+ }
231+ }
232+ }
233+ consumed_len = leading_spaces + token_str.length ();
234+ is_embed_tag = true ;
235+ } else {
236+ // Not a tag. Could be a plain trigger word.
237+ size_t first_delim = trimmed_str.find_first_of (" ," );
238+ token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr (0 , first_delim);
239+ consumed_len = leading_spaces + token_str.length ();
240+ }
241+
242+ std::string embd_name = trim (token_str);
243+ if (is_embed_tag) {
244+ embd_name = embd_name.substr (strlen (" <embed:" ), embd_name.length () - strlen (" <embed:" ) - 1 );
206245 }
207- if (embd_path.size () == 0 ) {
208- embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
246+
247+ std::string embd_path;
248+ bool is_path = contains (embd_name, " /" ) || contains (embd_name, " \\ " );
249+
250+ if (is_path) {
251+ if (file_exists (embd_name)) {
252+ embd_path = embd_name;
253+ } else if (file_exists (embd_name + " .safetensors" )) {
254+ embd_path = embd_name + " .safetensors" ;
255+ } else if (file_exists (embd_name + " .pt" )) {
256+ embd_path = embd_name + " .pt" ;
257+ } else if (file_exists (embd_name + " .ckpt" )) {
258+ embd_path = embd_name + " .ckpt" ;
259+ }
260+ } else {
261+ embd_path = get_full_path (embd_dir, embd_name + " .pt" );
262+ if (embd_path.size () == 0 ) {
263+ embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
264+ }
265+ if (embd_path.size () == 0 ) {
266+ embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
267+ }
209268 }
269+
210270 if (embd_path.size () > 0 ) {
211271 if (load_embedding (embd_name, embd_path, bpe_tokens)) {
212- if (word_end != std::string::npos) {
213- str = str.substr (word_end);
214- } else {
215- str = " " ;
216- }
272+ str = str.substr (consumed_len);
217273 return true ;
218274 }
219275 }
276+
277+ if (is_embed_tag) {
278+ LOG_WARN (" could not load embedding '%s'" , embd_name.c_str ());
279+ str = str.substr (consumed_len);
280+ return true ; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
281+ }
282+
283+ // It was not a tag and we couldn't find a file for it as a trigger word.
220284 return false ;
221285 };
222286 std::vector<int > curr_tokens = tokenizer.encode (text, on_new_token_cb);
@@ -245,30 +309,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
245309 LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
246310 }
247311
248- auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
249- size_t word_end = str.find (" ," );
250- std::string embd_name = word_end == std::string::npos ? str : str.substr (0 , word_end);
251- embd_name = trim (embd_name);
252- std::string embd_path = get_full_path (embd_dir, embd_name + " .pt" );
253- if (embd_path.size () == 0 ) {
254- embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
255- }
256- if (embd_path.size () == 0 ) {
257- embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
258- }
259- if (embd_path.size () > 0 ) {
260- if (load_embedding (embd_name, embd_path, bpe_tokens)) {
261- if (word_end != std::string::npos) {
262- str = str.substr (word_end);
263- } else {
264- str = " " ;
265- }
266- return true ;
267- }
268- }
269- return false ;
270- };
271-
272312 std::vector<int > tokens;
273313 std::vector<float > weights;
274314 std::vector<bool > class_token_mask;
@@ -278,6 +318,93 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
278318 std::vector<int > clean_input_ids;
279319 const std::string& curr_text = item.first ;
280320 float curr_weight = item.second ;
321+ size_t search_pos = 0 ;
322+ auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
323+ std::string token_str;
324+ size_t consumed_len = 0 ;
325+ bool is_embed_tag = false ;
326+
327+ // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
328+ std::string trimmed_str = trim (str);
329+ size_t leading_spaces = str.length () - trimmed_str.length ();
330+
331+ if (starts_with (trimmed_str, " <embed:" )) {
332+ size_t tag_end = trimmed_str.find (" >" );
333+ if (tag_end == std::string::npos) {
334+ return false ; // Incomplete tag.
335+ }
336+ std::string lower_tag = trimmed_str.substr (0 , tag_end + 1 );
337+ token_str = lower_tag; // Fallback to lowercased version
338+
339+ if (curr_text.length () >= lower_tag.length ()) {
340+ for (size_t i = search_pos; i <= curr_text.length () - lower_tag.length (); ++i) {
341+ bool match = true ;
342+ for (size_t j = 0 ; j < lower_tag.length (); ++j) {
343+ if (std::tolower (curr_text[i + j]) != lower_tag[j]) {
344+ match = false ;
345+ break ;
346+ }
347+ }
348+ if (match) {
349+ token_str = curr_text.substr (i, lower_tag.length ());
350+ search_pos = i + token_str.length ();
351+ break ;
352+ }
353+ }
354+ }
355+ consumed_len = leading_spaces + token_str.length ();
356+ is_embed_tag = true ;
357+ } else {
358+ // Not a tag. Could be a plain trigger word.
359+ size_t first_delim = trimmed_str.find_first_of (" ," );
360+ token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr (0 , first_delim);
361+ consumed_len = leading_spaces + token_str.length ();
362+ }
363+
364+ std::string embd_name = trim (token_str);
365+ if (is_embed_tag) {
366+ embd_name = embd_name.substr (strlen (" <embed:" ), embd_name.length () - strlen (" <embed:" ) - 1 );
367+ }
368+
369+ std::string embd_path;
370+ bool is_path = contains (embd_name, " /" ) || contains (embd_name, " \\ " );
371+
372+ if (is_path) {
373+ if (file_exists (embd_name)) {
374+ embd_path = embd_name;
375+ } else if (file_exists (embd_name + " .safetensors" )) {
376+ embd_path = embd_name + " .safetensors" ;
377+ } else if (file_exists (embd_name + " .pt" )) {
378+ embd_path = embd_name + " .pt" ;
379+ } else if (file_exists (embd_name + " .ckpt" )) {
380+ embd_path = embd_name + " .ckpt" ;
381+ }
382+ } else {
383+ embd_path = get_full_path (embd_dir, embd_name + " .pt" );
384+ if (embd_path.size () == 0 ) {
385+ embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
386+ }
387+ if (embd_path.size () == 0 ) {
388+ embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
389+ }
390+ }
391+
392+ if (embd_path.size () > 0 ) {
393+ if (load_embedding (embd_name, embd_path, bpe_tokens)) {
394+ str = str.substr (consumed_len);
395+ return true ;
396+ }
397+ }
398+
399+ if (is_embed_tag) {
400+ LOG_WARN (" could not load embedding '%s'" , embd_name.c_str ());
401+ str = str.substr (consumed_len);
402+ return true ; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
403+ }
404+
405+ // It was not a tag and we couldn't find a file for it as a trigger word.
406+ return false ;
407+ };
281408 // printf(" %s: %f \n", curr_text.c_str(), curr_weight);
282409 std::vector<int > curr_tokens = tokenizer.encode (curr_text, on_new_token_cb);
283410 int32_t clean_index = 0 ;
@@ -359,35 +486,98 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
359486 LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
360487 }
361488
362- auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
363- size_t word_end = str.find (" ," );
364- std::string embd_name = word_end == std::string::npos ? str : str.substr (0 , word_end);
365- embd_name = trim (embd_name);
366- std::string embd_path = get_full_path (embd_dir, embd_name + " .pt" );
367- if (embd_path.size () == 0 ) {
368- embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
369- }
370- if (embd_path.size () == 0 ) {
371- embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
372- }
373- if (embd_path.size () > 0 ) {
374- if (load_embedding (embd_name, embd_path, bpe_tokens)) {
375- if (word_end != std::string::npos) {
376- str = str.substr (word_end);
377- } else {
378- str = " " ;
379- }
380- return true ;
381- }
382- }
383- return false ;
384- };
385-
386489 std::vector<int > tokens;
387490 std::vector<float > weights;
388491 for (const auto & item : parsed_attention) {
389492 const std::string& curr_text = item.first ;
390493 float curr_weight = item.second ;
494+ size_t search_pos = 0 ;
495+ auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
496+ std::string token_str;
497+ size_t consumed_len = 0 ;
498+ bool is_embed_tag = false ;
499+
500+ // The tokenizer gives us chunks of text. We only process the first potential embedding token in that chunk.
501+ std::string trimmed_str = trim (str);
502+ size_t leading_spaces = str.length () - trimmed_str.length ();
503+
504+ if (starts_with (trimmed_str, " <embed:" )) {
505+ size_t tag_end = trimmed_str.find (" >" );
506+ if (tag_end == std::string::npos) {
507+ return false ; // Incomplete tag.
508+ }
509+ std::string lower_tag = trimmed_str.substr (0 , tag_end + 1 );
510+ token_str = lower_tag; // Fallback to lowercased version
511+
512+ if (curr_text.length () >= lower_tag.length ()) {
513+ for (size_t i = search_pos; i <= curr_text.length () - lower_tag.length (); ++i) {
514+ bool match = true ;
515+ for (size_t j = 0 ; j < lower_tag.length (); ++j) {
516+ if (std::tolower (curr_text[i + j]) != lower_tag[j]) {
517+ match = false ;
518+ break ;
519+ }
520+ }
521+ if (match) {
522+ token_str = curr_text.substr (i, lower_tag.length ());
523+ search_pos = i + token_str.length ();
524+ break ;
525+ }
526+ }
527+ }
528+ consumed_len = leading_spaces + token_str.length ();
529+ is_embed_tag = true ;
530+ } else {
531+ // Not a tag. Could be a plain trigger word.
532+ size_t first_delim = trimmed_str.find_first_of (" ," );
533+ token_str = (first_delim == std::string::npos) ? trimmed_str : trimmed_str.substr (0 , first_delim);
534+ consumed_len = leading_spaces + token_str.length ();
535+ }
536+
537+ std::string embd_name = trim (token_str);
538+ if (is_embed_tag) {
539+ embd_name = embd_name.substr (strlen (" <embed:" ), embd_name.length () - strlen (" <embed:" ) - 1 );
540+ }
541+
542+ std::string embd_path;
543+ bool is_path = contains (embd_name, " /" ) || contains (embd_name, " \\ " );
544+
545+ if (is_path) {
546+ if (file_exists (embd_name)) {
547+ embd_path = embd_name;
548+ } else if (file_exists (embd_name + " .safetensors" )) {
549+ embd_path = embd_name + " .safetensors" ;
550+ } else if (file_exists (embd_name + " .pt" )) {
551+ embd_path = embd_name + " .pt" ;
552+ } else if (file_exists (embd_name + " .ckpt" )) {
553+ embd_path = embd_name + " .ckpt" ;
554+ }
555+ } else {
556+ embd_path = get_full_path (embd_dir, embd_name + " .pt" );
557+ if (embd_path.size () == 0 ) {
558+ embd_path = get_full_path (embd_dir, embd_name + " .ckpt" );
559+ }
560+ if (embd_path.size () == 0 ) {
561+ embd_path = get_full_path (embd_dir, embd_name + " .safetensors" );
562+ }
563+ }
564+
565+ if (embd_path.size () > 0 ) {
566+ if (load_embedding (embd_name, embd_path, bpe_tokens)) {
567+ str = str.substr (consumed_len);
568+ return true ;
569+ }
570+ }
571+
572+ if (is_embed_tag) {
573+ LOG_WARN (" could not load embedding '%s'" , embd_name.c_str ());
574+ str = str.substr (consumed_len);
575+ return true ; // Consume the failed tag so the tokenizer doesn't try to parse it as text.
576+ }
577+
578+ // It was not a tag and we couldn't find a file for it as a trigger word.
579+ return false ;
580+ };
391581 std::vector<int > curr_tokens = tokenizer.encode (curr_text, on_new_token_cb);
392582 tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
393583 weights.insert (weights.end (), curr_tokens.size (), curr_weight);
0 commit comments