Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ dev = [
hint = [
"sentence-transformers>=5.0.0",
]
retrievers = [
"bm25s>=0.2.14",
"langchain>=0.3.27",
]


[project.scripts]
Expand Down
89 changes: 89 additions & 0 deletions src/agentlab/agents/bm25_agent/README.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have the the bm25_agent and embedding_agent in the focus_agent subdirectory, as they are related baselines. @recursix Do you have any thoughts about this?

Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# BM25Agent

A retrieval-augmented agent that uses BM25 (Best Matching 25) algorithm to filter and retrieve the most relevant parts of the accessibility tree (AXTree) based on the current goal and task history.

## Overview

``BM25Agent`` extends ``GenericAgent`` with intelligent content retrieval capabilities. Instead of processing the entire accessibility tree, it chunks the content and uses BM25 ranking to retrieve only the most relevant sections, reducing token usage and improving focus on task-relevant elements.

## Key Features

- **BM25-based retrieval**: Uses the BM25 algorithm to rank and retrieve relevant content chunks
- **Token-aware chunking**: Splits accessibility trees using tiktoken for optimal token usage
- **Configurable parameters**: Adjustable chunk size, overlap, and top-k retrieval
- **History integration**: Can optionally include task history in retrieval queries
- **Memory efficient**: Reduces context size by filtering irrelevant content

## Architecture

```text
Query (goal + history) → BM25 Retriever → Top-K Chunks → LLM → Action
AXTree
```

## Usage

### Basic Configuration

```python
from agentlab.agents.bm25_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
from agentlab.agents.bm25_agent.bm25_retriever import BM25RetrieverArgs
from agentlab.agents.bm25_agent.bm25_retriever_agent import BM25RetrieverAgentFlags

# Configure retriever parameters
retriever_args = BM25RetrieverArgs(
chunk_size=200, # Tokens per chunk
overlap=10, # Token overlap between chunks
top_k=10, # Number of chunks to retrieve
use_recursive_text_splitter=False # Use Langchain text splitter
)

# Configure agent flags
retriever_flags = BM25RetrieverAgentFlags(
use_history=True # Include task history in queries
)

# Create agent
agent_args = BM25RetrieverAgentArgs(
chat_model_args=your_chat_model_args,
flags=your_flags,
retriever_args=retriever_args,
retriever_flags=retriever_flags
)

agent = agent_args.make_agent()
```

### Pre-configured Agents

```python
from agentlab.agents.bm25_agent.agent_configs import (
BM25_RETRIEVER_AGENT, # Chunk size is 200 tokens
BM25_RETRIEVER_AGENT_100 # Chunk size is 100 tokens
)

# Use default configuration
agent = BM25_RETRIEVER_AGENT.make_agent()
```

## Configuration Parameters

### BM25RetrieverArgs

- `chunk_size` (int, default=100): Number of tokens per chunk
- `overlap` (int, default=10): Token overlap between consecutive chunks
- `top_k` (int, default=5): Number of most relevant chunks to retrieve
- `use_recursive_text_splitter` (bool, default=False): Use LangChain's recursive text splitter. Using this text splitter will override the ``chunk_size`` an ``overlap`` parameters.

### BM25RetrieverAgentFlags

- `use_history` (bool, default=False): Include interaction history in retrieval queries

## Citation

If you use this agent in your work, please consider citing:

```bibtex

```
8 changes: 8 additions & 0 deletions src/agentlab/agents/bm25_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .bm25_retriever_agent import BM25RetrieverAgent, BM25RetrieverAgentArgs
from .bm25_retriever import BM25RetrieverArgs
from .agent_configs import (
BM25_RETRIEVER_AGENT,
BM25_RETRIEVER_AGENT_100,
BM25_RETRIEVER_AGENT_50,
BM25_RETRIEVER_AGENT_500,
)
68 changes: 68 additions & 0 deletions src/agentlab/agents/bm25_agent/agent_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .bm25_retriever import BM25RetrieverArgs
from .bm25_retriever_agent import BM25RetrieverAgentArgs, BM25RetrieverAgentFlags

FLAGS_GPT_4o = FLAGS_GPT_4o.copy()
FLAGS_GPT_4o.obs.use_think_history = True

BM25_RETRIEVER_AGENT = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=200,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_100 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-100",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=100,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_50 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-50",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=50,
overlap=5,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)

BM25_RETRIEVER_AGENT_500 = BM25RetrieverAgentArgs(
agent_name="BM25RetrieverAgent-4.1-500",
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"],
flags=FLAGS_GPT_4o,
retriever_args=BM25RetrieverArgs(
top_k=10,
chunk_size=500,
overlap=10,
use_recursive_text_splitter=False,
),
retriever_flags=BM25RetrieverAgentFlags(
use_history=True,
),
)
133 changes: 133 additions & 0 deletions src/agentlab/agents/bm25_agent/bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import re
from dataclasses import dataclass

try:
import bm25s
except ImportError:
raise ImportError("bm25s is not installed. Please install it using `pip agentlab[retrievers]`.")
import tiktoken # Added import for tiktoken

from .utils import get_chunks_from_tokenizer


def count_tokens(text: str) -> int:
"""Count the number of tokens in the text using tiktoken for GPT-4."""
encoding = tiktoken.encoding_for_model("gpt-4")
tokens = encoding.encode(text)
return len(tokens)


@dataclass
class BM25RetrieverArgs:
chunk_size: int = 100
overlap: int = 10
top_k: int = 5
use_recursive_text_splitter: bool = False


class BM25SRetriever:
"""Simple retriever using BM25S to retrieve the most relevant lines"""

def __init__(
self,
tree: str,
chunk_size: int,
overlap: int,
top_k: int,
use_recursive_text_splitter: bool,
):
self.chunk_size = chunk_size
self.overlap = overlap
self.top_k = top_k
self.use_recursive_text_splitter = use_recursive_text_splitter
corpus = get_chunks_from_tokenizer(tree)
self.retriever = bm25s.BM25(corpus=corpus)
tokenized_corpus = bm25s.tokenize(corpus)
self.retriever.index(tokenized_corpus)

def retrieve(self, query):
tokenized_query = bm25s.tokenize(query)
if self.top_k > len(self.retriever.corpus):
results, _ = self.retriever.retrieve(
query_tokens=tokenized_query, k=len(self.retriever.corpus)
)
else:
results, _ = self.retriever.retrieve(query_tokens=tokenized_query, k=self.top_k)
return [str(res) for res in results[0]]

def create_text_chunks(self, axtree, chunk_size=200, overlap=50):
if self.use_recursive_text_splitter:
try:
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
)
except ImportError:
raise ImportError(
"langchain is not installed. Please install it using `pip agentlab[retrievers]`."
)

text_splitter = RecursiveCharacterTextSplitter()
return text_splitter.split_text(axtree)
else:
return get_chunks_from_tokenizer(axtree, self.chunk_size, self.overlap)

@staticmethod
def extract_bid(line):
"""
Extracts the bid from a line in the format '[bid] textarea ...'.

Parameters:
line (str): The input line containing the bid in square brackets.

Returns:
str: The extracted bid, or None if no bid is found.
"""
match = re.search(r"\[([a-zA-Z0-9]+)\]", line)
if match:
return match.group(1)
return None

@classmethod
def get_elements_around(cls, tree, element_id, n):
"""
Get n elements around the given element_id from the AXTree while preserving its indentation structure.

:param tree: String representing the AXTree with indentations.
:param element_id: The element ID to center around (can include alphanumeric IDs like 'a203').
:param n: The number of elements to include before and after.
:return: String of the AXTree elements around the given element ID, preserving indentation.
"""
# Split the tree into lines
lines = tree.splitlines()

# Extract the line indices and content containing element IDs
id_lines = [(i, line) for i, line in enumerate(lines) if "[" in line and "]" in line]

# Parse the IDs from the lines
parsed_ids = []
for idx, line in id_lines:
try:
element_id_in_line = line.split("[")[1].split("]")[0]
parsed_ids.append((idx, element_id_in_line, line))
except IndexError:
continue

# Find the index of the element with the given ID
target_idx = next(
(i for i, (_, eid, _) in enumerate(parsed_ids) if eid == element_id), None
)

if target_idx is None:
raise ValueError(f"Element ID {element_id} not found in the tree.")

# Calculate the range of elements to include
start_idx = max(0, target_idx - n)
end_idx = min(len(parsed_ids), target_idx + n + 1)

# Collect the lines to return
result_lines = []
for idx in range(start_idx, end_idx):
line_idx = parsed_ids[idx][0]
result_lines.append(lines[line_idx])

return "\n".join(result_lines)
Loading
Loading