-
Notifications
You must be signed in to change notification settings - Fork 213
Multi-turn native tool use RL #643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces native function calling for multi-turn conversations, which is a significant enhancement. The implementation includes a new 'browse' example with a Brave Search tool, environment, and tool parsers. The changes are well-structured, adding new components for the browsing task and modifying existing generator and inference client to plumb through tool specifications. My review includes suggestions to remove some dead code and debug statements, fix a bug in API key handling, and optimize a loop in the tool parser. Overall, this is a solid contribution.
| brave_api_key = os.getenv("BRAVE_API_KEY") | ||
|
|
||
| if not brave_api_key: | ||
| return {"error": "No BRAVE_API_KEY environment variable found. Please set it to use this function."} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The search method re-fetches the API key from the environment variable, ignoring self.api_key which was set in the constructor. This is redundant and can lead to incorrect behavior if the API key was provided via the constructor. You should use self.api_key throughout this method. Also, remember to update its usage on line 212.
| brave_api_key = os.getenv("BRAVE_API_KEY") | |
| if not brave_api_key: | |
| return {"error": "No BRAVE_API_KEY environment variable found. Please set it to use this function."} | |
| if not self.api_key: | |
| return {"error": "No BRAVE_API_KEY environment variable found. Please set it or pass the api_key to the constructor."} |
|
|
||
| country = region_mapping.get(region, "ALL") | ||
|
|
||
| headers = {"Accept": "application/json", "Accept-Encoding": "gzip", "X-Subscription-Token": brave_api_key} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be consistent with using self.api_key from the constructor, this line should also be updated to use self.api_key instead of the local brave_api_key variable.
| headers = {"Accept": "application/json", "Accept-Encoding": "gzip", "X-Subscription-Token": brave_api_key} | |
| headers = {"Accept": "application/json", "Accept-Encoding": "gzip", "X-Subscription-Token": self.api_key} |
| else: | ||
| tool_calls = [] | ||
| for tool_call in result: | ||
| try: | ||
| evaled_tool_dict = json.loads(tool_call) | ||
| except Exception: | ||
| try: | ||
| evaled_tool_dict = ast.literal_eval(tool_call) | ||
| except Exception: | ||
| evaled_tool_dict = None | ||
|
|
||
| if evaled_tool_dict and evaled_tool_dict.get("name") and evaled_tool_dict.get("arguments"): | ||
| tool_calls.append(evaled_tool_dict) | ||
|
|
||
| content = model_output[: model_output.find(self.tool_call_start_token)] | ||
|
|
||
| return {"tools_called": bool(tool_calls), "tool_calls": tool_calls, "content": content} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of content on line 104 is inside the for loop, which means it's being re-calculated on every iteration with the same result. This is inefficient. You should move this calculation outside and before the loop.
| else: | |
| tool_calls = [] | |
| for tool_call in result: | |
| try: | |
| evaled_tool_dict = json.loads(tool_call) | |
| except Exception: | |
| try: | |
| evaled_tool_dict = ast.literal_eval(tool_call) | |
| except Exception: | |
| evaled_tool_dict = None | |
| if evaled_tool_dict and evaled_tool_dict.get("name") and evaled_tool_dict.get("arguments"): | |
| tool_calls.append(evaled_tool_dict) | |
| content = model_output[: model_output.find(self.tool_call_start_token)] | |
| return {"tools_called": bool(tool_calls), "tool_calls": tool_calls, "content": content} | |
| else: | |
| tool_calls = [] | |
| content = model_output[: model_output.find(self.tool_call_start_token)] | |
| for tool_call in result: | |
| try: | |
| evaled_tool_dict = json.loads(tool_call) | |
| except Exception: | |
| try: | |
| evaled_tool_dict = ast.literal_eval(tool_call) | |
| except Exception: | |
| evaled_tool_dict = None | |
| if evaled_tool_dict and evaled_tool_dict.get("name") and evaled_tool_dict.get("arguments"): | |
| tool_calls.append(evaled_tool_dict) | |
| return {"tools_called": bool(tool_calls), "tool_calls": tool_calls, "content": content} |
| print("self.cfg.environment.env_class", self.cfg.environment.env_class) | ||
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | ||
| print("GOT TO GENERATOR") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These print statements appear to be for debugging purposes. They should be removed from the final code to keep the output clean. If this information is valuable, consider using the logger with an appropriate level (e.g., logger.debug).
| print("self.cfg.environment.env_class", self.cfg.environment.env_class) | |
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | |
| print("GOT TO GENERATOR") | |
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) |
| def _parse_action(self, action: str) -> List[Optional[str]]: | ||
| match = None | ||
| if "<search>" in action and "</search>" in action: | ||
| match = re.search(r"<search>(.*?)</search>", action, re.DOTALL) | ||
| return [match.group(1)] if match else [None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| trainer.eval_before_train=false \ | ||
| generator.eval_sampling_params.temperature=0 \ | ||
| generator.eval_sampling_params.stop='["</search>", "</answer>"]' \ | ||
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B/exports" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a typo in the trainer.export_path. It's set to a path containing skyrl-search, but other related paths and names use skyrl-browse. For consistency, this should probably be skyrl-browse.
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B/exports" \ | |
| trainer.export_path="$HOME/skyrl-browse_6turns_maxgeneratelen_1536_Qwen3-8B/exports" \ |
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B/exports" \ | ||
| trainer.eval_interval=50 \ | ||
| $@ | ||
|
No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| trainer.eval_before_train=false \ | ||
| generator.eval_sampling_params.temperature=0 \ | ||
| generator.eval_sampling_params.stop='["</search>", "</answer>"]' \ | ||
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B_lora/exports" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a typo in the trainer.export_path. It's set to a path containing skyrl-search, but other related paths and names use skyrl-browse. For consistency, this should probably be skyrl-browse.
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B_lora/exports" \ | |
| trainer.export_path="$HOME/skyrl-browse_6turns_maxgeneratelen_1536_Qwen3-8B_lora/exports" \ |
| trainer.export_path="$HOME/skyrl-search_6turns_maxgeneratelen_1536_Qwen3-8B_lora/exports" \ | ||
| trainer.eval_interval=50 \ | ||
| $@ | ||
|
No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Close the environment | ||
| await self._run_in_executor_if_available(env.close) | ||
|
|
||
| # decoded_text = self.tokenizer.decode(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implements native function calling with true multi-turn using the
toolrole. Parsing logic is implemented to parse out the tool calls from the model output, execute and feed it into the model in its native format.