-
Notifications
You must be signed in to change notification settings - Fork 47
Claude/fix raft function registration 01 maq5c t7dvc g wg h3e n shz tu #387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Claude/fix raft function registration 01 maq5c t7dvc g wg h3e n shz tu #387
Conversation
# Conflicts: # benchmark/config/countdown-template.yaml # docs/sphinx_doc/source/tutorial/example_step_wise.md # docs/sphinx_doc/source_zh/tutorial/example_step_wise.md # examples/agentscope_react/gsm8k.yaml # examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml # examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml # examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml # examples/agentscope_websearch/agentscopev1_websearch_agent.yaml
## Summary - Import R3L algorithm implementation from featureA branch - Create backup directory (R3L-back) for safety - Fix critical syntax error in Countdown environment ## Changes ### 1. Imported from featureA - examples/R3L/ - All R3L configuration files for 5 environments - trinity/common/workflows/envs/R3L/ - Complete R3L workflow implementations ### 2. Backup - trinity/common/workflows/envs/R3L-back/ - Full backup before modifications ### 3. Countdown Syntax Fix **File**: trinity/common/workflows/envs/R3L/countdown/utils.py - Fixed f-string syntax error in format_countdown_prompt_with_guidance() - Error: f-string expression part cannot include a backslash - Solution: Extract split operation to variable before f-string ### 4. Environment Status - Alfworld: ✓ Complete (reference implementation) - Countdown: ✓ Fixed syntax error - DAPO: ✓ No changes needed (math_verify import is correct) - ScienceWorld: ✓ No changes needed - WebShop: ✓ No changes needed ## Testing All Python files pass syntax validation: ✓ All R3L_workflow.py files ✓ All utils.py files See R3L_Fix_Summary.md for detailed documentation.
## Summary
Unified all 4 environments (countdown, dapo, scienceworld, webshop) to use the same reflection schema as Alfworld (the reference implementation).
## Changes
### Reflection Prompts (prompts/reflection.j2)
All environments now use the simplified Alfworld JSON schema:
- trajectory_summary: Concise overview (1-3 sentences)
- root_cause_analysis: Deep causal analysis using 'why' questioning
- trajectory_outcome: success | success_but_inefficient | failure
- improvement_suggestion: Generalizable, context-complete principle
- retry_from_step: Integer 0 to N-1 (top-level field)
### Validation Functions (utils.py::validate_reflect_report)
Updated to validate the Alfworld schema:
- Check for top-level fields: trajectory_summary, root_cause_analysis, trajectory_outcome
- Validate improvement_suggestion and retry_from_step for non-success outcomes
- Consistent validation logic across all environments
- Environment-specific log prefixes for debugging
### Workflow Integration (R3L_workflow.py)
Updated retry_step extraction:
- Old: reflect_checklist["analysis"]["retry_strategy"]["retry_step"]
- New: reflect_checklist.get("retry_from_step", 0)
- Top-level field access aligns with Alfworld schema
### Environment-Specific Examples
Each reflection.j2 includes domain-appropriate examples:
- Countdown: Number equation problems
- DAPO: Mathematical problem solving
- ScienceWorld: Interactive science experiments
- WebShop: E-commerce navigation
## Testing
All environments pass Python syntax validation:
✓ alfworld (reference)
✓ countdown
✓ dapo
✓ scienceworld
✓ webshop
## Impact
- Consistent reflection format across all R3L environments
- Simplified schema reduces complexity
- Easier to maintain and extend
- Aligns with the carefully tuned Alfworld implementation
…unction
## Summary
Completed the R3L environment unification by standardizing self_correction prompts and the reflect_report_to_guidance_prompt function across all environments to match Alfworld.
## Changes
### 1. Self-Correction Prompts (self_correction.j2)
Unified all 4 environments to use Alfworld's concise template:
**Before** (3 different versions):
- Countdown/DAPO: Detailed instructional format with bullet points
- ScienceWorld/WebShop: "Internal Monologue Directive" format
- Alfworld: Simple, direct format
**After** (All environments now use):
```
Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
{{ report }}
Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
```
### 2. DAPO reflect_report_to_guidance_prompt Function
Updated DAPO's complex implementation to match the simple template-based approach used by other environments:
**Before** (DAPO only):
- Manually constructed guidance from nested schema fields
- Tried to extract from analysis.flaw_analysis.lessons_learned
- 53 lines of complex string building logic
**After** (All environments):
```python
def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
report_str = json.dumps(report, indent=2, ensure_ascii=False)
jinja_env = _get_jinja_env()
template = jinja_env.get_template("self_correction.j2")
return template.render(report=report_str)
```
### 3. System Prompts (xxx_system.j2)
**No changes needed** - These are environment-specific:
- Interactive environments (alfworld, scienceworld, webshop): Use <think> + <action>
- Single-step environments (countdown, dapo): Use <think> + <answer>
This difference is intentional and correct.
## Testing
All environments pass Python syntax validation:
✓ alfworld
✓ countdown
✓ dapo
✓ scienceworld
✓ webshop
## Impact
- Complete consistency across all R3L prompts and functions
- Simpler, more maintainable codebase
- All environments now fully aligned with Alfworld reference implementation
## Summary
Fixed 3 critical issues in R3L workflows to ensure correct training behavior and consistency across all environments.
## 🔴 Critical Fixes
### 1. Fix ScienceWorld missing max_reflect_tokens (RUNTIME ERROR)
**File**: trinity/common/workflows/envs/R3L/scienceworld/R3L_workflow.py:41
- **Problem**: Used `max_reflect_tokens` at line 122 but never defined it
- **Fix**: Added `self.max_reflect_tokens = 4096` in __init__
- **Impact**: Prevents AttributeError at runtime
### 2. Fix Alfworld unsafe retry_step extraction (POTENTIAL KeyError)
**File**: trinity/common/workflows/envs/R3L/alfworld/R3L_workflow.py:246
- **Problem**: Used `reflect_checklist["retry_from_step"]` (unsafe dictionary access)
- **Fix**: Changed to `reflect_checklist.get("retry_from_step", 0)`
- **Impact**: Prevents KeyError if field is missing
### 3. Add _adjust_action_mask_for_retry to Countdown and DAPO (TRAINING DATA QUALITY)
**Files**:
- trinity/common/workflows/envs/R3L/countdown/R3L_workflow.py:153-192, 320-327
- trinity/common/workflows/envs/R3L/dapo/R3L_workflow.py:172-211, 339-346
- **Problem**: Missing critical method that ensures retry prefixes are excluded from training
- **Fix**: Added method definition and calls in run() method
- **Impact**: Ensures only improved retry attempts are trained, not failed first attempts
- **Details**:
- Identifies assistant response segments in action_mask
- Marks first retry_step segments as non-trainable (mask=0)
- Applied to both second_exp and corresponding first_exp for consistency
## Testing
All environments pass Python syntax validation:
✓ alfworld/R3L_workflow.py
✓ countdown/R3L_workflow.py
✓ dapo/R3L_workflow.py
✓ scienceworld/R3L_workflow.py
✓ webshop/R3L_workflow.py
## Impact
- ScienceWorld: Fixed runtime crash
- Alfworld: Improved robustness
- Countdown/DAPO: Correct training mask behavior (critical for SFT quality)
## Configuration Notes
- max_tokens: DAPO and Countdown use 4096 (math problems need longer reasoning); other envs use 512
- max_reflect_tokens: All environments use 4096 (consistent reflection capacity)
Align webshop and scienceworld with alfworld's approach of reminding the model about the required format (<think>...</think> <action>...</action>) in every user prompt. This improves format compliance during multi-turn interactions. Changes: - webshop/utils.py: Enhanced format_observation to include format reminder - scienceworld/utils.py: Enhanced format_observation to include format reminder
…ronments
This commit addresses several critical bugs and inconsistencies found during
rigorous comparison of all R3L environments against the alfworld reference.
Critical Bug Fixes (Priority 1):
1. alfworld/utils.py: Fixed duplicate variable check in validate_reflect_report
- Line 565 was checking retry_from_step twice instead of checking both
improvement_suggestion and retry_from_step
2. alfworld/utils.py: Added missing else clause for invalid trajectory_outcome
- Previously would not handle invalid outcome values explicitly
3. webshop/R3L_workflow.py: Removed hardcoded absolute path
- Replaced with environment variable WEBSHOP_PATH for portability
Standardization Fixes (Priority 2):
1. alfworld/utils.py: Standardized validate_reflect_report function signature
- Changed parameter from max_steps (optional) to total_steps (required)
- Changed return type hint from tuple to Tuple for consistency
- Added Tuple to imports
2. alfworld/utils.py: Added [R3L Alfworld Validation] prefix to all validation
error messages for consistency with other environments
3. alfworld/prompts/self_correction.j2: Added trailing newline for consistency
4. alfworld/prompts/reflection.j2: Added environment-specific hint
"(ALFWorld interactive tasks)" in improvement_suggestion field description
5. countdown/utils.py: Fixed reflect_report_to_guidance_prompt code order
- Now converts report before loading template (consistent with others)
- Updated docstring to full detailed version
6. scienceworld/R3L_workflow.py: Added Args section to _adjust_action_mask_for_retry
docstring for consistency
All environments now follow consistent patterns and conventions.
Keep the original hardcoded path as default, but allow override via WEBSHOP_PATH environment variable for flexibility in different environments.
WebShop Hardcoded Path Fixes: - grpo_workflow.py: Add environment variable WEBSHOP_PATH support with fallback - opmd_workflow.py: Add environment variable WEBSHOP_PATH support with fallback - raft_workflow.py: Add environment variable WEBSHOP_PATH support with fallback DAPO Format Reminder Enhancement: - dapo/utils.py: Add format_dapo_prompt() function for per-turn format reminders - Update first_rollout() to use format_dapo_prompt for initial prompt - Update first_rollout() feedback messages to include format reminders - Update second_rollout() to use format_dapo_prompt for initial prompt - Update second_rollout() feedback messages to include format reminders Now DAPO aligns with other environments (alfworld, countdown, scienceworld, webshop) by reminding the model about required format (<think>...</think> <answer>...</answer>) in every user turn, not just the system prompt.
RAFT Error Handling (align to alfworld pattern): - countdown/raft_workflow.py: Fix to only append successful samples (reward >= 1.0) - Success: append exp - Failure: append default_exp - Exception: append default_exp - dapo/raft_workflow.py: Same fix as countdown - scienceworld/raft_workflow.py: Same fix as countdown - webshop/raft_workflow.py: Same fix as countdown Previous behavior was appending all exps regardless of success/failure, and ignoring exceptions entirely. New behavior matches alfworld's RAFT pattern. WebShop Fixes: - webshop/grpo_workflow.py: Remove unnecessary max_reflect_tokens=4096 (GRPO doesn't use reflection, only R3L does) - webshop/raft_workflow.py: Fix registration name from "RAFT_baseline" to "raft_baseline" for consistency with other environments All environments now follow consistent RAFT training logic.
DAPO Workflow Creation: - countdown/dapo_workflow.py: Created DAPO workflow for countdown environment - Uses max_attempts=3, max_tokens=4096 - Computes overlong penalty on response length - Returns accuracy + format_score as total reward - dapo/dapo_workflow.py: Created DAPO workflow for dapo environment - Uses max_attempts=3, max_tokens=4096 - Handles 3 different prompt formats (prompt/question/problem) - Same DAPO penalty computation as other environments - scienceworld/dapo_workflow.py: Created DAPO workflow for scienceworld - Uses max_env_steps=30, max_tokens=16384 - Interactive environment with step-by-step execution - Computes overlong penalty on response tokens - webshop/dapo_workflow.py: Created DAPO workflow for webshop - Uses max_env_steps=15, max_tokens=512 - Includes WebShop environment initialization with path config - Compatible with existing webshop utils.first_rollout() All 4 DAPO workflows follow the same pattern as alfworld/dapo_workflow.py: - DAPO-style overlong penalty based on response token length - Configurable penalty parameters (penalty_factor, max_response_length, cache_length) - Total reward = accuracy_score + format_score - Clean metrics tracking (accuracy, format_score, response_length, total_reward) WebShop RAFT Fix: - webshop/raft_workflow.py: Reverted registration name from "raft_baseline" back to "RAFT_baseline" to avoid breaking config file dependencies Now all 5 environments have complete baseline implementations: R3L, GRPO, OPMD, RAFT, and DAPO workflows.
- Unified all RAFT workflow registration names to use uppercase 'RAFT_baseline_' - Updated workflow registration in trinity/common/workflows/envs/R3L for: * alfworld/raft_workflow.py * countdown/raft_workflow.py * dapo/raft_workflow.py * scienceworld/raft_workflow.py - Updated corresponding config files in examples/R3L to match: * All RAFT_7B.yaml and RAFT_1.5B.yaml files now use 'RAFT_baseline_xxx_workflow' - This ensures consistency between function registration names, __init__.py imports, and configuration file workflow_type declarations
Summary of ChangesHello @shiweijiezero, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the framework's capabilities by integrating advanced R3L (Reinforcement Learning from Reflection and Reasoning) workflows for diverse environments. It introduces new configurations and utility scripts to support these sophisticated learning paradigms, aiming to boost model efficiency and effectiveness through iterative self-correction. Concurrently, it refines existing system parameters by enabling key performance optimizations and adjusting synchronization settings for improved robustness. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant number of new configuration files and workflows for different environments (ALFWorld, Countdown, DAPO, ScienceWorld, WebShop) under the R3L directory. It also includes several new Python scripts for data generation and new algorithm implementations like OPMDReweightAdvGroupAdvantage. Additionally, many existing configuration files have been updated to enable performance features like prefix caching and to increase timeouts.
My review focuses on the new code and configurations. I've identified several issues:
- Hardcoded paths: Some data generation scripts contain hardcoded local paths, which is a critical issue for portability.
- Fragile JSON parsing: Several new workflow files use a fragile method to parse JSON from model outputs, which could easily break.
- Code Quality: There are some instances of redundant code and hardcoded values that should be replaced with configurable parameters.
- Typo: A minor typo was found in a data generation script.
Overall, this is a substantial contribution. Addressing the identified issues will improve the robustness and maintainability of the new code.
|
|
||
| # FIX 1: 将默认值改为 None | ||
| def create_dataset_files(output_dir, train_size=None, test_size=None): | ||
| alfworld_data_root = "/export/project/shiweijie/weijie/trinity/alfworld/json_2.1.1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| if __name__ == "__main__": | ||
| # NOTE: Mannually set the jar path here. | ||
| jar_path = "/your/path/ScienceWorld/scienceworld/scienceworld.jar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "jar_path": jar_path, | ||
| } | ||
| task_desc = json.dumps(task_config) | ||
| train_data.append({"task_desc": task_desc, "targe": ""}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "jar_path": jar_path, | ||
| } | ||
| task_desc = json.dumps(task_config) | ||
| test_data.append({"task_desc": task_desc, "targe": ""}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| reward_mean = torch.mean(group_rewards) | ||
| if len(exps) == 1: | ||
| group_baseline = torch.tensor(0.0) | ||
| group_rewards = torch.tensor([exps[0].reward], dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| first_brace = reflection_text.find('{') | ||
| last_brace = reflection_text.rfind('}') | ||
|
|
||
| if first_brace != -1 and last_brace != -1 and first_brace < last_brace: | ||
| json_str = reflection_text[first_brace:last_brace + 1] | ||
| else: | ||
| json_str = reflection_text | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method of finding the first and last brace to extract a JSON string is fragile. It can fail if the model's response includes other text with braces outside the main JSON object. A more robust approach is to use a regular expression to find a JSON block, especially one enclosed in markdown code fences (e.g., json ... ). For example: re.search(r'```json\s*(\{.*?\})\s*```', reflection_text, re.DOTALL).
Description
[Please describe the background, purpose, changes made, and how to test this PR]
Checklist
Please check the following items before code is ready to be reviewed.