diff --git a/environments/wiki_search/wiki_search.py b/environments/wiki_search/wiki_search.py index fde69e790..546cfc989 100644 --- a/environments/wiki_search/wiki_search.py +++ b/environments/wiki_search/wiki_search.py @@ -1,190 +1,156 @@ import os -from typing import cast +from typing import cast, List, Dict, Any -import chromadb -from chromadb.api.types import Embeddable, EmbeddingFunction -from chromadb.utils import embedding_functions from datasets import load_dataset from openai import OpenAI import verifiers as vf from verifiers.rubrics.judge_rubric import JudgeRubric -CHROMA_DB_DIR = ".chroma_db" - - def normalize_id(text: str) -> str: - """Normalize free text into an id: lowercased with spaces as underscores. - - Mirrors the section id normalization used elsewhere in this module. - """ + """Normalize text to create section IDs.""" return text.strip().lower().replace(" ", "_") -def load_environment( - max_turns: int = 10, - judge_model: str = "gpt-4.1-mini", - judge_base_url: str = "https://api.openai.com/v1", - judge_api_key_var: str = "OPENAI_API_KEY", - embed_model: str = "text-embedding-3-small", - embed_base_url: str = "https://api.openai.com/v1", - embed_api_key_var: str = "OPENAI_API_KEY", - corpus_dataset: str = "willcb/rare-wiki-pages", - corpus_split: str = "train", - chroma_db_dir: str = CHROMA_DB_DIR, -) -> vf.Environment: - # load corpus into memory and build page_id -> row index - corpus = load_dataset(corpus_dataset, split=corpus_split) - page_id_to_title: dict[str, str] = {} - page_id_to_content: dict[str, str] = {} - for row in corpus: - row = cast(dict, row) - pid = row["id"] - title = row["title"] - content = row["content"] - page_id_to_title[pid] = title - page_id_to_content[pid] = content - - # initialize persistent chroma collection with title embeddings - openai_ef = embedding_functions.OpenAIEmbeddingFunction( - model_name=embed_model, - api_base=embed_base_url, - api_key=os.getenv(embed_api_key_var, "EMPTY"), - ) - db_client = chromadb.PersistentClient(path=chroma_db_dir) - collection = db_client.get_or_create_collection( - name="wiki_titles", - embedding_function=cast(EmbeddingFunction[Embeddable], openai_ef), - ) - - # upsert missing pages - all_ids = list(page_id_to_title.keys()) - existing: set[str] = set() - for i in range(0, len(all_ids), 500): - batch = all_ids[i : i + 500] - got = collection.get(ids=batch) - existing.update(got.get("ids", [])) - missing = [pid for pid in all_ids if pid not in existing] - if missing: - documents = [] +class WikiSemanticSearchEnv(vf.SemanticSearchEnv): + def __init__( + self, + corpus_dataset: str = "willcb/rare-wiki-pages", + corpus_split: str = "train", + collection_name: str = "wiki_titles", + **kwargs + ): + self.corpus_dataset = corpus_dataset + self.corpus_split = corpus_split + self.page_id_to_title: Dict[str, str] = {} + self.page_id_to_content: Dict[str, str] = {} + + super().__init__(collection_name=collection_name, **kwargs) + + def prep_corpus(self) -> None: + corpus = load_dataset(self.corpus_dataset, split=self.corpus_split) + + page_ids = [] + titles = [] metadatas = [] - for pid in missing: - title = str(page_id_to_title[pid]).strip() - if not title: - raise ValueError(f"Empty title for page_id {pid}") - documents.append(title) - metadatas.append({"title": title}) - bs = 100 - for i in range(0, len(missing), bs): - print(f"Upserting {len(missing[i : i + bs])} pages") - collection.upsert( - ids=missing[i : i + bs], - documents=documents[i : i + bs], - metadatas=metadatas[i : i + bs], - ) - - # define tools - def search_pages(query: str) -> list[dict]: - """Search for top 10 relevant articles using title embedding similarity. - - args: - query (str): The query to search for. - - returns: - list[dict]: A list of dicts with page_id and title. - - example: - "basketball" -> [{"page_id": "basketball", "title": "Basketball"}, {"page_id": "basketball_rules", "title": "Basketball Rules"}, ...] + + for row in corpus: + row = cast(dict, row) + page_id = row["id"] + title = row["title"] + content = row["content"] + + self.page_id_to_title[page_id] = title + self.page_id_to_content[page_id] = content + + page_ids.append(page_id) + titles.append(title.strip()) # Use title as embedding text + metadatas.append({ + "title": title, + }) + + self.upsert_documents( + document_ids=page_ids, + documents=titles, + metadatas=metadatas, + batch_size=500 + ) + + async def search_pages(self, query: str) -> List[Dict]: + """Search for relevant Wikipedia pages. + + Args: + query: Search query + + Returns: + List of dicts with page_id and title + + Example: + "basketball" -> [{"page_id": "basketball", "title": "Basketball"}, ...] """ - results = collection.query(query_texts=[query], n_results=10) - if not results: - raise ValueError(f"No results found for query: {query}") - if not results["metadatas"]: - raise ValueError(f"No results metadata found for query: {query}") + results = await self.search_documents( + query=query, + n_results=10, + return_contents=False, + return_metadata=True + ) + + # wiki search format output = [] - for i in range(len(results["ids"][0])): - output.append( - { - "page_id": results["ids"][0][i], - "title": results["metadatas"][0][i]["title"], - } - ) - + for result in results: + output.append({ + "page_id": result["document_id"], + "title": result["metadata"]["title"] if result["metadata"] else "Unknown" + }) + return output - - def view_sections(page_id: str) -> list[dict]: - """View the sections of a page. - - args: - page_id (str): The ID of the page to view. - - returns: - list[dict]: A list of dicts with section_id and section_name. - - example: + + async def view_sections(self, page_id: str) -> List[Dict]: + """View sections of a Wikipedia page. + + Args: + page_id: The page ID + + Returns: + List of dicts with section_id and section_name + + Example: "basketball" -> [{"section_id": "basketball:history", "section_name": "History"}, ...] """ - content = page_id_to_content[page_id] + if page_id not in self.page_id_to_content: + raise ValueError(f"Page not found: {page_id}") + + content = self.page_id_to_content[page_id] sections = [] lines = content.split("\n") + for i, line in enumerate(lines): if line.startswith("#"): section_name = line.lstrip("#").strip() section_id = f"{page_id}:{normalize_id(section_name)}" - sections.append( - { - "section_id": section_id, - "section_name": section_name, - "start_line": i, - } - ) - - # if no sections found, return the whole page as one section + sections.append({ + "section_id": section_id, + "section_name": section_name + }) + + # If no sections, return whole page if not sections: - sections.append( - { - "section_id": f"{page_id}:full", - "section_name": "Full Page", - "start_line": 0, - } - ) - - return [ - {"section_id": s["section_id"], "section_name": s["section_name"]} - for s in sections - ] - - def read_section(section_id: str) -> str: - """Read a section of a page. - - args: - section_id (str): The ID of the section to read. - - returns: - str: The content of the section. - - example: - "baseball:finnish_baseball" -> "Finnish baseball is a sport that is played in Finland..." + sections.append({ + "section_id": f"{page_id}:full", + "section_name": "Full Page" + }) + + return sections + + async def read_section(self, section_id: str) -> str: + """Read content of a specific section. + + Args: + section_id: Section ID (format: "page_id:section_name") + + Returns: + Content of the section + + Example: + "basketball:history" -> "Basketball was invented in 1891..." """ if ":" not in section_id: - raise ValueError( - "Invalid section_id format. Expected: page_id:section_name" - ) - + raise ValueError("Invalid section_id format. Expected: page_id:section_name") + page_id, section_name_id = section_id.split(":", 1) - - # get Markdown content - content = page_id_to_content[page_id] + + if page_id not in self.page_id_to_content: + raise ValueError(f"Page not found: {page_id}") + + content = self.page_id_to_content[page_id] lines = content.split("\n") - - # special case for "full" section + if section_name_id == "full": return content - - # find section + section_start = None section_end = None - + for i, line in enumerate(lines): if line.startswith("#"): current_section = normalize_id(line.lstrip("#").strip()) @@ -193,7 +159,7 @@ def read_section(section_id: str) -> str: elif section_start is not None and section_end is None: section_end = i break - + if section_start is not None: if section_end is None: section_end = len(lines) @@ -201,34 +167,46 @@ def read_section(section_id: str) -> str: else: raise ValueError(f"Section not found: {section_id}") - tools = [ - search_pages, - view_sections, - read_section, - ] - dataset = load_dataset("willcb/wiki-trivia-questions", split="train") +async def judge_reward_func(judge, prompt, completion, answer, state) -> float: + judge_response = await judge(prompt, completion, answer, state) + return 1.0 if "yes" in judge_response.lower() else 0.0 + - vf_env = vf.ToolEnv( +def load_environment( + max_turns: int = 10, + judge_model: str = "gpt-4o-mini", + judge_base_url: str = "https://api.openai.com/v1", + judge_api_key_var: str = "OPENAI_API_KEY", + **kwargs +) -> vf.Environment: + dataset = load_dataset("willcb/wiki-trivia-questions", split="train") + + vf_env = WikiSemanticSearchEnv( dataset=dataset, - tools=tools, - parser=vf.Parser(), max_turns=max_turns, + **kwargs ) - judge_client = OpenAI(base_url=judge_base_url, api_key=os.getenv(judge_api_key_var)) + # use variation of search document as the original env + # searches and returns page titles instead of document contents + # maybe we update this later to do more standard document search via content embd + vf_env.remove_tool(vf_env.search_documents) + vf_env.add_tool(vf_env.search_pages) + vf_env.add_tool(vf_env.view_sections) + vf_env.add_tool(vf_env.read_section) + + judge_client = OpenAI( + base_url=judge_base_url, + api_key=os.getenv(judge_api_key_var) + ) judge_rubric = JudgeRubric( - judge_client=judge_client, judge_model=judge_model, parser=vf_env.parser + judge_client=judge_client, + judge_model=judge_model, + parser=vf_env.parser ) - - async def judge_reward_func(judge, prompt, completion, answer, state) -> float: - judge_response = await judge(prompt, completion, answer, state) - if "yes" in judge_response.lower(): - return 1.0 - else: - return 0.0 - + judge_rubric.add_reward_func(judge_reward_func, weight=1.0) vf_env.rubric = vf.RubricGroup(rubrics=[judge_rubric, vf_env.rubric]) - + return vf_env diff --git a/pyproject.toml b/pyproject.toml index e29e4c2f6..9089aacfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ envs = [ "brave-search", "nltk", "textarena", + "chromadb>=0.5.0", ] all = [ "torch>=2.7.0", diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 9433e2bc3..3b988ece0 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -12,6 +12,7 @@ from .envs.singleturn_env import SingleTurnEnv from .envs.stateful_tool_env import StatefulToolEnv from .envs.tool_env import ToolEnv +from .envs.semantic_search_env import SemanticSearchEnv from .parsers.parser import Parser from .parsers.think_parser import ThinkParser from .parsers.xml_parser import XMLParser @@ -82,6 +83,7 @@ def setup_logging( "SandboxEnv", "StatefulToolEnv", "ToolEnv", + "SemanticSearchEnv", "EnvGroup", "extract_boxed_answer", "extract_hash_answer", diff --git a/verifiers/envs/semantic_search_env.py b/verifiers/envs/semantic_search_env.py new file mode 100644 index 000000000..93c79b80e --- /dev/null +++ b/verifiers/envs/semantic_search_env.py @@ -0,0 +1,163 @@ +import os +from abc import ABC, abstractmethod +from typing import Optional, Any, Dict, List + +import chromadb +from chromadb.utils import embedding_functions + +import verifiers as vf + + +class SemanticSearchEnv(vf.ToolEnv): + def __init__( + self, + collection_name: str = "documents", + embed_model: str = "text-embedding-3-small", + embed_base_url: str = "https://api.openai.com/v1", + embed_api_key_var: str = "OPENAI_API_KEY", + chroma_db_dir: str = ".chroma_db", + chroma_server_port: int = 8000, + **kwargs + ): + super().__init__(**kwargs) + self.collection_name = collection_name + self.chroma_db_dir = chroma_db_dir + self.emb_fn = embedding_functions.OpenAIEmbeddingFunction( + model_name=embed_model, + api_base=embed_base_url, + api_key=os.getenv(embed_api_key_var, "EMPTY") + ) + self.setup_vector_db() + self.prep_corpus() + self.async_client = None + + self.chroma_server_port = chroma_server_port + self.check_server_running() + + self.add_tool(self.search_documents) + + @abstractmethod + def prep_corpus(self) -> None: + """Prepare and load the corpus into the vector database. + + This method should: + 1. Load data from the source + 2. Upsert documents into the ChromaDB collection + """ + raise NotImplementedError("Subclasses must implement prep_corpus") + + def setup_vector_db(self): + """Persistent client for creating and adding documents during init""" + self.setup_client = chromadb.PersistentClient(path=self.chroma_db_dir) + self.collection = self.setup_client.get_or_create_collection( + name=self.collection_name, + embedding_function=self.emb_fn + ) + + async def get_async_client(self): + """Async cient for querying via client-server""" + if self.async_client is None: + self.async_client = await chromadb.AsyncHttpClient( + host="localhost", + port=self.chroma_server_port + ) + return self.async_client + + def check_server_running(self): + try: + client = chromadb.HttpClient(host="localhost", port=self.chroma_server_port) + client.heartbeat() + except Exception: + raise RuntimeError( + f"ChromaDB server is not running at localhost:{self.chroma_server_port}. " + f"Please start the server: chroma run --path {self.chroma_db_dir}" + ) from None + + # doing upsert as thats what was used in original wiki search env + # but pretty sure this could just be .add since we are only upserting new documents + # but upsert is specifically for updating existing documents + + def upsert_documents( + self, + document_ids: List[str], + documents: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + batch_size: int = 500 + ) -> None: + all_ids = document_ids + existing: set[str] = set() + + for i in range(0, len(all_ids), batch_size): + batch = all_ids[i : i + batch_size] + got = self.collection.get(ids=batch) + existing.update(got.get("ids", [])) + + # Filter to only new documents + to_upsert_indices = [i for i, doc_id in enumerate(document_ids) if doc_id not in existing] + + if to_upsert_indices: + to_upsert_ids = [document_ids[i] for i in to_upsert_indices] + to_upsert_docs = [documents[i] for i in to_upsert_indices] + to_upsert_meta = [metadatas[i] for i in to_upsert_indices] if metadatas else None + + # Upsert in batches + for i in range(0, len(to_upsert_ids), batch_size): + end_idx = min(i + batch_size, len(to_upsert_ids)) + print(f"Upserting {end_idx - i} documents...") + + upsert_args = { + "ids": to_upsert_ids[i:end_idx], + "documents": to_upsert_docs[i:end_idx], + } + + if to_upsert_meta: + upsert_args["metadatas"] = to_upsert_meta[i:end_idx] + + self.collection.upsert(**upsert_args) + + ########################## + ## SEMANTIC SEARCH TOOL ## + ########################## + + async def search_documents( + self, + query: str, + return_contents: bool = True, + return_metadata: bool = True, + ) -> list[dict]: + """Search for relevant documents using embedding similarity. + + Args: + query (str): The query to search for. + + Returns: + list[dict]: A list of dicts with document_id, title, and optionally contents and/or metadata. + """ + include = [] + if return_contents: + include.append("documents") + if return_metadata: + include.append("metadatas") + if not include: + include = None + + async_client = await self.get_async_client() + collection = await async_client.get_collection( + name=self.collection_name, + embedding_function=self.emb_fn + ) + results = await collection.query( + query_texts=[query], + include=include, + n_results=10 + ) + if not results: + raise ValueError(f"No results found for query: {query}") + output = [] + for i in range(len(results["ids"][0])): + output.append({ + "document_id": results["ids"][0][i], + "metadata": results["metadatas"][0][i] if return_metadata else None, + "content": results["documents"][0][i] if return_contents else None, + }) + return output \ No newline at end of file