Skip to content

Commit 9378c2f

Browse files
committed
match both tokens and context text
1 parent d51ea08 commit 9378c2f

File tree

2 files changed

+142
-10
lines changed

2 files changed

+142
-10
lines changed

optillm/autothink/processor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,16 @@ def process(self, messages: List[Dict[str, str]]) -> str:
245245
return_tensors="pt"
246246
).to(self.model.device)
247247

248-
# Update token history in steering hooks
248+
# Reset and update token history in steering hooks
249249
if self.steering_hooks:
250250
token_ids = tokens[0].tolist()
251+
prompt_text = self.tokenizer.decode(token_ids)
251252
for hook, _ in self.steering_hooks:
253+
# Reset the hook state for a new generation
254+
hook.reset()
255+
# Update both token history and text context buffer
252256
hook.update_token_history(token_ids)
257+
hook.update_context(prompt_text)
253258
# Try to match with a steering vector
254259
hook.try_match()
255260

@@ -343,13 +348,19 @@ def process(self, messages: List[Dict[str, str]]) -> str:
343348
# Update steering hooks with new token
344349
if self.steering_hooks:
345350
for hook, _ in self.steering_hooks:
346-
# Update token history with the new token
351+
# Update both token history and text context
347352
hook.update_token_history([next_token])
353+
hook.update_context(next_str)
348354
# Check for matches on EVERY token
349355
hook.try_match()
350356

351357
tokens = torch.tensor([[next_token]]).to(tokens.device)
352358

359+
# Reset and clean up steering hooks
360+
if self.steering_hooks:
361+
for hook, _ in self.steering_hooks:
362+
hook.reset()
363+
353364
# Clean up steering hooks
354365
self._cleanup_steering()
355366

optillm/autothink/steering.py

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,23 +539,60 @@ def update_token_history(self, new_tokens: List[int]):
539539
# Log token updates periodically
540540
if random.random() < 0.01:
541541
logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens")
542+
543+
def update_context(self, new_tokens: str):
544+
"""
545+
Update the context buffer with new tokens.
546+
547+
Args:
548+
new_tokens: New tokens to add to the context.
549+
"""
550+
# Both methods - text-based and token-based
551+
if self.tokenizer is not None:
552+
# Token-based approach (similar to guided mode)
553+
# Tokenize the new text
554+
token_ids = self.tokenizer.encode(new_tokens, add_special_tokens=False)
555+
556+
if token_ids: # Only proceed if we got tokens
557+
# Add to token history
558+
self.token_history.extend(token_ids)
559+
560+
# Trim history if needed
561+
if len(self.token_history) > self.max_history:
562+
self.token_history = self.token_history[-self.max_history:]
563+
564+
# Log token updates periodically
565+
if random.random() < 0.01:
566+
logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens")
567+
568+
# Text-based approach (always update)
569+
# Update context buffer
570+
self.context_buffer += new_tokens
571+
572+
# Keep only the last 500 characters
573+
if len(self.context_buffer) > 500:
574+
self.context_buffer = self.context_buffer[-500:]
575+
logger.debug(f"STEERING: Context buffer trimmed to {len(self.context_buffer)} chars")
542576

543577
def try_match(self):
544578
"""
545579
Try to match the current context with a steering vector.
546580
Only allows one pattern to be selected for the entire generation.
581+
Tries both token-based and text-based matching approaches.
547582
"""
548583
# If we already have an active pattern, don't try to match again
549584
if self.active_pattern:
550585
return False
551586

552-
# Use token-based matching or text-based matching as appropriate
587+
# Try both token-based and text-based matching
553588
match_result = False
589+
590+
# First try token-based matching if available
554591
if self.tokenizer is not None and hasattr(self.manager, 'tokenized_contexts') and self.manager.tokenized_contexts:
555-
# Token-based matching (similar to guided mode)
556592
match_result = self._try_token_match()
557-
else:
558-
# Text-based matching as fallback
593+
594+
# If token matching fails, try text-based matching
595+
if not match_result:
559596
match_result = self._try_text_match()
560597

561598
# Set generation started flag AFTER trying to match
@@ -638,6 +675,46 @@ def _try_token_match(self):
638675

639676
return True
640677

678+
# If no match, try fuzzy matching with 70% similarity threshold
679+
if len(self.token_history) >= 8 and not self.match_found:
680+
logger.debug("STEERING: No exact match found, trying fuzzy matching")
681+
for tokenized_context, vector in self.manager.tokenized_contexts.items():
682+
token_list = list(tokenized_context)
683+
token_len = len(token_list)
684+
685+
if token_len >= 8: # Only try fuzzy matching for contexts with enough tokens
686+
match_len = min(len(self.token_history), token_len)
687+
last_tokens = self.token_history[-match_len:]
688+
context_tokens = token_list[-match_len:]
689+
690+
# Count matching tokens
691+
matches = sum(1 for a, b in zip(last_tokens, context_tokens) if a == b)
692+
similarity = matches / match_len
693+
694+
if similarity >= 0.7: # 70% similarity threshold
695+
if match_len > best_match['length']:
696+
best_match = {
697+
'length': match_len,
698+
'vector': vector,
699+
'is_partial': True,
700+
'match_len': match_len,
701+
'token_len': token_len,
702+
'similarity': similarity
703+
}
704+
705+
# Apply fuzzy match if found
706+
if best_match['vector'] is not None:
707+
self.match_found = True
708+
self.current_vector = best_match['vector']
709+
pattern = best_match['vector'].get("reasoning_pattern", "unknown")
710+
pivot_token = best_match['vector'].get("pivot_token", "")
711+
similarity = best_match.get('similarity', 0.0)
712+
713+
logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern")
714+
logger.info(f"STEERING: Pivot token: '{pivot_token}'")
715+
716+
return True
717+
641718
# If no match, try fuzzy matching with 70% similarity threshold
642719
if len(self.token_history) >= 8 and not self.match_found:
643720
logger.debug("STEERING: No exact match found, trying fuzzy matching")
@@ -682,6 +759,10 @@ def _try_token_match(self):
682759

683760
def _try_text_match(self):
684761
"""Try to match using text-based context (original approach)."""
762+
# Skip if context buffer is too short
763+
if len(self.context_buffer) < 10: # Require at least 10 chars for matching
764+
return False
765+
685766
# Get the last 100 characters as the match key
686767
match_key = self.context_buffer[-100:] if len(self.context_buffer) >= 100 else self.context_buffer
687768

@@ -697,22 +778,62 @@ def _try_text_match(self):
697778
self.match_found = True
698779
self.current_vector = vector
699780
pattern = vector.get("reasoning_pattern", "unknown")
700-
logger.info(f"STEERING: Found text match for {pattern} reasoning pattern: '{vector.get('pivot_token', '')}'")
781+
pivot_token = vector.get("pivot_token", "")
782+
logger.info(f"STEERING: Found text match for {pattern} reasoning pattern")
783+
logger.info(f"STEERING: Pivot token: '{pivot_token}'")
701784
return True
785+
786+
# Attempt fuzzy text matching as a fallback
787+
if len(match_key) >= 20: # Only try for reasonably sized contexts
788+
# Try each steering vector for approximate match
789+
best_match = None
790+
best_similarity = 0.0
791+
792+
for vector in self.manager.steering_vectors:
793+
vector_context = vector.get("pivot_context", "")
794+
if not vector_context or len(vector_context) < 20:
795+
continue
796+
797+
# Get the end of the vector context (last 100 chars)
798+
vector_key = vector_context[-100:] if len(vector_context) >= 100 else vector_context
799+
800+
# Calculate simple character-level similarity
801+
min_length = min(len(match_key), len(vector_key))
802+
matching_chars = sum(1 for a, b in zip(match_key, vector_key) if a == b)
803+
similarity = matching_chars / min_length if min_length > 0 else 0
804+
805+
# Keep track of best match above threshold
806+
if similarity >= 0.7 and similarity > best_similarity: # 70% similarity threshold
807+
best_similarity = similarity
808+
best_match = vector
809+
810+
# Use the best match if found
811+
if best_match is not None:
812+
self.match_found = True
813+
self.current_vector = best_match
814+
pattern = best_match.get("reasoning_pattern", "unknown")
815+
pivot_token = best_match.get("pivot_token", "")
816+
logger.info(f"STEERING: Found fuzzy text match ({best_similarity:.2f} similarity) for {pattern} pattern")
817+
logger.info(f"STEERING: Pivot token: '{pivot_token}'")
818+
return True
702819

703820
return False
704821

705822
def reset(self):
706-
"""Reset the hook state."""
823+
"""Reset the hook state for a new generation."""
707824
self.match_found = False
708825
self.current_vector = None
826+
827+
# Clear both text and token histories
709828
self.context_buffer = ""
710829
self.token_history = []
711-
self.last_pattern = None
712830

713-
# Reset pattern tracking
831+
# Reset pattern and state tracking
832+
self.last_pattern = None
714833
self.active_pattern = None
715834
self.generation_started = False
835+
836+
logger.info("STEERING: Hook state reset for new generation")
716837

717838
def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None) -> List[Tuple]:
718839
"""

0 commit comments

Comments
 (0)