Skip to content

Commit b3d858a

Browse files
committed
Add focus_agent, embedding and bm25 agents
1 parent 5ff585f commit b3d858a

25 files changed

+5506
-3026
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ dependencies = [
5454
"anthropic>=0.62.0",
5555
"litellm>=1.75.3",
5656
"python-dotenv>=1.1.1",
57+
"bm25s>=0.2.14",
58+
"langchain>=0.3.27",
5759
]
5860

5961
[project.optional-dependencies]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# BM25Agent
2+
3+
A retrieval-augmented agent that uses BM25 (Best Matching 25) algorithm to filter and retrieve the most relevant parts of the accessibility tree (AXTree) based on the current goal and task history.
4+
5+
## Overview
6+
7+
``BM25Agent`` extends ``GenericAgent`` with intelligent content retrieval capabilities. Instead of processing the entire accessibility tree, it chunks the content and uses BM25 ranking to retrieve only the most relevant sections, reducing token usage and improving focus on task-relevant elements.
8+
9+
## Key Features
10+
11+
- **BM25-based retrieval**: Uses the BM25 algorithm to rank and retrieve relevant content chunks
12+
- **Token-aware chunking**: Splits accessibility trees using tiktoken for optimal token usage
13+
- **Configurable parameters**: Adjustable chunk size, overlap, and top-k retrieval
14+
- **History integration**: Can optionally include task history in retrieval queries
15+
- **Memory efficient**: Reduces context size by filtering irrelevant content
16+
17+
## Architecture
18+
19+
```text
20+
Query (goal + history) → BM25 Retriever → Top-K Chunks → LLM → Action
21+
22+
AXTree
23+
```
24+
25+
## Usage
26+
27+
### Basic Configuration
28+
29+
```python
30+
from agentlab.agents.bm25_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
31+
from agentlab.agents.bm25_agent.bm25_retriever import BM25RetrieverArgs
32+
from agentlab.agents.bm25_agent.bm25_retriever_agent import BM25RetrieverAgentFlags
33+
34+
# Configure retriever parameters
35+
retriever_args = BM25RetrieverArgs(
36+
chunk_size=200, # Tokens per chunk
37+
overlap=10, # Token overlap between chunks
38+
top_k=10, # Number of chunks to retrieve
39+
use_recursive_text_splitter=False # Use Langchain text splitter
40+
)
41+
42+
# Configure agent flags
43+
retriever_flags = BM25RetrieverAgentFlags(
44+
use_history=True # Include task history in queries
45+
)
46+
47+
# Create agent
48+
agent_args = BM25RetrieverAgentArgs(
49+
chat_model_args=your_chat_model_args,
50+
flags=your_flags,
51+
retriever_args=retriever_args,
52+
retriever_flags=retriever_flags
53+
)
54+
55+
agent = agent_args.make_agent()
56+
```
57+
58+
### Pre-configured Agents
59+
60+
```python
61+
from agentlab.agents.bm25_agent.agent_configs import (
62+
BM25_RETRIEVER_AGENT, # Chunk size is 200 tokens
63+
BM25_RETRIEVER_AGENT_100 # Chunk size is 100 tokens
64+
)
65+
66+
# Use default configuration
67+
agent = BM25_RETRIEVER_AGENT.make_agent()
68+
```
69+
70+
## Configuration Parameters
71+
72+
### BM25RetrieverArgs
73+
74+
- `chunk_size` (int, default=100): Number of tokens per chunk
75+
- `overlap` (int, default=10): Token overlap between consecutive chunks
76+
- `top_k` (int, default=5): Number of most relevant chunks to retrieve
77+
- `use_recursive_text_splitter` (bool, default=False): Use LangChain's recursive text splitter. Using this text splitter will override the ``chunk_size`` an ``overlap`` parameters.
78+
79+
### BM25RetrieverAgentFlags
80+
81+
- `use_history` (bool, default=False): Include interaction history in retrieval queries
82+
83+
## Citation
84+
85+
If you use this agent in your work, please consider citing:
86+
87+
```bibtex
88+
89+
```
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .bm25_retriever_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
2+
from .bm25_retriever import BM25RetrieverArgs
3+
from .agent_configs import (
4+
BM25_RETRIEVER_AGENT,
5+
BM25_RETRIEVER_AGENT_100,
6+
BM25_RETRIEVER_AGENT_50,
7+
BM25_RETRIEVER_AGENT_500,
8+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
2+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
3+
4+
from .bm25_retriever import BM25RetrieverArgs
5+
from .bm25_retriever_agent import BM25RetrieverAgentArgs, BM25RetrieverAgentFlags
6+
7+
FLAGS_GPT_4o = FLAGS_GPT_4o.copy()
8+
FLAGS_GPT_4o.obs.use_think_history = True
9+
10+
BM25_RETRIEVER_AGENT = BM25RetrieverAgentArgs(
11+
agent_name="BM25RetrieverAgent-4.1",
12+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
13+
flags=FLAGS_GPT_4o,
14+
retriever_args=BM25RetrieverArgs(
15+
top_k=10,
16+
chunk_size=200,
17+
overlap=10,
18+
use_recursive_text_splitter=False,
19+
),
20+
retriever_flags=BM25RetrieverAgentFlags(
21+
use_history=True,
22+
),
23+
)
24+
25+
BM25_RETRIEVER_AGENT_100 = BM25RetrieverAgentArgs(
26+
agent_name="BM25RetrieverAgent-4.1-100",
27+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
28+
flags=FLAGS_GPT_4o,
29+
retriever_args=BM25RetrieverArgs(
30+
top_k=10,
31+
chunk_size=100,
32+
overlap=10,
33+
use_recursive_text_splitter=False,
34+
),
35+
retriever_flags=BM25RetrieverAgentFlags(
36+
use_history=True,
37+
),
38+
)
39+
40+
BM25_RETRIEVER_AGENT_50 = BM25RetrieverAgentArgs(
41+
agent_name="BM25RetrieverAgent-4.1-50",
42+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
43+
flags=FLAGS_GPT_4o,
44+
retriever_args=BM25RetrieverArgs(
45+
top_k=10,
46+
chunk_size=50,
47+
overlap=5,
48+
use_recursive_text_splitter=False,
49+
),
50+
retriever_flags=BM25RetrieverAgentFlags(
51+
use_history=True,
52+
),
53+
)
54+
55+
BM25_RETRIEVER_AGENT_500 = BM25RetrieverAgentArgs(
56+
agent_name="BM25RetrieverAgent-4.1-500",
57+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
58+
flags=FLAGS_GPT_4o,
59+
retriever_args=BM25RetrieverArgs(
60+
top_k=10,
61+
chunk_size=500,
62+
overlap=10,
63+
use_recursive_text_splitter=False,
64+
),
65+
retriever_flags=BM25RetrieverAgentFlags(
66+
use_history=True,
67+
),
68+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from dataclasses import dataclass
2+
import re
3+
4+
import bm25s
5+
import tiktoken # Added import for tiktoken
6+
7+
from .utils import get_chunks_from_tokenizer
8+
9+
10+
def count_tokens(text: str) -> int:
11+
"""Count the number of tokens in the text using tiktoken for GPT-4."""
12+
encoding = tiktoken.encoding_for_model("gpt-4")
13+
tokens = encoding.encode(text)
14+
return len(tokens)
15+
16+
17+
@dataclass
18+
class BM25RetrieverArgs:
19+
chunk_size: int = 100
20+
overlap: int = 10
21+
top_k: int = 5
22+
use_recursive_text_splitter: bool = False
23+
24+
25+
class BM25SRetriever:
26+
"""Simple retriever using BM25S to retrieve the most relevant lines"""
27+
28+
def __init__(
29+
self,
30+
tree: str,
31+
chunk_size: int,
32+
overlap: int,
33+
top_k: int,
34+
use_recursive_text_splitter: bool,
35+
):
36+
self.chunk_size = chunk_size
37+
self.overlap = overlap
38+
self.top_k = top_k
39+
self.use_recursive_text_splitter = use_recursive_text_splitter
40+
corpus = get_chunks_from_tokenizer(tree)
41+
self.retriever = bm25s.BM25(corpus=corpus)
42+
tokenized_corpus = bm25s.tokenize(corpus)
43+
self.retriever.index(tokenized_corpus)
44+
45+
def retrieve(self, query):
46+
tokenized_query = bm25s.tokenize(query)
47+
if self.top_k > len(self.retriever.corpus):
48+
results, _ = self.retriever.retrieve(
49+
query_tokens=tokenized_query, k=len(self.retriever.corpus)
50+
)
51+
else:
52+
results, _ = self.retriever.retrieve(query_tokens=tokenized_query, k=self.top_k)
53+
return [str(res) for res in results[0]]
54+
55+
def create_text_chunks(self, axtree, chunk_size=200, overlap=50):
56+
if self.use_recursive_text_splitter:
57+
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
58+
59+
text_splitter = RecursiveCharacterTextSplitter()
60+
return text_splitter.split_text(axtree)
61+
else:
62+
return get_chunks_from_tokenizer(axtree, self.chunk_size, self.overlap)
63+
64+
@staticmethod
65+
def extract_bid(line):
66+
"""
67+
Extracts the bid from a line in the format '[bid] textarea ...'.
68+
69+
Parameters:
70+
line (str): The input line containing the bid in square brackets.
71+
72+
Returns:
73+
str: The extracted bid, or None if no bid is found.
74+
"""
75+
match = re.search(r"\[([a-zA-Z0-9]+)\]", line)
76+
if match:
77+
return match.group(1)
78+
return None
79+
80+
@classmethod
81+
def get_elements_around(cls, tree, element_id, n):
82+
"""
83+
Get n elements around the given element_id from the AXTree while preserving its indentation structure.
84+
85+
:param tree: String representing the AXTree with indentations.
86+
:param element_id: The element ID to center around (can include alphanumeric IDs like 'a203').
87+
:param n: The number of elements to include before and after.
88+
:return: String of the AXTree elements around the given element ID, preserving indentation.
89+
"""
90+
# Split the tree into lines
91+
lines = tree.splitlines()
92+
93+
# Extract the line indices and content containing element IDs
94+
id_lines = [(i, line) for i, line in enumerate(lines) if "[" in line and "]" in line]
95+
96+
# Parse the IDs from the lines
97+
parsed_ids = []
98+
for idx, line in id_lines:
99+
try:
100+
element_id_in_line = line.split("[")[1].split("]")[0]
101+
parsed_ids.append((idx, element_id_in_line, line))
102+
except IndexError:
103+
continue
104+
105+
# Find the index of the element with the given ID
106+
target_idx = next(
107+
(i for i, (_, eid, _) in enumerate(parsed_ids) if eid == element_id), None
108+
)
109+
110+
if target_idx is None:
111+
raise ValueError(f"Element ID {element_id} not found in the tree.")
112+
113+
# Calculate the range of elements to include
114+
start_idx = max(0, target_idx - n)
115+
end_idx = min(len(parsed_ids), target_idx + n + 1)
116+
117+
# Collect the lines to return
118+
result_lines = []
119+
for idx in range(start_idx, end_idx):
120+
line_idx = parsed_ids[idx][0]
121+
result_lines.append(lines[line_idx])
122+
123+
return "\n".join(result_lines)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from copy import copy
2+
from dataclasses import dataclass
3+
4+
from browsergym.experiments import Agent
5+
6+
import agentlab.agents.dynamic_prompting as dp
7+
from agentlab.agents.generic_agent.generic_agent import GenericAgent, GenericAgentArgs
8+
from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags
9+
from agentlab.llm.chat_api import ChatModelArgs
10+
11+
from .bm25_retriever import BM25RetrieverArgs, BM25SRetriever
12+
13+
14+
@dataclass
15+
class BM25RetrieverAgentFlags:
16+
use_history: bool = False
17+
18+
19+
@dataclass
20+
class BM25RetrieverAgentArgs(GenericAgentArgs):
21+
flags: GenericPromptFlags = None
22+
chat_model_args: ChatModelArgs = None
23+
retriever_args: BM25RetrieverArgs = None
24+
retriever_flags: BM25RetrieverAgentFlags = None
25+
max_retry: int = 4
26+
agent_name: str = None
27+
28+
def __post_init__(self):
29+
if self.agent_name is None:
30+
self.agent_name = f"BM25RetrieverAgent-{self.chat_model_args.model_name}".replace(
31+
"/", "_"
32+
)
33+
34+
def make_agent(self) -> Agent:
35+
return BM25RetrieverAgent(
36+
self.chat_model_args,
37+
self.flags,
38+
self.retriever_args,
39+
self.retriever_flags,
40+
self.max_retry,
41+
)
42+
43+
44+
class BM25RetrieverAgent(GenericAgent):
45+
def __init__(
46+
self,
47+
chat_model_args: ChatModelArgs,
48+
flags,
49+
retriever_args: BM25RetrieverArgs,
50+
retriever_flags: BM25RetrieverAgentFlags,
51+
max_retry: int = 4,
52+
):
53+
super().__init__(chat_model_args, flags, max_retry)
54+
self.retriever_args = retriever_args
55+
self.retriever_flags = retriever_flags
56+
57+
def get_new_obs(self, obs: dict) -> str:
58+
query = (
59+
obs["goal"] + "\n" + obs["history"] if self.retriever_flags.use_history else obs["goal"]
60+
)
61+
axtree_txt: str = obs["axtree_txt"] if self.flags.obs.use_ax_tree else obs["pruned_dom"]
62+
# Initialize BM25 retriever with the current observation
63+
retriever = BM25SRetriever(
64+
axtree_txt,
65+
chunk_size=self.retriever_args.chunk_size,
66+
overlap=self.retriever_args.overlap,
67+
top_k=self.retriever_args.top_k,
68+
use_recursive_text_splitter=self.retriever_args.use_recursive_text_splitter,
69+
)
70+
# Retrieve the most relevant chunks
71+
relevant_chunks = retriever.retrieve(query)
72+
new_tree = ""
73+
for i, chunk in enumerate(relevant_chunks):
74+
new_tree += f"\n\nChunk {i}:\n{chunk}"
75+
return new_tree
76+
77+
def get_action(self, obs: dict):
78+
obs_history_copy = copy(self.obs_history)
79+
obs_history_copy.append(obs)
80+
history = dp.History(
81+
history_obs=obs_history_copy,
82+
actions=self.actions,
83+
memories=self.memories,
84+
thoughts=self.thoughts,
85+
flags=self.flags.obs,
86+
)
87+
obs["history"] = history.prompt
88+
obs["axtree_txt"] = self.get_new_obs(obs)
89+
action, info = super().get_action(obs)
90+
info.extra_info["pruned_tree"] = obs["axtree_txt"]
91+
info.extra_info["retriever_query"] = obs["goal"] + "\n" + obs["history"]
92+
return action, info

0 commit comments

Comments
 (0)