Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skyrl-gym/skyrl_gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
entry_point="skyrl_gym.envs.search.env:SearchEnv",
)

register(
id="browse",
entry_point="skyrl_gym.envs.browse.env:BrowseEnv",
)

register(
id="lcb",
entry_point="skyrl_gym.envs.lcb.env:LCBEnv",
Expand Down
117 changes: 117 additions & 0 deletions skyrl-gym/skyrl_gym/envs/browse/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput, ConversationType
from typing import Any
from skyrl_gym.envs.search.utils import compute_score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute_score function is being imported from skyrl_gym.envs.search.utils, but a new utils.py file with this function is also being added in the browse directory. To ensure the correct utility is used and to follow better module structure, you should import from the local utils.py file.

Suggested change
from skyrl_gym.envs.search.utils import compute_score
from .utils import compute_score

from skyrl_gym.tools import BraveSearchToolGroup
from skyrl_gym.tools.brave_search import DEFAULT_TIMEOUT
import re
from typing import Dict, Optional, List
from omegaconf import DictConfig


class BrowseEnv(BaseTextEnv):
"""
Environment for Search execution tasks.

Based on Verl + Search-R1 integration
"""

def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}):
super().__init__()

assert "reward_spec" in extras, "reward_spec field is required"
assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field"
self.ground_truth = extras["reward_spec"]["ground_truth"]
self.max_turns = extras["max_turns"] if "max_turns" in extras else 2

# Initialize the tools
# name is hardcoded to "BraveSearchToolGroup", with tool name "browse"
self.tool_group = BraveSearchToolGroup(
search_url=env_config.get("search_url", "http://127.0.0.1:8000/retrieve"),
topk=env_config.get("topk", 3),
timeout=env_config.get("timeout", DEFAULT_TIMEOUT),
log_requests=env_config.get("log_requests", True),
)
self.init_tool_groups([self.tool_group])

# Chat history
# role (user, assistant), content (tool observation or LLM response)
self.chat_history: ConversationType = []

def _parse_action(self, action: str) -> List[Optional[str]]:
match = None
if "<search>" in action and "</search>" in action:
match = re.search(r"<search>(.*?)</search>", action, re.DOTALL)
return [match.group(1)] if match else [None]

def _get_reward(self, action: str, done: bool) -> float:
if done:
# Concat all chat history into a single string and compute reward
chat_history_str = "".join([item["content"] for item in self.chat_history])
return compute_score(chat_history_str, self.ground_truth)
else:
# No reward for intermediate steps for Search tasks
return 0

def _is_done(self, action: str) -> bool:
if self.turns >= self.max_turns:
return True
return "<answer>" in action and "</answer>" in action

def _validate_action(self, action: str):
stop_tags = ["</search>", "</answer>"]
for tag in stop_tags:
if tag in action:
assert action.split(tag, 1)[1] == "", (
f"{tag} detected in the response but it is not the last string generated. "
f"Use {stop_tags} as stop strings in the configuration."
)

def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str:
tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input)

return "\n<information>" + tool_output + "</information>\n"

def step(self, action: str) -> BaseTextEnvStepOutput:
self.turns += 1
self._validate_action(action)
self.chat_history.append({"role": "assistant", "content": action})

error = None
done = self._is_done(action)
reward = self._get_reward(action, done)

if done:
return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={})

try:
query = self._parse_action(action)
observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
except Exception as e:
error = str(e)
observation = None

# Wrap the observation properly as a message
if observation:
new_obs = {"role": "user", "content": observation}
elif error:
# Give error as observation if any
new_obs = {"role": "user", "content": error}
else:
new_obs = None

info = {
"tool_group": "BraveSearchToolGroup",
"tool_name": "browse",
"tool_input": query,
}

# Update chat history
if new_obs:
self.chat_history.append(new_obs)

return BaseTextEnvStepOutput(
observations=[new_obs] if new_obs else [],
reward=reward,
done=done,
metadata=info,
)
Comment on lines +86 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the step method has a bug where it will raise an exception if the model's action does not contain a <search> tag. This is because _parse_action returns [None], and _execute_tool is then called with None as input, which is not handled. The agent should be able to produce a response without calling a tool.

The logic should be updated to only execute a tool if a valid tool call is parsed from the action. I've also refactored to initialize info only when a tool is actually called.

        query = self._parse_action(action)
        observation = None
        info = {}

        if query[0] is not None:
            try:
                observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
                info = {
                    "tool_group": "BraveSearchToolGroup",
                    "tool_name": "browse",
                    "tool_input": query,
                }
            except Exception as e:
                error = str(e)
                observation = None

        # Wrap the observation properly as a message
        if observation:
            new_obs = {"role": "user", "content": observation}
        elif error:
            # Give error as observation if any
            new_obs = {"role": "user", "content": error}
        else:
            new_obs = None

        # Update chat history
        if new_obs:
            self.chat_history.append(new_obs)

        return BaseTextEnvStepOutput(
            observations=[new_obs] if new_obs else [],
            reward=reward,
            done=done,
            metadata=info,
        )

118 changes: 118 additions & 0 deletions skyrl-gym/skyrl_gym/envs/browse/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 Search-R1 Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py

import re
import string


def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def em_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer == normalized_prediction:
score = 1
break
return score


def subem_check(prediction, golden_answers):
if isinstance(golden_answers, str):
golden_answers = [golden_answers]
normalized_prediction = normalize_answer(prediction)
score = 0
for golden_answer in golden_answers:
golden_answer = normalize_answer(golden_answer)
if golden_answer in normalized_prediction:
score = 1
break
return score


def extract_solution(solution_str):
"""Extract the equation from the solution string."""
answer_pattern = r"<answer>(.*?)</answer>"
match = re.finditer(answer_pattern, solution_str, re.DOTALL)
matches = list(match)

# If there are 0 matches, return None
if len(matches) < 1:
return None

# If there are 2 or more matches, return the last one
return matches[-1].group(1).strip()


def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method parameter is defined but not used within this function. It should be removed to avoid confusion.

Suggested change
def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
def compute_score(solution_str, ground_truth, format_score=0.0, score=1.0):

"""The scoring function for exact match (EM).

Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)

if answer is None:
return 0
else:
if em_check(answer, ground_truth["target"]):
return score
else:
return format_score


def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method parameter is defined but not used within this function. It should be removed to avoid confusion.

Suggested change
def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
def compute_score_subem(solution_str, ground_truth, format_score=0.0, score=1.0):

"""The scoring function for substring exact match (EM).

Args:
solution_str: the solution text
ground_truth: the ground truth
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_score: the score for the format
score: the score for the correct answer
"""
answer = extract_solution(solution_str=solution_str)

if answer is None:
return 0
else:
if subem_check(answer, ground_truth["target"]):
return score
else:
return format_score
3 changes: 2 additions & 1 deletion skyrl-gym/skyrl_gym/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .sql import SQLCodeExecutorToolGroup
from .search import SearchToolGroup
from .python import PythonCodeExecutorToolGroup
from .brave_search import BraveSearchToolGroup

__all__ = ["SQLCodeExecutorToolGroup", "SearchToolGroup", "PythonCodeExecutorToolGroup"]
__all__ = ["SQLCodeExecutorToolGroup", "SearchToolGroup", "PythonCodeExecutorToolGroup", "BraveSearchToolGroup"]
Loading