-
Notifications
You must be signed in to change notification settings - Fork 93
Add focus_agent, embedding and bm25 agents #302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
imenelydiaker
wants to merge
3
commits into
ServiceNow:main
Choose a base branch
from
imenelydiaker:add-focus-agent
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # BM25Agent | ||
|
|
||
| A retrieval-augmented agent that uses BM25 (Best Matching 25) algorithm to filter and retrieve the most relevant parts of the accessibility tree (AXTree) based on the current goal and task history. | ||
|
|
||
| ## Overview | ||
|
|
||
| ``BM25Agent`` extends ``GenericAgent`` with intelligent content retrieval capabilities. Instead of processing the entire accessibility tree, it chunks the content and uses BM25 ranking to retrieve only the most relevant sections, reducing token usage and improving focus on task-relevant elements. | ||
|
|
||
| ## Key Features | ||
|
|
||
| - **BM25-based retrieval**: Uses the BM25 algorithm to rank and retrieve relevant content chunks | ||
| - **Token-aware chunking**: Splits accessibility trees using tiktoken for optimal token usage | ||
| - **Configurable parameters**: Adjustable chunk size, overlap, and top-k retrieval | ||
| - **History integration**: Can optionally include task history in retrieval queries | ||
| - **Memory efficient**: Reduces context size by filtering irrelevant content | ||
|
|
||
| ## Architecture | ||
|
|
||
| ```text | ||
| Query (goal + history) → BM25 Retriever → Top-K Chunks → LLM → Action | ||
| ↑ | ||
| AXTree | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| ### Basic Configuration | ||
|
|
||
| ```python | ||
| from agentlab.agents.bm25_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs | ||
| from agentlab.agents.bm25_agent.bm25_retriever import BM25RetrieverArgs | ||
| from agentlab.agents.bm25_agent.bm25_retriever_agent import BM25RetrieverAgentFlags | ||
|
|
||
| # Configure retriever parameters | ||
| retriever_args = BM25RetrieverArgs( | ||
| chunk_size=200, # Tokens per chunk | ||
| overlap=10, # Token overlap between chunks | ||
| top_k=10, # Number of chunks to retrieve | ||
| use_recursive_text_splitter=False # Use Langchain text splitter | ||
| ) | ||
|
|
||
| # Configure agent flags | ||
| retriever_flags = BM25RetrieverAgentFlags( | ||
| use_history=True # Include task history in queries | ||
| ) | ||
|
|
||
| # Create agent | ||
| agent_args = BM25RetrieverAgentArgs( | ||
| chat_model_args=your_chat_model_args, | ||
| flags=your_flags, | ||
| retriever_args=retriever_args, | ||
| retriever_flags=retriever_flags | ||
| ) | ||
|
|
||
| agent = agent_args.make_agent() | ||
| ``` | ||
|
|
||
| ### Pre-configured Agents | ||
|
|
||
| ```python | ||
| from agentlab.agents.bm25_agent.agent_configs import ( | ||
| BM25_RETRIEVER_AGENT, # Chunk size is 200 tokens | ||
| BM25_RETRIEVER_AGENT_100 # Chunk size is 100 tokens | ||
| ) | ||
|
|
||
| # Use default configuration | ||
| agent = BM25_RETRIEVER_AGENT.make_agent() | ||
| ``` | ||
|
|
||
| ## Configuration Parameters | ||
|
|
||
| ### BM25RetrieverArgs | ||
|
|
||
| - `chunk_size` (int, default=100): Number of tokens per chunk | ||
| - `overlap` (int, default=10): Token overlap between consecutive chunks | ||
| - `top_k` (int, default=5): Number of most relevant chunks to retrieve | ||
| - `use_recursive_text_splitter` (bool, default=False): Use LangChain's recursive text splitter. Using this text splitter will override the ``chunk_size`` an ``overlap`` parameters. | ||
|
|
||
| ### BM25RetrieverAgentFlags | ||
|
|
||
| - `use_history` (bool, default=False): Include interaction history in retrieval queries | ||
|
|
||
| ## Citation | ||
|
|
||
| If you use this agent in your work, please consider citing: | ||
|
|
||
| ```bibtex | ||
|
|
||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from .bm25_retriever_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs | ||
| from .bm25_retriever import BM25RetrieverArgs | ||
| from .agent_configs import ( | ||
| BM25_RETRIEVER_AGENT, | ||
| BM25_RETRIEVER_AGENT_100, | ||
| BM25_RETRIEVER_AGENT_50, | ||
| BM25_RETRIEVER_AGENT_500, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o | ||
| from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT | ||
|
|
||
| from .bm25_retriever import BM25RetrieverArgs | ||
| from .bm25_retriever_agent import BM25RetrieverAgentArgs, BM25RetrieverAgentFlags | ||
|
|
||
| FLAGS_GPT_4o = FLAGS_GPT_4o.copy() | ||
| FLAGS_GPT_4o.obs.use_think_history = True | ||
|
|
||
| BM25_RETRIEVER_AGENT = BM25RetrieverAgentArgs( | ||
| agent_name="BM25RetrieverAgent-4.1", | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
| flags=FLAGS_GPT_4o, | ||
| retriever_args=BM25RetrieverArgs( | ||
| top_k=10, | ||
| chunk_size=200, | ||
| overlap=10, | ||
| use_recursive_text_splitter=False, | ||
| ), | ||
| retriever_flags=BM25RetrieverAgentFlags( | ||
| use_history=True, | ||
| ), | ||
| ) | ||
|
|
||
| BM25_RETRIEVER_AGENT_100 = BM25RetrieverAgentArgs( | ||
| agent_name="BM25RetrieverAgent-4.1-100", | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
| flags=FLAGS_GPT_4o, | ||
| retriever_args=BM25RetrieverArgs( | ||
| top_k=10, | ||
| chunk_size=100, | ||
| overlap=10, | ||
| use_recursive_text_splitter=False, | ||
| ), | ||
| retriever_flags=BM25RetrieverAgentFlags( | ||
| use_history=True, | ||
| ), | ||
| ) | ||
|
|
||
| BM25_RETRIEVER_AGENT_50 = BM25RetrieverAgentArgs( | ||
| agent_name="BM25RetrieverAgent-4.1-50", | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
| flags=FLAGS_GPT_4o, | ||
| retriever_args=BM25RetrieverArgs( | ||
| top_k=10, | ||
| chunk_size=50, | ||
| overlap=5, | ||
| use_recursive_text_splitter=False, | ||
| ), | ||
| retriever_flags=BM25RetrieverAgentFlags( | ||
| use_history=True, | ||
| ), | ||
| ) | ||
|
|
||
| BM25_RETRIEVER_AGENT_500 = BM25RetrieverAgentArgs( | ||
| agent_name="BM25RetrieverAgent-4.1-500", | ||
| chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"], | ||
| flags=FLAGS_GPT_4o, | ||
| retriever_args=BM25RetrieverArgs( | ||
| top_k=10, | ||
| chunk_size=500, | ||
| overlap=10, | ||
| use_recursive_text_splitter=False, | ||
| ), | ||
| retriever_flags=BM25RetrieverAgentFlags( | ||
| use_history=True, | ||
| ), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| import re | ||
| from dataclasses import dataclass | ||
|
|
||
| try: | ||
| import bm25s | ||
| except ImportError: | ||
| raise ImportError("bm25s is not installed. Please install it using `pip agentlab[retrievers]`.") | ||
| import tiktoken # Added import for tiktoken | ||
|
|
||
| from .utils import get_chunks_from_tokenizer | ||
|
|
||
|
|
||
| def count_tokens(text: str) -> int: | ||
| """Count the number of tokens in the text using tiktoken for GPT-4.""" | ||
| encoding = tiktoken.encoding_for_model("gpt-4") | ||
| tokens = encoding.encode(text) | ||
| return len(tokens) | ||
|
|
||
|
|
||
| @dataclass | ||
| class BM25RetrieverArgs: | ||
| chunk_size: int = 100 | ||
| overlap: int = 10 | ||
| top_k: int = 5 | ||
| use_recursive_text_splitter: bool = False | ||
|
|
||
|
|
||
| class BM25SRetriever: | ||
| """Simple retriever using BM25S to retrieve the most relevant lines""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| tree: str, | ||
| chunk_size: int, | ||
| overlap: int, | ||
| top_k: int, | ||
| use_recursive_text_splitter: bool, | ||
| ): | ||
| self.chunk_size = chunk_size | ||
| self.overlap = overlap | ||
| self.top_k = top_k | ||
| self.use_recursive_text_splitter = use_recursive_text_splitter | ||
| corpus = get_chunks_from_tokenizer(tree) | ||
| self.retriever = bm25s.BM25(corpus=corpus) | ||
| tokenized_corpus = bm25s.tokenize(corpus) | ||
| self.retriever.index(tokenized_corpus) | ||
|
|
||
| def retrieve(self, query): | ||
| tokenized_query = bm25s.tokenize(query) | ||
| if self.top_k > len(self.retriever.corpus): | ||
| results, _ = self.retriever.retrieve( | ||
| query_tokens=tokenized_query, k=len(self.retriever.corpus) | ||
| ) | ||
| else: | ||
| results, _ = self.retriever.retrieve(query_tokens=tokenized_query, k=self.top_k) | ||
| return [str(res) for res in results[0]] | ||
|
|
||
| def create_text_chunks(self, axtree, chunk_size=200, overlap=50): | ||
| if self.use_recursive_text_splitter: | ||
| try: | ||
| from langchain.text_splitter import ( | ||
| RecursiveCharacterTextSplitter, | ||
| ) | ||
| except ImportError: | ||
| raise ImportError( | ||
| "langchain is not installed. Please install it using `pip agentlab[retrievers]`." | ||
| ) | ||
|
|
||
| text_splitter = RecursiveCharacterTextSplitter() | ||
| return text_splitter.split_text(axtree) | ||
| else: | ||
| return get_chunks_from_tokenizer(axtree, self.chunk_size, self.overlap) | ||
|
|
||
| @staticmethod | ||
| def extract_bid(line): | ||
| """ | ||
| Extracts the bid from a line in the format '[bid] textarea ...'. | ||
|
|
||
| Parameters: | ||
| line (str): The input line containing the bid in square brackets. | ||
|
|
||
| Returns: | ||
| str: The extracted bid, or None if no bid is found. | ||
| """ | ||
| match = re.search(r"\[([a-zA-Z0-9]+)\]", line) | ||
| if match: | ||
| return match.group(1) | ||
| return None | ||
|
|
||
| @classmethod | ||
| def get_elements_around(cls, tree, element_id, n): | ||
| """ | ||
| Get n elements around the given element_id from the AXTree while preserving its indentation structure. | ||
|
|
||
| :param tree: String representing the AXTree with indentations. | ||
| :param element_id: The element ID to center around (can include alphanumeric IDs like 'a203'). | ||
| :param n: The number of elements to include before and after. | ||
| :return: String of the AXTree elements around the given element ID, preserving indentation. | ||
| """ | ||
| # Split the tree into lines | ||
| lines = tree.splitlines() | ||
|
|
||
| # Extract the line indices and content containing element IDs | ||
| id_lines = [(i, line) for i, line in enumerate(lines) if "[" in line and "]" in line] | ||
|
|
||
| # Parse the IDs from the lines | ||
| parsed_ids = [] | ||
| for idx, line in id_lines: | ||
| try: | ||
| element_id_in_line = line.split("[")[1].split("]")[0] | ||
| parsed_ids.append((idx, element_id_in_line, line)) | ||
| except IndexError: | ||
| continue | ||
|
|
||
| # Find the index of the element with the given ID | ||
| target_idx = next( | ||
| (i for i, (_, eid, _) in enumerate(parsed_ids) if eid == element_id), None | ||
| ) | ||
|
|
||
| if target_idx is None: | ||
| raise ValueError(f"Element ID {element_id} not found in the tree.") | ||
|
|
||
| # Calculate the range of elements to include | ||
| start_idx = max(0, target_idx - n) | ||
| end_idx = min(len(parsed_ids), target_idx + n + 1) | ||
|
|
||
| # Collect the lines to return | ||
| result_lines = [] | ||
| for idx in range(start_idx, end_idx): | ||
| line_idx = parsed_ids[idx][0] | ||
| result_lines.append(lines[line_idx]) | ||
|
|
||
| return "\n".join(result_lines) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have the the bm25_agent and embedding_agent in the focus_agent subdirectory, as they are related baselines. @recursix Do you have any thoughts about this?