-
Notifications
You must be signed in to change notification settings - Fork 222
pushing brave search multi turn rl code #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput, ConversationType | ||
| from typing import Any | ||
| from skyrl_gym.envs.search.utils import compute_score | ||
| from skyrl_gym.tools import BraveSearchToolGroup | ||
| from skyrl_gym.tools.brave_search import DEFAULT_TIMEOUT | ||
| import re | ||
| from typing import Dict, Optional, List | ||
| from omegaconf import DictConfig | ||
|
|
||
|
|
||
| class BrowseEnv(BaseTextEnv): | ||
| """ | ||
| Environment for Search execution tasks. | ||
|
|
||
| Based on Verl + Search-R1 integration | ||
| """ | ||
|
|
||
| def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}): | ||
| super().__init__() | ||
|
|
||
| assert "reward_spec" in extras, "reward_spec field is required" | ||
| assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field" | ||
| self.ground_truth = extras["reward_spec"]["ground_truth"] | ||
| self.max_turns = extras["max_turns"] if "max_turns" in extras else 2 | ||
|
|
||
| # Initialize the tools | ||
| # name is hardcoded to "BraveSearchToolGroup", with tool name "browse" | ||
| self.tool_group = BraveSearchToolGroup( | ||
| search_url=env_config.get("search_url", "http://127.0.0.1:8000/retrieve"), | ||
| topk=env_config.get("topk", 3), | ||
| timeout=env_config.get("timeout", DEFAULT_TIMEOUT), | ||
| log_requests=env_config.get("log_requests", True), | ||
| ) | ||
| self.init_tool_groups([self.tool_group]) | ||
|
|
||
| # Chat history | ||
| # role (user, assistant), content (tool observation or LLM response) | ||
| self.chat_history: ConversationType = [] | ||
|
|
||
| def _parse_action(self, action: str) -> List[Optional[str]]: | ||
| match = None | ||
| if "<search>" in action and "</search>" in action: | ||
| match = re.search(r"<search>(.*?)</search>", action, re.DOTALL) | ||
| return [match.group(1)] if match else [None] | ||
|
|
||
| def _get_reward(self, action: str, done: bool) -> float: | ||
| if done: | ||
| # Concat all chat history into a single string and compute reward | ||
| chat_history_str = "".join([item["content"] for item in self.chat_history]) | ||
| return compute_score(chat_history_str, self.ground_truth) | ||
| else: | ||
| # No reward for intermediate steps for Search tasks | ||
| return 0 | ||
|
|
||
| def _is_done(self, action: str) -> bool: | ||
| if self.turns >= self.max_turns: | ||
| return True | ||
| return "<answer>" in action and "</answer>" in action | ||
|
|
||
| def _validate_action(self, action: str): | ||
| stop_tags = ["</search>", "</answer>"] | ||
| for tag in stop_tags: | ||
| if tag in action: | ||
| assert action.split(tag, 1)[1] == "", ( | ||
| f"{tag} detected in the response but it is not the last string generated. " | ||
| f"Use {stop_tags} as stop strings in the configuration." | ||
| ) | ||
|
|
||
| def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: | ||
| tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input) | ||
|
|
||
| return "\n<information>" + tool_output + "</information>\n" | ||
|
|
||
| def step(self, action: str) -> BaseTextEnvStepOutput: | ||
| self.turns += 1 | ||
| self._validate_action(action) | ||
| self.chat_history.append({"role": "assistant", "content": action}) | ||
|
|
||
| error = None | ||
| done = self._is_done(action) | ||
| reward = self._get_reward(action, done) | ||
|
|
||
| if done: | ||
| return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) | ||
|
|
||
| try: | ||
| query = self._parse_action(action) | ||
| observation = self._execute_tool("BraveSearchToolGroup", "browse", query) | ||
| except Exception as e: | ||
| error = str(e) | ||
| observation = None | ||
|
|
||
| # Wrap the observation properly as a message | ||
| if observation: | ||
| new_obs = {"role": "user", "content": observation} | ||
| elif error: | ||
| # Give error as observation if any | ||
| new_obs = {"role": "user", "content": error} | ||
| else: | ||
| new_obs = None | ||
|
|
||
| info = { | ||
| "tool_group": "BraveSearchToolGroup", | ||
| "tool_name": "browse", | ||
| "tool_input": query, | ||
| } | ||
|
|
||
| # Update chat history | ||
| if new_obs: | ||
| self.chat_history.append(new_obs) | ||
|
|
||
| return BaseTextEnvStepOutput( | ||
| observations=[new_obs] if new_obs else [], | ||
| reward=reward, | ||
| done=done, | ||
| metadata=info, | ||
| ) | ||
|
Comment on lines
+86
to
+117
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of the The logic should be updated to only execute a tool if a valid tool call is parsed from the action. I've also refactored to initialize query = self._parse_action(action)
observation = None
info = {}
if query[0] is not None:
try:
observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
info = {
"tool_group": "BraveSearchToolGroup",
"tool_name": "browse",
"tool_input": query,
}
except Exception as e:
error = str(e)
observation = None
# Wrap the observation properly as a message
if observation:
new_obs = {"role": "user", "content": observation}
elif error:
# Give error as observation if any
new_obs = {"role": "user", "content": error}
else:
new_obs = None
# Update chat history
if new_obs:
self.chat_history.append(new_obs)
return BaseTextEnvStepOutput(
observations=[new_obs] if new_obs else [],
reward=reward,
done=done,
metadata=info,
) |
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,118 @@ | ||||||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||||||
| # Copyright 2023-2024 SGLang Team | ||||||
| # Copyright 2025 Search-R1 Contributors | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
| # Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py | ||||||
|
|
||||||
| import re | ||||||
| import string | ||||||
|
|
||||||
|
|
||||||
| def normalize_answer(s): | ||||||
| def remove_articles(text): | ||||||
| return re.sub(r"\b(a|an|the)\b", " ", text) | ||||||
|
|
||||||
| def white_space_fix(text): | ||||||
| return " ".join(text.split()) | ||||||
|
|
||||||
| def remove_punc(text): | ||||||
| exclude = set(string.punctuation) | ||||||
| return "".join(ch for ch in text if ch not in exclude) | ||||||
|
|
||||||
| def lower(text): | ||||||
| return text.lower() | ||||||
|
|
||||||
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | ||||||
|
|
||||||
|
|
||||||
| def em_check(prediction, golden_answers): | ||||||
| if isinstance(golden_answers, str): | ||||||
| golden_answers = [golden_answers] | ||||||
| normalized_prediction = normalize_answer(prediction) | ||||||
| score = 0 | ||||||
| for golden_answer in golden_answers: | ||||||
| golden_answer = normalize_answer(golden_answer) | ||||||
| if golden_answer == normalized_prediction: | ||||||
| score = 1 | ||||||
| break | ||||||
| return score | ||||||
|
|
||||||
|
|
||||||
| def subem_check(prediction, golden_answers): | ||||||
| if isinstance(golden_answers, str): | ||||||
| golden_answers = [golden_answers] | ||||||
| normalized_prediction = normalize_answer(prediction) | ||||||
| score = 0 | ||||||
| for golden_answer in golden_answers: | ||||||
| golden_answer = normalize_answer(golden_answer) | ||||||
| if golden_answer in normalized_prediction: | ||||||
| score = 1 | ||||||
| break | ||||||
| return score | ||||||
|
|
||||||
|
|
||||||
| def extract_solution(solution_str): | ||||||
| """Extract the equation from the solution string.""" | ||||||
| answer_pattern = r"<answer>(.*?)</answer>" | ||||||
| match = re.finditer(answer_pattern, solution_str, re.DOTALL) | ||||||
| matches = list(match) | ||||||
|
|
||||||
| # If there are 0 matches, return None | ||||||
| if len(matches) < 1: | ||||||
| return None | ||||||
|
|
||||||
| # If there are 2 or more matches, return the last one | ||||||
| return matches[-1].group(1).strip() | ||||||
|
|
||||||
|
|
||||||
| def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| """The scoring function for exact match (EM). | ||||||
|
|
||||||
| Args: | ||||||
| solution_str: the solution text | ||||||
| ground_truth: the ground truth | ||||||
| method: the method to extract the solution, choices are 'strict' and 'flexible' | ||||||
| format_score: the score for the format | ||||||
| score: the score for the correct answer | ||||||
| """ | ||||||
| answer = extract_solution(solution_str=solution_str) | ||||||
|
|
||||||
| if answer is None: | ||||||
| return 0 | ||||||
| else: | ||||||
| if em_check(answer, ground_truth["target"]): | ||||||
| return score | ||||||
| else: | ||||||
| return format_score | ||||||
|
|
||||||
|
|
||||||
| def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| """The scoring function for substring exact match (EM). | ||||||
|
|
||||||
| Args: | ||||||
| solution_str: the solution text | ||||||
| ground_truth: the ground truth | ||||||
| method: the method to extract the solution, choices are 'strict' and 'flexible' | ||||||
| format_score: the score for the format | ||||||
| score: the score for the correct answer | ||||||
| """ | ||||||
| answer = extract_solution(solution_str=solution_str) | ||||||
|
|
||||||
| if answer is None: | ||||||
| return 0 | ||||||
| else: | ||||||
| if subem_check(answer, ground_truth["target"]): | ||||||
| return score | ||||||
| else: | ||||||
| return format_score | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from .sql import SQLCodeExecutorToolGroup | ||
| from .search import SearchToolGroup | ||
| from .python import PythonCodeExecutorToolGroup | ||
| from .brave_search import BraveSearchToolGroup | ||
|
|
||
| __all__ = ["SQLCodeExecutorToolGroup", "SearchToolGroup", "PythonCodeExecutorToolGroup"] | ||
| __all__ = ["SQLCodeExecutorToolGroup", "SearchToolGroup", "PythonCodeExecutorToolGroup", "BraveSearchToolGroup"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
compute_scorefunction is being imported fromskyrl_gym.envs.search.utils, but a newutils.pyfile with this function is also being added in thebrowsedirectory. To ensure the correct utility is used and to follow better module structure, you should import from the localutils.pyfile.