Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 153 additions & 175 deletions environments/wiki_search/wiki_search.py
Original file line number Diff line number Diff line change
@@ -1,190 +1,156 @@
import os
from typing import cast
from typing import cast, List, Dict, Any

import chromadb
from chromadb.api.types import Embeddable, EmbeddingFunction
from chromadb.utils import embedding_functions
from datasets import load_dataset
from openai import OpenAI

import verifiers as vf
from verifiers.rubrics.judge_rubric import JudgeRubric

CHROMA_DB_DIR = ".chroma_db"


def normalize_id(text: str) -> str:
"""Normalize free text into an id: lowercased with spaces as underscores.

Mirrors the section id normalization used elsewhere in this module.
"""
"""Normalize text to create section IDs."""
return text.strip().lower().replace(" ", "_")


def load_environment(
max_turns: int = 10,
judge_model: str = "gpt-4.1-mini",
judge_base_url: str = "https://api.openai.com/v1",
judge_api_key_var: str = "OPENAI_API_KEY",
embed_model: str = "text-embedding-3-small",
embed_base_url: str = "https://api.openai.com/v1",
embed_api_key_var: str = "OPENAI_API_KEY",
corpus_dataset: str = "willcb/rare-wiki-pages",
corpus_split: str = "train",
chroma_db_dir: str = CHROMA_DB_DIR,
) -> vf.Environment:
# load corpus into memory and build page_id -> row index
corpus = load_dataset(corpus_dataset, split=corpus_split)
page_id_to_title: dict[str, str] = {}
page_id_to_content: dict[str, str] = {}
for row in corpus:
row = cast(dict, row)
pid = row["id"]
title = row["title"]
content = row["content"]
page_id_to_title[pid] = title
page_id_to_content[pid] = content

# initialize persistent chroma collection with title embeddings
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
model_name=embed_model,
api_base=embed_base_url,
api_key=os.getenv(embed_api_key_var, "EMPTY"),
)
db_client = chromadb.PersistentClient(path=chroma_db_dir)
collection = db_client.get_or_create_collection(
name="wiki_titles",
embedding_function=cast(EmbeddingFunction[Embeddable], openai_ef),
)

# upsert missing pages
all_ids = list(page_id_to_title.keys())
existing: set[str] = set()
for i in range(0, len(all_ids), 500):
batch = all_ids[i : i + 500]
got = collection.get(ids=batch)
existing.update(got.get("ids", []))
missing = [pid for pid in all_ids if pid not in existing]
if missing:
documents = []
class WikiSemanticSearchEnv(vf.SemanticSearchEnv):
def __init__(
self,
corpus_dataset: str = "willcb/rare-wiki-pages",
corpus_split: str = "train",
collection_name: str = "wiki_titles",
**kwargs
):
self.corpus_dataset = corpus_dataset
self.corpus_split = corpus_split
self.page_id_to_title: Dict[str, str] = {}
self.page_id_to_content: Dict[str, str] = {}

super().__init__(collection_name=collection_name, **kwargs)

def prep_corpus(self) -> None:
corpus = load_dataset(self.corpus_dataset, split=self.corpus_split)

page_ids = []
titles = []
metadatas = []
for pid in missing:
title = str(page_id_to_title[pid]).strip()
if not title:
raise ValueError(f"Empty title for page_id {pid}")
documents.append(title)
metadatas.append({"title": title})
bs = 100
for i in range(0, len(missing), bs):
print(f"Upserting {len(missing[i : i + bs])} pages")
collection.upsert(
ids=missing[i : i + bs],
documents=documents[i : i + bs],
metadatas=metadatas[i : i + bs],
)

# define tools
def search_pages(query: str) -> list[dict]:
"""Search for top 10 relevant articles using title embedding similarity.

args:
query (str): The query to search for.

returns:
list[dict]: A list of dicts with page_id and title.

example:
"basketball" -> [{"page_id": "basketball", "title": "Basketball"}, {"page_id": "basketball_rules", "title": "Basketball Rules"}, ...]

for row in corpus:
row = cast(dict, row)
page_id = row["id"]
title = row["title"]
content = row["content"]

self.page_id_to_title[page_id] = title
self.page_id_to_content[page_id] = content

page_ids.append(page_id)
titles.append(title.strip()) # Use title as embedding text
metadatas.append({
"title": title,
})

self.upsert_documents(
document_ids=page_ids,
documents=titles,
metadatas=metadatas,
batch_size=500
)

async def search_pages(self, query: str) -> List[Dict]:
"""Search for relevant Wikipedia pages.

Args:
query: Search query

Returns:
List of dicts with page_id and title

Example:
"basketball" -> [{"page_id": "basketball", "title": "Basketball"}, ...]
"""
results = collection.query(query_texts=[query], n_results=10)
if not results:
raise ValueError(f"No results found for query: {query}")
if not results["metadatas"]:
raise ValueError(f"No results metadata found for query: {query}")
results = await self.search_documents(
query=query,
n_results=10,
return_contents=False,
return_metadata=True
)

# wiki search format
output = []
for i in range(len(results["ids"][0])):
output.append(
{
"page_id": results["ids"][0][i],
"title": results["metadatas"][0][i]["title"],
}
)

for result in results:
output.append({
"page_id": result["document_id"],
"title": result["metadata"]["title"] if result["metadata"] else "Unknown"
})

return output

def view_sections(page_id: str) -> list[dict]:
"""View the sections of a page.

args:
page_id (str): The ID of the page to view.

returns:
list[dict]: A list of dicts with section_id and section_name.

example:
async def view_sections(self, page_id: str) -> List[Dict]:
"""View sections of a Wikipedia page.
Args:
page_id: The page ID
Returns:
List of dicts with section_id and section_name
Example:
"basketball" -> [{"section_id": "basketball:history", "section_name": "History"}, ...]
"""
content = page_id_to_content[page_id]
if page_id not in self.page_id_to_content:
raise ValueError(f"Page not found: {page_id}")

content = self.page_id_to_content[page_id]
sections = []
lines = content.split("\n")

for i, line in enumerate(lines):
if line.startswith("#"):
section_name = line.lstrip("#").strip()
section_id = f"{page_id}:{normalize_id(section_name)}"
sections.append(
{
"section_id": section_id,
"section_name": section_name,
"start_line": i,
}
)

# if no sections found, return the whole page as one section
sections.append({
"section_id": section_id,
"section_name": section_name
})

# If no sections, return whole page
if not sections:
sections.append(
{
"section_id": f"{page_id}:full",
"section_name": "Full Page",
"start_line": 0,
}
)

return [
{"section_id": s["section_id"], "section_name": s["section_name"]}
for s in sections
]

def read_section(section_id: str) -> str:
"""Read a section of a page.

args:
section_id (str): The ID of the section to read.

returns:
str: The content of the section.

example:
"baseball:finnish_baseball" -> "Finnish baseball is a sport that is played in Finland..."
sections.append({
"section_id": f"{page_id}:full",
"section_name": "Full Page"
})

return sections

async def read_section(self, section_id: str) -> str:
"""Read content of a specific section.

Args:
section_id: Section ID (format: "page_id:section_name")

Returns:
Content of the section

Example:
"basketball:history" -> "Basketball was invented in 1891..."
"""
if ":" not in section_id:
raise ValueError(
"Invalid section_id format. Expected: page_id:section_name"
)

raise ValueError("Invalid section_id format. Expected: page_id:section_name")

page_id, section_name_id = section_id.split(":", 1)

# get Markdown content
content = page_id_to_content[page_id]

if page_id not in self.page_id_to_content:
raise ValueError(f"Page not found: {page_id}")

content = self.page_id_to_content[page_id]
lines = content.split("\n")

# special case for "full" section

if section_name_id == "full":
return content

# find section

section_start = None
section_end = None

for i, line in enumerate(lines):
if line.startswith("#"):
current_section = normalize_id(line.lstrip("#").strip())
Expand All @@ -193,42 +159,54 @@ def read_section(section_id: str) -> str:
elif section_start is not None and section_end is None:
section_end = i
break

if section_start is not None:
if section_end is None:
section_end = len(lines)
return "\n".join(lines[section_start:section_end])
else:
raise ValueError(f"Section not found: {section_id}")

tools = [
search_pages,
view_sections,
read_section,
]

dataset = load_dataset("willcb/wiki-trivia-questions", split="train")
async def judge_reward_func(judge, prompt, completion, answer, state) -> float:
judge_response = await judge(prompt, completion, answer, state)
return 1.0 if "yes" in judge_response.lower() else 0.0


vf_env = vf.ToolEnv(
def load_environment(
max_turns: int = 10,
judge_model: str = "gpt-4o-mini",
judge_base_url: str = "https://api.openai.com/v1",
judge_api_key_var: str = "OPENAI_API_KEY",
**kwargs
) -> vf.Environment:
dataset = load_dataset("willcb/wiki-trivia-questions", split="train")

vf_env = WikiSemanticSearchEnv(
dataset=dataset,
tools=tools,
parser=vf.Parser(),
max_turns=max_turns,
**kwargs
)

judge_client = OpenAI(base_url=judge_base_url, api_key=os.getenv(judge_api_key_var))
# use variation of search document as the original env
# searches and returns page titles instead of document contents
# maybe we update this later to do more standard document search via content embd
vf_env.remove_tool(vf_env.search_documents)
vf_env.add_tool(vf_env.search_pages)
vf_env.add_tool(vf_env.view_sections)
vf_env.add_tool(vf_env.read_section)

judge_client = OpenAI(
base_url=judge_base_url,
api_key=os.getenv(judge_api_key_var)
)
judge_rubric = JudgeRubric(
judge_client=judge_client, judge_model=judge_model, parser=vf_env.parser
judge_client=judge_client,
judge_model=judge_model,
parser=vf_env.parser
)

async def judge_reward_func(judge, prompt, completion, answer, state) -> float:
judge_response = await judge(prompt, completion, answer, state)
if "yes" in judge_response.lower():
return 1.0
else:
return 0.0


judge_rubric.add_reward_func(judge_reward_func, weight=1.0)
vf_env.rubric = vf.RubricGroup(rubrics=[judge_rubric, vf_env.rubric])

return vf_env
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ envs = [
"brave-search",
"nltk",
"textarena",
"chromadb>=0.5.0",
]
all = [
"torch>=2.7.0",
Expand Down
Loading