-
Notifications
You must be signed in to change notification settings - Fork 271
Fix Proxy Mode with extra_body support, enforcing max_tokens, switching to individual mode, and fix prompt template and tool call parsing #775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dhh/tau2-example
Are you sure you want to change the base?
Conversation
…est, (2) enforcing 'max_tokens' on proxy server side, (3) switch to 'individual' mode in tool call scenarios, (4) fix prompt template and tool call parsing in completions.create function
Summary of ChangesHello @samjia2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the proxy mode's capabilities by introducing support for 'extra_body' in completion requests and enforcing 'max_tokens' at the proxy level. It also refines the client's prompt templating and tool call parsing logic, ensuring more robust and flexible interactions with language models. Furthermore, the system now defaults to an 'individual' export mode for tool call scenarios and includes several minor bug fixes and configuration adjustments, particularly within the Tau2 example, to improve overall stability and control. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several important fixes and features, including support for extra_body in proxy mode, enforcement of max_tokens, and changes to prompt templating and tool call parsing. The changes generally improve the robustness and functionality of the system.
My review focuses on a few key areas:
- A critical bug in
areal/experimental/openai/cache.pywhere filtering logic doesn't update the cache object correctly for the'concat'export style. - An opportunity to improve code clarity in
areal/experimental/openai/client.pyby removing an unused parameter from the new_ensure_dictfunction. - A suggestion to improve the robustness of request parsing in
areal/utils/proxy_utils.pyby using the recommended Pydantic API. - A minor issue with a weak assertion in the example
examples/tau2/tau2_train.py.
Overall, the PR is a good step forward. Addressing these points will further improve the code quality and prevent potential bugs.
| for id, interaction in self.items(): | ||
| if interaction.interaction_id != id: | ||
| raise ValueError( | ||
| f"Interaction ID mismatch: {interaction.interaction_id} != {id}" | ||
| logger.warning( | ||
| f"Interaction ID mismatch: {interaction.interaction_id} != {id}. It is possibly due to generation failure during trajectory generation." | ||
| ) | ||
| cache = { | ||
| id: interaction | ||
| for id, interaction in cache.items() | ||
| if id == interaction.interaction_id | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from raise ValueError to logger.warning improves robustness. However, the current filtering implementation has a flaw. The cache variable is rebound to a new dict, but self (the InteractionCache instance) is not modified. The subsequent code for style == 'concat' still uses self, which contains the unfiltered interactions, leading to incorrect behavior. For style == 'individual', it correctly uses the filtered cache.
To fix this, you should modify self in-place. The suggested change also makes the filtering more efficient by using a single pass.
for id, interaction in list(self.items()):
if interaction.interaction_id != id:
logger.warning(
f"Interaction ID mismatch: {interaction.interaction_id} != {id}. It is possibly due to generation failure during trajectory generation. Removing it."
)
del self[id]
cache = self| def _ensure_dict( | ||
| name: str, | ||
| item: Any, | ||
| ) -> Any: | ||
| _item = None | ||
| if isinstance(item, dict): | ||
| _item = {k: _ensure_dict(name, v) for k, v in item.items() if v is not None} | ||
| elif isinstance(item, BaseModel): | ||
| _item = item.model_dump(exclude_none=True, mode="json") | ||
| elif type(item).__name__ == "ValidatorIterator" or isinstance(item, list): | ||
| _item = [_ensure_dict(name, i) for i in item] | ||
| else: | ||
| _item = item | ||
| return _item |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name parameter in _ensure_dict is not used. It's passed in recursive calls but its value is never read. It should be removed to clean up the code. The call sites at lines 209, 236, and 518 in this file should be updated accordingly to _ensure_dict(messages_list_raw), _ensure_dict(tools), and _ensure_dict(input) respectively.
| def _ensure_dict( | |
| name: str, | |
| item: Any, | |
| ) -> Any: | |
| _item = None | |
| if isinstance(item, dict): | |
| _item = {k: _ensure_dict(name, v) for k, v in item.items() if v is not None} | |
| elif isinstance(item, BaseModel): | |
| _item = item.model_dump(exclude_none=True, mode="json") | |
| elif type(item).__name__ == "ValidatorIterator" or isinstance(item, list): | |
| _item = [_ensure_dict(name, i) for i in item] | |
| else: | |
| _item = item | |
| return _item | |
| def _ensure_dict( | |
| item: Any, | |
| ) -> Any: | |
| _item = None | |
| if isinstance(item, dict): | |
| _item = {k: _ensure_dict(v) for k, v in item.items() if v is not None} | |
| elif isinstance(item, BaseModel): | |
| _item = item.model_dump(exclude_none=True, mode="json") | |
| elif type(item).__name__ == "ValidatorIterator" or isinstance(item, list): | |
| _item = [_ensure_dict(i) for i in item] | |
| else: | |
| _item = item | |
| return _item |
| known_fields = {k: v for k, v in data.items() if k in cls.__annotations__.keys()} | ||
|
|
||
| # Extract extra fields | ||
| extra_body = {k: v for k, v in data.items() if k not in cls.__annotations__.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using cls.__annotations__.keys() to get the fields of a Pydantic model is not robust. It might not work correctly with field aliases, inherited fields, or other advanced Pydantic features. The recommended way to get a model's fields in Pydantic v2+ is to use cls.model_fields.
| known_fields = {k: v for k, v in data.items() if k in cls.__annotations__.keys()} | |
| # Extract extra fields | |
| extra_body = {k: v for k, v in data.items() if k not in cls.__annotations__.keys()} | |
| known_fields = {k: v for k, v in data.items() if k in cls.model_fields} | |
| # Extract extra fields | |
| extra_body = {k: v for k, v in data.items() if k not in cls.model_fields} |
| assert len(run_infos) == len(rewards), ( | ||
| len(run_infos), | ||
| len(rewards), | ||
| self.group_size, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion len(run_infos) == len(rewards) is likely always true, as both run_infos and rewards are derived from self.group_size. This doesn't effectively check if the number of results matches the expected group size. A more explicit assertion against self.group_size with a descriptive message would be more useful.
assert len(run_infos) == self.group_size and len(rewards) == self.group_size, (
f"Expected {self.group_size} run_infos and rewards, but got {len(run_infos)} and {len(rewards)}"
)
nuzant
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logger.warning( | ||
| f"Interaction ID mismatch: {interaction.interaction_id} != {id}. It is possibly due to generation failure during trajectory generation." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this only happening in individual mode? It seems this should never happen. If generation fails, the interaction should not be added into the cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this only happening in individual mode? It seems this should never happen. If generation fails, the interaction should not be added into the cache.
Generation could fail when enforcing max_tokens on the proxy server side.
| def _ensure_dict( | ||
| name: str, | ||
| item: Any, | ||
| ) -> Any: | ||
| _item = None | ||
| if isinstance(item, dict): | ||
| _item = {k: _ensure_dict(name, v) for k, v in item.items() if v is not None} | ||
| elif isinstance(item, BaseModel): | ||
| _item = item.model_dump(exclude_none=True, mode="json") | ||
| elif type(item).__name__ == "ValidatorIterator" or isinstance(item, list): | ||
| _item = [_ensure_dict(name, i) for i in item] | ||
| else: | ||
| _item = item |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could rewrite a new ensure_input_type function and remove _ensure_message_dict_list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could rewrite a new
ensure_input_typefunction and remove_ensure_message_dict_list.
Agree
…est, (2) enforcing 'max_tokens' on proxy server side, (3) switch to 'individual' mode in tool call scenarios, (4) fix prompt template and tool call parsing in completions.create function
Description
Fix proxy mode with
Related Issue
Fixes #(issue)
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
Additional Context
Training is valided on Tau2 Airline domain:
Need help? Check the Contributing Guide or ask in
GitHub Discussions!