@@ -539,23 +539,60 @@ def update_token_history(self, new_tokens: List[int]):
539539 # Log token updates periodically
540540 if random .random () < 0.01 :
541541 logger .debug (f"STEERING: Token history updated, now has { len (self .token_history )} tokens" )
542+
543+ def update_context (self , new_tokens : str ):
544+ """
545+ Update the context buffer with new tokens.
546+
547+ Args:
548+ new_tokens: New tokens to add to the context.
549+ """
550+ # Both methods - text-based and token-based
551+ if self .tokenizer is not None :
552+ # Token-based approach (similar to guided mode)
553+ # Tokenize the new text
554+ token_ids = self .tokenizer .encode (new_tokens , add_special_tokens = False )
555+
556+ if token_ids : # Only proceed if we got tokens
557+ # Add to token history
558+ self .token_history .extend (token_ids )
559+
560+ # Trim history if needed
561+ if len (self .token_history ) > self .max_history :
562+ self .token_history = self .token_history [- self .max_history :]
563+
564+ # Log token updates periodically
565+ if random .random () < 0.01 :
566+ logger .debug (f"STEERING: Token history updated, now has { len (self .token_history )} tokens" )
567+
568+ # Text-based approach (always update)
569+ # Update context buffer
570+ self .context_buffer += new_tokens
571+
572+ # Keep only the last 500 characters
573+ if len (self .context_buffer ) > 500 :
574+ self .context_buffer = self .context_buffer [- 500 :]
575+ logger .debug (f"STEERING: Context buffer trimmed to { len (self .context_buffer )} chars" )
542576
543577 def try_match (self ):
544578 """
545579 Try to match the current context with a steering vector.
546580 Only allows one pattern to be selected for the entire generation.
581+ Tries both token-based and text-based matching approaches.
547582 """
548583 # If we already have an active pattern, don't try to match again
549584 if self .active_pattern :
550585 return False
551586
552- # Use token-based matching or text-based matching as appropriate
587+ # Try both token-based and text-based matching
553588 match_result = False
589+
590+ # First try token-based matching if available
554591 if self .tokenizer is not None and hasattr (self .manager , 'tokenized_contexts' ) and self .manager .tokenized_contexts :
555- # Token-based matching (similar to guided mode)
556592 match_result = self ._try_token_match ()
557- else :
558- # Text-based matching as fallback
593+
594+ # If token matching fails, try text-based matching
595+ if not match_result :
559596 match_result = self ._try_text_match ()
560597
561598 # Set generation started flag AFTER trying to match
@@ -638,6 +675,46 @@ def _try_token_match(self):
638675
639676 return True
640677
678+ # If no match, try fuzzy matching with 70% similarity threshold
679+ if len (self .token_history ) >= 8 and not self .match_found :
680+ logger .debug ("STEERING: No exact match found, trying fuzzy matching" )
681+ for tokenized_context , vector in self .manager .tokenized_contexts .items ():
682+ token_list = list (tokenized_context )
683+ token_len = len (token_list )
684+
685+ if token_len >= 8 : # Only try fuzzy matching for contexts with enough tokens
686+ match_len = min (len (self .token_history ), token_len )
687+ last_tokens = self .token_history [- match_len :]
688+ context_tokens = token_list [- match_len :]
689+
690+ # Count matching tokens
691+ matches = sum (1 for a , b in zip (last_tokens , context_tokens ) if a == b )
692+ similarity = matches / match_len
693+
694+ if similarity >= 0.7 : # 70% similarity threshold
695+ if match_len > best_match ['length' ]:
696+ best_match = {
697+ 'length' : match_len ,
698+ 'vector' : vector ,
699+ 'is_partial' : True ,
700+ 'match_len' : match_len ,
701+ 'token_len' : token_len ,
702+ 'similarity' : similarity
703+ }
704+
705+ # Apply fuzzy match if found
706+ if best_match ['vector' ] is not None :
707+ self .match_found = True
708+ self .current_vector = best_match ['vector' ]
709+ pattern = best_match ['vector' ].get ("reasoning_pattern" , "unknown" )
710+ pivot_token = best_match ['vector' ].get ("pivot_token" , "" )
711+ similarity = best_match .get ('similarity' , 0.0 )
712+
713+ logger .info (f"STEERING: Found fuzzy match ({ similarity :.2f} similarity) for { pattern } pattern" )
714+ logger .info (f"STEERING: Pivot token: '{ pivot_token } '" )
715+
716+ return True
717+
641718 # If no match, try fuzzy matching with 70% similarity threshold
642719 if len (self .token_history ) >= 8 and not self .match_found :
643720 logger .debug ("STEERING: No exact match found, trying fuzzy matching" )
@@ -682,6 +759,10 @@ def _try_token_match(self):
682759
683760 def _try_text_match (self ):
684761 """Try to match using text-based context (original approach)."""
762+ # Skip if context buffer is too short
763+ if len (self .context_buffer ) < 10 : # Require at least 10 chars for matching
764+ return False
765+
685766 # Get the last 100 characters as the match key
686767 match_key = self .context_buffer [- 100 :] if len (self .context_buffer ) >= 100 else self .context_buffer
687768
@@ -697,22 +778,62 @@ def _try_text_match(self):
697778 self .match_found = True
698779 self .current_vector = vector
699780 pattern = vector .get ("reasoning_pattern" , "unknown" )
700- logger .info (f"STEERING: Found text match for { pattern } reasoning pattern: '{ vector .get ('pivot_token' , '' )} '" )
781+ pivot_token = vector .get ("pivot_token" , "" )
782+ logger .info (f"STEERING: Found text match for { pattern } reasoning pattern" )
783+ logger .info (f"STEERING: Pivot token: '{ pivot_token } '" )
701784 return True
785+
786+ # Attempt fuzzy text matching as a fallback
787+ if len (match_key ) >= 20 : # Only try for reasonably sized contexts
788+ # Try each steering vector for approximate match
789+ best_match = None
790+ best_similarity = 0.0
791+
792+ for vector in self .manager .steering_vectors :
793+ vector_context = vector .get ("pivot_context" , "" )
794+ if not vector_context or len (vector_context ) < 20 :
795+ continue
796+
797+ # Get the end of the vector context (last 100 chars)
798+ vector_key = vector_context [- 100 :] if len (vector_context ) >= 100 else vector_context
799+
800+ # Calculate simple character-level similarity
801+ min_length = min (len (match_key ), len (vector_key ))
802+ matching_chars = sum (1 for a , b in zip (match_key , vector_key ) if a == b )
803+ similarity = matching_chars / min_length if min_length > 0 else 0
804+
805+ # Keep track of best match above threshold
806+ if similarity >= 0.7 and similarity > best_similarity : # 70% similarity threshold
807+ best_similarity = similarity
808+ best_match = vector
809+
810+ # Use the best match if found
811+ if best_match is not None :
812+ self .match_found = True
813+ self .current_vector = best_match
814+ pattern = best_match .get ("reasoning_pattern" , "unknown" )
815+ pivot_token = best_match .get ("pivot_token" , "" )
816+ logger .info (f"STEERING: Found fuzzy text match ({ best_similarity :.2f} similarity) for { pattern } pattern" )
817+ logger .info (f"STEERING: Pivot token: '{ pivot_token } '" )
818+ return True
702819
703820 return False
704821
705822 def reset (self ):
706- """Reset the hook state."""
823+ """Reset the hook state for a new generation ."""
707824 self .match_found = False
708825 self .current_vector = None
826+
827+ # Clear both text and token histories
709828 self .context_buffer = ""
710829 self .token_history = []
711- self .last_pattern = None
712830
713- # Reset pattern tracking
831+ # Reset pattern and state tracking
832+ self .last_pattern = None
714833 self .active_pattern = None
715834 self .generation_started = False
835+
836+ logger .info ("STEERING: Hook state reset for new generation" )
716837
717838def install_steering_hooks (model , manager : SteeringVectorManager , tokenizer = None ) -> List [Tuple ]:
718839 """
0 commit comments