-
Notifications
You must be signed in to change notification settings - Fork 215
pushing brave search multi turn rl code #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new BrowseEnv for multi-turn reinforcement learning using Brave Search, along with the necessary tool integrations and an example script. The overall structure is good, but there are several critical issues that need addressing. The BraveSearchToolGroup implementation incorrectly handles API requests by not using the configured session or URL, and its retry logic is flawed. The new BrowseEnv has a bug in its step method that can cause crashes when no tool is called. Additionally, there are some unused parameters and leftover debugging print statements that should be cleaned up. I've provided specific comments and suggestions to resolve these issues.
| def _search( | ||
| self, | ||
| keywords: str, | ||
| max_results: Optional[int] = 10, | ||
| region: Optional[str] = "wt-wt", | ||
| ) -> Union[List[Dict], Dict]: | ||
| """ | ||
| Queries the Brave Search API for the provided keywords and region. | ||
| Args: | ||
| keywords (str): The keywords to search for. | ||
| max_results (int, optional): The maximum number of search results to return. Defaults to 10. | ||
| region (str, optional): The region to search in. Defaults to "wt-wt". Possible values include: | ||
| - xa-ar for Arabia | ||
| - xa-en for Arabia (en) | ||
| - ar-es for Argentina | ||
| - au-en for Australia | ||
| - at-de for Austria | ||
| - be-fr for Belgium (fr) | ||
| - be-nl for Belgium (nl) | ||
| - br-pt for Brazil | ||
| - bg-bg for Bulgaria | ||
| - ca-en for Canada | ||
| - ca-fr for Canada (fr) | ||
| - ct-ca for Catalan | ||
| - cl-es for Chile | ||
| - cn-zh for China | ||
| - co-es for Colombia | ||
| - hr-hr for Croatia | ||
| - cz-cs for Czech Republic | ||
| - dk-da for Denmark | ||
| - ee-et for Estonia | ||
| - fi-fi for Finland | ||
| - fr-fr for France | ||
| - de-de for Germany | ||
| - gr-el for Greece | ||
| - hk-tzh for Hong Kong | ||
| - hu-hu for Hungary | ||
| - in-en for India | ||
| - id-id for Indonesia | ||
| - id-en for Indonesia (en) | ||
| - ie-en for Ireland | ||
| - il-he for Israel | ||
| - it-it for Italy | ||
| - jp-jp for Japan | ||
| - kr-kr for Korea | ||
| - lv-lv for Latvia | ||
| - lt-lt for Lithuania | ||
| - xl-es for Latin America | ||
| - my-ms for Malaysia | ||
| - my-en for Malaysia (en) | ||
| - mx-es for Mexico | ||
| - nl-nl for Netherlands | ||
| - nz-en for New Zealand | ||
| - no-no for Norway | ||
| - pe-es for Peru | ||
| - ph-en for Philippines | ||
| - ph-tl for Philippines (tl) | ||
| - pl-pl for Poland | ||
| - pt-pt for Portugal | ||
| - ro-ro for Romania | ||
| - ru-ru for Russia | ||
| - sg-en for Singapore | ||
| - sk-sk for Slovak Republic | ||
| - sl-sl for Slovenia | ||
| - za-en for South Africa | ||
| - es-es for Spain | ||
| - se-sv for Sweden | ||
| - ch-de for Switzerland (de) | ||
| - ch-fr for Switzerland (fr) | ||
| - ch-it for Switzerland (it) | ||
| - tw-tzh for Taiwan | ||
| - th-th for Thailand | ||
| - tr-tr for Turkey | ||
| - ua-uk for Ukraine | ||
| - uk-en for United Kingdom | ||
| - us-en for United States | ||
| - ue-es for United States (es) | ||
| - ve-es for Venezuela | ||
| - vn-vi for Vietnam | ||
| - wt-wt for No region | ||
| Returns: | ||
| list: A list of search result dictionaries, each containing: | ||
| - 'title' (str): The title of the search result. | ||
| - 'href' (str): The URL of the search result. | ||
| - 'body' (str): A brief description or snippet from the search result. | ||
| Or a dict with 'error' key if an error occurred. | ||
| """ | ||
| brave_api_key = os.getenv("BRAVE_API_KEY") | ||
|
|
||
| if not brave_api_key: | ||
| return {"error": "No BRAVE_API_KEY environment variable found. Please set it to use this function."} | ||
|
|
||
| backoff = 2 # initial back-off in seconds | ||
|
|
||
| # Map region codes to Brave Search country codes (ISO 3166-1 alpha-2) | ||
| region_mapping = { | ||
| "xa-ar": "SA", "xa-en": "SA", "ar-es": "AR", "au-en": "AU", "at-de": "AT", | ||
| "be-fr": "BE", "be-nl": "BE", "br-pt": "BR", "bg-bg": "BG", "ca-en": "CA", | ||
| "ca-fr": "CA", "ct-ca": "ES", "cl-es": "CL", "cn-zh": "CN", "co-es": "CO", | ||
| "hr-hr": "HR", "cz-cs": "CZ", "dk-da": "DK", "ee-et": "EE", "fi-fi": "FI", | ||
| "fr-fr": "FR", "de-de": "DE", "gr-el": "GR", "hk-tzh": "HK", "hu-hu": "HU", | ||
| "in-en": "IN", "id-id": "ID", "id-en": "ID", "ie-en": "IE", "il-he": "IL", | ||
| "it-it": "IT", "jp-jp": "JP", "kr-kr": "KR", "lv-lv": "LV", "lt-lt": "LT", | ||
| "xl-es": "MX", "my-ms": "MY", "my-en": "MY", "mx-es": "MX", "nl-nl": "NL", | ||
| "nz-en": "NZ", "no-no": "NO", "pe-es": "PE", "ph-en": "PH", "ph-tl": "PH", | ||
| "pl-pl": "PL", "pt-pt": "PT", "ro-ro": "RO", "ru-ru": "RU", "sg-en": "SG", | ||
| "sk-sk": "SK", "sl-sl": "SI", "za-en": "ZA", "es-es": "ES", "se-sv": "SE", | ||
| "ch-de": "CH", "ch-fr": "CH", "ch-it": "CH", "tw-tzh": "TW", "th-th": "TH", | ||
| "tr-tr": "TR", "ua-uk": "UA", "uk-en": "GB", "us-en": "US", "ue-es": "US", | ||
| "ve-es": "VE", "vn-vi": "VN", "wt-wt": "ALL" | ||
| } | ||
|
|
||
| country = region_mapping.get(region, "ALL") | ||
|
|
||
| headers = { | ||
| "Accept": "application/json", | ||
| "Accept-Encoding": "gzip", | ||
| "X-Subscription-Token": brave_api_key | ||
| } | ||
|
|
||
| params = { | ||
| "q": keywords, | ||
| "count": max_results, | ||
| "search_lang": "en", | ||
| "country": country, | ||
| } | ||
|
|
||
| # Infinite retry loop with exponential backoff for rate limits | ||
| while True: | ||
| try: | ||
| response = requests.get( | ||
| "https://api.search.brave.com/res/v1/web/search", | ||
| headers=headers, | ||
| params=params, | ||
| timeout=30 | ||
| ) | ||
| response.raise_for_status() | ||
| break # Success | ||
| except requests.exceptions.HTTPError as e: | ||
| if response.status_code == 429: | ||
| wait_time = backoff + random.uniform(0, backoff) | ||
| print(f"⚠️ Rate limit hit (429). Retrying in {wait_time:.1f} seconds...") | ||
| time.sleep(wait_time) | ||
| backoff = min(backoff * 2, 120) # cap the back-off at 2 minutes | ||
| continue | ||
| else: | ||
| return {"error": f"HTTP error occurred: {str(e)}"} | ||
| except Exception as e: | ||
| return {"error": f"An error occurred: {str(e)}"} | ||
|
|
||
| try: | ||
| search_results = response.json() | ||
| except Exception as e: | ||
| return {"error": f"Failed to parse response JSON: {str(e)}"} | ||
|
|
||
| if "web" not in search_results or "results" not in search_results["web"]: | ||
| return {"error": "No results found in the response."} | ||
|
|
||
| web_results = search_results["web"]["results"] | ||
|
|
||
| # Convert the search results to the desired format | ||
| results = [] | ||
| for result in web_results[:max_results]: | ||
| results.append({ | ||
| "title": result.get("title", ""), | ||
| "href": result.get("url", ""), | ||
| "body": result.get("description", ""), | ||
| }) | ||
|
|
||
| return results No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _search method has several critical issues that deviate from the class design and introduce bugs:
- Session Not Used: It does not use the
self.sessionobject for making requests, instead callingrequests.get()directly. This negates the benefits of the connection pooling configured in__init__. - URL Hardcoded: It uses a hardcoded URL to the Brave Search API and ignores
self.search_urlfrom__init__. This makes thesearch_urlparameter misleading, especially since the example scripts configure it for a local retriever.
This method should be refactored to use self.session.get(self.search_url, ...) to align with the class's design and configuration.
| @@ -0,0 +1,117 @@ | |||
| from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput, ConversationType | |||
| from typing import Any | |||
| from skyrl_gym.envs.search.utils import compute_score | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute_score function is being imported from skyrl_gym.envs.search.utils, but a new utils.py file with this function is also being added in the browse directory. To ensure the correct utility is used and to follow better module structure, you should import from the local utils.py file.
| from skyrl_gym.envs.search.utils import compute_score | |
| from .utils import compute_score |
| try: | ||
| query = self._parse_action(action) | ||
| observation = self._execute_tool("BraveSearchToolGroup", "browse", query) | ||
| except Exception as e: | ||
| error = str(e) | ||
| observation = None | ||
|
|
||
| # Wrap the observation properly as a message | ||
| if observation: | ||
| new_obs = {"role": "user", "content": observation} | ||
| elif error: | ||
| # Give error as observation if any | ||
| new_obs = {"role": "user", "content": error} | ||
| else: | ||
| new_obs = None | ||
|
|
||
| info = { | ||
| "tool_group": "BraveSearchToolGroup", | ||
| "tool_name": "browse", | ||
| "tool_input": query, | ||
| } | ||
|
|
||
| # Update chat history | ||
| if new_obs: | ||
| self.chat_history.append(new_obs) | ||
|
|
||
| return BaseTextEnvStepOutput( | ||
| observations=[new_obs] if new_obs else [], | ||
| reward=reward, | ||
| done=done, | ||
| metadata=info, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of the step method has a bug where it will raise an exception if the model's action does not contain a <search> tag. This is because _parse_action returns [None], and _execute_tool is then called with None as input, which is not handled. The agent should be able to produce a response without calling a tool.
The logic should be updated to only execute a tool if a valid tool call is parsed from the action. I've also refactored to initialize info only when a tool is actually called.
query = self._parse_action(action)
observation = None
info = {}
if query[0] is not None:
try:
observation = self._execute_tool("BraveSearchToolGroup", "browse", query)
info = {
"tool_group": "BraveSearchToolGroup",
"tool_name": "browse",
"tool_input": query,
}
except Exception as e:
error = str(e)
observation = None
# Wrap the observation properly as a message
if observation:
new_obs = {"role": "user", "content": observation}
elif error:
# Give error as observation if any
new_obs = {"role": "user", "content": error}
else:
new_obs = None
# Update chat history
if new_obs:
self.chat_history.append(new_obs)
return BaseTextEnvStepOutput(
observations=[new_obs] if new_obs else [],
reward=reward,
done=done,
metadata=info,
)| while True: | ||
| try: | ||
| response = requests.get( | ||
| "https://api.search.brave.com/res/v1/web/search", | ||
| headers=headers, | ||
| params=params, | ||
| timeout=30 | ||
| ) | ||
| response.raise_for_status() | ||
| break # Success | ||
| except requests.exceptions.HTTPError as e: | ||
| if response.status_code == 429: | ||
| wait_time = backoff + random.uniform(0, backoff) | ||
| print(f"⚠️ Rate limit hit (429). Retrying in {wait_time:.1f} seconds...") | ||
| time.sleep(wait_time) | ||
| backoff = min(backoff * 2, 120) # cap the back-off at 2 minutes | ||
| continue | ||
| else: | ||
| return {"error": f"HTTP error occurred: {str(e)}"} | ||
| except Exception as e: | ||
| return {"error": f"An error occurred: {str(e)}"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The retry logic in _search has several issues:
- Infinite Loop: The
while Trueloop for retries is not bounded, which could lead to an infinite loop on repeated 429 errors. TheMAX_RETRIESconstant should be used to limit retries. - Hardcoded Timeout: The request timeout is hardcoded to
30instead of usingself.timeout. - Logging:
printis used for logging rate limit warnings. It's better to use the configuredlogger(e.g.,logger.warning(...)). - Unused Constants:
MAX_RETRIESandINITIAL_RETRY_DELAYare defined but not used.
The retry logic should be updated to use a bounded loop and use the class attributes for configuration.
| return matches[-1].group(1).strip() | ||
|
|
||
|
|
||
| def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return format_score | ||
|
|
||
|
|
||
| def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method parameter is defined but not used within this function. It should be removed to avoid confusion.
| def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): | |
| def compute_score_subem(solution_str, ground_truth, format_score=0.0, score=1.0): |
| +environment.skyrl_gym.browse.timeout=30 \ | ||
| trainer.logger="wandb" \ | ||
| trainer.project_name="skyrl-browse" \ | ||
| trainer.run_name="skyrl-browse_8turns_maxgeneratelen_500" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a mismatch between the run_name which specifies '8turns' and the generator.max_turns parameter which is set to 6 on line 47. To avoid confusion when tracking experiments, these should be consistent.
| trainer.run_name="skyrl-browse_8turns_maxgeneratelen_500" \ | |
| trainer.run_name="skyrl-browse_6turns_maxgeneratelen_500" \ |
| print('self.cfg.environment.env_class',self.cfg.environment.env_class) | ||
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | ||
| print('GOT TO GENERATOR') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These print statements appear to be for debugging and should be removed before merging.
| print('self.cfg.environment.env_class',self.cfg.environment.env_class) | |
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | |
| print('GOT TO GENERATOR') | |
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) |
No description provided.