Skip to content

Conversation

@pranavraja99
Copy link

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new BrowseEnv for multi-turn reinforcement learning using Brave Search, along with the necessary tool integrations and an example script. The overall structure is good, but there are several critical issues that need addressing. The BraveSearchToolGroup implementation incorrectly handles API requests by not using the configured session or URL, and its retry logic is flawed. The new BrowseEnv has a bug in its step method that can cause crashes when no tool is called. Additionally, there are some unused parameters and leftover debugging print statements that should be cleaned up. I've provided specific comments and suggestions to resolve these issues.

Comment on lines +69 to +240
def _search(
self,
keywords: str,
max_results: Optional[int] = 10,
region: Optional[str] = "wt-wt",
) -> Union[List[Dict], Dict]:
"""
Queries the Brave Search API for the provided keywords and region.
Args:
keywords (str): The keywords to search for.
max_results (int, optional): The maximum number of search results to return. Defaults to 10.
region (str, optional): The region to search in. Defaults to "wt-wt". Possible values include:
- xa-ar for Arabia
- xa-en for Arabia (en)
- ar-es for Argentina
- au-en for Australia
- at-de for Austria
- be-fr for Belgium (fr)
- be-nl for Belgium (nl)
- br-pt for Brazil
- bg-bg for Bulgaria
- ca-en for Canada
- ca-fr for Canada (fr)
- ct-ca for Catalan
- cl-es for Chile
- cn-zh for China
- co-es for Colombia
- hr-hr for Croatia
- cz-cs for Czech Republic
- dk-da for Denmark
- ee-et for Estonia
- fi-fi for Finland
- fr-fr for France
- de-de for Germany
- gr-el for Greece
- hk-tzh for Hong Kong
- hu-hu for Hungary
- in-en for India
- id-id for Indonesia
- id-en for Indonesia (en)
- ie-en for Ireland
- il-he for Israel
- it-it for Italy
- jp-jp for Japan
- kr-kr for Korea
- lv-lv for Latvia
- lt-lt for Lithuania
- xl-es for Latin America
- my-ms for Malaysia
- my-en for Malaysia (en)
- mx-es for Mexico
- nl-nl for Netherlands
- nz-en for New Zealand
- no-no for Norway
- pe-es for Peru
- ph-en for Philippines
- ph-tl for Philippines (tl)
- pl-pl for Poland
- pt-pt for Portugal
- ro-ro for Romania
- ru-ru for Russia
- sg-en for Singapore
- sk-sk for Slovak Republic
- sl-sl for Slovenia
- za-en for South Africa
- es-es for Spain
- se-sv for Sweden
- ch-de for Switzerland (de)
- ch-fr for Switzerland (fr)
- ch-it for Switzerland (it)
- tw-tzh for Taiwan
- th-th for Thailand
- tr-tr for Turkey
- ua-uk for Ukraine
- uk-en for United Kingdom
- us-en for United States
- ue-es for United States (es)
- ve-es for Venezuela
- vn-vi for Vietnam
- wt-wt for No region
Returns:
list: A list of search result dictionaries, each containing:
- 'title' (str): The title of the search result.
- 'href' (str): The URL of the search result.
- 'body' (str): A brief description or snippet from the search result.
Or a dict with 'error' key if an error occurred.
"""
brave_api_key = os.getenv("BRAVE_API_KEY")

if not brave_api_key:
return {"error": "No BRAVE_API_KEY environment variable found. Please set it to use this function."}

backoff = 2 # initial back-off in seconds

# Map region codes to Brave Search country codes (ISO 3166-1 alpha-2)
region_mapping = {
"xa-ar": "SA", "xa-en": "SA", "ar-es": "AR", "au-en": "AU", "at-de": "AT",
"be-fr": "BE", "be-nl": "BE", "br-pt": "BR", "bg-bg": "BG", "ca-en": "CA",
"ca-fr": "CA", "ct-ca": "ES", "cl-es": "CL", "cn-zh": "CN", "co-es": "CO",
"hr-hr": "HR", "cz-cs": "CZ", "dk-da": "DK", "ee-et": "EE", "fi-fi": "FI",
"fr-fr": "FR", "de-de": "DE", "gr-el": "GR", "hk-tzh": "HK", "hu-hu": "HU",
"in-en": "IN", "id-id": "ID", "id-en": "ID", "ie-en": "IE", "il-he": "IL",
"it-it": "IT", "jp-jp": "JP", "kr-kr": "KR", "lv-lv": "LV", "lt-lt": "LT",
"xl-es": "MX", "my-ms": "MY", "my-en": "MY", "mx-es": "MX", "nl-nl": "NL",
"nz-en": "NZ", "no-no": "NO", "pe-es": "PE", "ph-en": "PH", "ph-tl": "PH",
"pl-pl": "PL", "pt-pt": "PT", "ro-ro": "RO", "ru-ru": "RU", "sg-en": "SG",
"sk-sk": "SK", "sl-sl": "SI", "za-en": "ZA", "es-es": "ES", "se-sv": "SE",
"ch-de": "CH", "ch-fr": "CH", "ch-it": "CH", "tw-tzh": "TW", "th-th": "TH",
"tr-tr": "TR", "ua-uk": "UA", "uk-en": "GB", "us-en": "US", "ue-es": "US",
"ve-es": "VE", "vn-vi": "VN", "wt-wt": "ALL"
}

country = region_mapping.get(region, "ALL")

headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": brave_api_key
}

params = {
"q": keywords,
"count": max_results,
"search_lang": "en",
"country": country,
}

# Infinite retry loop with exponential backoff for rate limits
while True:
try:
response = requests.get(
"https://api.search.brave.com/res/v1/web/search",
headers=headers,
params=params,
timeout=30
)
response.raise_for_status()
break # Success
except requests.exceptions.HTTPError as e:
if response.status_code == 429:
wait_time = backoff + random.uniform(0, backoff)
print(f"⚠️ Rate limit hit (429). Retrying in {wait_time:.1f} seconds...")
time.sleep(wait_time)
backoff = min(backoff * 2, 120) # cap the back-off at 2 minutes
continue
else:
return {"error": f"HTTP error occurred: {str(e)}"}
except Exception as e:
return {"error": f"An error occurred: {str(e)}"}

try:
search_results = response.json()
except Exception as e:
return {"error": f"Failed to parse response JSON: {str(e)}"}

if "web" not in search_results or "results" not in search_results["web"]:
return {"error": "No results found in the response."}

web_results = search_results["web"]["results"]

# Convert the search results to the desired format
results = []
for result in web_results[:max_results]:
results.append({
"title": result.get("title", ""),
"href": result.get("url", ""),
"body": result.get("description", ""),
})

return results No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _search method has several critical issues that deviate from the class design and introduce bugs:

  1. Session Not Used: It does not use the self.session object for making requests, instead calling requests.get() directly. This negates the benefits of the connection pooling configured in __init__.
  2. URL Hardcoded: It uses a hardcoded URL to the Brave Search API and ignores self.search_url from __init__. This makes the search_url parameter misleading, especially since the example scripts configure it for a local retriever.

This method should be refactored to use self.session.get(self.search_url, ...) to align with the class's design and configuration.

@@ -0,0 +1,117 @@
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput, ConversationType
from typing import Any
from skyrl_gym.envs.search.utils import compute_score
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The compute_score function is being imported from skyrl_gym.envs.search.utils, but a new utils.py file with this function is also being added in the browse directory. To ensure the correct utility is used and to follow better module structure, you should import from the local utils.py file.

Suggested change
from skyrl_gym.envs.search.utils import compute_score
from .utils import compute_score

Comment on lines +86 to +117
try:
query = self._parse_action(action)
observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
except Exception as e:
error = str(e)
observation = None

# Wrap the observation properly as a message
if observation:
new_obs = {"role": "user", "content": observation}
elif error:
# Give error as observation if any
new_obs = {"role": "user", "content": error}
else:
new_obs = None

info = {
"tool_group": "BraveSearchToolGroup",
"tool_name": "browse",
"tool_input": query,
}

# Update chat history
if new_obs:
self.chat_history.append(new_obs)

return BaseTextEnvStepOutput(
observations=[new_obs] if new_obs else [],
reward=reward,
done=done,
metadata=info,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the step method has a bug where it will raise an exception if the model's action does not contain a <search> tag. This is because _parse_action returns [None], and _execute_tool is then called with None as input, which is not handled. The agent should be able to produce a response without calling a tool.

The logic should be updated to only execute a tool if a valid tool call is parsed from the action. I've also refactored to initialize info only when a tool is actually called.

        query = self._parse_action(action)
        observation = None
        info = {}

        if query[0] is not None:
            try:
                observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
                info = {
                    "tool_group": "BraveSearchToolGroup",
                    "tool_name": "browse",
                    "tool_input": query,
                }
            except Exception as e:
                error = str(e)
                observation = None

        # Wrap the observation properly as a message
        if observation:
            new_obs = {"role": "user", "content": observation}
        elif error:
            # Give error as observation if any
            new_obs = {"role": "user", "content": error}
        else:
            new_obs = None

        # Update chat history
        if new_obs:
            self.chat_history.append(new_obs)

        return BaseTextEnvStepOutput(
            observations=[new_obs] if new_obs else [],
            reward=reward,
            done=done,
            metadata=info,
        )

Comment on lines +199 to +219
while True:
try:
response = requests.get(
"https://api.search.brave.com/res/v1/web/search",
headers=headers,
params=params,
timeout=30
)
response.raise_for_status()
break # Success
except requests.exceptions.HTTPError as e:
if response.status_code == 429:
wait_time = backoff + random.uniform(0, backoff)
print(f"⚠️ Rate limit hit (429). Retrying in {wait_time:.1f} seconds...")
time.sleep(wait_time)
backoff = min(backoff * 2, 120) # cap the back-off at 2 minutes
continue
else:
return {"error": f"HTTP error occurred: {str(e)}"}
except Exception as e:
return {"error": f"An error occurred: {str(e)}"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The retry logic in _search has several issues:

  1. Infinite Loop: The while True loop for retries is not bounded, which could lead to an infinite loop on repeated 429 errors. The MAX_RETRIES constant should be used to limit retries.
  2. Hardcoded Timeout: The request timeout is hardcoded to 30 instead of using self.timeout.
  3. Logging: print is used for logging rate limit warnings. It's better to use the configured logger (e.g., logger.warning(...)).
  4. Unused Constants: MAX_RETRIES and INITIAL_RETRY_DELAY are defined but not used.

The retry logic should be updated to use a bounded loop and use the class attributes for configuration.

return matches[-1].group(1).strip()


def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method parameter is defined but not used within this function. It should be removed to avoid confusion.

Suggested change
def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
def compute_score(solution_str, ground_truth, format_score=0.0, score=1.0):

return format_score


def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method parameter is defined but not used within this function. It should be removed to avoid confusion.

Suggested change
def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0):
def compute_score_subem(solution_str, ground_truth, format_score=0.0, score=1.0):

+environment.skyrl_gym.browse.timeout=30 \
trainer.logger="wandb" \
trainer.project_name="skyrl-browse" \
trainer.run_name="skyrl-browse_8turns_maxgeneratelen_500" \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a mismatch between the run_name which specifies '8turns' and the generator.max_turns parameter which is set to 6 on line 47. To avoid confusion when tracking experiments, these should be consistent.

Suggested change
trainer.run_name="skyrl-browse_8turns_maxgeneratelen_500" \
trainer.run_name="skyrl-browse_6turns_maxgeneratelen_500" \

Comment on lines +266 to +268
print('self.cfg.environment.env_class',self.cfg.environment.env_class)
generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client)
print('GOT TO GENERATOR')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging and should be removed before merging.

Suggested change
print('self.cfg.environment.env_class',self.cfg.environment.env_class)
generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client)
print('GOT TO GENERATOR')
generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants