diff --git a/.claude/settings.local.json b/.claude/settings.local.json
new file mode 100644
index 0000000000..08a5d97728
--- /dev/null
+++ b/.claude/settings.local.json
@@ -0,0 +1,9 @@
+{
+ "permissions": {
+ "allow": [
+ "WebFetch(domain:arxiv.org)"
+ ],
+ "deny": [],
+ "ask": []
+ }
+}
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 97342e6637..8332481c10 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,7 +15,7 @@ repos:
rev: 23.7.0
hooks:
- id: black
- language_version: python3.10
+# language_version: python3.10
args: [--line-length=100]
- repo: https://github.com/pycqa/isort
diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml
index 9e602bfe52..1cd2918312 100644
--- a/benchmark/config/gsm8k-template.yaml
+++ b/benchmark/config/gsm8k-template.yaml
@@ -60,8 +60,8 @@ explorer:
engine_num: 2
tensor_parallel_size: 1
enforce_eager: false
- enable_prefix_caching: false
- enable_chunked_prefill: false
+ enable_prefix_caching: true
+ enable_chunked_prefill: true
gpu_memory_utilization: 0.9
dtype: bfloat16
seed: 42
diff --git a/examples/R3L/alfworld/RAFT_1.5B.yaml b/examples/R3L/alfworld/RAFT_1.5B.yaml
new file mode 100644
index 0000000000..f51c9df323
--- /dev/null
+++ b/examples/R3L/alfworld/RAFT_1.5B.yaml
@@ -0,0 +1,72 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_raft_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_raft_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/RAFT_7B.yaml b/examples/R3L/alfworld/RAFT_7B.yaml
new file mode 100644
index 0000000000..197a6360d6
--- /dev/null
+++ b/examples/R3L/alfworld/RAFT_7B.yaml
@@ -0,0 +1,72 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 1
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_raft_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_raft_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/dapo_1.5B.yaml b/examples/R3L/alfworld/dapo_1.5B.yaml
new file mode 100644
index 0000000000..961b6dde42
--- /dev/null
+++ b/examples/R3L/alfworld/dapo_1.5B.yaml
@@ -0,0 +1,88 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_DAPO"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+ lr_warmup_steps: 20
+ policy_loss_fn_args:
+ clip_range_low: 0.2
+ clip_range_high: 0.28
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 48 # 如果是96会OOM
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ workflow_args:
+ enable_overlong_penalty: true
+ penalty_factor: 1.0
+ max_response_length: 512
+ cache_length: 400
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'dapo_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_dapo_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_dapo_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 42
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/dapo_7B.yaml b/examples/R3L/alfworld/dapo_7B.yaml
new file mode 100644
index 0000000000..2c671880c0
--- /dev/null
+++ b/examples/R3L/alfworld/dapo_7B.yaml
@@ -0,0 +1,88 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_DAPO"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+ lr_warmup_steps: 20
+ policy_loss_fn_args:
+ clip_range_low: 0.2
+ clip_range_high: 0.28
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ workflow_args:
+ enable_overlong_penalty: true
+ penalty_factor: 1.0
+ max_response_length: 512
+ cache_length: 400
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'dapo_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_dapo_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_dapo_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 42
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/get_alfworld_data.py b/examples/R3L/alfworld/get_alfworld_data.py
new file mode 100644
index 0000000000..2a432f92cc
--- /dev/null
+++ b/examples/R3L/alfworld/get_alfworld_data.py
@@ -0,0 +1,90 @@
+import glob
+import json
+import os
+import random
+
+random.seed(42)
+
+
+# FIX 1: 将默认值改为 None
+def create_dataset_files(output_dir, train_size=None, test_size=None):
+ alfworld_data_root = "/export/project/shiweijie/weijie/trinity/alfworld/json_2.1.1"
+
+ train_game_files = glob.glob(os.path.join(alfworld_data_root, "train/*/*/game.tw-pddl"))
+ test_game_files = glob.glob(os.path.join(alfworld_data_root, "valid_seen/*/*/game.tw-pddl"))
+
+ train_game_files = sorted([os.path.abspath(file) for file in train_game_files])
+ test_game_files = sorted([os.path.abspath(file) for file in test_game_files])
+
+ print(f"Total train game files found: {len(train_game_files)}")
+ print(f"Total test game files found: {len(test_game_files)}")
+
+ # FIX 2: 如果参数为 None,则使用全部文件
+ if train_size is None:
+ train_size = len(train_game_files)
+ print(f"train_size not set, defaulting to all: {train_size}")
+ if test_size is None:
+ test_size = len(test_game_files)
+ print(f"test_size not set, defaulting to all: {test_size}")
+
+ # check sizes
+ assert train_size <= len(
+ train_game_files
+ ), f"train_size {train_size} > available {len(train_game_files)}"
+ assert test_size <= len(
+ test_game_files
+ ), f"test_size {test_size} > available {len(test_game_files)}"
+
+ # 随机采样
+ selected_train_files = random.sample(train_game_files, train_size)
+ selected_test_files = random.sample(test_game_files, test_size)
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ def _create_data_list(game_files):
+ data = []
+ for game_file_path in game_files:
+ data.append({"game_file": game_file_path, "target": ""})
+ return data
+
+ train_data = _create_data_list(selected_train_files)
+ test_data = _create_data_list(selected_test_files)
+
+ dataset_dict = {"train": train_data, "test": test_data}
+
+ for split, data in dataset_dict.items():
+ output_file = os.path.join(output_dir, f"{split}.jsonl")
+ with open(output_file, "w") as f:
+ for item in data:
+ f.write(json.dumps(item) + "\n")
+
+ dataset_info = {
+ "citation": "",
+ "description": "Custom dataset",
+ "splits": {
+ "train": {"name": "train", "num_examples": len(train_data)},
+ "test": {"name": "test", "num_examples": len(test_data)},
+ },
+ }
+
+ with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
+ json.dump(dataset_info, f, indent=2)
+
+ print(f"Created dataset with {len(train_data)} train and {len(test_data)} test examples.")
+
+
+if __name__ == "__main__":
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/alfworld_data"
+
+ # 1. 使用全部数据 (train=3553, test=140)
+ print("--- Creating full dataset ---")
+ create_dataset_files(output_dir)
+
+ # # 2. 仍然可以指定特定大小
+ # print("\n--- Creating subset dataset ---")
+ # create_dataset_files(output_dir, train_size=2953, test_size=100)
+ #
+ # # 3. 只指定训练集大小 (测试集将使用全部)
+ # print("\n--- Creating partial subset dataset ---")
+ # create_dataset_files(output_dir, train_size=100)
diff --git a/examples/R3L/alfworld/grpo_1.5B.yaml b/examples/R3L/alfworld/grpo_1.5B.yaml
new file mode 100644
index 0000000000..3aa20235d3
--- /dev/null
+++ b/examples/R3L/alfworld/grpo_1.5B.yaml
@@ -0,0 +1,79 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_grpo_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_grpo_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/grpo_7B.yaml b/examples/R3L/alfworld/grpo_7B.yaml
new file mode 100644
index 0000000000..df42417ac7
--- /dev/null
+++ b/examples/R3L/alfworld/grpo_7B.yaml
@@ -0,0 +1,79 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_grpo_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_grpo_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/gspo_1.5B.yaml b/examples/R3L/alfworld/gspo_1.5B.yaml
new file mode 100644
index 0000000000..082339eec7
--- /dev/null
+++ b/examples/R3L/alfworld/gspo_1.5B.yaml
@@ -0,0 +1,84 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_GSPO"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ policy_loss_fn: gspo
+ policy_loss_fn_args:
+ clip_range_low: 0.0003
+ clip_range_high: 0.0004
+ loss_agg_mode: seq-mean-token-mean
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 5
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_gspo_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_gspo_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 42
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/gspo_7B.yaml b/examples/R3L/alfworld/gspo_7B.yaml
new file mode 100644
index 0000000000..4862c28182
--- /dev/null
+++ b/examples/R3L/alfworld/gspo_7B.yaml
@@ -0,0 +1,84 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_GSPO"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ policy_loss_fn: gspo
+ policy_loss_fn_args:
+ clip_range_low: 0.0003
+ clip_range_high: 0.0004
+ loss_agg_mode: seq-mean-token-mean
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_gspo_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_gspo_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 42
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_1.5B.yaml b/examples/R3L/alfworld/opmd_1.5B.yaml
new file mode 100644
index 0000000000..855405d4b9
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_opmd_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_opmd_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_7B.yaml b/examples/R3L/alfworld/opmd_7B.yaml
new file mode 100644
index 0000000000..ad84d37ac1
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_7B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_opmd_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_opmd_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_R3L_1.5B.yaml b/examples/R3L/alfworld/opmd_R3L_1.5B.yaml
new file mode 100644
index 0000000000..e417d6dca3
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_R3L_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_R3L_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_R3L_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 6
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_R3L_7B.yaml b/examples/R3L/alfworld/opmd_R3L_7B.yaml
new file mode 100644
index 0000000000..8e06b4aa5c
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_R3L_7B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_R3L_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_R3L_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_reweight_adv_1.5B.yaml b/examples/R3L/alfworld/opmd_reweight_adv_1.5B.yaml
new file mode 100644
index 0000000000..91ae2bdccf
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_reweight_adv_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_1.5B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_opmd_reweight_adv_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_opmd_reweight_adv_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/alfworld/opmd_reweight_adv_7B.yaml b/examples/R3L/alfworld/opmd_reweight_adv_7B.yaml
new file mode 100644
index 0000000000..666b3fad54
--- /dev/null
+++ b/examples/R3L/alfworld/opmd_reweight_adv_7B.yaml
@@ -0,0 +1,76 @@
+project: "ALFWORLD"
+name: "ALFWORLD_RFT_Qwen_7B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: alfworld
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: 'examples/R3L/alfworld/alfworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: alfworld_opmd_reweight_adv_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///alfworld_opmd_reweight_adv_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 10240
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/RAFT_1.5B.yaml b/examples/R3L/countdown/RAFT_1.5B.yaml
new file mode 100644
index 0000000000..efa89d6e98
--- /dev/null
+++ b/examples/R3L/countdown/RAFT_1.5B.yaml
@@ -0,0 +1,74 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_1.5B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_raft_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_raft_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/RAFT_7B.yaml b/examples/R3L/countdown/RAFT_7B.yaml
new file mode 100644
index 0000000000..dacabd0745
--- /dev/null
+++ b/examples/R3L/countdown/RAFT_7B.yaml
@@ -0,0 +1,74 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_7B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_raft_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_raft_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/grpo_1.5B.yaml b/examples/R3L/countdown/grpo_1.5B.yaml
new file mode 100644
index 0000000000..7687d779f4
--- /dev/null
+++ b/examples/R3L/countdown/grpo_1.5B.yaml
@@ -0,0 +1,81 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_1.5B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_grpo_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_grpo_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/grpo_7B.yaml b/examples/R3L/countdown/grpo_7B.yaml
new file mode 100644
index 0000000000..9462bcc708
--- /dev/null
+++ b/examples/R3L/countdown/grpo_7B.yaml
@@ -0,0 +1,81 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_7B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_grpo_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_grpo_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_1.5B.yaml b/examples/R3L/countdown/opmd_1.5B.yaml
new file mode 100644
index 0000000000..9b1e23b0ba
--- /dev/null
+++ b/examples/R3L/countdown/opmd_1.5B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_1.5B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_opmd_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_opmd_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_7B.yaml b/examples/R3L/countdown/opmd_7B.yaml
new file mode 100644
index 0000000000..8a0f696d25
--- /dev/null
+++ b/examples/R3L/countdown/opmd_7B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_7B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_opmd_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_opmd_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_R3L_1.5B.yaml b/examples/R3L/countdown/opmd_R3L_1.5B.yaml
new file mode 100644
index 0000000000..c59aa7f44a
--- /dev/null
+++ b/examples/R3L/countdown/opmd_R3L_1.5B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_1.5B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_R3L_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_R3L_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_R3L_7B.yaml b/examples/R3L/countdown/opmd_R3L_7B.yaml
new file mode 100644
index 0000000000..106dc134a2
--- /dev/null
+++ b/examples/R3L/countdown/opmd_R3L_7B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_7B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_R3L_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_R3L_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_reweight_adv_1.5B.yaml b/examples/R3L/countdown/opmd_reweight_adv_1.5B.yaml
new file mode 100644
index 0000000000..c0fd64d414
--- /dev/null
+++ b/examples/R3L/countdown/opmd_reweight_adv_1.5B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_1.5B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_opmd_reweight_adv_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_opmd_reweight_adv_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/countdown/opmd_reweight_adv_7B.yaml b/examples/R3L/countdown/opmd_reweight_adv_7B.yaml
new file mode 100644
index 0000000000..02f17a68c5
--- /dev/null
+++ b/examples/R3L/countdown/opmd_reweight_adv_7B.yaml
@@ -0,0 +1,78 @@
+project: "COUNTDOWN"
+name: "COUNTDOWN_RFT_Qwen_7B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: countdown
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'train'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: countdown-eval
+ storage_type: file
+ path: 'justinphan3110/Countdown-Tasks-3to4'
+ split: 'test'
+ format:
+ prompt_key: 'nums'
+ response_key: 'target'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_countdown_workflow'
+ trainer_input:
+ experience_buffer:
+ name: countdown_opmd_reweight_adv_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///countdown_opmd_reweight_adv_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/RAFT_1.5B.yaml b/examples/R3L/dapo/RAFT_1.5B.yaml
new file mode 100644
index 0000000000..fd36db7dcc
--- /dev/null
+++ b/examples/R3L/dapo/RAFT_1.5B.yaml
@@ -0,0 +1,137 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_1.5B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_raft_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_raft_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/RAFT_7B.yaml b/examples/R3L/dapo/RAFT_7B.yaml
new file mode 100644
index 0000000000..de144c3d8f
--- /dev/null
+++ b/examples/R3L/dapo/RAFT_7B.yaml
@@ -0,0 +1,137 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_7B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_raft_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_raft_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/grpo_1.5B.yaml b/examples/R3L/dapo/grpo_1.5B.yaml
new file mode 100644
index 0000000000..d25c947d16
--- /dev/null
+++ b/examples/R3L/dapo/grpo_1.5B.yaml
@@ -0,0 +1,144 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_1.5B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_grpo_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_grpo_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 6
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/grpo_7B.yaml b/examples/R3L/dapo/grpo_7B.yaml
new file mode 100644
index 0000000000..ea77d5f1df
--- /dev/null
+++ b/examples/R3L/dapo/grpo_7B.yaml
@@ -0,0 +1,144 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_7B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_grpo_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_grpo_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 6
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_1.5B.yaml b/examples/R3L/dapo/opmd_1.5B.yaml
new file mode 100644
index 0000000000..0d5ec806ca
--- /dev/null
+++ b/examples/R3L/dapo/opmd_1.5B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_1.5B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_opmd_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_opmd_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_7B.yaml b/examples/R3L/dapo/opmd_7B.yaml
new file mode 100644
index 0000000000..31cdc3a518
--- /dev/null
+++ b/examples/R3L/dapo/opmd_7B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_7B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_opmd_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_opmd_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_R3L_1.5B.yaml b/examples/R3L/dapo/opmd_R3L_1.5B.yaml
new file mode 100644
index 0000000000..394592d24f
--- /dev/null
+++ b/examples/R3L/dapo/opmd_R3L_1.5B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_1.5B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_R3L_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_R3L_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_R3L_7B.yaml b/examples/R3L/dapo/opmd_R3L_7B.yaml
new file mode 100644
index 0000000000..fba20ac335
--- /dev/null
+++ b/examples/R3L/dapo/opmd_R3L_7B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_7B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_R3L_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_R3L_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_reweight_adv_1.5B.yaml b/examples/R3L/dapo/opmd_reweight_adv_1.5B.yaml
new file mode 100644
index 0000000000..40312eb22e
--- /dev/null
+++ b/examples/R3L/dapo/opmd_reweight_adv_1.5B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_1.5B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_opmd_reweight_adv_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_opmd_reweight_adv_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/dapo/opmd_reweight_adv_7B.yaml b/examples/R3L/dapo/opmd_reweight_adv_7B.yaml
new file mode 100644
index 0000000000..fadb952e1d
--- /dev/null
+++ b/examples/R3L/dapo/opmd_reweight_adv_7B.yaml
@@ -0,0 +1,141 @@
+project: "DAPO"
+name: "DAPO_RFT_Qwen_7B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: dapo
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'train'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: dapo-eval
+ storage_type: file
+ path: 'weijiezz/DAPO-Math-17k-split'
+ split: 'test'
+ format:
+ prompt_key: 'prompt'
+ response_key: 'reward_model'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime24
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime24'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-aime25
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_aime25'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-amc23
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_amc23'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-gsm8k
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_gsm8k'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-math500
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_math500'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-minervamath
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_minervamath'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ - name: math-olympiadbench
+ storage_type: file
+ path: 'weijiezz/math-datasets-100k'
+ split: 'test_olympiadbench'
+ format:
+ prompt_key: 'question'
+ response_key: 'answer'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_dapo_workflow'
+ trainer_input:
+ experience_buffer:
+ name: dapo_opmd_reweight_adv_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///dapo_opmd_reweight_adv_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/RAFT_1.5B.yaml b/examples/R3L/scienceworld/RAFT_1.5B.yaml
new file mode 100644
index 0000000000..c2de9c3e35
--- /dev/null
+++ b/examples/R3L/scienceworld/RAFT_1.5B.yaml
@@ -0,0 +1,72 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_1.5B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_raft_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_raft_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/RAFT_7B.yaml b/examples/R3L/scienceworld/RAFT_7B.yaml
new file mode 100644
index 0000000000..d5afb2476c
--- /dev/null
+++ b/examples/R3L/scienceworld/RAFT_7B.yaml
@@ -0,0 +1,72 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_7B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_raft_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_raft_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/get_sciworld_data.py b/examples/R3L/scienceworld/get_sciworld_data.py
new file mode 100644
index 0000000000..a10da8d2ee
--- /dev/null
+++ b/examples/R3L/scienceworld/get_sciworld_data.py
@@ -0,0 +1,132 @@
+"""
+We use this script to create the huggingface format dataset files for the sciworld dataset.
+NOTE: You need to install the ScienceWorld dataset first: https://github.com/allenai/ScienceWorld
+"""
+import json
+import os
+import random
+
+random.seed(42)
+
+task_variations = {
+ "boil": 30,
+ "melt": 30,
+ "freeze": 30,
+ "change-the-state-of-matter-of": 30,
+ "use-thermometer": 540,
+ "measure-melting-point-known-substance": 436,
+ "measure-melting-point-unknown-substance": 300,
+ "power-component": 20,
+ "power-component-renewable-vs-nonrenewable-energy": 20,
+ "test-conductivity": 900,
+ "test-conductivity-of-unknown-substances": 600,
+ "find-living-thing": 300,
+ "find-non-living-thing": 300,
+ "find-plant": 300,
+ "find-animal": 300,
+ "grow-plant": 126,
+ "grow-fruit": 126,
+ "chemistry-mix": 32,
+ "chemistry-mix-paint-secondary-color": 36,
+ "chemistry-mix-paint-tertiary-color": 36,
+ "lifespan-longest-lived": 125,
+ "lifespan-shortest-lived": 125,
+ "lifespan-longest-lived-then-shortest-lived": 125,
+ "identify-life-stages-1": 14,
+ "identify-life-stages-2": 10,
+ "inclined-plane-determine-angle": 168,
+ "inclined-plane-friction-named-surfaces": 1386,
+ "inclined-plane-friction-unnamed-surfaces": 162,
+ "mendelian-genetics-known-plant": 120,
+ "mendelian-genetics-unknown-plant": 480,
+}
+
+
+def create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.6):
+ # make the output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ train_data = []
+ test_data = []
+
+ for task_name in train_task_names:
+ total_var = task_variations[task_name]
+ for i in range(int(total_var * percentage)):
+ task_config = {
+ "task_name": task_name,
+ "var_num": i,
+ "jar_path": jar_path,
+ }
+ task_desc = json.dumps(task_config)
+ train_data.append({"task_desc": task_desc, "targe": ""})
+
+ random.shuffle(train_data)
+
+ for task_name in test_task_names:
+ total_var = task_variations[task_name]
+ for i in range(int(total_var * percentage)):
+ task_config = {
+ "task_name": task_name,
+ "var_num": i,
+ "jar_path": jar_path,
+ }
+ task_desc = json.dumps(task_config)
+ test_data.append({"task_desc": task_desc, "targe": ""})
+
+ random.shuffle(test_data)
+
+ # create dataset_dict
+ dataset_dict = {"train": train_data, "test": test_data}
+
+ for split, data in dataset_dict.items():
+ output_file = os.path.join(output_dir, f"{split}.jsonl")
+ with open(output_file, "w") as f:
+ for item in data:
+ f.write(json.dumps(item) + "\n")
+
+ # create dataset_dict.json
+ dataset_info = {
+ "citation": "",
+ "description": "Custom dataset",
+ "splits": {
+ "train": {"name": "train", "num_examples": len(train_data)},
+ "test": {"name": "test", "num_examples": len(test_data)},
+ },
+ }
+
+ with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
+ json.dump(dataset_info, f, indent=2)
+
+
+if __name__ == "__main__":
+ # NOTE: Mannually set the jar path here.
+ jar_path = "/your/path/ScienceWorld/scienceworld/scienceworld.jar"
+ # Check if the jar file exists, raise an error if it doesn't exist.
+ if not os.path.exists(jar_path):
+ raise FileNotFoundError(
+ f"JAR file not found at {jar_path}, please set the jar path mannually."
+ )
+
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/sciworld_data"
+ train_task_names = [
+ "boil",
+ "melt",
+ "change-the-state-of-matter-of",
+ "use-thermometer",
+ "measure-melting-point-known-substance",
+ "power-component",
+ "test-conductivity",
+ "find-living-thing",
+ "find-plant",
+ "grow-plant",
+ "chemistry-mix",
+ "chemistry-mix-paint-secondary-color",
+ "lifespan-shortest-lived",
+ "identify-life-stages-2",
+ "inclined-plane-determine-angle",
+ "inclined-plane-friction-named-surfaces",
+ "mendelian-genetics-known-plant",
+ ]
+ test_task_names = list(task_variations.keys() - set(train_task_names))
+ create_dataset_files(output_dir, train_task_names, test_task_names, jar_path, percentage=0.5)
diff --git a/examples/R3L/scienceworld/grpo_1.5B.yaml b/examples/R3L/scienceworld/grpo_1.5B.yaml
new file mode 100644
index 0000000000..84961fcfd8
--- /dev/null
+++ b/examples/R3L/scienceworld/grpo_1.5B.yaml
@@ -0,0 +1,79 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_1.5B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_grpo_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_grpo_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/grpo_7B.yaml b/examples/R3L/scienceworld/grpo_7B.yaml
new file mode 100644
index 0000000000..f41d14e7cc
--- /dev/null
+++ b/examples/R3L/scienceworld/grpo_7B.yaml
@@ -0,0 +1,79 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_7B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_grpo_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_grpo_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_1.5B.yaml b/examples/R3L/scienceworld/opmd_1.5B.yaml
new file mode 100644
index 0000000000..1606439826
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_1.5B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_opmd_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_opmd_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_7B.yaml b/examples/R3L/scienceworld/opmd_7B.yaml
new file mode 100644
index 0000000000..c7f8be22c3
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_7B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_7B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_opmd_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_opmd_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_R3L_1.5B.yaml b/examples/R3L/scienceworld/opmd_R3L_1.5B.yaml
new file mode 100644
index 0000000000..d6f99be0fb
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_R3L_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_1.5B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_R3L_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_R3L_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_R3L_7B.yaml b/examples/R3L/scienceworld/opmd_R3L_7B.yaml
new file mode 100644
index 0000000000..19a4baed64
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_R3L_7B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_7B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_R3L_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_R3L_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_reweight_adv_1.5B.yaml b/examples/R3L/scienceworld/opmd_reweight_adv_1.5B.yaml
new file mode 100644
index 0000000000..264f09e6b9
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_reweight_adv_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_1.5B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_opmd_reweight_adv_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_opmd_reweight_adv_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/scienceworld/opmd_reweight_adv_7B.yaml b/examples/R3L/scienceworld/opmd_reweight_adv_7B.yaml
new file mode 100644
index 0000000000..68fdb50d3c
--- /dev/null
+++ b/examples/R3L/scienceworld/opmd_reweight_adv_7B.yaml
@@ -0,0 +1,76 @@
+project: "SCIENCEWORLD"
+name: "SCIENCEWORLD_RFT_Qwen_7B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: scienceworld
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: scienceworld-eval
+ storage_type: file
+ path: 'examples/R3L/scienceworld/scienceworld_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_scienceworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: scienceworld_opmd_reweight_adv_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///scienceworld_opmd_reweight_adv_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/RAFT_1.5B.yaml b/examples/R3L/webshop/RAFT_1.5B.yaml
new file mode 100644
index 0000000000..f0d5bca0c4
--- /dev/null
+++ b/examples/R3L/webshop/RAFT_1.5B.yaml
@@ -0,0 +1,72 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_1.5B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_raft_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_raft_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/RAFT_7B.yaml b/examples/R3L/webshop/RAFT_7B.yaml
new file mode 100644
index 0000000000..3feccf10cb
--- /dev/null
+++ b/examples/R3L/webshop/RAFT_7B.yaml
@@ -0,0 +1,72 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_7B_RAFT_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'RAFT_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_raft_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_raft_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/get_webshop_data.py b/examples/R3L/webshop/get_webshop_data.py
new file mode 100644
index 0000000000..61ea93a5ea
--- /dev/null
+++ b/examples/R3L/webshop/get_webshop_data.py
@@ -0,0 +1,47 @@
+"""
+We use this script to create the huggingface format dataset files for the webshop dataset.
+"""
+import json
+import os
+
+
+def create_dataset_files(output_dir, train_size=4096, test_size=100):
+ # make the output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ # for webshop dataset, we just need the session id as the task id
+ all_data = []
+ for task_id in range(train_size + test_size):
+ all_data.append({"task_id": task_id, "target": ""})
+
+ # split the train and test data
+ train_data = all_data[:train_size]
+ test_data = all_data[train_size : train_size + test_size]
+
+ # create dataset_dict
+ dataset_dict = {"train": train_data, "test": test_data}
+
+ for split, data in dataset_dict.items():
+ output_file = os.path.join(output_dir, f"{split}.jsonl")
+ with open(output_file, "w") as f:
+ for item in data:
+ f.write(json.dumps(item) + "\n")
+
+ # create dataset_dict.json
+ dataset_info = {
+ "citation": "",
+ "description": "Custom dataset",
+ "splits": {
+ "train": {"name": "train", "num_examples": len(train_data)},
+ "test": {"name": "test", "num_examples": len(test_data)},
+ },
+ }
+
+ with open(os.path.join(output_dir, "dataset_dict.json"), "w") as f:
+ json.dump(dataset_info, f, indent=2)
+
+
+if __name__ == "__main__":
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/webshop_data"
+ create_dataset_files(output_dir, train_size=4096, test_size=100)
diff --git a/examples/R3L/webshop/grpo_1.5B.yaml b/examples/R3L/webshop/grpo_1.5B.yaml
new file mode 100644
index 0000000000..6e59962af2
--- /dev/null
+++ b/examples/R3L/webshop/grpo_1.5B.yaml
@@ -0,0 +1,79 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_1.5B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_grpo_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_grpo_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/grpo_7B.yaml b/examples/R3L/webshop/grpo_7B.yaml
new file mode 100644
index 0000000000..6df621b07e
--- /dev/null
+++ b/examples/R3L/webshop/grpo_7B.yaml
@@ -0,0 +1,79 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_7B_GRPO_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: grpo
+ kl_loss_fn: k3
+ kl_loss_fn_args:
+ kl_coef: 0.01
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'grpo_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_grpo_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_grpo_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_1.5B.yaml b/examples/R3L/webshop/opmd_1.5B.yaml
new file mode 100644
index 0000000000..f503f44150
--- /dev/null
+++ b/examples/R3L/webshop/opmd_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_1.5B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_opmd_baseline_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_opmd_baseline_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_7B.yaml b/examples/R3L/webshop/opmd_7B.yaml
new file mode 100644
index 0000000000..c3e2dbee92
--- /dev/null
+++ b/examples/R3L/webshop/opmd_7B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_7B_OPMD_Baseline"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_opmd_baseline_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_opmd_baseline_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_R3L_1.5B.yaml b/examples/R3L/webshop/opmd_R3L_1.5B.yaml
new file mode 100644
index 0000000000..3848dac764
--- /dev/null
+++ b/examples/R3L/webshop/opmd_R3L_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_1.5B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_R3L_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_R3L_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_R3L_7B.yaml b/examples/R3L/webshop/opmd_R3L_7B.yaml
new file mode 100644
index 0000000000..637a6b26a2
--- /dev/null
+++ b/examples/R3L/webshop/opmd_R3L_7B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_7B_R3L"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'R3L_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_R3L_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_R3L_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_reweight_adv_1.5B.yaml b/examples/R3L/webshop/opmd_reweight_adv_1.5B.yaml
new file mode 100644
index 0000000000..7f839e4e7b
--- /dev/null
+++ b/examples/R3L/webshop/opmd_reweight_adv_1.5B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_1.5B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_opmd_reweight_adv_1.5B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_opmd_reweight_adv_1.5B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/R3L/webshop/opmd_reweight_adv_7B.yaml b/examples/R3L/webshop/opmd_reweight_adv_7B.yaml
new file mode 100644
index 0000000000..ba32c12dcb
--- /dev/null
+++ b/examples/R3L/webshop/opmd_reweight_adv_7B.yaml
@@ -0,0 +1,76 @@
+project: "WEBSHOP"
+name: "WEBSHOP_RFT_Qwen_7B_OPMD_Reweight_Adv"
+checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
+algorithm:
+ algorithm_type: opmd_reweight_adv
+ repeat_times: 8
+ optimizer:
+ lr: 1e-6
+model:
+ model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
+ max_response_tokens: 512
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 4
+buffer:
+ total_epochs: 20
+ batch_size: 96
+ explorer_input:
+ taskset:
+ name: webshop
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'train'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets:
+ - name: webshop-eval
+ storage_type: file
+ path: 'examples/R3L/webshop/webshop_data'
+ split: 'test'
+ format:
+ prompt_key: 'task_id'
+ rollout_args:
+ temperature: 0.4
+ default_workflow_type: 'opmd_baseline_webshop_workflow'
+ trainer_input:
+ experience_buffer:
+ name: webshop_opmd_reweight_adv_7B_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: true
+ priority_fn: decay_limit_randomization
+ path: 'sqlite:///webshop_opmd_reweight_adv_7B.db'
+explorer:
+ runner_per_model: 32
+ eval_interval: 20
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 0
+ gpu_memory_utilization: 0.7
+ enable_chunked_prefill: true
+data_processor:
+ experience_pipeline:
+ operators:
+ - name: "OPMD_filter"
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 12000
+trainer:
+ save_interval: 20
+ grad_clip: 1.0
+ use_dynamic_bsz: true
+ max_token_len_per_gpu: 16384
+ ulysses_sequence_parallel_size: 1
+monitor:
+ monitor_type: wandb
diff --git a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
index 7fc2445eaf..83803ef1ba 100644
--- a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
+++ b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
@@ -65,7 +65,7 @@ synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 4
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100000
grad_clip: 1.0
diff --git a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
index 30a115cada..f37af321f0 100644
--- a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
+++ b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
@@ -56,7 +56,7 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
+ enable_prefix_caching: true
enforce_eager: false
dtype: bfloat16
gpu_memory_utilization: 0.86
@@ -65,7 +65,7 @@ synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 4
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100000
grad_clip: 1.0
diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml
index d108ca24cd..3244db9ed8 100644
--- a/examples/asymre_gsm8k/gsm8k.yaml
+++ b/examples/asymre_gsm8k/gsm8k.yaml
@@ -57,14 +57,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
- sync_timeout: 1200
+ sync_timeout: 12000
sync_offset: 0
trainer:
save_interval: 100
diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml
index 8ad903030a..d3197f489d 100644
--- a/examples/asymre_math/math.yaml
+++ b/examples/asymre_math/math.yaml
@@ -59,8 +59,8 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1 # Each engine uses 1 GPU
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml
index 9314b5340c..e1e774ab4c 100644
--- a/examples/async_gsm8k/explorer.yaml
+++ b/examples/async_gsm8k/explorer.yaml
@@ -40,8 +40,8 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/cispo_gsm8k/gsm8k.yaml b/examples/cispo_gsm8k/gsm8k.yaml
index aff4c50f97..dfa114dc6e 100644
--- a/examples/cispo_gsm8k/gsm8k.yaml
+++ b/examples/cispo_gsm8k/gsm8k.yaml
@@ -49,14 +49,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 4
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml
index 6418c22d56..b95c8284f8 100644
--- a/examples/dapo_math/dapo.yaml
+++ b/examples/dapo_math/dapo.yaml
@@ -65,14 +65,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 16
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/dpo_human_in_the_loop/dpo.yaml b/examples/dpo_human_in_the_loop/dpo.yaml
index f13dfc539e..052684ea8b 100644
--- a/examples/dpo_human_in_the_loop/dpo.yaml
+++ b/examples/dpo_human_in_the_loop/dpo.yaml
@@ -72,7 +72,7 @@ buffer:
synchronizer:
sync_method: 'checkpoint'
sync_interval: 30
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 30
total_steps: 200
diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml
index f9948becca..abbc7c3313 100644
--- a/examples/dpo_humanlike/dpo.yaml
+++ b/examples/dpo_humanlike/dpo.yaml
@@ -37,7 +37,7 @@ buffer:
synchronizer:
sync_method: 'checkpoint'
sync_interval: 30
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 30
total_steps: 200
diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml
index 77ba65d555..c32b3ad868 100644
--- a/examples/grpo_alfworld/alfworld.yaml
+++ b/examples/grpo_alfworld/alfworld.yaml
@@ -38,8 +38,8 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 2
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.7
diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml
index 5427b6829e..395a54aaf5 100644
--- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml
+++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml
@@ -45,8 +45,8 @@ explorer:
enable_history: true
engine_num: 2
tensor_parallel_size: 2
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.7
diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml
index 24a58a8c96..84ff6f2d88 100644
--- a/examples/grpo_email_search/email_search.yaml
+++ b/examples/grpo_email_search/email_search.yaml
@@ -88,8 +88,8 @@ explorer:
tool_call_parser: hermes
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.7
diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml
index b0640f089c..0833429b61 100644
--- a/examples/grpo_gsm8k/gsm8k.yaml
+++ b/examples/grpo_gsm8k/gsm8k.yaml
@@ -49,14 +49,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
trainer_type: 'verl'
save_interval: 100
diff --git a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
index e633f7bbc1..ce0db55b09 100644
--- a/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
+++ b/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
@@ -69,14 +69,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml
index e0297e099b..9b765413ac 100644
--- a/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml
+++ b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml
@@ -50,8 +50,8 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
auxiliary_models:
diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
index e519e5c806..81ad7d1ff5 100644
--- a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
+++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
@@ -67,14 +67,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml
index 9fc071ba67..7aff5e1d8d 100644
--- a/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml
+++ b/examples/grpo_gsm8k_trainable_ruler/gsm8k_ruler.yaml
@@ -52,8 +52,8 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/grpo_lora_gsm8k/gsm8k.yaml b/examples/grpo_lora_gsm8k/gsm8k.yaml
index 6818d82a74..8955ad4742 100644
--- a/examples/grpo_lora_gsm8k/gsm8k.yaml
+++ b/examples/grpo_lora_gsm8k/gsm8k.yaml
@@ -52,14 +52,14 @@ explorer:
rollout_model:
engine_num: 1
tensor_parallel_size: 4
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'checkpoint'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
trainer_type: 'verl'
save_interval: 100
diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml
index 1ec35ce86c..6069dc74ff 100644
--- a/examples/grpo_math/math.yaml
+++ b/examples/grpo_math/math.yaml
@@ -42,14 +42,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/grpo_rubric_as_reward/rubric.yaml b/examples/grpo_rubric_as_reward/rubric.yaml
index 48e6909ba0..6b6231a6aa 100644
--- a/examples/grpo_rubric_as_reward/rubric.yaml
+++ b/examples/grpo_rubric_as_reward/rubric.yaml
@@ -44,8 +44,8 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
auxiliary_models:
diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml
index 09bf683132..f13ccc1d5c 100644
--- a/examples/grpo_sciworld/sciworld.yaml
+++ b/examples/grpo_sciworld/sciworld.yaml
@@ -37,8 +37,8 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 2
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.7
@@ -46,7 +46,7 @@ explorer:
synchronizer:
sync_method: 'nccl'
sync_interval: 8
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 10
grad_clip: 1.0
diff --git a/examples/grpo_toolcall/toolace.yaml b/examples/grpo_toolcall/toolace.yaml
index 05e8a7e7e4..d72a2ff7eb 100644
--- a/examples/grpo_toolcall/toolace.yaml
+++ b/examples/grpo_toolcall/toolace.yaml
@@ -37,8 +37,8 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/grpo_vlm/vlm.yaml b/examples/grpo_vlm/vlm.yaml
index 7dda5517b4..55214209cb 100644
--- a/examples/grpo_vlm/vlm.yaml
+++ b/examples/grpo_vlm/vlm.yaml
@@ -41,14 +41,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml
index 7357002bcb..c8f7258d23 100644
--- a/examples/grpo_webshop/webshop.yaml
+++ b/examples/grpo_webshop/webshop.yaml
@@ -37,8 +37,8 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 2
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
gpu_memory_utilization: 0.7
@@ -46,7 +46,7 @@ explorer:
synchronizer:
sync_method: 'nccl'
sync_interval: 8
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 10
grad_clip: 1.0
diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml
index 53ff7e5a8a..16f3be535b 100644
--- a/examples/mix_chord/mix_chord.yaml
+++ b/examples/mix_chord/mix_chord.yaml
@@ -73,14 +73,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 50
grad_clip: 1.0
diff --git a/examples/mix_chord/mix_chord_toolace.yaml b/examples/mix_chord/mix_chord_toolace.yaml
index 9380c82c36..be051f5467 100644
--- a/examples/mix_chord/mix_chord_toolace.yaml
+++ b/examples/mix_chord/mix_chord_toolace.yaml
@@ -68,14 +68,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 50
grad_clip: 1.0
diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml
index 10b242fb15..096e950c61 100644
--- a/examples/mix_math/mix_math.yaml
+++ b/examples/mix_math/mix_math.yaml
@@ -72,14 +72,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 50
grad_clip: 1.0
diff --git a/examples/mix_vlm/mix_vlm.yaml b/examples/mix_vlm/mix_vlm.yaml
index a617c9ab8d..f2a89f8100 100644
--- a/examples/mix_vlm/mix_vlm.yaml
+++ b/examples/mix_vlm/mix_vlm.yaml
@@ -78,8 +78,8 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml
index c9e2b9563e..46ebb9ae1f 100644
--- a/examples/opmd_gsm8k/opmd_gsm8k.yaml
+++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml
@@ -41,14 +41,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 10
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml
index a890c7e532..69bede4a00 100644
--- a/examples/ppo_countdown/countdown.yaml
+++ b/examples/ppo_countdown/countdown.yaml
@@ -40,14 +40,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 10
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/ppo_countdown_exp_replay/countdown.yaml b/examples/ppo_countdown_exp_replay/countdown.yaml
index c3871bd93d..c2b43f0a30 100644
--- a/examples/ppo_countdown_exp_replay/countdown.yaml
+++ b/examples/ppo_countdown_exp_replay/countdown.yaml
@@ -44,8 +44,8 @@ explorer:
rollout_model:
engine_num: 1 # allocate 1 GPU for explorer
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
diff --git a/examples/ppo_countdown_megatron/countdown.yaml b/examples/ppo_countdown_megatron/countdown.yaml
index 1d8ffa2b13..7910ef7b4e 100644
--- a/examples/ppo_countdown_megatron/countdown.yaml
+++ b/examples/ppo_countdown_megatron/countdown.yaml
@@ -39,14 +39,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 10
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/examples/rec_gsm8k/gsm8k.yaml b/examples/rec_gsm8k/gsm8k.yaml
index c0dd6a3f07..122c87932d 100644
--- a/examples/rec_gsm8k/gsm8k.yaml
+++ b/examples/rec_gsm8k/gsm8k.yaml
@@ -59,14 +59,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
- sync_timeout: 1200
+ sync_timeout: 12000
sync_offset: 0
trainer:
trainer_type: verl
diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml
index 295c82116e..62473a8736 100644
--- a/examples/sppo_gsm8k/gsm8k.yaml
+++ b/examples/sppo_gsm8k/gsm8k.yaml
@@ -54,14 +54,14 @@ explorer:
rollout_model:
engine_num: 4
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
- sync_timeout: 1200
+ sync_timeout: 12000
sync_offset: 0
trainer:
save_interval: 100
diff --git a/examples/topr_gsm8k/gsm8k.yaml b/examples/topr_gsm8k/gsm8k.yaml
index 252b5e7d7a..f04c469fb0 100644
--- a/examples/topr_gsm8k/gsm8k.yaml
+++ b/examples/topr_gsm8k/gsm8k.yaml
@@ -49,14 +49,14 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
- enforce_eager: true
+ enable_prefix_caching: true
+ enforce_eager: false
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 8
- sync_timeout: 1200
+ sync_timeout: 12000
trainer:
save_interval: 100
grad_clip: 1.0
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index 0f085a25ea..201208e0f5 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -38,7 +38,7 @@ explorer:
rollout_model:
engine_num: 2
tensor_parallel_size: 1
- enable_prefix_caching: false
+ enable_prefix_caching: true
enforce_eager: true
dtype: bfloat16
seed: 42
diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py
index f8349062f1..3a30583a69 100644
--- a/trinity/algorithm/advantage_fn/__init__.py
+++ b/trinity/algorithm/advantage_fn/__init__.py
@@ -14,6 +14,7 @@
from trinity.algorithm.advantage_fn.opmd_advantage import (
OPMDAdvantageFn,
OPMDGroupAdvantage,
+ OPMDReweightAdvGroupAdvantage,
)
from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
from trinity.algorithm.advantage_fn.rec_advantage import RECGroupedAdvantage
@@ -37,6 +38,7 @@
"RLOOAdvantageFn",
"OPMDAdvantageFn",
"OPMDGroupAdvantage",
+ "OPMDReweightAdvGroupAdvantage",
"REINFORCEGroupAdvantage",
"ASYMREAdvantageFn",
"RECGroupedAdvantage",
diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py
index 82bea6c90d..bfdbe6f2ad 100644
--- a/trinity/algorithm/advantage_fn/opmd_advantage.py
+++ b/trinity/algorithm/advantage_fn/opmd_advantage.py
@@ -123,6 +123,7 @@ def calculate_group_advantage(
reward_mean = torch.mean(group_rewards)
if len(exps) == 1:
group_baseline = torch.tensor(0.0)
+ group_rewards = torch.tensor([exps[0].reward], dtype=torch.float32)
else:
if self.opmd_baseline == "mean":
group_baseline = reward_mean
@@ -144,3 +145,49 @@ def calculate_group_advantage(
@classmethod
def default_args(cls) -> dict:
return {"opmd_baseline": "mean", "tau": 1.0}
+
+@ADVANTAGE_FN.register_module("opmd_reweight_adv")
+class OPMDReweightAdvGroupAdvantage(GroupAdvantage):
+ """OPMD Group Advantage computation with reweighting"""
+
+ def __init__(self, opmd_baseline: str = "mean", tau: float = 1.0, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.opmd_baseline = opmd_baseline
+ self.tau = tau
+
+ def group_experiences(self, exps):
+ return group_by(exps, id_type="task")
+
+ def calculate_group_advantage(
+ self, group_id: str, exps: List[Experience]
+ ) -> Tuple[List[Experience], Dict]:
+ with torch.no_grad():
+ if len(exps) == 1:
+ group_baseline = torch.tensor(0.0)
+ group_rewards = torch.tensor([exps[0].reward], dtype=torch.float32)
+ else:
+ group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
+ if self.opmd_baseline == "mean":
+ group_baseline = torch.mean(group_rewards)
+ else:
+ group_baseline = self.tau * (
+ torch.logsumexp(group_rewards / self.tau, dim=-1)
+ - torch.log(torch.tensor(len(exps)))
+ )
+ for exp in exps:
+ score = exp.reward - group_baseline
+ if exp.reward >= 1.0:
+ score = 1.0
+ if score >= 0:
+ score = score * 3
+ exp.advantages = score * exp.action_mask
+ exp.returns = exp.advantages.clone()
+ metrics = {
+ "group_baseline": group_baseline.item(),
+ "reward_mean": torch.mean(group_rewards).item(),
+ }
+ return exps, metrics
+
+ @classmethod
+ def default_args(cls) -> dict:
+ return {"opmd_baseline": "mean", "tau": 1.0}
\ No newline at end of file
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index 4384da7b8e..9b950913b6 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -136,6 +136,28 @@ def default_config(cls) -> Dict:
"entropy_loss_fn": "default",
}
+@ALGORITHM_TYPE.register_module("opmd_reweight_adv")
+class OPMDReweightAdvAlgorithm(AlgorithmType):
+ """OPMD with reweighting advantage algorithm."""
+
+ use_critic: bool = False
+ # use_reference: bool = True
+ use_reference: bool = True
+ compute_advantage_in_trainer: bool = False
+ can_balance_batch: bool = True
+ schema: type = "experience"
+
+ @classmethod
+ def default_config(cls) -> Dict:
+ return {
+ "repeat_times": 2,
+ "advantage_fn": "opmd_reweight_adv",
+ "sample_strategy": "default",
+ "policy_loss_fn": "opmd",
+ "kl_penalty_fn": "none",
+ "kl_loss_fn": "k2",
+ "entropy_loss_fn": "default",
+ }
@ALGORITHM_TYPE.register_module("asymre")
class AsymREAlgorithm(AlgorithmType):
diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py
index 4153c049b2..de1eb41969 100644
--- a/trinity/buffer/operators/__init__.py
+++ b/trinity/buffer/operators/__init__.py
@@ -6,6 +6,8 @@
from trinity.buffer.operators.filters.reward_filter import RewardFilter, RewardSTDFilter
from trinity.buffer.operators.mappers.pass_rate_calculator import PassRateCalculator
from trinity.buffer.operators.mappers.reward_shaping_mapper import RewardShapingMapper
+from trinity.buffer.operators.filters.RAFT_filter import RAFTFilter
+from trinity.buffer.operators.filters.OPMD_filter import OPMDFilter
__all__ = [
"ExperienceOperator",
@@ -15,4 +17,6 @@
"RewardShapingMapper",
"PassRateCalculator",
"DataJuicerOperator",
+ "RAFTFilter",
+ "OPMDFilter",
]
diff --git a/trinity/buffer/operators/filters/OPMD_filter.py b/trinity/buffer/operators/filters/OPMD_filter.py
new file mode 100644
index 0000000000..e3b7f683aa
--- /dev/null
+++ b/trinity/buffer/operators/filters/OPMD_filter.py
@@ -0,0 +1,19 @@
+from typing import List, Tuple
+
+import copy
+
+from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
+from trinity.common.experience import Experience, group_by
+
+
+@EXPERIENCE_OPERATORS.register_module("OPMD_filter")
+class OPMDFilter(ExperienceOperator):
+ def __init__(self):
+ pass
+
+ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
+ # 过滤无效的经验
+ filtered_exps = [exp for exp in exps if exp.reward is not None and exp.prompt_length > 2]
+ metrics = {"filtered_count": len(exps) - len(filtered_exps)}
+ return filtered_exps, metrics
+
diff --git a/trinity/buffer/operators/filters/RAFT_filter.py b/trinity/buffer/operators/filters/RAFT_filter.py
new file mode 100644
index 0000000000..b71eb9ef9f
--- /dev/null
+++ b/trinity/buffer/operators/filters/RAFT_filter.py
@@ -0,0 +1,20 @@
+from typing import List, Tuple
+
+import copy
+
+from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
+from trinity.common.experience import Experience, group_by
+
+
+@EXPERIENCE_OPERATORS.register_module("RAFT_filter")
+class RAFTFilter(ExperienceOperator):
+ def __init__(self):
+ pass
+
+ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
+ # 过滤无效的经验
+ # filtered_exps = [exp for exp in exps if exp.reward is not None and exp.reward >= 1.0]
+ filtered_exps = [exp for exp in exps if exp.reward is not None and exp.reward >= 0.3]
+ metrics = {"filtered_count": len(exps) - len(filtered_exps)}
+ return filtered_exps, metrics
+
diff --git a/trinity/common/config.py b/trinity/common/config.py
index c722959b96..4d1e338f55 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -175,7 +175,7 @@ class StorageConfig:
task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig)
# enable progress bar (tqdm) for _HFBatchReader
- enable_progress_bar: Optional[bool] = False
+ enable_progress_bar: Optional[bool] = True
# get storage from existing experiment
ray_namespace: Optional[str] = None
@@ -224,7 +224,7 @@ class TasksetConfig:
max_retry_times: int = 3
max_retry_interval: int = 1
- enable_progress_bar: bool = False
+ enable_progress_bar: bool = True
# ! DO NOT SET, automatically load from checkpoint
index: int = 0
@@ -287,7 +287,7 @@ class ExperienceBufferConfig:
split: str = "train"
subset_name: Optional[str] = None
format: FormatConfig = field(default_factory=FormatConfig)
- enable_progress_bar: Optional[bool] = False
+ enable_progress_bar: Optional[bool] = True
# ! DO NOT SET, automatically set
schema_type: Optional[str] = None
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 1277f041f9..ed43ada567 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -29,8 +29,75 @@
RAFTReflectAlfworldWorkflow,
)
from trinity.common.workflows.envs.email_searcher.workflow import EmailSearchWorkflow
+from trinity.common.workflows.envs.R3L.alfworld.grpo_workflow import (
+ GRPOBaselineAlfworldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.alfworld.opmd_workflow import (
+ OPMDBaselineAlfworldWorkflow,
+)
+
+# Alfworld R3L workflows
+from trinity.common.workflows.envs.R3L.alfworld.dapo_workflow import (
+ DAPOAlfworldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.alfworld.R3L_workflow import R3LAlfworldWorkflow
+from trinity.common.workflows.envs.R3L.alfworld.raft_workflow import (
+ RAFTBaselineAlfworldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.countdown.grpo_workflow import (
+ GRPOBaselineCountdownWorkflow,
+)
+from trinity.common.workflows.envs.R3L.countdown.opmd_workflow import (
+ OPMDBaselineCountdownWorkflow,
+)
+
+# Countdown R3L workflows
+from trinity.common.workflows.envs.R3L.countdown.R3L_workflow import (
+ R3LCountdownWorkflow,
+)
+from trinity.common.workflows.envs.R3L.countdown.raft_workflow import (
+ RAFTBaselineCountdownWorkflow,
+)
+from trinity.common.workflows.envs.R3L.dapo.grpo_workflow import (
+ GRPOBaselineDapoWorkflow,
+)
+from trinity.common.workflows.envs.R3L.dapo.opmd_workflow import (
+ OPMDBaselineDapoWorkflow,
+)
+
+# DAPO R3L workflows
+from trinity.common.workflows.envs.R3L.dapo.R3L_workflow import R3LDapoWorkflow
+from trinity.common.workflows.envs.R3L.dapo.raft_workflow import (
+ RAFTBaselineDapoWorkflow,
+)
+from trinity.common.workflows.envs.R3L.scienceworld.grpo_workflow import (
+ GRPOBaselineScienceWorldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.scienceworld.opmd_workflow import (
+ OPMDBaselineScienceWorldWorkflow,
+)
+
+# ScienceWorld R3L workflows
+from trinity.common.workflows.envs.R3L.scienceworld.R3L_workflow import (
+ R3LScienceWorldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.scienceworld.raft_workflow import (
+ RAFTBaselineScienceWorldWorkflow,
+)
+from trinity.common.workflows.envs.R3L.webshop.grpo_workflow import (
+ GRPOBaselineWebshopWorkflow,
+)
+from trinity.common.workflows.envs.R3L.webshop.opmd_workflow import (
+ OPMDBaselineWebshopWorkflow,
+)
+from trinity.common.workflows.envs.R3L.webshop.R3L_workflow import R3LWebshopWorkflow
+from trinity.common.workflows.envs.R3L.webshop.raft_workflow import (
+ RAFTBaselineWebshopWorkflow,
+)
from trinity.common.workflows.envs.sciworld.sciworld_workflow import SciWorldWorkflow
from trinity.common.workflows.envs.webshop.webshop_workflow import WebShopWorkflow
+
+#
from trinity.common.workflows.eval_workflow import (
AsyncMathEvalWorkflow,
MathEvalWorkflow,
@@ -70,6 +137,32 @@
"AsyncMathWorkflow",
"MathWorkflow",
"WebShopWorkflow",
+ "R3LWebshopWorkflow",
+ "GRPOBaselineWebshopWorkflow",
+ "OPMDBaselineWebshopWorkflow",
+ "RAFTBaselineWebshopWorkflow",
+ # Alfworld R3L workflows
+ "R3LAlfworldWorkflow",
+ "GRPOBaselineAlfworldWorkflow",
+ "OPMDBaselineAlfworldWorkflow",
+ "RAFTBaselineAlfworldWorkflow",
+ "DAPOAlfworldWorkflow",
+ # DAPO R3L workflows
+ "R3LDapoWorkflow",
+ "GRPOBaselineDapoWorkflow",
+ "OPMDBaselineDapoWorkflow",
+ "RAFTBaselineDapoWorkflow",
+ # ScienceWorld R3L workflows
+ "R3LScienceWorldWorkflow",
+ "GRPOBaselineScienceWorldWorkflow",
+ "OPMDBaselineScienceWorldWorkflow",
+ "RAFTBaselineScienceWorldWorkflow",
+ # Countdown R3L workflows
+ "R3LCountdownWorkflow",
+ "GRPOBaselineCountdownWorkflow",
+ "OPMDBaselineCountdownWorkflow",
+ "RAFTBaselineCountdownWorkflow",
+ # Original workflows
"AlfworldWorkflow",
"StepWiseAlfworldWorkflow",
"RAFTAlfworldWorkflow",
diff --git a/trinity/common/workflows/envs/R3L-back/__init__.py b/trinity/common/workflows/envs/R3L-back/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/R3L_workflow.py
new file mode 100644
index 0000000000..a73023f78f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/R3L_workflow.py
@@ -0,0 +1,376 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_alfworld_workflow")
+class R3LAlfworldWorkflow(Workflow):
+ """
+ R3L workflow for alfworld
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_alfworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LAlfworldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Parse JSON
+ json_match = re.search(r"```json\s*(\{.*?\})\s*```", reflection_text, re.DOTALL)
+ if json_match:
+ json_str = json_match.group(1)
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ if not is_valid:
+ # Do another rollout to ensure the batch has enough data
+ try:
+ retry_env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, retry_env
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"]
+
+ try:
+ second_env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, or second is perfect with fewer steps,
+ # record reflection and retry data
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/__init__.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/grpo_workflow.py
new file mode 100644
index 0000000000..2de739c0aa
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/grpo_workflow.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_alfworld_workflow")
+class GRPOBaselineAlfworldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Alfworld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # History length for sliding window (verl-agent uses 2)
+ self.history_length = 2
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineAlfworldWorkflow, temperature={self.temperature}, history_length={self.history_length}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ # print(f"trajectory: {trajectory}")
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[GRPO] Rollout {i} failed with exception: {e}")
+ # exp = Experience(
+ # tokens=torch.tensor([0, 0], dtype=torch.long),
+ # prompt_length=1,
+ # action_mask=torch.tensor([False], dtype=torch.bool),
+ # logprobs=torch.tensor([0.0], dtype=torch.float),
+ # metrics={
+ # "success": 0.0,
+ # "reward": 0.0,
+ # }
+ # )
+ # exp.reward = 0.0
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/opmd_workflow.py
new file mode 100644
index 0000000000..9edb8a7573
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/opmd_workflow.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_alfworld_workflow")
+class OPMDBaselineAlfworldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Alfworld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing OPMDAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/alfworld_system.j2 b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/alfworld_system.j2
new file mode 100644
index 0000000000..b0879979b1
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/alfworld_system.j2
@@ -0,0 +1,5 @@
+You are an expert agent operating in the ALFRED Embodied Environment.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/reflection.j2
new file mode 100644
index 0000000000..87987361b4
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/reflection.j2
@@ -0,0 +1,133 @@
+# Metacognitive Analyst AI Prompt
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing a series of thoughts and actions. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks from user and environment.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes or missed opportunities? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall plan was misguided.
+- Reasoning Flaw: The interpretation of information was incorrect.
+- Execution Flaw: The intent was correct, but the resulting action was clumsy or ineffective.
+- Knowledge Gap: Lacked critical information necessary to solve the problem.
+- Inefficiency: The goal was achieved, but via a redundant or convoluted path.
+- Invalid Format: The response was syntactically incorrect or violated protocol.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover the core flawed assumption or problematic mental model that led to the flaw. Continuously ask "Why?" until the most fundamental cause is revealed.
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal overall strategy or approach for this task?
+What series of positive effects would likely have followed from using this better approach?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis of the "Leverage Point," formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained and include ALL necessary context. It should be understandable and applicable without requiring external knowledge of the specific trajectory.
+ - ❌ **BAD**: "Click operations tend to cause failures"
+ - ✅ **GOOD**: "In the xxx environment, when click operations are not available in the action space, attempting to execute click will cause failures"
+
+2. **Domain Specificity**: Clearly specify the environment, system, or context where this principle applies.
+ - Include environment name
+ - Include relevant constraints or conditions
+
+3. **Causal Chain Awareness**: The principle should consider not just the immediate impact but also downstream consequences.
+ - Consider how the corrective action affects subsequent steps
+ - Anticipate potential cascading effects
+
+4. **Actionable Structure**: The principle should be actionable and clear, typically including:
+ - The specific environment or context
+ - Clear trigger conditions or situations
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+ **Note**: The exact format can vary based on the nature of the insight. It could be a prescriptive rule ("When X, do Y"), a cautionary guideline ("Avoid X in situation Y"), or a strategic insight ("Prioritize X because Y"). Choose the format that best captures the lesson learned.
+
+5. **Independence Test**: The principle should be meaningful and correct even if read in isolation, without access to the original trajectory.
+
+**Question 3.2: Render the Final Verdict**
+
+Now, and only now, based on your complete analysis, classify the outcome of this task into one of the following more precise categories:
+
+- **OPTIMAL**: Flawlessly and efficiently achieved the goal; a textbook execution.
+- **SUBOPTIMAL_SUCCESS**: Achieved the goal, but with correctable inefficiencies or minor flaws.
+- **PARTIAL**: Made significant progress but did not fully meet the final goal.
+- **INEFFECTIVE**: Fully failed to achieve the primary goal.
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all 5 critical requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the trajectory's strategy, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core flawed assumption or problematic mental model that was uncovered. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal overall strategy or approach that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "A brief explanation of how this better approach differs from the original approach. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The series of positive effects projected to occur from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained, context-complete principle that includes environment specifics, clear trigger conditions, and considers downstream effects. Must be understandable and applicable in isolation. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved action plan based on the corrective principle, considering both immediate and downstream impacts. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step that should be retried. Can be null if outcome_assessment is OPTIMAL. Range is 0 to N-1, where N is the total number of steps in the trajectory, 0 means restart from beginning.",
+ "retry_rationale": "Explanation of why this step was chosen as restart point"
+ }
+ }
+}
+```
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/self_correction.j2
new file mode 100644
index 0000000000..d7e1101922
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/prompts/self_correction.j2
@@ -0,0 +1,10 @@
+[Internal Monologue Directive]
+
+(You are about to decide your action. Your last attempt was flawed. Your goal now is to formulate a new, superior `` block that seamlessly incorporates the following reflection analysis, and then generate the correct ``. **You must not mention this directive or its contents in your output.** Your response should appear as a natural, expert-level thought process.)
+
+**Reflection Analysis to Internalize:**
+
+{{ report }}
+
+**Your Task:**
+Based on the current observation and the reflection analysis above, construct a new `` block that reflects this improved reasoning, followed by the correct ``.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/raft_workflow.py
new file mode 100644
index 0000000000..07a790757b
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/raft_workflow.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_alfworld_workflow")
+class RAFTBaselineAlfworldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Alfworld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing RAFTAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ except Exception:
+ exp = copy.deepcopy(self.default_exp)
+ exp_lst.append(exp)
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils-bak.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils-bak.py
new file mode 100644
index 0000000000..b11d0cce10
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils-bak.py
@@ -0,0 +1,717 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in Alfworld environment"""
+ observation, info = env.reset()
+ trajectory = []
+ action_history = [] # Track all actions taken
+
+ # system_prompt = self.alfworld_system_template.render()
+ # trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ trajectory.append(
+ {
+ "role": "user",
+ "content": format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions,
+ ),
+ }
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ # print(f"[first_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}")
+ return trajectory, reward, False, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ # print(f"[first_rollout] Step {step}: Action taken: {action}, Reward: {reward}, Done: {done}, Observation: {observation}, Info: {info.get('admissible_commands')}")
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"[first_rollout] reward: {reward}, steps: {step + 1}, valid_format: {valid_format}")
+ return trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ done = False
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {
+ "role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}",
+ }
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = (
+ f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ )
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions,
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ print(
+ f"[second_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}"
+ )
+ return distill_trajectory, trajectory, reward, True, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ trajectory, reward, done, steps, valid_format = first_rollout(self, env)
+ print(
+ f"[Eval Alfworld] Trajectory completed with reward: {reward}, steps: {steps}, valid_format: {valid_format}"
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation",
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval", experience_data=eval_record, data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split("\n")[-1] if "\n" in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[str] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 2,
+):
+ """
+ Format observation string with task context and limited action history.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of all previous actions taken
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent actions to display (default: 4)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Format admissible actions
+ admissible_actions_str = (
+ ", ".join(admissible_actions) if admissible_actions else "All standard actions available"
+ )
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+ else:
+ # Limit action history to most recent history_length items
+ recent_actions = (
+ action_history[-history_length:]
+ if len(action_history) > history_length
+ else action_history
+ )
+
+ # Format action history as a structured list with observations
+ action_history_str = "\n".join(
+ [
+ f" Step {current_step - len(recent_actions) + i}: {action}"
+ for i, action in enumerate(recent_actions)
+ ]
+ )
+
+ # Create formatted prompt with limited history
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {len(recent_actions)} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+
+ return prompt
+
+
+def parse_response(response):
+ """Parse think and action components from response"""
+ try:
+ # Use regex to extract think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ action_match = re.search(action_pattern, response, re.DOTALL)
+
+ think = think_match.group(1).strip() if think_match else None
+ action = action_match.group(1).strip() if action_match else None
+
+ return think, action
+ except Exception:
+ return None, None
+
+
+def create_alfworld_environment(game_file, max_episode_steps=50):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file,
+ request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert],
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency",
+ ]
+ if (
+ diagnosis.get("category") not in valid_categories
+ and diagnosis.get("category") != "null"
+ ):
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(
+ f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}"
+ )
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(
+ f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}"
+ )
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(
+ f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}"
+ )
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1})."
+ )
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(task_id: str, experience_data: Dict, data_dir: str) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, "__dict__"):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None,
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory),
+ },
+ "created_at": datetime.now().isoformat(),
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils.py b/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils.py
new file mode 100644
index 0000000000..800ccc47e4
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld-bak/utils.py
@@ -0,0 +1,719 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Run a single rollout in Alfworld environment.
+ Uses sliding window approach like verl-agent: each step is a single-turn call.
+ """
+ observation, info = env.reset()
+
+ # Store complete trajectory for training data
+ full_trajectory = []
+
+ # Store history as (observation, action) pairs for sliding window
+ history = []
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation
+ task_description = extract_task_description(observation)
+ history_length = getattr(self, 'history_length', 2) # Use configurable history_length
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ # Build prompt with sliding window history (verl-agent style)
+ prompt = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=history, # Pass full history, format_observation will handle sliding window
+ admissible_actions=admissible_actions,
+ history_length=history_length
+ )
+
+ # Single-turn chat call (verl-agent style: each step is independent)
+ single_turn_messages = [{"role": "user", "content": prompt}]
+
+ # Get model response
+ responses = self.model.chat(
+ single_turn_messages,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 2560 - 512:
+ return full_trajectory, reward, False, step, False
+
+ response_text = responses[0].response_text.strip()
+
+ # Store in full trajectory for training
+ full_trajectory.append({"role": "user", "content": prompt})
+ full_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ full_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Store in history (only observation and action, NOT thinking - verl-agent style)
+ history.append({'observation': observation, 'action': action})
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ full_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return full_trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ done = False
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 2560 - 512:
+ print(f"[second_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}")
+ return distill_trajectory, trajectory, reward, True, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ print(
+ f"[Eval Alfworld] Trajectory completed with reward: {reward}, steps: {steps}, valid_format: {valid_format}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split('\n')[-1] if '\n' in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[Dict[str, str]] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 2
+):
+ """
+ Format observation string with task context and limited action history.
+ Adapted to verl-agent style: history is a list of {observation, action} dicts.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of {observation, action} dicts from previous steps
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent steps to display (default: 2)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Format admissible actions (remove 'help' like verl-agent does)
+ filtered_actions = [a for a in admissible_actions if a != 'help']
+ admissible_actions_str = ", ".join(f"'{a}'" for a in filtered_actions) if filtered_actions else "All standard actions available"
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version (ALFWORLD_TEMPLATE_NO_HIS style)
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+ else:
+ # Limit history to most recent history_length items (sliding window)
+ recent_history = action_history[-history_length:] if len(action_history) > history_length else action_history
+ valid_history_length = len(recent_history)
+ start_idx = len(action_history) - valid_history_length
+
+ # Format action history: only show actions, NOT observations (verl-agent style)
+ action_history_lines = []
+ for i, record in enumerate(recent_history):
+ step_num = start_idx + i
+ action = record['action']
+ # Note: We could include observation here like verl-agent does:
+ # obs = record['observation']
+ # action_history_lines.append(f"[Observation {step_num}: '{obs}', Action {step_num}: '{action}']")
+ # But for simplicity, just show actions:
+ action_history_lines.append(f" Step {step_num}: {action}")
+
+ action_history_str = "\n".join(action_history_lines)
+
+ # Create formatted prompt with limited history (ALFWORLD_TEMPLATE style)
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {valid_history_length} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+
+ return prompt
+
+
+def parse_response(response):
+ """Parse think and action components from response"""
+ try:
+ # Use regex to extract think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ action_match = re.search(action_pattern, response, re.DOTALL)
+
+ think = think_match.group(1).strip() if think_match else None
+ action = action_match.group(1).strip() if action_match else None
+
+ return think, action
+ except Exception:
+ return None, None
+
+
+def create_alfworld_environment(game_file, max_episode_steps=50):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file, request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert]
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency"
+ ]
+ if diagnosis.get("category") not in valid_categories and diagnosis.get("category") != "null":
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}")
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned
+ and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}")
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}")
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1}).")
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld/R3L_workflow.py
new file mode 100644
index 0000000000..d6dee60a94
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/R3L_workflow.py
@@ -0,0 +1,340 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_alfworld_workflow")
+class R3LAlfworldWorkflow(Workflow):
+ """
+ R3L workflow for alfworld
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = True
+ # Create data directories
+ self.data_dir = f"R3L_alfworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LAlfworldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["retry_from_step"]
+
+ try:
+ second_env = utils.create_alfworld_environment(self.game_file_path)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, or second is perfect with fewer steps,
+ # record reflection and retry data
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception as e:
+ print(f"Second rollout after reflection failed: {e}")
+ except Exception as e:
+ print(f"First rollout failed: {e}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/dapo_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld/dapo_workflow.py
new file mode 100644
index 0000000000..d7a35e55a5
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/dapo_workflow.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_alfworld_workflow")
+class DAPOAlfworldWorkflow(Workflow):
+ """
+ DAPO Workflow for Alfworld environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 512)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing DAPOAlfworldWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[DAPO] Rollout - reward: {reward}, steps: {steps}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "steps": steps,
+ "env_reward": reward,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld/grpo_workflow.py
new file mode 100644
index 0000000000..bbc34b716a
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/grpo_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_alfworld_workflow")
+class GRPOBaselineAlfworldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Alfworld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld/opmd_workflow.py
new file mode 100644
index 0000000000..61663f685f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/opmd_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_alfworld_workflow")
+class OPMDBaselineAlfworldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Alfworld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing OPMDAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/prompts/alfworld_system.j2 b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/alfworld_system.j2
new file mode 100644
index 0000000000..05d4532bf6
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/alfworld_system.j2
@@ -0,0 +1,44 @@
+You are an agent interacting with a virtual text-based environment.
+
+## Response Format:
+You MUST use this exact format for every response. Both tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Action Commands:
+Your must be one of the following, strictly following the command (argument) format.
+
+### Navigation & Observation:
+- look: Look around your current location to get more details.
+- inventory: Check the object you are currently holding (you can only hold one).
+- go to (receptacle): Move to a receptacle (e.g., table, fridge, sink).
+
+### Interacting with Receptacles:
+- open (receptacle): Open a receptacle.
+- close (receptacle): Close a receptacle.
+
+### Interacting with Objects:
+- take (object) from (receptacle): Pick up an object from a receptacle.
+- move (object) to (receptacle): Place the object you are holding into or onto a receptacle.
+- examine (object): Examine an object closely to learn its properties.
+
+### Changing Object States:
+- heat (object) with (receptacle): Heat an object with a device (e.g., microwave).
+- cool (object) with (receptacle): Cool an object with a device (e.g., fridge).
+- clean (object) with (receptacle): Clean an object with a device (e.g., sink).
+- slice (object) with (object): Slice an object using a sharp object (e.g., knife).
+
+For example your output should be like this:
+your reasoning process here
+look
+
+your reasoning process here
+go to sofa 1
+
+## Critical Rules & Constraints
+- Single Item Inventory: You can only hold one object at a time. You must put down the current object before taking a new one.
+- Examine Before Acting: Before performing an action on an object (like take, heat, or clean), it is best to examine it first to confirm its properties.
+- Use Exact Names: The (object) and (receptacle) arguments in your command MUST exactly match the names seen in your Observation, including any numbers (e.g., apple 1, desk 2).
+- Systematic Thinking: Break down complex tasks into smaller, manageable sub-goals. Clearly outline your plan in the block.
+- Step Limit: You must complete the task within 25 steps.
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/reflection.j2
new file mode 100644
index 0000000000..4590a3deec
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name, (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Solving the equation 3x² - 12x + 9 = 0
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent attempted to solve a quadratic equation by immediately applying the quadratic formula with a=3, b=-12, c=9. The calculation resulted in x = (12 ± √(144-108))/6 = (12 ± 6)/6, yielding x=3 or x=1. However, the agent failed to verify the solution and missed that the equation could be simplified first by factoring out 3, leading to a more elegant solution path.",
+"root_cause_analysis": "Why was the approach suboptimal? Because the agent jumped directly to the quadratic formula without checking for simplifications. Why skip simplification? Because it saw standard form ax²+bx+c=0 and immediately pattern-matched to 'use quadratic formula'. Why this pattern-matching? Because the agent treated the quadratic formula as a universal first-choice method rather than one tool among many. Root cause: Lack of strategic problem assessment - the agent optimized for immediate formula application rather than problem structure analysis, missing that all coefficients shared a common factor of 3.",
+"trajectory_outcome": "success_but_inefficient",
+"improvement_suggestion": "In mathematical problem-solving environments, always perform a structural analysis before applying solution methods: (1) check for common factors in all terms, (2) look for special patterns (perfect squares, difference of squares, sum/product relationships), (3) assess whether simplification reduces computational complexity. For quadratic equations specifically, factor out GCD first - this often reveals simpler factorizations or reduces calculation errors. Example: 3x²-12x+9=0 becomes 3(x²-4x+3)=0, then 3(x-1)(x-3)=0, directly yielding x=1 or x=3 without formula computation. Apply formula only when factoring is not immediately apparent.",
+"retry_from_step": 0
+}
+```
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/self_correction.j2
new file mode 100644
index 0000000000..10c0cf6627
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/alfworld/raft_workflow.py
new file mode 100644
index 0000000000..ec06ee4658
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/raft_workflow.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_alfworld_workflow")
+class RAFTBaselineAlfworldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Alfworld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing RAFTAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_alfworld_environment(self.game_file_path)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/alfworld/utils.py b/trinity/common/workflows/envs/R3L-back/alfworld/utils.py
new file mode 100644
index 0000000000..b665749c81
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/alfworld/utils.py
@@ -0,0 +1,696 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in Alfworld environment"""
+ observation, info = env.reset()
+ trajectory = []
+ action_history = [] # Track all actions taken
+
+ system_prompt = self.alfworld_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ done = False
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ trajectory.append(
+ {"role": "user", "content": format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )}
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action, error_msg = parse_response(response_text)
+ if error_msg is not None:
+ valid_format = False
+ observation = f"{error_msg}"
+ # 对于reward, done, info则保持默认值或者上一次的值
+ trajectory.append({"role": "user", "content": f"Feedback: {error_msg}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+ else:
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ if action not in admissible_actions:
+ valid_format = False
+ observation = f"Invalid action '{action}' not in admissible actions."
+ trajectory.append({"role": "user", "content": f"Feedback: {observation}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Track successfully executed actions for history
+ if valid_format:
+ action_history.append(action)
+
+ # Check for consecutive action repetition (last 3 actions)
+ if len(action_history) >= 3 and all(
+ a == action_history[-1] for a in action_history[-3:]
+ ):
+ repeated_action = action_history[-1]
+ feedback = f"Repeated invalid action {repeated_action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ done = False
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action, _ = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action, error_msg = parse_response(response_text)
+ if error_msg is not None:
+ valid_format = False
+ observation = f"{error_msg}"
+ else:
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ if action not in admissible_actions:
+ valid_format = False
+ observation = f"Invalid action '{action}' not in admissible actions."
+
+ # Track successfully executed actions for history
+ if valid_format:
+ action_history.append(action)
+
+ # Check for consecutive action repetition (last 3 actions)
+ if len(action_history) >= 3 and all(
+ a == action_history[-1] for a in action_history[-3:]
+ ):
+ repeated_action = action_history[-1]
+ feedback = f"Repeated invalid action {repeated_action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ if done:
+ break
+
+ print(f"[Second Rollout] - reward: {reward}, steps: {step + 1}")
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split('\n')[-1] if '\n' in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[str] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 4
+):
+ """
+ Format observation string with task context and limited action history.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of all previous actions taken
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent actions to display (default: 4)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: {admissible_actions}.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
+
+Format: your reasoning process your chosen action"""
+ else:
+ # Limit action history to most recent history_length items
+ recent_actions = action_history[-history_length:] if len(action_history) > history_length else action_history
+
+ # Format action history as a structured list
+ action_history_str = "\n".join([f" Step {current_step - len(recent_actions) + i}: {action}"
+ for i, action in enumerate(recent_actions)])
+
+ # Create formatted prompt with limited history
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {len(recent_actions)} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: {admissible_actions}.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
+
+Format: your reasoning process your chosen action"""
+
+ return prompt
+
+
+def parse_response(response):
+ """
+ Parse think and action components from response.
+ Returns (think, action, error_message) tuple.
+ - If successful: (think_content, action_content, None)
+ - If error: (None, None, error_message)
+ """
+ try:
+ # Use regex to extract all think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_matches = re.findall(think_pattern, response, re.DOTALL)
+ action_matches = re.findall(action_pattern, response, re.DOTALL)
+
+ # Check for multiple think tags
+ if len(think_matches) > 1:
+ return None, None, f"Multiple tags found ({len(think_matches)}). Only one pair is allowed."
+
+ # Check for multiple action tags
+ if len(action_matches) > 1:
+ return None, None, f"Multiple tags found ({len(action_matches)}). Only one pair is allowed."
+
+ # Check if tags are missing
+ if len(think_matches) == 0 and len(action_matches) == 0:
+ return None, None, "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ elif len(think_matches) == 0:
+ return None, None, "Invalid response format, missing valid tag, please ensure to follow the output format strictly: ... ..."
+ elif len(action_matches) == 0:
+ return None, None, "Invalid response format, missing valid tag, please ensure to follow the output format strictly: ... ..."
+
+ think = think_matches[0].strip()
+ action = action_matches[0].strip()
+
+ return think, action, None
+ except Exception:
+ return None, None, "Unexpected error occurred while parsing response format."
+
+
+def create_alfworld_environment(game_file, max_episode_steps=25):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file, request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert]
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+ analysis = report["root_cause_analysis"]
+
+ if outcome == "success":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if retry_from_step is None or retry_from_step is None:
+ print("Validation failed: Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"Validation failed: 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"Validation failed: 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_from_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {max_steps - 1}).")
+ return False, False
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/countdown/R3L_workflow.py
new file mode 100644
index 0000000000..25c03921b0
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/R3L_workflow.py
@@ -0,0 +1,354 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_countdown_workflow")
+class R3LCountdownWorkflow(Workflow):
+ """
+ R3L workflow for Countdown mathematical problem solving
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_countdown_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LCountdownWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ print(f"[R3L] Reflection failed - Error: {str(e)}")
+ return None, None, None
+
+ def run(self) -> List[Experience]:
+ """Run the R3L countdown workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[R3L Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ print(f"[R3L] Starting reflection on first attempt (reward: {reward})...")
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, attempts)
+
+ if reflect_checklist is None:
+ print(f"[R3L] Reflection failed - No valid reflection data generated")
+ elif is_valid and not is_perfect:
+ print(f"[R3L] Reflection successful - Valid reflection generated")
+ elif is_perfect:
+ print(f"[R3L] Reflection indicates perfect first attempt - No retry needed")
+ elif not is_valid:
+ print(f"[R3L] Reflection validation failed - Invalid reflection data")
+
+ if not is_valid or is_perfect:
+ print(f"[R3L] Skip second rollout due to invalid ({not is_valid}) or perfect ({is_perfect}) reflection.")
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Do another rollout to ensure the batch has enough data
+ print(f"[R3L] Performing additional rollout...")
+ try:
+ retry_trajectory, retry_reward, retry_success, retry_predicted_answer, retry_ground_truth, retry_attempts = utils.first_rollout(self)
+ print(f"[R3L] Additional rollout completed - reward: {retry_reward}, attempts: {retry_attempts}")
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_success else 0.0,
+ "reward": retry_reward,
+ "attempts": retry_attempts,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ success=retry_success,
+ predicted_answer=retry_predicted_answer,
+ ground_truth=retry_ground_truth,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"[R3L] Retry rollout after invalid reflection failed - Error: {e}")
+
+ else:
+ print("[R3L] Valid reflection obtained, proceeding to second rollout...")
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"] if "retry_strategy" in reflect_checklist.get("analysis", {}) else 0
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_success,
+ second_predicted_answer,
+ second_ground_truth,
+ second_attempts,
+ ) = utils.second_rollout(
+ self, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, attempts: {second_attempts}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_success else 0.0,
+ "second_reward": second_reward,
+ "second_attempts": second_attempts,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ success=second_success,
+ predicted_answer=second_predicted_answer,
+ ground_truth=second_ground_truth,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, record reflection and retry data
+ if second_reward > reward and second_reward >= 1.0:
+ print(f"[R3L] Second attempt successful improvement - Recording reflection and retry experiences")
+ print(f"[R3L] Reward improvement: {reward} -> {second_reward} (+{second_reward - reward:.2f})")
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("[R3L] Reflection and retry led to improvement, recording both...")
+ elif second_reward <= reward:
+ print(f"[R3L] Second attempt did not improve - First reward: {reward}, Second reward: {second_reward}")
+ else:
+ print(f"[R3L] Second attempt improved but below threshold - Reward: {second_reward} (need >= 1.0)")
+ except Exception as e:
+ print(f"[R3L] Second rollout failed - Error: {str(e)}")
+ except Exception as e:
+ print(f"[R3L] Rollout iteration {i} failed - Error: {str(e)}")
+
+ # Print summary statistics
+ print(f"\n[R3L Summary] Generated {len(exp_lst)} experiences")
+ total_reward = sum(exp.reward for exp in exp_lst)
+ avg_reward = total_reward / len(exp_lst) if exp_lst else 0.0
+ print(f"[R3L Summary] Total reward: {total_reward:.2f}, Average reward: {avg_reward:.2f}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/__init__.py b/trinity/common/workflows/envs/R3L-back/countdown/__init__.py
new file mode 100644
index 0000000000..5c3d1be19a
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/__init__.py
@@ -0,0 +1,2 @@
+# -*- coding: utf-8 -*-
+"""Countdown R3L workflows"""
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/countdown/grpo_workflow.py
new file mode 100644
index 0000000000..62e87c8833
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/grpo_workflow.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_countdown_workflow")
+class GRPOBaselineCountdownWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Countdown environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ # if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums")
+ self.target = raw_task.get("target")
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Multiple rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[GRPO Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/countdown/opmd_workflow.py
new file mode 100644
index 0000000000..0d13287070
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/opmd_workflow.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_countdown_workflow")
+class OPMDBaselineCountdownWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Countdown mathematical problem solving.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing OPMDCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[OPMD Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/prompts/countdown_system.j2 b/trinity/common/workflows/envs/R3L-back/countdown/prompts/countdown_system.j2
new file mode 100644
index 0000000000..f497326641
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/prompts/countdown_system.j2
@@ -0,0 +1,27 @@
+You are a mathematical problem solver. Your task is to create equations using given numbers to reach a target value.
+
+## Response Format:
+You MUST use this exact format for every response. Both tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+your final equation
+
+## Task Description:
+Given a set of numbers and a target value, you need to create an equation using basic arithmetic operations (+, -, *, /) where:
+- Each given number can only be used once
+- You can use parentheses to control order of operations
+- The equation must equal the target value
+
+## Example:
+For numbers [44, 19, 35] and target 98:
+Let me try different combinations:
+- 44 + 19 + 35 = 98 ✓ This works!
+
+(44 + 19 + 35)
+
+## Critical Rules:
+- Use each number exactly once
+- Only use the four basic operations: +, -, *, /
+- Your answer must be in the tags
+- Show your reasoning process in the tags
+- The equation in must be valid and evaluate to the target
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/countdown/prompts/reflection.j2
new file mode 100644
index 0000000000..4517ba643c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/prompts/reflection.j2
@@ -0,0 +1,138 @@
+# Metacognitive Analyst AI Prompt for Countdown Task
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing mathematical problem-solving attempts. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log showing attempts to solve countdown equations. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks about equation correctness.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes (wrong numbers used, incorrect operations, format errors)? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall mathematical approach was misguided.
+- Reasoning Flaw: The arithmetic calculation was incorrect.
+- Execution Flaw: The equation format was wrong or numbers were misused.
+- Knowledge Gap: Lacked critical understanding of arithmetic operations.
+- Inefficiency: The goal was achieved, but via an overly complex equation.
+- Invalid Format: The response violated the required format.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover why the equation failed. Was it:
+- Using wrong numbers?
+- Using numbers multiple times?
+- Incorrect arithmetic operations?
+- Wrong order of operations?
+- Format issues?
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal mathematical strategy for this countdown problem?
+What arithmetic operations or combination would lead to the correct answer?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis, formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained for countdown mathematical tasks.
+ - ❌ **BAD**: "Check if numbers are correct"
+ - ✅ **GOOD**: "In countdown tasks, before finalizing an equation, verify that each given number is used exactly once and no extra numbers are introduced"
+
+2. **Domain Specificity**: Clearly specify this applies to countdown equation problems.
+
+3. **Actionable Structure**: The principle should include:
+ - The specific context (countdown equation task)
+ - Clear trigger conditions
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+4. **Independence Test**: The principle should be meaningful even if read in isolation.
+
+**Question 3.2: Render the Final Verdict**
+
+Classify the outcome into one of the following categories:
+
+- **OPTIMAL**: Correctly solved the equation; perfect execution.
+- **SUBOPTIMAL_SUCCESS**: Got the right answer, but used an overly complex approach.
+- **PARTIAL**: Made progress but the final equation was incorrect.
+- **INEFFECTIVE**: Completely failed to produce a valid equation.
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the mathematical approach, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core mathematical error or problematic approach. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal mathematical strategy that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "How this approach differs from the failed attempt. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The expected positive outcomes from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained principle for solving countdown equations more effectively. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved mathematical approach based on the corrective principle. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "For countdown tasks, this is typically 0 (restart from beginning) since it's a single-step problem. Can be null if outcome_assessment is OPTIMAL.",
+ "retry_rationale": "Explanation of why restarting is needed"
+ }
+ }
+}
+```
+
+## Quality Check for Corrective Principles
+
+Before finalizing, verify your corrective principle against this checklist:
+- [ ] Can someone unfamiliar with this specific problem understand and apply this principle?
+- [ ] Does it specify it applies to countdown equation tasks?
+- [ ] Does it include clear, actionable guidance?
+- [ ] Is it specific enough to be actionable but general enough to be reusable?
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/countdown/prompts/self_correction.j2
new file mode 100644
index 0000000000..fe7d377a39
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/prompts/self_correction.j2
@@ -0,0 +1,15 @@
+# Previous Attempt Analysis & Guidance
+
+Based on analysis of a previous attempt at a similar problem, here is guidance to help you solve this problem more effectively:
+
+{{ report }}
+
+## Instructions for This Attempt
+
+- Apply the insights and strategies mentioned above
+- Be especially careful about the identified error types
+- Follow the recommended approach if applicable
+- Double-check your work at each step
+- Remember: This guidance is based on common pitfalls, use it to inform your strategy
+
+**Important**: Present your solution naturally, showing your reasoning process. Do not explicitly reference this guidance or mention it's a retry.
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/countdown/raft_workflow.py
new file mode 100644
index 0000000000..32a973d78b
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/raft_workflow.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_countdown_workflow")
+class RAFTBaselineCountdownWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Countdown environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing RAFTBaselineCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Multiple rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[RAFT Countdown] Rollout {i} - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/countdown/utils.py b/trinity/common/workflows/envs/R3L-back/countdown/utils.py
new file mode 100644
index 0000000000..f98f7e3c86
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/countdown/utils.py
@@ -0,0 +1,583 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.utils.eval_utils import evaluate_equation, validate_equation
+
+
+def first_rollout(self) -> tuple[List[Dict[str, str]], float, bool, str, str, int]:
+ """Run countdown problem solving with multiple attempts (max 3 attempts) using multi-round interaction"""
+ trajectory = []
+ attempt_history = [] # Track attempt history for limited history display
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Format user prompt with history (limited to history_length)
+ user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ trajectory.append({"role": "user", "content": user_prompt})
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": "Invalid format",
+ "feedback": feedback
+ })
+ continue
+
+ # Verify answer
+ is_correct = countdown_verify(predicted_answer, self.numbers, self.target)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ feedback = f"Correct! Your equation {predicted_answer} successfully equals {self.target} using the numbers {self.numbers}."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your equation {predicted_answer} does not work. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your equation {predicted_answer} does not match the target {self.target}. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": predicted_answer,
+ "feedback": feedback
+ })
+
+ return trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+
+def second_rollout(
+ self,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, str, str, int]:
+ """
+ Performs rollout with guidance from reflection.
+ For countdown problems, we typically start from the beginning with guidance.
+ """
+ trajectory = []
+ distill_trajectory = []
+ attempt_history = [] # Track attempt history for limited history display
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Format user prompt with history and guidance
+ if attempt == 0:
+ # First attempt includes guidance
+ user_prompt = format_countdown_prompt_with_guidance(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ guidance_prompt=guidance_prompt,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ # For distill trajectory, use prompt without guidance
+ distill_user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ else:
+ # Subsequent attempts don't repeat guidance
+ user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ distill_user_prompt = user_prompt
+
+ trajectory.append({"role": "user", "content": user_prompt})
+ distill_trajectory.append({"role": "user", "content": distill_user_prompt})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": "Invalid format",
+ "feedback": feedback
+ })
+ continue
+
+ # Verify answer
+ is_correct = countdown_verify(predicted_answer, self.numbers, self.target)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ feedback = f"Correct! Your equation {predicted_answer} successfully equals {self.target}."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your equation {predicted_answer} does not work. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your equation {predicted_answer} does not match the target {self.target}. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": predicted_answer,
+ "feedback": feedback
+ })
+
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+
+def eval_countdown(self) -> List[Experience]:
+ """Evaluate a single countdown problem"""
+ print("[R3L Countdown Eval] Starting evaluation...")
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ print(f"[R3L Countdown Eval] Completed - Reward: {reward}, Success: {success}, Attempts: {attempts}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ print(f"[R3L Countdown Eval] Evaluation failed - Error: {str(e)}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def parse_response(response: str) -> Tuple[Optional[str], Optional[str]]:
+ """Parse think and answer from countdown response"""
+ try:
+ # Extract think section
+ think_pattern = r"\s*(.*?)\s*"
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ think = think_match.group(1).strip() if think_match else None
+
+ # Extract answer from tags
+ answer_pattern = r"\s*(.*?)\s*"
+ answer_match = re.search(answer_pattern, response, re.DOTALL | re.IGNORECASE)
+ if answer_match:
+ answer = answer_match.group(1).strip()
+ else:
+ answer = None
+
+ return think, answer
+ except Exception as e:
+ print(f"Error parsing response: {e}")
+ return None, None
+
+
+def countdown_verify(predicted_answer: str, numbers: List[int], target: int) -> bool:
+ """
+ Verify if the predicted countdown equation is correct.
+ """
+ if not predicted_answer:
+ print("Predicted answer is empty.")
+ return False
+
+ # Extract equation from predicted answer
+ equation = predicted_answer
+
+ # Validate equation uses correct numbers
+ if not validate_equation(equation, numbers):
+ print("Equation validation failed: uses invalid numbers.")
+ return False
+
+ # Evaluate equation
+ try:
+ result = evaluate_equation(equation)
+ if result is None:
+ print("Equation evaluation returned None.")
+ return False
+
+ if abs(result - target) < 1e-5: # Account for floating point precision
+ print(f"Equation evaluation successful: matches target, {result}, {target}.")
+ return True
+ else:
+ print(f"Equation evaluation result {result} does not match target {target}.")
+ return False
+ except Exception as e:
+ print(f"Error evaluating equation: {e}")
+ return False
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Format trajectory for reflection analysis.
+ Includes all messages including feedback.
+ """
+ formatted_lines = []
+ step_counter = 0
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ formatted_lines.append(f"**System Prompt:**\n{content}\n" + "=" * 50)
+ elif role == "user":
+ formatted_lines.append(f"\n**Step {step_counter} - User:**")
+ formatted_lines.append(f"{content}")
+ elif role == "assistant":
+ formatted_lines.append(f"\n**Step {step_counter} - Assistant Response:**")
+ formatted_lines.append(f"{content}")
+ step_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validate the structure and content of the reflection report.
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ """
+ if not isinstance(report, dict):
+ print("[R3L Countdown Validation] Reflection report is not a dict")
+ return False, False
+
+ # Check required keys
+ if "outcome_assessment" not in report or "analysis" not in report:
+ print("[R3L Countdown Validation] Missing required top-level keys in reflection report")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check valid outcome values
+ valid_outcomes = ["OPTIMAL", "SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]
+ if outcome not in valid_outcomes:
+ print(f"[R3L Countdown Validation] Invalid outcome_assessment: {outcome} (valid: {valid_outcomes})")
+ return False, False
+
+ # If OPTIMAL, it's perfect
+ if outcome == "OPTIMAL":
+ return True, True
+
+ # For non-OPTIMAL outcomes, check required analysis fields
+ if "summary" not in analysis:
+ print("[R3L Countdown Validation] Missing 'summary' in analysis")
+ return False, False
+
+ if "flaw_analysis" not in analysis:
+ print("[R3L Countdown Validation] Missing 'flaw_analysis' in analysis")
+ return False, False
+
+ if "lessons_learned" not in analysis:
+ print("[R3L Countdown Validation] Missing 'lessons_learned' in analysis")
+ return False, False
+
+ if "retry_strategy" not in analysis:
+ print("[R3L Countdown Validation] Missing 'retry_strategy' in analysis")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis["retry_strategy"]
+ if "retry_step" not in retry_strategy:
+ print("[R3L Countdown Validation] Missing 'retry_step' in retry_strategy")
+ return False, False
+
+ return True, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Convert validated reflection report into a structured guidance prompt.
+ """
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and attempts to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(experience_data, f, indent=2, ensure_ascii=False)
+
+ return filepath
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ success: bool,
+ predicted_answer: str,
+ ground_truth: str,
+ attempt_type: str,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record.
+
+ Args:
+ task_id: Task identifier
+ trajectory: Conversation trajectory
+ reward: Final reward
+ success: Whether the task was successful
+ predicted_answer: Model's predicted answer
+ ground_truth: Correct answer
+ attempt_type: Type of attempt (e.g., 'first', 'second', 'evaluation')
+ additional_metrics: Optional additional metrics
+
+ Returns:
+ Experience record dictionary
+ """
+ record = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "reward": reward,
+ "success": success,
+ "predicted_answer": predicted_answer,
+ "ground_truth": ground_truth,
+ }
+
+ if additional_metrics:
+ record["additional_metrics"] = additional_metrics
+
+ return record
+
+
+def format_countdown_prompt(
+ numbers: List[int],
+ target: int,
+ current_step: int,
+ attempt_history: List[Dict[str, str]],
+ history_length: int = 4
+) -> str:
+ """
+ Format countdown prompt with limited history.
+
+ Args:
+ numbers: Available numbers for the countdown problem
+ target: Target number to achieve
+ current_step: Current attempt number
+ attempt_history: List of previous attempts with equations and feedback
+ history_length: Maximum number of previous attempts to show (default: 4)
+
+ Returns:
+ Formatted prompt string
+ """
+ if current_step == 0 or not attempt_history:
+ # First attempt - no history
+ prompt = f"""You are an expert at solving countdown number problems.
+Your current task is: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+ else:
+ # Show limited history
+ recent_attempts = attempt_history[-history_length:] if len(attempt_history) > history_length else attempt_history
+
+ # Format attempt history as a list
+ history_lines = []
+ for idx, attempt in enumerate(recent_attempts):
+ attempt_num = current_step - len(recent_attempts) + idx + 1
+ history_lines.append(f" Attempt {attempt_num}: {attempt['equation']} -> {attempt['feedback']}")
+
+ attempt_history_str = "\n".join(history_lines)
+
+ prompt = f"""You are an expert at solving countdown number problems.
+Your task is: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
+Prior to this attempt, you have already made {current_step} attempt(s). Below are the most recent {len(recent_attempts)} attempts and their feedback:
+{attempt_history_str}
+You are now at attempt {current_step + 1}.
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+
+ return prompt
+
+
+def format_countdown_prompt_with_guidance(
+ numbers: List[int],
+ target: int,
+ current_step: int,
+ attempt_history: List[Dict[str, str]],
+ guidance_prompt: str,
+ history_length: int = 4
+) -> str:
+ """
+ Format countdown prompt with limited history and guidance from reflection.
+
+ Args:
+ numbers: Available numbers for the countdown problem
+ target: Target number to achieve
+ current_step: Current attempt number
+ attempt_history: List of previous attempts with equations and feedback
+ guidance_prompt: Guidance from reflection analysis
+ history_length: Maximum number of previous attempts to show (default: 4)
+
+ Returns:
+ Formatted prompt string with guidance
+ """
+ base_prompt = format_countdown_prompt(numbers, target, current_step, attempt_history, history_length)
+
+ # Insert guidance before the final instruction
+ prompt_with_guidance = f"""{base_prompt.split('Now it\'s your turn')[0]}
+# Previous Attempt Analysis & Guidance
+{guidance_prompt}
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+
+ return prompt_with_guidance
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/dapo/R3L_workflow.py
new file mode 100644
index 0000000000..93e7654ee1
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/R3L_workflow.py
@@ -0,0 +1,373 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_dapo_workflow")
+class R3LDapoWorkflow(Workflow):
+ """
+ R3L workflow for DAPO mathematical problem solving
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_dapo_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LDapoWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ print(f"[R3L] Reflection failed - Error: {str(e)}")
+ return None, None, None
+
+ def run(self) -> List[Experience]:
+ """Run the R3L dapo workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[R3L] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ print(f"[R3L] Starting reflection on first attempt (reward: {reward})...")
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, attempts)
+
+ if reflect_checklist is None:
+ print(f"[R3L] Reflection failed - No valid reflection data generated")
+ elif is_valid and not is_perfect:
+ print(f"[R3L] Reflection successful - Valid reflection generated")
+ elif is_perfect:
+ print(f"[R3L] Reflection indicates perfect first attempt - No retry needed")
+ elif not is_valid:
+ print(f"[R3L] Reflection validation failed - Invalid reflection data")
+
+ if not is_valid or is_perfect:
+ print(f"[R3L] Skip second rollout due to invalid ({not is_valid}) or perfect ({is_perfect}) reflection.")
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Do another rollout to ensure the batch has enough data
+ print(f"[R3L] Performing additional rollout...")
+ try:
+ retry_trajectory, retry_reward, retry_success, retry_predicted_answer, retry_ground_truth, retry_attempts = utils.first_rollout(self)
+ print(f"[R3L] Additional rollout completed - reward: {retry_reward}, attempts: {retry_attempts}")
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_success else 0.0,
+ "reward": retry_reward,
+ "attempts": retry_attempts,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ success=retry_success,
+ predicted_answer=retry_predicted_answer,
+ ground_truth=retry_ground_truth,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"[R3L] Retry rollout after invalid reflection failed - Error: {e}")
+
+ else:
+ print("[R3L] Valid reflection obtained, proceeding to second rollout...")
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"] if "retry_strategy" in reflect_checklist.get("analysis", {}) else 0
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_success,
+ second_predicted_answer,
+ second_ground_truth,
+ second_attempts,
+ ) = utils.second_rollout(
+ self, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, attempts: {second_attempts}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_success else 0.0,
+ "second_reward": second_reward,
+ "second_attempts": second_attempts,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ success=second_success,
+ predicted_answer=second_predicted_answer,
+ ground_truth=second_ground_truth,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, record reflection and retry data
+ if second_reward > reward and second_reward >= 1.0:
+ print(f"[R3L] Second attempt successful improvement - Recording reflection and retry experiences")
+ print(f"[R3L] Reward improvement: {reward} -> {second_reward} (+{second_reward - reward:.2f})")
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("[R3L] Reflection and retry led to improvement, recording both...")
+ elif second_reward <= reward:
+ print(f"[R3L] Second attempt did not improve - First reward: {reward}, Second reward: {second_reward}")
+ else:
+ print(f"[R3L] Second attempt improved but below threshold - Reward: {second_reward} (need >= 1.0)")
+ except Exception as e:
+ print(f"[R3L] Second rollout failed - Error: {str(e)}")
+ except Exception as e:
+ print(f"[R3L] Rollout iteration {i} failed - Error: {str(e)}")
+
+ # Print summary statistics
+ print(f"\n[R3L Summary] Generated {len(exp_lst)} experiences")
+ total_reward = sum(exp.reward for exp in exp_lst)
+ avg_reward = total_reward / len(exp_lst) if exp_lst else 0.0
+ print(f"[R3L Summary] Total reward: {total_reward:.2f}, Average reward: {avg_reward:.2f}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/__init__.py b/trinity/common/workflows/envs/R3L-back/dapo/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/dapo/grpo_workflow.py
new file mode 100644
index 0000000000..bd97ead5d7
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/grpo_workflow.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_dapo_workflow")
+class GRPOBaselineDapoWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for DAPO mathematical problem solving.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[GRPO] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/dapo/opmd_workflow.py
new file mode 100644
index 0000000000..e257764fdd
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/opmd_workflow.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_dapo_workflow")
+class OPMDBaselineDapoWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for DAPO mathematical problem solving.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing OPMDDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[OPMD] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/prompts/math_system.j2 b/trinity/common/workflows/envs/R3L-back/dapo/prompts/math_system.j2
new file mode 100644
index 0000000000..cbe87e085f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/prompts/math_system.j2
@@ -0,0 +1,46 @@
+You are a mathematical problem solver. Your task is to solve mathematical problems step by step.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your step-by-step reasoning and solution process
+your final answer
+
+## Instructions:
+1. Carefully read and understand the problem
+2. Show your reasoning step by step in the tags
+3. Provide your final answer in the tags
+4. For numerical answers, provide the exact value
+5. If the problem asks for a specific format (e.g., \\boxed{}), use that format in your answer
+
+## Example:
+Problem: "What is the sum of all positive integers less than 100 that are divisible by 3?"
+
+
+I need to find all positive integers less than 100 that are divisible by 3, then sum them.
+
+The integers divisible by 3 less than 100 are: 3, 6, 9, ..., 99
+This is an arithmetic sequence with:
+- First term a₁ = 3
+- Common difference d = 3
+- Last term aₙ = 99
+
+To find how many terms: aₙ = a₁ + (n-1)d
+99 = 3 + (n-1)×3
+96 = (n-1)×3
+n-1 = 32
+n = 33
+
+Sum of arithmetic sequence: S = n(a₁ + aₙ)/2
+S = 33(3 + 99)/2
+S = 33 × 102/2
+S = 33 × 51
+S = 1683
+
+\boxed{1683}
+
+## Notes:
+- Be thorough in your reasoning
+- Show all important steps
+- Double-check your calculations
+- Provide the final answer clearly in the tags
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/dapo/prompts/reflection.j2
new file mode 100644
index 0000000000..6259cb384f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/prompts/reflection.j2
@@ -0,0 +1,178 @@
+# Metacognitive Analyst AI Prompt for Mathematical Problem Solving
+
+You are a Metacognitive Analyst AI specialized in mathematical problem-solving analysis. Your core mission is to analyze a "Trajectory Log" containing a series of mathematical reasoning steps and actions. Your goal is to extract deep insights, identify errors, and formulate actionable guidance for improvement.
+
+You will receive:
+1. A trajectory log showing the solution attempt
+2. The correct final answer (for comparison only - DO NOT reveal this in your guidance)
+
+Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+### Part 1: Global Review & Answer Verification
+
+First, you must understand the entire trajectory and verify the final answer.
+
+**Question 1.1: Extract the Final Answer**
+Identify the final answer provided in the trajectory. Look for patterns like "Answer: X" or "\boxed{X}".
+
+**Question 1.2: Verify Correctness**
+Compare the extracted answer with the correct answer. Is it correct?
+- If correct and the approach was efficient: This might be OPTIMAL
+- If correct but inefficient: This might be SUBOPTIMAL_SUCCESS
+- If incorrect: Determine the severity (PARTIAL if close/minor error, INEFFECTIVE if completely wrong)
+
+**Question 1.3: Panoramic Strategy Analysis**
+Read through the entire mathematical reasoning and summarize in one or two sentences what the overall approach was and what result it led to.
+
+### Part 2: Error Analysis (If Answer is Incorrect)
+
+This section is mandatory if the answer is incorrect. Use the correct answer to guide your analysis, but DO NOT mention the specific correct answer value in your guidance.
+
+**Question 2.1: Error Type Classification**
+Based on comparing with the correct answer, classify the primary error into ONE category:
+- **Calculation Error**: Arithmetic mistakes, algebraic manipulation errors
+- **Conceptual Misunderstanding**: Wrong formula, theorem, or principle applied
+- **Method Selection Error**: Used a valid but inefficient or inappropriate method for this problem
+- **Setup Error**: Correct method but wrong initial setup (e.g., wrong equation, wrong variable definition)
+- **Logic Error**: Flawed reasoning or invalid logical steps
+- **Incomplete Solution**: Stopped prematurely or missed critical steps
+- **Sign/Unit Error**: Correct approach but sign errors, unit conversion errors
+
+**Question 2.2: Pinpoint the Critical Mistake**
+Identify the specific step where the error occurred or where a better approach should have been taken. What was the exact mistake?
+
+**Question 2.3: Root Cause Analysis**
+Why did this error occur? Was it:
+- Lack of attention to problem constraints?
+- Overlooking a key condition or relationship?
+- Choosing a complicated path when a simpler one exists?
+- Missing a pattern or symmetry?
+- Misapplying a concept?
+
+**Question 2.4: Directional Guidance (WITHOUT revealing answer)**
+Based on the error type and the correct answer, provide DIRECTIONAL hints:
+
+For Calculation Errors:
+- "Double-check arithmetic in step X"
+- "Verify algebraic manipulations, especially expansions/factorizations"
+- "Re-examine the computation of [specific expression]"
+
+For Conceptual/Method Errors:
+- "Consider whether [theorem/method Y] might be more appropriate"
+- "Re-examine the problem type - does it suggest [technique Z]?"
+- "Think about alternative approaches like [method category]"
+- "Review the conditions - they might suggest using [concept]"
+
+For Setup Errors:
+- "Re-examine how you defined [variable/equation]"
+- "Check if all constraints from the problem are captured"
+- "Verify the initial equation setup matches the problem description"
+
+For Logic Errors:
+- "Re-examine the logical flow from step X to Y"
+- "Check if all cases are considered"
+- "Verify each implication is valid"
+
+### Part 3: Better Approach Analysis
+
+**Question 3.1: Identify Optimal Method**
+What would be the most efficient and elegant approach for this problem? Consider:
+- Direct methods vs. computational approaches
+- Symmetry and special cases
+- Standard techniques for this problem type
+
+**Question 3.2: Key Insights**
+What key mathematical insights or patterns should be recognized to solve this efficiently?
+
+**Question 3.3: Step-by-Step Optimal Strategy (High-Level)**
+Outline a high-level strategy (WITHOUT solving) that would lead to success:
+- Initial analysis and problem understanding
+- Method selection rationale
+- Key intermediate steps
+- Verification strategy
+
+### Part 4: Synthesis and Corrective Principle
+
+**Question 4.1: Formulate Corrective Principle**
+Create a principle that:
+- Addresses the specific error type
+- Provides actionable guidance
+- Is applicable to similar problems
+- Does NOT reveal the specific answer
+
+Examples of GOOD principles:
+- "For problems involving [problem type], always check if [technique] applies before using [alternative technique]"
+- "When dealing with [concept], remember to consider [constraint] which often affects [aspect]"
+- "In [problem category], look for [pattern/symmetry] which can simplify the solution significantly"
+
+**Question 4.2: Render Final Verdict**
+Classify the outcome:
+- **OPTIMAL**: Correct answer with efficient approach
+- **SUBOPTIMAL_SUCCESS**: Correct answer but inefficient/convoluted method
+- **PARTIAL**: Made progress, some correct reasoning, but wrong final answer
+- **INEFFECTIVE**: Fundamentally wrong approach or completely incorrect
+
+## Final Output Format
+
+### Part One: Detailed Analysis Process
+
+**1. Global Review & Answer Verification**
+- 1.1 Extracted Final Answer: [Answer from trajectory]
+- 1.2 Correctness Verification: [Correct/Incorrect and reasoning]
+- 1.3 Panoramic Strategy Analysis: [Summary of approach]
+
+**2. Error Analysis** (if applicable)
+- 2.1 Error Type: [Classification]
+- 2.2 Critical Mistake: [Specific error location and description]
+- 2.3 Root Cause: [Why this error occurred]
+- 2.4 Directional Guidance: [Hints WITHOUT revealing answer]
+
+**3. Better Approach Analysis**
+- 3.1 Optimal Method: [Recommended approach]
+- 3.2 Key Insights: [Mathematical insights needed]
+- 3.3 High-Level Strategy: [Step-by-step outline]
+
+**4. Synthesis**
+- 4.1 Corrective Principle: [Formulated principle]
+- 4.2 Final Verdict: [Classification]
+
+### Part Two: Final JSON Report
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the mathematical approach, correctness, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Calculation Error | Conceptual Misunderstanding | Method Selection Error | Setup Error | Logic Error | Incomplete Solution | Sign/Unit Error | null",
+ "root_cause": "The fundamental reason for the error. Null if OPTIMAL.",
+ "critical_step": "The specific step number where error occurred or better approach should diverge. Null if OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "High-level optimal strategy without solving. Null if OPTIMAL.",
+ "key_insights": "Critical mathematical insights needed. Null if OPTIMAL.",
+ "method_hints": "Directional hints about which methods/techniques to consider, WITHOUT revealing the answer. Null if OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "Mathematical problem-solving principle derived from this analysis. Null if OPTIMAL.",
+ "verification_reminders": "What to double-check for similar problems. Null if OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step to retry from (0 to N-1, where 0 = restart). Null if OPTIMAL.",
+ "retry_rationale": "Why restart from this step, with directional guidance but NO answer spoilers."
+ }
+ }
+}
+```
+
+## Critical Reminders
+
+1. **NEVER reveal the specific correct answer value in any guidance**
+2. **DO use** the correct answer to identify error types and provide directional hints
+3. **Focus on** teaching problem-solving strategies, not giving solutions
+4. **Provide hints** that point toward the right direction without solving the problem
+5. **Consider** the mathematical level and type of problem when formulating guidance
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/dapo/prompts/self_correction.j2
new file mode 100644
index 0000000000..fe7d377a39
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/prompts/self_correction.j2
@@ -0,0 +1,15 @@
+# Previous Attempt Analysis & Guidance
+
+Based on analysis of a previous attempt at a similar problem, here is guidance to help you solve this problem more effectively:
+
+{{ report }}
+
+## Instructions for This Attempt
+
+- Apply the insights and strategies mentioned above
+- Be especially careful about the identified error types
+- Follow the recommended approach if applicable
+- Double-check your work at each step
+- Remember: This guidance is based on common pitfalls, use it to inform your strategy
+
+**Important**: Present your solution naturally, showing your reasoning process. Do not explicitly reference this guidance or mention it's a retry.
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/dapo/raft_workflow.py
new file mode 100644
index 0000000000..d721202285
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/raft_workflow.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_dapo_workflow")
+class RAFTBaselineDapoWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for DAPO mathematical problem solving.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing RAFTDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[RAFT] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/dapo/utils.py b/trinity/common/workflows/envs/R3L-back/dapo/utils.py
new file mode 100644
index 0000000000..ff6f46c5f8
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/dapo/utils.py
@@ -0,0 +1,462 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from math_verify import parse, verify
+
+from trinity.common.experience import Experience
+
+
+def first_rollout(self) -> tuple[List[Dict[str, str]], float, bool, str, str, int]:
+ """Run math problem solving with multiple attempts (max 3 attempts)"""
+ trajectory = []
+
+ # Add system prompt
+ system_prompt = self.dapo_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ # Add user prompt (math problem)
+ if self.prompt:
+ trajectory.append({"role": "user", "content": self.prompt})
+ else:
+ trajectory.append({"role": "user", "content": "Please solve the given mathematical problem."})
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Verify answer
+ is_correct = math_verify(predicted_answer, self.ground_truth)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ print(f"[R3L First Rollout] Attempt {attempt_count} - Correct answer! Reward: {final_reward}")
+ feedback = f"Correct! Your answer {predicted_answer} matches the expected answer."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ print(f"[R3L First Rollout] Attempt {attempt_count} - Incorrect answer: {predicted_answer} (Expected: {self.ground_truth})")
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match the expected answer. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ return trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+
+def second_rollout(
+ self,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, str, str, int]:
+ """
+ Performs rollout with guidance from reflection.
+ For math problems, we typically start from the beginning with guidance.
+ """
+ trajectory = []
+ distill_trajectory = []
+
+ # Prepare system prompts
+ original_system_prompt = self.dapo_system_template.render()
+
+ # Starting from beginning with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Add user prompt (math problem)
+ if self.prompt:
+ trajectory.append({"role": "user", "content": self.prompt})
+ distill_trajectory.append({"role": "user", "content": self.prompt})
+ else:
+ trajectory.append({"role": "user", "content": "Please solve the given mathematical problem."})
+ distill_trajectory.append({"role": "user", "content": "Please solve the given mathematical problem."})
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Verify answer
+ is_correct = math_verify(predicted_answer, self.ground_truth)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ print(f"[R3L Second Rollout] Attempt {attempt_count} - Correct answer! Reward: {final_reward}")
+ feedback = f"Correct! Your answer {predicted_answer} matches the expected answer."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ print(f"[R3L Second Rollout] Attempt {attempt_count} - Incorrect answer: {predicted_answer} (Expected: {self.ground_truth})")
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match the expected answer. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+
+def eval_dapo(self) -> List[Experience]:
+ """Evaluate a single math problem"""
+ print("[R3L Eval] Starting evaluation...")
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ print(f"[R3L Eval] Completed - Reward: {reward}, Success: {success}, Attempts: {attempts}")
+ print(f"[R3L Eval] Predicted: {predicted_answer}, Ground Truth: {ground_truth}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ print(f"[R3L Eval] Evaluation failed - Error: {str(e)}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def parse_response(response: str) -> Tuple[Optional[str], Optional[str]]:
+ """Parse think and answer from math response"""
+ try:
+ # Extract think section
+ think_pattern = r"\s*(.*?)\s*"
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ think = think_match.group(1).strip() if think_match else None
+
+ # Extract answer from tags
+ answer_pattern = r"\s*(.*?)\s*"
+ answer_match = re.search(answer_pattern, response, re.DOTALL | re.IGNORECASE)
+ if answer_match:
+ answer = answer_match.group(1).strip()
+ else:
+ # Fallback: look for "Answer:" pattern
+ answer_line_pattern = r"Answer:\s*(.+?)(?:\n|$)"
+ answer_line_match = re.search(answer_line_pattern, response, re.IGNORECASE)
+ answer = answer_line_match.group(1).strip() if answer_line_match else None
+
+ return think, answer
+ except Exception as e:
+ print(f"Error parsing response: {e}")
+ return None, None
+
+
+def math_verify(predicted_answer: str, ground_truth: str) -> bool:
+ """
+ Verify if the predicted math answer matches the ground truth using math_verify library.
+ """
+ if not predicted_answer or not ground_truth:
+ return False
+
+ if parse is None or verify is None:
+ # Fallback: simple string comparison
+ pred_clean = str(predicted_answer).strip().lower()
+ gt_clean = str(ground_truth).strip().lower()
+ return pred_clean == gt_clean
+
+ try:
+ # Parse and verify
+ gold = parse(ground_truth)
+ answer = parse(predicted_answer)
+ return verify(gold, answer)
+ except Exception:
+ # Fallback comparison
+ return str(predicted_answer).strip().lower() == str(ground_truth).strip().lower()
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Format trajectory for reflection analysis.
+ Includes all messages including feedback.
+ """
+ formatted_lines = []
+ step_counter = 0
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ formatted_lines.append(f"**System Prompt:**\n{content}\n" + "=" * 50)
+ elif role == "user":
+ formatted_lines.append(f"\n**Step {step_counter} - User:**")
+ formatted_lines.append(f"{content}")
+ elif role == "assistant":
+ formatted_lines.append(f"\n**Step {step_counter} - Assistant Response:**")
+ formatted_lines.append(f"{content}")
+ step_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validate the structure and content of the reflection report.
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ """
+ if not isinstance(report, dict):
+ print("[R3L Validation] Reflection report is not a dict")
+ return False, False
+
+ # Check required keys
+ if "outcome_assessment" not in report or "analysis" not in report:
+ print("[R3L Validation] Missing required top-level keys in reflection report")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check valid outcome values
+ valid_outcomes = ["OPTIMAL", "SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]
+ if outcome not in valid_outcomes:
+ print(f"[R3L Validation] Invalid outcome_assessment: {outcome} (valid: {valid_outcomes})")
+ return False, False
+
+ # If OPTIMAL, it's perfect
+ is_perfect = (outcome == "OPTIMAL")
+
+ # Check retry_strategy
+ if not is_perfect and "retry_strategy" in analysis:
+ retry_strategy = analysis["retry_strategy"]
+ retry_step = retry_strategy.get("retry_step")
+
+ if retry_step is not None:
+ if not isinstance(retry_step, int) or retry_step < 0 or retry_step > total_steps:
+ print(f"[R3L Validation] Invalid retry_step: {retry_step} (total steps: {total_steps})")
+ return False, False
+ print(f"[R3L Validation] Reflection validated - Outcome: {outcome}, Is perfect: {is_perfect}")
+ return True, is_perfect
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Convert a validated reflection report into a guidance prompt for second attempt.
+ The guidance should provide directional hints without revealing the answer.
+ """
+ print("[R3L] Converting reflection report to guidance prompt...")
+ try:
+ analysis = report.get("analysis", {})
+ flaw_analysis = analysis.get("flaw_analysis", {})
+ lessons_learned = analysis.get("lessons_learned", {})
+
+ # Build guidance sections
+ guidance_parts = []
+
+ # Add summary
+ if "summary" in analysis:
+ guidance_parts.append(f"## Analysis Summary\n{analysis['summary']}")
+
+ # Add error diagnosis (without answer)
+ if "diagnosis" in flaw_analysis:
+ diagnosis = flaw_analysis["diagnosis"]
+ guidance_parts.append(f"\n## Error Diagnosis")
+ if "category" in diagnosis and diagnosis["category"]:
+ guidance_parts.append(f"Error Type: {diagnosis['category']}")
+ if "root_cause" in diagnosis and diagnosis["root_cause"]:
+ guidance_parts.append(f"Root Cause: {diagnosis['root_cause']}")
+
+ # Add method hints (directional guidance)
+ if "better_approach" in flaw_analysis:
+ better_approach = flaw_analysis["better_approach"]
+ guidance_parts.append(f"\n## Recommended Approach")
+ if "key_insights" in better_approach and better_approach["key_insights"]:
+ guidance_parts.append(f"Key Insights: {better_approach['key_insights']}")
+ if "method_hints" in better_approach and better_approach["method_hints"]:
+ guidance_parts.append(f"Method Hints: {better_approach['method_hints']}")
+ if "strategy" in better_approach and better_approach["strategy"]:
+ guidance_parts.append(f"Strategy: {better_approach['strategy']}")
+
+ # Add corrective principle
+ if "corrective_principle" in lessons_learned and lessons_learned["corrective_principle"]:
+ guidance_parts.append(f"\n## Corrective Principle\n{lessons_learned['corrective_principle']}")
+
+ # Add verification reminders
+ if "verification_reminders" in lessons_learned and lessons_learned["verification_reminders"]:
+ guidance_parts.append(f"\n## Verification Reminders\n{lessons_learned['verification_reminders']}")
+
+ guidance_prompt = "\n\n".join(guidance_parts)
+ print(f"[R3L] Guidance prompt generated ({len(guidance_parts)} sections)")
+ return guidance_prompt
+
+ except Exception as e:
+ print(f"[R3L] Error converting reflection to guidance: {e}")
+ return "Please try solving the problem again carefully."
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ success: bool,
+ predicted_answer: str = "",
+ ground_truth: str = "",
+ attempt_type: str = "first",
+ additional_metrics: Optional[Dict] = None
+) -> Dict[str, Any]:
+ """Create an experience record for data saving"""
+ record = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "reward": reward,
+ "success": success,
+ "predicted_answer": predicted_answer,
+ "ground_truth": ground_truth,
+ }
+
+ if additional_metrics:
+ record.update(additional_metrics)
+
+ return record
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict[str, Any],
+ data_dir: str
+):
+ """Save experience data to file"""
+ os.makedirs(data_dir, exist_ok=True)
+ file_path = os.path.join(data_dir, f"{task_id}.json")
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(experience_data, f, ensure_ascii=False, indent=2)
+
+
+def generate_default_experience() -> Experience:
+ """Generate a default experience for failed cases"""
+ return Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={"success": 0.0, "reward": 0.0},
+ reward=0.0
+ )
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/scienceworld/R3L_workflow.py
new file mode 100644
index 0000000000..f4f62ceab7
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/R3L_workflow.py
@@ -0,0 +1,348 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_scienceworld_workflow")
+class R3LScienceWorldWorkflow(Workflow):
+ """
+ R3L workflow for scienceworld
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_scienceworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LScienceWorldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+ """
+ if retry_step <= 0:
+ return
+
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L scienceworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ try:
+ retry_env = utils.create_sciworld_environment(self.task_desc)
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, retry_env
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"]
+
+ try:
+ second_env = utils.create_sciworld_environment(self.task_desc)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/__init__.py b/trinity/common/workflows/envs/R3L-back/scienceworld/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/scienceworld/grpo_workflow.py
new file mode 100644
index 0000000000..3a9a00fd8e
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/grpo_workflow.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_scienceworld_workflow")
+class GRPOBaselineScienceWorldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for ScienceWorld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Single rollout execution
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/scienceworld/opmd_workflow.py
new file mode 100644
index 0000000000..d2b0aef8a2
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/opmd_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_scienceworld_workflow")
+class OPMDBaselineScienceWorldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for ScienceWorld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing OPMDScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Single rollout execution
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/reflection.j2
new file mode 100644
index 0000000000..4518e035fc
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/reflection.j2
@@ -0,0 +1,150 @@
+# Metacognitive Analyst AI Prompt
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing a series of thoughts and actions in a SciWorld scientific experiment environment. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks from user and environment.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes or missed opportunities? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall plan was misguided.
+- Reasoning Flaw: The interpretation of information was incorrect.
+- Execution Flaw: The intent was correct, but the resulting action was clumsy or ineffective.
+- Knowledge Gap: Lacked critical information necessary to solve the problem.
+- Inefficiency: The goal was achieved, but via a redundant or convoluted path.
+- Invalid Format: The response was syntactically incorrect or violated protocol.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover the core flawed assumption or problematic mental model that led to the flaw. Continuously ask "Why?" until the most fundamental cause is revealed.
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal overall strategy or approach for this scientific experiment task?
+What series of positive effects would likely have followed from using this better approach?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis of the "Leverage Point," formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained and include ALL necessary context. It should be understandable and applicable without requiring external knowledge of the specific trajectory.
+ - ❌ **BAD**: "Check objects more carefully"
+ - ✅ **GOOD**: "In the SciWorld environment, when conducting experiments, always examine objects before using them to verify their properties match the task requirements"
+
+2. **Domain Specificity**: Clearly specify the environment, system, or context where this principle applies.
+ - Include environment name (SciWorld)
+ - Include relevant constraints or conditions
+
+3. **Causal Chain Awareness**: The principle should consider not just the immediate impact but also downstream consequences.
+ - Consider how the corrective action affects subsequent steps
+ - Anticipate potential cascading effects
+
+4. **Actionable Structure**: The principle should be actionable and clear, typically including:
+ - The specific environment or context
+ - Clear trigger conditions or situations
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+ **Note**: The exact format can vary based on the nature of the insight. It could be a prescriptive rule ("When X, do Y"), a cautionary guideline ("Avoid X in situation Y"), or a strategic insight ("Prioritize X because Y"). Choose the format that best captures the lesson learned.
+
+5. **Independence Test**: The principle should be meaningful and correct even if read in isolation, without access to the original trajectory.
+
+**Question 3.2: Render the Final Verdict**
+
+Now, and only now, based on your complete analysis, classify the outcome of this task into one of the following more precise categories:
+
+- **OPTIMAL**: Flawlessly and efficiently achieved the goal; a textbook execution.
+- **SUBOPTIMAL_SUCCESS**: Achieved the goal, but with correctable inefficiencies or minor flaws.
+- **PARTIAL**: Made significant progress but did not fully meet the final goal.
+- **INEFFECTIVE**: Fully failed to achieve the primary goal.
+
+**Question 3.3: Determine Retry Strategy**
+
+Based on your analysis, determine the optimal retry strategy:
+- Identify the specific step where the retry should begin
+- Explain why this step was chosen as the restart point
+- Consider whether restarting from the beginning (step 0) or from a specific problematic step would be more beneficial
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all 5 critical requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+- 3.3 Retry Strategy: Fill in the retry step and rationale here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the trajectory's strategy, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core flawed assumption or problematic mental model that was uncovered. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal overall strategy or approach that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "A brief explanation of how this better approach differs from the original approach. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The series of positive effects projected to occur from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained, context-complete principle that includes environment specifics, clear trigger conditions, and considers downstream effects. Must be understandable and applicable in isolation. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved action plan based on the corrective principle, considering both immediate and downstream impacts. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step that should be retried. Can be null if outcome_assessment is OPTIMAL. Range is 0 to N-1, where N is the total number of steps in the trajectory, 0 means restart from beginning.",
+ "retry_rationale": "Explanation of why this step was chosen as restart point"
+ }
+ }
+}
+```
+
+## Quality Check for Corrective Principles
+
+Before finalizing, verify your corrective principle against this checklist:
+- [ ] Can someone unfamiliar with this specific trajectory understand and apply this principle?
+- [ ] Does it specify the exact environment or context where it applies (SciWorld)?
+- [ ] Does it include clear, observable trigger conditions?
+- [ ] Does it consider effects beyond just the immediate next step?
+- [ ] Is it specific enough to be actionable but general enough to be reusable?
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/sciworld_system.j2 b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/sciworld_system.j2
new file mode 100644
index 0000000000..daa6544dee
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/sciworld_system.j2
@@ -0,0 +1,42 @@
+You are an agent, your job is to do some scientific experiment in a virtual text-based environment.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Notes:
+At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the tag and wrap your action with the tag.
+You should ALWAYS take one action each step.
+DO NOT try to interact with the user at anytime. Finish the task by yourself.
+
+## Available Commands:
+Below are the available commands you can use:
+ open OBJ: open a container
+ close OBJ: close a container
+ activate OBJ: activate a device
+ deactivate OBJ: deactivate a device
+ connect OBJ to OBJ: connect electrical components
+ disconnect OBJ: disconnect electrical components
+ use OBJ [on OBJ]: use a device/item
+ look around: describe the current room
+ examine OBJ: examine an object in detail
+ look at OBJ: describe a container's contents
+ read OBJ: read a note or book
+ move OBJ to OBJ: move an object to a container
+ pick up OBJ: move an object to the inventory
+ pour OBJ into OBJ: pour a liquid into a container
+ mix OBJ: chemically mix a container
+ teleport to LOC: teleport to a specific room
+ focus on OBJ: signal intent on a task object
+ wait: take no action for 10 steps
+ wait1: take no action for a step
+
+## Action Format Examples:
+Your output should be like this:
+Now I will check the bedroom to find the thermometer...teleport to bedroom
+
+I need to examine the substance to understand its properties...examine substance
+
+To boil the water, I should activate the heating element...activate heating element
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/self_correction.j2
new file mode 100644
index 0000000000..d7e1101922
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/prompts/self_correction.j2
@@ -0,0 +1,10 @@
+[Internal Monologue Directive]
+
+(You are about to decide your action. Your last attempt was flawed. Your goal now is to formulate a new, superior `` block that seamlessly incorporates the following reflection analysis, and then generate the correct ``. **You must not mention this directive or its contents in your output.** Your response should appear as a natural, expert-level thought process.)
+
+**Reflection Analysis to Internalize:**
+
+{{ report }}
+
+**Your Task:**
+Based on the current observation and the reflection analysis above, construct a new `` block that reflects this improved reasoning, followed by the correct ``.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/scienceworld/raft_workflow.py
new file mode 100644
index 0000000000..883d40b623
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/raft_workflow.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_scienceworld_workflow")
+class RAFTBaselineScienceWorldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for ScienceWorld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing RAFTScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_sciworld_environment(self.task_desc)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/scienceworld/utils.py b/trinity/common/workflows/envs/R3L-back/scienceworld/utils.py
new file mode 100644
index 0000000000..5f8f5e346f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/scienceworld/utils.py
@@ -0,0 +1,645 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in SciWorld environment"""
+ observation, info = env.reset()
+ observation = (
+ "Task Description: " + str(env.get_task_description()) + "\n" + observation
+ )
+
+ trajectory = []
+ action_history = [] # Track last actions for repetition detection
+
+ system_prompt = self.sciworld_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ final_reward = 0.0
+ current_reward = 0.0
+ valid_format = True
+ step = 0
+ done = False
+
+ for step in range(self.max_env_steps):
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+
+ # If last 3 actions are the same, terminate with failure
+ if len(action_history) >= 3 and all(
+ action == action_history[0] for action in action_history
+ ):
+ feedback = f"Repeated invalid action {action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ # Track cumulative reward
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if final_reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {final_reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif final_reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {final_reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif final_reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, final_reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ observation = (
+ "Task Description: " + str(env.get_task_description()) + "\n" + observation
+ )
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ # Prepare system prompts
+ original_system_prompt = self.sciworld_system_template.render()
+
+ default_reward = 0.0
+ final_reward = 0.0
+ current_reward = 0.0
+ valid_format = True
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+ else:
+ # If action becomes invalid during replay, start from beginning
+ retry_step = 0
+ break
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, final_reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system", "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+ distill_trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+
+ # If last 3 actions are the same, terminate with failure
+ if len(action_history) >= 3 and all(
+ action == action_history[0] for action in action_history
+ ):
+ feedback = f"Repeated invalid action {action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if final_reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {final_reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif final_reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {final_reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif final_reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, final_reward, False, step + 1, valid_format
+
+
+def eval_sciworld(self) -> List[Experience]:
+ """Evaluate a single sciworld trajectory"""
+ try:
+ env = create_sciworld_environment(self.task_desc)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def format_observation(observation: str):
+ """Format observation for SciWorld environment"""
+ return "Observation: \n" + observation
+
+
+def parse_response(response):
+ """Parse all three components from response with a single regex"""
+ think, action = None, None
+ try:
+ # Use single regex to extract all three components at once
+ pattern = r"\s*(.*?)\s*.*?\s*(.*?)\s*"
+ match = re.search(pattern, response, re.DOTALL)
+
+ if match:
+ think, action = match.group(1).strip(), match.group(2).strip()
+ except Exception:
+ pass
+ return think, action
+
+
+def validate_action(action):
+ """
+ Validate action format for SciWorld environment.
+ SciWorld actions don't need validation against available_actions like WebShop.
+ We just check if the action is non-empty.
+ """
+ if not action or not action.strip():
+ return False, "Action cannot be empty"
+
+ # SciWorld accepts any non-empty action string
+ # The environment itself will handle invalid actions
+ return True, ""
+
+
+def create_sciworld_environment(task_desc):
+ """Create sciworld environment"""
+ try:
+ from scienceworld import ScienceWorldEnv
+
+ # Parse task_desc to get task name and variation
+ # Format: "task_name-variation_number" or just "task_name"
+ if '-' in task_desc:
+ parts = task_desc.split('-')
+ task_name = parts[0]
+ variation = int(parts[1]) if len(parts) > 1 else 0
+ else:
+ task_name = task_desc
+ variation = 0
+
+ env = ScienceWorldEnv(task_name, serverPath="")
+ env.load(task_name, variation, generateGoldPath=True)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import scienceworld dependencies: {e}. "
+ "Please install scienceworld following the instructions at https://github.com/allenai/ScienceWorld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency"
+ ]
+ if diagnosis.get("category") not in valid_categories and diagnosis.get("category") != "null":
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}")
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned
+ and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}")
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}")
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1}).")
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/R3L_workflow.py b/trinity/common/workflows/envs/R3L-back/webshop/R3L_workflow.py
new file mode 100644
index 0000000000..ddb5b0c711
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/R3L_workflow.py
@@ -0,0 +1,411 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_webshop_workflow")
+class R3LWebshopWorkflow(Workflow):
+ """
+ R3L workflow for webshop
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_webshop_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ # Try gymnasium first, fallback to gym
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env require ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ },
+ reward=-0.1 # Default minimum reward for webshop tasks
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": -0.1,
+ },
+ reward=-0.1
+ )
+
+ print(
+ f"Initializing ExpLearnWebshopWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # print("Generating reflection report using the unified self-interrogation prompt...")
+
+ # 格式化轨迹以供LLM阅读
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # 使用Jinja2模板渲染反思提示
+ reflect_prompt = self.reflection_template.render()
+
+ # 调用模型并解析结果
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # print(f"raw reflection text: {reflection_text}")
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ # print(f"Failed during unified reflection process: {e}")
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the experience learning webshop workflow and return experiences"""
+
+ if self.is_eval:
+ # print("pass evaluation mode")
+ # return [opmd_reflect_enhanced_restart_utils.generate_default_experience()]
+ return self.eval_webshop()
+
+ # Generate unique task ID using timestamp
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # 一半用于rollout,一半在此基础上进行反思再rollout
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ # exp.info = {"valid": format_valid}
+ # print(exp.info)
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # 设置eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # 对首次尝试进行反思
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # print("Reflect report is invalid or indicates perfection, skipping second rollout")
+ # 如果第一次尝试的reward是1.0且反思给出完美,则记录反思exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # 设置eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # 再进行一次rollout,以让整个batch有足够的数据
+ try:
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ # 设置eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"]
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, self.env, self.session_id, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ # Find and modify the exp that was already added to exp_lst
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ # second_exp.info = {"valid": second_format_valid}
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # 设置eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # 如果第二次尝试的分数高于第一次,或第二次是满分情况下步数更少,则记录反思和重试数据
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ # 将反思数据转换为exp
+ # reflect_exp.reward = second_reward - reward
+ reflect_exp.reward = 1.0
+ # 设置eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # 将重试数据转换为exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ # retry_exp.reward = second_reward - reward
+ retry_exp.reward = 1.0
+ # 设置eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ # print
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/__init__.py b/trinity/common/workflows/envs/R3L-back/webshop/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/grpo_workflow.py b/trinity/common/workflows/envs/R3L-back/webshop/grpo_workflow.py
new file mode 100644
index 0000000000..d1dec54c50
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/grpo_workflow.py
@@ -0,0 +1,121 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_webshop_workflow")
+class GRPOBaselineWebshopWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for WebShop environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_webshop(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/opmd_workflow.py b/trinity/common/workflows/envs/R3L-back/webshop/opmd_workflow.py
new file mode 100644
index 0000000000..b27cb2da0f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/opmd_workflow.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_webshop_workflow")
+class OPMDBaselineWebshopWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for WebShop environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing OPMDWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_webshop(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L-back/webshop/prompts/reflection.j2
new file mode 100644
index 0000000000..996d4bd34f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/prompts/reflection.j2
@@ -0,0 +1,142 @@
+# Metacognitive Analyst AI Prompt
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing a series of thoughts and actions. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks from user and environment.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes or missed opportunities? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall plan was misguided.
+- Reasoning Flaw: The interpretation of information was incorrect.
+- Execution Flaw: The intent was correct, but the resulting action was clumsy or ineffective.
+- Knowledge Gap: Lacked critical information necessary to solve the problem.
+- Inefficiency: The goal was achieved, but via a redundant or convoluted path.
+- Invalid Format: The response was syntactically incorrect or violated protocol.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover the core flawed assumption or problematic mental model that led to the flaw. Continuously ask "Why?" until the most fundamental cause is revealed.
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal overall strategy or approach for this task?
+What series of positive effects would likely have followed from using this better approach?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis of the "Leverage Point," formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained and include ALL necessary context. It should be understandable and applicable without requiring external knowledge of the specific trajectory.
+ - ❌ **BAD**: "Click operations tend to cause failures"
+ - ✅ **GOOD**: "In the xxx environment, when click operations are not available in the action space, attempting to execute click will cause failures"
+
+2. **Domain Specificity**: Clearly specify the environment, system, or context where this principle applies.
+ - Include environment name
+ - Include relevant constraints or conditions
+
+3. **Causal Chain Awareness**: The principle should consider not just the immediate impact but also downstream consequences.
+ - Consider how the corrective action affects subsequent steps
+ - Anticipate potential cascading effects
+
+4. **Actionable Structure**: The principle should be actionable and clear, typically including:
+ - The specific environment or context
+ - Clear trigger conditions or situations
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+ **Note**: The exact format can vary based on the nature of the insight. It could be a prescriptive rule ("When X, do Y"), a cautionary guideline ("Avoid X in situation Y"), or a strategic insight ("Prioritize X because Y"). Choose the format that best captures the lesson learned.
+
+5. **Independence Test**: The principle should be meaningful and correct even if read in isolation, without access to the original trajectory.
+
+**Question 3.2: Render the Final Verdict**
+
+Now, and only now, based on your complete analysis, classify the outcome of this task into one of the following more precise categories:
+
+- **OPTIMAL**: Flawlessly and efficiently achieved the goal; a textbook execution.
+- **SUBOPTIMAL_SUCCESS**: Achieved the goal, but with correctable inefficiencies or minor flaws.
+- **PARTIAL**: Made significant progress but did not fully meet the final goal.
+- **INEFFECTIVE**: Fully failed to achieve the primary goal.
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all 5 critical requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the trajectory's strategy, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core flawed assumption or problematic mental model that was uncovered. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal overall strategy or approach that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "A brief explanation of how this better approach differs from the original approach. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The series of positive effects projected to occur from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained, context-complete principle that includes environment specifics, clear trigger conditions, and considers downstream effects. Must be understandable and applicable in isolation. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved action plan based on the corrective principle, considering both immediate and downstream impacts. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step that should be retried. Can be null if outcome_assessment is OPTIMAL. Range is 0 to N-1, where N is the total number of steps in the trajectory, 0 means restart from beginning.",
+ "retry_rationale": "Explanation of why this step was chosen as restart point"
+ }
+ }
+}
+```
+
+## Quality Check for Corrective Principles
+
+Before finalizing, verify your corrective principle against this checklist:
+- [ ] Can someone unfamiliar with this specific trajectory understand and apply this principle?
+- [ ] Does it specify the exact environment or context where it applies?
+- [ ] Does it include clear, observable trigger conditions?
+- [ ] Does it consider effects beyond just the immediate next step?
+- [ ] Is it specific enough to be actionable but general enough to be reusable?
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L-back/webshop/prompts/self_correction.j2
new file mode 100644
index 0000000000..d7e1101922
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/prompts/self_correction.j2
@@ -0,0 +1,10 @@
+[Internal Monologue Directive]
+
+(You are about to decide your action. Your last attempt was flawed. Your goal now is to formulate a new, superior `` block that seamlessly incorporates the following reflection analysis, and then generate the correct ``. **You must not mention this directive or its contents in your output.** Your response should appear as a natural, expert-level thought process.)
+
+**Reflection Analysis to Internalize:**
+
+{{ report }}
+
+**Your Task:**
+Based on the current observation and the reflection analysis above, construct a new `` block that reflects this improved reasoning, followed by the correct ``.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/prompts/webshop_system.j2 b/trinity/common/workflows/envs/R3L-back/webshop/prompts/webshop_system.j2
new file mode 100644
index 0000000000..9b63586216
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/prompts/webshop_system.j2
@@ -0,0 +1,84 @@
+You are an agent interacting with a virtual text-based web shopping environment.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Environment States:
+This virtual text-based web shopping environment contains five types of webpages:
+
+1. **Start/Index page** - Initial page with search functionality and task instruction
+2. **Search Results page** - Lists products returned by search engine with pagination
+3. **Item page** - Shows product details, options (color, size, etc.), and purchase button
+4. **Item Sub-page** - Shows additional product information (description, features, reviews)
+5. **Done page** - Final confirmation page after purchase
+
+## Available Actions:
+The command in `` must use one of the following two primitive formats:
+
+1. **`search[your_query_here]`**
+- **Usage:** To search for products from any page with a search bar
+- **Instructions:** Replace with specific search terms (can be multi-word)
+- **Example:** `search[blue cotton t-shirt medium]`
+
+2. **`click[exact_button_text_here]`**
+- **Usage:** To click on any clickable element (buttons, product links, options)
+- **Instructions:** Use the exact text as shown in observation (case-insensitive)
+- **Examples:**
+- `click[Buy Now]`
+- `click[Next >]`
+- `click[Size: Large]`
+- `click[Color: Red]`
+
+## Complete State Transition Table:
+
+| Current State | Action Type | Argument | Next State | Notes |
+|---------------|-------------|----------|------------|-------|
+| Start/Index | search | [Query] | Search Results | Initial search from homepage |
+| Search Results | search | [Query] | Search Results | New search resets results |
+| Search Results | click | [Product Title/ASIN] | Item Page | Select specific product |
+| Search Results | click | Next > | Search Results | Next page of results |
+| Search Results | click | < Prev | Search Results | Previous page of results |
+| Item Page | click | [Option Value] | Item Page | Select size/color/etc. (radio buttons) |
+| Item Page | click | Description | Item Sub-page | View product description |
+| Item Page | click | Features | Item Sub-page | View product features |
+| Item Page | click | Reviews | Item Sub-page | View product reviews |
+| Item Page | click | Buy Now | Done Page | **Purchase and end episode** |
+| Item Page | click | < Back to Search | Search Results | Return to search results |
+| Item Page | click | < Prev | Search Results | Return to search results |
+| Item Sub-page | click | < Prev | Item Page | Return to main product page |
+| Any Page | search | [Query] | Search Results | Start new search |
+
+## Key Implementation Details:
+
+### Clickable Elements:
+- **Buttons:** `[button] Text [button_]` → use `click[Text]`
+- **Product Links:** Product titles/ASINs → use `click[Product Name]`
+- **Options:** Radio buttons for size, color, etc. → use `click[Option Value]`
+- **Navigation:** `< Prev`, `Next >`, `< Back to Search`
+
+### Page Identification:
+You can identify the current page type by observing:
+- **Start page:** Contains initial instruction and search bar
+- **Search Results:** Lists multiple products with pagination controls
+- **Item Page:** Shows single product with options and "Buy Now" button
+- **Item Sub-page:** Shows detailed info without "Buy Now" button
+- **Done Page:** Shows purchase confirmation
+
+### Important Navigation Rules:
+1. **From Item Sub-pages:** You MUST click `< Prev` to return to Item Page before purchasing
+2. **Option Selection:** Selecting options (size, color) stays on the same Item Page
+3. **Search Resets:** Using search from any page starts a new product search
+4. **Purchase Requirement:** You can only purchase from the Item Page, not sub-pages
+
+## Task Completion:
+- **Goal:** Find and purchase an item matching the given instruction within 15 steps
+- **Success:** Episode ends when you click "Buy Now" with appropriate product and options
+
+## Observation Format:
+- Clickable elements appear as: `[button] Text [button_]`
+- Selected options may show as: `[clicked button] Text [clicked button_]`
+- Regular text appears without special formatting
+- The instruction text shows your shopping goal
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/raft_workflow.py b/trinity/common/workflows/envs/R3L-back/webshop/raft_workflow.py
new file mode 100644
index 0000000000..9488a8156a
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/raft_workflow.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_webshop_workflow")
+class RAFTBaselineWebshopWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for WebShop environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing RAFTWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ },
+ reward=-0.1
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L-back/webshop/utils.py b/trinity/common/workflows/envs/R3L-back/webshop/utils.py
new file mode 100644
index 0000000000..7c20d6cdb9
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L-back/webshop/utils.py
@@ -0,0 +1,658 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env, session_id) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout"""
+ # print(f"About to reset env with session_id: {session_id}")
+ env.reset(session=session_id)
+ observation = env.observation
+ trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ system_prompt = self.webshop_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = -0.1
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ for step in range(self.max_env_steps):
+ available_actions = env.get_available_actions()
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+
+ # Get model response with experience guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the three components for action execution
+ think, action = parse_response(response_text)
+ if action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"Terminating due to invalid response format: {response_text}")
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+
+ # If last 2 actions are the same, terminate with failure
+ if len(action_history) >= 2 and all(
+ action == action_history[0] for action in action_history
+ ) and "next" not in action.lower() and "prev" not in action.lower() and "search" not in action.lower():
+ feedback = f"Repeated invalid action {action} multiple times, shopping task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"Terminating due to 5 consecutive identical actions: {action_text}")
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action, available_actions)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate timeout feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Shopping task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Shopping task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+ else:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+
+ # Add timeout feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, reward, False, step + 1, valid_format
+
+def second_rollout(
+ self,
+ env,
+ session_id: int,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ session_id: The ID for the current task session.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ env.reset(session=session_id)
+ observation = env.observation
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ # Prepare system prompts
+ original_system_prompt = self.webshop_system_template.render()
+
+ default_reward = -0.1
+ reward = default_reward
+ valid_format = True
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ action_valid, error_msg = validate_action(action, env.get_available_actions())
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+ else:
+ # If action becomes invalid during replay, start from beginning
+ retry_step = 0
+ break
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system", "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ available_actions = env.get_available_actions()
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+ distill_trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+
+ # If last 2 actions are the same, terminate with failure
+ if len(action_history) >= 2 and all(
+ action == action_history[0] for action in action_history
+ ) and "next" not in action.lower() and "prev" not in action.lower() and "search" not in action.lower():
+ feedback = f"Repeated invalid action {action} multiple times, shopping task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action, available_actions)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Shopping task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Shopping task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+ else:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ # For compatibility, return the same trajectory as both distill_trajectory and second_trajectory
+ # since we're starting fresh instead of resuming from a checkpoint
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+def eval_webshop(self) -> List[Experience]:
+ """Evaluate a single webshop trajectory"""
+ try:
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, self.env, self.session_id
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[WebShop Eval] Rollout - reward: {reward}, steps: {steps}, valid_format: {valid_format}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ }
+ )
+ exp.reward = -0.1
+ return [exp]
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def format_observation(observation: str, available_actions: dict):
+ return f"Environment Observation: {observation} \n Available Actions: {available_actions}"
+
+
+def parse_response(response):
+ """Parse all three components from response with a single regex"""
+ think, action = None, None
+ try:
+ # Use single regex to extract all three components at once
+ pattern = r"\s*(.*?)\s*.*?\s*(.*?)\s*"
+ match = re.search(pattern, response, re.DOTALL)
+
+ if match:
+ think, action = match.group(1).strip(), match.group(2).strip()
+ except Exception:
+ pass
+ return think, action
+
+
+def validate_action(action, available_actions):
+ """Validate action format and availability"""
+ import re
+
+ # Parse action format: action_name[action_arg]
+ pattern = re.compile(r"(.+)\[(.+)\]")
+ m = re.match(pattern, action)
+ if m is None:
+ return (
+ False,
+ "Invalid action format. You should use format: action_name[action_arg], like search[query] or click[button].",
+ )
+
+ action_name, action_arg = m.groups()
+ action_name = action_name.strip()
+ action_arg = action_arg.strip()
+
+ # Validate search action
+ if action_name == "search":
+ if not action_arg:
+ return (
+ False,
+ "Invalid search action, please type in the query you want to search in the square brackets.",
+ )
+ if not available_actions["has_search_bar"]:
+ return (
+ False,
+ "Cannot perform search action without search bar. Please click the Back to Search button first.",
+ )
+ return True, ""
+
+ # Validate click action
+ elif action_name == "click":
+ if not action_arg:
+ return (
+ False,
+ "Invalid click action, please specify the button name in the square brackets.",
+ )
+ # Convert to lowercase for comparison (as clickables are typically lowercase)
+ action_arg_lower = action_arg.lower()
+ if action_arg_lower not in available_actions["clickables"]:
+ return (
+ False,
+ f"Button '{action_arg}' not found on current page. Available buttons: {available_actions['clickables']}",
+ )
+ return True, ""
+
+ # Unknown action
+ else:
+ return (
+ False,
+ f"Unknown action '{action_name}'. Only 'search' and 'click' actions are supported.",
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency"
+ ]
+ if diagnosis.get("category") not in valid_categories and diagnosis.get("category") != "null":
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}")
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned
+ and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}")
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}")
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1}).")
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L/__init__.py b/trinity/common/workflows/envs/R3L/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/R3L_workflow.py b/trinity/common/workflows/envs/R3L/alfworld-bak/R3L_workflow.py
new file mode 100644
index 0000000000..a73023f78f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/R3L_workflow.py
@@ -0,0 +1,376 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_alfworld_workflow")
+class R3LAlfworldWorkflow(Workflow):
+ """
+ R3L workflow for alfworld
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_alfworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LAlfworldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Parse JSON
+ json_match = re.search(r"```json\s*(\{.*?\})\s*```", reflection_text, re.DOTALL)
+ if json_match:
+ json_str = json_match.group(1)
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ if not is_valid:
+ # Do another rollout to ensure the batch has enough data
+ try:
+ retry_env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, retry_env
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report
+ retry_step = reflect_checklist["analysis"]["retry_strategy"]["retry_step"]
+
+ try:
+ second_env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, or second is perfect with fewer steps,
+ # record reflection and retry data
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/__init__.py b/trinity/common/workflows/envs/R3L/alfworld-bak/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/grpo_workflow.py b/trinity/common/workflows/envs/R3L/alfworld-bak/grpo_workflow.py
new file mode 100644
index 0000000000..2de739c0aa
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/grpo_workflow.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_alfworld_workflow")
+class GRPOBaselineAlfworldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Alfworld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # History length for sliding window (verl-agent uses 2)
+ self.history_length = 2
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineAlfworldWorkflow, temperature={self.temperature}, history_length={self.history_length}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ # print(f"trajectory: {trajectory}")
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[GRPO] Rollout {i} failed with exception: {e}")
+ # exp = Experience(
+ # tokens=torch.tensor([0, 0], dtype=torch.long),
+ # prompt_length=1,
+ # action_mask=torch.tensor([False], dtype=torch.bool),
+ # logprobs=torch.tensor([0.0], dtype=torch.float),
+ # metrics={
+ # "success": 0.0,
+ # "reward": 0.0,
+ # }
+ # )
+ # exp.reward = 0.0
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/opmd_workflow.py b/trinity/common/workflows/envs/R3L/alfworld-bak/opmd_workflow.py
new file mode 100644
index 0000000000..9edb8a7573
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/opmd_workflow.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_alfworld_workflow")
+class OPMDBaselineAlfworldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Alfworld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing OPMDAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/alfworld_system.j2 b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/alfworld_system.j2
new file mode 100644
index 0000000000..b0879979b1
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/alfworld_system.j2
@@ -0,0 +1,5 @@
+You are an expert agent operating in the ALFRED Embodied Environment.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/reflection.j2
new file mode 100644
index 0000000000..87987361b4
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/reflection.j2
@@ -0,0 +1,133 @@
+# Metacognitive Analyst AI Prompt
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing a series of thoughts and actions. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks from user and environment.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes or missed opportunities? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall plan was misguided.
+- Reasoning Flaw: The interpretation of information was incorrect.
+- Execution Flaw: The intent was correct, but the resulting action was clumsy or ineffective.
+- Knowledge Gap: Lacked critical information necessary to solve the problem.
+- Inefficiency: The goal was achieved, but via a redundant or convoluted path.
+- Invalid Format: The response was syntactically incorrect or violated protocol.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover the core flawed assumption or problematic mental model that led to the flaw. Continuously ask "Why?" until the most fundamental cause is revealed.
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal overall strategy or approach for this task?
+What series of positive effects would likely have followed from using this better approach?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis of the "Leverage Point," formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained and include ALL necessary context. It should be understandable and applicable without requiring external knowledge of the specific trajectory.
+ - ❌ **BAD**: "Click operations tend to cause failures"
+ - ✅ **GOOD**: "In the xxx environment, when click operations are not available in the action space, attempting to execute click will cause failures"
+
+2. **Domain Specificity**: Clearly specify the environment, system, or context where this principle applies.
+ - Include environment name
+ - Include relevant constraints or conditions
+
+3. **Causal Chain Awareness**: The principle should consider not just the immediate impact but also downstream consequences.
+ - Consider how the corrective action affects subsequent steps
+ - Anticipate potential cascading effects
+
+4. **Actionable Structure**: The principle should be actionable and clear, typically including:
+ - The specific environment or context
+ - Clear trigger conditions or situations
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+ **Note**: The exact format can vary based on the nature of the insight. It could be a prescriptive rule ("When X, do Y"), a cautionary guideline ("Avoid X in situation Y"), or a strategic insight ("Prioritize X because Y"). Choose the format that best captures the lesson learned.
+
+5. **Independence Test**: The principle should be meaningful and correct even if read in isolation, without access to the original trajectory.
+
+**Question 3.2: Render the Final Verdict**
+
+Now, and only now, based on your complete analysis, classify the outcome of this task into one of the following more precise categories:
+
+- **OPTIMAL**: Flawlessly and efficiently achieved the goal; a textbook execution.
+- **SUBOPTIMAL_SUCCESS**: Achieved the goal, but with correctable inefficiencies or minor flaws.
+- **PARTIAL**: Made significant progress but did not fully meet the final goal.
+- **INEFFECTIVE**: Fully failed to achieve the primary goal.
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all 5 critical requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+```json
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the trajectory's strategy, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core flawed assumption or problematic mental model that was uncovered. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal overall strategy or approach that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "A brief explanation of how this better approach differs from the original approach. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The series of positive effects projected to occur from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained, context-complete principle that includes environment specifics, clear trigger conditions, and considers downstream effects. Must be understandable and applicable in isolation. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved action plan based on the corrective principle, considering both immediate and downstream impacts. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step that should be retried. Can be null if outcome_assessment is OPTIMAL. Range is 0 to N-1, where N is the total number of steps in the trajectory, 0 means restart from beginning.",
+ "retry_rationale": "Explanation of why this step was chosen as restart point"
+ }
+ }
+}
+```
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/self_correction.j2
new file mode 100644
index 0000000000..d7e1101922
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/prompts/self_correction.j2
@@ -0,0 +1,10 @@
+[Internal Monologue Directive]
+
+(You are about to decide your action. Your last attempt was flawed. Your goal now is to formulate a new, superior `` block that seamlessly incorporates the following reflection analysis, and then generate the correct ``. **You must not mention this directive or its contents in your output.** Your response should appear as a natural, expert-level thought process.)
+
+**Reflection Analysis to Internalize:**
+
+{{ report }}
+
+**Your Task:**
+Based on the current observation and the reflection analysis above, construct a new `` block that reflects this improved reasoning, followed by the correct ``.
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/raft_workflow.py b/trinity/common/workflows/envs/R3L/alfworld-bak/raft_workflow.py
new file mode 100644
index 0000000000..07a790757b
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/raft_workflow.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("raft_baseline_alfworld_workflow")
+class RAFTBaselineAlfworldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Alfworld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 50
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing RAFTAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_alfworld_environment(self.game_file_path, self.max_env_steps)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ except Exception:
+ exp = copy.deepcopy(self.default_exp)
+ exp_lst.append(exp)
+
+ return exp_lst
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/utils-bak.py b/trinity/common/workflows/envs/R3L/alfworld-bak/utils-bak.py
new file mode 100644
index 0000000000..b11d0cce10
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/utils-bak.py
@@ -0,0 +1,717 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in Alfworld environment"""
+ observation, info = env.reset()
+ trajectory = []
+ action_history = [] # Track all actions taken
+
+ # system_prompt = self.alfworld_system_template.render()
+ # trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ trajectory.append(
+ {
+ "role": "user",
+ "content": format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions,
+ ),
+ }
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ # print(f"[first_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}")
+ return trajectory, reward, False, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ # print(f"[first_rollout] Step {step}: Action taken: {action}, Reward: {reward}, Done: {done}, Observation: {observation}, Info: {info.get('admissible_commands')}")
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"[first_rollout] reward: {reward}, steps: {step + 1}, valid_format: {valid_format}")
+ return trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ done = False
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {
+ "role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}",
+ }
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = (
+ f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ )
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions,
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ print(
+ f"[second_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}"
+ )
+ return distill_trajectory, trajectory, reward, True, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ trajectory, reward, done, steps, valid_format = first_rollout(self, env)
+ print(
+ f"[Eval Alfworld] Trajectory completed with reward: {reward}, steps: {steps}, valid_format: {valid_format}"
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation",
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval", experience_data=eval_record, data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split("\n")[-1] if "\n" in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[str] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 2,
+):
+ """
+ Format observation string with task context and limited action history.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of all previous actions taken
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent actions to display (default: 4)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Format admissible actions
+ admissible_actions_str = (
+ ", ".join(admissible_actions) if admissible_actions else "All standard actions available"
+ )
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+ else:
+ # Limit action history to most recent history_length items
+ recent_actions = (
+ action_history[-history_length:]
+ if len(action_history) > history_length
+ else action_history
+ )
+
+ # Format action history as a structured list with observations
+ action_history_str = "\n".join(
+ [
+ f" Step {current_step - len(recent_actions) + i}: {action}"
+ for i, action in enumerate(recent_actions)
+ ]
+ )
+
+ # Create formatted prompt with limited history
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {len(recent_actions)} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+
+ return prompt
+
+
+def parse_response(response):
+ """Parse think and action components from response"""
+ try:
+ # Use regex to extract think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ action_match = re.search(action_pattern, response, re.DOTALL)
+
+ think = think_match.group(1).strip() if think_match else None
+ action = action_match.group(1).strip() if action_match else None
+
+ return think, action
+ except Exception:
+ return None, None
+
+
+def create_alfworld_environment(game_file, max_episode_steps=50):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file,
+ request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert],
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency",
+ ]
+ if (
+ diagnosis.get("category") not in valid_categories
+ and diagnosis.get("category") != "null"
+ ):
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(
+ f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}"
+ )
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(
+ f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}"
+ )
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(
+ f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}"
+ )
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1})."
+ )
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(task_id: str, experience_data: Dict, data_dir: str) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, "__dict__"):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None,
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory),
+ },
+ "created_at": datetime.now().isoformat(),
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L/alfworld-bak/utils.py b/trinity/common/workflows/envs/R3L/alfworld-bak/utils.py
new file mode 100644
index 0000000000..800ccc47e4
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld-bak/utils.py
@@ -0,0 +1,719 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Run a single rollout in Alfworld environment.
+ Uses sliding window approach like verl-agent: each step is a single-turn call.
+ """
+ observation, info = env.reset()
+
+ # Store complete trajectory for training data
+ full_trajectory = []
+
+ # Store history as (observation, action) pairs for sliding window
+ history = []
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation
+ task_description = extract_task_description(observation)
+ history_length = getattr(self, 'history_length', 2) # Use configurable history_length
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ # Build prompt with sliding window history (verl-agent style)
+ prompt = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=history, # Pass full history, format_observation will handle sliding window
+ admissible_actions=admissible_actions,
+ history_length=history_length
+ )
+
+ # Single-turn chat call (verl-agent style: each step is independent)
+ single_turn_messages = [{"role": "user", "content": prompt}]
+
+ # Get model response
+ responses = self.model.chat(
+ single_turn_messages,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 2560 - 512:
+ return full_trajectory, reward, False, step, False
+
+ response_text = responses[0].response_text.strip()
+
+ # Store in full trajectory for training
+ full_trajectory.append({"role": "user", "content": prompt})
+ full_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ full_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Store in history (only observation and action, NOT thinking - verl-agent style)
+ history.append({'observation': observation, 'action': action})
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ else:
+ feedback = f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+
+ # Add feedback to trajectory
+ full_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return full_trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ done = False
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 2560 - 512:
+ print(f"[second_rollout] Token count {responses[0].tokens.shape[0]} exceeds 2048 at step {step}")
+ return distill_trajectory, trajectory, reward, True, step, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ # Only add format feedback if not in the last step (to avoid duplicate feedback)
+ if step < self.max_env_steps - 1:
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ continue
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+
+ # Track successfully executed actions for history (only if action was valid)
+ if "Nothing happens." not in observation:
+ action_history.append(action)
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path, self.max_env_steps)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ print(
+ f"[Eval Alfworld] Trajectory completed with reward: {reward}, steps: {steps}, valid_format: {valid_format}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split('\n')[-1] if '\n' in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[Dict[str, str]] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 2
+):
+ """
+ Format observation string with task context and limited action history.
+ Adapted to verl-agent style: history is a list of {observation, action} dicts.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of {observation, action} dicts from previous steps
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent steps to display (default: 2)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Format admissible actions (remove 'help' like verl-agent does)
+ filtered_actions = [a for a in admissible_actions if a != 'help']
+ admissible_actions_str = ", ".join(f"'{a}'" for a in filtered_actions) if filtered_actions else "All standard actions available"
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version (ALFWORLD_TEMPLATE_NO_HIS style)
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+ else:
+ # Limit history to most recent history_length items (sliding window)
+ recent_history = action_history[-history_length:] if len(action_history) > history_length else action_history
+ valid_history_length = len(recent_history)
+ start_idx = len(action_history) - valid_history_length
+
+ # Format action history: only show actions, NOT observations (verl-agent style)
+ action_history_lines = []
+ for i, record in enumerate(recent_history):
+ step_num = start_idx + i
+ action = record['action']
+ # Note: We could include observation here like verl-agent does:
+ # obs = record['observation']
+ # action_history_lines.append(f"[Observation {step_num}: '{obs}', Action {step_num}: '{action}']")
+ # But for simplicity, just show actions:
+ action_history_lines.append(f" Step {step_num}: {action}")
+
+ action_history_str = "\n".join(action_history_lines)
+
+ # Create formatted prompt with limited history (ALFWORLD_TEMPLATE style)
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {valid_history_length} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: [{admissible_actions_str}].
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags."""
+
+ return prompt
+
+
+def parse_response(response):
+ """Parse think and action components from response"""
+ try:
+ # Use regex to extract think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ action_match = re.search(action_pattern, response, re.DOTALL)
+
+ think = think_match.group(1).strip() if think_match else None
+ action = action_match.group(1).strip() if action_match else None
+
+ return think, action
+ except Exception:
+ return None, None
+
+
+def create_alfworld_environment(game_file, max_episode_steps=50):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file, request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert]
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], max_steps: int = None) -> tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ max_steps: Maximum number of steps in trajectory for retry_step bounds checking (optional)
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "outcome_assessment" not in report
+ or "analysis" not in report
+ ):
+ print("Validation failed: Report is not a dict or missing top-level keys.")
+ return False, False
+
+ outcome = report["outcome_assessment"]
+ analysis = report["analysis"]
+
+ # Check for required top-level analysis keys
+ if "summary" not in analysis:
+ print("Validation failed: Missing 'summary' in analysis.")
+ return False, False
+
+ if outcome == "OPTIMAL":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("OPTIMAL report validation successful.")
+ return True, True
+
+ elif outcome in ["SUBOPTIMAL_SUCCESS", "PARTIAL", "INEFFECTIVE"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ flaw_analysis = analysis.get("flaw_analysis", {})
+
+ # Validate diagnosis
+ diagnosis = flaw_analysis.get("diagnosis", {})
+ valid_categories = [
+ "Strategy Flaw",
+ "Reasoning Flaw",
+ "Execution Flaw",
+ "Knowledge Gap",
+ "Inefficiency"
+ ]
+ if diagnosis.get("category") not in valid_categories and diagnosis.get("category") != "null":
+ print(f"Validation failed: Invalid 'category'. Got: {diagnosis.get('category')}")
+ return False, False
+
+ # Validate better_approach
+ better_approach = flaw_analysis.get("better_approach", {})
+ required_better_approach_keys = ["strategy", "key_differences", "projected_benefits"]
+ for key in required_better_approach_keys:
+ if key not in better_approach:
+ print(f"Validation failed: Missing '{key}' in better_approach. Got: {better_approach}")
+ return False, False
+
+ # Validate lessons_learned
+ lessons_learned = analysis.get("lessons_learned", {})
+ if not (
+ "corrective_principle" in lessons_learned
+ and "revised_action_plan" in lessons_learned
+ ):
+ print(f"Validation failed: Invalid 'lessons_learned'. Got: {lessons_learned}")
+ return False, False
+
+ # Validate retry_strategy
+ retry_strategy = analysis.get("retry_strategy", {})
+ if not retry_strategy:
+ print("Validation failed: Missing 'retry_strategy' in analysis.")
+ return False, False
+
+ # Validate retry_step
+ if "retry_step" not in retry_strategy:
+ print("Validation failed: Missing 'retry_step' in retry_strategy.")
+ return False, False
+
+ retry_step = retry_strategy["retry_step"]
+ if retry_step is not None:
+ try:
+ retry_step = int(retry_step)
+ except (ValueError, TypeError):
+ print(f"Validation failed: 'retry_step' must be an integer or null. Got: {retry_step}")
+ return False, False
+ if not isinstance(retry_step, int) or retry_step < 0:
+ print(f"Validation failed: 'retry_step' must be a non-negative integer or null. Got: {retry_step}")
+ return False, False
+
+ # Check trajectory bounds if max_steps is provided
+ if max_steps is not None:
+ if retry_step >= max_steps:
+ print(
+ f"Validation failed: 'retry_step' ({retry_step}) exceeds trajectory bounds (0 to {max_steps - 1}).")
+ return False, False
+
+ # Validate retry_rationale
+ if "retry_rationale" not in retry_strategy:
+ print("Validation failed: Missing 'retry_rationale' in retry_strategy.")
+ return False, False
+
+ print(f"{outcome} report validation successful.")
+ return True, False
+
+ else:
+ print(f"Validation failed: Unknown 'outcome_assessment': {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L/alfworld/R3L_workflow.py b/trinity/common/workflows/envs/R3L/alfworld/R3L_workflow.py
new file mode 100644
index 0000000000..ce7d7e246c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/R3L_workflow.py
@@ -0,0 +1,340 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_alfworld_workflow")
+class R3LAlfworldWorkflow(Workflow):
+ """
+ R3L workflow for alfworld
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = True
+ # Create data directories
+ self.data_dir = f"R3L_alfworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LAlfworldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report (top-level field in alfworld schema)
+ retry_step = reflect_checklist.get("retry_from_step", 0)
+
+ try:
+ second_env = utils.create_alfworld_environment(self.game_file_path)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, or second is perfect with fewer steps,
+ # record reflection and retry data
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception as e:
+ print(f"Second rollout after reflection failed: {e}")
+ except Exception as e:
+ print(f"First rollout failed: {e}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/alfworld/dapo_workflow.py b/trinity/common/workflows/envs/R3L/alfworld/dapo_workflow.py
new file mode 100644
index 0000000000..d7a35e55a5
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/dapo_workflow.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_alfworld_workflow")
+class DAPOAlfworldWorkflow(Workflow):
+ """
+ DAPO Workflow for Alfworld environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 512)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing DAPOAlfworldWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[DAPO] Rollout - reward: {reward}, steps: {steps}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "steps": steps,
+ "env_reward": reward,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/alfworld/grpo_workflow.py b/trinity/common/workflows/envs/R3L/alfworld/grpo_workflow.py
new file mode 100644
index 0000000000..bbc34b716a
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/grpo_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_alfworld_workflow")
+class GRPOBaselineAlfworldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Alfworld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/alfworld/opmd_workflow.py b/trinity/common/workflows/envs/R3L/alfworld/opmd_workflow.py
new file mode 100644
index 0000000000..61663f685f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/opmd_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_alfworld_workflow")
+class OPMDBaselineAlfworldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Alfworld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing OPMDAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_alfworld(self)
+
+ # Single rollout execution
+ env = utils.create_alfworld_environment(self.game_file_path)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/alfworld/prompts/alfworld_system.j2 b/trinity/common/workflows/envs/R3L/alfworld/prompts/alfworld_system.j2
new file mode 100644
index 0000000000..05d4532bf6
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/prompts/alfworld_system.j2
@@ -0,0 +1,44 @@
+You are an agent interacting with a virtual text-based environment.
+
+## Response Format:
+You MUST use this exact format for every response. Both tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Action Commands:
+Your must be one of the following, strictly following the command (argument) format.
+
+### Navigation & Observation:
+- look: Look around your current location to get more details.
+- inventory: Check the object you are currently holding (you can only hold one).
+- go to (receptacle): Move to a receptacle (e.g., table, fridge, sink).
+
+### Interacting with Receptacles:
+- open (receptacle): Open a receptacle.
+- close (receptacle): Close a receptacle.
+
+### Interacting with Objects:
+- take (object) from (receptacle): Pick up an object from a receptacle.
+- move (object) to (receptacle): Place the object you are holding into or onto a receptacle.
+- examine (object): Examine an object closely to learn its properties.
+
+### Changing Object States:
+- heat (object) with (receptacle): Heat an object with a device (e.g., microwave).
+- cool (object) with (receptacle): Cool an object with a device (e.g., fridge).
+- clean (object) with (receptacle): Clean an object with a device (e.g., sink).
+- slice (object) with (object): Slice an object using a sharp object (e.g., knife).
+
+For example your output should be like this:
+your reasoning process here
+look
+
+your reasoning process here
+go to sofa 1
+
+## Critical Rules & Constraints
+- Single Item Inventory: You can only hold one object at a time. You must put down the current object before taking a new one.
+- Examine Before Acting: Before performing an action on an object (like take, heat, or clean), it is best to examine it first to confirm its properties.
+- Use Exact Names: The (object) and (receptacle) arguments in your command MUST exactly match the names seen in your Observation, including any numbers (e.g., apple 1, desk 2).
+- Systematic Thinking: Break down complex tasks into smaller, manageable sub-goals. Clearly outline your plan in the block.
+- Step Limit: You must complete the task within 25 steps.
diff --git a/trinity/common/workflows/envs/R3L/alfworld/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/alfworld/prompts/reflection.j2
new file mode 100644
index 0000000000..e182a1f0a7
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name (ALFWorld interactive tasks), (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Solving the equation 3x² - 12x + 9 = 0
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent attempted to solve a quadratic equation by immediately applying the quadratic formula with a=3, b=-12, c=9. The calculation resulted in x = (12 ± √(144-108))/6 = (12 ± 6)/6, yielding x=3 or x=1. However, the agent failed to verify the solution and missed that the equation could be simplified first by factoring out 3, leading to a more elegant solution path.",
+"root_cause_analysis": "Why was the approach suboptimal? Because the agent jumped directly to the quadratic formula without checking for simplifications. Why skip simplification? Because it saw standard form ax²+bx+c=0 and immediately pattern-matched to 'use quadratic formula'. Why this pattern-matching? Because the agent treated the quadratic formula as a universal first-choice method rather than one tool among many. Root cause: Lack of strategic problem assessment - the agent optimized for immediate formula application rather than problem structure analysis, missing that all coefficients shared a common factor of 3.",
+"trajectory_outcome": "success_but_inefficient",
+"improvement_suggestion": "In mathematical problem-solving environments, always perform a structural analysis before applying solution methods: (1) check for common factors in all terms, (2) look for special patterns (perfect squares, difference of squares, sum/product relationships), (3) assess whether simplification reduces computational complexity. For quadratic equations specifically, factor out GCD first - this often reveals simpler factorizations or reduces calculation errors. Example: 3x²-12x+9=0 becomes 3(x²-4x+3)=0, then 3(x-1)(x-3)=0, directly yielding x=1 or x=3 without formula computation. Apply formula only when factoring is not immediately apparent.",
+"retry_from_step": 0
+}
+```
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L/alfworld/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/alfworld/prompts/self_correction.j2
new file mode 100644
index 0000000000..3c0a34b676
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
diff --git a/trinity/common/workflows/envs/R3L/alfworld/raft_workflow.py b/trinity/common/workflows/envs/R3L/alfworld/raft_workflow.py
new file mode 100644
index 0000000000..c92a5fae87
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/raft_workflow.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.alfworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_alfworld_workflow")
+class RAFTBaselineAlfworldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Alfworld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 25
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing RAFTAlfworldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.task = task
+ self.n = task.repeat_times
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_alfworld_environment(self.game_file_path)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/alfworld/utils.py b/trinity/common/workflows/envs/R3L/alfworld/utils.py
new file mode 100644
index 0000000000..58107a924c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/alfworld/utils.py
@@ -0,0 +1,698 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in Alfworld environment"""
+ observation, info = env.reset()
+ trajectory = []
+ action_history = [] # Track all actions taken
+
+ system_prompt = self.alfworld_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ done = False
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ for step in range(self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ trajectory.append(
+ {"role": "user", "content": format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )}
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action, error_msg = parse_response(response_text)
+ if error_msg is not None:
+ valid_format = False
+ observation = f"{error_msg}"
+ # 对于reward, done, info则保持默认值或者上一次的值
+ trajectory.append({"role": "user", "content": f"Feedback: {error_msg}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+ else:
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ if action not in admissible_actions:
+ valid_format = False
+ observation = f"Invalid action '{action}' not in admissible actions."
+ trajectory.append({"role": "user", "content": f"Feedback: {observation}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Track successfully executed actions for history
+ if valid_format:
+ action_history.append(action)
+
+ # Check for consecutive action repetition (last 3 actions)
+ if len(action_history) >= 3 and all(
+ a == action_history[-1] for a in action_history[-3:]
+ ):
+ repeated_action = action_history[-1]
+ feedback = f"Repeated invalid action {repeated_action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track all actions taken
+
+ # Prepare system prompts
+ original_system_prompt = self.alfworld_system_template.render()
+
+ default_reward = 0.0
+ done = False
+ reward = default_reward
+ valid_format = True
+
+ # Extract task description from the initial observation (only once at the beginning)
+ task_description = extract_task_description(observation)
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action, _ = parse_response(msg["content"])
+ if think is not None and action is not None:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system",
+ "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ # Extract admissible actions from info if available
+ admissible_actions = info.get("admissible_commands", []) if isinstance(info, dict) else []
+
+ formatted_obs = format_observation(
+ current_observation=observation,
+ task_description=task_description,
+ current_step=step,
+ action_history=action_history,
+ admissible_actions=admissible_actions
+ )
+
+ trajectory.append({"role": "user", "content": formatted_obs})
+ distill_trajectory.append({"role": "user", "content": formatted_obs})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action, error_msg = parse_response(response_text)
+ if error_msg is not None:
+ valid_format = False
+ observation = f"{error_msg}"
+ else:
+ # Execute action in environment
+ observation, reward, done, info = env.step(action)
+ if action not in admissible_actions:
+ valid_format = False
+ observation = f"Invalid action '{action}' not in admissible actions."
+
+ # Track successfully executed actions for history
+ if valid_format:
+ action_history.append(action)
+
+ # Check for consecutive action repetition (last 3 actions)
+ if len(action_history) >= 3 and all(
+ a == action_history[-1] for a in action_history[-3:]
+ ):
+ repeated_action = action_history[-1]
+ feedback = f"Repeated invalid action {repeated_action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ if done:
+ break
+
+ print(f"[Second Rollout] - reward: {reward}, steps: {step + 1}")
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+
+def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ try:
+ env = create_alfworld_environment(self.game_file_path)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def extract_task_description(observation: str) -> str:
+ """
+ Extract task description from the initial observation.
+ The task description is typically in the format: "Your task is to: ..."
+
+ Args:
+ observation: Initial observation from environment
+
+ Returns:
+ Extracted task description string
+ """
+ # Look for pattern "Your task is to: "
+ match = re.search(r"Your task is to:\s*(.+?)(?:\n|$)", observation, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+ # Fallback: return a portion of the observation
+ return observation.split('\n')[-1] if '\n' in observation else observation
+
+
+def format_observation(
+ current_observation: str,
+ task_description: str = "",
+ current_step: int = 0,
+ action_history: List[str] = None,
+ admissible_actions: List[str] = None,
+ history_length: int = 4
+):
+ """
+ Format observation string with task context and limited action history.
+
+ Args:
+ current_observation: Current observation from environment
+ task_description: Description of the task to complete
+ current_step: Current step number
+ action_history: List of all previous actions taken
+ admissible_actions: List of currently admissible actions
+ history_length: Maximum number of recent actions to display (default: 4)
+ """
+ if action_history is None:
+ action_history = []
+ if admissible_actions is None:
+ admissible_actions = []
+
+ # Check if this is the first step (no history)
+ if current_step == 0 or not action_history:
+ # First step - no history version
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment.
+Your current observation is: {current_observation}
+Your admissible actions of the current situation are: {admissible_actions}.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
+
+Format: your reasoning process your chosen action"""
+ else:
+ # Limit action history to most recent history_length items
+ recent_actions = action_history[-history_length:] if len(action_history) > history_length else action_history
+
+ # Format action history as a structured list
+ action_history_str = "\n".join([f" Step {current_step - len(recent_actions) + i}: {action}"
+ for i, action in enumerate(recent_actions)])
+
+ # Create formatted prompt with limited history
+ prompt = f"""You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
+Prior to this step, you have already taken {len(action_history)} step(s). Below are the most recent {len(recent_actions)} actions you took:
+{action_history_str}
+You are now at step {current_step} and your current observation is: {current_observation}
+Your admissible actions of the current situation are: {admissible_actions}.
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags.
+
+Format: your reasoning process your chosen action"""
+
+ return prompt
+
+
+def parse_response(response):
+ """
+ Parse think and action components from response.
+ Returns (think, action, error_message) tuple.
+ - If successful: (think_content, action_content, None)
+ - If error: (None, None, error_message)
+ """
+ try:
+ # Use regex to extract all think and action components
+ think_pattern = r"\s*(.*?)\s*"
+ action_pattern = r"\s*(.*?)\s*"
+
+ think_matches = re.findall(think_pattern, response, re.DOTALL)
+ action_matches = re.findall(action_pattern, response, re.DOTALL)
+
+ # Check for multiple think tags
+ if len(think_matches) > 1:
+ return None, None, f"Multiple tags found ({len(think_matches)}). Only one pair is allowed."
+
+ # Check for multiple action tags
+ if len(action_matches) > 1:
+ return None, None, f"Multiple tags found ({len(action_matches)}). Only one pair is allowed."
+
+ # Check if tags are missing
+ if len(think_matches) == 0 and len(action_matches) == 0:
+ return None, None, "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ elif len(think_matches) == 0:
+ return None, None, "Invalid response format, missing valid tag, please ensure to follow the output format strictly: ... ..."
+ elif len(action_matches) == 0:
+ return None, None, "Invalid response format, missing valid tag, please ensure to follow the output format strictly: ... ..."
+
+ think = think_matches[0].strip()
+ action = action_matches[0].strip()
+
+ return think, action, None
+ except Exception:
+ return None, None, "Unexpected error occurred while parsing response format."
+
+
+def create_alfworld_environment(game_file, max_episode_steps=25):
+ """
+ Create alfworld environment
+
+ Args:
+ game_file: Path to the game file
+ max_episode_steps: Maximum number of steps per episode (default: 50)
+ """
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file, request_infos,
+ max_episode_steps=max_episode_steps,
+ asynchronous=True,
+ wrappers=[AlfredDemangler(), expert]
+ )
+ env = textworld.gym.make(env_id)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import alfworld dependencies: {e}. "
+ "Please install alfworld following the instructions at https://github.com/alfworld/alfworld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the new reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ total_steps: Total number of steps in trajectory for retry_step bounds checking
+
+ Returns:
+ Tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("[R3L Alfworld Validation] Validation failed: Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+ analysis = report["root_cause_analysis"]
+
+ if outcome == "success":
+ # For OPTIMAL, we only need summary and no flaw analysis
+ print("[R3L Alfworld Validation] success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate flaw_analysis structure
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if improvement_suggestion is None or retry_from_step is None:
+ print("[R3L Alfworld Validation] Validation failed: Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"[R3L Alfworld Validation] Validation failed: 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"[R3L Alfworld Validation] Validation failed: 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds
+ if retry_from_step >= total_steps:
+ print(
+ f"[R3L Alfworld Validation] Validation failed: 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {total_steps - 1}).")
+ return False, False
+ print(f"[R3L Alfworld Validation] {outcome} report validation successful.")
+ return True, False
+ else:
+ print(f"[R3L Alfworld Validation] Invalid trajectory_outcome: {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L/countdown/R3L_workflow.py b/trinity/common/workflows/envs/R3L/countdown/R3L_workflow.py
new file mode 100644
index 0000000000..ba901afdd1
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/R3L_workflow.py
@@ -0,0 +1,404 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_countdown_workflow")
+class R3LCountdownWorkflow(Workflow):
+ """
+ R3L workflow for Countdown mathematical problem solving
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_countdown_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LCountdownWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ print(f"[R3L] Reflection failed - Error: {str(e)}")
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L countdown workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[R3L Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ print(f"[R3L] Starting reflection on first attempt (reward: {reward})...")
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, attempts)
+
+ if reflect_checklist is None:
+ print(f"[R3L] Reflection failed - No valid reflection data generated")
+ elif is_valid and not is_perfect:
+ print(f"[R3L] Reflection successful - Valid reflection generated")
+ elif is_perfect:
+ print(f"[R3L] Reflection indicates perfect first attempt - No retry needed")
+ elif not is_valid:
+ print(f"[R3L] Reflection validation failed - Invalid reflection data")
+
+ if not is_valid or is_perfect:
+ print(f"[R3L] Skip second rollout due to invalid ({not is_valid}) or perfect ({is_perfect}) reflection.")
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Do another rollout to ensure the batch has enough data
+ print(f"[R3L] Performing additional rollout...")
+ try:
+ retry_trajectory, retry_reward, retry_success, retry_predicted_answer, retry_ground_truth, retry_attempts = utils.first_rollout(self)
+ print(f"[R3L] Additional rollout completed - reward: {retry_reward}, attempts: {retry_attempts}")
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_success else 0.0,
+ "reward": retry_reward,
+ "attempts": retry_attempts,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ success=retry_success,
+ predicted_answer=retry_predicted_answer,
+ ground_truth=retry_ground_truth,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"[R3L] Retry rollout after invalid reflection failed - Error: {e}")
+
+ else:
+ print("[R3L] Valid reflection obtained, proceeding to second rollout...")
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report (top-level field in alfworld schema)
+ retry_step = reflect_checklist.get("retry_from_step", 0)
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_success,
+ second_predicted_answer,
+ second_ground_truth,
+ second_attempts,
+ ) = utils.second_rollout(
+ self, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, attempts: {second_attempts}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_success else 0.0,
+ "second_reward": second_reward,
+ "second_attempts": second_attempts,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ success=second_success,
+ predicted_answer=second_predicted_answer,
+ ground_truth=second_ground_truth,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, record reflection and retry data
+ if second_reward > reward and second_reward >= 1.0:
+ print(f"[R3L] Second attempt successful improvement - Recording reflection and retry experiences")
+ print(f"[R3L] Reward improvement: {reward} -> {second_reward} (+{second_reward - reward:.2f})")
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("[R3L] Reflection and retry led to improvement, recording both...")
+ elif second_reward <= reward:
+ print(f"[R3L] Second attempt did not improve - First reward: {reward}, Second reward: {second_reward}")
+ else:
+ print(f"[R3L] Second attempt improved but below threshold - Reward: {second_reward} (need >= 1.0)")
+ except Exception as e:
+ print(f"[R3L] Second rollout failed - Error: {str(e)}")
+ except Exception as e:
+ print(f"[R3L] Rollout iteration {i} failed - Error: {str(e)}")
+
+ # Print summary statistics
+ print(f"\n[R3L Summary] Generated {len(exp_lst)} experiences")
+ total_reward = sum(exp.reward for exp in exp_lst)
+ avg_reward = total_reward / len(exp_lst) if exp_lst else 0.0
+ print(f"[R3L Summary] Total reward: {total_reward:.2f}, Average reward: {avg_reward:.2f}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/countdown/__init__.py b/trinity/common/workflows/envs/R3L/countdown/__init__.py
new file mode 100644
index 0000000000..5c3d1be19a
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/__init__.py
@@ -0,0 +1,2 @@
+# -*- coding: utf-8 -*-
+"""Countdown R3L workflows"""
diff --git a/trinity/common/workflows/envs/R3L/countdown/dapo_workflow.py b/trinity/common/workflows/envs/R3L/countdown/dapo_workflow.py
new file mode 100644
index 0000000000..4480803594
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/dapo_workflow.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_countdown_workflow")
+class DAPOCountdownWorkflow(Workflow):
+ """
+ DAPO Workflow for Countdown environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 4096)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing DAPOCountdownWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from raw_task
+ raw_task = task.raw_task or {}
+ self.numbers = raw_task.get("numbers", [])
+ self.target = raw_task.get("target", 0)
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[DAPO Countdown] Rollout - reward: {reward}, attempts: {attempts}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "attempts": attempts,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO Countdown] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/countdown/grpo_workflow.py b/trinity/common/workflows/envs/R3L/countdown/grpo_workflow.py
new file mode 100644
index 0000000000..62e87c8833
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/grpo_workflow.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_countdown_workflow")
+class GRPOBaselineCountdownWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for Countdown environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ # if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums")
+ self.target = raw_task.get("target")
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Multiple rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[GRPO Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/countdown/opmd_workflow.py b/trinity/common/workflows/envs/R3L/countdown/opmd_workflow.py
new file mode 100644
index 0000000000..0d13287070
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/opmd_workflow.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_countdown_workflow")
+class OPMDBaselineCountdownWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for Countdown mathematical problem solving.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing OPMDCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_countdown(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[OPMD Countdown] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/countdown/prompts/countdown_system.j2 b/trinity/common/workflows/envs/R3L/countdown/prompts/countdown_system.j2
new file mode 100644
index 0000000000..f497326641
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/prompts/countdown_system.j2
@@ -0,0 +1,27 @@
+You are a mathematical problem solver. Your task is to create equations using given numbers to reach a target value.
+
+## Response Format:
+You MUST use this exact format for every response. Both tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+your final equation
+
+## Task Description:
+Given a set of numbers and a target value, you need to create an equation using basic arithmetic operations (+, -, *, /) where:
+- Each given number can only be used once
+- You can use parentheses to control order of operations
+- The equation must equal the target value
+
+## Example:
+For numbers [44, 19, 35] and target 98:
+Let me try different combinations:
+- 44 + 19 + 35 = 98 ✓ This works!
+
+(44 + 19 + 35)
+
+## Critical Rules:
+- Use each number exactly once
+- Only use the four basic operations: +, -, *, /
+- Your answer must be in the tags
+- Show your reasoning process in the tags
+- The equation in must be valid and evaluate to the target
diff --git a/trinity/common/workflows/envs/R3L/countdown/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/countdown/prompts/reflection.j2
new file mode 100644
index 0000000000..db4b2883ce
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name (Countdown number problems), (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. For countdown problems, this is typically 0 since they are single-attempt tasks. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Solving a countdown problem to reach 24 using numbers [3, 8, 3, 8]
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent attempted to solve a countdown problem by trying the equation (8-3)*(8-3) which equals 25, not the target 24. The agent used each number once as required, but the arithmetic result was incorrect. The approach showed understanding of the rules (each number used once) but failed to verify the calculation matched the target.",
+"root_cause_analysis": "Why did the solution fail? Because (8-3)*(8-3) = 5*5 = 25, not 24. Why didn't the agent catch this? Because it focused on constructing a valid equation structure but didn't verify the numerical result. Why skip verification? Because the agent treated 'forming an equation' as success rather than 'reaching the exact target'. Root cause: Premature satisfaction with equation structure without rigorous numerical verification - the agent optimized for syntactic correctness (valid equation form) rather than semantic correctness (correct numerical result).",
+"trajectory_outcome": "failure",
+"improvement_suggestion": "In Countdown number problems, always verify the final numerical result before submitting. The task requires: (1) using each given number exactly once, AND (2) the equation evaluating to the exact target value. Both conditions must be satisfied. After constructing any equation, compute its value step-by-step and compare with the target. If the result differs even by 1, the equation is incorrect regardless of its structural validity. For [3,8,3,8]→24: while (8-3)*(8-3)=25 satisfies the usage rule, it fails the value requirement. A correct solution is 8/(3-8/3)=24, which both uses each number once and evaluates to exactly 24.",
+"retry_from_step": 0
+}
+```
diff --git a/trinity/common/workflows/envs/R3L/countdown/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/countdown/prompts/self_correction.j2
new file mode 100644
index 0000000000..3c0a34b676
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
diff --git a/trinity/common/workflows/envs/R3L/countdown/raft_workflow.py b/trinity/common/workflows/envs/R3L/countdown/raft_workflow.py
new file mode 100644
index 0000000000..0e8428104d
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/raft_workflow.py
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.countdown import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_countdown_workflow")
+class RAFTBaselineCountdownWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for Countdown environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.countdown_system_template = self.jinja_env.get_template("countdown_system.j2")
+
+ print(
+ f"Initializing RAFTBaselineCountdownWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract numbers and target from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Countdown format: direct access to nums and target fields
+ self.numbers = raw_task.get("nums", [])
+ self.target = raw_task.get("target", 0)
+ else:
+ self.numbers = []
+ self.target = 0
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Multiple rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[RAFT Countdown] Rollout {i} - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # RAFT only uses successful samples
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/countdown/utils.py b/trinity/common/workflows/envs/R3L/countdown/utils.py
new file mode 100644
index 0000000000..8c1cdd354c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/countdown/utils.py
@@ -0,0 +1,595 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.utils.eval_utils import evaluate_equation, validate_equation
+
+
+def first_rollout(self) -> tuple[List[Dict[str, str]], float, bool, str, str, int]:
+ """Run countdown problem solving with multiple attempts (max 3 attempts) using multi-round interaction"""
+ trajectory = []
+ attempt_history = [] # Track attempt history for limited history display
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Format user prompt with history (limited to history_length)
+ user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ trajectory.append({"role": "user", "content": user_prompt})
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": "Invalid format",
+ "feedback": feedback
+ })
+ continue
+
+ # Verify answer
+ is_correct = countdown_verify(predicted_answer, self.numbers, self.target)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ feedback = f"Correct! Your equation {predicted_answer} successfully equals {self.target} using the numbers {self.numbers}."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your equation {predicted_answer} does not work. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your equation {predicted_answer} does not match the target {self.target}. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": predicted_answer,
+ "feedback": feedback
+ })
+
+ return trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+
+def second_rollout(
+ self,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, str, str, int]:
+ """
+ Performs rollout with guidance from reflection.
+ For countdown problems, we typically start from the beginning with guidance.
+ """
+ trajectory = []
+ distill_trajectory = []
+ attempt_history = [] # Track attempt history for limited history display
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Format user prompt with history and guidance
+ if attempt == 0:
+ # First attempt includes guidance
+ user_prompt = format_countdown_prompt_with_guidance(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ guidance_prompt=guidance_prompt,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ # For distill trajectory, use prompt without guidance
+ distill_user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ else:
+ # Subsequent attempts don't repeat guidance
+ user_prompt = format_countdown_prompt(
+ numbers=self.numbers,
+ target=self.target,
+ current_step=attempt,
+ attempt_history=attempt_history,
+ history_length=getattr(self, 'history_length', 4)
+ )
+ distill_user_prompt = user_prompt
+
+ trajectory.append({"role": "user", "content": user_prompt})
+ distill_trajectory.append({"role": "user", "content": distill_user_prompt})
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": "Invalid format",
+ "feedback": feedback
+ })
+ continue
+
+ # Verify answer
+ is_correct = countdown_verify(predicted_answer, self.numbers, self.target)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ feedback = f"Correct! Your equation {predicted_answer} successfully equals {self.target}."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your equation {predicted_answer} does not work. Please try again."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your equation {predicted_answer} does not match the target {self.target}. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ # Record this failed attempt in history
+ attempt_history.append({
+ "equation": predicted_answer,
+ "feedback": feedback
+ })
+
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, str(self.target), attempt_count
+
+
+def eval_countdown(self) -> List[Experience]:
+ """Evaluate a single countdown problem"""
+ print("[R3L Countdown Eval] Starting evaluation...")
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ print(f"[R3L Countdown Eval] Completed - Reward: {reward}, Success: {success}, Attempts: {attempts}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ print(f"[R3L Countdown Eval] Evaluation failed - Error: {str(e)}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def parse_response(response: str) -> Tuple[Optional[str], Optional[str]]:
+ """Parse think and answer from countdown response"""
+ try:
+ # Extract think section
+ think_pattern = r"\s*(.*?)\s*"
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ think = think_match.group(1).strip() if think_match else None
+
+ # Extract answer from tags
+ answer_pattern = r"\s*(.*?)\s*"
+ answer_match = re.search(answer_pattern, response, re.DOTALL | re.IGNORECASE)
+ if answer_match:
+ answer = answer_match.group(1).strip()
+ else:
+ answer = None
+
+ return think, answer
+ except Exception as e:
+ print(f"Error parsing response: {e}")
+ return None, None
+
+
+def countdown_verify(predicted_answer: str, numbers: List[int], target: int) -> bool:
+ """
+ Verify if the predicted countdown equation is correct.
+ """
+ if not predicted_answer:
+ print("Predicted answer is empty.")
+ return False
+
+ # Extract equation from predicted answer
+ equation = predicted_answer
+
+ # Validate equation uses correct numbers
+ if not validate_equation(equation, numbers):
+ print("Equation validation failed: uses invalid numbers.")
+ return False
+
+ # Evaluate equation
+ try:
+ result = evaluate_equation(equation)
+ if result is None:
+ print("Equation evaluation returned None.")
+ return False
+
+ if abs(result - target) < 1e-5: # Account for floating point precision
+ print(f"Equation evaluation successful: matches target, {result}, {target}.")
+ return True
+ else:
+ print(f"Equation evaluation result {result} does not match target {target}.")
+ return False
+ except Exception as e:
+ print(f"Error evaluating equation: {e}")
+ return False
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Format trajectory for reflection analysis.
+ Includes all messages including feedback.
+ """
+ formatted_lines = []
+ step_counter = 0
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ formatted_lines.append(f"**System Prompt:**\n{content}\n" + "=" * 50)
+ elif role == "user":
+ formatted_lines.append(f"\n**Step {step_counter} - User:**")
+ formatted_lines.append(f"{content}")
+ elif role == "assistant":
+ formatted_lines.append(f"\n**Step {step_counter} - Assistant Response:**")
+ formatted_lines.append(f"{content}")
+ step_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the alfworld reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ total_steps: Maximum number of steps in trajectory for retry_step bounds checking
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("[R3L Countdown Validation] Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+
+ if outcome == "success":
+ # For success, we only need summary and no flaw analysis
+ print("[R3L Countdown Validation] success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate required fields
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if improvement_suggestion is None or retry_from_step is None:
+ print("[R3L Countdown Validation] Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"[R3L Countdown Validation] 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"[R3L Countdown Validation] 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds if total_steps is provided
+ if total_steps is not None:
+ if retry_from_step >= total_steps:
+ print(
+ f"[R3L Countdown Validation] 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {total_steps - 1}).")
+ return False, False
+ print(f"[R3L Countdown Validation] {outcome} report validation successful.")
+ return True, False
+ else:
+ print(f"[R3L Countdown Validation] Invalid trajectory_outcome: {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and attempts to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(experience_data, f, indent=2, ensure_ascii=False)
+
+ return filepath
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ success: bool,
+ predicted_answer: str,
+ ground_truth: str,
+ attempt_type: str,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record.
+
+ Args:
+ task_id: Task identifier
+ trajectory: Conversation trajectory
+ reward: Final reward
+ success: Whether the task was successful
+ predicted_answer: Model's predicted answer
+ ground_truth: Correct answer
+ attempt_type: Type of attempt (e.g., 'first', 'second', 'evaluation')
+ additional_metrics: Optional additional metrics
+
+ Returns:
+ Experience record dictionary
+ """
+ record = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "reward": reward,
+ "success": success,
+ "predicted_answer": predicted_answer,
+ "ground_truth": ground_truth,
+ }
+
+ if additional_metrics:
+ record["additional_metrics"] = additional_metrics
+
+ return record
+
+
+def format_countdown_prompt(
+ numbers: List[int],
+ target: int,
+ current_step: int,
+ attempt_history: List[Dict[str, str]],
+ history_length: int = 4
+) -> str:
+ """
+ Format countdown prompt with limited history.
+
+ Args:
+ numbers: Available numbers for the countdown problem
+ target: Target number to achieve
+ current_step: Current attempt number
+ attempt_history: List of previous attempts with equations and feedback
+ history_length: Maximum number of previous attempts to show (default: 4)
+
+ Returns:
+ Formatted prompt string
+ """
+ if current_step == 0 or not attempt_history:
+ # First attempt - no history
+ prompt = f"""You are an expert at solving countdown number problems.
+Your current task is: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+ else:
+ # Show limited history
+ recent_attempts = attempt_history[-history_length:] if len(attempt_history) > history_length else attempt_history
+
+ # Format attempt history as a list
+ history_lines = []
+ for idx, attempt in enumerate(recent_attempts):
+ attempt_num = current_step - len(recent_attempts) + idx + 1
+ history_lines.append(f" Attempt {attempt_num}: {attempt['equation']} -> {attempt['feedback']}")
+
+ attempt_history_str = "\n".join(history_lines)
+
+ prompt = f"""You are an expert at solving countdown number problems.
+Your task is: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
+Prior to this attempt, you have already made {current_step} attempt(s). Below are the most recent {len(recent_attempts)} attempts and their feedback:
+{attempt_history_str}
+You are now at attempt {current_step + 1}.
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+
+ return prompt
+
+
+def format_countdown_prompt_with_guidance(
+ numbers: List[int],
+ target: int,
+ current_step: int,
+ attempt_history: List[Dict[str, str]],
+ guidance_prompt: str,
+ history_length: int = 4
+) -> str:
+ """
+ Format countdown prompt with limited history and guidance from reflection.
+
+ Args:
+ numbers: Available numbers for the countdown problem
+ target: Target number to achieve
+ current_step: Current attempt number
+ attempt_history: List of previous attempts with equations and feedback
+ guidance_prompt: Guidance from reflection analysis
+ history_length: Maximum number of previous attempts to show (default: 4)
+
+ Returns:
+ Formatted prompt string with guidance
+ """
+ base_prompt = format_countdown_prompt(numbers, target, current_step, attempt_history, history_length)
+
+ # Insert guidance before the final instruction
+ split_marker = "Now it's your turn"
+ base_parts = base_prompt.split(split_marker)
+ base_prefix = base_parts[0] if len(base_parts) > 0 else base_prompt
+
+ prompt_with_guidance = f"""{base_prefix}
+# Previous Attempt Analysis & Guidance
+{guidance_prompt}
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your equation answer and present it within tags, for example (1 + 2) / 3."""
+
+ return prompt_with_guidance
diff --git a/trinity/common/workflows/envs/R3L/dapo/R3L_workflow.py b/trinity/common/workflows/envs/R3L/dapo/R3L_workflow.py
new file mode 100644
index 0000000000..ed013b312c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/R3L_workflow.py
@@ -0,0 +1,423 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_dapo_workflow")
+class R3LDapoWorkflow(Workflow):
+ """
+ R3L workflow for DAPO mathematical problem solving
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_dapo_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LDapoWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ print(f"[R3L] Reflection failed - Error: {str(e)}")
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L dapo workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[R3L] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ print(f"[R3L] Starting reflection on first attempt (reward: {reward})...")
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, attempts)
+
+ if reflect_checklist is None:
+ print(f"[R3L] Reflection failed - No valid reflection data generated")
+ elif is_valid and not is_perfect:
+ print(f"[R3L] Reflection successful - Valid reflection generated")
+ elif is_perfect:
+ print(f"[R3L] Reflection indicates perfect first attempt - No retry needed")
+ elif not is_valid:
+ print(f"[R3L] Reflection validation failed - Invalid reflection data")
+
+ if not is_valid or is_perfect:
+ print(f"[R3L] Skip second rollout due to invalid ({not is_valid}) or perfect ({is_perfect}) reflection.")
+ # If first attempt reward is 1.0 and reflection gives perfect, record reflection exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Do another rollout to ensure the batch has enough data
+ print(f"[R3L] Performing additional rollout...")
+ try:
+ retry_trajectory, retry_reward, retry_success, retry_predicted_answer, retry_ground_truth, retry_attempts = utils.first_rollout(self)
+ print(f"[R3L] Additional rollout completed - reward: {retry_reward}, attempts: {retry_attempts}")
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_success else 0.0,
+ "reward": retry_reward,
+ "attempts": retry_attempts,
+ }
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ success=retry_success,
+ predicted_answer=retry_predicted_answer,
+ ground_truth=retry_ground_truth,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"[R3L] Retry rollout after invalid reflection failed - Error: {e}")
+
+ else:
+ print("[R3L] Valid reflection obtained, proceeding to second rollout...")
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report (top-level field in alfworld schema)
+ retry_step = reflect_checklist.get("retry_from_step", 0)
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_success,
+ second_predicted_answer,
+ second_ground_truth,
+ second_attempts,
+ ) = utils.second_rollout(
+ self, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, attempts: {second_attempts}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_success else 0.0,
+ "second_reward": second_reward,
+ "second_attempts": second_attempts,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # Set eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ success=second_success,
+ predicted_answer=second_predicted_answer,
+ ground_truth=second_ground_truth,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # If second attempt score is higher than first, record reflection and retry data
+ if second_reward > reward and second_reward >= 1.0:
+ print(f"[R3L] Second attempt successful improvement - Recording reflection and retry experiences")
+ print(f"[R3L] Reward improvement: {reward} -> {second_reward} (+{second_reward - reward:.2f})")
+ reflect_exp.reward = 1.0
+ # Set eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # Convert retry data to exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ retry_exp.reward = 1.0
+ # Set eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("[R3L] Reflection and retry led to improvement, recording both...")
+ elif second_reward <= reward:
+ print(f"[R3L] Second attempt did not improve - First reward: {reward}, Second reward: {second_reward}")
+ else:
+ print(f"[R3L] Second attempt improved but below threshold - Reward: {second_reward} (need >= 1.0)")
+ except Exception as e:
+ print(f"[R3L] Second rollout failed - Error: {str(e)}")
+ except Exception as e:
+ print(f"[R3L] Rollout iteration {i} failed - Error: {str(e)}")
+
+ # Print summary statistics
+ print(f"\n[R3L Summary] Generated {len(exp_lst)} experiences")
+ total_reward = sum(exp.reward for exp in exp_lst)
+ avg_reward = total_reward / len(exp_lst) if exp_lst else 0.0
+ print(f"[R3L Summary] Total reward: {total_reward:.2f}, Average reward: {avg_reward:.2f}")
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/dapo/__init__.py b/trinity/common/workflows/envs/R3L/dapo/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L/dapo/dapo_workflow.py b/trinity/common/workflows/envs/R3L/dapo/dapo_workflow.py
new file mode 100644
index 0000000000..be57ee1f80
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/dapo_workflow.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_dapo_workflow")
+class DAPODapoWorkflow(Workflow):
+ """
+ DAPO Workflow for DAPO environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 4096)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing DAPODapoWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from raw_task
+ raw_task = task.raw_task or {}
+
+ # Handle different formats of raw_task
+ if "prompt" in raw_task:
+ self.prompt = raw_task["prompt"]
+ self.ground_truth = raw_task.get("ground_truth", "")
+ elif "question" in raw_task:
+ # Alternative format
+ self.prompt = raw_task["question"]
+ self.ground_truth = raw_task.get("answer", "")
+ elif "problem" in raw_task:
+ # Another alternative format
+ self.prompt = raw_task["problem"]
+ self.ground_truth = raw_task.get("solution", raw_task.get("answer", ""))
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[DAPO] Rollout - reward: {reward}, attempts: {attempts}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "attempts": attempts,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/dapo/grpo_workflow.py b/trinity/common/workflows/envs/R3L/dapo/grpo_workflow.py
new file mode 100644
index 0000000000..bd97ead5d7
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/grpo_workflow.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_dapo_workflow")
+class GRPOBaselineDapoWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for DAPO mathematical problem solving.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[GRPO] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/dapo/opmd_workflow.py b/trinity/common/workflows/envs/R3L/dapo/opmd_workflow.py
new file mode 100644
index 0000000000..e257764fdd
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/opmd_workflow.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_dapo_workflow")
+class OPMDBaselineDapoWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for DAPO mathematical problem solving.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing OPMDDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_dapo(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[OPMD] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/dapo/prompts/math_system.j2 b/trinity/common/workflows/envs/R3L/dapo/prompts/math_system.j2
new file mode 100644
index 0000000000..cbe87e085f
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/prompts/math_system.j2
@@ -0,0 +1,46 @@
+You are a mathematical problem solver. Your task is to solve mathematical problems step by step.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your step-by-step reasoning and solution process
+your final answer
+
+## Instructions:
+1. Carefully read and understand the problem
+2. Show your reasoning step by step in the tags
+3. Provide your final answer in the tags
+4. For numerical answers, provide the exact value
+5. If the problem asks for a specific format (e.g., \\boxed{}), use that format in your answer
+
+## Example:
+Problem: "What is the sum of all positive integers less than 100 that are divisible by 3?"
+
+
+I need to find all positive integers less than 100 that are divisible by 3, then sum them.
+
+The integers divisible by 3 less than 100 are: 3, 6, 9, ..., 99
+This is an arithmetic sequence with:
+- First term a₁ = 3
+- Common difference d = 3
+- Last term aₙ = 99
+
+To find how many terms: aₙ = a₁ + (n-1)d
+99 = 3 + (n-1)×3
+96 = (n-1)×3
+n-1 = 32
+n = 33
+
+Sum of arithmetic sequence: S = n(a₁ + aₙ)/2
+S = 33(3 + 99)/2
+S = 33 × 102/2
+S = 33 × 51
+S = 1683
+
+\boxed{1683}
+
+## Notes:
+- Be thorough in your reasoning
+- Show all important steps
+- Double-check your calculations
+- Provide the final answer clearly in the tags
diff --git a/trinity/common/workflows/envs/R3L/dapo/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/dapo/prompts/reflection.j2
new file mode 100644
index 0000000000..8d4d25e035
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name (mathematical problem solving), (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. For mathematical problems, this typically indicates the step where the mathematical error occurred. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Solving a quadratic equation 2x² + 5x - 3 = 0
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent attempted to solve a quadratic equation by applying the quadratic formula with a=2, b=5, c=-3. However, during the calculation of the discriminant, the agent computed b² - 4ac = 25 - 4(2)(-3) = 25 - 24 = 1 incorrectly (should be 25 + 24 = 49). This arithmetic error propagated through, yielding x = (-5 ± 1)/4, giving incorrect solutions x=-1 and x=-1.5 instead of the correct x=0.5 and x=-3.",
+"root_cause_analysis": "Why was the solution incorrect? Because the discriminant was computed as 1 instead of 49. Why was it computed as 1? Because the agent evaluated 4(2)(-3) as +24 instead of -24, then subtracted it: 25-24=1. Why this sign error? Because the agent failed to recognize that multiplying 4×2×(-3) = -24, and subtracting a negative is equivalent to addition: 25-(-24) = 25+24. Root cause: Insufficient attention to sign rules when handling negative numbers in multi-step arithmetic - the agent treated the operation mechanically without tracking the semantic meaning of 'subtracting 4ac when c is negative'.",
+"trajectory_outcome": "failure",
+"improvement_suggestion": "In mathematical problem solving, especially when applying formulas involving negative numbers, explicitly track and verify signs at each arithmetic step. For the quadratic formula discriminant b²-4ac: (1) when c is negative, 4ac becomes negative, (2) subtracting a negative number equals adding its absolute value, (3) always double-check: b²-4ac with c<0 becomes b²+4|a||c|. Before proceeding with the formula, verify the discriminant value independently. For 2x²+5x-3=0: b²-4ac = 25-4(2)(-3) = 25-(-24) = 25+24 = 49, not 1. This verification step catches sign errors early and prevents error propagation through subsequent calculations.",
+"retry_from_step": 0
+}
+```
diff --git a/trinity/common/workflows/envs/R3L/dapo/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/dapo/prompts/self_correction.j2
new file mode 100644
index 0000000000..3c0a34b676
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
diff --git a/trinity/common/workflows/envs/R3L/dapo/raft_workflow.py b/trinity/common/workflows/envs/R3L/dapo/raft_workflow.py
new file mode 100644
index 0000000000..2fd65d578b
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/raft_workflow.py
@@ -0,0 +1,154 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.dapo import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_dapo_workflow")
+class RAFTBaselineDapoWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for DAPO mathematical problem solving.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_attempts = 3
+ self.max_tokens = 4096
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.dapo_system_template = self.jinja_env.get_template("math_system.j2")
+
+ print(
+ f"Initializing RAFTDapoWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Extract prompt and ground truth from task
+ if hasattr(task, 'raw_task') and task.raw_task:
+ raw_task = task.raw_task
+
+ # Format 1: prompt is a list (math_dapo format)
+ if "prompt" in raw_task and isinstance(raw_task["prompt"], list):
+ if len(raw_task["prompt"]) > 0 and isinstance(raw_task["prompt"][0], dict):
+ self.prompt = raw_task["prompt"][0].get("content", "")
+ else:
+ self.prompt = ""
+
+ reward_model_data = raw_task.get("reward_model", {})
+ if isinstance(reward_model_data, dict):
+ self.ground_truth = reward_model_data.get("ground_truth", "")
+ else:
+ self.ground_truth = ""
+
+ # Format 2: question/answer format (AIME format)
+ elif "question" in raw_task and "answer" in raw_task:
+ self.prompt = raw_task.get("question", "")
+ self.ground_truth = raw_task.get("answer", "")
+
+ # Fallback: simple prompt/answer
+ else:
+ self.prompt = raw_task.get("prompt", "")
+ self.ground_truth = raw_task.get("answer", "")
+ else:
+ self.prompt = ""
+ self.ground_truth = ""
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = utils.first_rollout(self)
+ print(f"[RAFT] First rollout - reward: {reward}, attempts: {attempts}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ # RAFT only uses successful samples
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/dapo/utils.py b/trinity/common/workflows/envs/R3L/dapo/utils.py
new file mode 100644
index 0000000000..00ca16ad1d
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/dapo/utils.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from math_verify import parse, verify
+
+from trinity.common.experience import Experience
+
+
+def first_rollout(self) -> tuple[List[Dict[str, str]], float, bool, str, str, int]:
+ """Run math problem solving with multiple attempts (max 3 attempts)"""
+ trajectory = []
+
+ # Add system prompt
+ system_prompt = self.dapo_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ # Add user prompt (math problem) with format reminder
+ problem_prompt = self.prompt if self.prompt else "Please solve the given mathematical problem."
+ formatted_prompt = format_dapo_prompt(problem_prompt, attempt=0)
+ trajectory.append({"role": "user", "content": formatted_prompt})
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ formatted_feedback = format_dapo_prompt("", attempt=attempt_count, feedback=feedback)
+ trajectory.append({"role": "user", "content": formatted_feedback})
+ continue
+
+ # Verify answer
+ is_correct = math_verify(predicted_answer, self.ground_truth)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ print(f"[R3L First Rollout] Attempt {attempt_count} - Correct answer! Reward: {final_reward}")
+ feedback = f"Correct! Your answer {predicted_answer} matches the expected answer."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ print(f"[R3L First Rollout] Attempt {attempt_count} - Incorrect answer: {predicted_answer} (Expected: {self.ground_truth})")
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match. Please try again."
+ formatted_feedback = format_dapo_prompt("", attempt=attempt_count, feedback=feedback)
+ trajectory.append({"role": "user", "content": formatted_feedback})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match the expected answer. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ return trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+
+def second_rollout(
+ self,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, str, str, int]:
+ """
+ Performs rollout with guidance from reflection.
+ For math problems, we typically start from the beginning with guidance.
+ """
+ trajectory = []
+ distill_trajectory = []
+
+ # Prepare system prompts
+ original_system_prompt = self.dapo_system_template.render()
+
+ # Starting from beginning with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Add user prompt (math problem) with format reminder
+ problem_prompt = self.prompt if self.prompt else "Please solve the given mathematical problem."
+ formatted_prompt = format_dapo_prompt(problem_prompt, attempt=0)
+ trajectory.append({"role": "user", "content": formatted_prompt})
+ distill_trajectory.append({"role": "user", "content": formatted_prompt})
+
+ final_reward = 0.0
+ final_success = False
+ final_predicted_answer = ""
+ attempt_count = 0
+
+ # Try up to 3 attempts
+ for attempt in range(self.max_attempts):
+ attempt_count = attempt + 1
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 4096:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse think and answer
+ think, predicted_answer = parse_response(response_text)
+
+ if think is None or predicted_answer is None:
+ # Invalid format
+ feedback = "Invalid response format. Please ensure you provide both ... and ... tags."
+ formatted_feedback = format_dapo_prompt("", attempt=attempt_count, feedback=feedback)
+ trajectory.append({"role": "user", "content": formatted_feedback})
+ distill_trajectory.append({"role": "user", "content": formatted_feedback})
+ continue
+
+ # Verify answer
+ is_correct = math_verify(predicted_answer, self.ground_truth)
+
+ if is_correct:
+ final_reward = 1.0
+ final_success = True
+ final_predicted_answer = predicted_answer
+ print(f"[R3L Second Rollout] Attempt {attempt_count} - Correct answer! Reward: {final_reward}")
+ feedback = f"Correct! Your answer {predicted_answer} matches the expected answer."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ break
+ else:
+ # Wrong answer
+ print(f"[R3L Second Rollout] Attempt {attempt_count} - Incorrect answer: {predicted_answer} (Expected: {self.ground_truth})")
+ if attempt < self.max_attempts - 1:
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match. Please try again."
+ formatted_feedback = format_dapo_prompt("", attempt=attempt_count, feedback=feedback)
+ trajectory.append({"role": "user", "content": formatted_feedback})
+ distill_trajectory.append({"role": "user", "content": formatted_feedback})
+ else:
+ # Last attempt
+ feedback = f"Incorrect. Your answer {predicted_answer} does not match the expected answer. Maximum attempts reached."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ final_predicted_answer = predicted_answer
+
+ return distill_trajectory, trajectory, final_reward, final_success, final_predicted_answer, self.ground_truth, attempt_count
+
+
+def eval_dapo(self) -> List[Experience]:
+ """Evaluate a single math problem"""
+ print("[R3L Eval] Starting evaluation...")
+ try:
+ trajectory, reward, success, predicted_answer, ground_truth, attempts = first_rollout(self)
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if success else 0.0,
+ "reward": reward,
+ "attempts": attempts,
+ }
+ print(f"[R3L Eval] Completed - Reward: {reward}, Success: {success}, Attempts: {attempts}")
+ print(f"[R3L Eval] Predicted: {predicted_answer}, Ground Truth: {ground_truth}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ success=success,
+ predicted_answer=predicted_answer,
+ ground_truth=ground_truth,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ print(f"[R3L Eval] Evaluation failed - Error: {str(e)}")
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def format_dapo_prompt(prompt: str, attempt: int = 0, feedback: str = None) -> str:
+ """
+ Format DAPO prompt with format reminder for each user turn.
+
+ Args:
+ prompt: The math problem prompt
+ attempt: Current attempt number (0-based)
+ feedback: Optional feedback from previous attempt
+
+ Returns:
+ Formatted prompt string with format reminder
+ """
+ if attempt == 0 or feedback is None:
+ # First attempt - just the problem with format reminder
+ return f"""{prompt}
+
+Now it's your turn to solve this problem.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your final answer and present it within tags."""
+ else:
+ # Subsequent attempt - include feedback and format reminder
+ return f"""Feedback: {feedback}
+
+Now it's your turn to try again.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should provide your final answer and present it within tags."""
+
+
+def parse_response(response: str) -> Tuple[Optional[str], Optional[str]]:
+ """Parse think and answer from math response"""
+ try:
+ # Extract think section
+ think_pattern = r"\s*(.*?)\s*"
+ think_match = re.search(think_pattern, response, re.DOTALL)
+ think = think_match.group(1).strip() if think_match else None
+
+ # Extract answer from tags
+ answer_pattern = r"\s*(.*?)\s*"
+ answer_match = re.search(answer_pattern, response, re.DOTALL | re.IGNORECASE)
+ if answer_match:
+ answer = answer_match.group(1).strip()
+ else:
+ # Fallback: look for "Answer:" pattern
+ answer_line_pattern = r"Answer:\s*(.+?)(?:\n|$)"
+ answer_line_match = re.search(answer_line_pattern, response, re.IGNORECASE)
+ answer = answer_line_match.group(1).strip() if answer_line_match else None
+
+ return think, answer
+ except Exception as e:
+ print(f"Error parsing response: {e}")
+ return None, None
+
+
+def math_verify(predicted_answer: str, ground_truth: str) -> bool:
+ """
+ Verify if the predicted math answer matches the ground truth using math_verify library.
+ """
+ if not predicted_answer or not ground_truth:
+ return False
+
+ if parse is None or verify is None:
+ # Fallback: simple string comparison
+ pred_clean = str(predicted_answer).strip().lower()
+ gt_clean = str(ground_truth).strip().lower()
+ return pred_clean == gt_clean
+
+ try:
+ # Parse and verify
+ gold = parse(ground_truth)
+ answer = parse(predicted_answer)
+ return verify(gold, answer)
+ except Exception:
+ # Fallback comparison
+ return str(predicted_answer).strip().lower() == str(ground_truth).strip().lower()
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Format trajectory for reflection analysis.
+ Includes all messages including feedback.
+ """
+ formatted_lines = []
+ step_counter = 0
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ formatted_lines.append(f"**System Prompt:**\n{content}\n" + "=" * 50)
+ elif role == "user":
+ formatted_lines.append(f"\n**Step {step_counter} - User:**")
+ formatted_lines.append(f"{content}")
+ elif role == "assistant":
+ formatted_lines.append(f"\n**Step {step_counter} - Assistant Response:**")
+ formatted_lines.append(f"{content}")
+ step_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the alfworld reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ total_steps: Maximum number of steps in trajectory for retry_step bounds checking
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("[R3L DAPO Validation] Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+
+ if outcome == "success":
+ # For success, we only need summary and no flaw analysis
+ print("[R3L DAPO Validation] success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate required fields
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if improvement_suggestion is None or retry_from_step is None:
+ print("[R3L DAPO Validation] Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"[R3L DAPO Validation] 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"[R3L DAPO Validation] 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds if total_steps is provided
+ if total_steps is not None:
+ if retry_from_step >= total_steps:
+ print(
+ f"[R3L DAPO Validation] 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {total_steps - 1}).")
+ return False, False
+ print(f"[R3L DAPO Validation] {outcome} report validation successful.")
+ return True, False
+ else:
+ print(f"[R3L DAPO Validation] Invalid trajectory_outcome: {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ success: bool,
+ predicted_answer: str = "",
+ ground_truth: str = "",
+ attempt_type: str = "first",
+ additional_metrics: Optional[Dict] = None
+) -> Dict[str, Any]:
+ """Create an experience record for data saving"""
+ record = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "reward": reward,
+ "success": success,
+ "predicted_answer": predicted_answer,
+ "ground_truth": ground_truth,
+ }
+
+ if additional_metrics:
+ record.update(additional_metrics)
+
+ return record
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict[str, Any],
+ data_dir: str
+):
+ """Save experience data to file"""
+ os.makedirs(data_dir, exist_ok=True)
+ file_path = os.path.join(data_dir, f"{task_id}.json")
+
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(experience_data, f, ensure_ascii=False, indent=2)
+
+
+def generate_default_experience() -> Experience:
+ """Generate a default experience for failed cases"""
+ return Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={"success": 0.0, "reward": 0.0},
+ reward=0.0
+ )
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/R3L_workflow.py b/trinity/common/workflows/envs/R3L/scienceworld/R3L_workflow.py
new file mode 100644
index 0000000000..674ef37ed4
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/R3L_workflow.py
@@ -0,0 +1,354 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_scienceworld_workflow")
+class R3LScienceWorldWorkflow(Workflow):
+ """
+ R3L workflow for scienceworld
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_scienceworld_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ print(
+ f"Initializing R3LScienceWorldWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ """
+ # Format trajectory for LLM reading
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # Use Jinja2 template to render reflection prompt
+ reflect_prompt = self.reflection_template.render()
+
+ # Call model and parse results
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the R3L scienceworld workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Generate unique task ID
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n // 2): # Half for rollout, half for reflection + retry
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # Set eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # Reflect on first attempt
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ try:
+ retry_env = utils.create_sciworld_environment(self.task_desc)
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, retry_env
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report (top-level field in alfworld schema)
+ retry_step = reflect_checklist.get("retry_from_step", 0)
+
+ try:
+ second_env = utils.create_sciworld_environment(self.task_desc)
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, second_env, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ reflect_exp.reward = 1.0
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ retry_exp.reward = 1.0
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/__init__.py b/trinity/common/workflows/envs/R3L/scienceworld/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/dapo_workflow.py b/trinity/common/workflows/envs/R3L/scienceworld/dapo_workflow.py
new file mode 100644
index 0000000000..d9212215e8
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/dapo_workflow.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_scienceworld_workflow")
+class DAPOScienceWorldWorkflow(Workflow):
+ """
+ DAPO Workflow for ScienceWorld environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 16384)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing DAPOScienceWorldWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or ""
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_scienceworld(self)
+
+ # Single rollout execution
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[DAPO ScienceWorld] Rollout - reward: {reward}, steps: {steps}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "steps": steps,
+ "env_reward": reward,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO ScienceWorld] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/grpo_workflow.py b/trinity/common/workflows/envs/R3L/scienceworld/grpo_workflow.py
new file mode 100644
index 0000000000..3a9a00fd8e
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/grpo_workflow.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_scienceworld_workflow")
+class GRPOBaselineScienceWorldWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for ScienceWorld environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Single rollout execution
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/opmd_workflow.py b/trinity/common/workflows/envs/R3L/scienceworld/opmd_workflow.py
new file mode 100644
index 0000000000..d2b0aef8a2
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/opmd_workflow.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_scienceworld_workflow")
+class OPMDBaselineScienceWorldWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for ScienceWorld environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing OPMDScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_sciworld(self)
+
+ # Single rollout execution
+ env = utils.create_sciworld_environment(self.task_desc)
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/scienceworld/prompts/reflection.j2
new file mode 100644
index 0000000000..2ff44858cc
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name (ScienceWorld interactive science experiments), (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Growing a plant in ScienceWorld - task is to grow a plant to fruiting stage
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent attempted to grow a plant by placing a seed in soil and watering it once. However, the plant did not progress to fruiting stage within the allowed steps. The agent focused solely on initial planting and watering but failed to monitor plant growth stages or provide continued care (light, water, temperature) throughout the growth cycle.",
+"root_cause_analysis": "Why did the plant not fruit? Because it did not receive continued care after initial planting. Why was continued care not provided? Because the agent treated planting as a one-time action rather than an ongoing process. Why this one-time mindset? Because the agent lacked understanding that plant growth in ScienceWorld requires monitoring multiple growth stages and providing appropriate environmental conditions at each stage. Root cause: Incomplete mental model of plant growth dynamics - the agent optimized for task initiation (planting) rather than task completion (ensuring full growth cycle to fruiting), missing that biological processes require sustained interaction over time.",
+"trajectory_outcome": "failure",
+"improvement_suggestion": "In ScienceWorld plant growth experiments, treat growth as a multi-stage process requiring continuous monitoring and intervention. After planting: (1) regularly check plant status to identify current growth stage, (2) ensure continuous environmental needs (water, light, appropriate temperature) are met at each stage, (3) recognize that progression from seed→sprout→plant→flower→fruit requires multiple game cycles with appropriate care at each transition. For example, after planting a seed in soil, check plant every few steps, water when soil is dry, ensure adequate light exposure, and wait for natural stage transitions. A single watering is insufficient - plants need sustained attention throughout their growth cycle.",
+"retry_from_step": 2
+}
+```
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/prompts/sciworld_system.j2 b/trinity/common/workflows/envs/R3L/scienceworld/prompts/sciworld_system.j2
new file mode 100644
index 0000000000..daa6544dee
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/prompts/sciworld_system.j2
@@ -0,0 +1,42 @@
+You are an agent, your job is to do some scientific experiment in a virtual text-based environment.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Notes:
+At each step, you should first think then perform action to fulfill the instruction. You should ALWAYS wrap your thinking with the tag and wrap your action with the tag.
+You should ALWAYS take one action each step.
+DO NOT try to interact with the user at anytime. Finish the task by yourself.
+
+## Available Commands:
+Below are the available commands you can use:
+ open OBJ: open a container
+ close OBJ: close a container
+ activate OBJ: activate a device
+ deactivate OBJ: deactivate a device
+ connect OBJ to OBJ: connect electrical components
+ disconnect OBJ: disconnect electrical components
+ use OBJ [on OBJ]: use a device/item
+ look around: describe the current room
+ examine OBJ: examine an object in detail
+ look at OBJ: describe a container's contents
+ read OBJ: read a note or book
+ move OBJ to OBJ: move an object to a container
+ pick up OBJ: move an object to the inventory
+ pour OBJ into OBJ: pour a liquid into a container
+ mix OBJ: chemically mix a container
+ teleport to LOC: teleport to a specific room
+ focus on OBJ: signal intent on a task object
+ wait: take no action for 10 steps
+ wait1: take no action for a step
+
+## Action Format Examples:
+Your output should be like this:
+Now I will check the bedroom to find the thermometer...teleport to bedroom
+
+I need to examine the substance to understand its properties...examine substance
+
+To boil the water, I should activate the heating element...activate heating element
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/scienceworld/prompts/self_correction.j2
new file mode 100644
index 0000000000..3c0a34b676
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/raft_workflow.py b/trinity/common/workflows/envs/R3L/scienceworld/raft_workflow.py
new file mode 100644
index 0000000000..d788ff7f6c
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/raft_workflow.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.scienceworld import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_scienceworld_workflow")
+class RAFTBaselineScienceWorldWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for ScienceWorld environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 30
+ self.max_tokens = 16384
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.sciworld_system_template = self.jinja_env.get_template("sciworld_system.j2")
+
+ print(
+ f"Initializing RAFTScienceWorldWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ },
+ reward=0.0
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.task_desc = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ env = utils.create_sciworld_environment(self.task_desc)
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, env
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # RAFT only uses successful samples
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/scienceworld/utils.py b/trinity/common/workflows/envs/R3L/scienceworld/utils.py
new file mode 100644
index 0000000000..cc57ed2a87
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/scienceworld/utils.py
@@ -0,0 +1,603 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout in SciWorld environment"""
+ observation, info = env.reset()
+ observation = (
+ "Task Description: " + str(env.get_task_description()) + "\n" + observation
+ )
+
+ trajectory = []
+ action_history = [] # Track last actions for repetition detection
+
+ system_prompt = self.sciworld_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = 0.0
+ final_reward = 0.0
+ current_reward = 0.0
+ valid_format = True
+ step = 0
+ done = False
+
+ for step in range(self.max_env_steps):
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+
+ # Get model response
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response components
+ think, action = parse_response(response_text)
+ if action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+
+ # If last 3 actions are the same, terminate with failure
+ if len(action_history) >= 3 and all(
+ action == action_history[0] for action in action_history
+ ):
+ feedback = f"Repeated invalid action {action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ # Track cumulative reward
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if final_reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {final_reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif final_reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {final_reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif final_reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, final_reward, False, step + 1, valid_format
+
+
+def second_rollout(
+ self,
+ env,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+ """
+
+ # Reset environment to start fresh
+ observation, info = env.reset()
+ observation = (
+ "Task Description: " + str(env.get_task_description()) + "\n" + observation
+ )
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ # Prepare system prompts
+ original_system_prompt = self.sciworld_system_template.render()
+
+ default_reward = 0.0
+ final_reward = 0.0
+ current_reward = 0.0
+ valid_format = True
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+ else:
+ # If action becomes invalid during replay, start from beginning
+ retry_step = 0
+ break
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, final_reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system", "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+ distill_trajectory.append(
+ {"role": "user", "content": format_observation(observation)}
+ )
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 3:
+ action_history.pop(0)
+
+ # If last 3 actions are the same, terminate with failure
+ if len(action_history) >= 3 and all(
+ action == action_history[0] for action in action_history
+ ):
+ feedback = f"Repeated invalid action {action} multiple times, task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ if reward > current_reward:
+ final_reward = reward
+ current_reward = reward
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if final_reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Task completed successfully (reward: {final_reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif final_reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Task completed successfully (reward: {final_reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif final_reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ else:
+ feedback = (
+ f"Task not completed (reward: {final_reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ return distill_trajectory, trajectory, final_reward, False, step + 1, valid_format
+
+
+def eval_sciworld(self) -> List[Experience]:
+ """Evaluate a single sciworld trajectory"""
+ try:
+ env = create_sciworld_environment(self.task_desc)
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, env
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[Eval] First rollout - reward: {reward}, steps: {steps}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": 0.0,
+ }
+ )
+ exp.reward = 0.0
+ return [exp]
+
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def format_observation(observation: str):
+ """Format observation for SciWorld environment with format reminder"""
+ formatted_prompt = f"""Observation:
+{observation}
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an action and present it within tags.
+
+Format: your reasoning process your chosen action"""
+ return formatted_prompt
+
+
+def parse_response(response):
+ """Parse all three components from response with a single regex"""
+ think, action = None, None
+ try:
+ # Use single regex to extract all three components at once
+ pattern = r"\s*(.*?)\s*.*?\s*(.*?)\s*"
+ match = re.search(pattern, response, re.DOTALL)
+
+ if match:
+ think, action = match.group(1).strip(), match.group(2).strip()
+ except Exception:
+ pass
+ return think, action
+
+
+def validate_action(action):
+ """
+ Validate action format for SciWorld environment.
+ SciWorld actions don't need validation against available_actions like WebShop.
+ We just check if the action is non-empty.
+ """
+ if not action or not action.strip():
+ return False, "Action cannot be empty"
+
+ # SciWorld accepts any non-empty action string
+ # The environment itself will handle invalid actions
+ return True, ""
+
+
+def create_sciworld_environment(task_desc):
+ """Create sciworld environment"""
+ try:
+ from scienceworld import ScienceWorldEnv
+
+ # Parse task_desc to get task name and variation
+ # Format: "task_name-variation_number" or just "task_name"
+ if '-' in task_desc:
+ parts = task_desc.split('-')
+ task_name = parts[0]
+ variation = int(parts[1]) if len(parts) > 1 else 0
+ else:
+ task_name = task_desc
+ variation = 0
+
+ env = ScienceWorldEnv(task_name, serverPath="")
+ env.load(task_name, variation, generateGoldPath=True)
+
+ return env
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to import scienceworld dependencies: {e}. "
+ "Please install scienceworld following the instructions at https://github.com/allenai/ScienceWorld"
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the alfworld reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ total_steps: Maximum number of steps in trajectory for retry_step bounds checking
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("[R3L ScienceWorld Validation] Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+
+ if outcome == "success":
+ # For success, we only need summary and no flaw analysis
+ print("[R3L ScienceWorld Validation] success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate required fields
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if improvement_suggestion is None or retry_from_step is None:
+ print("[R3L ScienceWorld Validation] Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"[R3L ScienceWorld Validation] 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"[R3L ScienceWorld Validation] 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds if total_steps is provided
+ if total_steps is not None:
+ if retry_from_step >= total_steps:
+ print(
+ f"[R3L ScienceWorld Validation] 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {total_steps - 1}).")
+ return False, False
+ print(f"[R3L ScienceWorld Validation] {outcome} report validation successful.")
+ return True, False
+ else:
+ print(f"[R3L ScienceWorld Validation] Invalid trajectory_outcome: {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/trinity/common/workflows/envs/R3L/webshop/R3L_workflow.py b/trinity/common/workflows/envs/R3L/webshop/R3L_workflow.py
new file mode 100644
index 0000000000..b5a37b3b48
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/R3L_workflow.py
@@ -0,0 +1,416 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("R3L_webshop_workflow")
+class R3LWebshopWorkflow(Workflow):
+ """
+ R3L workflow for webshop
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ self.whether_save_data = False
+ # Create data directories
+ self.data_dir = f"R3L_webshop_data"
+ self.eval_dir = os.path.join(self.data_dir, "eval")
+ self.train_dir = os.path.join(self.data_dir, "train")
+
+ os.makedirs(self.eval_dir, exist_ok=True)
+ os.makedirs(self.train_dir, exist_ok=True)
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # Add WebShop path - can be overridden via WEBSHOP_PATH environment variable
+ webshop_path = os.environ.get("WEBSHOP_PATH")
+ if webshop_path:
+ sys.path.append(webshop_path)
+ else:
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ # Try gymnasium first, fallback to gym
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env require ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = f"Error importing WebAgentTextEnv {str(e)}. Please make sure you have installed the web_agent_site package, following the instructions in https://github.com/princeton-nlp/WebShop"
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+ self.reflection_template = self.jinja_env.get_template("reflection.j2")
+
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ },
+ reward=-0.1 # Default minimum reward for webshop tasks
+ )
+
+ self.default_second_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "second_success": 0.0,
+ "second_reward": -0.1,
+ },
+ reward=-0.1
+ )
+
+ print(
+ f"Initializing ExpLearnWebshopWorkflow with experience learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def get_reflect(self, trajectory: List[Dict[str, str]]) -> tuple[
+ Optional[Dict[str, Any]], Optional[str], Optional[Any]]:
+ """
+ Generates a comprehensive reflection report using a single, unified self-interrogation prompt.
+ The model first assesses its own performance and then follows the appropriate reflection path.
+ """
+ # print("Generating reflection report using the unified self-interrogation prompt...")
+
+ # 格式化轨迹以供LLM阅读
+ formatted_trajectory = utils.format_trajectory_for_reflection(trajectory)
+
+ # 使用Jinja2模板渲染反思提示
+ reflect_prompt = self.reflection_template.render()
+
+ # 调用模型并解析结果
+ try:
+ responses = self.model.chat(
+ [
+ {"role": "system", "content": reflect_prompt},
+ {"role": "user", "content": "Here is last attempt trajectory log: \n\n" + formatted_trajectory + "\n\nPlease output in the specified JSON format."}
+ ],
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_reflect_tokens,
+ )
+ reflection_text = responses[0].response_text.strip()
+
+ # print(f"raw reflection text: {reflection_text}")
+
+ # Find first '{' and last '}'
+ first_brace = reflection_text.find('{')
+ last_brace = reflection_text.rfind('}')
+
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
+ json_str = reflection_text[first_brace:last_brace + 1]
+ else:
+ json_str = reflection_text
+
+ reflection_data = json.loads(json_str)
+ return reflection_data, reflection_text, responses[0]
+
+ except Exception as e:
+ # print(f"Failed during unified reflection process: {e}")
+ return None, None, None
+
+ def _adjust_action_mask_for_retry(self, experience: Experience, retry_step: int):
+ """
+ Adjust action_mask in-place to exclude retry prefix from training.
+ Only tokens from retry_step onwards should be trained.
+
+ Args:
+ experience: The experience object with action_mask to adjust
+ retry_step: The step from which training should start
+ """
+ if retry_step <= 0:
+ return
+
+ # Note: experience.action_mask already excludes prompt tokens
+ action_mask = experience.action_mask
+
+ # Find all assistant response regions and mark the first 'retry_step' as non-trainable
+ if torch.any(action_mask == 1):
+ # Find all segments where action_mask == 1 (assistant responses)
+ assistant_segments = []
+ in_segment = False
+ segment_start = 0
+
+ for i, mask_val in enumerate(action_mask):
+ if mask_val == 1 and not in_segment:
+ # Start of a new segment
+ segment_start = i
+ in_segment = True
+ elif mask_val == 0 and in_segment:
+ # End of current segment
+ assistant_segments.append((segment_start, i))
+ in_segment = False
+
+ # Handle case where sequence ends with assistant response
+ if in_segment:
+ assistant_segments.append((segment_start, len(action_mask)))
+
+ # Set the first 'retry_step' assistant segments to 0 (non-trainable)
+ for i in range(min(retry_step, len(assistant_segments))):
+ start, end = assistant_segments[i]
+ action_mask[start:end] = 0
+
+ def run(self) -> List[Experience]:
+ """Run the experience learning webshop workflow and return experiences"""
+
+ if self.is_eval:
+ # print("pass evaluation mode")
+ # return [opmd_reflect_enhanced_restart_utils.generate_default_experience()]
+ return self.eval_webshop()
+
+ # Generate unique task ID using timestamp
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+
+ exp_lst = []
+ for i in range(self.n // 2): # 一半用于rollout,一半在此基础上进行反思再rollout
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[R3L] First rollout - reward: {reward}, steps: {steps}")
+
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ # exp.info = {"valid": format_valid}
+ # print(exp.info)
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # 设置eid
+ exp.eid.task = str(self.task.task_id) + f"_explore"
+ exp_run_id = len(exp_lst) + self.run_id_base
+ exp.eid.run = exp_run_id
+ exp_lst.append(exp)
+
+ if self.whether_save_data:
+ # Save first attempt experience data
+ first_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="first"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_first",
+ experience_data=first_record,
+ data_dir=self.train_dir
+ )
+
+ # 对首次尝试进行反思
+ reflect_checklist, reflection_text, reflect_exp = self.get_reflect(trajectory)
+ is_valid, is_perfect = utils.validate_reflect_report(reflect_checklist, steps)
+
+ if not is_valid or is_perfect:
+ # print("Reflect report is invalid or indicates perfection, skipping second rollout")
+ # 如果第一次尝试的reward是1.0且反思给出完美,则记录反思exp
+ if reward >= 1.0 and is_perfect and reflect_exp is not None:
+ reflect_exp.reward = 1.0
+ # 设置eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # 再进行一次rollout,以让整个batch有足够的数据
+ try:
+ retry_trajectory, retry_reward, retry_done, retry_steps, retry_format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+
+ retry_exp = self.model.convert_messages_to_experience(retry_trajectory[:-1])
+ retry_exp.reward = retry_reward
+ retry_exp.metrics = {
+ "success": 1.0 if retry_reward >= 1.0 else 0.0,
+ "steps": retry_steps,
+ "reward": retry_reward,
+ }
+ # 设置eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_explore"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ if self.whether_save_data:
+ # Save retry attempt experience data
+ retry_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=retry_trajectory,
+ reward=retry_reward,
+ steps=retry_steps,
+ success=retry_reward >= 1.0,
+ attempt_type="retry_after_invalid_reflection"
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_retry",
+ experience_data=retry_record,
+ data_dir=self.train_dir
+ )
+ except Exception as e:
+ print(f"Retry rollout after invalid reflection failed: {e}")
+
+ else:
+ guidance_prompt = utils.reflect_report_to_guidance_prompt(reflect_checklist)
+ # Extract retry_step from validated reflection report (top-level field in alfworld schema)
+ retry_step = reflect_checklist.get("retry_from_step", 0)
+
+ try:
+ (
+ distill_trajectory,
+ second_trajectory,
+ second_reward,
+ second_done,
+ second_steps,
+ second_format_valid,
+ ) = utils.second_rollout(
+ self, self.env, self.session_id, guidance_prompt, trajectory, retry_step
+ )
+ print(f"[R3L] Second rollout - reward: {second_reward}, steps: {second_steps}, improve: {second_reward > reward}")
+ second_exp = self.model.convert_messages_to_experience(distill_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(second_exp, retry_step)
+ # Also adjust first rollout exp for fair comparison
+ # Find and modify the exp that was already added to exp_lst
+ for existing_exp in exp_lst:
+ if existing_exp.eid.run == exp_run_id:
+ self._adjust_action_mask_for_retry(existing_exp, retry_step)
+ break
+
+ second_exp.reward = second_reward
+ # second_exp.info = {"valid": second_format_valid}
+ second_exp.metrics = {
+ "second_success": 1.0 if second_reward >= 1.0 else 0.0,
+ "second_steps": second_steps,
+ "second_reward": second_reward,
+ "second_improve": 1.0 if second_reward > reward else 0.0,
+ "second_reward_diff": second_reward - reward,
+ }
+ # 设置eid
+ second_exp.eid.task = str(self.task.task_id) + f"_explore"
+ second_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(second_exp)
+
+ if self.whether_save_data:
+ # Save second attempt experience data
+ second_record = utils.create_experience_record(
+ task_id=task_id,
+ trajectory=second_trajectory,
+ reward=second_reward,
+ steps=second_steps,
+ success=second_reward >= 1.0,
+ attempt_type="second",
+ additional_metrics={
+ "first_reward": reward,
+ "improvement": second_reward > reward,
+ "reward_difference": second_reward - reward,
+ "step_difference": second_steps - steps
+ }
+ )
+ utils.save_experience_data(
+ task_id=f"{task_id}_attempt_{i}_second",
+ experience_data=second_record,
+ data_dir=self.train_dir
+ )
+
+ # 如果第二次尝试的分数高于第一次,或第二次是满分情况下步数更少,则记录反思和重试数据
+ if (second_reward > reward and second_reward >= 1.0) or (second_reward >= 1.0 and second_steps < steps):
+ # 将反思数据转换为exp
+ # reflect_exp.reward = second_reward - reward
+ reflect_exp.reward = 1.0
+ # 设置eid
+ reflect_exp.eid.task = str(self.task.task_id) + f"_reflect_{i}"
+ reflect_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(reflect_exp)
+
+ # 将重试数据转换为exp
+ retry_exp = self.model.convert_messages_to_experience(second_trajectory[:-1])
+
+ # Adjust action_mask to exclude retry prefix from training
+ if retry_step > 0:
+ self._adjust_action_mask_for_retry(retry_exp, retry_step)
+
+ # retry_exp.reward = second_reward - reward
+ retry_exp.reward = 1.0
+ # 设置eid
+ retry_exp.eid.task = str(self.task.task_id) + f"_retry_{i}"
+ retry_exp.eid.run = len(exp_lst) + self.run_id_base
+ exp_lst.append(retry_exp)
+
+ # print
+ print("Reflection and retry led to improvement, recording both...")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/webshop/__init__.py b/trinity/common/workflows/envs/R3L/webshop/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/trinity/common/workflows/envs/R3L/webshop/dapo_workflow.py b/trinity/common/workflows/envs/R3L/webshop/dapo_workflow.py
new file mode 100644
index 0000000000..a717b2224d
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/dapo_workflow.py
@@ -0,0 +1,199 @@
+# -*- coding: utf-8 -*-
+import os
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("dapo_webshop_workflow")
+class DAPOWebshopWorkflow(Workflow):
+ """
+ DAPO Workflow for WebShop environment.
+ Performs rollouts with DAPO-style overlong penalty on response length.
+ No separate reward function needed - penalty computed directly in workflow.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # DAPO overlong penalty parameters
+ workflow_args = task.workflow_args or {}
+ self.enable_overlong_penalty = workflow_args.get("enable_overlong_penalty", True)
+ self.penalty_factor = workflow_args.get("penalty_factor", 1.0)
+ self.max_response_length = workflow_args.get("max_response_length", 512)
+ self.cache_length = workflow_args.get("cache_length", 100)
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # Add WebShop path - can be overridden via WEBSHOP_PATH environment variable
+ webshop_path = os.environ.get("WEBSHOP_PATH")
+ if webshop_path:
+ sys.path.append(webshop_path)
+ else:
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Failed to initialize WebShop environment: {e}. "
+ f"Please ensure web_agent_site is installed and accessible."
+ )
+ print(error_message)
+ raise RuntimeError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing DAPOWebshopWorkflow, temperature={self.temperature}, "
+ f"overlong_penalty={'enabled' if self.enable_overlong_penalty else 'disabled'}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = task.task_desc or "0"
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ # Update DAPO parameters if provided
+ workflow_args = task.workflow_args or {}
+ if "enable_overlong_penalty" in workflow_args:
+ self.enable_overlong_penalty = workflow_args["enable_overlong_penalty"]
+ if "penalty_factor" in workflow_args:
+ self.penalty_factor = workflow_args["penalty_factor"]
+ if "max_response_length" in workflow_args:
+ self.max_response_length = workflow_args["max_response_length"]
+ if "cache_length" in workflow_args:
+ self.cache_length = workflow_args["cache_length"]
+
+ def compute_overlong_penalty(self, response_tokens: torch.Tensor) -> float:
+ """
+ Compute DAPO-style overlong penalty based on response token length.
+
+ Args:
+ response_tokens: Response tokens (tensor)
+
+ Returns:
+ Penalty score (non-positive float)
+ """
+ if not self.enable_overlong_penalty:
+ return 0.0
+
+ response_len = len(response_tokens)
+ expected_len = self.max_response_length - self.cache_length
+
+ if response_len < expected_len:
+ # No penalty for short responses
+ return 0.0
+ elif response_len > self.max_response_length:
+ # Fixed penalty for excessively long responses
+ return -self.penalty_factor
+ else:
+ # Linear penalty in the transition zone
+ return (expected_len - response_len) / self.cache_length * self.penalty_factor
+
+ def run(self) -> List[Experience]:
+ """Run the DAPO workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_webshop(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[DAPO WebShop] Rollout - reward: {reward}, steps: {steps}")
+
+ # Convert trajectory to experience
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+
+ # Extract response tokens from experience
+ response_tokens = exp.tokens[exp.prompt_length:]
+
+ # Compute DAPO overlong penalty (format score)
+ format_score = self.compute_overlong_penalty(response_tokens)
+
+ # Calculate accuracy score
+ accuracy_score = 1.0 if reward >= 1.0 else 0.0
+
+ # Total reward = accuracy + format_score
+ total_reward = accuracy_score + format_score
+
+ # Update experience reward and metrics
+ exp.reward = total_reward
+ exp.metrics = {
+ "success": accuracy_score,
+ "steps": steps,
+ "env_reward": reward,
+ "accuracy": accuracy_score,
+ "format_score": format_score,
+ "response_length": len(response_tokens),
+ "total_reward": total_reward,
+ }
+
+ # Set experience ID
+ exp.eid.task = str(self.task.task_id)
+ exp.eid.run = i + self.run_id_base
+
+ exp_lst.append(exp)
+ except Exception as e:
+ print(f"[DAPO WebShop] Rollout failed: {e}")
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/webshop/grpo_workflow.py b/trinity/common/workflows/envs/R3L/webshop/grpo_workflow.py
new file mode 100644
index 0000000000..2b418db3af
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/grpo_workflow.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("grpo_baseline_webshop_workflow")
+class GRPOBaselineWebshopWorkflow(Workflow):
+ """
+ GRPO Baseline Workflow for WebShop environment.
+ Performs simple rollouts without reflection or learning from experience.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # Add WebShop path - can be overridden via WEBSHOP_PATH environment variable
+ webshop_path = os.environ.get("WEBSHOP_PATH")
+ if webshop_path:
+ sys.path.append(webshop_path)
+ else:
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing GRPOBaselineWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the GRPO baseline workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_webshop(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[GRPO] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/webshop/opmd_workflow.py b/trinity/common/workflows/envs/R3L/webshop/opmd_workflow.py
new file mode 100644
index 0000000000..58df14b9dd
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/opmd_workflow.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+from pathlib import Path
+from typing import List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("opmd_baseline_webshop_workflow")
+class OPMDBaselineWebshopWorkflow(Workflow):
+ """
+ OPMD Baseline workflow for WebShop environment.
+ Performs rollouts for offline policy model distillation.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # Add WebShop path - can be overridden via WEBSHOP_PATH environment variable
+ webshop_path = os.environ.get("WEBSHOP_PATH")
+ if webshop_path:
+ sys.path.append(webshop_path)
+ else:
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing OPMDWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the OPMD workflow and return experiences"""
+
+ if self.is_eval:
+ return utils.eval_webshop(self)
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[OPMD] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ exp_lst.append(exp)
+ except Exception:
+ pass
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/webshop/prompts/reflection.j2 b/trinity/common/workflows/envs/R3L/webshop/prompts/reflection.j2
new file mode 100644
index 0000000000..7137e32c69
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/prompts/reflection.j2
@@ -0,0 +1,28 @@
+You are a Reflector that analyzes trajectory logs based on user and environment feedback. Your goal is to identify what went wrong, trace root causes, and extract reusable principles for future improvement. Review the trajectory and feedback to understand the strategy and outcome. Through Socratic-style iterative "why" questioning, trace issues back to their fundamental flawed assumptions or mental models. Then formulate an actionable principle and suggest where to retry if needed.
+
+Please output in the following JSON format:
+
+```json
+{
+"trajectory_summary": "Concise overview in 1-3 sentences covering: (1) the strategy or approach employed by the agent, (2) the final result or outcome achieved, (3) key observations about execution quality (e.g., efficiency, correctness, optimality).",
+"root_cause_analysis": "Deep causal analysis using iterative 'why' questioning to trace from observable symptoms back to the fundamental root cause (flawed assumption, incorrect mental model, or critical knowledge gap). Chain your reasoning explicitly (e.g., 'Why X? Because Y. Why Y? Because Z.'). Identify the deepest underlying issue, not just surface-level errors. Set to null only if execution was truly flawless.",
+"trajectory_outcome": "Classification of the trajectory result. Must be EXACTLY one of these three values (case-sensitive, including underscores): 'success' (goal fully achieved with optimal execution quality), 'success_but_inefficient' (goal achieved but with unnecessary steps, redundant actions, or suboptimal approach), 'failure' (goal not achieved or task incomplete).",
+"improvement_suggestion": "A generalizable, context-complete principle for avoiding similar issues in future attempts. Must be self-contained and actionable without reference to this specific trajectory. Include: (1) the specific environment/system/domain name (WebShop e-commerce navigation), (2) the triggering conditions or scenario when this applies, (3) the specific problem or pitfall to avoid, (4) the recommended solution or approach with clear rationale. Frame as reusable knowledge. Set to null if and only if trajectory_outcome is 'success'.",
+"retry_from_step": "Integer from 0 to N-1 identifying the earliest step where the root cause first manifested or where a corrected decision could alter the outcome. This represents the optimal restart point if given one opportunity to retry. Use 0 when the root cause traces to initial strategy selection or foundational assumptions. Set to null if trajectory_outcome is 'success' or if retry would not be beneficial."
+}
+```
+
+## Example
+
+**Scenario**: Buying "wireless mouse under $20" in WebShop
+
+**Example Output**:
+```json
+{
+"trajectory_summary": "The agent searched for 'wireless mouse' and immediately clicked on the first search result without checking the price. After viewing the product page showing a $35 wireless mouse, the agent added it to cart and purchased it. The purchase was completed successfully but violated the price constraint (<$20), resulting in task failure despite correct product type.",
+"root_cause_analysis": "Why did the agent buy an overpriced mouse? Because it added to cart without price verification. Why skip price verification? Because after finding a product matching 'wireless mouse', the agent treated type-matching as sufficient for task completion. Why ignore price constraints? Because the agent parsed the task instruction incompletely, extracting only the product type keyword while discarding the critical price constraint '<$20'. Root cause: Incomplete task decomposition - the agent optimized for keyword matching (wireless mouse) rather than constraint satisfaction (wireless mouse AND price<$20), treating multi-constraint tasks as single-attribute search problems.",
+"trajectory_outcome": "failure",
+"improvement_suggestion": "In WebShop product search tasks, extract and validate ALL constraints before purchase. For queries with multiple requirements: (1) parse task instruction to identify all constraints (product type, price range, features, ratings), (2) during product search, filter or verify each constraint systematically, (3) on product pages, explicitly check constraint satisfaction (read price, verify features) before adding to cart, (4) treat any constraint violation as disqualifying. For 'wireless mouse under $20': first search 'wireless mouse', then on each product page check price<$20 before proceeding. Only items satisfying ALL constraints should reach checkout. A correct-type but wrong-price purchase is a failure, not a partial success.",
+"retry_from_step": 1
+}
+```
diff --git a/trinity/common/workflows/envs/R3L/webshop/prompts/self_correction.j2 b/trinity/common/workflows/envs/R3L/webshop/prompts/self_correction.j2
new file mode 100644
index 0000000000..3c0a34b676
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/prompts/self_correction.j2
@@ -0,0 +1,5 @@
+Your previous attempt encountered issues. Below is a reflection based on user and environment feedback:
+
+{{ report }}
+
+Apply the lessons learned from this reflection to avoid repeating the same mistakes. Do not mention or reference this guidance in your response.
diff --git a/trinity/common/workflows/envs/R3L/webshop/prompts/webshop_system.j2 b/trinity/common/workflows/envs/R3L/webshop/prompts/webshop_system.j2
new file mode 100644
index 0000000000..9b63586216
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/prompts/webshop_system.j2
@@ -0,0 +1,84 @@
+You are an agent interacting with a virtual text-based web shopping environment.
+
+## Response Format:
+You MUST use this exact format for every response. All tags are REQUIRED in sequential order:
+
+your analytical reasoning and thought process
+exactly one specific action command
+
+## Environment States:
+This virtual text-based web shopping environment contains five types of webpages:
+
+1. **Start/Index page** - Initial page with search functionality and task instruction
+2. **Search Results page** - Lists products returned by search engine with pagination
+3. **Item page** - Shows product details, options (color, size, etc.), and purchase button
+4. **Item Sub-page** - Shows additional product information (description, features, reviews)
+5. **Done page** - Final confirmation page after purchase
+
+## Available Actions:
+The command in `` must use one of the following two primitive formats:
+
+1. **`search[your_query_here]`**
+- **Usage:** To search for products from any page with a search bar
+- **Instructions:** Replace with specific search terms (can be multi-word)
+- **Example:** `search[blue cotton t-shirt medium]`
+
+2. **`click[exact_button_text_here]`**
+- **Usage:** To click on any clickable element (buttons, product links, options)
+- **Instructions:** Use the exact text as shown in observation (case-insensitive)
+- **Examples:**
+- `click[Buy Now]`
+- `click[Next >]`
+- `click[Size: Large]`
+- `click[Color: Red]`
+
+## Complete State Transition Table:
+
+| Current State | Action Type | Argument | Next State | Notes |
+|---------------|-------------|----------|------------|-------|
+| Start/Index | search | [Query] | Search Results | Initial search from homepage |
+| Search Results | search | [Query] | Search Results | New search resets results |
+| Search Results | click | [Product Title/ASIN] | Item Page | Select specific product |
+| Search Results | click | Next > | Search Results | Next page of results |
+| Search Results | click | < Prev | Search Results | Previous page of results |
+| Item Page | click | [Option Value] | Item Page | Select size/color/etc. (radio buttons) |
+| Item Page | click | Description | Item Sub-page | View product description |
+| Item Page | click | Features | Item Sub-page | View product features |
+| Item Page | click | Reviews | Item Sub-page | View product reviews |
+| Item Page | click | Buy Now | Done Page | **Purchase and end episode** |
+| Item Page | click | < Back to Search | Search Results | Return to search results |
+| Item Page | click | < Prev | Search Results | Return to search results |
+| Item Sub-page | click | < Prev | Item Page | Return to main product page |
+| Any Page | search | [Query] | Search Results | Start new search |
+
+## Key Implementation Details:
+
+### Clickable Elements:
+- **Buttons:** `[button] Text [button_]` → use `click[Text]`
+- **Product Links:** Product titles/ASINs → use `click[Product Name]`
+- **Options:** Radio buttons for size, color, etc. → use `click[Option Value]`
+- **Navigation:** `< Prev`, `Next >`, `< Back to Search`
+
+### Page Identification:
+You can identify the current page type by observing:
+- **Start page:** Contains initial instruction and search bar
+- **Search Results:** Lists multiple products with pagination controls
+- **Item Page:** Shows single product with options and "Buy Now" button
+- **Item Sub-page:** Shows detailed info without "Buy Now" button
+- **Done Page:** Shows purchase confirmation
+
+### Important Navigation Rules:
+1. **From Item Sub-pages:** You MUST click `< Prev` to return to Item Page before purchasing
+2. **Option Selection:** Selecting options (size, color) stays on the same Item Page
+3. **Search Resets:** Using search from any page starts a new product search
+4. **Purchase Requirement:** You can only purchase from the Item Page, not sub-pages
+
+## Task Completion:
+- **Goal:** Find and purchase an item matching the given instruction within 15 steps
+- **Success:** Episode ends when you click "Buy Now" with appropriate product and options
+
+## Observation Format:
+- Clickable elements appear as: `[button] Text [button_]`
+- Selected options may show as: `[clicked button] Text [clicked button_]`
+- Regular text appears without special formatting
+- The instruction text shows your shopping goal
\ No newline at end of file
diff --git a/trinity/common/workflows/envs/R3L/webshop/raft_workflow.py b/trinity/common/workflows/envs/R3L/webshop/raft_workflow.py
new file mode 100644
index 0000000000..5cc501a59e
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/raft_workflow.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+import copy
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.R3L.webshop import utils
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_baseline_webshop_workflow")
+class RAFTBaselineWebshopWorkflow(Workflow):
+ """
+ RAFT Baseline workflow for WebShop environment.
+ Performs rollouts for Reinforcement Learning from AI Feedback Training.
+ """
+
+ can_reset: bool = True
+ can_repeat: bool = True
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.max_env_steps = 15
+ self.max_tokens = 512
+ self.max_reflect_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+ self.whether_save_data = False
+
+ # Initialize WebShop environment
+ try:
+ import sys
+ # Add WebShop path - can be overridden via WEBSHOP_PATH environment variable
+ webshop_path = os.environ.get("WEBSHOP_PATH")
+ if webshop_path:
+ sys.path.append(webshop_path)
+ else:
+ # sys.path.append("/nas/shiweijie/trinity/webshop")
+ sys.path.append("/home/wshiah/code/shiweijie/weijie/trinity/webshop")
+ import gym
+ from web_agent_site.envs import WebAgentTextEnv # noqa: F401
+
+ # NOTE: Hosting the env requires ~15GB CPU memory.
+ # If you want easier env, you can set the num_products to 1000 or 100000.
+ self.env = gym.make(
+ "WebAgentTextEnv-v0",
+ observation_mode="text_rich",
+ num_products=None,
+ human_goals=True,
+ )
+ except Exception as e:
+ error_message = (
+ f"Error importing WebAgentTextEnv {str(e)}. "
+ f"Please make sure you have installed the web_agent_site package, "
+ f"following the instructions in https://github.com/princeton-nlp/WebShop"
+ )
+ raise ImportError(error_message)
+
+ # Initialize Jinja2 templates
+ prompts_dir = Path(__file__).parent / "prompts"
+ self.jinja_env = Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+ # Cache templates to avoid repeated loading
+ self.webshop_system_template = self.jinja_env.get_template("webshop_system.j2")
+
+ print(
+ f"Initializing RAFTWebshopWorkflow, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ # Default experience for error cases
+ self.default_exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ },
+ reward=-0.1
+ )
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.session_id = int(task.task_desc or "0")
+ self.is_eval = task.is_eval
+ self.task = task
+ self.n = task.repeat_times
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT workflow and return experiences"""
+
+ if self.is_eval:
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ return [exp]
+ except Exception:
+ return [copy.deepcopy(self.default_exp)]
+
+ # Single rollout execution
+ exp_lst = []
+ for i in range(self.n):
+ try:
+ trajectory, reward, done, steps, format_valid = utils.first_rollout(
+ self, self.env, self.session_id
+ )
+ print(f"[RAFT] First rollout - reward: {reward}, steps: {steps}")
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ # RAFT only uses successful samples
+ if reward >= 1.0:
+ exp_lst.append(exp)
+ else:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+ except Exception:
+ exp_lst.append(copy.deepcopy(self.default_exp))
+
+ return exp_lst
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
+ self.n = repeat_times
diff --git a/trinity/common/workflows/envs/R3L/webshop/utils.py b/trinity/common/workflows/envs/R3L/webshop/utils.py
new file mode 100644
index 0000000000..fba272dd98
--- /dev/null
+++ b/trinity/common/workflows/envs/R3L/webshop/utils.py
@@ -0,0 +1,617 @@
+import json
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from jinja2 import Environment, FileSystemLoader
+import torch
+from trinity.common.experience import Experience
+
+
+def first_rollout(self, env, session_id) -> tuple[List[Dict[str, str]], float, bool, int, bool]:
+ """Run a single rollout"""
+ # print(f"About to reset env with session_id: {session_id}")
+ env.reset(session=session_id)
+ observation = env.observation
+ trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ system_prompt = self.webshop_system_template.render()
+ trajectory.append({"role": "system", "content": system_prompt})
+
+ default_reward = -0.1
+ reward = default_reward
+ valid_format = True
+ step = 0
+
+ for step in range(self.max_env_steps):
+ available_actions = env.get_available_actions()
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+
+ # Get model response with experience guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the three components for action execution
+ think, action = parse_response(response_text)
+ if action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"Terminating due to invalid response format: {response_text}")
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+
+ # If last 2 actions are the same, terminate with failure
+ if len(action_history) >= 2 and all(
+ action == action_history[0] for action in action_history
+ ) and "next" not in action.lower() and "prev" not in action.lower() and "search" not in action.lower():
+ feedback = f"Repeated invalid action {action} multiple times, shopping task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ # print(f"Terminating due to 5 consecutive identical actions: {action_text}")
+ valid_format = False
+ return trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action, available_actions)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate timeout feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Shopping task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Shopping task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+ else:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+
+ # Add timeout feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return trajectory, reward, False, step + 1, valid_format
+
+def second_rollout(
+ self,
+ env,
+ session_id: int,
+ guidance_prompt: str,
+ first_trajectory: List[Dict[str, str]],
+ retry_step: int = 0,
+) -> tuple[List[Dict[str, str]], List[Dict[str, str]], float, bool, int, bool]:
+ """
+ Performs rollout starting from a specific retry step, reusing previous responses.
+
+ Args:
+ env: The environment instance.
+ session_id: The ID for the current task session.
+ guidance_prompt: The pre-generated guidance from reflection.
+ first_trajectory: The full log of the initial attempt.
+ retry_step: The step to start retry from (0-based, 0 means from beginning).
+
+ Returns:
+ A tuple containing (distill_trajectory, second_trajectory, reward, done status,
+ step count, and format validity).
+ """
+
+ # Reset environment to start fresh
+ env.reset(session=session_id)
+ observation = env.observation
+ trajectory = []
+ distill_trajectory = []
+ action_history = [] # Track last 3 actions for repetition detection
+
+ # Prepare system prompts
+ original_system_prompt = self.webshop_system_template.render()
+
+ default_reward = -0.1
+ reward = default_reward
+ valid_format = True
+
+ # Copy responses from first trajectory up to retry_step
+ step = 0
+ if retry_step > 0:
+ # Add original system prompt only
+ trajectory.append({"role": "system", "content": original_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ # Replay first trajectory up to retry_step to restore environment state
+ first_step = 0
+ for msg in first_trajectory[1:]: # Skip system message
+ if msg["role"] == "user":
+ # This is an observation - copy it and continue
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+ elif msg["role"] == "assistant":
+ if first_step < retry_step:
+ # Copy the assistant response from first trajectory
+ trajectory.append(msg)
+ distill_trajectory.append(msg)
+
+ # Execute the action to restore environment state
+ think, action = parse_response(msg["content"])
+ if think is not None and action is not None:
+ action_valid, error_msg = validate_action(action, env.get_available_actions())
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+ else:
+ # If action becomes invalid during replay, start from beginning
+ retry_step = 0
+ break
+ first_step += 1
+ step = first_step
+
+ if done:
+ # If environment finished during replay, no need to continue
+ return distill_trajectory, trajectory, reward, done, step, valid_format
+ else:
+ break
+
+ # Add guidance prompt as a separate system message before retry point
+ guidance_system_msg = {"role": "system", "content": f"# Previous Attempt Analysis & Guidance\n{guidance_prompt}"}
+ trajectory.append(guidance_system_msg)
+ # Don't add guidance to distill_trajectory to keep it clean
+
+ else:
+ # Starting from beginning - add system prompt with guidance
+ merged_system_prompt = f"{original_system_prompt}\n\n# Previous Attempt Analysis & Guidance\n{guidance_prompt}"
+ trajectory.append({"role": "system", "content": merged_system_prompt})
+ distill_trajectory.append({"role": "system", "content": original_system_prompt})
+
+ for step in range(step, self.max_env_steps):
+ available_actions = env.get_available_actions()
+ trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+ distill_trajectory.append(
+ {"role": "user", "content": format_observation(observation, available_actions)}
+ )
+
+ # Get model response with guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ # Check if tokens exceed limit
+ if responses[0].tokens.shape[0] >= 20480 - 512:
+ # 由于 chat 内部 tokenizer 会做截断,所以只要>= 最长限制 就直接终止
+ return distill_trajectory, trajectory, default_reward, False, step + 1, False
+
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+ distill_trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the response
+ think, action = parse_response(response_text)
+ if think is None or action is None:
+ valid_format = False
+ feedback = "Invalid response format, missing valid or tags, please ensure to follow the output format strictly: ... ..."
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Check for consecutive action repetition
+ action_history.append(action)
+ if len(action_history) > 2:
+ action_history.pop(0)
+
+ # If last 2 actions are the same, terminate with failure
+ if len(action_history) >= 2 and all(
+ action == action_history[0] for action in action_history
+ ) and "next" not in action.lower() and "prev" not in action.lower() and "search" not in action.lower():
+ feedback = f"Repeated invalid action {action} multiple times, shopping task failed"
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ valid_format = False
+ return distill_trajectory, trajectory, default_reward, False, step + 1, valid_format
+
+ # Validate and execute action in environment
+ action_valid, error_msg = validate_action(action, available_actions)
+ if action_valid:
+ observation, reward, done, info = env.step(action)
+ else:
+ observation, reward, done = error_msg, default_reward, False
+
+ if done:
+ break
+
+ # Generate feedback
+ if reward >= 1.0 and step + 1 < self.max_env_steps:
+ feedback = f"Shopping task completed successfully (reward: {reward}/1.0), and satisfying the step limit ({step + 1}/{self.max_env_steps} steps)"
+ elif reward >= 1.0 and step + 1 >= self.max_env_steps:
+ feedback = (
+ f"Shopping task completed successfully (reward: {reward}/1.0), but exceeded the step limit ({step + 1}/{self.max_env_steps} steps)"
+ )
+ elif reward < 1.0 and step + 1 < self.max_env_steps:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), but within the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+ else:
+ feedback = (
+ f"Shopping task not completed (reward: {reward}/1.0), and exceeded the step limit ({step + 1}/{self.max_env_steps} steps). It may not satisfy the Attribute Matching, Option Matching, or Price Matching requirements, please you carefully check and ensure all requirements are satisfied."
+ )
+
+ # Add feedback to trajectory
+ trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+ distill_trajectory.append({"role": "user", "content": f"Feedback: {feedback}"})
+
+ # For compatibility, return the same trajectory as both distill_trajectory and second_trajectory
+ # since we're starting fresh instead of resuming from a checkpoint
+ return distill_trajectory, trajectory, reward, False, step + 1, valid_format
+
+def eval_webshop(self) -> List[Experience]:
+ """Evaluate a single webshop trajectory"""
+ try:
+ trajectory, reward, done, steps, valid_format = first_rollout(
+ self, self.env, self.session_id
+ )
+ exp = self.model.convert_messages_to_experience(trajectory[:-1])
+ exp.reward = reward
+ exp.metrics = {
+ "success": 1.0 if reward >= 1.0 else 0.0,
+ "steps": steps,
+ "reward": reward,
+ }
+ print(f"[WebShop Eval] Rollout - reward: {reward}, steps: {steps}, valid_format: {valid_format}")
+
+ if self.whether_save_data:
+ # Save evaluation data
+ eval_task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ eval_record = create_experience_record(
+ task_id=eval_task_id,
+ trajectory=trajectory,
+ reward=reward,
+ steps=steps,
+ success=reward >= 1.0,
+ attempt_type="evaluation"
+ )
+ save_experience_data(
+ task_id=f"{eval_task_id}_eval",
+ experience_data=eval_record,
+ data_dir=self.eval_dir
+ )
+ except Exception as e:
+ # logger.warning(f"Single rollout failed during eval: {e}")
+ task_id = f"{str(self.task.batch_id).replace('/', '_')}_{self.task.task_id}"
+ exp = Experience(
+ tokens=torch.tensor([0, 0], dtype=torch.long),
+ prompt_length=1,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ logprobs=torch.tensor([0.0], dtype=torch.float),
+ metrics={
+ "success": 0.0,
+ "reward": -0.1,
+ }
+ )
+ exp.reward = -0.1
+ return [exp]
+
+def _get_jinja_env():
+ """Initialize Jinja2 environment for template loading."""
+ prompts_dir = Path(__file__).parent / "prompts"
+ return Environment(
+ loader=FileSystemLoader(str(prompts_dir)),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+
+
+def format_observation(observation: str, available_actions: dict):
+ """Format observation with format reminder for each turn"""
+ formatted_prompt = f"""Environment Observation: {observation}
+Available Actions: {available_actions}
+
+Now it's your turn to take an action.
+You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags.
+Once you've finished your reasoning, you should choose an action and present it within tags.
+
+Format: your reasoning process your chosen action"""
+ return formatted_prompt
+
+
+def parse_response(response):
+ """Parse all three components from response with a single regex"""
+ think, action = None, None
+ try:
+ # Use single regex to extract all three components at once
+ pattern = r"\s*(.*?)\s*.*?\s*(.*?)\s*"
+ match = re.search(pattern, response, re.DOTALL)
+
+ if match:
+ think, action = match.group(1).strip(), match.group(2).strip()
+ except Exception:
+ pass
+ return think, action
+
+
+def validate_action(action, available_actions):
+ """Validate action format and availability"""
+ import re
+
+ # Parse action format: action_name[action_arg]
+ pattern = re.compile(r"(.+)\[(.+)\]")
+ m = re.match(pattern, action)
+ if m is None:
+ return (
+ False,
+ "Invalid action format. You should use format: action_name[action_arg], like search[query] or click[button].",
+ )
+
+ action_name, action_arg = m.groups()
+ action_name = action_name.strip()
+ action_arg = action_arg.strip()
+
+ # Validate search action
+ if action_name == "search":
+ if not action_arg:
+ return (
+ False,
+ "Invalid search action, please type in the query you want to search in the square brackets.",
+ )
+ if not available_actions["has_search_bar"]:
+ return (
+ False,
+ "Cannot perform search action without search bar. Please click the Back to Search button first.",
+ )
+ return True, ""
+
+ # Validate click action
+ elif action_name == "click":
+ if not action_arg:
+ return (
+ False,
+ "Invalid click action, please specify the button name in the square brackets.",
+ )
+ # Convert to lowercase for comparison (as clickables are typically lowercase)
+ action_arg_lower = action_arg.lower()
+ if action_arg_lower not in available_actions["clickables"]:
+ return (
+ False,
+ f"Button '{action_arg}' not found on current page. Available buttons: {available_actions['clickables']}",
+ )
+ return True, ""
+
+ # Unknown action
+ else:
+ return (
+ False,
+ f"Unknown action '{action_name}'. Only 'search' and 'click' actions are supported.",
+ )
+
+
+def format_trajectory_for_reflection(trajectory: List[Dict[str, str]]) -> str:
+ """
+ Correctly formats the trajectory for reflection, including the system prompt
+ and numbering the user/assistant turns.
+ """
+ formatted_lines = []
+ # 使用一个计数器来追踪 user/assistant 的交互轮次
+ turn_counter = 0 # 从 0 开始计数
+
+ for msg in trajectory:
+ role = msg["role"]
+ content = msg["content"]
+
+ if role == "system":
+ # 系统提示不计入步骤,但必须作为规则展示在最前面
+ formatted_lines.append(f"**System Rules & Context:**\n{content}\n" + "=" * 30)
+ elif role == "user":
+ # 标记一个新回合的开始
+ formatted_lines.append(f"\n**Step {turn_counter}**")
+ formatted_lines.append(f" - User Observation/Feedback:\n {content.strip()}")
+ elif role == "assistant":
+ # 助理的思考和行动
+ formatted_lines.append(f" - Assistant Thought & Action:\n {content.strip()}")
+ # 一轮完整的 user-assistant 对话结束后,回合数增加
+ turn_counter += 1
+
+ return "\n".join(formatted_lines)
+
+
+def validate_reflect_report(report: Dict[str, Any], total_steps: int) -> Tuple[bool, bool]:
+ """
+ Validates the structure and content of the reflection report
+ based on the alfworld reflection.j2 schema.
+
+ Args:
+ report: The reflection report to validate
+ total_steps: Maximum number of steps in trajectory for retry_step bounds checking
+
+ Returns:
+ tuple[bool, bool]: (is_valid, is_perfect)
+ - is_valid: Whether the report structure is valid
+ - is_perfect: Whether the report indicates the trajectory is perfect (only meaningful if is_valid is True)
+ """
+ if (
+ not isinstance(report, dict)
+ or "trajectory_summary" not in report
+ or "root_cause_analysis" not in report
+ or "trajectory_outcome" not in report
+ ):
+ print("[R3L WebShop Validation] Report is not a dict or missing keys.")
+ return False, False
+
+ outcome = report["trajectory_outcome"]
+
+ if outcome == "success":
+ # For success, we only need summary and no flaw analysis
+ print("[R3L WebShop Validation] success report validation successful.")
+ return True, True
+
+ elif outcome in ["success_but_inefficient", "failure"]:
+ # For non-optimal outcomes, validate required fields
+ improvement_suggestion = report.get("improvement_suggestion", None)
+ retry_from_step = report.get("retry_from_step", None)
+
+ if improvement_suggestion is None or retry_from_step is None:
+ print("[R3L WebShop Validation] Missing 'improvement_suggestion' or 'retry_from_step'.")
+ return False, False
+
+ # check retry from step
+ try:
+ retry_from_step = int(retry_from_step)
+ except (ValueError, TypeError):
+ print(f"[R3L WebShop Validation] 'retry_from_step' must be an integer. Got: {retry_from_step}")
+ return False, False
+ if not isinstance(retry_from_step, int) or retry_from_step < 0:
+ print(f"[R3L WebShop Validation] 'retry_from_step' must be a non-negative integer. Got: {retry_from_step}")
+ return False, False
+ # Check trajectory bounds if total_steps is provided
+ if total_steps is not None:
+ if retry_from_step >= total_steps:
+ print(
+ f"[R3L WebShop Validation] 'retry_from_step' ({retry_from_step}) exceeds trajectory bounds (0 to {total_steps - 1}).")
+ return False, False
+ print(f"[R3L WebShop Validation] {outcome} report validation successful.")
+ return True, False
+ else:
+ print(f"[R3L WebShop Validation] Invalid trajectory_outcome: {outcome}")
+ return False, False
+
+
+def reflect_report_to_guidance_prompt(report: Dict[str, Any]) -> str:
+ """
+ Converts a validated reflection report into a structured, actionable
+ guidance prompt for the agent's second attempt. This prompt is framed
+ as an internal directive to ensure the model's output is clean for SFT.
+ """
+ # Convert the report dictionary to a formatted string
+ report_str = json.dumps(report, indent=2, ensure_ascii=False)
+
+ # Load and render template
+ jinja_env = _get_jinja_env()
+ template = jinja_env.get_template("self_correction.j2")
+
+ return template.render(report=report_str)
+
+
+def save_experience_data(
+ task_id: str,
+ experience_data: Dict,
+ data_dir: str
+) -> str:
+ """
+ Save experience data including trajectory, rewards, and steps to a JSON file.
+
+ Args:
+ task_id: Unique identifier for the task
+ experience_data: Dictionary containing experience information
+ data_dir: Directory to save the data
+
+ Returns:
+ Path to the saved file
+ """
+ os.makedirs(data_dir, exist_ok=True)
+
+ # Add timestamp for uniqueness
+ filename = f"{task_id}.json"
+ filepath = os.path.join(data_dir, filename)
+
+ # Ensure experience_data is JSON serializable
+ serializable_data = {}
+ for key, value in experience_data.items():
+ if isinstance(value, torch.Tensor):
+ serializable_data[key] = value.tolist()
+ elif hasattr(value, '__dict__'):
+ # For complex objects, convert to dict representation
+ serializable_data[key] = str(value)
+ else:
+ serializable_data[key] = value
+
+ # Add metadata
+ serializable_data["saved_at"] = datetime.now().isoformat()
+ serializable_data["task_id"] = task_id
+
+ try:
+ with open(filepath, 'w', encoding='utf-8') as f:
+ json.dump(serializable_data, f, indent=2, ensure_ascii=False)
+ # print(f"Experience data saved to: {filepath}")
+ return filepath
+ except Exception as e:
+ # print(f"Failed to save experience data: {e}")
+ return ""
+
+
+def create_experience_record(
+ task_id: str,
+ trajectory: List[Dict[str, str]],
+ reward: float,
+ steps: int,
+ success: bool,
+ attempt_type: str = "first",
+ reflection_data: Optional[Dict] = None,
+ additional_metrics: Optional[Dict] = None
+) -> Dict:
+ """
+ Create a structured experience record for saving.
+
+ Args:
+ task_id: Unique identifier for the task
+ trajectory: List of conversation messages
+ reward: Final reward received
+ steps: Number of steps taken
+ success: Whether the task was completed successfully
+ attempt_type: Type of attempt ("first", "second", "reflect")
+ reflection_data: Optional reflection analysis data
+ additional_metrics: Additional metrics to record
+
+ Returns:
+ Dictionary containing structured experience data
+ """
+ experience_record = {
+ "task_id": task_id,
+ "attempt_type": attempt_type,
+ "trajectory": trajectory,
+ "metrics": {
+ "reward": reward,
+ "steps": steps,
+ "success": success,
+ "trajectory_length": len(trajectory)
+ },
+ "created_at": datetime.now().isoformat()
+ }
+
+ if reflection_data:
+ experience_record["reflection"] = reflection_data
+
+ if additional_metrics:
+ experience_record["metrics"].update(additional_metrics)
+
+ return experience_record
diff --git a/writing/acl.sty b/writing/acl.sty
new file mode 100644
index 0000000000..d9b74d0e6d
--- /dev/null
+++ b/writing/acl.sty
@@ -0,0 +1,312 @@
+% This is the LaTex style file for *ACL.
+% The official sources can be found at
+%
+% https://github.com/acl-org/acl-style-files/
+%
+% This package is activated by adding
+%
+% \usepackage{acl}
+%
+% to your LaTeX file. When submitting your paper for review, add the "review" option:
+%
+% \usepackage[review]{acl}
+
+\newif\ifacl@finalcopy
+\newif\ifacl@anonymize
+\newif\ifacl@linenumbers
+\newif\ifacl@pagenumbers
+\DeclareOption{final}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumbersfalse}
+\DeclareOption{review}{\acl@finalcopyfalse\acl@anonymizetrue\acl@linenumberstrue\acl@pagenumberstrue}
+\DeclareOption{preprint}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumberstrue}
+\ExecuteOptions{final} % final copy is the default
+
+% include hyperref, unless user specifies nohyperref option like this:
+% \usepackage[nohyperref]{acl}
+\newif\ifacl@hyperref
+\DeclareOption{hyperref}{\acl@hyperreftrue}
+\DeclareOption{nohyperref}{\acl@hyperreffalse}
+\ExecuteOptions{hyperref} % default is to use hyperref
+\ProcessOptions\relax
+
+\typeout{Conference Style for ACL}
+
+\usepackage{xcolor}
+
+\ifacl@linenumbers
+ % Add draft line numbering via the lineno package
+ % https://texblog.org/2012/02/08/adding-line-numbers-to-documents/
+ \usepackage[switch,mathlines]{lineno}
+
+ % Line numbers in gray Helvetica 8pt
+ \font\aclhv = phvb at 8pt
+ \renewcommand\linenumberfont{\aclhv\color{lightgray}}
+
+ % Zero-fill line numbers
+ % NUMBER with left flushed zeros \fillzeros[]
+ \newcount\cv@tmpc@ \newcount\cv@tmpc
+ \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi
+ \cv@tmpc=1 %
+ \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi
+ \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat
+ \ifnum#2<0\advance\cv@tmpc1\relax-\fi
+ \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat
+ \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}%
+ \renewcommand\thelinenumber{\fillzeros[3]{\arabic{linenumber}}}
+ \AtBeginDocument{\linenumbers}
+
+ \setlength{\linenumbersep}{1.6cm}
+
+ % Bug: An equation with $$ ... $$ isn't numbered, nor is the previous line.
+
+ % Patch amsmath commands so that the previous line and the equation itself
+ % are numbered. Bug: multline has an extra line number.
+ % https://tex.stackexchange.com/questions/461186/how-to-use-lineno-with-amsmath-align
+ \usepackage{etoolbox} %% <- for \pretocmd, \apptocmd and \patchcmd
+
+ \newcommand*\linenomathpatch[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomath}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomath}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+ \newcommand*\linenomathpatchAMS[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomathAMS}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomathAMS}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+
+ %% Definition of \linenomathAMS depends on whether the mathlines option is provided
+ \expandafter\ifx\linenomath\linenomathWithnumbers
+ \let\linenomathAMS\linenomathWithnumbers
+ %% The following line gets rid of an extra line numbers at the bottom:
+ \patchcmd\linenomathAMS{\advance\postdisplaypenalty\linenopenalty}{}{}{}
+ \else
+ \let\linenomathAMS\linenomathNonumbers
+ \fi
+
+ \AtBeginDocument{%
+ \linenomathpatch{equation}%
+ \linenomathpatchAMS{gather}%
+ \linenomathpatchAMS{multline}%
+ \linenomathpatchAMS{align}%
+ \linenomathpatchAMS{alignat}%
+ \linenomathpatchAMS{flalign}%
+ }
+\else
+ % Hack to ignore these commands, which review mode puts into the .aux file.
+ \newcommand{\@LN@col}[1]{}
+ \newcommand{\@LN}[2]{}
+ \newcommand{\nolinenumbers}{}
+\fi
+
+\PassOptionsToPackage{a4paper,margin=2.5cm,heightrounded=true}{geometry}
+\RequirePackage{geometry}
+
+\setlength\columnsep{0.6cm}
+\newlength\titlebox
+\setlength\titlebox{11\baselineskip}
+% \titlebox should be a multiple of \baselineskip so that
+% column height remaining fits an exact number of lines of text
+
+\flushbottom \twocolumn \sloppy
+
+% We're never going to need a table of contents, so just flush it to
+% save space --- suggested by drstrip@sandia-2
+\def\addcontentsline#1#2#3{}
+
+\ifacl@pagenumbers
+ \pagenumbering{arabic}
+\else
+ \thispagestyle{empty}
+ \pagestyle{empty}
+\fi
+
+%% Title and Authors %%
+
+\let\Thanks\thanks % \Thanks and \thanks used to be different, but keep this for backwards compatibility.
+
+\newcommand\outauthor{%
+ \begin{tabular}[t]{c}
+ \ifacl@anonymize
+ \bfseries Anonymous ACL submission
+ \else
+ \bfseries\@author
+ \fi
+ \end{tabular}}
+
+% Mostly taken from deproc.
+\AtBeginDocument{
+\def\maketitle{\par
+ \begingroup
+ \def\thefootnote{\fnsymbol{footnote}}
+ \twocolumn[\@maketitle]
+ \@thanks
+ \endgroup
+ \setcounter{footnote}{0}
+ \let\maketitle\relax
+ \let\@maketitle\relax
+ \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax}
+\def\@maketitle{\vbox to \titlebox{\hsize\textwidth
+ \linewidth\hsize \vskip 0.125in minus 0.125in \centering
+ {\Large\bfseries \@title \par} \vskip 0.2in plus 1fil minus 0.1in
+ {\def\and{\unskip\enspace{\rmfamily and}\enspace}%
+ \def\And{\end{tabular}\hss \egroup \hskip 1in plus 2fil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bfseries}%
+ \def\AND{\end{tabular}\hss\egroup \hfil\hfil\egroup
+ \vskip 0.25in plus 1fil minus 0.125in
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bfseries}
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss
+ \outauthor
+ \hss\egroup
+ \hfil\hfil\egroup}
+ \vskip 0.3in plus 2fil minus 0.1in
+}}
+}
+
+% margins and font size for abstract
+\renewenvironment{abstract}%
+ {\begin{center}\large\textbf{\abstractname}\end{center}%
+ \begin{list}{}%
+ {\setlength{\rightmargin}{0.6cm}%
+ \setlength{\leftmargin}{0.6cm}}%
+ \item[]\ignorespaces%
+ \@setsize\normalsize{12pt}\xpt\@xpt
+ }%
+ {\unskip\end{list}}
+
+% Resizing figure and table captions - SL
+% Support for interacting with the caption, subfigure, and subcaption packages - SL
+\RequirePackage{caption}
+\DeclareCaptionFont{10pt}{\fontsize{10pt}{12pt}\selectfont}
+\captionsetup{font=10pt}
+
+\RequirePackage{natbib}
+% for citation commands in the .tex, authors can use:
+% \citep, \citet, and \citeyearpar for compatibility with natbib, or
+% \cite, \newcite, and \shortcite for compatibility with older ACL .sty files
+\renewcommand\cite{\citep} % to get "(Author Year)" with natbib
+\newcommand\shortcite{\citeyearpar}% to get "(Year)" with natbib
+\newcommand\newcite{\citet} % to get "Author (Year)" with natbib
+\newcommand{\citeposs}[1]{\citeauthor{#1}'s (\citeyear{#1})} % to get "Author's (Year)"
+
+\bibliographystyle{acl_natbib}
+
+% Bibliography
+
+% Don't put a label in the bibliography at all. Just use the unlabeled format
+% instead.
+\def\thebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{References\@mkboth
+ {References}{References}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthebibliography=\endlist
+
+
+% Allow for a bibliography of sources of attested examples
+\def\thesourcebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{Sources of Attested Examples\@mkboth
+ {Sources of Attested Examples}{Sources of Attested Examples}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthesourcebibliography=\endlist
+
+% sections with less space
+\def\section{\@startsection {section}{1}{\z@}{-2.0ex plus
+ -0.5ex minus -.2ex}{1.5ex plus 0.3ex minus .2ex}{\large\bfseries\raggedright}}
+\def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus
+ -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\bfseries\raggedright}}
+%% changed by KO to - values to get the initial parindent right
+\def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex plus
+ -0.5ex minus -.2ex}{0.5ex plus .2ex}{\normalsize\bfseries\raggedright}}
+\def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bfseries}}
+\def\subparagraph{\@startsection{subparagraph}{5}{\parindent}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bfseries}}
+
+% Footnotes
+\footnotesep 6.65pt %
+\skip\footins 9pt plus 4pt minus 2pt
+\def\footnoterule{\kern-3pt \hrule width 5pc \kern 2.6pt }
+\setcounter{footnote}{0}
+
+% Lists and paragraphs
+\parindent 1em
+\topsep 4pt plus 1pt minus 2pt
+\partopsep 1pt plus 0.5pt minus 0.5pt
+\itemsep 2pt plus 1pt minus 0.5pt
+\parsep 2pt plus 1pt minus 0.5pt
+
+\leftmargin 2em \leftmargini\leftmargin \leftmarginii 2em
+\leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em \leftmarginvi .5em
+\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
+
+\def\@listi{\leftmargin\leftmargini}
+\def\@listii{\leftmargin\leftmarginii
+ \labelwidth\leftmarginii\advance\labelwidth-\labelsep
+ \topsep 2pt plus 1pt minus 0.5pt
+ \parsep 1pt plus 0.5pt minus 0.5pt
+ \itemsep \parsep}
+\def\@listiii{\leftmargin\leftmarginiii
+ \labelwidth\leftmarginiii\advance\labelwidth-\labelsep
+ \topsep 1pt plus 0.5pt minus 0.5pt
+ \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt
+ \itemsep \topsep}
+\def\@listiv{\leftmargin\leftmarginiv
+ \labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
+\def\@listv{\leftmargin\leftmarginv
+ \labelwidth\leftmarginv\advance\labelwidth-\labelsep}
+\def\@listvi{\leftmargin\leftmarginvi
+ \labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
+
+\abovedisplayskip 7pt plus2pt minus5pt%
+\belowdisplayskip \abovedisplayskip
+\abovedisplayshortskip 0pt plus3pt%
+\belowdisplayshortskip 4pt plus3pt minus3pt%
+
+% Less leading in most fonts (due to the narrow columns)
+% The choices were between 1-pt and 1.5-pt leading
+\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt}
+\def\small{\@setsize\small{10pt}\ixpt\@ixpt}
+\def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt}
+\def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt}
+\def\tiny{\@setsize\tiny{7pt}\vipt\@vipt}
+\def\large{\@setsize\large{14pt}\xiipt\@xiipt}
+\def\Large{\@setsize\Large{16pt}\xivpt\@xivpt}
+\def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt}
+\def\huge{\@setsize\huge{23pt}\xxpt\@xxpt}
+\def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt}
+
+% The hyperref manual (section 9) says hyperref should be loaded after natbib
+\ifacl@hyperref
+ \PassOptionsToPackage{breaklinks}{hyperref}
+ \RequirePackage{hyperref}
+ % make links dark blue
+ \definecolor{darkblue}{rgb}{0, 0, 0.5}
+ \hypersetup{colorlinks=true, citecolor=darkblue, linkcolor=darkblue, urlcolor=darkblue}
+\else
+ % This definition is used if the hyperref package is not loaded.
+ % It provides a backup, no-op definiton of \href.
+ % This is necessary because \href command is used in the acl_natbib.bst file.
+ \def\href#1#2{{#2}}
+ \usepackage{url}
+\fi
diff --git a/writing/acl_latex.tex b/writing/acl_latex.tex
new file mode 100644
index 0000000000..534c1f2587
--- /dev/null
+++ b/writing/acl_latex.tex
@@ -0,0 +1,353 @@
+\documentclass[11pt]{article}
+
+% Change "review" to "final" to generate the final (sometimes called camera-ready) version.
+% Change to "preprint" to generate a non-anonymous version with page numbers.
+\usepackage[review]{acl}
+
+% Standard package includes
+\usepackage{times}
+\usepackage{latexsym}
+
+% For proper rendering and hyphenation of words containing Latin characters (including in bib files)
+\usepackage[T1]{fontenc}
+% For Vietnamese characters
+% \usepackage[T5]{fontenc}
+% See https://www.latex-project.org/help/documentation/encguide.pdf for other character sets
+
+% This assumes your files are encoded as UTF8
+\usepackage[utf8]{inputenc}
+
+% This is not strictly necessary, and may be commented out,
+% but it will improve the layout of the manuscript,
+% and will typically save some space.
+\usepackage{microtype}
+\usepackage{inconsolata}
+
+% Standard package includes
+\usepackage{times}
+\usepackage{latexsym}
+\usepackage{algorithmic}
+\usepackage{graphicx}
+\usepackage{textcomp}
+\usepackage{xcolor}
+\usepackage{physics}
+\usepackage{mathdots}
+\usepackage{algorithm}
+\usepackage{subfigure}
+\usepackage{pgfplots}
+\usepackage{tikz}
+\usetikzlibrary{patterns}
+\usepackage{multirow}
+\usepackage{multicol}
+% \usepackage{authblk}
+\usepackage{array}
+\pgfplotsset{compat=1.18}
+
+\usepackage{tcolorbox}
+\tcbuselibrary{breakable}
+\usepackage{listings}
+\usepackage{cuted} % 提供 strip 环境
+\usepackage{booktabs} % 用于漂亮的表格线条 (top/mid/bottom rule)
+
+\lstset{
+ basicstyle=\footnotesize\ttfamily,
+ breaklines=true,
+ breakatwhitespace=false,
+ columns=flexible,
+ keepspaces=true,
+ showstringspaces=false,
+ tabsize=2,
+ frame=none
+}
+
+% 在这里添加这一行
+\let\Bbbk\relax
+\usepackage{amsmath,amssymb,amsfonts}
+\DeclareMathOperator*{\argmax}{arg\,max}
+% If the title and author information does not fit in the area allocated, uncomment the following
+%
+%\setlength\titlebox{}
+%
+% and set to something 5cm or larger.
+
+\title{R$^3$L: Reflect, Retry, and Reinforce Learning via Language-Guided Exploration and Positive Preference Optimization}
+
+
+\author{First Author \\
+ Affiliation / Address line 1 \\
+ \texttt{email@domain} \\\And
+ Second Author \\
+ Affiliation / Address line 1 \\
+ \texttt{email@domain} \\}
+
+
+\begin{document}
+\maketitle
+\begin{abstract}
+Group-relative policy optimization for reinforcement learning has emerged as a powerful technique to improve reasoning and agentic capabilities in large language models. However, this approach is limited by inefficient exploration, coarse credit assignment, and training instability when learning from failure-dominated data. To this end, we propose R$^3$L: Reflect, Retry, and Reinforce Learning. To improve exploration, R$^3$L employs a reflect-then-retry mechanism that uses language feedback to identify failure points and synthesize high-quality trajectories. To refine credit assignment, R$^3$L applies Pivotal Credit Assignment, which leverages these identified failure points to apply updates only after the error, preserving valid steps. To ensure stable learning from this synthesized data, R$^3$L introduces Positive Preference Optimization that amplifies gradients from successful trajectories, ensuring they dominate the learning signal. Experiments on reasoning and agentic tasks demonstrate xx\% improvements over baselines while maintaining training stability.
+\end{abstract}
+
+\section{Introduction}
+Group-Relative Policy Optimization (GRPO) \cite{shao2024deepseekmath} for reinforcement learning has emerged as a powerful technique to improve reasoning and agentic capabilities in large language models \cite{cui2025enhancing,shi2025semantic,plaat2025agentic}. By comparing sampled responses within each group, GRPO eliminates the need for critic models while providing stable learning signals. Recent works have demonstrated its success, including DeepSeek-R1 \cite{guo2025deepseek}, DeepSeek-Math \cite{shao2024deepseekmath}, and Search-R1 \cite{jin2025search}. Despite these successes, performance on complex multi-step tasks remains limited by the issues as below:
+\begin{itemize}
+ \item \textbf{Inefficient exploration.} Stochastic sampling produces predominantly failed trajectories on difficult problems. When all samples in a group fail, the reward variance becomes zero, yielding null gradients that stall learning \cite{nan2025ngrpo}. Even when some samples succeed, the scarcity of high-reward trajectories limits learning efficiency. Scalar rewards indicate correctness but provide no actionable guidance on why solutions failed or how to discover better ones \cite{zhang2025critique}. This calls for a guided exploration mechanism that leverages language feedback to efficiently synthesize successful trajectories.
+ \item \textbf{Coarse credit assignment.} Trajectory-level rewards fail to distinguish correct intermediate steps from incorrect ones. A trajectory with valid reasoning but wrong final answer gets penalized entirely, inappropriately suppressing the correct reasoning steps that led to the failure point. This discourages the model from learning sound intermediate behaviors.
+ \item \textbf{Training instability from low-reward trajectory dominance.} When failed trajectories dominate training data, they create a fundamental asymmetry. Learning signals from failures only suppress incorrect actions without providing positive guidance on what the model should generate instead. This asymmetry destabilizes training because the model receives overwhelming suppression signals that push the distribution in unpredictable directions, while lacking sufficient directional signals from successful trajectories to guide it toward desired behaviors \cite{wu2025learning}. This problem is exacerbated with off-policy data and with ultra-long, multi-turn interactions. To stabilize learning under such conditions, the training objective must ensure that high-reward trajectories dominate the learning signal, providing clear directional guidance even when outnumbered by failures.
+\end{itemize}
+
+In this paper, we propose R$^3$L, a Reflect, Retry, and Reinforce Learning framework that enhances exploration through language-guided trajectory synthesis, refines credit assignment through pivotal updates, and ensures training stability via positive preference optimization. To improve exploration, R$^3$L employs a reflect-then-retry mechanism that uses language feedback to identify precise failure points in unsuccessful trajectories, then restarts generation from these pivots with corrective guidance to synthesize high-reward trajectories. To refine credit assignment, we introduce Pivotal Credit Assignment that applies updates exclusively to tokens after identified failure points, preserving valid prefixes while correcting critical errors. To ensure stable learning from this synthesized off-policy data, we propose Positive Preference Optimization that reweights advantages to amplify gradients from successful trajectories, ensuring they dominate the learning signal even when outnumbered by failures. Reflection and retry skills are maintained through auxiliary meta-tasks trained on successful corrections. Extensive experiments demonstrate xx improvements over baselines. Our contributions are summarized as follows:
+\begin{itemize}
+ \item We propose a language-guided reflect-then-retry mechanism that synthesizes high-reward trajectories by identifying failure points and restarting generation with corrective guidance, significantly improving exploration efficiency.
+ \item We present Pivotal Credit Assignment that applies updates exclusively after identified failure points, preserving valid reasoning prefixes while learning from critical errors.
+ \item We introduce Positive Preference Optimization that ensures stable off-policy learning by amplifying gradients from successful trajectories to dominate the learning signal.
+\end{itemize}
+
+\input{figure/framework}
+
+\section{Related Work}
+Compared to Proximal Policy Optimization \cite{schulman2017proximal}, Group Relative Policy Optimization (GRPO) eliminates the critic model and estimates advantages by normalizing rewards across sampled responses. Despite its efficiency, two fundamental challenges limit its performance: inefficient exploration and training instability.
+
+\subsection{Inefficient Exploration in GRPO}
+GRPO suffers from inefficient exploration due to null gradients when all samples fail \cite{bamba2025xrpo} and scarcity of high-reward trajectories \cite{wang2025slow}. Sampling-based methods like DAPO \cite{yu2025dapo} and RAFT \cite{dongraft} use oversampling and filtering to ensure gradient validity, but incur significant computational cost. Correction-based methods actively generate improved trajectories through external feedback or self-reflection. HINT \cite{wang2025hint} and Agent-RLVR \cite{da2025agent} use heuristic guidance and external critics. Goedel-Prover-V2 \cite{lin2025goedel} employs scaffolded synthesis. Reflect-Retry-Reward \cite{bensal2025reflect} uses trained self-reflection models. While these methods improve sample quality, they generate substantial off-policy data that can degrade training stability \cite{zheng2025prosperity}. We leverage language feedback to synthesize high-reward trajectories and introduce Positive Preference Optimization to ensure stable off-policy learning.
+
+\subsection{Training Instability in GRPO}
+Training instability stems from high-variance gradients \cite{liu2025understanding} and coarse credit assignment \cite{parthasarathi2025grpo}. To address gradient variance, GSPO \cite{zheng2025group} replaces token-level importance weights with sequence-level ratios. BAPO \cite{xi2025bapo} identifies that off-policy negative samples dominate gradients and proposes adaptive clipping. For credit assignment, trajectory-level rewards incorrectly penalize entire sequences for single errors. Process reward models \cite{wang2024math} provide step-level supervision but are costly and prone to noise \cite{xiong2024watch}. GiGPO \cite{feng2025group} and VinePPO \cite{kazemnejad2024vineppo} offer finer-grained credit assignment through anchor states or Monte Carlo estimates. We stabilize gradients through Positive Preference Optimization and refine credit assignment through Pivotal Credit Assignment that isolates failure points.
+
+\section{Preliminary}
+\subsection{Policy Gradient for LLM}
+A multi-turn trajectory $\tau$ consists of $K$ turns, where each turn $k$ contains an environment input $x_k$ and the response $y_k$. \begin{equation}
+ \tau = (x_1, y_1, x_2, y_2, \ldots, x_K, y_K)
+\end{equation}
+Given a reward function $R(\cdot)$, the objective is to maximize the expected reward:
+\begin{equation}
+ J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} [R(\tau)]
+\end{equation}
+Each response $y_k$ is generated as a sequence of $T_k$ tokens. Let $h_k = (x_1, y_1, \ldots, x_{k-1}, y_{k-1}, x_k)$ denote the history up to turn $k$. The policy gradient is:
+\begin{equation}
+ \nabla_{\theta} J(\theta) \!\!=\!\! \mathbb{E}_{\tau \sim \pi_{\theta}} \!\!\left[ \sum_{k=1}^{K} \!\! \sum_{t=1}^{T_k} \!\! \nabla_{\theta} \log \pi_{\theta}(y_k^t | h_k, y_k^{ R(\tau_i)$):
+\begin{equation}
+\begin{split}
+ \mathcal{D}_{reflect} &= \{([\tau_i, f_i], r_i)\} \\
+ \mathcal{D}_{retry} &= \{(h_{ 0 \\
+ A(\tau_i) & \text{otherwise}
+ \end{cases}
+\end{equation}
+where $A(\tau_i) = R(\tau_i) - \bar{R}$ is the simplified advantage from Eq.\ref{eq:grpo_advantage}, and $\alpha > 1$ is the amplification factor. Perfect trajectories receive constant advantage $\alpha$ regardless of group statistics, ensuring they provide strong positive signals even when training stagnates. Above-average trajectories get amplified advantages. Below-average trajectories retain original negative advantages.
+
+This reweighting addresses the asymmetry between learning from successes and failures. By amplifying positive advantages, we ensure their gradient contribution dominates. We now combine this positive-preferred advantage $\hat{A}$ with the pivotal credit mask of Section \ref{credit assignment} as final R$^3$L objective:
+\begin{equation}
+\begin{split}
+ \mathcal{L}_{R^3L} &= -\mathbb{E}_{\tau \sim \mathcal{G}_{explore}} \\
+ & \left[ \frac{1}{|\tau|} \sum_{k,t} \text{mask}_k^t \cdot \hat{A}_k^t \log \pi_{\theta}(y_k^t | h_k, y_k^{ and repository
+%%% Modifications Copyright 2002–23, Norman Gray,
+%%% and distributed under the terms of the LPPL; see README for discussion.
+%%%
+%%% Added webpage entry type, and url and lastchecked fields.
+%%% Added eprint support.
+%%% Added DOI support.
+%%% Added PUBMED support.
+%%% Added hyperref support.
+%%% Original headers follow...
+
+%%
+%% This is file `acl_natbib_basic.bst',
+%% generated with the docstrip utility.
+%%
+%% The original source files were:
+%%
+%% merlin.mbs (with options: `ay,nat,pres,ed-au,keyxyr,blkyear,dt-beg,yr-per,note-yr,num-xser,pre-edn,xedn,nfss')
+%% ----------------------------------------
+%% *** Intended for ACL conferences ***
+%%
+%% Copyright 1994-2011 Patrick W Daly
+ % ===============================================================
+ % IMPORTANT NOTICE:
+ % This bibliographic style (bst) file has been generated from one or
+ % more master bibliographic style (mbs) files, listed above.
+ %
+ % This generated file can be redistributed and/or modified under the terms
+ % of the LaTeX Project Public License Distributed from CTAN
+ % archives in directory macros/latex/base/lppl.txt; either
+ % version 1 of the License, or any later version.
+ % ===============================================================
+ % Name and version information of the main mbs file:
+ % \ProvidesFile{merlin.mbs}[2011/11/18 4.33 (PWD, AO, DPC)]
+ % For use with BibTeX version 0.99a or later
+ %-------------------------------------------------------------------
+ % This bibliography style file is intended for texts in ENGLISH
+ % This is an author-year citation style bibliography. As such, it is
+ % non-standard LaTeX, and requires a special package file to function properly.
+ % Such a package is natbib.sty by Patrick W. Daly
+ % The form of the \bibitem entries is
+ % \bibitem[Jones et al.(1990)]{key}...
+ % \bibitem[Jones et al.(1990)Jones, Baker, and Smith]{key}...
+ % The essential feature is that the label (the part in brackets) consists
+ % of the author names, as they should appear in the citation, with the year
+ % in parentheses following. There must be no space before the opening
+ % parenthesis!
+ % With natbib v5.3, a full list of authors may also follow the year.
+ % In natbib.sty, it is possible to define the type of enclosures that is
+ % really wanted (brackets or parentheses), but in either case, there must
+ % be parentheses in the label.
+ % The \cite command functions as follows:
+ % \citet{key} ==>> Jones et al. (1990)
+ % \citet*{key} ==>> Jones, Baker, and Smith (1990)
+ % \citep{key} ==>> (Jones et al., 1990)
+ % \citep*{key} ==>> (Jones, Baker, and Smith, 1990)
+ % \citep[chap. 2]{key} ==>> (Jones et al., 1990, chap. 2)
+ % \citep[e.g.][]{key} ==>> (e.g. Jones et al., 1990)
+ % \citep[e.g.][p. 32]{key} ==>> (e.g. Jones et al., 1990, p. 32)
+ % \citeauthor{key} ==>> Jones et al.
+ % \citeauthor*{key} ==>> Jones, Baker, and Smith
+ % \citeyear{key} ==>> 1990
+ %---------------------------------------------------------------------
+
+%% 2025 modified to truncate author lists of more than 20 authors
+
+ENTRY
+ { address
+ archivePrefix
+ author
+ booktitle
+ chapter
+ edition
+ editor
+ eid
+ eprint
+ eprinttype % = archivePrefix
+ howpublished
+ institution
+ journal
+ key
+ month
+ note
+ number
+ organization
+ pages
+ publisher
+ school
+ series
+ title
+ type
+ volume
+ year
+ doi % urlbst
+ pubmed % urlbst
+ url % urlbst
+ lastchecked % urlbst
+ }
+ {}
+ { label extra.label sort.label short.list }
+INTEGERS { output.state before.all mid.sentence after.sentence after.block }
+% urlbst...
+% urlbst constants and state variables
+STRINGS { urlintro
+ eprinturl eprintprefix doiprefix doiurl pubmedprefix pubmedurl
+ citedstring onlinestring linktextstring
+ openinlinelink closeinlinelink }
+INTEGERS { hrefform doiform inlinelinks makeinlinelink
+ addeprints adddoi addpubmed }
+FUNCTION {init.urlbst.variables}
+{
+ % The following constants may be adjusted by hand, if desired
+
+ % The first set allow you to enable or disable certain functionality.
+ #1 'addeprints := % 0=no eprints; 1=include eprints
+ #2 'hrefform := % 0=no crossrefs; 1=hypertex hrefs; 2=hyperref hrefs
+ #1 'inlinelinks := % 0=URLs explicit; 1=URLs attached to titles
+ #1 'adddoi := % 0=no DOI resolver; 1=include it
+ #1 'addpubmed := % 0=no PUBMED resolver; 1=include it
+ #0 'doiform := % 0=with href; 1=with \doi{}
+
+ % String constants, which you _might_ want to tweak.
+ "online" 'onlinestring := % label that a resource is online
+ "[link]" 'linktextstring := % anonymous link text
+ "http://www.ncbi.nlm.nih.gov/pubmed/" 'pubmedurl := % prefix to make URL from PUBMED
+ "https://doi.org/" 'doiurl := % prefix to make URL from DOI
+ "doi:" 'doiprefix := % printed text to introduce DOI
+ "https://arxiv.org/abs/" 'eprinturl := % prefix to make URL from eprint ref
+ "cited " 'citedstring := % label in "lastchecked" remark
+ "arXiv:" 'eprintprefix := % text prefix printed before eprint ref
+ "PMID:" 'pubmedprefix := % text prefix printed before PUBMED ref
+ "URL: " 'urlintro := % text prefix before URL
+
+ % The following are internal state variables, not configuration constants,
+ % so they shouldn't be fiddled with.
+ #0 'makeinlinelink := % state variable managed by possibly.setup.inlinelink
+ "" 'openinlinelink := % ditto
+ "" 'closeinlinelink := % ditto
+}
+INTEGERS {
+ bracket.state
+ outside.brackets
+ open.brackets
+ within.brackets
+ close.brackets
+}
+% ...urlbst to here
+FUNCTION {init.state.consts}
+{ #0 'outside.brackets := % urlbst...
+ #1 'open.brackets :=
+ #2 'within.brackets :=
+ #3 'close.brackets := % ...urlbst to here
+
+ #0 'before.all :=
+ #1 'mid.sentence :=
+ #2 'after.sentence :=
+ #3 'after.block :=
+}
+STRINGS { s t}
+% urlbst
+FUNCTION {output.nonnull.original}
+{ 's :=
+ output.state mid.sentence =
+ { ", " * write$ }
+ { output.state after.block =
+ { add.period$ write$
+ newline$
+ "\newblock " write$
+ }
+ { output.state before.all =
+ 'write$
+ { add.period$ " " * write$ }
+ if$
+ }
+ if$
+ mid.sentence 'output.state :=
+ }
+ if$
+ s
+}
+
+% urlbst...
+% Minimal DOI parsing.
+% Given a DOI on the stack, check whether it starts with 'doiurl' or not.
+% In either case, leave on the stack first a DOI with, and then a DOI without, the URL prefix.
+FUNCTION {parse.doi}
+{
+ #1 doiurl text.length$ substring$
+ doiurl =
+ { doi
+ doi doiurl text.length$ #1 + #999 substring$ }
+ { doiurl doi *
+ doi }
+ if$
+}
+% The following three functions are for handling inlinelink. They wrap
+% a block of text which is potentially output with write$ by multiple
+% other functions, so we don't know the content a priori.
+% They communicate between each other using the variables makeinlinelink
+% (which is true if a link should be made), and closeinlinelink (which holds
+% the string which should close any current link. They can be called
+% at any time, but start.inlinelink will be a no-op unless something has
+% previously set makeinlinelink true, and the two ...end.inlinelink functions
+% will only do their stuff if start.inlinelink has previously set
+% closeinlinelink to be non-empty.
+% (thanks to 'ijvm' for suggested code here)
+FUNCTION {uand}
+{ 'skip$ { pop$ #0 } if$ } % 'and' (which isn't defined at this point in the file)
+FUNCTION {possibly.setup.inlinelink}
+{ makeinlinelink hrefform #0 > uand
+ { doi empty$ adddoi uand
+ { pubmed empty$ addpubmed uand
+ { eprint empty$ addeprints uand
+ { url empty$
+ { "" }
+ { url }
+ if$ }
+ { eprinturl eprint * }
+ if$ }
+ { pubmedurl pubmed * }
+ if$ }
+% { doiurl doi * }
+ { doi empty$
+ { "XXX" }
+ { doi parse.doi pop$ }
+ if$
+ }
+ if$
+ % an appropriately-formatted URL is now on the stack
+ hrefform #1 = % hypertex
+ { "\special {html: }{" * 'openinlinelink :=
+ "\special {html:}" 'closeinlinelink := }
+ { "\href {" swap$ * "} {" * 'openinlinelink := % hrefform=#2 -- hyperref
+ % the space between "} {" matters: a URL of just the right length can cause "\% newline em"
+ "}" 'closeinlinelink := }
+ if$
+ #0 'makeinlinelink :=
+ }
+ 'skip$
+ if$ % makeinlinelink
+}
+FUNCTION {add.inlinelink}
+{ openinlinelink empty$
+ 'skip$
+ { openinlinelink swap$ * closeinlinelink *
+ "" 'openinlinelink :=
+ }
+ if$
+}
+FUNCTION {output.nonnull}
+{ % Save the thing we've been asked to output
+ 's :=
+ % If the bracket-state is close.brackets, then add a close-bracket to
+ % what is currently at the top of the stack, and set bracket.state
+ % to outside.brackets
+ bracket.state close.brackets =
+ { "]" *
+ outside.brackets 'bracket.state :=
+ }
+ 'skip$
+ if$
+ bracket.state outside.brackets =
+ { % We're outside all brackets -- this is the normal situation.
+ % Write out what's currently at the top of the stack, using the
+ % original output.nonnull function.
+ s
+ add.inlinelink
+ output.nonnull.original % invoke the original output.nonnull
+ }
+ { % Still in brackets. Add open-bracket or (continuation) comma, add the
+ % new text (in s) to the top of the stack, and move to the close-brackets
+ % state, ready for next time (unless inbrackets resets it). If we come
+ % into this branch, then output.state is carefully undisturbed.
+ bracket.state open.brackets =
+ { " [" * }
+ { ", " * } % bracket.state will be within.brackets
+ if$
+ s *
+ close.brackets 'bracket.state :=
+ }
+ if$
+}
+
+% Call this function just before adding something which should be presented in
+% brackets. bracket.state is handled specially within output.nonnull.
+FUNCTION {inbrackets}
+{ bracket.state close.brackets =
+ { within.brackets 'bracket.state := } % reset the state: not open nor closed
+ { open.brackets 'bracket.state := }
+ if$
+}
+
+FUNCTION {format.lastchecked}
+{ lastchecked empty$
+ { "" }
+ { inbrackets citedstring lastchecked * }
+ if$
+}
+% ...urlbst to here
+FUNCTION {output}
+{ duplicate$ empty$
+ 'pop$
+ 'output.nonnull
+ if$
+}
+FUNCTION {output.check}
+{ 't :=
+ duplicate$ empty$
+ { pop$ "empty " t * " in " * cite$ * warning$ }
+ 'output.nonnull
+ if$
+}
+FUNCTION {fin.entry.original} % urlbst (renamed from fin.entry, so it can be wrapped below)
+{ add.period$
+ write$
+ newline$
+}
+
+FUNCTION {new.block}
+{ output.state before.all =
+ 'skip$
+ { after.block 'output.state := }
+ if$
+}
+FUNCTION {new.sentence}
+{ output.state after.block =
+ 'skip$
+ { output.state before.all =
+ 'skip$
+ { after.sentence 'output.state := }
+ if$
+ }
+ if$
+}
+FUNCTION {add.blank}
+{ " " * before.all 'output.state :=
+}
+
+FUNCTION {date.block}
+{
+ new.block
+}
+
+FUNCTION {not}
+{ { #0 }
+ { #1 }
+ if$
+}
+FUNCTION {and}
+{ 'skip$
+ { pop$ #0 }
+ if$
+}
+FUNCTION {or}
+{ { pop$ #1 }
+ 'skip$
+ if$
+}
+FUNCTION {new.block.checkb}
+{ empty$
+ swap$ empty$
+ and
+ 'skip$
+ 'new.block
+ if$
+}
+FUNCTION {field.or.null}
+{ duplicate$ empty$
+ { pop$ "" }
+ 'skip$
+ if$
+}
+FUNCTION {emphasize}
+{ duplicate$ empty$
+ { pop$ "" }
+ { "\emph{" swap$ * "}" * }
+ if$
+}
+FUNCTION {tie.or.space.prefix} % puts ~ before the preceding part if it is of length <3
+{ duplicate$ text.length$ #3 <
+ { "~" }
+ { " " }
+ if$
+ swap$
+}
+
+FUNCTION {capitalize}
+{ "u" change.case$ "t" change.case$ }
+
+FUNCTION {space.word}
+{ " " swap$ * " " * }
+ % Here are the language-specific definitions for explicit words.
+ % Each function has a name bbl.xxx where xxx is the English word.
+ % The language selected here is ENGLISH
+FUNCTION {bbl.and}
+{ "and"}
+
+FUNCTION {bbl.etal}
+{ "et~al." }
+
+FUNCTION {bbl.editors}
+{ "editors" }
+
+FUNCTION {bbl.editor}
+{ "editor" }
+
+FUNCTION {bbl.edby}
+{ "edited by" }
+
+FUNCTION {bbl.edition}
+{ "edition" }
+
+FUNCTION {bbl.volume}
+{ "volume" }
+
+FUNCTION {bbl.of}
+{ "of" }
+
+FUNCTION {bbl.number}
+{ "number" }
+
+FUNCTION {bbl.nr}
+{ "no." }
+
+FUNCTION {bbl.in}
+{ "in" }
+
+FUNCTION {bbl.pages}
+{ "pages" }
+
+FUNCTION {bbl.page}
+{ "page" }
+
+FUNCTION {bbl.chapter}
+{ "chapter" }
+
+FUNCTION {bbl.techrep}
+{ "Technical Report" }
+
+FUNCTION {bbl.mthesis}
+{ "Master's thesis" }
+
+FUNCTION {bbl.phdthesis}
+{ "Ph.D. thesis" }
+
+MACRO {jan} {"January"}
+
+MACRO {feb} {"February"}
+
+MACRO {mar} {"March"}
+
+MACRO {apr} {"April"}
+
+MACRO {may} {"May"}
+
+MACRO {jun} {"June"}
+
+MACRO {jul} {"July"}
+
+MACRO {aug} {"August"}
+
+MACRO {sep} {"September"}
+
+MACRO {oct} {"October"}
+
+MACRO {nov} {"November"}
+
+MACRO {dec} {"December"}
+
+MACRO {acmcs} {"ACM Computing Surveys"}
+
+MACRO {acta} {"Acta Informatica"}
+
+MACRO {cacm} {"Communications of the ACM"}
+
+MACRO {ibmjrd} {"IBM Journal of Research and Development"}
+
+MACRO {ibmsj} {"IBM Systems Journal"}
+
+MACRO {ieeese} {"IEEE Transactions on Software Engineering"}
+
+MACRO {ieeetc} {"IEEE Transactions on Computers"}
+
+MACRO {ieeetcad}
+ {"IEEE Transactions on Computer-Aided Design of Integrated Circuits"}
+
+MACRO {ipl} {"Information Processing Letters"}
+
+MACRO {jacm} {"Journal of the ACM"}
+
+MACRO {jcss} {"Journal of Computer and System Sciences"}
+
+MACRO {scp} {"Science of Computer Programming"}
+
+MACRO {sicomp} {"SIAM Journal on Computing"}
+
+MACRO {tocs} {"ACM Transactions on Computer Systems"}
+
+MACRO {tods} {"ACM Transactions on Database Systems"}
+
+MACRO {tog} {"ACM Transactions on Graphics"}
+
+MACRO {toms} {"ACM Transactions on Mathematical Software"}
+
+MACRO {toois} {"ACM Transactions on Office Information Systems"}
+
+MACRO {toplas} {"ACM Transactions on Programming Languages and Systems"}
+
+MACRO {tcs} {"Theoretical Computer Science"}
+
+% bibinfo.check avoids acting on missing fields while bibinfo.warn will
+% issue a warning message if a missing field is detected. Prior to calling
+% the bibinfo functions, the user should push the field value and then its
+% name string, in that order.
+FUNCTION {bibinfo.check}
+{ swap$
+ duplicate$ missing$
+ {
+ pop$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ pop$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {bibinfo.warn}
+{ swap$
+ duplicate$ missing$
+ {
+ swap$ "missing " swap$ * " in " * cite$ * warning$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ "empty " swap$ * " in " * cite$ * warning$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+INTEGERS { nameptr namesleft numnames }
+
+
+STRINGS { bibinfo}
+
+FUNCTION {format.names}
+{ 'bibinfo :=
+ duplicate$ empty$ 'skip$ {
+ 's :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{ff~}{vv~}{ll}{, jj}" % first name first for all authors
+ format.name$
+ bibinfo bibinfo.check
+ 't :=
+ nameptr #1 >
+ {
+ nameptr #19 % truncate after 19 names
+ #1 + =
+ numnames #20 % if there are more than 20 names
+ > and
+ { "others" 't :=
+ #1 'namesleft := }
+ 'skip$
+ if$ % end truncation of long list of names
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ t "others" =
+ {
+ %% " " * bbl.etal *
+ % compute the number of remaining authors
+ " and " * numnames nameptr - #1 + int.to.str$ * " others" *
+ }
+ {
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+ } if$
+}
+FUNCTION {format.names.ed}
+{
+ format.names
+}
+FUNCTION {format.key}
+{ empty$
+ { key field.or.null }
+ { "" }
+ if$
+}
+
+FUNCTION {format.authors}
+{ author "author" format.names
+}
+FUNCTION {get.bbl.editor}
+{ editor num.names$ #1 > 'bbl.editors 'bbl.editor if$ }
+
+FUNCTION {format.editors}
+{ editor "editor" format.names duplicate$ empty$ 'skip$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ *
+ }
+ if$
+}
+FUNCTION {format.note}
+{
+ note empty$
+ { "" }
+ { note #1 #1 substring$
+ duplicate$ "{" =
+ 'skip$
+ { output.state mid.sentence =
+ { "l" }
+ { "u" }
+ if$
+ change.case$
+ }
+ if$
+ note #2 global.max$ substring$ * "note" bibinfo.check
+ }
+ if$
+}
+
+FUNCTION {format.title}
+{ title
+ duplicate$ empty$ 'skip$
+ { "t" change.case$ }
+ if$
+ "title" bibinfo.check
+}
+FUNCTION {format.full.names}
+{'s :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv~}{ll}" format.name$
+ 't :=
+ nameptr #1 >
+ {
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ t "others" =
+ {
+ " " * bbl.etal *
+ }
+ {
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {author.editor.key.full}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {author.key.full}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {editor.key.full}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+}
+
+FUNCTION {make.full.names}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.full
+ { type$ "proceedings" =
+ 'editor.key.full
+ 'author.key.full
+ if$
+ }
+ if$
+}
+
+FUNCTION {output.bibitem.original} % urlbst (renamed from output.bibitem, so it can be wrapped below)
+{ newline$
+ "\bibitem[{" write$
+ label write$
+ ")" make.full.names duplicate$ short.list =
+ { pop$ }
+ { * }
+ if$
+ "}]{" * write$
+ cite$ write$
+ "}" write$
+ newline$
+ ""
+ before.all 'output.state :=
+}
+
+FUNCTION {n.dashify}
+{
+ 't :=
+ ""
+ { t empty$ not }
+ { t #1 #1 substring$ "-" =
+ { t #1 #2 substring$ "--" = not
+ { "--" *
+ t #2 global.max$ substring$ 't :=
+ }
+ { { t #1 #1 substring$ "-" = }
+ { "-" *
+ t #2 global.max$ substring$ 't :=
+ }
+ while$
+ }
+ if$
+ }
+ { t #1 #1 substring$ *
+ t #2 global.max$ substring$ 't :=
+ }
+ if$
+ }
+ while$
+}
+
+FUNCTION {word.in}
+{ bbl.in capitalize
+ " " * }
+
+FUNCTION {format.date}
+{ year "year" bibinfo.check duplicate$ empty$
+ {
+ }
+ 'skip$
+ if$
+ extra.label *
+ before.all 'output.state :=
+ after.sentence 'output.state :=
+}
+FUNCTION {format.btitle}
+{ title "title" bibinfo.check
+ duplicate$ empty$ 'skip$
+ {
+ emphasize
+ }
+ if$
+}
+FUNCTION {either.or.check}
+{ empty$
+ 'pop$
+ { "can't use both " swap$ * " fields in " * cite$ * warning$ }
+ if$
+}
+FUNCTION {format.bvolume}
+{ volume empty$
+ { "" }
+ { bbl.volume volume tie.or.space.prefix
+ "volume" bibinfo.check * *
+ series "series" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ bbl.of space.word * swap$
+ emphasize * }
+ if$
+ "volume and number" number either.or.check
+ }
+ if$
+}
+FUNCTION {format.number.series}
+{ volume empty$
+ { number empty$
+ { series field.or.null }
+ { series empty$
+ { number "number" bibinfo.check }
+ { output.state mid.sentence =
+ { bbl.number }
+ { bbl.number capitalize }
+ if$
+ number tie.or.space.prefix "number" bibinfo.check * *
+ bbl.in space.word *
+ series "series" bibinfo.check *
+ }
+ if$
+ }
+ if$
+ }
+ { "" }
+ if$
+}
+
+FUNCTION {format.edition}
+{ edition duplicate$ empty$ 'skip$
+ {
+ output.state mid.sentence =
+ { "l" }
+ { "t" }
+ if$ change.case$
+ "edition" bibinfo.check
+ " " * bbl.edition *
+ }
+ if$
+}
+INTEGERS { multiresult }
+FUNCTION {multi.page.check}
+{ 't :=
+ #0 'multiresult :=
+ { multiresult not
+ t empty$ not
+ and
+ }
+ { t #1 #1 substring$
+ duplicate$ "-" =
+ swap$ duplicate$ "," =
+ swap$ "+" =
+ or or
+ { #1 'multiresult := }
+ { t #2 global.max$ substring$ 't := }
+ if$
+ }
+ while$
+ multiresult
+}
+FUNCTION {format.pages}
+{ pages duplicate$ empty$ 'skip$
+ { duplicate$ multi.page.check
+ {
+ bbl.pages swap$
+ n.dashify
+ }
+ {
+ bbl.page swap$
+ }
+ if$
+ tie.or.space.prefix
+ "pages" bibinfo.check
+ * *
+ }
+ if$
+}
+FUNCTION {format.journal.pages}
+{ pages duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$
+ { pop$ pop$ format.pages }
+ {
+ ":" *
+ swap$
+ n.dashify
+ "pages" bibinfo.check
+ *
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.journal.eid}
+{ eid "eid" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$ 'skip$
+ {
+ ":" *
+ }
+ if$
+ swap$ *
+ }
+ if$
+}
+FUNCTION {format.vol.num.pages}
+{ volume field.or.null
+ duplicate$ empty$ 'skip$
+ {
+ "volume" bibinfo.check
+ }
+ if$
+ number "number" bibinfo.check duplicate$ empty$ 'skip$
+ {
+ swap$ duplicate$ empty$
+ { "there's a number but no volume in " cite$ * warning$ }
+ 'skip$
+ if$
+ swap$
+ "(" swap$ * ")" *
+ }
+ if$ *
+ eid empty$
+ { format.journal.pages }
+ { format.journal.eid }
+ if$
+}
+
+FUNCTION {format.chapter}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ }
+ if$
+}
+
+FUNCTION {format.chapter.pages}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ pages empty$
+ 'skip$
+ { ", " * format.pages * }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.booktitle}
+{
+ booktitle "booktitle" bibinfo.check
+ emphasize
+}
+FUNCTION {format.in.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.in.ed.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ editor "editor" format.names.ed duplicate$ empty$ 'pop$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ ", " *
+ * swap$
+ * }
+ if$
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.thesis.type}
+{ type duplicate$ empty$
+ 'pop$
+ { swap$ pop$
+ "t" change.case$ "type" bibinfo.check
+ }
+ if$
+}
+FUNCTION {format.tr.number}
+{ number "number" bibinfo.check
+ type duplicate$ empty$
+ { pop$ bbl.techrep }
+ 'skip$
+ if$
+ "type" bibinfo.check
+ swap$ duplicate$ empty$
+ { pop$ "t" change.case$ }
+ { tie.or.space.prefix * * }
+ if$
+}
+FUNCTION {format.article.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.book.crossref}
+{ volume duplicate$ empty$
+ { "empty volume in " cite$ * "'s crossref of " * crossref * warning$
+ pop$ word.in
+ }
+ { bbl.volume
+ capitalize
+ swap$ tie.or.space.prefix "volume" bibinfo.check * * bbl.of space.word *
+ }
+ if$
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.incoll.inproc.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.org.or.pub}
+{ 't :=
+ ""
+ address empty$ t empty$ and
+ 'skip$
+ {
+ t empty$
+ { address "address" bibinfo.check *
+ }
+ { t *
+ address empty$
+ 'skip$
+ { ", " * address "address" bibinfo.check * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.publisher.address}
+{ publisher "publisher" bibinfo.warn format.org.or.pub
+}
+
+FUNCTION {format.organization.address}
+{ organization "organization" bibinfo.check format.org.or.pub
+}
+
+FUNCTION {archiveprefix.or.eprinttype} % holder for eprinttype with archiveprefix precedence
+{
+ archiveprefix empty$
+ {
+ eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ { eprinttype }
+ if$
+ }
+ { archiveprefix }
+ if$
+}
+
+FUNCTION {output.eprint} % this is only used with the @misc record type (common for arXiv and other preprint server bibtex records)
+{
+ eprint empty$
+ {% if eprint field is empty
+ publisher field.or.null "arXiv" = % field.or.null here helps when no publisher field in the record
+ { publisher " preprint" * } % add " preprint" to publisher with the idea that publisher is the name of the preprint server
+ { "" } % if publisher != "arXiv" then empty output
+ if$
+ emphasize % no output function after emphasize because nothing goes after this
+ }
+ {% if eprint field is not empty
+ archiveprefix.or.eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ {% if archiveprefix or eprinttype fields are not empty
+ journal empty$
+ { "Preprint" } % if journal field is empty: output just "Preprint" emphasized like a journal name
+ { journal } % if journal field is not empty, output it (takes precedence)
+ if$
+ emphasize output % emphasize what we formed before, setting output as a border to the subblock that follows with the comma delimiter
+ archiveprefix.or.eprinttype ":" * eprint * % subblock with eprinttype and eprint number
+ }
+ if$
+ }
+ if$
+}
+
+% urlbst...
+% Functions for making hypertext links.
+% In all cases, the stack has (link-text href-url)
+%
+% make 'null' specials
+FUNCTION {make.href.null}
+{
+ pop$
+}
+% make hypertex specials
+FUNCTION {make.href.hypertex}
+{
+ "\special {html: }" * swap$ *
+ "\special {html:}" *
+}
+% make hyperref specials
+FUNCTION {make.href.hyperref}
+{
+ "\href {" swap$ * "} {\path{" * swap$ * "}}" *
+}
+FUNCTION {make.href}
+{ hrefform #2 =
+ 'make.href.hyperref % hrefform = 2
+ { hrefform #1 =
+ 'make.href.hypertex % hrefform = 1
+ 'make.href.null % hrefform = 0 (or anything else)
+ if$
+ }
+ if$
+}
+
+% If inlinelinks is true, then format.url should be a no-op, since it's
+% (a) redundant, and (b) could end up as a link-within-a-link.
+FUNCTION {format.url}
+{ inlinelinks #1 = url empty$ or
+ { "" }
+ { hrefform #1 =
+ { % special case -- add HyperTeX specials
+ urlintro "\url{" url * "}" * url make.href.hypertex * }
+ { urlintro "\url{" * url * "}" * }
+ if$
+ }
+ if$
+}
+FUNCTION {format.eprint}
+{ eprint empty$
+ { "" }
+ { eprintprefix eprint * eprinturl eprint * make.href }
+ if$
+}
+
+FUNCTION {format.doi}
+{ doi empty$
+ { "" }
+ { doi parse.doi % leaves "https://doi.org/DOI" DOI on the stack
+ 's := 't :=
+ doiform #1 =
+ { "\doi{" s * "}" * }
+ { doiprefix s * t make.href }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.pubmed}
+{ pubmed empty$
+ { "" }
+ { pubmedprefix pubmed * pubmedurl pubmed * make.href }
+ if$
+}
+
+% Output a URL. We can't use the more normal idiom (something like
+% `format.url output'), because the `inbrackets' within
+% format.lastchecked applies to everything between calls to `output',
+% so that `format.url format.lastchecked * output' ends up with both
+% the URL and the lastchecked in brackets.
+FUNCTION {output.url}
+{ url empty$
+ 'skip$
+ { new.block
+ format.url output
+ format.lastchecked output
+ }
+ if$
+}
+
+FUNCTION {output.web.refs}
+{
+ new.block
+ inlinelinks
+ 'skip$ % links were inline -- don't repeat them
+ { % If the generated DOI will be the same as the URL,
+ % then don't print the URL (thanks to Joseph Wright
+ % for (the original version of) this code,
+ % at http://tex.stackexchange.com/questions/5660)
+ adddoi
+ doi empty$ { "X" } { doi parse.doi pop$ } if$ % DOI URL to be generated
+ url empty$ { "Y" } { url } if$ % the URL, or "Y" if empty
+ = % are the strings equal?
+ and
+ 'skip$
+ { output.url }
+ if$
+ addeprints eprint empty$ not and
+ { format.eprint output.nonnull }
+ 'skip$
+ if$
+ adddoi doi empty$ not and
+ { format.doi output.nonnull }
+ 'skip$
+ if$
+ addpubmed pubmed empty$ not and
+ { format.pubmed output.nonnull }
+ 'skip$
+ if$
+ }
+ if$
+}
+
+% Wrapper for output.bibitem.original.
+% If the URL field is not empty, set makeinlinelink to be true,
+% so that an inline link will be started at the next opportunity
+FUNCTION {output.bibitem}
+{ outside.brackets 'bracket.state :=
+ output.bibitem.original
+ inlinelinks url empty$ not doi empty$ not or pubmed empty$ not or eprint empty$ not or and
+ { #1 'makeinlinelink := }
+ { #0 'makeinlinelink := }
+ if$
+}
+
+% Wrapper for fin.entry.original
+FUNCTION {fin.entry}
+{ output.web.refs % urlbst
+ makeinlinelink % ooops, it appears we didn't have a title for inlinelink
+ { possibly.setup.inlinelink % add some artificial link text here, as a fallback
+ linktextstring output.nonnull }
+ 'skip$
+ if$
+ bracket.state close.brackets = % urlbst
+ { "]" * }
+ 'skip$
+ if$
+ fin.entry.original
+}
+
+% Webpage entry type.
+% Title and url fields required;
+% author, note, year, month, and lastchecked fields optional
+% See references
+% ISO 690-2 http://www.nlc-bnc.ca/iso/tc46sc9/standard/690-2e.htm
+% http://www.classroom.net/classroom/CitingNetResources.html
+% http://neal.ctstateu.edu/history/cite.html
+% http://www.cas.usf.edu/english/walker/mla.html
+% for citation formats for web pages.
+FUNCTION {webpage}
+{ output.bibitem
+ author empty$
+ { editor empty$
+ 'skip$ % author and editor both optional
+ { format.editors output.nonnull }
+ if$
+ }
+ { editor empty$
+ { format.authors output.nonnull }
+ { "can't use both author and editor fields in " cite$ * warning$ }
+ if$
+ }
+ if$
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$
+ format.title "title" output.check
+ inbrackets onlinestring output
+ new.block
+ year empty$
+ 'skip$
+ { format.date "year" output.check }
+ if$
+ % We don't need to output the URL details ('lastchecked' and 'url'),
+ % because fin.entry does that for us, using output.web.refs. The only
+ % reason we would want to put them here is if we were to decide that
+ % they should go in front of the rather miscellaneous information in 'note'.
+ new.block
+ note output
+ fin.entry
+}
+% ...urlbst to here
+
+
+FUNCTION {article}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ {
+ journal
+ "journal" bibinfo.check
+ emphasize
+ "journal" output.check
+ possibly.setup.inlinelink format.vol.num.pages output% urlbst
+ }
+ { format.article.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {book}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ crossref missing$
+ { format.bvolume output
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {booklet}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {inbook}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ crossref missing$
+ {
+ format.edition output
+ format.bvolume output
+ format.chapter "chapter" output.check
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ format.chapter "chapter" output.check
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {incollection}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.ed.booktitle "booktitle" output.check
+ format.edition output
+ format.bvolume output
+ format.number.series output
+ format.chapter.pages output
+ new.sentence
+ format.publisher.address output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.chapter.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {inproceedings}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.booktitle "booktitle" output.check
+ format.bvolume output
+ format.number.series output
+ format.pages output
+ address "address" bibinfo.check output
+ new.sentence
+ organization "organization" bibinfo.check output
+ publisher "publisher" bibinfo.check output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {conference} { inproceedings }
+FUNCTION {manual}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ organization address new.block.checkb
+ organization "organization" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {mastersthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ bbl.mthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ month "month" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {misc}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ new.block
+ output.eprint output
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {phdthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle
+ "title" output.check
+ new.block
+ bbl.phdthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {presentation}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ format.organization.address "organization and address" output.check
+ month "month" output.check
+ year "year" output.check
+ new.block
+ format.note output
+ new.sentence
+ type missing$ 'skip$
+ {"(" type capitalize * ")" * output}
+ if$
+ fin.entry
+}
+
+FUNCTION {proceedings}
+{ output.bibitem
+ format.editors output
+ editor format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.bvolume output
+ format.number.series output
+ new.sentence
+ publisher empty$
+ { format.organization.address output }
+ { organization "organization" bibinfo.check output
+ new.sentence
+ format.publisher.address output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {techreport}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ format.tr.number output.nonnull
+ institution "institution" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {unpublished}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ format.note "note" output.check
+ fin.entry
+}
+
+FUNCTION {default.type} { misc }
+READ
+FUNCTION {sortify}
+{ purify$
+ "l" change.case$
+}
+INTEGERS { len }
+FUNCTION {chop.word}
+{ 's :=
+ 'len :=
+ s #1 len substring$ =
+ { s len #1 + global.max$ substring$ }
+ 's
+ if$
+}
+FUNCTION {format.lab.names}
+{ 's :=
+ "" 't :=
+ s #1 "{vv~}{ll}" format.name$
+ s num.names$ duplicate$
+ #2 >
+ { pop$
+ " " * bbl.etal *
+ }
+ { #2 <
+ 'skip$
+ { s #2 "{ff }{vv }{ll}{ jj}" format.name$ "others" =
+ {
+ " " * bbl.etal *
+ }
+ { bbl.and space.word * s #2 "{vv~}{ll}" format.name$
+ * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+
+FUNCTION {author.key.label}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {author.editor.key.label}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {editor.key.label}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+}
+
+FUNCTION {calc.short.authors}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.label
+ { type$ "proceedings" =
+ 'editor.key.label
+ 'author.key.label
+ if$
+ }
+ if$
+ 'short.list :=
+}
+
+FUNCTION {calc.label}
+{ calc.short.authors
+ short.list
+ "("
+ *
+ year duplicate$ empty$
+ short.list key field.or.null = or
+ { pop$ "" }
+ 'skip$
+ if$
+ *
+ 'label :=
+}
+
+FUNCTION {sort.format.names}
+{ 's :=
+ #1 'nameptr :=
+ ""
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv{ } }{ll{ }}{ ff{ }}{ jj{ }}"
+ format.name$ 't :=
+ nameptr #1 >
+ {
+ " " *
+ namesleft #1 = t "others" = and
+ { "zzzzz" 't := }
+ 'skip$
+ if$
+ t sortify *
+ }
+ { t sortify * }
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {sort.format.title}
+{ 't :=
+ "A " #2
+ "An " #3
+ "The " #4 t chop.word
+ chop.word
+ chop.word
+ sortify
+ #1 global.max$ substring$
+}
+FUNCTION {author.sort}
+{ author empty$
+ { key empty$
+ { "to sort, need author or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {author.editor.sort}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { "to sort, need author, editor, or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {editor.sort}
+{ editor empty$
+ { key empty$
+ { "to sort, need editor or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+}
+FUNCTION {presort}
+{ calc.label
+ label sortify
+ " "
+ *
+ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.sort
+ { type$ "proceedings" =
+ 'editor.sort
+ 'author.sort
+ if$
+ }
+ if$
+ #1 entry.max$ substring$
+ 'sort.label :=
+ sort.label
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+
+ITERATE {presort}
+SORT
+STRINGS { last.label next.extra }
+INTEGERS { last.extra.num last.extra.num.extended last.extra.num.blank number.label }
+FUNCTION {initialize.extra.label.stuff}
+{ #0 int.to.chr$ 'last.label :=
+ "" 'next.extra :=
+ #0 'last.extra.num :=
+ "a" chr.to.int$ #1 - 'last.extra.num.blank :=
+ last.extra.num.blank 'last.extra.num.extended :=
+ #0 'number.label :=
+}
+FUNCTION {forward.pass}
+{ last.label label =
+ { last.extra.num #1 + 'last.extra.num :=
+ last.extra.num "z" chr.to.int$ >
+ { "a" chr.to.int$ 'last.extra.num :=
+ last.extra.num.extended #1 + 'last.extra.num.extended :=
+ }
+ 'skip$
+ if$
+ last.extra.num.extended last.extra.num.blank >
+ { last.extra.num.extended int.to.chr$
+ last.extra.num int.to.chr$
+ * 'extra.label := }
+ { last.extra.num int.to.chr$ 'extra.label := }
+ if$
+ }
+ { "a" chr.to.int$ 'last.extra.num :=
+ "" 'extra.label :=
+ label 'last.label :=
+ }
+ if$
+ number.label #1 + 'number.label :=
+}
+FUNCTION {reverse.pass}
+{ next.extra "b" =
+ { "a" 'extra.label := }
+ 'skip$
+ if$
+ extra.label 'next.extra :=
+ extra.label
+ duplicate$ empty$
+ 'skip$
+ { year field.or.null #-1 #1 substring$ chr.to.int$ #65 <
+ { "{\natexlab{" swap$ * "}}" * }
+ { "{(\natexlab{" swap$ * "})}" * }
+ if$ }
+ if$
+ 'extra.label :=
+ label extra.label * 'label :=
+}
+EXECUTE {initialize.extra.label.stuff}
+ITERATE {forward.pass}
+REVERSE {reverse.pass}
+FUNCTION {bib.sort.order}
+{ sort.label
+ " "
+ *
+ year field.or.null sortify
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+ITERATE {bib.sort.order}
+SORT
+FUNCTION {begin.bib}
+{ preamble$ empty$
+ 'skip$
+ { preamble$ write$ newline$ }
+ if$
+ "\begin{thebibliography}{" number.label int.to.str$ * "}" *
+ write$ newline$
+ "\providecommand{\natexlab}[1]{#1}"
+ write$ newline$
+}
+EXECUTE {begin.bib}
+EXECUTE {init.urlbst.variables} % urlbst
+EXECUTE {init.state.consts}
+ITERATE {call.type$}
+FUNCTION {end.bib}
+{ newline$
+ "\end{thebibliography}" write$ newline$
+}
+EXECUTE {end.bib}
+%% End of customized bst file
+%%
+%% End of file `acl_natbib_basic.bst'.
diff --git a/writing/custom.bib b/writing/custom.bib
new file mode 100644
index 0000000000..7073d82119
--- /dev/null
+++ b/writing/custom.bib
@@ -0,0 +1,284 @@
+% Use this file for citations not found in the ACL Anthology (contained in "anthology.bib").
+
+@article{schulman2017proximal,
+ title={Proximal policy optimization algorithms},
+ author={Schulman, John and Wolski, Filip and Dhariwal, Prafulla and Radford, Alec and Klimov, Oleg},
+ journal={arXiv preprint arXiv:1707.06347},
+ year={2017}
+}
+
+@article{shao2024deepseekmath,
+ title={Deepseekmath: Pushing the limits of mathematical reasoning in open language models},
+ author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Bi, Xiao and Zhang, Haowei and Zhang, Mingchuan and Li, YK and Wu, Yang and others},
+ journal={arXiv preprint arXiv:2402.03300},
+ year={2024}
+}
+
+@article{yu2025dapo,
+ title={Dapo: An open-source llm reinforcement learning system at scale},
+ author={Yu, Qiying and Zhang, Zheng and Zhu, Ruofei and Yuan, Yufeng and Zuo, Xiaochen and Yue, Yu and Dai, Weinan and Fan, Tiantian and Liu, Gaohong and Liu, Lingjun and others},
+ journal={arXiv preprint arXiv:2503.14476},
+ year={2025}
+}
+
+@article{dongraft,
+ title={RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment},
+ author={Dong, Hanze and Xiong, Wei and Goyal, Deepanshu and Zhang, Yihan and Chow, Winnie and Pan, Rui and Diao, Shizhe and Zhang, Jipeng and SHUM, KaShun and Zhang, Tong},
+ journal={Transactions on Machine Learning Research},
+ year={2023}
+}
+
+@article{wang2025hint,
+ title={HINT: Helping Ineffective Rollouts Navigate Towards Effectiveness},
+ author={Wang, Xinyi and Han, Jinyi and Jiang, Zishang and Li, Tingyun and Liang, Jiaqing and Jiang, Sihang and Dai, Zhaoqian and Ma, Shuguang and Yu, Fei and Xiao, Yanghua},
+ journal={arXiv preprint arXiv:2510.09388},
+ year={2025}
+}
+
+@article{da2025agent,
+ title={Agent-RLVR: Training Software Engineering Agents via Guidance and Environment Rewards},
+ author={Da, Jeff and Wang, Clinton and Deng, Xiang and Ma, Yuntao and Barhate, Nikhil and Hendryx, Sean},
+ journal={arXiv preprint arXiv:2506.11425},
+ year={2025}
+}
+
+@article{lin2025goedel,
+ title={Goedel-prover-v2: Scaling formal theorem proving with scaffolded data synthesis and self-correction},
+ author={Lin, Yong and Tang, Shange and Lyu, Bohan and Yang, Ziran and Chung, Jui-Hui and Zhao, Haoyu and Jiang, Lai and Geng, Yihan and Ge, Jiawei and Sun, Jingruo and others},
+ journal={arXiv preprint arXiv:2508.03613},
+ year={2025}
+}
+
+@article{zheng2025group,
+ title={Group sequence policy optimization},
+ author={Zheng, Chujie and Liu, Shixuan and Li, Mingze and Chen, Xiong-Hui and Yu, Bowen and Gao, Chang and Dang, Kai and Liu, Yuqiong and Men, Rui and Yang, An and others},
+ journal={arXiv preprint arXiv:2507.18071},
+ year={2025}
+}
+
+@article{xi2025bapo,
+ title={BAPO: Stabilizing Off-Policy Reinforcement Learning for LLMs via Balanced Policy Optimization with Adaptive Clipping},
+ author={Xi, Zhiheng and Guo, Xin and Nan, Yang and Zhou, Enyu and Shen, Junrui and Chen, Wenxiang and Liu, Jiaqi and Huang, Jixuan and Zhang, Zhihao and Guo, Honglin and others},
+ journal={arXiv preprint arXiv:2510.18927},
+ year={2025}
+}
+
+@inproceedings{wang2024math,
+ title={Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations},
+ author={Wang, Peiyi and Li, Lei and Shao, Zhihong and Xu, Runxin and Dai, Damai and Li, Yifei and Chen, Deli and Wu, Yu and Sui, Zhifang},
+ booktitle={ACL},
+ pages={9426--9439},
+ year={2024}
+}
+
+@inproceedings{xiong2024watch,
+ title={Watch Every Step! LLM Agent Learning via Iterative Step-level Process Refinement},
+ author={Xiong, Weimin and Song, Yifan and Zhao, Xiutian and Wu, Wenhao and Wang, Xun and Wang, Ke and Li, Cheng and Peng, Wei and Li, Sujian},
+ booktitle={EMNLP},
+ pages={1556--1572},
+ year={2024}
+}
+
+@article{feng2025group,
+ title={Group-in-group policy optimization for llm agent training},
+ author={Feng, Lang and Xue, Zhenghai and Liu, Tingcong and An, Bo},
+ journal={arXiv preprint arXiv:2505.10978},
+ year={2025}
+}
+
+@article{kazemnejad2024vineppo,
+ title={VinePPO: Refining Credit Assignment in RL Training of LLMs},
+ author={Kazemnejad, Amirhossein and Aghajohari, Milad and Portelance, Eva and Sordoni, Alessandro and Reddy, Siva and Courville, Aaron and Roux, Nicolas Le},
+ journal={arXiv preprint arXiv:2410.01679},
+ year={2024}
+}
+
+@article{lee2020stochastic,
+ title={Stochastic latent actor-critic: Deep reinforcement learning with a latent variable model},
+ author={Lee, Alex X and Nagabandi, Anusha and Abbeel, Pieter and Levine, Sergey},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ pages={741--752},
+ year={2020}
+}
+
+@article{bensal2025reflect,
+ title={Reflect, Retry, Reward: Self-Improving LLMs via Reinforcement Learning},
+ author={Bensal, Shelly and Jamil, Umar and Bryant, Christopher and Russak, Melisa and Kamble, Kiran and Mozolevskyi, Dmytro and Ali, Muayad and AlShikh, Waseem},
+ journal={arXiv preprint arXiv:2505.24726},
+ year={2025}
+}
+
+@article{zheng2025prosperity,
+ title={Prosperity before Collapse: How Far Can Off-Policy RL Reach with Stale Data on LLMs?},
+ author={Zheng, Haizhong and Zhao, Jiawei and Chen, Bedi},
+ journal={arXiv preprint arXiv:2510.01161},
+ year={2025}
+}
+
+@article{liu2025understanding,
+ title={Understanding r1-zero-like training: A critical perspective},
+ author={Liu, Zichen and Chen, Changyu and Li, Wenjun and Qi, Penghui and Pang, Tianyu and Du, Chao and Lee, Wee Sun and Lin, Min},
+ journal={arXiv preprint arXiv:2503.20783},
+ year={2025}
+}
+
+@article{parthasarathi2025grpo,
+ title={GRPO-$\lambda$: Credit Assignment improves LLM Reasoning},
+author={Parthasarathi, Prasanna and Reymond, Mathieu and Chen, Boxing and Cui, Yufei and Chandar, Sarath},
+ journal={arXiv preprint arXiv:2510.00194},
+ year={2025}
+}
+
+@article{bamba2025xrpo,
+ title={XRPO: Pushing the limits of GRPO with Targeted Exploration and Exploitation},
+ author={Bamba, Udbhav and Fang, Minghao and Yu, Yifan and Zheng, Haizhong and Lai, Fan},
+ journal={arXiv preprint arXiv:2510.06672},
+ year={2025}
+}
+
+@article{wang2025slow,
+ title={Slow-Fast Policy Optimization: Reposition-Before-Update for LLM Reasoning},
+ author={Wang, Ziyan and Wang, Zheng and Fu, Jie and Qu, Xingwei and Cheng, Qi and Tang, Shengpu and Zhang, Minjia and Huo, Xiaoming},
+ journal={arXiv preprint arXiv:2510.04072},
+ year={2025}
+}
+
+@article{cui2025enhancing,
+ title={Enhancing Tool Learning in Large Language Models with Hierarchical Error Checklists},
+ author={Cui, Yue and Yao, Liuyi and Tao, Shuchang and Shi, Weijie and Li, Yaliang and Ding, Bolin and Zhou, Xiaofang},
+ journal={arXiv preprint arXiv:2506.00042},
+ year={2025}
+}
+
+@article{shi2025semantic,
+ title={Semantic-guided Diverse Decoding for Large Language Model},
+ author={Shi, Weijie and Cui, Yue and Wu, Yaguang and Fang, Jingzhi and Zhang, Shibo and Li, Mengze and Han, Sirui and Zhu, Jia and Xu, Jiajie and Zhou, Xiaofang},
+ journal={arXiv preprint arXiv:2506.23601},
+ year={2025}
+}
+
+@article{plaat2025agentic,
+ title={Agentic large language models, a survey},
+ author={Plaat, Aske and van Duijn, Max and van Stein, Niki and Preuss, Mike and van der Putten, Peter and Batenburg, Kees Joost},
+ journal={arXiv preprint arXiv:2503.23037},
+ year={2025}
+}
+
+@article{guo2025deepseek,
+ title={Deepseek-r1: Incentivizing reasoning capability in llms via reinforcement learning},
+ author={Guo, Daya and Yang, Dejian and Zhang, Haowei and Song, Junxiao and Zhang, Ruoyu and Xu, Runxin and Zhu, Qihao and Ma, Shirong and Wang, Peiyi and Bi, Xiao and others},
+ journal={arXiv preprint arXiv:2501.12948},
+ year={2025}
+}
+
+@article{jin2025search,
+ title={Search-r1: Training llms to reason and leverage search engines with reinforcement learning},
+ author={Jin, Bowen and Zeng, Hansi and Yue, Zhenrui and Yoon, Jinsung and Arik, Sercan and Wang, Dong and Zamani, Hamed and Han, Jiawei},
+ journal={arXiv preprint arXiv:2503.09516},
+ year={2025}
+}
+
+@article{nan2025ngrpo,
+ title={Ngrpo: Negative-enhanced group relative policy optimization},
+ author={Nan, Gongrui and Chen, Siye and Huang, Jing and Lu, Mengyu and Wang, Dexun and Xie, Chunmei and Xiong, Weiqi and Zeng, Xianzhou and Zhou, Qixuan and Li, Yadong and others},
+ journal={arXiv preprint arXiv:2509.18851},
+ year={2025}
+}
+
+@article{zhang2025critique,
+ title={Critique-grpo: Advancing llm reasoning with natural language and numerical feedback},
+ author={Zhang, Xiaoying and Sun, Hao and Zhang, Yipeng and Feng, Kaituo and Lu, Chaochao and Yang, Chao and Meng, Helen},
+ journal={arXiv preprint arXiv:2506.03106},
+ year={2025}
+}
+
+@inproceedings{wu2025learning,
+ title={Learning from imperfect demonstrations with self-supervision for robotic manipulation},
+ author={Wu, Kun and Liu, Ning and Zhao, Zhen and Qiu, Di and Li, Jinming and Che, Zhengping and Xu, Zhiyuan and Tang, Jian},
+ booktitle={2025 IEEE International Conference on Robotics and Automation (ICRA)},
+ pages={16899--16906},
+ year={2025},
+ organization={IEEE}
+}
+
+@article{qiu2025agentdistill,
+ title={AgentDistill: Training-Free Agent Distillation with Generalizable MCP Boxes},
+ author={Qiu, Jiahao and Juan, Xinzhe and Wang, Yimin and Yang, Ling and Qi, Xuan and Zhang, Tongcheng and Guo, Jiacheng and Lu, Yifu and Yao, Zixin and Wang, Hongru and others},
+ journal={arXiv preprint arXiv:2506.14728},
+ year={2025}
+}
+
+@article{kujanpaa2025efficient,
+ title={Efficient Knowledge Injection in LLMs via Self-Distillation},
+ author={Kujanp{\"a}{\"a}, Kalle and Marttinen, Pekka and Valpola, Harri and Ilin, Alexander},
+ journal={Transactions on Machine Learning Research},
+ year={2025}
+}
+
+@article{shridhar2020alfworld,
+ title={Alfworld: Aligning text and embodied environments for interactive learning},
+ author={Shridhar, Mohit and Yuan, Xingdi and C{\^o}t{\'e}, Marc-Alexandre and Bisk, Yonatan and Trischler, Adam and Hausknecht, Matthew},
+ journal={arXiv preprint arXiv:2010.03768},
+ year={2020}
+}
+
+@article{yao2022webshop,
+ title={Webshop: Towards scalable real-world web interaction with grounded language agents},
+ author={Yao, Shunyu and Chen, Howard and Yang, John and Narasimhan, Karthik},
+ journal={Advances in Neural Information Processing Systems},
+ volume={35},
+ pages={20744--20757},
+ year={2022}
+}
+
+@article{wang2022scienceworld,
+ title={Scienceworld: Is your agent smarter than a 5th grader?},
+ author={Wang, Ruoyao and Jansen, Peter and C{\^o}t{\'e}, Marc-Alexandre and Ammanabrolu, Prithviraj},
+ journal={arXiv preprint arXiv:2203.07540},
+ year={2022}
+}
+
+@article{cobbe2021training,
+ title={Training verifiers to solve math word problems},
+ author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and others},
+ journal={arXiv preprint arXiv:2110.14168},
+ year={2021}
+}
+
+@inproceedings{lightman2023let,
+ title={Let's verify step by step},
+ author={Lightman, Hunter and Kosaraju, Vineet and Burda, Yuri and Edwards, Harrison and Baker, Bowen and Lee, Teddy and Leike, Jan and Schulman, John and Sutskever, Ilya and Cobbe, Karl},
+ booktitle={The Twelfth International Conference on Learning Representations},
+ year={2023}
+}
+
+@article{lewkowycz2022solving,
+ title={Solving quantitative reasoning problems with language models},
+ author={Lewkowycz, Aitor and Andreassen, Anders and Dohan, David and Dyer, Ethan and Michalewski, Henryk and Ramasesh, Vinay and Slone, Ambrose and Anil, Cem and Schlag, Imanol and Gutman-Solo, Theo and others},
+ journal={Advances in neural information processing systems},
+ volume={35},
+ pages={3843--3857},
+ year={2022}
+}
+
+@article{gao2024omni,
+ title={Omni-math: A universal olympiad level mathematic benchmark for large language models},
+ author={Gao, Bofei and Song, Feifan and Yang, Zhe and Cai, Zefan and Miao, Yibo and Dong, Qingxiu and Li, Lei and Ma, Chenghao and Chen, Liang and Xu, Runxin and others},
+ journal={arXiv preprint arXiv:2410.07985},
+ year={2024}
+}
+
+@article{yao2025group,
+ title={Group-relative reinforce is secretly an off-policy algorithm: Demystifying some myths about grpo and its friends},
+ author={Yao, Chaorui and Chen, Yanxi and Sun, Yuchang and Chen, Yushuo and Zhang, Wenhao and Pan, Xuchen and Li, Yaliang and Ding, Bolin},
+ journal={arXiv preprint arXiv:2509.24203},
+ year={2025}
+}
+
+@article{pan2025trinity,
+ title={Trinity-rft: A general-purpose and unified framework for reinforcement fine-tuning of large language models},
+ author={Pan, Xuchen and Chen, Yanxi and Chen, Yushuo and Sun, Yuchang and Chen, Daoyuan and Zhang, Wenhao and Xie, Yuexiang and Huang, Yilun and Zhang, Yilei and Gao, Dawei and others},
+ journal={arXiv preprint arXiv:2505.17826},
+ year={2025}
+}
\ No newline at end of file
diff --git a/writing/figure/exp-main-result.tex b/writing/figure/exp-main-result.tex
new file mode 100644
index 0000000000..672261bb6e
--- /dev/null
+++ b/writing/figure/exp-main-result.tex
@@ -0,0 +1,35 @@
+\begin{table*}[t]
+\centering
+\caption{Main results on agentic environments and mathematical reasoning benchmarks. All metrics are success rate. R$^3$L significantly outperforms baselines on complex tasks.}
+\label{tab:main_results}
+\resizebox{\textwidth}{!}{% <--- 关键:强制缩放表格以适应页面宽度
+\begin{tabular}{ll ccc cccccc}
+\toprule
+& & \multicolumn{3}{c}{\textbf{Agentic Environments}} & \multicolumn{6}{c}{\textbf{Mathematical Reasoning}} \\
+\cmidrule(lr){3-5} \cmidrule(lr){6-11}
+\textbf{Model} & \textbf{Method} & \textbf{ALFWorld} & \textbf{WebShop} & \textbf{ScienceWorld} & \textbf{GSM8K} & \textbf{Math500} & \textbf{MinervaMath} & \textbf{Olympiad} & \textbf{AMC23} & \textbf{DAPO} \\
+\midrule
+
+% --- 1.5B Model Data ---
+\multirow{6}{*}{\texttt{Qwen2.5-1.5B-Ins}}
+& RAFT & 0.826 & 0.450 & 0.001 & 0.434 & 0.204 & 0.051 & 0.053 & 0.125 & 0.086 \\
+& OPMD & 0.835 & 0.224 & 0.016 & 0.465 & 0.292 & 0.063 & 0.123 & 0.125 & 0.070 \\
+& GRPO & 0.720 & 0.260 & 0.049 & 0.474 & 0.304 & 0.099 & 0.090 & 0.200 & 0.136 \\
+& DAPO & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& GSPO & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& \textbf{R$^3$L (Ours)} & \textbf{0.810} & \textbf{0.355} & \textbf{0.122} & \textbf{0.721} & \textbf{0.439} & \textbf{0.120} & \textbf{0.168} & \textbf{0.250} & \textbf{0.156} \\
+\midrule
+
+% --- 7B Model Data (Placeholders) ---
+\multirow{6}{*}{\texttt{Qwen2.5-7B-Ins}}
+& RAFT & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& OPMD & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& GRPO & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& DAPO & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& GSPO & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} & \textit{xx.x} \\
+& \textbf{R$^3$L (Ours)} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} & \textbf{\textit{xx.x}} \\
+
+\bottomrule
+\end{tabular}
+} % <-- 结束 \resizebox
+\end{table*}
\ No newline at end of file
diff --git a/writing/figure/framework.tex b/writing/figure/framework.tex
new file mode 100644
index 0000000000..003bdefc7f
--- /dev/null
+++ b/writing/figure/framework.tex
@@ -0,0 +1,10 @@
+\begin{figure*}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{figure/src/framework.pdf}
+\end{minipage}
+\centering
+\caption{Overview}
+\label{fig:architecture}
+\end{figure*}
\ No newline at end of file
diff --git a/writing/figure/reflect_prompt.tex b/writing/figure/reflect_prompt.tex
new file mode 100644
index 0000000000..7673b40cd3
--- /dev/null
+++ b/writing/figure/reflect_prompt.tex
@@ -0,0 +1,149 @@
+\section{Reflection Prompt}
+\label{sec:reflection_prompt}
+
+The reflection prompt guides the model through structured self-analysis to identify failure points and formulate corrective principles.
+
+\begin{strip}
+\begin{tcolorbox}[
+ colback=gray!5,
+ colframe=gray!75,
+ title=Reflection Prompt Template,
+ fonttitle=\bfseries,
+ breakable,
+ width=\textwidth
+]
+\begin{lstlisting}
+# Metacognitive Analyst AI Prompt
+
+You are a Metacognitive Analyst AI. Your core mission is to analyze a "Trajectory Log" containing a series of thoughts and actions. Your goal is to extract deep insights, summarize lessons learned, and formulate actionable principles for future improvement.
+
+You will receive a trajectory log. Your final output must be a structurally complete JSON object.
+
+## Your Internal Monologue & Analysis Protocol (MANDATORY)
+
+You will now begin your structured self-interrogation. Your analysis process must first review the trajectory globally before focusing on key points.
+
+### Part 1: Global Review & Analysis
+
+First, you must understand the entire trajectory from a macro perspective, especially feedbacks from user and environment.
+
+**Question 1.1: Conduct a Panoramic Trajectory Analysis**
+Read through the entire trajectory log and summarize in one or two sentences what the overall strategy was and what result it ultimately led to.
+
+**Question 1.2: Identify Key Issues**
+Based on your global understanding, identify the main problems or inefficiencies in the trajectory. What were the key mistakes or missed opportunities? If the execution was flawless, this is None.
+
+### Part 2: Deep Analysis of Key Issues
+
+Next, you will conduct this deep analysis if and only if key issues were identified in Part 1.
+
+**Question 2.1: Diagnose the Primary Flaw**
+What was the fundamental nature of the primary flaw? Categorize it into ONE of the following:
+- Strategy Flaw: The overall plan was misguided.
+- Reasoning Flaw: The interpretation of information was incorrect.
+- Execution Flaw: The intent was correct, but the resulting action was clumsy or ineffective.
+- Knowledge Gap: Lacked critical information necessary to solve the problem.
+- Inefficiency: The goal was achieved, but via a redundant or convoluted path.
+- Invalid Format: The response was syntactically incorrect or violated protocol.
+
+**Question 2.2: Uncover the Root Cause**
+Conduct a flexible root cause inquiry to uncover the core flawed assumption or problematic mental model that led to the flaw. Continuously ask "Why?" until the most fundamental cause is revealed.
+
+**Question 2.3: Formulate Better Approach**
+What would have been the optimal overall strategy or approach for this task?
+What series of positive effects would likely have followed from using this better approach?
+
+### Part 3: Synthesis, Verdict, and Lessons Learned
+
+Finally, after completing all the above analysis, you will synthesize your findings and render a final judgment.
+
+**Question 3.1: Formulate a Corrective Principle**
+
+Based on the analysis of the "Leverage Point," formulate an impactful Corrective Principle.
+
+**CRITICAL REQUIREMENTS for Principle Formulation:**
+
+1. **Context Completeness**: The principle must be self-contained and include ALL necessary context. It should be understandable and applicable without requiring external knowledge of the specific trajectory.
+ - BAD: "Click operations tend to cause failures"
+ - GOOD: "In the xxx environment, when click operations are not available in the action space, attempting to execute click will cause failures"
+
+2. **Domain Specificity**: Clearly specify the environment, system, or context where this principle applies.
+ - Include environment name
+ - Include relevant constraints or conditions
+
+3. **Causal Chain Awareness**: The principle should consider not just the immediate impact but also downstream consequences.
+ - Consider how the corrective action affects subsequent steps
+ - Anticipate potential cascading effects
+
+4. **Actionable Structure**: The principle should be actionable and clear, typically including:
+ - The specific environment or context
+ - Clear trigger conditions or situations
+ - The recommended action or approach
+ - The reasoning and expected benefits
+
+ **Note**: The exact format can vary based on the nature of the insight. It could be a prescriptive rule ("When X, do Y"), a cautionary guideline ("Avoid X in situation Y"), or a strategic insight ("Prioritize X because Y"). Choose the format that best captures the lesson learned.
+
+5. **Independence Test**: The principle should be meaningful and correct even if read in isolation, without access to the original trajectory.
+
+**Question 3.2: Render the Final Verdict**
+
+Now, and only now, based on your complete analysis, classify the outcome of this task into one of the following more precise categories:
+
+- **OPTIMAL**: Flawlessly and efficiently achieved the goal; a textbook execution.
+- **SUBOPTIMAL_SUCCESS**: Achieved the goal, but with correctable inefficiencies or minor flaws.
+- **PARTIAL**: Made significant progress but did not fully meet the final goal.
+- **INEFFECTIVE**: Fully failed to achieve the primary goal.
+
+## Final Output Format (Strictly Adhere to the Unified Schema)
+
+Your final output must strictly contain the following two parts: Part One is your detailed analysis process (in text form), and Part Two is the summary JSON report.
+
+### Part One: Detailed Analysis Process
+
+You must answer all questions from the protocol one by one here, showing your complete chain of thought.
+
+**1. Global Review & Analysis**
+- 1.1 Panoramic Trajectory Analysis: Fill in your macro summary of the trajectory here
+- 1.2 Key Issues Identification: Fill in the identified key issues and the reasoning here
+
+**2. Deep Analysis of Key Issues**
+- 2.1 Primary Flaw Diagnosis: Fill in the flaw's classification here
+- 2.2 Root Cause: Fill in the result of the root cause inquiry here
+- 2.3 Better Approach: Fill in the analysis of the optimal strategy and its expected benefits here
+
+**3. Synthesis, Verdict, and Lessons Learned**
+- 3.1 Corrective Principle: Fill in the corrective principle you formulated here (MUST meet all 5 critical requirements)
+- 3.2 Final Verdict: Fill in the final classification verdict you rendered here
+
+### Part Two: Final JSON Report
+
+After completing the detailed analysis above, synthesize all conclusions and populate the following JSON structure (```json is mandatory as JSON prefix):
+
+{
+ "outcome_assessment": "OPTIMAL | SUBOPTIMAL_SUCCESS | PARTIAL | INEFFECTIVE",
+ "analysis": {
+ "summary": "Summary of the trajectory's strategy, outcome, and core insight.",
+ "flaw_analysis": {
+ "diagnosis": {
+ "category": "Strategy Flaw | Reasoning Flaw | Execution Flaw | Knowledge Gap | Inefficiency | null",
+ "root_cause": "The core flawed assumption or problematic mental model that was uncovered. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "better_approach": {
+ "strategy": "The optimal overall strategy or approach that should have been used. Can be null if outcome_assessment is OPTIMAL.",
+ "key_differences": "A brief explanation of how this better approach differs from the original approach. Can be null if outcome_assessment is OPTIMAL.",
+ "projected_benefits": "The series of positive effects projected to occur from using the better approach. Can be null if outcome_assessment is OPTIMAL."
+ }
+ },
+ "lessons_learned": {
+ "corrective_principle": "A self-contained, context-complete principle that includes environment specifics, clear trigger conditions, and considers downstream effects. Must be understandable and applicable in isolation. Can be null if outcome_assessment is OPTIMAL.",
+ "revised_action_plan": "The improved action plan based on the corrective principle, considering both immediate and downstream impacts. Can be null if outcome_assessment is OPTIMAL."
+ },
+ "retry_strategy": {
+ "retry_step": "The specific step that should be retried. Can be null if outcome_assessment is OPTIMAL. Range is 0 to N-1, where N is the total number of steps in the trajectory, 0 means restart from beginning.",
+ "retry_rationale": "Explanation of why this step was chosen as restart point"
+ }
+ }
+}
+\end{lstlisting}
+\end{tcolorbox}
+\end{strip}
\ No newline at end of file
diff --git "a/writing/figure/src/framework - \345\211\257\346\234\254.pptx" "b/writing/figure/src/framework - \345\211\257\346\234\254.pptx"
new file mode 100644
index 0000000000..b73fcbc4ae
Binary files /dev/null and "b/writing/figure/src/framework - \345\211\257\346\234\254.pptx" differ
diff --git a/writing/figure/src/framework.pdf b/writing/figure/src/framework.pdf
new file mode 100644
index 0000000000..09b4abdfbb
Binary files /dev/null and b/writing/figure/src/framework.pdf differ
diff --git a/writing/figure/src/framework.pptx b/writing/figure/src/framework.pptx
new file mode 100644
index 0000000000..de6b445c0c
Binary files /dev/null and b/writing/figure/src/framework.pptx differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README"
new file mode 100644
index 0000000000..7e59600739
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README"
@@ -0,0 +1 @@
+# README
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README.md" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README.md"
new file mode 100644
index 0000000000..025577f497
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/README.md"
@@ -0,0 +1,53 @@
+# *ACL Paper Styles
+
+This directory contains the latest LaTeX and Word templates for *ACL
+conferences.
+
+## Instructions for authors
+
+Paper submissions to *ACL conferences must use the official ACL style
+templates.
+
+The LaTeX style files are available
+
+- as an [Overleaf template](https://www.overleaf.com/latex/templates/association-for-computational-linguistics-acl-conference/jvxskxpnznfj)
+- in this repository, in the [`latex`](https://github.com/acl-org/acl-style-files/blob/master/latex) subdirectory
+- as a [.zip file](https://github.com/acl-org/acl-style-files/archive/refs/heads/master.zip)
+
+Please see [`latex/acl_latex.tex`](https://github.com/acl-org/acl-style-files/blob/master/acl_latex.tex) for an example.
+
+The Microsoft Word template is available in this repository at [`word/acl.docx`](https://github.com/acl-org/acl-style-files/blob/master/word/acl.docx).
+
+Please follow the paper formatting guidelines general to *ACL
+conferences:
+
+- [Paper formatting guidelines](https://acl-org.github.io/ACLPUB/formatting.html)
+
+Authors may not modify these style files or use templates designed for
+other conferences.
+
+## Instructions for publications chairs
+
+To adapt the style files for your conference, please fork this repository and
+make necessary changes. Minimally, you'll need to update the name of
+the conference and rename the files.
+
+If you make improvements to the templates that should be propagated to
+future conferences, please submit a pull request. Thank you in
+advance!
+
+In older versions of the templates, authors were asked to fill in the
+START submission ID so that it would be stamped at the top of each
+page of the anonymized version. This is no longer needed, because it
+is now possible to do this stamping automatically within
+START. Currently, the way to do this is for the program chair to email
+support@softconf.com and request it.
+
+## Instructions for making changes to style files
+
+- merge pull request in github, or push to github
+- git pull from github to a local repository
+- then, git push from your local repository to overleaf project
+ - Overleaf project is https://www.overleaf.com/project/5f64f1fb97c4c50001b60549
+ - Overleaf git url is https://git.overleaf.com/5f64f1fb97c4c50001b60549
+- then, click "Submit" and then "Sumbit as Template" in overleaf in order to ask overleaf to update the overleaf template from the overleaf project
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/anthology.bib.txt" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/anthology.bib.txt"
new file mode 100644
index 0000000000..14f228c6eb
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/anthology.bib.txt"
@@ -0,0 +1,8 @@
+% Please download the latest anthology.bib from the following URL:
+%
+% http://aclweb.org/anthology/anthology.bib
+%
+% From the command line, this can be done with curl or wget.
+%
+% If you are using Overleaf, go to "New File -> From External URL".
+% You will then be able to use it directly, and to periodically update it by clicking Overleaf's convenient "refresh" button.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/formatting.md" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/formatting.md"
new file mode 100644
index 0000000000..eeb1ce1548
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/formatting.md"
@@ -0,0 +1,326 @@
+# Instructions for *ACL Proceedings
+
+The following instructions are for authors of papers submitted for review to ACL conferences (hereafter, "review version") or paper accepted for publication in its proceedings (hereafter, "final version").
+All authors are required to adhere to these specifications.
+
+## Style Files
+
+*ACL provides style files for LaTeX and Microsoft Word that meet these requirements. They can be found at:
+
+> https://acl-org.github.io/ACLPUB/
+
+We strongly recommend the use of these style files, which have been appropriately tailored for the *ACL proceedings.
+
+## Paper Length
+
+The conference accepts submissions of long papers and short papers.
+Review versions of long papers may have up to eight (8) pages of content plus unlimited pages for references.
+Upon acceptance, final versions of long papers will be given one additional page -- up to nine (9) pages of content plus unlimited pages for acknowledgements and references -- so that reviewers' comments can be taken into account.
+Review versions of short papers may have up to four (4) pages of content, plus unlimited pages for references.
+Final versions of short papers may have up to five (5) pages, plus unlimited pages for acknowledgements and references.
+For both long and short papers, all figures and tables that are part of the main text must fit within these page limits.
+
+The conference encourages submission of appendices and supplementary material, which are not required to fit within these page limits. However, review versions of papers must be self-contained: it is optional for reviewers to look at appendices or supplementary material. Please see [Appendices](#Appendices) and [Supplementary](#Supplementary Material) for more information.
+
+Review versions should not refer, for further detail, to documents, code or data resources that are not available to the reviewers.
+
+Papers that do not conform to these requirements may be rejected without review.
+
+Workshop chairs may have different rules for allowed length and whether appendices or supplementary materials are welcome.
+As always, the respective call for papers is the authoritative source.
+
+## Anonymity
+
+As reviewing will be double-blind, review versions must not include any identifying information about the authors (such as names, affiliations, or URLs).
+Self-references that reveal the author's identity, e.g.,
+
+> We previously showed (Gusfield, 1997)...
+
+must be avoided, and anonymous citations, e.g.,
+
+> We previously showed (Anonymous, 1997)...
+
+should also be avoided. Instead, use citations such as
+
+> Gusfield (1997) previously showed...
+
+Review versions must not include acknowledgements.
+
+**Papers that do not conform to these requirements may be rejected without review.**
+
+Any preliminary non-archival versions of submitted papers should be listed in the submission form but not in the review version of the paper.
+Reviewers are generally aware that authors may present preliminary versions of their work in other venues, but will not be provided the list of previous presentations from the submission form.
+
+Once a paper has been accepted to the conference, the final version should include the author's names and affiliations, and is allowed to use self-references.
+
+## Multiple Submission
+
+Papers that have been or will be submitted to other meetings or publications must indicate this at submission time in the START submission form, and must be withdrawn from the other venues if accepted by *ACL.
+Authors of papers accepted for presentation at *ACL must notify the program chairs by the deadline for final versions ("camera-ready deadline") whether the paper will be presented.
+We will not accept for publication or presentation any papers that overlap significantly in content or results with papers that will be (or have been) published elsewhere.
+
+Authors submitting more than one paper to *ACL must ensure that submissions do not overlap significantly (>25%) with each other in content or results.
+
+## Formatting Instructions
+
+### File Format
+
+Papers must be in Adobe Portable Document Format (PDF).
+Please make sure that your PDF file embeds all necessary fonts (especially for tree diagrams, symbols, and Asian languages).
+When you print or create the PDF file, there is usually an option in your printer setup to include none, all or just non-standard fonts.
+Please make sure that you select the option of including *all* the fonts.
+**Before sending it, test your PDF by printing it from a computer different from the one where it was created.**
+
+Some word processors may generate very large PDF files, where each page is rendered as an image.
+Such images may reproduce poorly.
+In this case, try alternative ways to obtain the PDF.
+
+All papers must use **A4 paper format** (21 cm x 29.7 cm).
+Papers must not be submitted with any other paper size.
+
+If you cannot meet the above requirements, please contact the publication chairs as soon as possible.
+
+### Layout
+
+All text except for page numbers must fit within the margins.
+
+Review versions should have page numbers, centered in the bottom margin, but **pages should not be numbered in the final version.**
+
+Manuscripts must be set in two columns.
+Exceptions to the two-column format include the title, authors' names and complete addresses, which must be centered at the top of the first page, and any full-width figures or tables.
+
+The exact dimensions for a page on A4 paper are:
+
+* Left margin: 2.5 cm
+* Right margin: 2.5 cm
+* Top margin: 2.5 cm
+* Bottom margin: 2.5 cm
+* Column width: 7.7 cm
+* Column height: 24.7 cm
+* Gap between columns: 0.6 cm
+
+In the review version, a ruler (line numbers in the left and right margins of the article) should be printed, so that reviewers may comment on particular lines in the paper.
+The ruler should not change the appearance of any other content on the page.
+The final version should not contain a ruler.
+
+### Fonts
+
+All text (except non-Latin scripts and mathematical formulas) should be set in **Times Roman**.
+If Times Roman is unavailable, you may use **Times New Roman** or **Computer Modern Roman.**
+
+The following table specifies what font sizes and styles must be used for each type of text in the manuscript.
+
+| Type of Text | Font Size | Style |
+| --------------------- | --------- | ----- |
+| paper title | 15 pt | bold |
+| author names | 12 pt | bold |
+| author affiliation | 12 pt | |
+| the word ``Abstract'' | 12 pt | bold |
+| section titles | 12 pt | bold |
+| subsection titles | 11 pt | bold |
+| document text | 11 pt | |
+| captions | 10 pt | |
+| abstract text | 10 pt | |
+| bibliography | 10 pt | |
+| footnotes | 9 pt | |
+
+### Title and Authors
+
+Center the title, author's name(s) and affiliation(s) across both columns.
+
+Place the title centered at the top of the first page, in 15-point bold.
+Long titles should be typed on two lines without a blank line intervening.
+Put the title 2.5 cm from the top of the page.
+Write the title in [title case](https://apastyle.apa.org/style-grammar-guidelines/capitalization/title-case); do not write the title in all capital letters, except for acronyms (e.g., "BLEU") or proper nouns ("English") that are normally uppercased or capitalized.
+
+Place the author name(s) and affiliation(s) under the title.
+Write authors' full names; do not abbreviate given names to initials, unless they are normally written as initials ("Margaret Mitchell", not "M. Mitchell").
+Do not format surnames in all capitals ("Mitchell", not "MITCHELL").
+
+Do not use footnotes for affiliations.
+The affiliation should contain the author's complete address, and if possible, an electronic mail address.
+
+The title, author names and addresses should be completely identical to those entered to the paper submission website in order to maintain the consistency of author information among all publications of the conference.
+If they are different, the publication chairs may resolve the difference without consulting with you; so it is in your own interest to double-check that the information is consistent.
+
+Start the body of the first page 7.5 cm from the top of the page.
+**Even in the review version of the paper, you should maintain space for names and addresses so that they will fit in the final version.**
+
+### Abstract
+
+Type the abstract at the beginning of the first column.
+Center the word **Abstract** in 12 point bold above the body of the abstract.
+The width of the abstract should be smaller than the
+normal column width by 0.6 cm on each side.
+The abstract text should be 10 point roman, single-spaced.
+
+The abstract should be a concise summary of the general thesis and conclusions of the paper.
+It should be no longer than 200 words.
+
+### Text
+
+Begin typing the main body of the text immediately after the abstract, continuing in two columns.
+The text should be 11 point roman, single-spaced.
+
+Indent 0.4 cm when starting a new paragraph, except for the first paragraph in a section.
+
+### Sections
+
+Use numbered sections (Arabic numerals) to facilitate cross references.
+Number subsections with the section number and the subsection number separated by a dot, in Arabic numerals, e.g.,
+
+> 1 Introduction
+
+or
+
+> 6.1 File Format
+
+### Footnotes
+Put footnotes at the bottom of the page and use 9 point font.
+They may be numbered or referred to by asterisks or other symbols.
+Footnotes should be separated from the text by a line.
+
+### Figures and tables
+
+Place figures and tables in the paper near where they are first discussed, rather than at the end, if possible.
+Wide figures/tables may run across both columns.
+
+To accommodate people who are color-blind (as well as those printing with black-and-white printers), grayscale readability is strongly encouraged.
+Color is not forbidden, but authors should ensure that tables and figures do not rely solely on color to convey critical distinctions.
+
+**Captions:**
+Provide a caption for every figure/table; number each one sequentially in the form:
+
+> Figure 1: Caption of the Figure.
+
+and
+
+> Table 1: Caption of the Table.
+
+Captions should be placed below figures/tables, in 10 point roman type.
+Captions that are one line are centered.
+Captions longer than one line are left-aligned.
+
+### Hyperlinks
+
+Within-document and external hyperlinks should be dark blue (hex #000099), not underlined or boxed.
+
+### Non-English Text
+
+Text in languages other than English should be accompanied by translations into English, and text in scripts other than Latin should \emph{also} be accompanied by transliterations into Latin script, since not all readers can recognize non-Latin characters easily.
+
+For example, παράδειγμα *paradeigma* ‘example’ is a Greek word, and this is a Greek sentence:
+
+> Αυτό είναι ένα παράδειγμα.
+> auto einai ena paradeigma.
+> ‘This is an example.’
+
+### Citations
+
+Citations within the text appear in parentheses (Gusfield, 1997), or, if the author's name appears in the text itself: Gusfield (1997).
+Append lowercase letters to the year in cases of ambiguities.
+Cite papers with two authors using both authors' names (Aho and Ullman, 1972), but cite papers with more than two authors by the first author's name and ``et al.'' (Chandra et al., 1981).
+Collapse multiple citations into a single pair of parentheses (Gusfield, 1997; Aho and Ullman, 1972).
+
+Refrain from using full citations as sentence constituents.
+Instead of
+
+> (Gusfield, 1997) showed that ...
+> In (Gusfield, 1997), ...''
+
+write
+
+> Gusfield (1997) showed that ...
+> In Gusfield (1997), ...
+
+Submissions should accurately reference prior and related work, including code and data.
+If a piece of prior work appeared in multiple venues, the version that appeared in a refereed, archival venue should be referenced.
+If multiple versions of a piece of prior work exist, the one used by the authors should be referenced.
+
+### Acknowledgments
+
+The acknowledgments should go immediately before the references.
+Do not number the acknowledgments section.
+Do not include this section in the review version.
+
+### References
+
+Gather the full set of references together under the unnumbered section heading **References**.
+Place the References section before any Appendices.
+Arrange the references alphabetically by first author, rather than by order of occurrence in the text.
+
+Provide as complete a citation as possible, using a consistent format, such as the [one for Computational Linguistics](http://cljournal.org/style_guide_refs.html) or the one in the [Publication Manual of the American Psychological Association](https://apastyle.apa.org/products/publication-manual-7th-edition).
+Use full names for authors, not just initials.
+Authors should not rely on automated citation indices to provide accurate references for prior and related work.
+
+As part of our work to make ACL materials more widely used and cited outside of our discipline, ACL has registered as a CrossRef member, as a registrant of Digital Object Identifiers (DOIs), the standard for registering permanent URNs for referencing scholarly materials.
+
+All references are required to contain DOIs of all cited works when possible, or, as a second resort, links to ACL Anthology pages.
+Appropriate records should be found for most materials in the current [ACL Anthology](https://aclweb.org/anthology/).
+
+Example article in a journal:
+
+> Rie Kubota Ando and Tong Zhang. 2005. [A framework for learning predictive structures from multiple tasks and unlabeled data](https://www.jmlr.org/papers/v6/ando05a.html). *Journal of Machine Learning Research*, 6:1817–1853.
+
+Example paper in non-ACL proceedings, with DOI:
+
+> Galen Andrew and Jianfeng Gao. 2007. [Scalable training of L1-regularized log-linear models](https://doi.org/10.1145/1273496.1273501). In *Proceedings of the 24th International Conference on Machine Learning*, pages 33–40.
+
+Example ACL Anthology paper with DOI:
+
+> James Goodman, Andreas Vlachos, and Jason Naradowsky. 2016. [Noise reduction and targeted exploration in imitation learning for Abstract Meaning Representation parsing](http://dx.doi.org/10.18653/v1/P16-1001). In *Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)*, pages 1–45711, Berlin, Germany. Association for Computational Linguistics.
+
+Example ACL Anthology paper without DOI:
+
+> Benjamin Börschinger and Mark Johnson. 2011. [A particle filter algorithm for Bayesian word segmentation](https://www.aclweb.org/anthology/U11-1004/). In *Proceedings of the Australasian Language Technology Association Workshop 2011*, pages 10–44718, Canberra, Australia.
+
+Example arXiv paper:
+
+> Mohammad Sadegh Rasooli and Joel R. Tetreault. 2015. [Yara parser: A fast and accurate dependency parser](http://arxiv.org/abs/1503.06733). *Computing Research Repository*, arXiv:1503.06733. Version 2.
+
+## Appendices
+
+Appendices are material that can be read, and include lemmas, formulas, proofs, and tables that are not critical to the reading and understanding of the paper.
+Letter them in sequence and provide an informative title:
+
+> Appendix A. Title of Appendix
+
+The appendices come after the references.
+
+Review versions of appendices must follow the same anonymity guidelines as the main paper.
+
+## Supplementary Material
+
+Submissions may include non-readable supplementary material used in the work and described in the paper.
+Any accompanying software and/or data should include licenses and documentation of research review as appropriate.
+Supplementary material may report preprocessing decisions, model parameters, and other details necessary for the replication of the experiments reported in the paper.
+Seemingly small preprocessing decisions can sometimes make a large difference in performance, so it is crucial to record such decisions to precisely characterize state-of-the-art methods.
+
+Nonetheless, supplementary material should be supplementary (rather than central) to the paper.
+**Submissions that misuse the supplementary material may be rejected without review.**
+Supplementary material may include explanations or details of proofs or derivations that do not fit into the paper, lists of features or feature templates, sample inputs and outputs for a system, pseudo-code or source code, and data.
+(Source code and data should be separate uploads, rather than part of the paper).
+
+The paper should not rely on the supplementary material: while the paper may refer to and cite the supplementary material and the supplementary material will be available to the reviewers, they will not be asked to review the supplementary material.
+
+Review versions of supplementary material must follow the same anonymity guidelines as the main paper.
+
+## Credits
+
+This document has been adapted from the instructions for earlier ACL and NAACL proceedings, including those for
+ACL 2020 by Steven Bethard, Ryan Cotterell and Rui Yan,
+ACL 2019 by Douwe Kiela and Ivan Ivan Vulić,
+NAACL 2019 by Stephanie Lukin and Alla Roskovskaya,
+ACL 2018 by Shay Cohen, Kevin Gimpel, and Wei Lu,
+NAACL 2018 by Margaret Mitchell and Stephanie Lukin,
+BibTeX suggestions for (NA)ACL 2017/2018 from Jason Eisner,
+ACL 2017 by Dan Gildea and Min-Yen Kan,
+NAACL 2017 by Margaret Mitchell,
+ACL 2012 by Maggie Li and Michael White,
+ACL 2010 by Jing-Shin Chang and Philipp Koehn,
+ACL 2008 by Johanna D. Moore, Simone Teufel, James Allan, and Sadaoki Furui,
+ACL 2005 by Hwee Tou Ng and Kemal Oflazer,
+ACL 2002 by Eugene Charniak and Dekang Lin,
+and earlier ACL and EACL formats written by several people, including
+John Chen, Henry S. Thompson and Donald Walker.
+Additional elements were taken from the formatting instructions of the *International Joint Conference on Artificial Intelligence* and the *Conference on Computer Vision and Pattern Recognition*.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl.sty" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl.sty"
new file mode 100644
index 0000000000..c494e0a838
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl.sty"
@@ -0,0 +1,321 @@
+% This is the LaTex style file for *ACL.
+% The official sources can be found at
+%
+% https://github.com/acl-org/acl-style-files/
+%
+% This package is activated by adding
+%
+% \usepackage{acl}
+%
+% to your LaTeX file. When submitting your paper for review, add the "review" option:
+%
+% \usepackage[review]{acl}
+
+\newif\ifacl@finalcopy
+\newif\ifacl@anonymize
+\newif\ifacl@linenumbers
+\newif\ifacl@pagenumbers
+\DeclareOption{final}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumbersfalse}
+\DeclareOption{review}{\acl@finalcopyfalse\acl@anonymizetrue\acl@linenumberstrue\acl@pagenumberstrue}
+\DeclareOption{preprint}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumberstrue}
+\ExecuteOptions{final} % final copy is the default
+
+% include hyperref, unless user specifies nohyperref option like this:
+% \usepackage[nohyperref]{acl}
+\newif\ifacl@hyperref
+\DeclareOption{hyperref}{\acl@hyperreftrue}
+\DeclareOption{nohyperref}{\acl@hyperreffalse}
+\ExecuteOptions{hyperref} % default is to use hyperref
+\ProcessOptions\relax
+
+\typeout{Conference Style for ACL}
+
+\usepackage{xcolor}
+
+\ifacl@linenumbers
+ % Add draft line numbering via the lineno package
+ % https://texblog.org/2012/02/08/adding-line-numbers-to-documents/
+ \usepackage[switch,mathlines]{lineno}
+
+ % Line numbers in gray Helvetica 8pt
+ \font\aclhv = phvb at 8pt
+ \renewcommand\linenumberfont{\aclhv\color{lightgray}}
+
+ % Zero-fill line numbers
+ % NUMBER with left flushed zeros \fillzeros[]
+ \newcount\cv@tmpc@ \newcount\cv@tmpc
+ \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi
+ \cv@tmpc=1 %
+ \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi
+ \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat
+ \ifnum#2<0\advance\cv@tmpc1\relax-\fi
+ \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat
+ \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}%
+ \renewcommand\thelinenumber{\fillzeros[3]{\arabic{linenumber}}}
+ \linenumbers
+
+ \setlength{\linenumbersep}{1.6cm}
+
+ % Bug: An equation with $$ ... $$ isn't numbered, nor is the previous line.
+
+ % Patch amsmath commands so that the previous line and the equation itself
+ % are numbered. Bug: multline has an extra line number.
+ % https://tex.stackexchange.com/questions/461186/how-to-use-lineno-with-amsmath-align
+ \usepackage{etoolbox} %% <- for \pretocmd, \apptocmd and \patchcmd
+
+ \newcommand*\linenomathpatch[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomath}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomath}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+ \newcommand*\linenomathpatchAMS[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomathAMS}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomathAMS}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+
+ %% Definition of \linenomathAMS depends on whether the mathlines option is provided
+ \expandafter\ifx\linenomath\linenomathWithnumbers
+ \let\linenomathAMS\linenomathWithnumbers
+ %% The following line gets rid of an extra line numbers at the bottom:
+ \patchcmd\linenomathAMS{\advance\postdisplaypenalty\linenopenalty}{}{}{}
+ \else
+ \let\linenomathAMS\linenomathNonumbers
+ \fi
+
+ \AtBeginDocument{%
+ \linenomathpatch{equation}%
+ \linenomathpatchAMS{gather}%
+ \linenomathpatchAMS{multline}%
+ \linenomathpatchAMS{align}%
+ \linenomathpatchAMS{alignat}%
+ \linenomathpatchAMS{flalign}%
+ }
+\else
+ % Hack to ignore these commands, which review mode puts into the .aux file.
+ \newcommand{\@LN@col}[1]{}
+ \newcommand{\@LN}[2]{}
+\fi
+
+\iffalse
+\PassOptionsToPackage{
+ a4paper,
+ top=2.21573cm,left=2.54cm,
+ textheight=704.60031pt, % 51 * \baselineskip + \topskip
+ textwidth=16.0cm,
+ headheight=0.17573cm,headsep=0cm
+}{geometry}
+\fi
+\PassOptionsToPackage{a4paper,margin=2.5cm,heightrounded=true}{geometry}
+\RequirePackage{geometry}
+
+\setlength\columnsep{0.6cm}
+\newlength\titlebox
+\setlength\titlebox{11\baselineskip}
+% \titlebox should be a multiple of \baselineskip so that
+% column height remaining fits an exact number of lines of text
+
+\flushbottom \twocolumn \sloppy
+
+% We're never going to need a table of contents, so just flush it to
+% save space --- suggested by drstrip@sandia-2
+\def\addcontentsline#1#2#3{}
+
+\ifacl@pagenumbers
+ \pagenumbering{arabic}
+\else
+ \thispagestyle{empty}
+ \pagestyle{empty}
+\fi
+
+%% Title and Authors %%
+
+\let\Thanks\thanks % \Thanks and \thanks used to be different, but keep this for backwards compatibility.
+
+\newcommand\outauthor{%
+ \begin{tabular}[t]{c}
+ \ifacl@anonymize
+ \bf Anonymous ACL submission
+ \else
+ \bf\@author
+ \fi
+ \end{tabular}}
+
+% Mostly taken from deproc.
+\AtBeginDocument{
+\def\maketitle{\par
+ \begingroup
+ \def\thefootnote{\fnsymbol{footnote}}
+ \twocolumn[\@maketitle] \@thanks
+ \endgroup
+ \setcounter{footnote}{0}
+ \let\maketitle\relax \let\@maketitle\relax
+ \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax}
+\def\@maketitle{\vbox to \titlebox{\hsize\textwidth
+ \linewidth\hsize \vskip 0.125in minus 0.125in \centering
+ {\Large\bf \@title \par} \vskip 0.2in plus 1fil minus 0.1in
+ {\def\and{\unskip\enspace{\rm and}\enspace}%
+ \def\And{\end{tabular}\hss \egroup \hskip 1in plus 2fil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bf}%
+ \def\AND{\end{tabular}\hss\egroup \hfil\hfil\egroup
+ \vskip 0.25in plus 1fil minus 0.125in
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bf}
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss
+ \outauthor
+ \hss\egroup
+ \hfil\hfil\egroup}
+ \vskip 0.3in plus 2fil minus 0.1in
+}}
+}
+
+% margins and font size for abstract
+\renewenvironment{abstract}%
+ {\centerline{\large\bf Abstract}%
+ \begin{list}{}%
+ {\setlength{\rightmargin}{0.6cm}%
+ \setlength{\leftmargin}{0.6cm}}%
+ \item[]\ignorespaces%
+ \@setsize\normalsize{12pt}\xpt\@xpt
+ }%
+ {\unskip\end{list}}
+
+%\renewenvironment{abstract}{\centerline{\large\bf
+% Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex}
+
+% Resizing figure and table captions - SL
+% Support for interacting with the caption, subfigure, and subcaption packages - SL
+\RequirePackage{caption}
+\DeclareCaptionFont{10pt}{\fontsize{10pt}{12pt}\selectfont}
+\captionsetup{font=10pt}
+
+\RequirePackage{natbib}
+% for citation commands in the .tex, authors can use:
+% \citep, \citet, and \citeyearpar for compatibility with natbib, or
+% \cite, \newcite, and \shortcite for compatibility with older ACL .sty files
+\renewcommand\cite{\citep} % to get "(Author Year)" with natbib
+\newcommand\shortcite{\citeyearpar}% to get "(Year)" with natbib
+\newcommand\newcite{\citet} % to get "Author (Year)" with natbib
+\newcommand{\citeposs}[1]{\citeauthor{#1}'s (\citeyear{#1})} % to get "Author's (Year)"
+
+\bibliographystyle{acl_natbib}
+
+% Bibliography
+
+% Don't put a label in the bibliography at all. Just use the unlabeled format
+% instead.
+\def\thebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{References\@mkboth
+ {References}{References}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthebibliography=\endlist
+
+
+% Allow for a bibliography of sources of attested examples
+\def\thesourcebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{Sources of Attested Examples\@mkboth
+ {Sources of Attested Examples}{Sources of Attested Examples}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthesourcebibliography=\endlist
+
+% sections with less space
+\def\section{\@startsection {section}{1}{\z@}{-2.0ex plus
+ -0.5ex minus -.2ex}{1.5ex plus 0.3ex minus .2ex}{\large\bf\raggedright}}
+\def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus
+ -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\bf\raggedright}}
+%% changed by KO to - values to get the initial parindent right
+\def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex plus
+ -0.5ex minus -.2ex}{0.5ex plus .2ex}{\normalsize\bf\raggedright}}
+\def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bf}}
+\def\subparagraph{\@startsection{subparagraph}{5}{\parindent}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bf}}
+
+% Footnotes
+\footnotesep 6.65pt %
+\skip\footins 9pt plus 4pt minus 2pt
+\def\footnoterule{\kern-3pt \hrule width 5pc \kern 2.6pt }
+\setcounter{footnote}{0}
+
+% Lists and paragraphs
+\parindent 1em
+\topsep 4pt plus 1pt minus 2pt
+\partopsep 1pt plus 0.5pt minus 0.5pt
+\itemsep 2pt plus 1pt minus 0.5pt
+\parsep 2pt plus 1pt minus 0.5pt
+
+\leftmargin 2em \leftmargini\leftmargin \leftmarginii 2em
+\leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em \leftmarginvi .5em
+\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
+
+\def\@listi{\leftmargin\leftmargini}
+\def\@listii{\leftmargin\leftmarginii
+ \labelwidth\leftmarginii\advance\labelwidth-\labelsep
+ \topsep 2pt plus 1pt minus 0.5pt
+ \parsep 1pt plus 0.5pt minus 0.5pt
+ \itemsep \parsep}
+\def\@listiii{\leftmargin\leftmarginiii
+ \labelwidth\leftmarginiii\advance\labelwidth-\labelsep
+ \topsep 1pt plus 0.5pt minus 0.5pt
+ \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt
+ \itemsep \topsep}
+\def\@listiv{\leftmargin\leftmarginiv
+ \labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
+\def\@listv{\leftmargin\leftmarginv
+ \labelwidth\leftmarginv\advance\labelwidth-\labelsep}
+\def\@listvi{\leftmargin\leftmarginvi
+ \labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
+
+\abovedisplayskip 7pt plus2pt minus5pt%
+\belowdisplayskip \abovedisplayskip
+\abovedisplayshortskip 0pt plus3pt%
+\belowdisplayshortskip 4pt plus3pt minus3pt%
+
+% Less leading in most fonts (due to the narrow columns)
+% The choices were between 1-pt and 1.5-pt leading
+\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt}
+\def\small{\@setsize\small{10pt}\ixpt\@ixpt}
+\def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt}
+\def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt}
+\def\tiny{\@setsize\tiny{7pt}\vipt\@vipt}
+\def\large{\@setsize\large{14pt}\xiipt\@xiipt}
+\def\Large{\@setsize\Large{16pt}\xivpt\@xivpt}
+\def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt}
+\def\huge{\@setsize\huge{23pt}\xxpt\@xxpt}
+\def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt}
+
+% The hyperref manual (section 9) says hyperref should be loaded after natbib
+\ifacl@hyperref
+ \PassOptionsToPackage{breaklinks}{hyperref}
+ \RequirePackage{hyperref}
+ % make links dark blue
+ \definecolor{darkblue}{rgb}{0, 0, 0.5}
+ \hypersetup{colorlinks=true, citecolor=darkblue, linkcolor=darkblue, urlcolor=darkblue}
+\else
+ % This definition is used if the hyperref package is not loaded.
+ % It provides a backup, no-op definiton of \href.
+ % This is necessary because \href command is used in the acl_natbib.bst file.
+ \def\href#1#2{{#2}}
+ \usepackage{url}
+\fi
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_latex.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_latex.tex"
new file mode 100644
index 0000000000..5f8c01bc90
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_latex.tex"
@@ -0,0 +1,507 @@
+% This must be in the first 5 lines to tell arXiv to use pdfLaTeX, which is strongly recommended.
+\pdfoutput=1
+% In particular, the hyperref package requires pdfLaTeX in order to break URLs across lines.
+
+\documentclass[11pt]{article}
+
+% Change "review" to "final" to generate the final (sometimes called camera-ready) version.
+% Change to "preprint" to generate a non-anonymous version with page numbers.
+\usepackage[final]{acl}
+
+% Standard package includes
+\usepackage{times}
+\usepackage{latexsym}
+\usepackage{algorithmic}
+\usepackage{graphicx}
+\usepackage{textcomp}
+\usepackage{xcolor}
+\usepackage{physics}
+\usepackage{mathdots}
+\usepackage{algorithm}
+\usepackage{subfigure}
+\usepackage{pgfplots}
+\usepackage{tikz}
+\usetikzlibrary{patterns}
+\usepackage{multirow}
+\usepackage{multicol}
+% \usepackage{authblk}
+\usepackage{array}
+\pgfplotsset{compat=1.18}
+
+% 在这里添加这一行
+\let\Bbbk\relax
+\usepackage{amsmath,amssymb,amsfonts}
+\DeclareMathOperator*{\argmax}{arg\,max}
+
+% For proper rendering and hyphenation of words containing Latin characters (including in bib files)
+\usepackage[T1]{fontenc}
+% For Vietnamese characters
+% \usepackage[T5]{fontenc}
+% See https://www.latex-project.org/help/documentation/encguide.pdf for other character sets
+
+% This assumes your files are encoded as UTF8
+\usepackage[utf8]{inputenc}
+
+% This is not strictly necessary, and may be commented out,
+% but it will improve the layout of the manuscript,
+% and will typically save some space.
+\usepackage{microtype}
+
+% This is also not strictly necessary, and may be commented out.
+% However, it will improve the aesthetics of text in
+% the typewriter font.
+\usepackage{inconsolata}
+
+%Including images in your LaTeX document requires adding
+%additional package(s)
+\usepackage{graphicx}
+
+% If the title and author information does not fit in the area allocated, uncomment the following
+%
+%\setlength\titlebox{}
+%
+% and set to something 5cm or larger.
+
+\title{Making RALM Robust to Irrelevant Contexts via Layer Knowledge Guided Attention}
+
+% Author information can be set in various styles:
+% For several authors from the same institution:
+% \author{Author 1 \and ... \and Author n \\
+% Address line \\ ... \\ Address line}
+% if the names do not fit well on one line use
+% Author 1 \\ {\bf Author 2} \\ ... \\ {\bf Author n} \\
+% For authors from different institutions:
+% \author{Author 1 \\ Address line \\ ... \\ Address line
+% \And ... \And
+% Author n \\ Address line \\ ... \\ Address line}
+% To start a separate ``row'' of authors use \AND, as in
+% \author{Author 1 \\ Address line \\ ... \\ Address line
+% \AND
+% Author 2 \\ Address line \\ ... \\ Address line \And
+% Author 3 \\ Address line \\ ... \\ Address line}
+
+\author{
+ \textbf{Weijie Shi\textsuperscript{1}\thanks{Co-authors: Hao Chen, Yao Zhao}\thanks{\small{
+ \textbf{Email:} \href{mailto:wshiah@connect.ust.hk}{wshiah@connect.ust.hk}
+ }}},
+ \textbf{Hao Chen\textsuperscript{2}\footnotemark[1]},
+ \textbf{Jiaming Li\textsuperscript{2}},
+ \textbf{Yao Zhao\textsuperscript{2}\footnotemark[1]},
+\\
+ \textbf{Yazhong Zhang\textsuperscript{2}\thanks{Corresponding authors: Yazhong Zhang, Qijin Chen, Jiajie Xu, Jia Zhu}},
+ \textbf{Qijin Chen\textsuperscript{2}\footnotemark[3]},
+ \textbf{Jipeng Zhang\textsuperscript{1}},
+ \textbf{Ruiyuan Zhang\textsuperscript{1}},
+\\
+ \textbf{Jia Zhu\textsuperscript{3}\footnotemark[3]},
+ \textbf{Jiajie Xu\textsuperscript{4}\footnotemark[3]},
+ \textbf{Xiaofang Zhou\textsuperscript{1}}
+\\
+\\
+ \textsuperscript{1}The Hong Kong University of Science and Technology,
+ \textsuperscript{2}Alibaba Group, \\
+ \textsuperscript{3}Zhejiang Key Laboratory of Intelligent Education Technology and Application, \\ Zhejiang Normal University,
+ \textsuperscript{4}Soochow University
+}
+%\author{
+% \textbf{First Author\textsuperscript{1}},
+% \textbf{Second Author\textsuperscript{1,2}},
+% \textbf{Third T. Author\textsuperscript{1}},
+% \textbf{Fourth Author\textsuperscript{1}},
+%\\
+% \textbf{Fifth Author\textsuperscript{1,2}},
+% \textbf{Sixth Author\textsuperscript{1}},
+% \textbf{Seventh Author\textsuperscript{1}},
+% \textbf{Eighth Author \textsuperscript{1,2,3,4}},
+%\\
+% \textbf{Ninth Author\textsuperscript{1}},
+% \textbf{Tenth Author\textsuperscript{1}},
+% \textbf{Eleventh E. Author\textsuperscript{1,2,3,4,5}},
+% \textbf{Twelfth Author\textsuperscript{1}},
+%\\
+% \textbf{Thirteenth Author\textsuperscript{3}},
+% \textbf{Fourteenth F. Author\textsuperscript{2,4}},
+% \textbf{Fifteenth Author\textsuperscript{1}},
+% \textbf{Sixteenth Author\textsuperscript{1}},
+%\\
+% \textbf{Seventeenth S. Author\textsuperscript{4,5}},
+% \textbf{Eighteenth Author\textsuperscript{3,4}},
+% \textbf{Nineteenth N. Author\textsuperscript{2,5}},
+% \textbf{Twentieth Author\textsuperscript{1}}
+%\\
+%\\
+% \textsuperscript{1}Affiliation 1,
+% \textsuperscript{2}Affiliation 2,
+% \textsuperscript{3}Affiliation 3,
+% \textsuperscript{4}Affiliation 4,
+% \textsuperscript{5}Affiliation 5
+%\\
+% \small{
+% \textbf{Correspondence:} \href{mailto:email@domain}{email@domain}
+% }
+%}
+
+\begin{document}
+\maketitle
+\begin{abstract}
+Retrieval-augmented language models (RALMs) aim to incorporate external knowledge to address the issues of factual hallucination and knowledge obsolescence faced by large language models (LLMs). Inevitably, the retrieved passages based on similarity search may be irrelevant to the given question, and the aggregation of these passages can confuse the model to give a correct answer. To improve the performance of RALM in such conditions, we propose layer-knowledge guided attention for RALMs, which harnesses the layer-wise knowledge of LLMs to optimize per-layer attention on useful passages, making the model pay attention to the most relevant content and ignore irrelevant ones. Specifically, we first systematically study LLM's attention patterns and their relationship with the accuracy of RALM responses, where middle-focus attentions play a crucial role in selectively gathering relevant information. Based on this, a layer-wise passage estimator leverages the varied knowledge encoded across LLM layers to assess not only passage relevance scores but also associated confidences. Finally, a relevance-aware passage fusion enables selective attention to relevant passages, mitigating distractibility and positional bias of causal attention. Experiments show that our method outperforms existing methods on RALM benchmarks.
+\end{abstract}
+
+\section{Introduction}
+Large language models (LLMs) have demonstrated remarkable performance, scalability, and adaptability in various natural language processing tasks \cite{bang2023multitask,guo2023close,chowdhery2022palm}. However, LLMs encounter significant challenges when tackling knowledge-intensive tasks, including factual hallucination \cite{cao2020factual,raunak2021curious,ji2023survey}, knowledge obsolescence \cite{he2022rethinking}, and a lack of domain-specific expertise \cite{shen2023chatgpt,li2023chatgpt}. To address these issues, retrieval-augmented language model (RALM) has emerged as a mainstream approach, which leverages a retrieval-then-read pipeline to supply external information for the LLM answering questions.
+
+\input{figure/fig-motivation}
+
+Despite RALM's potential, LLMs struggle to handle retrieved passages, which contain irrelevant ones, hindering performance in two aspects:
+\begin{itemize}
+ \item \textbf{Attention Distractibility}: As shown in Figure \ref{fig:retrieved passages number}, while increasing retrieved passages improves recall linearly, LLM accuracy plateaus or declines due to attention disruption from irrelevant content \cite{shi2023large}. The question tokens' attention becomes scattered across noisy information in the passages.
+ \item \textbf{Positional Bias}: As illustrated in Figure \ref{fig:answer position}, LLM performance exhibits a U-shaped curve based on passage position, with better handling of information at the start and end while missing crucial middle content \cite{liu2024lost}. This stems from LLM attention's over-reliance on positional information.
+\end{itemize}
+While RankRAG \cite{yoran2024making} attempts to address these issues by filtering irrelevant passages and optimizing passage placement, these serve as compromised strategies rather than fundamental fixes to LLM attention processing retrieved passages.
+
+In this paper, we propose Layer-Knowledge Guided Attention for RALM (LKG-RALM), which harnesses the layer-wise knowledge of LLMs to optimize attention on useful passages. To effectively guide LLM's attention, accurately assessing the relevance of retrieved passages is crucial. Recent works demonstrate that LLM-based embeddings significantly outperform BERT-like models on the MTEB leaderboard due to superior scaling and comprehensive pre-training. Furthermore, \citet{rome,chuang2023dola} indicate that different LLM layers encode varied knowledge, from grammatical understanding in lower layers to reasoning capabilities in higher ones. Building on these insights, we propose a layer-wise passage estimator, which fully leverages varied knowledge of LLM layers to accurately predict both relevance and estimation confidence. Since not all layers' knowledge contributes equally to relevance assessment, an entropy-based layer-knowledge selection is proposed to dynamically determine which layers' knowledge is suitable for passages. To mitigate distractibility and positional bias from irrelevant passages, a relevance-aware passage fusion employs a relevance-guided attention mask to enable question tokens to selectively attend to retrieved passages for middle-focused attention patterns. Experiments demonstrate that LKG-RALM achieves substantial performance improvements across RALM datasets. Our contributions are summarized as:
+\begin{itemize}
+ \item We present the first systematic study on the relation between RALM's attention patterns and performance. Based on these, LKG-RALM leverages layer-wise knowledge to guide middle-focused attention toward relevant passages, thereby enhancing the understanding of retrieved information.
+ \item We propose a layer-wise passage estimator to utilize LLM layer-specific knowledge to assess reliable and adaptable passage relevance.
+ \item We propose relevance-aware passage fusion to enable question tokens to selectively attend to relevant passages, mitigating distractibility and positional bias.
+\end{itemize}
+
+\input{figure/fig-attention distribution}
+\input{figure/fig-attention evidence}
+
+\section{Related Work}
+\subsection{Retrieval-augmented Language Model}
+Retrieval-augmented language models (RALMs) \cite{survey1,zhao2023survey,survey2} enhance generation by incorporating retrieved passages through three main approaches: query-based fusion, which concatenates passages with input queries \cite{REPLUG,RALM} or features \cite{FID,liu2023recap}; logits-based fusion, which combines probability distributions from input and retrieved passages \cite{khandelwal2019generalization,huang2023k}; and latent fusion, which integrates passages into hidden states via attention \cite{wang2023shall} or weighted additions \cite{wu2024improving}.
+
+Recent work has focused on addressing noise in retrieved passages. \citet{liu2024lost} analyzed position bias across model types and query positions, while \citet{shi2023large,wu2024instructing} attempted to incorporate passage relevance into context. Other approaches include filtering irrelevant passages \cite{zhang2021drop,yoran2024making} and developing noise-resistant fine-tuning strategies \cite{liu2024chatqa,yu2024rankrag}. However, these methods remain constrained by reranking accuracy and fail to address the fundamental limitations of causal attention. Our work investigates the relationship between attention patterns and RALM performance, leading to our LKG-RALM approach that leverages layer-wise knowledge for improved passage attention.
+
+\subsection{Passage Relevance Assessment}
+While traditional methods like BM25 \cite{bm25} and BERT-based models \cite{DPR,contriever,bge} have advanced text representation, they face scaling challenges in representation training. Recent approaches \cite{wang2023improving,behnamghader2024llm2vec,springer2024repetition} have shown promise in adapting decoder-only LLMs as text encoders through contrastive learning. However, even state-of-the-art models achieve only 62\% accuracy on the MTEB leaderboard \cite{MTEB}, highlighting the need for more nuanced relevance assessment approaches.
+
+\citet{rome,chuang2023dola,zhang2024comprehensive} have shown that LLMs encode layer-specific knowledge, ranging from grammatical structures in lower layers to complex reasoning in higher ones. Building on this insight, we propose a layer-wise passage estimator that leverages this hierarchical knowledge structure to provide comprehensive relevance assessments with reliability measures.
+
+
+
+\input{figure/fig-llm framework}
+
+\section{Preliminaries}
+\subsection{Problem Formalization}
+Our method is depicted under the open question-answering (open-QA) settings, aiming to predict an answer $y_{ans}$ based on a question $q$ and $n$ retrieved passages $[p_1,\dots,p_n]$.
+
+\subsection{Analysis of Attention Patterns of RALM}
+To address the challenges of attention's distractibility and positional bias in RALM, it is crucial to systematically investigate its attention patterns. The attention mechanism selects specific tokens to gather information from retrieved passages for the generation of next token. Following the \citet{fu2024attentionpattern} methodology, we conducted a systematic study on the attention distribution of LLAMA-3.1-8B using 2000 samples from the NQ and TriviaQA dataset (details in Appendix \ref{Attention pattern analysis}). Figure \ref{fig:attention pattern} reveals three distinct attention patterns that potentially impact the model's ability to process retrieved passages.
+
+
+\textbf{Edge-focused attention}, observed in 78\% of attention heads, shows over 99\% of attention concentrating on the beginning and end of the context. \citet{xiao2023efficient} demonstrated that this phenomenon persists even when replacing the initial tokens with meaningless ones, indicating that the model emphasizes absolute position rather than semantic value. This pattern correlates strongly with positional bias, hindering the model's ability to process crucial information in the middle of the input sequence.
+
+\textbf{Uniform attention}, accounting for 5.37\% of patterns, distributes attention almost uniformly across all tokens in the context. While appearing to provide equal consideration to all information, this pattern potentially contributes to the model's distractibility by failing to focus on the most relevant parts of the input.
+
+\textbf{Middle-focused attention}, though present in only 6\% of attention heads, manifests in two variants: "scattered over middle" and "concentrated on middle". The former distributes attention across several tokens, while the latter concentrates on only one or two tokens. This pattern plays a crucial role in selectively gathering information from the context, essential for comprehending retrieved passages.
+
+To analyze the relationship between these patterns and RALM performance, we examined the correlation between attention weight sums and model accuracy. Figure \ref{fig:attention pattern}(d) reveals that increased attention on relevant passages in edge-focused and uniform patterns yielded no performance gains, while middle-focused patterns demonstrated a strong positive correlation with RALM accuracy. Manipulation experiments further supported these findings: artificially replacing edge-focused and uniform patterns with middle-focused attention on relevant passages disrupted the model's attention structure, leading to performance degradation. As shown in Figure \ref{fig:attention manipulation}, deliberately redirecting middle-focused patterns to irrelevant passages significantly decreased performance, while concentrating this attention on relevant passages improved it. These results suggest that guiding middle-focused attention towards relevant passages could significantly enhance RALM's effectiveness.
+
+% \input{figure/fig-llm pipeline}
+
+\input{figure/fig-passage relevance estimator}
+
+\section{Methodology}
+\subsection{Overview}
+The vanilla attention of LLMs often suffers from distractibility and positional bias, which is unsuitable for open-QA with retrieved passages. We take advantage of layer-wise knowledge of LLMs to assess passage relevance, then guide the LLM's attention to generate answers, effectively mitigating these issues. The overall framework is illustrated in Figure \ref{fig:llm framework}.
+
+% It first utilizes a layer-wise passage estimator to mine diverse knowledge encoded in different LLM layers, assessing the relevance of retrieved passages along with the reliability and adaptability of these relevance scores. Next, these assessments guide the LLM's attention through a relevance-aware passage integration component, enhancing robustness against irrelevant passages. Finally, three tailored losses optimize the relevance estimations, while a language model loss jointly fine-tunes the estimator and the LLM for improved performance.
+
+\subsection{Adding Special Tokens}
+To clearly delineate the boundary of the given question and each passage, we introduce trainable special tokens into the sequence. Specifically, we add \textcolor{blue}{$[d]$} and \textcolor{orange}{$[e_i]$} tokens as boundary markers at the beginning and end of each retrieved passage, respectively, while \textcolor{blue}{$[q]$} and \textcolor{yellow!70!black}{$[e_q]$} tokens demarcate the question.
+
+\subsection{Layer-wise Passage Estimator}
+Traditional passage estimators often rely on BERT-like structures, but these methods typically yield low accuracy and fail to leverage the rich, layer-specific knowledge embedded in LLMs. We propose a layer-wise passage estimator as Figure \ref{fig:passage relevance estimator}, which utilizes per-layer knowledge of LLM to assess relevance scores from multifaceted views, along with their associated confidence. Additionally, it incorporates an entropy-based layer-knowledge selection, which analyzes the attention distribution to determine the applicability of each layer's knowledge to the passages. By combining these comprehensive estimations, our approach provides trustworthy guidance for the LLM's attention.
+
+\subsubsection{Layer-wise Relevance and Confidence Estimation.}
+For a given layer $l$, we leverage the LLM's internal representations to compute relevance and confidence scores. To adapt the LLM's parameter to relevance assessment, we add trainable low-rank weights (LoRA) to each decoder layer. To enhance contextual understanding within passages, we follow previous work \cite{behnamghader2024llm2vec} to adopt Blocked Bidirectional Attention Mask rather than the causal attention. To aggregate sentence-level information, we extract the hidden states of the last special tokens as sentence embedding: $e_i$ for each passage and $e_q$ for question. An adapter and dropout components are then used to enhance their robustness. Finally, two cross-attention components compute the relevance scores $r_1,\dots,r_n$ and confidence scores $c_1,\dots,c_n$ between passages and the question, respectively.
+
+\subsubsection{Optimizing Estimator.}
+To optimize our layer-wise passage estimator, we introduce three specialized loss functions, each addressing a crucial aspect of effective relevance estimation:
+
+\paragraph{Relevance Loss.} To ensure the model accurately identifies relevant passages, we employ a relevance loss. This loss function encourages the estimated relevance scores to closely align with the ground truth, thereby improving the model's ability to distinguish between relevant and irrelevant passages:
+\begin{equation}
+L_{relevance} = -\frac{1}{n} \sum_{i=1}^n [\bar{r_i} \log(r_i) + (1-\bar{r_i}) \log(1-r_i)]
+\end{equation}
+where $\bar{r_i}$ is the ground truth label, and $r_i$ is the estimated relevance score.
+
+\paragraph{Confidence Loss} Recognizing that not all relevance predictions are equally reliable, we introduce a confidence loss. We posit that the model should exhibit high confidence for easier samples to classify, while maintaining lower confidence for more challenging and confusing cases. To this end, we leverage external models (such as BGE \cite{bge}) to assist in determining sample difficulty:
+\begin{equation}
+L_{confidence}\!\! =\!\! -\!\!\sum_{i=1}^n\! ( c_i \! - \! [ \!\frac{1}{K} \!\! \sum_{k=1}^K I(M_k(r_i|q) \!\! == \! \bar{r_i})])
+\end{equation}
+where $c_i$ is the estimated confidence score, $M_k$ represents $K$ different external models, and $I(\cdot)$ is the indicator function. This loss trains the model to produce confidence scores that accurately reflect the trustworthiness of its relevance predictions:
+
+\paragraph{Diversity Loss.} To ensure a comprehensive utilization of layer-specific knowledge and avoid overly homogeneous relevance guidance, we employ a diversity loss based on the entropy of the final relevance guidance:
+\begin{equation}
+L_{diversity} = -H(\alpha_1, \ldots, \alpha_n)
+\end{equation}
+where $H(\cdot)$ is the entropy function, and $\alpha_1, \ldots, \alpha_n$ are the final relevance guidance weights.
+
+Combining these loss functions through simple addition, our estimator learns to provide accurate, confident, and diverse relevance assessments across different layers of the LLM. The Relevance Loss helps to quickly narrow down the search space to the most pertinent passages, while the Diversity Loss encourages a broader exploration of potentially relevant information, increasing the chances of recalling the correct answer. Although these two losses may seem antagonistic, their balanced combination leads to a more robust and comprehensive relevance assessment.
+
+\subsubsection{Entropy-based Layer-Knowledge Selection.}
+Inspired by \citet{hyeon2023scratching}, to ensure the effective utilization of layer-specific knowledge in passage assessment, we propose an entropy-based layer-knowledge selection to identify which layers provide the most informative and contextually rich representations for each passage.
+
+Specifically, for each passage $p_i$, we calculate the entropy $H_i$ of the attention distribution from each passage's last special token to other tokens in the sequence:
+\begin{equation}
+H_i = -K\sum_{j=1}^{n} w_i^j \log w_i^j
+\end{equation}
+where $w_i^j$ denotes the attention weight from the last special token $e_i$ to the $j$-th token in the sequence, $n$ is token number of passage $p_i$, and $K$ is the scaling factor. A higher entropy value indicates that the sentence embedding gathers a broader range of contextual information.
+
+Finally, we use a selection weight to aggregate the layer-wise relevance and confidence scores to update the relevance guidance:
+\begin{equation}
+\begin{aligned}
+\alpha_i^l = \beta &\left(\log(1 + H_i) \cdot \log(1 + r_i) \cdot \log(1 + c_i)\right) \\
+ &+ (1-\beta) \alpha_i^{l-1}
+\end{aligned}
+\end{equation}
+where $\alpha_i^l$ is the updated relevance guidance for passage $p_i$ at layer $l$, and $\beta$ balances current and previous layer assessments. To mitigate numerical oversensitivity, we employ a logarithmic multiplication. This approach combines layer-wise relevance and confidence estimation with entropy-based layer-knowledge selection, enabling our estimator to leverage diverse knowledge across LLM layers and provide robust guidance for the LLM's attention mechanism.
+
+\subsection{Relevance-aware Passage Fusion}
+To mitigate the issues of distractibility and positional bias, we propose a Relevance-aware Passage Fusion that selectively directs LLM attention to relevant passages based on the relevance guidance obtained from the Layer-wise Passage Estimator.
+
+To effectively guide the LLM's attention towards relevant passages while mitigating the effects of distractibility and positional bias inherent in traditional attention frameworks, we introduce a relevance-guided attention mask. This mask dynamically modulates query-passage interactions based on estimated relevance, preserves intra-passage context, and inhibits cross-passage interference, thereby enhancing the model's capacity to prioritize salient information. The mask modulates the attention weights based on the estimated relevance of each passage. Formally, for each layer $l$, we define the attention mask $M^l$ as:
+\begin{equation}
+M^l_{ij} = \begin{cases}
+\begin{aligned}
+&\alpha_k^l, & &\text{if } i \in q \text{ and } j \in p_k \\
+& & &\text{(middle-focused attention heads)} \\
+&1, & &\text{if } i \in q \text{ and } j \in p_k \\
+& & &\text{(other attention heads)} \\
+&1, & &\text{if } i \in p_k \text{ and } j \in p_k \\
+& & &\text{(same passage)} \\
+&0, & &\text{if } i \in p_k \text{ and } j \in p_m \text{ where } k \neq m \\
+& & &\text{(different passages)}
+\end{aligned}
+\end{cases}
+\end{equation}
+where $q$ represents the set of query token positions, $p_k$ is the set of token positions for passage $k$, and $\alpha_k^l$ is the relevance guidance for passage $k$ at layer $l$. As our analysis of RALM attention, we selectively apply relevance-guided attention mask to Middle-focused attention heads only, while maintaining the functionality of Edge-focused and Uniform attention patterns. Finally, we use the standard language modeling loss to jointly fine-tune the LLM.
+
+\section{Experiments}
+\subsection{Experimental Setting}
+\subsubsection{Datasets}
+To assess performance across diverse data characteristics, we employ a range of representative datasets for RALM evaluation. These include Natural Question (NQ) \cite{NQ}, TriviaQA \cite{TQA}, StrategyQA \cite{strategyQA}, HotpotQA \cite{HotpotQA}, PopQA \cite{PopQA}, and 2WikiMQA \cite{ho2020constructing}. Detailed descriptions are provided in Appendix \ref{dataset description}.
+
+\subsubsection{Baselines}
+We categorize our baselines into three groups: closed-book LLM without retrieval, LLM with retrieval, and robust RALM. The first two groups include LLAMA-3.1 \cite{dubey2024llama}, Qwen-2.5 \cite{yang2024qwen2}, ChatGPT, GPT4, and Claude-3-Sonnet. The third group comprises REPLUGE \cite{REPLUG}, Self-RAG \cite{asai2023self}, RA-ISF \cite{liu2024ra} Noise-Resistant \cite{yoran2024making}, ChatQA-1.5 \cite{liu2024chatqa} and RankRAG \cite{yu2024rankrag}. Comprehensive descriptions of these baselines are provided in Appendix \ref{baseline setting}.
+
+\subsubsection{Evaluation criteria}
+In evaluating the quality of the predicted answers, we employ the standard exact match (EM) metric on 5-shot, following previous work \cite{DPR,atlas}. The generated answer is first normalized by lowercasing, removing articles, punctuation, and duplicated whitespace. The EM score is binary for one question, with a value of 1 if the predicted answer matches the ground-truth answer exactly, and 0 otherwise. Then we averaged the EM scores across all questions in the test set and then multiplied by 100 to obtain final scores.
+
+\subsection{Overall Performance}
+\input{figure/table-ODQAperformance}
+The results in Table 1 demonstrate varying performance across different model types. Closed-book LLMs show strong baseline performance but face limitations in their knowledge base. Adding retrieval generally improves performance, as seen with LLAMA-3.1-70B improving from 21.8 to 42.7 on NQ, and further fine-tuning brings additional gains (reaching 44.9). However, this improvement isn't consistent across all models and datasets. For instance, GPT-4 with retrieval shows decreased performance on TriviaQA (87.0 to 75.0).
+
+The combination of retrieval and fine-tuning shows promising results, particularly for larger models. Qwen-2.5-72B benefits significantly from both enhancements, with performance on NQ improving from 39.9 (base) to 45.1 (+ retrieval) to 47.6 (+ retrieval \& fine-tuning). Claude-3-Sonnet with retrieval achieves strong results, reaching 55.1 on NQ and 90.8 on TriviaQA. Robust RALM methods, particularly RankRAG, demonstrate effective utilization of retrieved passages, showing consistent improvements across datasets. RankRAG achieves strong performance with 54.2 on NQ and 59.9 on PopQA, outperforming many traditional retrieval-augmented approaches.
+
+LKG-RALM outperforms baseline models across all datasets. Compared to closed-book LLMs, it shows substantial gains of 12.3 percentage points on NQ to 10.0 on 2WikiMQA. When compared to retrieval-augmented and fine-tuned models, LKG-RALM still demonstrates superior performance, with Qwen-2.5-72B based LKG-RALM achieving the best results across most metrics (61.5 on NQ, 46.1 on HotpotQA). The performance gap between 8B and 70B variants (2.3-5.7 percentage points) suggests that larger models can better leverage our approach, particularly on complex tasks like PopQA.
+
+
+\subsection{Ablation Results}
+\input{figure/q2-2}
+\subsubsection{Effect of Designed Components}
+Table \ref{table:q2-2} shows the ablation results of LKG-RALM with LLAMA-3.1-8B as the backbone. All proposed components contribute significantly to the final performance. Replacing the Layer-wise Passage Estimator with a reranking model causes a substantial performance drop across all tasks, with an average decrease of 3.96. This highlights its crucial role in using layer-wised LLM knowledge to assess passage relevance. The Entropy-based Layer-Knowledge Selection mechanism proves effective, as its removal leads to an average EM decrease of 1.52, showing the importance of dynamically selecting informative layer representations for each passage. Ablating the Relevance-aware Passage Fusion component results in significant performance degradation, with an average EM decrease of 22.22. This demonstrates our approach's effectiveness in reducing distractibility and positional bias when processing multiple passages, compared to traditional attention mechanisms. Finally, the auxiliary losses improve performance across most tasks by 1.06, indicating their value in guiding the model to consider passage relevance, prediction confidence, and diverse utilization of layer knowledge during training.
+
+% Table \ref{table:q2-2} presents the ablation results of LKG-RALM with LLAMA-3.1-8B as the backbone. Overall, we observe that all of the proposed components contribute significantly to the final performance. Replacing the Layer-wise Passage Estimator with a reranking model leads to a substantial performance drop across all tasks, with an average decrease of 3.96. This underscores its crucial role in utilizing layer-wised LLM knowledge to assess the relevance of retrieved passages. Besides, the Entropy-based Layer-Knowledge Selection mechanism also proves to be effective, as its removal leads to an average EM decrease of 1.52, highlighting the importance of dynamically selecting the most informative layer representations for each passage. Ablating the Relevance-aware Passage Fusion component causes significant performance degradation, with an average EM decrease of 22.22. This demonstrates the effectiveness of our approach in mitigating distractibility and positional bias when processing multiple passages, compared to traditional attention mechanisms. Finally, the auxiliary losses contribute to improved performance across most tasks by 1.06, indicating their value in guiding the model to explicitly consider passage relevance, prediction confidence, and diverse utilization of layer knowledge during training.
+
+\subsection{Robustness Analysis}
+\subsubsection{Increased Number of Retrieved Passages}
+To assess the scalability and efficiency of LKG-RALM in handling larger amounts of retrieved information, we conducted experiments varying the number of retrieved passages from 0 to 50. Figure \ref{fig:q4-1} illustrates the performance trends of LKG-RALM compared to baseline models across different datasets.
+
+The experiment reveals that LKG-RALM demonstrates superior scalability and maintains high performance even as the number of retrieved passages increases significantly. LKG-RALM shows a steady increase in EM scores from 24.8 to 61.5 as the number of passages grows, with only a slight plateau effect beyond 35 passages, indicating effective utilization of additional information without suffering from information overload. In contrast, baseline models like RankRAG and Self-RAG initially show improvements with more passages, but their performance begins to degrade or plateau beyond 25 passages. RankRAG reaches a peak of 53.7 at 25 passages before slightly declining to 54.2 at 50 passages, while Self-RAG peaks at 43.5 at 25 passages before sharply declining to 28.4 at 50 passages. GPT-4 with Retrieval shows remarkable stability, maintaining a nearly constant performance (around 40.4) regardless of the number of passages, indicating its strong innate knowledge but potential limitations in effectively utilizing additional retrieved information.
+
+LKG-RALM maintains its edge, with a 6.8 EM score advantage over RankRAG at 50 passages (61.0 vs 54.2). Despite the increased potential for irrelevant information with more passages, LKG-RALM's performance remains robust, underscoring the effectiveness of its relevance-aware passage fusion mechanism.
+
+\subsubsection{Higher Proportions of Irrelevant Passages}
+\input{figure/q4}
+Our method internally attends attention to relevant passages for the given question, facilitating evidence-seeking from noisy contexts. To evaluate its robustness and noise tolerance, we conducted adversarial testing by incrementally replacing the 50 retrieved passages with irrelevant passages, ranging from 0\% to 100\% substitution.
+
+From Figure \ref{fig:q4-2}, LKG-RALM showed strong resilience against irrelevant information. When increasing irrelevant passages to 100\%, the EM score only gradually decreased from 61.0 to 42.7, significantly outperforming other retrieval-based models. Even with 80\% irrelevant input, LKG-RALM maintained a strong EM score of 54.8. In comparison, models without explicit relevance modeling like RankRAG saw sharp performance drops, falling from 54.2 to 15.6 with fully irrelevant passages. While GPT-4 with Retrieval showed high noise tolerance, dropping only from 40.4 to 31.5 under fully irrelevant conditions, it did not leverage relevant information as effectively as LKG-RALM, as shown in our earlier experiment. LKG-RALM's superior performance stems from its explicit relevance modeling, which helps it focus on pertinent information while filtering out noise. This allows it to effectively balance the use of retrieved knowledge with its inherent model capabilities.
+
+
+% From Figure \ref{fig:q4-2}, we observed that LKG-RALM demonstrated remarkable resilience to noise and irrelevant information. As the proportion of irrelevant passages increased to 100\%, the EM score exhibited only a gradual decline from 61.0 to 42.7, significantly outpacing other retrieval-based models under similar conditions. Even when 80\% of the input was irrelevant, LKG-RALM maintained a high EM score of 54.8. In contrast, models lacking explicit relevance modeling, such as RankRAG, experienced substantial performance degradation, dropping from 54.2 to 15.6 under fully irrelevant conditions. GPT-4 with Retrieval showed the highest robustness to irrelevant passages, with its performance only marginally decreasing from 40.4 to 31.5 under fully irrelevant conditions. However, it did not fully utilize additional relevant information as effectively as LKG-RALM, as seen in the first experiment. The superior performance of LKG-RALM can be attributed to its explicit relevance modeling, enabling it to focus on pertinent passages while effectively ignoring distractions in extremely noisy environments. By leveraging its well-trained parametric knowledge and external information effectively, LKG-RALM strikes an optimal balance between utilizing retrieved knowledge and relying on inherent model capabilities.
+
+
+
+% \input{figure/q5}
+% \subsection{(Q5) Visual Analysis}
+% % 不同attention层的可视化效果,参考ROME论文,看attention权重找文章准不准,对比相关性添加前后的可视化效果
+% To offer transparent insights into the inner workings of the HC-LLM architecture, we conduct explanatory visual analysis investigating the cross-attention distributions modeled within the Question-Passage Interaction component. Specifically, we analyze the impact of incorporating supervised question-passage relevance on steering the model's evidence extraction towards pertinent passages.
+
+% In the default setting without explicit relevance modeling, the attention weights are dispersed across all input passages regardless of pertinence. However, after integrating relevance supervision that imparts physical interpretation to the attention distributions, we observe a gradual concentration of the attention weights towards key evidence-bearing passages across layers, effectively filtering out distracting irrelevant ones. For example in Figure \ref{fig:visual}, the relevance-incorporated attention heatmap clearly illustrates shrinking weights allocation on non-critical passages. Meanwhile, passages containing answer-bearing evidence receive concentrated attention. This validates that the relevance modeling successfully provides effective top-down guidance for HC-LLM to focus on salient passages and mitigate noise interference.
+
+% Remarkably, we also notice substantial attention retained on the question itself in deeper layers. This intriguing phenomenon suggests that without useful evidence contexts, HC-LLM can still leverage its internal pre-trained knowledge and reasoning strength to deeply comprehend the question semantics. It can then generate answers by exploiting the parametric connections between query concepts and world knowledge within the foundation language model. Such intrinsic capability offers fail-safe robustness against absent or noisy passages. Overall, these visualizations offer explanatory power regarding how the relevance mechanism interacts with pre-trained knowledge to facilitate noise-resilient evidence extraction for ODQA.
+
+\section{Conclusion}
+In this work, we proposed LKG-RALM, which leverages layer-wise knowledge within LLMs to guide attention toward relevant passages, addressing distractibility and positional bias in handling retrieved passages. A layer-wise passage estimator evaluates passage relevance by utilizing diverse layer knowledge within the LLM. Entropy-based layer-knowledge selection dynamically identifies the most relevant layers for accurate passage assessment. Relevance-aware passage fusion selectively prioritizes crucial content, reducing the impact of irrelevant passages and overcoming positional bias. Extensive experiments across multiple datasets demonstrate that LKG-RALM achieves notable improvements in accuracy and robustness for knowledge-intensive tasks.
+
+\section*{Limitations}
+Our work has several important limitations that should be acknowledged:
+
+First, while our layer-wise passage estimator significantly improves RALM performance, it introduces additional computational overhead. The need to process passages through multiple layers for relevance assessment increases both memory usage and inference time. Although this overhead is relatively small compared to the base LLM inference, it may impact real-time applications or resource-constrained environments. Future work could explore more efficient methods for leveraging layer-wise knowledge without significant computational costs.
+
+Second, our approach relies heavily on the quality of retrieved passages. While LKG-RALM shows improved robustness to irrelevant passages, its performance still degrades when the retrieval quality is poor or when dealing with queries requiring information beyond the knowledge cutoff date of the retrieval corpus. This limitation is particularly evident in rapidly evolving domains where the retrieved information may become outdated quickly.
+
+Third, the effectiveness of our layer-knowledge selection mechanism may vary across different LLM architectures and sizes. While we demonstrated strong performance with LLAMA-3.1 and Qwen-2.5, the optimal configuration of layer-wise knowledge utilization might need to be adjusted for different model architectures. Additionally, our current approach to entropy-based layer selection may not capture all aspects of layer-specific knowledge representation.
+
+\section*{Ethics Statement}
+Our work utilizes publicly available datasets and pre-trained language models, adhering to established data usage guidelines. However, several ethical considerations deserve attention. While LKG-RALM shows improved robustness in handling retrieved information, it inherits potential biases present in both the pre-trained language models and the retrieval corpus, which could affect the model's responses across different demographic groups or topic areas. We emphasize that our work primarily focuses on technical improvements in retrieval-augmented language modeling and should be complemented with dedicated bias mitigation strategies. Additionally, the improved performance of our model in handling retrieved passages raises questions about information authenticity and attribution. While LKG-RALM can better identify and utilize relevant information, users should be aware that the model's responses are based on retrieved passages that may contain inaccuracies or outdated information. We recommend implementing clear attribution mechanisms and confidence indicators in practical applications.
+
+\section*{Acknowledgments}
+This work is supported by the National Natural Science Foundation of China (Grant No. 62102277), the National Key R\&D Program of China under Grant No. 2022YFC3303600, the Zhejiang Provincial Natural Science Foundation of China under Grant No. LY23F020010.
+
+\bibliography{custom}
+
+\appendix
+
+\section{Attention Analysis Setting} \label{Attention pattern analysis}
+\subsection{Model and Dataset Selection}
+Our analysis of attention patterns in Retrieval-Augmented Language Models (RALMs) was conducted using LLAMA-3.1-8B-instruct as the base model. We randomly selected a sample of 2000 queries from the Natural Questions and TraviaQA datasets, which consist of real-world queries submitted to Google Search and Allen Institute along with high-quality human-annotated answers extracted from Wikipedia pages.
+
+\subsection{Attention Pattern Classification}
+In our study, we focused on three main categories of attention patterns: edge-focused attention, uniform attention, and middle-focused attention. To analyze these patterns, we examined the attention distribution across all attention heads in the model. For each attention head, we calculated the percentage of attention allocated to different parts of the input sequence, specifically the beginning, middle, and end. Based on this distribution, we classified each attention head into one of the three main categories. We then quantified the prevalence of each attention pattern type across all attention heads to gain a comprehensive understanding of the model's attention behavior.
+
+\subsection{Correlation Analysis}
+To investigate the relationship between attention patterns and model performance, we conducted a correlation analysis. This involved calculating the sum of attention weights for each pattern type and measuring the model's accuracy on the test set. We then computed the correlation between these attention weight sums and the model's accuracy for each pattern type. This analysis allowed us to identify which attention patterns were most strongly associated with improved model performance.
+
+\subsection{Manipulation Experiments}
+To further validate our findings and explore the causal relationships between attention patterns and model performance, we conducted two types of manipulation experiments. In the first experiment, we artificially replaced edge-focused and uniform attention patterns with middle-focused attention on relevant passages. This allowed us to observe how redirecting attention to potentially more informative parts of the input affected the model's performance. In the second experiment, we deliberately redirected middle-focused attention patterns to both irrelevant and relevant passages. By comparing the model's performance under these different conditions, we were able to assess the impact of focused attention on specific parts of the input.
+
+Throughout our analysis, we created detailed visualizations to illustrate the different attention patterns and their impact on model performance. These visualizations, presented in Figure \ref{fig:attention pattern} and Figure \ref{fig:attention replacement}, provide a clear and intuitive representation of our findings, helping to elucidate the complex relationships between attention mechanisms and RALM performance. By systematically examining different types of attention patterns, their prevalence, and their relationship to model accuracy, we have identified potential areas for improvement in model design and training, particularly in guiding attention to relevant parts of the input for enhanced performance in open-domain question answering tasks.
+
+\section{Dataset Description} \label{dataset description}
+To comprehensively evaluate the performance of Retrieval-Augmented Language Models (RALMs) across diverse data characteristics, we employ a range of representative datasets covering various aspects of question-answering tasks, from factoid questions to multi-hop reasoning and strategy-based inquiries. Below, we provide detailed descriptions of each dataset:
+
+\begin{itemize}
+ \item \textit{Natural Questions (NQ)}: Developed by Google Research \cite{NQ}, this dataset comprises real-world queries submitted to Google Search, accompanied by high-quality human-annotated answers extracted from Wikipedia pages. NQ offers a rich mixture of long and short answer formats, reflecting authentic information-seeking behaviors across a broad range of subjects. Its structure, featuring both comprehensive passages and concise answer spans, provides a nuanced testing ground for RALMs.
+ \item \textit{TriviaQA}: Crafted by researchers at the Allen Institute for AI \cite{TQA}, TriviaQA presents a formidable challenge with its extensive collection of question-answer pairs. These are sourced from trivia enthusiasts and paired with supporting evidence from Wikipedia and web searches. The dataset's hallmark is its high lexical and syntactic variance between questions and answers, necessitating robust retrieval and reasoning capabilities from models. By spanning both web and Wikipedia domains, it offers a comprehensive evaluation landscape.
+ \item \textit{StrategyQA}: Developed by the Allen Institute for AI \cite{strategyQA}, it focuses on multi-hop reasoning questions that demand implicit strategic thinking. StrategyQA's questions often require common sense reasoning, with answers typically being binary (yes/no) but necessitating complex cognitive processes. It is specifically designed to challenge and evaluate models' strategic thinking abilities, pushing the boundaries of AI reasoning.
+ \item \textit{HotpotQA}: A collaborative effort by Carnegie Mellon University \cite{HotpotQA}, HotpotQA features Wikipedia-based question-answer pairs that explicitly require reasoning across multiple supporting documents. It includes sentence-level supporting facts for answer explanation and maintains a balance across different reasoning types, such as bridging and comparison. This structure makes HotpotQA particularly effective in assessing multi-hop reasoning capabilities.
+ \item \textit{PopQA}: Created by researchers at the University of Washington \cite{PopQA}, PopQA centers on questions about popular culture, including movies, music, celebrities, and current events. This dataset is crucial for testing models' ability to handle contemporary and rapidly evolving information. It challenges RALMs to navigate ambiguity and context-dependent information, reflecting the dynamic nature of real-world knowledge.
+ \item \textit{2WikiMQA}: Developed by the Graduate University for Advanced Studies \cite{ho2020constructing}, 2WikiMQA is a multi-hop open-domain question-answering dataset constructed from Wikipedia. It features questions that necessitate reasoning across multiple Wikipedia pages and includes complex queries that cannot be answered by a single fact. This dataset is designed to simultaneously test both retrieval accuracy and advanced reasoning capabilities of RALMs.
+\end{itemize}
+
+By employing this diverse set of benchmarks, we aim to provide a holistic assessment of model capabilities, from factual recall to complex reasoning and strategic thinking.
+
+\section{Baseline Settings} \label{baseline setting}
+Our baseline methods for open-QA tasks represent a diverse range of approaches, from pure language models to sophisticated retrieval-augmented systems. We categorize these baselines into three groups: closed-book LLM without retrieval, LLM with retrieval, and robust RALM. Each group showcases different strategies for tackling RAG challenges.
+
+\subsection{Closed-book LLM without retrieval and LLM with retrieval}
+The first two groups encompass state-of-the-art language models that have demonstrated exceptional capabilities in various natural language processing tasks. For models in the LLM with retrieval category, we employ a straightforward approach of concatenating retrieved content to the context, allowing the LLM to process the augmented input:
+\begin{itemize}
+ \item \textit{LLAMA-3.1} (2024) \cite{dubey2024llama}: The latest iteration in the LLaMA series, LLAMA-3.1 builds upon 15 trillion texts, achieving the most effective open-source ability.
+ \item \textit{Qwen-2.5} (2024) \cite{yang2024qwen2}: Developed by Alibaba, Qwen-2.5 represents a significant advancement in multilingual capabilities, trained on 18 trillion data to achieve state-of-the-art performance across various tasks.
+ \item \textit{ChatGPT} (2022): Developed by OpenAI, this model has gained widespread recognition for its conversational prowess and extensive knowledge base across diverse domains.
+ \item \textit{GPT-4} (2023) \cite{gpt4}: A large-scale, multimodal model developed by OpenAI, capable of accepting image and text inputs and producing text outputs. It exhibits human-level performance on various professional and academic benchmarks.
+ \item \textit{Claude-3-Sonnet} (2024): An advanced AI model from Anthropic, part of the Claude 3 model family, known for its strong performance across a wide range of tasks.
+\end{itemize}
+
+\subsection{Robust RALM}
+The third group comprises advanced retrieval-augmented language models that enhance the robustness and effectiveness of RAG:
+\begin{itemize}
+ \item \textit{REPLUG} (2023) \cite{REPLUG}: A retrieval-augmented language modeling framework that treats the language model as a black box and augments it with a tuneable retrieval model. It simply prepends retrieved documents to the input for the frozen black-box LM.
+ \item \textit{Self-RAG} (2023) \cite{asai2023self}: A framework that enhances an LM's quality and factuality through retrieval and self-reflection. It trains a single arbitrary LM that adaptively retrieves passages on-demand, and generates and reflects on retrieved passages and its own generations using special tokens.
+ \item \textit{RA-ISF} (2024) \cite{liu2024ra}: A framework that iteratively decomposes tasks and processes them in three submodules to enhance the model's problem-solving capabilities. It aims to improve factual reasoning capabilities and reduce hallucinations.
+ \item \textit{Noise-Resistant RALM} (2024) \cite{yoran2024making}: This approach focuses on making retrieval-augmented language models robust to irrelevant context. It proposes two methods: a simple baseline that filters out retrieved passages using an NLI model, and a method for automatically generating data to fine-tune the language model.
+ \item \textit{ChatQA-1.5} (2024) \cite{liu2024chatqa}: An evolution of the ChatQA model, this version introduces refinements aimed at enhancing effectiveness in question-answering tasks, particularly in conversational contexts.
+ \item \textit{RankRAG} (2024) \cite{yu2024rankrag}: A instruction fine-tuning framework that instruction-tunes a single LLM for the dual purpose of context ranking and answer generation in RAG.
+\end{itemize}
+
+In our experimental setup, the RankRAG results are referenced from the original paper using LLaMA-3-70B in a zero-shot setting and are supported by the authors; ChatQA \footnote{https://huggingface.co/nvidia/Llama3-ChatQA-1.5-70B} leverages LLaMA-3-70B in a five-shot setting; Noise-Resistant RALM \footnote{https://huggingface.co/datasets/Ori/strategyqa-ret-robust} is reproduced using LLaMA-3.1-8B; and RA-ISF \footnote{https://github.com/OceannTwT/ra-isf} is implemented with ChatGPT-3.5.
+
+\section{Implementation Details}
+Our model foundation utilizes LLAMA and Qwen. For retrieval, we follow ATLAS \cite{atlas} by using the Wikipedia dump from December 20, 2018, as our external corpus, comprising 28 million passages. We adopt a hybrid retrieval \cite{arivazhagan2023hybrid}, where BM25 is grounded on the Elastic Search \cite{es}, while the dense retriever is based on the FAISS index \cite{FAISS}. The training data is followed by Self-RAG \cite{asai2023self}. The trainable low-rank weights were implemented using LoRA \cite{hu2021lora}, with a rank dimension of 256. The hidden size of the adaptor is set to 4096. We optimized all trainable parameters using the AdamW optimizer with a learning rate of 1e-5. The batch size was set to 32, and a warmup ratio of 0.1 was employed along with a cosine learning rate scheduler. Three external relevance scores are obtained from BGE-M3 \footnote{https://huggingface.co/BAAI/bge-m3}, E5-mistral-7b-instruct \footnote{https://huggingface.co/intfloat/e5-mistral-7b-instruct}, GTE-Qwen2-7B-instruct \footnote{https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct}. The updating factor $\beta$ for layer-wised relevance guidance was set to 0.2.
+
+Notice that we can use separate LLMs for passage estimation and answer generation in parallel. A lighter estimator (e.g., 1.5B) paired with a larger generator (e.g., 8B) minimizes overhead, where the generator can share relevance guidance across some layers due to differing layer counts.
+
+For the attention pattern analysis, we define the first 3 tokens as the head and the last 3 tokens as the tail, with the remaining tokens classified as the middle. We employ threshold-based metrics to distinguish between attention patterns:
+\begin{itemize}
+ \item Edge-focused: Combined attention weights of head and tail exceed 75\%.
+ \item Uniform: Middle attention weights exceed 90\%, with over 40\% of tokens having attention weights greater than 1/(input length), and no single token's attention weight exceeding 10\%. The term "uniform" is somewhat hyperbolic. What it actually represents is a pattern where no single token receives exceptionally high attention.
+ \item Middle-focused: Middle attention weights exceed 90\%, with either one or two tokens having attention weights above 30\%, or three or more tokens having attention weights above 10\%.
+\end{itemize}
+
+
+\section{Performance on general NLP tasks}
+Beyond open-QA, we assess the capabilities of LKG-RALM architecture on broader NLP benchmarks. Specifically, we evaluate on Multitask Language Understanding (MMLU) and Language Modeling, standing challenging tasks covering both understanding and generation.
+
+\subsection{MMLU}
+\input{figure/q6-1}
+We evaluated LKG-RALM on the Multi-task Language Understanding (MMLU) benchmark \cite{mmlu}, a comprehensive multiple choice QA dataset consisting of 57 natural language understanding tasks, including elementary mathematics, US history, computer science, law, and more. Following previous work \cite{REPLUG}, We grouped these tasks into four categories: Humanities, Social Science, STEM, and Other. We still use the Wikipedia dump as an external corpus for retrieving information to improve the performance on the MMLU task.
+
+
+As shown in Table \ref{table:MMLU}, the results demonstrate that LKG-RALM outperforms the original LLAMA model by a significant margin across all tasks. Specifically, we observe an average accuracy improvement of 1.5\% on Humanities, 7.2\% on Social Science, 2.8\% on STEM, and 4.4\% on other tasks over LLAMA-3.1-8B. Moreover, compared to other models, we have achieved competitive performance. LKG-RALM-8B outperforms ChatGPT by 4.8\% on average and surpasses Self-RAG by 10.1\%. Compared with RankRAG, we obtain 1.3 absolute improvements on average. This substantial performance boost can be attributed to two key factors. Firstly, the retrieved passages from Wikipedia provide useful external knowledge and context for the model to better understand the input texts. Secondly, the relevance-guided architecture enables more effective encoding and reasoning over lengthy context-like passages, facilitating passage understanding for solving complex MMLU tasks.
+
+\subsection{Language Modeling}
+\input{figure/q6-2}
+As a crucial touchstone for evaluating general language generation capabilities, we assess the LKG-RALM architecture on language modeling benchmarks spanning diverse domains including websites, academic writing, code, and dialogue on the Pile dataset. These benchmarks require predicting subsequent tokens based on preceding textual context, and evaluating model fluency, coherence and grounding. A key challenge arises from lengthy context segmentation across long documents, which hinders encoding the full history to produce logically consistent continuations. To enable relevance-aware passage fusion component, we segment lengthy sequences into 100-token passages.
+
+Following prior work \cite{REPLUG}, we report the standard bits-per-byte (BPB) metric which measures cross-entropy reduction to evaluate perplexity improvements. A smaller value of BPB means better performance. As shown in Table \ref{table:LM}, LKG-RALM substantially enhances base LLMs like GPT-2, Qwen, and LLAMA across all categories of the Pile benchmark by 9.06\% (GPT-2), 5.17\% (Qwen), and 3.45\% (LLAMA) BPB on average. This demonstrates LKG-RALM successfully encodes rich multi-granularity semantics to produce logical and human-like text continuations. The significant perplexity reductions validate that hierarchical encoding mechanisms enhance the language model's context capacity to track lengthy precedings for coherent generation. By effectively navigating long-range dependencies, LKG-RALM generates higher-quality and better-grounded natural language.
+
+\section{Efficiency and Accuracy Trade-off}
+\input{figure/table-efficient}
+To contextualize our method's efficiency, we compared LKG-RALM's performance with existing models in Table \ref{table:efficiency-accuracy}. Self-RAG and RA-ISF require multiple rounds of retrieval and question decomposition-based multi-turn dialogue, respectively. Their low parallelism results in approximately 3.7x inference time compared to LLAMA-3.1-8B, with Self-RAG taking 3.07 seconds per query and RA-ISF requiring 3.44 seconds per query. Noise-Resistant-8B and RankRAG-70B use BERT-based NLI models and Reranking Models, respectively, to assist in passage filtering or sorting. This introduces an additional 0.8\% inference latency and 1.9\% computational overhead for RankRAG-70B, increasing its processing time to 1.25 seconds per query and its computational cost to 145.4 TFLOPs for 1024 tokens. Similarly, LKG-RALM-70B employs the Qwen-2.5-7B model for passage relevance analysis, resulting in a 2.4\% increase in inference latency (from 1.24 to 1.27 s/query) and a 10.2\% increase in computational cost (from 142.6 to 157.2 TFLOPs). The low latency is attributed to the layer-wise passage estimator's ability to operate in high parallelism with LLM inference.
+
+To further evaluate the trade-off between efficiency and accuracy, we conducted experiments using different LLM sizes for the layer-wise passage estimator. Table \ref{table:efficiency-accuracy} shows that increasing the size of the passage estimator from Qwen-2.5-500M to Qwen-2.5-7B yields consistent improvements in EM scores for both LLAMA-3.1-8B and LLAMA-3.1-70B base models. For LLAMA-3.1-8B, the EM score improves from 53.6 to 55.3 as we scale up the estimator, with a modest increase in processing time from 0.83 to 0.91 seconds per query. The computational cost rises from 18.1 to 30.9 TFLOPs. Notably, even with the largest Qwen-2.5-7B estimator, LKG-RALM-70B maintains competitive efficiency compared to RankRAG-70B (1.27 vs 1.25 s/query) while achieving superior EM scores (61.0 vs 54.2). LKG-RALM's flexible framework allows users to balance accuracy and efficiency by selecting appropriate estimator sizes. For example, Qwen-2.5-1.5B with LLAMA-3.1-70B improves EM score by 0.7 over the 500M version, with minimal increases in query time and computational cost.
+
+\section{Effects of Training Data Size}
+\input{figure/fig-training-data-size}
+We conducted an analysis to understand how the scale of training data affects the model's performance. Specifically, we randomly sampled 30k, 60k, 90k, 160k, and 220k instances from our original 440k training instances and fine-tuned five LKG-RALM-70B variants on these subsets. We then compared the model performance on NQ and HotpotQA with our final LKG-RALM trained on the full 440k instances. We also evaluated LLAMA-3.1-70B fine-tuned on the same data subsets as a baseline. Figure \ref{fig:data-size-impact} shows the models' performance trained on different amounts of data. For NQ, LKG-RALM-70B's accuracy improves from 50.5 with 30k training instances to 61.0 with the full 440k dataset. In contrast, LLAMA-3.1-FineTune shows minimal improvement, from 44.6 to 45.7 with a mere 1.1 increase. The performance gap between LKG-RALM-70B and LLAMA-3.1-FineTune widens significantly, from 5.9 at 30k instances to 15.3 at 440k instances.
+
+The model's strong performance largely comes from its effective pre-training parameter space. With only 30k fine-tuning examples, LKG-RALM-70B shows impressive results, reaching 50.5 EM on NQ and 37.1 on HotpotQA. This indicates that minimal fine-tuning data can activate the model's core capabilities in passage estimation and robustness. The performance difference between LKG-RALM-70B and LLAMA-3.1-FineTune is clear even with limited data, showing our approach's effectiveness. As training data increases, LKG-RALM-70B's accuracy steadily improves, though gains slow after 220k examples. For NQ, the improvement from 220k to 440k is just 0.7, versus 1.8 from 160k to 220k. This shows that while more data helps performance, the benefits decrease with larger datasets. The model's architecture and pre-training are key to its success, enabling strong results with limited fine-tuning data, while additional training data has diminishing returns.
+
+\section{Impact of Relevance Score Quantity}
+To investigate how the number of relevance scores affects model performance, we conducted experiments that preserved only the first k distinct relevance scores (k ranging from 1 to 7). Any subsequent relevance scores that were either similar or appeared later in the sequence were overwritten by these values. In our analysis of a 32-layer LLMAMA-3.1-8B, we observed significantly different relevance scores (L1 distance >0.3) distributed across layers. The distribution of these distinct scores is presented in Table \ref{tab:score-distribution}.
+
+\setlength{\tabcolsep}{4pt} % 默认值是6pt
+\begin{table}[t]
+\scriptsize
+\centering
+\caption{Distribution of Distinct Relevance Scores Across Model Layers}
+\label{tab:score-distribution}
+\begin{tabular}{lcccccccc}
+\hline
+Number of Scores & 1 & 2 & 3 & 4 & 5 & 6 & 7 & Others \\
+\hline
+Percentage (\%) & 4.2 & 17.2 & 21.8 & 35.9 & 10.4 & 5.6 & 3.2 & 1.7 \\
+\hline
+\end{tabular}
+\end{table}
+
+\setlength{\tabcolsep}{3pt} % 默认值是6pt
+\begin{table}[t]
+\scriptsize
+\centering
+\caption{Model Performance with Different Numbers of Relevance Scores}
+\label{tab:score-impact}
+\begin{tabular}{lccccc}
+\hline
+Model & NQ & TriviaQA & HotpotQA & PopQA & 2WikiMQA \\
+\hline
+LLAMA-3.1-8B & 18.7 & 78.5 & 16.5 & 22.1 & 13.9 \\
++ Retrieval \& Fine-tuning & 35.7 & 77.4 & 28.9 & 37.1 & 25.3 \\
+\hline
++ LKG-RALM (k=1) & 36.4 & 77.3 & 29.6 & 38.6 & 26.2 \\
++ LKG-RALM (k=2) & 40.7 & 78.5 & 31.3 & 42.3 & 26.5 \\
++ LKG-RALM (k=3) & 47.2 & 85.4 & 37.4 & 48.0 & 33.4 \\
++ LKG-RALM (k=4) & 50.1 & 86.0 & 39.3 & 51.6 & 34.6 \\
++ LKG-RALM (k=5) & 51.6 & 86.0 & 40.5 & 52.3 & 35.0 \\
++ LKG-RALM (k=6) & 51.9 & 86.1 & 40.5 & 53.2 & 35.4 \\
++ LKG-RALM (k=7) & 51.8 & 86.2 & 40.1 & 53.3 & 35.5 \\
++ LKG-RALM (all) & 55.3 & 88.6 & 43.1 & 57.2 & 39.0 \\
+\hline
+\end{tabular}
+\end{table}
+
+The impact of varying k values on model performance across different datasets is shown in Table \ref{tab:score-impact}. The result reveals that LKG-RALM's accuracy consistently improved across metrics as more relevance scores were incorporated into the attention guidance mechanism. Notably significant performance improvements were observed at several key transitions. When increasing from k=2 to k=3, we observed substantial gains across all datasets, with NQ accuracy improving by 6.5 points (from 40.7 to 47.2) and TriviaQA showing a remarkable 6.9-point increase (from 78.5 to 85.4). The transition from k=3 to k=4 brought further improvements, particularly in NQ (2.9 points) and PopQA (3.6 points). Interestingly, while incremental improvements continued beyond k=4, they became more modest, with gains typically under 1.5 points per step.
+
+The most striking performance boost was achieved when utilizing all relevance scores instead of limiting to the first seven distinct scores. This configuration led to substantial improvements across all datasets: NQ improved by 3.5 points (from 51.8 to 55.3), TriviaQA by 2.4 points (from 86.2 to 88.6), and 2WikiMQA by 3.5 points (from 35.5 to 39.0). These results strongly suggest that while the first few distinct relevance scores contribute significantly to model performance, the additional nuanced guidance from higher layers plays a crucial role in maximizing the model's capabilities.
+
+While our experimental design focused on retaining the first k relevance scores, this approach may not be optimal as higher-layer knowledge representations could provide essential guidance for complex reasoning tasks. The superior performance achieved when utilizing all attention guidance signals validates their collective importance in enhancing model accuracy and demonstrates the value of maintaining a diverse set of relevance scores across different layers of the model architecture.
+
+
+
+\end{document}
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_natbib.bst" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_natbib.bst"
new file mode 100644
index 0000000000..cad5a5e9d7
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/acl_natbib.bst"
@@ -0,0 +1,1928 @@
+%%% Modification of BibTeX style file acl_natbib_nourl.bst
+%%% ... by urlbst, version 0.9.1 (marked with "% urlbst")
+%%% See and repository
+%%% Modifications Copyright 2002–23, Norman Gray,
+%%% and distributed under the terms of the LPPL; see README for discussion.
+%%%
+%%% Added webpage entry type, and url and lastchecked fields.
+%%% Added eprint support.
+%%% Added DOI support.
+%%% Added PUBMED support.
+%%% Added hyperref support.
+%%% Original headers follow...
+
+%%
+%% This is file `acl_natbib_basic.bst',
+%% generated with the docstrip utility.
+%%
+%% The original source files were:
+%%
+%% merlin.mbs (with options: `ay,nat,pres,ed-au,keyxyr,blkyear,dt-beg,yr-per,note-yr,num-xser,pre-edn,xedn,nfss')
+%% ----------------------------------------
+%% *** Intended for ACL conferences ***
+%%
+%% Copyright 1994-2011 Patrick W Daly
+ % ===============================================================
+ % IMPORTANT NOTICE:
+ % This bibliographic style (bst) file has been generated from one or
+ % more master bibliographic style (mbs) files, listed above.
+ %
+ % This generated file can be redistributed and/or modified under the terms
+ % of the LaTeX Project Public License Distributed from CTAN
+ % archives in directory macros/latex/base/lppl.txt; either
+ % version 1 of the License, or any later version.
+ % ===============================================================
+ % Name and version information of the main mbs file:
+ % \ProvidesFile{merlin.mbs}[2011/11/18 4.33 (PWD, AO, DPC)]
+ % For use with BibTeX version 0.99a or later
+ %-------------------------------------------------------------------
+ % This bibliography style file is intended for texts in ENGLISH
+ % This is an author-year citation style bibliography. As such, it is
+ % non-standard LaTeX, and requires a special package file to function properly.
+ % Such a package is natbib.sty by Patrick W. Daly
+ % The form of the \bibitem entries is
+ % \bibitem[Jones et al.(1990)]{key}...
+ % \bibitem[Jones et al.(1990)Jones, Baker, and Smith]{key}...
+ % The essential feature is that the label (the part in brackets) consists
+ % of the author names, as they should appear in the citation, with the year
+ % in parentheses following. There must be no space before the opening
+ % parenthesis!
+ % With natbib v5.3, a full list of authors may also follow the year.
+ % In natbib.sty, it is possible to define the type of enclosures that is
+ % really wanted (brackets or parentheses), but in either case, there must
+ % be parentheses in the label.
+ % The \cite command functions as follows:
+ % \citet{key} ==>> Jones et al. (1990)
+ % \citet*{key} ==>> Jones, Baker, and Smith (1990)
+ % \citep{key} ==>> (Jones et al., 1990)
+ % \citep*{key} ==>> (Jones, Baker, and Smith, 1990)
+ % \citep[chap. 2]{key} ==>> (Jones et al., 1990, chap. 2)
+ % \citep[e.g.][]{key} ==>> (e.g. Jones et al., 1990)
+ % \citep[e.g.][p. 32]{key} ==>> (e.g. Jones et al., 1990, p. 32)
+ % \citeauthor{key} ==>> Jones et al.
+ % \citeauthor*{key} ==>> Jones, Baker, and Smith
+ % \citeyear{key} ==>> 1990
+ %---------------------------------------------------------------------
+
+ENTRY
+ { address
+ archivePrefix
+ author
+ booktitle
+ chapter
+ edition
+ editor
+ eid
+ eprint
+ eprinttype % = archivePrefix
+ howpublished
+ institution
+ journal
+ key
+ month
+ note
+ number
+ organization
+ pages
+ publisher
+ school
+ series
+ title
+ type
+ volume
+ year
+ doi % urlbst
+ pubmed % urlbst
+ url % urlbst
+ lastchecked % urlbst
+ }
+ {}
+ { label extra.label sort.label short.list }
+INTEGERS { output.state before.all mid.sentence after.sentence after.block }
+% urlbst...
+% urlbst constants and state variables
+STRINGS { urlintro
+ eprinturl eprintprefix doiprefix doiurl pubmedprefix pubmedurl
+ citedstring onlinestring linktextstring
+ openinlinelink closeinlinelink }
+INTEGERS { hrefform doiform inlinelinks makeinlinelink
+ addeprints adddoi addpubmed }
+FUNCTION {init.urlbst.variables}
+{
+ % The following constants may be adjusted by hand, if desired
+
+ % The first set allow you to enable or disable certain functionality.
+ #1 'addeprints := % 0=no eprints; 1=include eprints
+ #2 'hrefform := % 0=no crossrefs; 1=hypertex hrefs; 2=hyperref hrefs
+ #1 'inlinelinks := % 0=URLs explicit; 1=URLs attached to titles
+ #1 'adddoi := % 0=no DOI resolver; 1=include it
+ #1 'addpubmed := % 0=no PUBMED resolver; 1=include it
+ #0 'doiform := % 0=with href; 1=with \doi{}
+
+ % String constants, which you _might_ want to tweak.
+ "online" 'onlinestring := % label that a resource is online
+ "[link]" 'linktextstring := % anonymous link text
+ "http://www.ncbi.nlm.nih.gov/pubmed/" 'pubmedurl := % prefix to make URL from PUBMED
+ "https://doi.org/" 'doiurl := % prefix to make URL from DOI
+ "doi:" 'doiprefix := % printed text to introduce DOI
+ "https://arxiv.org/abs/" 'eprinturl := % prefix to make URL from eprint ref
+ "cited " 'citedstring := % label in "lastchecked" remark
+ "arXiv:" 'eprintprefix := % text prefix printed before eprint ref
+ "PMID:" 'pubmedprefix := % text prefix printed before PUBMED ref
+ "URL: " 'urlintro := % text prefix before URL
+
+ % The following are internal state variables, not configuration constants,
+ % so they shouldn't be fiddled with.
+ #0 'makeinlinelink := % state variable managed by possibly.setup.inlinelink
+ "" 'openinlinelink := % ditto
+ "" 'closeinlinelink := % ditto
+}
+INTEGERS {
+ bracket.state
+ outside.brackets
+ open.brackets
+ within.brackets
+ close.brackets
+}
+% ...urlbst to here
+FUNCTION {init.state.consts}
+{ #0 'outside.brackets := % urlbst...
+ #1 'open.brackets :=
+ #2 'within.brackets :=
+ #3 'close.brackets := % ...urlbst to here
+
+ #0 'before.all :=
+ #1 'mid.sentence :=
+ #2 'after.sentence :=
+ #3 'after.block :=
+}
+STRINGS { s t}
+% urlbst
+FUNCTION {output.nonnull.original}
+{ 's :=
+ output.state mid.sentence =
+ { ", " * write$ }
+ { output.state after.block =
+ { add.period$ write$
+ newline$
+ "\newblock " write$
+ }
+ { output.state before.all =
+ 'write$
+ { add.period$ " " * write$ }
+ if$
+ }
+ if$
+ mid.sentence 'output.state :=
+ }
+ if$
+ s
+}
+
+% urlbst...
+% Minimal DOI parsing.
+% Given a DOI on the stack, check whether it starts with 'doiurl' or not.
+% In either case, leave on the stack first a DOI with, and then a DOI without, the URL prefix.
+FUNCTION {parse.doi}
+{
+ #1 doiurl text.length$ substring$
+ doiurl =
+ { doi
+ doi doiurl text.length$ #1 + #999 substring$ }
+ { doiurl doi *
+ doi }
+ if$
+}
+% The following three functions are for handling inlinelink. They wrap
+% a block of text which is potentially output with write$ by multiple
+% other functions, so we don't know the content a priori.
+% They communicate between each other using the variables makeinlinelink
+% (which is true if a link should be made), and closeinlinelink (which holds
+% the string which should close any current link. They can be called
+% at any time, but start.inlinelink will be a no-op unless something has
+% previously set makeinlinelink true, and the two ...end.inlinelink functions
+% will only do their stuff if start.inlinelink has previously set
+% closeinlinelink to be non-empty.
+% (thanks to 'ijvm' for suggested code here)
+FUNCTION {uand}
+{ 'skip$ { pop$ #0 } if$ } % 'and' (which isn't defined at this point in the file)
+FUNCTION {possibly.setup.inlinelink}
+{ makeinlinelink hrefform #0 > uand
+ { doi empty$ adddoi uand
+ { pubmed empty$ addpubmed uand
+ { eprint empty$ addeprints uand
+ { url empty$
+ { "" }
+ { url }
+ if$ }
+ { eprinturl eprint * }
+ if$ }
+ { pubmedurl pubmed * }
+ if$ }
+% { doiurl doi * }
+ { doi empty$
+ { "XXX" }
+ { doi parse.doi pop$ }
+ if$
+ }
+ if$
+ % an appropriately-formatted URL is now on the stack
+ hrefform #1 = % hypertex
+ { "\special {html: }{" * 'openinlinelink :=
+ "\special {html:}" 'closeinlinelink := }
+ { "\href {" swap$ * "} {" * 'openinlinelink := % hrefform=#2 -- hyperref
+ % the space between "} {" matters: a URL of just the right length can cause "\% newline em"
+ "}" 'closeinlinelink := }
+ if$
+ #0 'makeinlinelink :=
+ }
+ 'skip$
+ if$ % makeinlinelink
+}
+FUNCTION {add.inlinelink}
+{ openinlinelink empty$
+ 'skip$
+ { openinlinelink swap$ * closeinlinelink *
+ "" 'openinlinelink :=
+ }
+ if$
+}
+FUNCTION {output.nonnull}
+{ % Save the thing we've been asked to output
+ 's :=
+ % If the bracket-state is close.brackets, then add a close-bracket to
+ % what is currently at the top of the stack, and set bracket.state
+ % to outside.brackets
+ bracket.state close.brackets =
+ { "]" *
+ outside.brackets 'bracket.state :=
+ }
+ 'skip$
+ if$
+ bracket.state outside.brackets =
+ { % We're outside all brackets -- this is the normal situation.
+ % Write out what's currently at the top of the stack, using the
+ % original output.nonnull function.
+ s
+ add.inlinelink
+ output.nonnull.original % invoke the original output.nonnull
+ }
+ { % Still in brackets. Add open-bracket or (continuation) comma, add the
+ % new text (in s) to the top of the stack, and move to the close-brackets
+ % state, ready for next time (unless inbrackets resets it). If we come
+ % into this branch, then output.state is carefully undisturbed.
+ bracket.state open.brackets =
+ { " [" * }
+ { ", " * } % bracket.state will be within.brackets
+ if$
+ s *
+ close.brackets 'bracket.state :=
+ }
+ if$
+}
+
+% Call this function just before adding something which should be presented in
+% brackets. bracket.state is handled specially within output.nonnull.
+FUNCTION {inbrackets}
+{ bracket.state close.brackets =
+ { within.brackets 'bracket.state := } % reset the state: not open nor closed
+ { open.brackets 'bracket.state := }
+ if$
+}
+
+FUNCTION {format.lastchecked}
+{ lastchecked empty$
+ { "" }
+ { inbrackets citedstring lastchecked * }
+ if$
+}
+% ...urlbst to here
+FUNCTION {output}
+{ duplicate$ empty$
+ 'pop$
+ 'output.nonnull
+ if$
+}
+FUNCTION {output.check}
+{ 't :=
+ duplicate$ empty$
+ { pop$ "empty " t * " in " * cite$ * warning$ }
+ 'output.nonnull
+ if$
+}
+FUNCTION {fin.entry.original} % urlbst (renamed from fin.entry, so it can be wrapped below)
+{ add.period$
+ write$
+ newline$
+}
+
+FUNCTION {new.block}
+{ output.state before.all =
+ 'skip$
+ { after.block 'output.state := }
+ if$
+}
+FUNCTION {new.sentence}
+{ output.state after.block =
+ 'skip$
+ { output.state before.all =
+ 'skip$
+ { after.sentence 'output.state := }
+ if$
+ }
+ if$
+}
+FUNCTION {add.blank}
+{ " " * before.all 'output.state :=
+}
+
+FUNCTION {date.block}
+{
+ new.block
+}
+
+FUNCTION {not}
+{ { #0 }
+ { #1 }
+ if$
+}
+FUNCTION {and}
+{ 'skip$
+ { pop$ #0 }
+ if$
+}
+FUNCTION {or}
+{ { pop$ #1 }
+ 'skip$
+ if$
+}
+FUNCTION {new.block.checkb}
+{ empty$
+ swap$ empty$
+ and
+ 'skip$
+ 'new.block
+ if$
+}
+FUNCTION {field.or.null}
+{ duplicate$ empty$
+ { pop$ "" }
+ 'skip$
+ if$
+}
+FUNCTION {emphasize}
+{ duplicate$ empty$
+ { pop$ "" }
+ { "\emph{" swap$ * "}" * }
+ if$
+}
+FUNCTION {tie.or.space.prefix} % puts ~ before the preceding part if it is of length <3
+{ duplicate$ text.length$ #3 <
+ { "~" }
+ { " " }
+ if$
+ swap$
+}
+
+FUNCTION {capitalize}
+{ "u" change.case$ "t" change.case$ }
+
+FUNCTION {space.word}
+{ " " swap$ * " " * }
+ % Here are the language-specific definitions for explicit words.
+ % Each function has a name bbl.xxx where xxx is the English word.
+ % The language selected here is ENGLISH
+FUNCTION {bbl.and}
+{ "and"}
+
+FUNCTION {bbl.etal}
+{ "et~al." }
+
+FUNCTION {bbl.editors}
+{ "editors" }
+
+FUNCTION {bbl.editor}
+{ "editor" }
+
+FUNCTION {bbl.edby}
+{ "edited by" }
+
+FUNCTION {bbl.edition}
+{ "edition" }
+
+FUNCTION {bbl.volume}
+{ "volume" }
+
+FUNCTION {bbl.of}
+{ "of" }
+
+FUNCTION {bbl.number}
+{ "number" }
+
+FUNCTION {bbl.nr}
+{ "no." }
+
+FUNCTION {bbl.in}
+{ "in" }
+
+FUNCTION {bbl.pages}
+{ "pages" }
+
+FUNCTION {bbl.page}
+{ "page" }
+
+FUNCTION {bbl.chapter}
+{ "chapter" }
+
+FUNCTION {bbl.techrep}
+{ "Technical Report" }
+
+FUNCTION {bbl.mthesis}
+{ "Master's thesis" }
+
+FUNCTION {bbl.phdthesis}
+{ "Ph.D. thesis" }
+
+MACRO {jan} {"January"}
+
+MACRO {feb} {"February"}
+
+MACRO {mar} {"March"}
+
+MACRO {apr} {"April"}
+
+MACRO {may} {"May"}
+
+MACRO {jun} {"June"}
+
+MACRO {jul} {"July"}
+
+MACRO {aug} {"August"}
+
+MACRO {sep} {"September"}
+
+MACRO {oct} {"October"}
+
+MACRO {nov} {"November"}
+
+MACRO {dec} {"December"}
+
+MACRO {acmcs} {"ACM Computing Surveys"}
+
+MACRO {acta} {"Acta Informatica"}
+
+MACRO {cacm} {"Communications of the ACM"}
+
+MACRO {ibmjrd} {"IBM Journal of Research and Development"}
+
+MACRO {ibmsj} {"IBM Systems Journal"}
+
+MACRO {ieeese} {"IEEE Transactions on Software Engineering"}
+
+MACRO {ieeetc} {"IEEE Transactions on Computers"}
+
+MACRO {ieeetcad}
+ {"IEEE Transactions on Computer-Aided Design of Integrated Circuits"}
+
+MACRO {ipl} {"Information Processing Letters"}
+
+MACRO {jacm} {"Journal of the ACM"}
+
+MACRO {jcss} {"Journal of Computer and System Sciences"}
+
+MACRO {scp} {"Science of Computer Programming"}
+
+MACRO {sicomp} {"SIAM Journal on Computing"}
+
+MACRO {tocs} {"ACM Transactions on Computer Systems"}
+
+MACRO {tods} {"ACM Transactions on Database Systems"}
+
+MACRO {tog} {"ACM Transactions on Graphics"}
+
+MACRO {toms} {"ACM Transactions on Mathematical Software"}
+
+MACRO {toois} {"ACM Transactions on Office Information Systems"}
+
+MACRO {toplas} {"ACM Transactions on Programming Languages and Systems"}
+
+MACRO {tcs} {"Theoretical Computer Science"}
+
+% bibinfo.check avoids acting on missing fields while bibinfo.warn will
+% issue a warning message if a missing field is detected. Prior to calling
+% the bibinfo functions, the user should push the field value and then its
+% name string, in that order.
+FUNCTION {bibinfo.check}
+{ swap$
+ duplicate$ missing$
+ {
+ pop$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ pop$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {bibinfo.warn}
+{ swap$
+ duplicate$ missing$
+ {
+ swap$ "missing " swap$ * " in " * cite$ * warning$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ "empty " swap$ * " in " * cite$ * warning$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+INTEGERS { nameptr namesleft numnames }
+
+
+STRINGS { bibinfo}
+
+FUNCTION {format.names}
+{ 'bibinfo :=
+ duplicate$ empty$ 'skip$ {
+ 's :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{ff~}{vv~}{ll}{, jj}" % first name first for all authors
+ format.name$
+ bibinfo bibinfo.check
+ 't :=
+ nameptr #1 >
+ {
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ t "others" =
+ {
+ " " * bbl.etal *
+ }
+ {
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+ } if$
+}
+FUNCTION {format.names.ed}
+{
+ format.names
+}
+FUNCTION {format.key}
+{ empty$
+ { key field.or.null }
+ { "" }
+ if$
+}
+
+FUNCTION {format.authors}
+{ author "author" format.names
+}
+FUNCTION {get.bbl.editor}
+{ editor num.names$ #1 > 'bbl.editors 'bbl.editor if$ }
+
+FUNCTION {format.editors}
+{ editor "editor" format.names duplicate$ empty$ 'skip$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ *
+ }
+ if$
+}
+FUNCTION {format.note}
+{
+ note empty$
+ { "" }
+ { note #1 #1 substring$
+ duplicate$ "{" =
+ 'skip$
+ { output.state mid.sentence =
+ { "l" }
+ { "u" }
+ if$
+ change.case$
+ }
+ if$
+ note #2 global.max$ substring$ * "note" bibinfo.check
+ }
+ if$
+}
+
+FUNCTION {format.title}
+{ title
+ duplicate$ empty$ 'skip$
+ { "t" change.case$ }
+ if$
+ "title" bibinfo.check
+}
+FUNCTION {format.full.names}
+{'s :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv~}{ll}" format.name$
+ 't :=
+ nameptr #1 >
+ {
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ t "others" =
+ {
+ " " * bbl.etal *
+ }
+ {
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {author.editor.key.full}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {author.key.full}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {editor.key.full}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+}
+
+FUNCTION {make.full.names}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.full
+ { type$ "proceedings" =
+ 'editor.key.full
+ 'author.key.full
+ if$
+ }
+ if$
+}
+
+FUNCTION {output.bibitem.original} % urlbst (renamed from output.bibitem, so it can be wrapped below)
+{ newline$
+ "\bibitem[{" write$
+ label write$
+ ")" make.full.names duplicate$ short.list =
+ { pop$ }
+ { * }
+ if$
+ "}]{" * write$
+ cite$ write$
+ "}" write$
+ newline$
+ ""
+ before.all 'output.state :=
+}
+
+FUNCTION {n.dashify}
+{
+ 't :=
+ ""
+ { t empty$ not }
+ { t #1 #1 substring$ "-" =
+ { t #1 #2 substring$ "--" = not
+ { "--" *
+ t #2 global.max$ substring$ 't :=
+ }
+ { { t #1 #1 substring$ "-" = }
+ { "-" *
+ t #2 global.max$ substring$ 't :=
+ }
+ while$
+ }
+ if$
+ }
+ { t #1 #1 substring$ *
+ t #2 global.max$ substring$ 't :=
+ }
+ if$
+ }
+ while$
+}
+
+FUNCTION {word.in}
+{ bbl.in capitalize
+ " " * }
+
+FUNCTION {format.date}
+{ year "year" bibinfo.check duplicate$ empty$
+ {
+ }
+ 'skip$
+ if$
+ extra.label *
+ before.all 'output.state :=
+ after.sentence 'output.state :=
+}
+FUNCTION {format.btitle}
+{ title "title" bibinfo.check
+ duplicate$ empty$ 'skip$
+ {
+ emphasize
+ }
+ if$
+}
+FUNCTION {either.or.check}
+{ empty$
+ 'pop$
+ { "can't use both " swap$ * " fields in " * cite$ * warning$ }
+ if$
+}
+FUNCTION {format.bvolume}
+{ volume empty$
+ { "" }
+ { bbl.volume volume tie.or.space.prefix
+ "volume" bibinfo.check * *
+ series "series" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ bbl.of space.word * swap$
+ emphasize * }
+ if$
+ "volume and number" number either.or.check
+ }
+ if$
+}
+FUNCTION {format.number.series}
+{ volume empty$
+ { number empty$
+ { series field.or.null }
+ { series empty$
+ { number "number" bibinfo.check }
+ { output.state mid.sentence =
+ { bbl.number }
+ { bbl.number capitalize }
+ if$
+ number tie.or.space.prefix "number" bibinfo.check * *
+ bbl.in space.word *
+ series "series" bibinfo.check *
+ }
+ if$
+ }
+ if$
+ }
+ { "" }
+ if$
+}
+
+FUNCTION {format.edition}
+{ edition duplicate$ empty$ 'skip$
+ {
+ output.state mid.sentence =
+ { "l" }
+ { "t" }
+ if$ change.case$
+ "edition" bibinfo.check
+ " " * bbl.edition *
+ }
+ if$
+}
+INTEGERS { multiresult }
+FUNCTION {multi.page.check}
+{ 't :=
+ #0 'multiresult :=
+ { multiresult not
+ t empty$ not
+ and
+ }
+ { t #1 #1 substring$
+ duplicate$ "-" =
+ swap$ duplicate$ "," =
+ swap$ "+" =
+ or or
+ { #1 'multiresult := }
+ { t #2 global.max$ substring$ 't := }
+ if$
+ }
+ while$
+ multiresult
+}
+FUNCTION {format.pages}
+{ pages duplicate$ empty$ 'skip$
+ { duplicate$ multi.page.check
+ {
+ bbl.pages swap$
+ n.dashify
+ }
+ {
+ bbl.page swap$
+ }
+ if$
+ tie.or.space.prefix
+ "pages" bibinfo.check
+ * *
+ }
+ if$
+}
+FUNCTION {format.journal.pages}
+{ pages duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$
+ { pop$ pop$ format.pages }
+ {
+ ":" *
+ swap$
+ n.dashify
+ "pages" bibinfo.check
+ *
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.journal.eid}
+{ eid "eid" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$ 'skip$
+ {
+ ":" *
+ }
+ if$
+ swap$ *
+ }
+ if$
+}
+FUNCTION {format.vol.num.pages}
+{ volume field.or.null
+ duplicate$ empty$ 'skip$
+ {
+ "volume" bibinfo.check
+ }
+ if$
+ number "number" bibinfo.check duplicate$ empty$ 'skip$
+ {
+ swap$ duplicate$ empty$
+ { "there's a number but no volume in " cite$ * warning$ }
+ 'skip$
+ if$
+ swap$
+ "(" swap$ * ")" *
+ }
+ if$ *
+ eid empty$
+ { format.journal.pages }
+ { format.journal.eid }
+ if$
+}
+
+FUNCTION {format.chapter}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ }
+ if$
+}
+
+FUNCTION {format.chapter.pages}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ pages empty$
+ 'skip$
+ { ", " * format.pages * }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.booktitle}
+{
+ booktitle "booktitle" bibinfo.check
+ emphasize
+}
+FUNCTION {format.in.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.in.ed.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ editor "editor" format.names.ed duplicate$ empty$ 'pop$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ ", " *
+ * swap$
+ * }
+ if$
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.thesis.type}
+{ type duplicate$ empty$
+ 'pop$
+ { swap$ pop$
+ "t" change.case$ "type" bibinfo.check
+ }
+ if$
+}
+FUNCTION {format.tr.number}
+{ number "number" bibinfo.check
+ type duplicate$ empty$
+ { pop$ bbl.techrep }
+ 'skip$
+ if$
+ "type" bibinfo.check
+ swap$ duplicate$ empty$
+ { pop$ "t" change.case$ }
+ { tie.or.space.prefix * * }
+ if$
+}
+FUNCTION {format.article.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.book.crossref}
+{ volume duplicate$ empty$
+ { "empty volume in " cite$ * "'s crossref of " * crossref * warning$
+ pop$ word.in
+ }
+ { bbl.volume
+ capitalize
+ swap$ tie.or.space.prefix "volume" bibinfo.check * * bbl.of space.word *
+ }
+ if$
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.incoll.inproc.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.org.or.pub}
+{ 't :=
+ ""
+ address empty$ t empty$ and
+ 'skip$
+ {
+ t empty$
+ { address "address" bibinfo.check *
+ }
+ { t *
+ address empty$
+ 'skip$
+ { ", " * address "address" bibinfo.check * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.publisher.address}
+{ publisher "publisher" bibinfo.warn format.org.or.pub
+}
+
+FUNCTION {format.organization.address}
+{ organization "organization" bibinfo.check format.org.or.pub
+}
+
+FUNCTION {archiveprefix.or.eprinttype} % holder for eprinttype with archiveprefix precedence
+{
+ archiveprefix empty$
+ {
+ eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ { eprinttype }
+ if$
+ }
+ { archiveprefix }
+ if$
+}
+
+FUNCTION {output.eprint} % this is only used with the @misc record type (common for arXiv and other preprint server bibtex records)
+{
+ eprint empty$
+ {% if eprint field is empty
+ publisher field.or.null "arXiv" = % field.or.null here helps when no publisher field in the record
+ { publisher " preprint" * } % add " preprint" to publisher with the idea that publisher is the name of the preprint server
+ { "" } % if publisher != "arXiv" then empty output
+ if$
+ emphasize % no output function after emphasize because nothing goes after this
+ }
+ {% if eprint field is not empty
+ archiveprefix.or.eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ {% if archiveprefix or eprinttype fields are not empty
+ journal empty$
+ { "Preprint" } % if journal field is empty: output just "Preprint" emphasized like a journal name
+ { journal } % if journal field is not empty, output it (takes precedence)
+ if$
+ emphasize output % emphasize what we formed before, setting output as a border to the subblock that follows with the comma delimiter
+ archiveprefix.or.eprinttype ":" * eprint * % subblock with eprinttype and eprint number
+ }
+ if$
+ }
+ if$
+}
+
+% urlbst...
+% Functions for making hypertext links.
+% In all cases, the stack has (link-text href-url)
+%
+% make 'null' specials
+FUNCTION {make.href.null}
+{
+ pop$
+}
+% make hypertex specials
+FUNCTION {make.href.hypertex}
+{
+ "\special {html: }" * swap$ *
+ "\special {html:}" *
+}
+% make hyperref specials
+FUNCTION {make.href.hyperref}
+{
+ "\href {" swap$ * "} {\path{" * swap$ * "}}" *
+}
+FUNCTION {make.href}
+{ hrefform #2 =
+ 'make.href.hyperref % hrefform = 2
+ { hrefform #1 =
+ 'make.href.hypertex % hrefform = 1
+ 'make.href.null % hrefform = 0 (or anything else)
+ if$
+ }
+ if$
+}
+
+% If inlinelinks is true, then format.url should be a no-op, since it's
+% (a) redundant, and (b) could end up as a link-within-a-link.
+FUNCTION {format.url}
+{ inlinelinks #1 = url empty$ or
+ { "" }
+ { hrefform #1 =
+ { % special case -- add HyperTeX specials
+ urlintro "\url{" url * "}" * url make.href.hypertex * }
+ { urlintro "\url{" * url * "}" * }
+ if$
+ }
+ if$
+}
+FUNCTION {format.eprint}
+{ eprint empty$
+ { "" }
+ { eprintprefix eprint * eprinturl eprint * make.href }
+ if$
+}
+
+FUNCTION {format.doi}
+{ doi empty$
+ { "" }
+ { doi parse.doi % leaves "https://doi.org/DOI" DOI on the stack
+ 's := 't :=
+ doiform #1 =
+ { "\doi{" s * "}" * }
+ { doiprefix s * t make.href }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.pubmed}
+{ pubmed empty$
+ { "" }
+ { pubmedprefix pubmed * pubmedurl pubmed * make.href }
+ if$
+}
+
+% Output a URL. We can't use the more normal idiom (something like
+% `format.url output'), because the `inbrackets' within
+% format.lastchecked applies to everything between calls to `output',
+% so that `format.url format.lastchecked * output' ends up with both
+% the URL and the lastchecked in brackets.
+FUNCTION {output.url}
+{ url empty$
+ 'skip$
+ { new.block
+ format.url output
+ format.lastchecked output
+ }
+ if$
+}
+
+FUNCTION {output.web.refs}
+{
+ new.block
+ inlinelinks
+ 'skip$ % links were inline -- don't repeat them
+ { % If the generated DOI will be the same as the URL,
+ % then don't print the URL (thanks to Joseph Wright
+ % for (the original version of) this code,
+ % at http://tex.stackexchange.com/questions/5660)
+ adddoi
+ doi empty$ { "X" } { doi parse.doi pop$ } if$ % DOI URL to be generated
+ url empty$ { "Y" } { url } if$ % the URL, or "Y" if empty
+ = % are the strings equal?
+ and
+ 'skip$
+ { output.url }
+ if$
+ addeprints eprint empty$ not and
+ { format.eprint output.nonnull }
+ 'skip$
+ if$
+ adddoi doi empty$ not and
+ { format.doi output.nonnull }
+ 'skip$
+ if$
+ addpubmed pubmed empty$ not and
+ { format.pubmed output.nonnull }
+ 'skip$
+ if$
+ }
+ if$
+}
+
+% Wrapper for output.bibitem.original.
+% If the URL field is not empty, set makeinlinelink to be true,
+% so that an inline link will be started at the next opportunity
+FUNCTION {output.bibitem}
+{ outside.brackets 'bracket.state :=
+ output.bibitem.original
+ inlinelinks url empty$ not doi empty$ not or pubmed empty$ not or eprint empty$ not or and
+ { #1 'makeinlinelink := }
+ { #0 'makeinlinelink := }
+ if$
+}
+
+% Wrapper for fin.entry.original
+FUNCTION {fin.entry}
+{ output.web.refs % urlbst
+ makeinlinelink % ooops, it appears we didn't have a title for inlinelink
+ { possibly.setup.inlinelink % add some artificial link text here, as a fallback
+ linktextstring output.nonnull }
+ 'skip$
+ if$
+ bracket.state close.brackets = % urlbst
+ { "]" * }
+ 'skip$
+ if$
+ fin.entry.original
+}
+
+% Webpage entry type.
+% Title and url fields required;
+% author, note, year, month, and lastchecked fields optional
+% See references
+% ISO 690-2 http://www.nlc-bnc.ca/iso/tc46sc9/standard/690-2e.htm
+% http://www.classroom.net/classroom/CitingNetResources.html
+% http://neal.ctstateu.edu/history/cite.html
+% http://www.cas.usf.edu/english/walker/mla.html
+% for citation formats for web pages.
+FUNCTION {webpage}
+{ output.bibitem
+ author empty$
+ { editor empty$
+ 'skip$ % author and editor both optional
+ { format.editors output.nonnull }
+ if$
+ }
+ { editor empty$
+ { format.authors output.nonnull }
+ { "can't use both author and editor fields in " cite$ * warning$ }
+ if$
+ }
+ if$
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$
+ format.title "title" output.check
+ inbrackets onlinestring output
+ new.block
+ year empty$
+ 'skip$
+ { format.date "year" output.check }
+ if$
+ % We don't need to output the URL details ('lastchecked' and 'url'),
+ % because fin.entry does that for us, using output.web.refs. The only
+ % reason we would want to put them here is if we were to decide that
+ % they should go in front of the rather miscellaneous information in 'note'.
+ new.block
+ note output
+ fin.entry
+}
+% ...urlbst to here
+
+
+FUNCTION {article}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ {
+ journal
+ "journal" bibinfo.check
+ emphasize
+ "journal" output.check
+ possibly.setup.inlinelink format.vol.num.pages output% urlbst
+ }
+ { format.article.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {book}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ crossref missing$
+ { format.bvolume output
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {booklet}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {inbook}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ crossref missing$
+ {
+ format.edition output
+ format.bvolume output
+ format.chapter "chapter" output.check
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ format.chapter "chapter" output.check
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {incollection}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.ed.booktitle "booktitle" output.check
+ format.edition output
+ format.bvolume output
+ format.number.series output
+ format.chapter.pages output
+ new.sentence
+ format.publisher.address output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.chapter.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {inproceedings}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.booktitle "booktitle" output.check
+ format.bvolume output
+ format.number.series output
+ format.pages output
+ address "address" bibinfo.check output
+ new.sentence
+ organization "organization" bibinfo.check output
+ publisher "publisher" bibinfo.check output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {conference} { inproceedings }
+FUNCTION {manual}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ organization address new.block.checkb
+ organization "organization" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {mastersthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ bbl.mthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ month "month" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {misc}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ new.block
+ output.eprint output
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {phdthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle
+ "title" output.check
+ new.block
+ bbl.phdthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {presentation}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ format.organization.address "organization and address" output.check
+ month "month" output.check
+ year "year" output.check
+ new.block
+ format.note output
+ new.sentence
+ type missing$ 'skip$
+ {"(" type capitalize * ")" * output}
+ if$
+ fin.entry
+}
+
+FUNCTION {proceedings}
+{ output.bibitem
+ format.editors output
+ editor format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.bvolume output
+ format.number.series output
+ new.sentence
+ publisher empty$
+ { format.organization.address output }
+ { organization "organization" bibinfo.check output
+ new.sentence
+ format.publisher.address output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {techreport}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ format.tr.number output.nonnull
+ institution "institution" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {unpublished}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ format.note "note" output.check
+ fin.entry
+}
+
+FUNCTION {default.type} { misc }
+READ
+FUNCTION {sortify}
+{ purify$
+ "l" change.case$
+}
+INTEGERS { len }
+FUNCTION {chop.word}
+{ 's :=
+ 'len :=
+ s #1 len substring$ =
+ { s len #1 + global.max$ substring$ }
+ 's
+ if$
+}
+FUNCTION {format.lab.names}
+{ 's :=
+ "" 't :=
+ s #1 "{vv~}{ll}" format.name$
+ s num.names$ duplicate$
+ #2 >
+ { pop$
+ " " * bbl.etal *
+ }
+ { #2 <
+ 'skip$
+ { s #2 "{ff }{vv }{ll}{ jj}" format.name$ "others" =
+ {
+ " " * bbl.etal *
+ }
+ { bbl.and space.word * s #2 "{vv~}{ll}" format.name$
+ * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+
+FUNCTION {author.key.label}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {author.editor.key.label}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {editor.key.label}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+}
+
+FUNCTION {calc.short.authors}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.label
+ { type$ "proceedings" =
+ 'editor.key.label
+ 'author.key.label
+ if$
+ }
+ if$
+ 'short.list :=
+}
+
+FUNCTION {calc.label}
+{ calc.short.authors
+ short.list
+ "("
+ *
+ year duplicate$ empty$
+ short.list key field.or.null = or
+ { pop$ "" }
+ 'skip$
+ if$
+ *
+ 'label :=
+}
+
+FUNCTION {sort.format.names}
+{ 's :=
+ #1 'nameptr :=
+ ""
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv{ } }{ll{ }}{ ff{ }}{ jj{ }}"
+ format.name$ 't :=
+ nameptr #1 >
+ {
+ " " *
+ namesleft #1 = t "others" = and
+ { "zzzzz" 't := }
+ 'skip$
+ if$
+ t sortify *
+ }
+ { t sortify * }
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {sort.format.title}
+{ 't :=
+ "A " #2
+ "An " #3
+ "The " #4 t chop.word
+ chop.word
+ chop.word
+ sortify
+ #1 global.max$ substring$
+}
+FUNCTION {author.sort}
+{ author empty$
+ { key empty$
+ { "to sort, need author or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {author.editor.sort}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { "to sort, need author, editor, or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {editor.sort}
+{ editor empty$
+ { key empty$
+ { "to sort, need editor or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+}
+FUNCTION {presort}
+{ calc.label
+ label sortify
+ " "
+ *
+ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.sort
+ { type$ "proceedings" =
+ 'editor.sort
+ 'author.sort
+ if$
+ }
+ if$
+ #1 entry.max$ substring$
+ 'sort.label :=
+ sort.label
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+
+ITERATE {presort}
+SORT
+STRINGS { last.label next.extra }
+INTEGERS { last.extra.num last.extra.num.extended last.extra.num.blank number.label }
+FUNCTION {initialize.extra.label.stuff}
+{ #0 int.to.chr$ 'last.label :=
+ "" 'next.extra :=
+ #0 'last.extra.num :=
+ "a" chr.to.int$ #1 - 'last.extra.num.blank :=
+ last.extra.num.blank 'last.extra.num.extended :=
+ #0 'number.label :=
+}
+FUNCTION {forward.pass}
+{ last.label label =
+ { last.extra.num #1 + 'last.extra.num :=
+ last.extra.num "z" chr.to.int$ >
+ { "a" chr.to.int$ 'last.extra.num :=
+ last.extra.num.extended #1 + 'last.extra.num.extended :=
+ }
+ 'skip$
+ if$
+ last.extra.num.extended last.extra.num.blank >
+ { last.extra.num.extended int.to.chr$
+ last.extra.num int.to.chr$
+ * 'extra.label := }
+ { last.extra.num int.to.chr$ 'extra.label := }
+ if$
+ }
+ { "a" chr.to.int$ 'last.extra.num :=
+ "" 'extra.label :=
+ label 'last.label :=
+ }
+ if$
+ number.label #1 + 'number.label :=
+}
+FUNCTION {reverse.pass}
+{ next.extra "b" =
+ { "a" 'extra.label := }
+ 'skip$
+ if$
+ extra.label 'next.extra :=
+ extra.label
+ duplicate$ empty$
+ 'skip$
+ { year field.or.null #-1 #1 substring$ chr.to.int$ #65 <
+ { "{\natexlab{" swap$ * "}}" * }
+ { "{(\natexlab{" swap$ * "})}" * }
+ if$ }
+ if$
+ 'extra.label :=
+ label extra.label * 'label :=
+}
+EXECUTE {initialize.extra.label.stuff}
+ITERATE {forward.pass}
+REVERSE {reverse.pass}
+FUNCTION {bib.sort.order}
+{ sort.label
+ " "
+ *
+ year field.or.null sortify
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+ITERATE {bib.sort.order}
+SORT
+FUNCTION {begin.bib}
+{ preamble$ empty$
+ 'skip$
+ { preamble$ write$ newline$ }
+ if$
+ "\begin{thebibliography}{" number.label int.to.str$ * "}" *
+ write$ newline$
+ "\providecommand{\natexlab}[1]{#1}"
+ write$ newline$
+}
+EXECUTE {begin.bib}
+EXECUTE {init.urlbst.variables} % urlbst
+EXECUTE {init.state.consts}
+ITERATE {call.type$}
+FUNCTION {end.bib}
+{ newline$
+ "\end{thebibliography}" write$ newline$
+}
+EXECUTE {end.bib}
+%% End of customized bst file
+%%
+%% End of file `acl_natbib_basic.bst'.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/custom.bib" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/custom.bib"
new file mode 100644
index 0000000000..21aa2fb86b
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/custom.bib"
@@ -0,0 +1,744 @@
+@article{gao2020pile,
+ title={The pile: An 800gb dataset of diverse text for language modeling},
+ author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{hendrycks2020measuring,
+ title={Measuring massive multitask language understanding},
+ author={Hendrycks, Dan and Burns, Collin and Basart, Steven and Zou, Andy and Mazeika, Mantas and Song, Dawn and Steinhardt, Jacob},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{hu2021lora,
+ title={Lora: Low-rank adaptation of large language models},
+ author={Hu, Edward J and Shen, Yelong and Wallis, Phillip and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Shean and Wang, Lu and Chen, Weizhu},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{bang2023multitask,
+ title={A multitask, multilingual, multimodal evaluation of chatgpt on reasoning, hallucination, and interactivity},
+ author={Bang, Yejin and Cahyawijaya, Samuel and Lee, Nayeon and Dai, Wenliang and Su, Dan and Wilie, Bryan and Lovenia, Holy and Ji, Ziwei and Yu, Tiezheng and Chung, Willy and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{guo2023close,
+ title={How close is chatgpt to human experts? comparison corpus, evaluation, and detection},
+ author={Guo, Biyang and Zhang, Xin and Wang, Ziyuan and Jiang, Minqi and Nie, Jinran and Ding, Yuxuan and Yue, Jianwei and Wu, Yupeng},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{ouyang2022training,
+ title={Training language models to follow instructions with human feedback},
+ author={Ouyang, Long and Wu, Jeffrey and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
+ journal={NeurIPS},
+ volume={35},
+ pages={27730--27744},
+ year={2022}
+}
+
+
+
+@article{cao2020factual,
+ title={Factual error correction for abstractive summarization models},
+ author={Cao, Meng and Dong, Yue and Wu, Jiapeng and Cheung, Jackie Chi Kit},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{raunak2021curious,
+ title={The curious case of hallucinations in neural machine translation},
+ author={Raunak, Vikas and Menezes, Arul and Junczys-Dowmunt, Marcin},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{ji2023survey,
+ title={Survey of hallucination in natural language generation},
+ author={Ji, Ziwei and Lee, Nayeon and Frieske, Rita and Yu, Tiezheng and Su, Dan and Xu, Yan and Ishii, Etsuko and Bang, Ye Jin and Madotto, Andrea and Fung, Pascale},
+ journal={ACM Computing Surveys},
+ volume={55},
+ number={12},
+ pages={1--38},
+ year={2023},
+ publisher={ACM New York, NY}
+}
+
+@article{chowdhery2022palm,
+ title={Palm: Scaling language modeling with pathways},
+ author={Chowdhery, Aakanksha and Narang, Sharan and Devlin, Jacob and Bosma, Maarten and Mishra, Gaurav and Roberts, Adam and Barham, Paul and Chung, Hyung Won and Sutton, Charles and Gehrmann, Sebastian and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{he2022rethinking,
+ title={Rethinking with retrieval: Faithful large language model inference},
+ author={He, Hangfeng and Zhang, Hongming and Roth, Dan},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{shen2023chatgpt,
+ title={In chatgpt we trust? measuring and characterizing the reliability of chatgpt},
+ author={Shen, Xinyue and Chen, Zeyuan and Backes, Michael and Zhang, Yang},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{li2023chatgpt,
+ title={Are ChatGPT and GPT-4 General-Purpose Solvers for Financial Text Analytics? An Examination on Several Typical Tasks},
+ author={Li, Xianzhi and Zhu, Xiaodan and Ma, Zhiqiang and Liu, Xiaomo and Shah, Sameena},
+ journal={arXiv},
+ year={2023}
+}
+
+@inproceedings{cai2022recent,
+ title={Recent advances in retrieval-augmented text generation},
+ author={Cai, Deng and Wang, Yan and Liu, Lemao and Shi, Shuming},
+ booktitle={ACM SIGIR},
+ pages={3417--3419},
+ year={2022}
+}
+
+@article{lewis2020retrieval,
+ title={Retrieval-augmented generation for knowledge-intensive nlp tasks},
+ author={Lewis, Patrick and Perez, Ethan and Piktus, Aleksandra and Petroni, Fabio and Karpukhin, Vladimir and Goyal, Naman and K{\"u}ttler, Heinrich and Lewis, Mike and Yih, Wen-tau and Rockt{\"a}schel, Tim and others},
+ journal={NeurIPS},
+ volume={33},
+ pages={9459--9474},
+ year={2020}
+}
+
+@article{zhu2021retrieving,
+ title={Retrieving and reading: A comprehensive survey on open-domain question answering},
+ author={Zhu, Fengbin and Lei, Wenqiang and Wang, Chao and Zheng, Jianming and Poria, Soujanya and Chua, Tat-Seng},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{zhang2022survey,
+ title={A survey for efficient open domain question answering},
+ author={Zhang, Qin and Chen, Shangsi and Xu, Dongkuan and Cao, Qingqing and Chen, Xiaojun and Cohn, Trevor and Fang, Meng},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{FiD,
+ title={Leveraging passage retrieval with generative models for open domain question answering},
+ author={Izacard, Gautier and Grave, Edouard},
+ journal={arXiv},
+ year={2020}
+}
+
+@inproceedings{FiE,
+ title={FiE: Building a Global Probability Space by Leveraging Early Fusion in Encoder for Open-Domain Question Answering},
+ author={Kedia, Akhil and Zaidi, Mohd Abbas and Lee, Haejun},
+ booktitle={EMNLP},
+ pages={4246--4260},
+ year={2022}
+}
+
+@article{R2D2,
+ title={R2-D2: A modular baseline for open-domain question answering},
+ author={Fajcik, Martin and Docekal, Martin and Ondrej, Karel and Smrz, Pavel},
+ journal={arXiv},
+ year={2021}
+}
+
+@inproceedings{REALM,
+ title={Retrieval augmented language model pre-training},
+ author={Guu, Kelvin and Lee, Kenton and Tung, Zora and Pasupat, Panupong and Chang, Mingwei},
+ booktitle={ICML},
+ pages={3929--3938},
+ year={2020},
+ organization={PMLR}
+}
+
+@article{DPR,
+ title={Dense passage retrieval for open-domain question answering},
+ author={Karpukhin, Vladimir and O{\u{g}}uz, Barlas and Min, Sewon and Lewis, Patrick and Wu, Ledell and Edunov, Sergey and Chen, Danqi and Yih, Wen-tau},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{REPLUG,
+ title={Replug: Retrieval-augmented black-box language models},
+ author={Shi, Weijia and Min, Sewon and Yasunaga, Michihiro and Seo, Minjoon and James, Rich and Lewis, Mike and Zettlemoyer, Luke and Yih, Wen-tau},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{atlas,
+ title={Few-shot learning with retrieval augmented language models},
+ author={Izacard, Gautier and Lewis, Patrick and Lomeli, Maria and Hosseini, Lucas and Petroni, Fabio and Schick, Timo and Dwivedi-Yu, Jane and Joulin, Armand and Riedel, Sebastian and Grave, Edouard},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{RALM,
+ title={In-context retrieval-augmented language models},
+ author={Ram, Ori and Levine, Yoav and Dalmedigos, Itay and Muhlgay, Dor and Shashua, Amnon and Leyton-Brown, Kevin and Shoham, Yoav},
+ journal={arXiv},
+ year={2023}
+}
+
+@inproceedings{RETRO,
+ title={Improving language models by retrieving from trillions of tokens},
+ author={Borgeaud, Sebastian and Mensch, Arthur and Hoffmann, Jordan and Cai, Trevor and Rutherford, Eliza and Millican, Katie and Van Den Driessche, George Bm and Lespiau, Jean-Baptiste and Damoc, Bogdan and Clark, Aidan and others},
+ booktitle={ICML},
+ pages={2206--2240},
+ year={2022},
+ organization={PMLR}
+}
+
+@article{luo2023empirical,
+ title={An Empirical Study of Catastrophic Forgetting in Large Language Models During Continual Fine-tuning},
+ author={Luo, Yun and Yang, Zhen and Meng, Fandong and Li, Yafu and Zhou, Jie and Zhang, Yue},
+ journal={arXiv},
+ pages={arXiv--2308},
+ year={2023}
+}
+
+@article{cai2018skeleton,
+ title={Skeleton-to-response: Dialogue generation guided by retrieval memory},
+ author={Cai, Deng and Wang, Yan and Bi, Victoria and Tu, Zhaopeng and Liu, Xiaojiang and Lam, Wai and Shi, Shuming},
+ journal={arXiv},
+ year={2018}
+}
+
+@article{maynez2020faithfulness,
+ title={On faithfulness and factuality in abstractive summarization},
+ author={Maynez, Joshua and Narayan, Shashi and Bohnet, Bernd and McDonald, Ryan},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{llama2,
+ title={Llama 2: Open foundation and fine-tuned chat models},
+ author={Touvron, Hugo and Martin, Louis and Stone, Kevin and Albert, Peter and Almahairi, Amjad and Babaei, Yasmine and Bashlykov, Nikolay and Batra, Soumya and Bhargava, Prajjwal and Bhosale, Shruti and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{bloom,
+ title={Bloom: A 176b-parameter open-access multilingual language model},
+ author={Scao, Teven Le and Fan, Angela and Akiki, Christopher and Pavlick, Ellie and Ili{\'c}, Suzana and Hesslow, Daniel and Castagn{\'e}, Roman and Luccioni, Alexandra Sasha and Yvon, Fran{\c{c}}ois and Gall{\'e}, Matthias and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{baichuan,
+ title={Baichuan 2: Open large-scale language models},
+ author={Yang, Aiyuan and Xiao, Bin and Wang, Bingning and Zhang, Borong and Yin, Chao and Lv, Chenxu and Pan, Da and Wang, Dian and Yan, Dong and Yang, Fan and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{contriever,
+ title={Unsupervised dense information retrieval with contrastive learning},
+ author={Izacard, Gautier and Caron, Mathilde and Hosseini, Lucas and Riedel, Sebastian and Bojanowski, Piotr and Joulin, Armand and Grave, Edouard},
+ journal={arXiv},
+ year={2021}
+}
+
+@inproceedings{arivazhagan2023hybrid,
+ title={Hybrid Hierarchical Retrieval for Open-Domain Question Answering},
+ author={Arivazhagan, Manoj Ghuhan and Liu, Lan and Qi, Peng and Chen, Xinchi and Wang, William Yang and Huang, Zhiheng},
+ booktitle={Findings of ACL},
+ pages={10680--10689},
+ year={2023}
+}
+
+@article{bm25,
+ title={The probabilistic relevance framework: BM25 and beyond},
+ author={Robertson, Stephen and Zaragoza, Hugo and others},
+ journal={Foundations and Trends{\textregistered} in Information Retrieval},
+ volume={3},
+ number={4},
+ pages={333--389},
+ year={2009},
+ publisher={Now Publishers, Inc.}
+}
+
+@article{rome,
+ title={Locating and Editing Factual Associations in {GPT}},
+ author={Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
+ journal={NeurIPS},
+ volume={35},
+ year={2022}
+}
+
+@article{ORQA,
+ title={Latent retrieval for weakly supervised open domain question answering},
+ author={Lee, Kenton and Chang, Ming-Wei and Toutanova, Kristina},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{NQ,
+ title={Natural questions: a benchmark for question answering research},
+ author={Kwiatkowski, Tom and Palomaki, Jennimaria and Redfield, Olivia and Collins, Michael and Parikh, Ankur and Alberti, Chris and Epstein, Danielle and Polosukhin, Illia and Devlin, Jacob and Lee, Kenton and others},
+ journal={TACL},
+ volume={7},
+ pages={453--466},
+ year={2019},
+}
+
+@article{TQA,
+ title={Triviaqa: A large scale distantly supervised challenge dataset for reading comprehension},
+ author={Joshi, Mandar and Choi, Eunsol and Weld, Daniel S and Zettlemoyer, Luke},
+ journal={arXiv},
+ year={2017}
+}
+
+@article{PALM,
+ title={Palm: Scaling language modeling with pathways},
+ author={Chowdhery, Aakanksha and Narang, Sharan and Devlin, Jacob and Bosma, Maarten and Mishra, Gaurav and Roberts, Adam and Barham, Paul and Chung, Hyung Won and Sutton, Charles and Gehrmann, Sebastian and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{Codex,
+ title={Evaluating large language models trained on code},
+ author={Chen, Mark and Tworek, Jerry and Jun, Heewoo and Yuan, Qiming and Pinto, Henrique Ponde de Oliveira and Kaplan, Jared and Edwards, Harri and Burda, Yuri and Joseph, Nicholas and Brockman, Greg and others},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{chinchilla,
+ title={Training compute-optimal large language models},
+ author={Hoffmann, Jordan and Borgeaud, Sebastian and Mensch, Arthur and Buchatskaya, Elena and Cai, Trevor and Rutherford, Eliza and Casas, Diego de Las and Hendricks, Lisa Anne and Welbl, Johannes and Clark, Aidan and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{squad,
+ title={Squad: 100,000+ questions for machine comprehension of text},
+ author={Rajpurkar, Pranav and Zhang, Jian and Lopyrev, Konstantin and Liang, Percy},
+ journal={arXiv},
+ year={2016}
+}
+
+@book{es,
+ title={Elasticsearch: the definitive guide: a distributed real-time search and analytics engine},
+ author={Gormley, Clinton and Tong, Zachary},
+ year={2015},
+ publisher={" O'Reilly Media, Inc."}
+}
+
+@article{FAISS,
+ title={Billion-scale similarity search with gpus},
+ author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}},
+ journal={IEEE Transactions on Big Data},
+ volume={7},
+ number={3},
+ pages={535--547},
+ year={2019},
+ publisher={IEEE}
+}
+
+@article{UnitedQA,
+ title={UnitedQA: A hybrid approach for open domain question answering},
+ author={Cheng, Hao and Shen, Yelong and Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{dhr,
+ title={Dense hierarchical retrieval for open-domain question answering},
+ author={Liu, Ye and Hashimoto, Kazuma and Zhou, Yingbo and Yavuz, Semih and Xiong, Caiming and Yu, Philip S},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{bert,
+ title={Bert: Pre-training of deep bidirectional transformers for language understanding},
+ author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
+ journal={arXiv},
+ year={2018}
+}
+
+@article{roberta,
+ title={Roberta: A robustly optimized bert pretraining approach},
+ author={Liu, Yinhan and Ott, Myle and Goyal, Naman and Du, Jingfei and Joshi, Mandar and Chen, Danqi and Levy, Omer and Lewis, Mike and Zettlemoyer, Luke and Stoyanov, Veselin},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{albert,
+ title={Albert: A lite bert for self-supervised learning of language representations},
+ author={Lan, Zhenzhong and Chen, Mingda and Goodman, Sebastian and Gimpel, Kevin and Sharma, Piyush and Soricut, Radu},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{kobayashi2000information,
+ title={Information retrieval on the web},
+ author={Kobayashi, Mei and Takeda, Koichi},
+ journal={ACM computing surveys (CSUR)},
+ volume={32},
+ number={2},
+ pages={144--173},
+ year={2000},
+ publisher={ACM New York, NY, USA}
+}
+
+@article{baradaran2022survey,
+ title={A survey on machine reading comprehension systems},
+ author={Baradaran, Razieh and Ghiasi, Razieh and Amirkhani, Hossein},
+ journal={Natural Language Engineering},
+ volume={28},
+ number={6},
+ pages={683--732},
+ year={2022},
+ publisher={Cambridge University Press}
+}
+
+@article{kg-fid,
+ title={Kg-fid: Infusing knowledge graph in fusion-in-decoder for open-domain question answering},
+ author={Yu, Donghan and Zhu, Chenguang and Fang, Yuwei and Yu, Wenhao and Wang, Shuohang and Xu, Yichong and Ren, Xiang and Yang, Yiming and Zeng, Michael},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{zhao2023survey,
+ title={A survey of large language models},
+ author={Zhao, Wayne Xin and Zhou, Kun and Li, Junyi and Tang, Tianyi and Wang, Xiaolei and Hou, Yupeng and Min, Yingqian and Zhang, Beichen and Zhang, Junjie and Dong, Zican and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{gpt2,
+ title={Language models are unsupervised multitask learners},
+ author={Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya and others},
+ journal={OpenAI blog},
+ volume={1},
+ number={8},
+ pages={9},
+ year={2019}
+}
+
+@article{gpt1,
+ title={Improving language understanding by generative pre-training},
+ author={Radford, Alec and Narasimhan, Karthik and Salimans, Tim and Sutskever, Ilya and others},
+ year={2018},
+ publisher={OpenAI}
+}
+
+@article{gpt3,
+ title={Language models are few-shot learners},
+ author={Brown, Tom and Mann, Benjamin and Ryder, Nick and Subbiah, Melanie and Kaplan, Jared D and Dhariwal, Prafulla and Neelakantan, Arvind and Shyam, Pranav and Sastry, Girish and Askell, Amanda and others},
+ journal={Neurips},
+ volume={33},
+ pages={1877--1901},
+ year={2020}
+}
+
+@article{t5,
+ title={Exploring the limits of transfer learning with a unified text-to-text transformer},
+ author={Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J},
+ journal={The Journal of Machine Learning Research},
+ volume={21},
+ number={1},
+ pages={5485--5551},
+ year={2020},
+ publisher={JMLRORG}
+}
+
+@article{MTEB,
+ doi = {10.48550/ARXIV.2210.07316},
+ url = {https://arxiv.org/abs/2210.07316},
+ author = {Muennighoff, Niklas and Tazi, Nouamane and Magne, Lo{\"\i}c and Reimers, Nils},
+ title = {MTEB: Massive Text Embedding Benchmark},
+ publisher = {arXiv},
+ journal={arXiv},
+ year = {2022}
+}
+
+@misc{SFR-embedding-2,
+ title = {SFR-Embedding-2: Advanced Text Embedding with Multi-stage Training},
+ author = {Rui Meng and Ye Liu and Shafiq Rayhan Joty and Caiming Xiong and Yingbo Zhou and Semih Yavuz},
+ year = {2024},
+ url = {https://huggingface.co/Salesforce/SFR-Embedding-2_R}
+}
+
+@article{liu2024lost,
+ title={Lost in the middle: How language models use long contexts},
+ author={Liu, Nelson F and Lin, Kevin and Hewitt, John and Paranjape, Ashwin and Bevilacqua, Michele and Petroni, Fabio and Liang, Percy},
+ journal={TACL},
+ volume={12},
+ pages={157--173},
+ year={2024},
+ publisher={MIT Press One Broadway, 12th Floor, Cambridge, Massachusetts 02142, USA~…}
+}
+
+@inproceedings{shi2023large,
+ title={Large language models can be easily distracted by irrelevant context},
+ author={Shi, Freda and Chen, Xinyun and Misra, Kanishka and Scales, Nathan and Dohan, David and Chi, Ed H and Sch{\"a}rli, Nathanael and Zhou, Denny},
+ booktitle={ICML},
+ pages={31210--31227},
+ year={2023},
+ organization={PMLR}
+}
+
+@inproceedings{yoran2024making,
+ title={Making Retrieval-Augmented Language Models Robust to Irrelevant Context},
+ author={Yoran, Ori and Wolfson, Tomer and Ram, Ori and Berant, Jonathan},
+ booktitle={ICLR},
+ year={2024},
+}
+
+@article{wu2024instructing,
+ title={Instructing large language models to identify and ignore irrelevant conditions},
+ author={Wu, Zhenyu and Shen, Chao and Jiang, Meng},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{liu2024ra,
+ title={RA-ISF: Learning to Answer and Understand from Retrieval Augmentation via Iterative Self-Feedback},
+ author={Liu, Yanming and Peng, Xinyue and Zhang, Xuhong and Liu, Weihao and Yin, Jianwei and Cao, Jiannan and Du, Tianyu},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{chuang2023dola,
+ title={Dola: Decoding by contrasting layers improves factuality in large language models},
+ author={Chuang, Yung-Sung and Xie, Yujia and Luo, Hongyin and Kim, Yoon and Glass, James and He, Pengcheng},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{survey1,
+ title={Retrieval-augmented generation for ai-generated content: A survey},
+ author={Zhao, Penghao and Zhang, Hailin and Yu, Qinhan and Wang, Zhengren and Geng, Yunteng and Fu, Fangcheng and Yang, Ling and Zhang, Wentao and Cui, Bin},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{survey2,
+ title={Retrieval-augmented generation for large language models: A survey},
+ author={Gao, Yunfan and Xiong, Yun and Gao, Xinyu and Jia, Kangxiang and Pan, Jinliu and Bi, Yuxi and Dai, Yi and Sun, Jiawei and Wang, Haofen},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{survey3,
+ title={Retrieval-Augmented Generation for Natural Language Processing: A Survey},
+ author={Wu, Shangyu and Xiong, Ying and Cui, Yufei and Wu, Haolun and Chen, Can and Yuan, Ye and Huang, Lianming and Liu, Xue and Kuo, Tei-Wei and Guan, Nan and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{liu2023recap,
+ title={Recap: Retrieval-enhanced context-aware prefix encoder for personalized dialogue response generation},
+ author={Liu, Shuai and Cho, Hyundong J and Freedman, Marjorie and Ma, Xuezhe and May, Jonathan},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{khandelwal2019generalization,
+ title={Generalization through memorization: Nearest neighbor language models},
+ author={Khandelwal, Urvashi and Levy, Omer and Jurafsky, Dan and Zettlemoyer, Luke and Lewis, Mike},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{huang2023k,
+ title={$ k $ NN-Adapter: Efficient Domain Adaptation for Black-Box Language Models},
+ author={Huang, Yangsibo and Liu, Daogao and Zhong, Zexuan and Shi, Weijia and Lee, Yin Tat},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{wang2023shall,
+ title={Shall we pretrain autoregressive language models with retrieval? a comprehensive study},
+ author={Wang, Boxin and Ping, Wei and Xu, Peng and McAfee, Lawrence and Liu, Zihan and Shoeybi, Mohammad and Dong, Yi and Kuchaiev, Oleksii and Li, Bo and Xiao, Chaowei and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{wu2024improving,
+ title={Improving natural language understanding with computation-efficient retrieval representation fusion},
+ author={Wu, Shangyu and Xiong, Ying and Cui, Yufei and Liu, Xue and Tang, Buzhou and Kuo, Tei-Wei and Xue, Chun Jason},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{bge,
+ title={Bge m3-embedding: Multi-lingual, multi-functionality, multi-granularity text embeddings through self-knowledge distillation},
+ author={Chen, Jianlv and Xiao, Shitao and Zhang, Peitian and Luo, Kun and Lian, Defu and Liu, Zheng},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{wang2023improving,
+ title={Improving text embeddings with large language models},
+ author={Wang, Liang and Yang, Nan and Huang, Xiaolong and Yang, Linjun and Majumder, Rangan and Wei, Furu},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{behnamghader2024llm2vec,
+ title={Llm2vec: Large language models are secretly powerful text encoders},
+ author={BehnamGhader, Parishad and Adlakha, Vaibhav and Mosbach, Marius and Bahdanau, Dzmitry and Chapados, Nicolas and Reddy, Siva},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{springer2024repetition,
+ title={Repetition improves language model embeddings},
+ author={Springer, Jacob Mitchell and Kotha, Suhas and Fried, Daniel and Neubig, Graham and Raghunathan, Aditi},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{zhang2024comprehensive,
+ title={A comprehensive study of knowledge editing for large language models},
+ author={Zhang, Ningyu and Yao, Yunzhi and Tian, Bozhong and Wang, Peng and Deng, Shumin and Wang, Mengru and Xi, Zekun and Mao, Shengyu and Zhang, Jintian and Ni, Yuansheng and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{nawrot2024dynamic,
+ title={Dynamic memory compression: Retrofitting llms for accelerated inference},
+ author={Nawrot, Piotr and {\L}a{\'n}cucki, Adrian and Chochowski, Marcin and Tarjan, David and Ponti, Edoardo M},
+ journal={arXiv},
+ year={2024}
+}
+
+@inproceedings{zhai2023stabilizing,
+ title={Stabilizing transformer training by preventing attention entropy collapse},
+ author={Zhai, Shuangfei and Likhomanenko, Tatiana and Littwin, Etai and Busbridge, Dan and Ramapuram, Jason and Zhang, Yizhe and Gu, Jiatao and Susskind, Joshua M},
+ booktitle={ICML},
+ pages={40770--40803},
+ year={2023},
+ organization={PMLR}
+}
+
+@article{fu2024attentionpattern,
+ title = "How Do Language Models put Attention Weights over Long Context?",
+ author = "Fu, Yao",
+ journal = "Yao Fu’s Notion",
+ year = "2024",
+ month = "Mar",
+ url = "https://yaofu.notion.site/How-Do-Language-Models-put-Attention-Weights-over-Long-Context-10250219d5ce42e8b465087c383a034e?pvs=4"
+}
+
+@inproceedings{zhang2021drop,
+ title={Drop redundant, shrink irrelevant: Selective knowledge injection for language pretraining.},
+ author={Zhang, Ningyu and Deng, Shumin and Cheng, Xu and Chen, Xi and Zhang, Yichi and Zhang, Wei and Chen, Huajun and Center, Hangzhou Innovation},
+ booktitle={IJCAI},
+ pages={4007--4014},
+ year={2021}
+}
+
+@article{xiao2023efficient,
+ title={Efficient streaming language models with attention sinks},
+ author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{strategyQA,
+ title={Did aristotle use a laptop? a question answering benchmark with implicit reasoning strategies},
+ author={Geva, Mor and Khashabi, Daniel and Segal, Elad and Khot, Tushar and Roth, Dan and Berant, Jonathan},
+ journal={Transactions of the Association for Computational Linguistics},
+ volume={9},
+ pages={346--361},
+ year={2021},
+ publisher={MIT Press One Rogers Street, Cambridge, MA 02142-1209, USA journals-info~…}
+}
+
+@article{HotpotQA,
+ title={HotpotQA: A dataset for diverse, explainable multi-hop question answering},
+ author={Yang, Zhilin and Qi, Peng and Zhang, Saizheng and Bengio, Yoshua and Cohen, William W and Salakhutdinov, Ruslan and Manning, Christopher D},
+ journal={arXiv},
+ year={2018}
+}
+
+@article{PopQA,
+ title={When not to trust language models: Investigating effectiveness of parametric and non-parametric memories},
+ author={Mallen, Alex and Asai, Akari and Zhong, Victor and Das, Rajarshi and Khashabi, Daniel and Hajishirzi, Hannaneh},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{ho2020constructing,
+ title={Constructing a multi-hop QA dataset for comprehensive evaluation of reasoning steps},
+ author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
+ journal={arXiv},
+ year={2020}
+}
+
+@article{dubey2024llama,
+ title={The Llama 3 Herd of Models},
+ author={Dubey, Abhimanyu and Jauhri, Abhinav and Pandey, Abhinav and Kadian, Abhishek and Al-Dahle, Ahmad and Letman, Aiesha and Mathur, Akhil and Schelten, Alan and Yang, Amy and Fan, Angela and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{yang2024qwen2,
+ title={Qwen2 technical report},
+ author={Yang, An and Yang, Baosong and Hui, Binyuan and Zheng, Bo and Yu, Bowen and Zhou, Chang and Li, Chengpeng and Li, Chengyuan and Liu, Dayiheng and Huang, Fei and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{asai2023self,
+ title={Self-rag: Learning to retrieve, generate, and critique through self-reflection},
+ author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
+ journal={arXiv preprint arXiv:2310.11511},
+ year={2023}
+}
+
+@article{yu2024rankrag,
+ title={RankRAG: Unifying Context Ranking with Retrieval-Augmented Generation in LLMs},
+ author={Yu, Yue and Ping, Wei and Liu, Zihan and Wang, Boxin and You, Jiaxuan and Zhang, Chao and Shoeybi, Mohammad and Catanzaro, Bryan},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{liu2024chatqa,
+ title={Chatqa: Building gpt-4 level conversational qa models},
+ author={Liu, Zihan and Ping, Wei and Roy, Rajarshi and Xu, Peng and Shoeybi, Mohammad and Catanzaro, Bryan},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{ROPE,
+ title={Roformer: Enhanced transformer with rotary position embedding},
+ author={Su, Jianlin and Ahmed, Murtadha and Lu, Yu and Pan, Shengfeng and Bo, Wen and Liu, Yunfeng},
+ journal={Neurocomputing},
+ volume={568},
+ pages={127063},
+ year={2024},
+ publisher={Elsevier}
+}
+
+@article{gpt4,
+ title={Gpt-4 technical report},
+ author={Achiam, Josh and Adler, Steven and Agarwal, Sandhini and Ahmad, Lama and Akkaya, Ilge and Aleman, Florencia Leoni and Almeida, Diogo and Altenschmidt, Janko and Altman, Sam and Anadkat, Shyamal and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@inproceedings{hyeon2023scratching,
+ title={Scratching visual transformer's back with uniform attention},
+ author={Hyeon-Woo, Nam and Yu-Ji, Kim and Heo, Byeongho and Han, Dongyoon and Oh, Seong Joon and Oh, Tae-Hyun},
+ booktitle={ICCV},
+ pages={5807--5818},
+ year={2023}
+}
+
+@article{mmlu,
+ title={Measuring massive multitask language understanding},
+ author={Hendrycks, Dan and Burns, Collin and Basart, Steven and Zou, Andy and Mazeika, Mantas and Song, Dawn and Steinhardt, Jacob},
+ journal={arXiv},
+ year={2020}
+}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/algorithm1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/algorithm1.tex"
new file mode 100644
index 0000000000..20d2103bfa
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/algorithm1.tex"
@@ -0,0 +1,17 @@
+\begin{algorithm}[t]
+\caption{Hierarchical Retrieval with LLM}
+\begin{algorithmic}[1]
+\REQUIRE Question $q$, Corpus $\mathcal{C}$
+\STATE Generate three rewritten questions $q_{rewrite}$ using LLM
+\STATE Concatenate $q$ and $q_{rewrite}$ as expanded query
+\STATE Retrieve top-$k/2$ documents using sparse retriever
+\STATE Retrieve top-$k/2$ documents using dense retriever
+\STATE Extract core entities from question $q$ by LLM
+\STATE Filter retrieved documents containing core entities
+\STATE Further filter documents by LLM based on title and question
+\STATE Retrieve top-$n/2$ passages from filtered documents using sparse retriever
+\STATE Retrieve top-$n/2$ passages from filtered documents using dense retriever
+\ENSURE Retrieved passages $\mathcal{P}$
+\end{algorithmic}
+\label{alg1}
+\end{algorithm}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/architecture.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/architecture.tex"
new file mode 100644
index 0000000000..0fb8742522
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/architecture.tex"
@@ -0,0 +1,10 @@
+\begin{figure*}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{latex/figure/src/mixture-architecture.pdf}
+\end{minipage}
+\centering
+\caption{Overview}
+\label{fig:architecture}
+\end{figure*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention distribution.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention distribution.tex"
new file mode 100644
index 0000000000..598b82129b
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention distribution.tex"
@@ -0,0 +1,33 @@
+\begin{figure*}[t]
+\centering
+\subfigure[Edge-focused Attention]{
+ \begin{minipage}[t]{0.22\linewidth}
+ \centering
+ \includegraphics[width=0.95\linewidth]{figure/src/edge_focused_attention.pdf}
+ \end{minipage}
+ \label{fig:edge-focused}
+}
+\subfigure[Uniform Attention]{
+ \begin{minipage}[t]{0.22\linewidth}
+ \centering
+ \includegraphics[width=0.95\linewidth]{figure/src/uniform_attention.pdf}
+ \end{minipage}
+ \label{fig:uniform}
+}
+\subfigure[Middle-focused Attention]{
+ \begin{minipage}[t]{0.22\linewidth}
+ \centering
+ \includegraphics[width=0.95\linewidth]{figure/src/middle_focused_attention.pdf}
+ \end{minipage}
+ \label{fig:middle-focused}
+}
+\subfigure[Attention Distribution]{
+ \begin{minipage}[t]{0.22\linewidth}
+ \centering
+ \includegraphics[width=0.9\linewidth]{figure/src/attention_distribution_pie.pdf}
+ \end{minipage}
+ \label{fig:distribution}
+}
+\caption{Attention patterns of RALM.}
+\label{fig:attention pattern}
+\end{figure*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention evidence.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention evidence.tex"
new file mode 100644
index 0000000000..604ebb6747
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-attention evidence.tex"
@@ -0,0 +1,106 @@
+\begin{figure}[t]
+\centering
+\subfigure[Relation between attention pattern and RALM accuracy]{
+ \begin{minipage}[t]{0.95\linewidth}
+ \centering
+ \includegraphics[width=\linewidth]{figure/src/ralm_accuracy_plot.pdf}
+ \end{minipage}
+ \label{fig:ralm-accuracy}
+}
+\subfigure[
+Impact of replacing edge-focused and uniform attention with middle-focused attention on relevant passages]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.46]
+\begin{axis}[
+xlabel=Number of Replaced Attentions,
+ylabel=EM Scores,
+% y label style={at={(-0.00,0.5)}},
+ymajorgrids=true,
+font=\Large,
+legend style={at={(0.5,1.05)},anchor=south,legend columns=2},
+% xmin=1, xmax=10,
+]
+\addplot+[color={rgb,255:red,255;green,110;blue,110},line width=1.5pt,mark=diamond*,mark size=2pt] coordinates {
+(1,42.0)
+(2,41.5)
+(3,40.8)
+(4,39.9)
+(5,38.7)
+(6,37.2)
+(7,35.5)
+(8,33.8)
+(9,32.1)
+(10,30.4)
+};
+\addlegendentry{Edge-focused}
+\addplot+[color={rgb,255:red,100;green,150;blue,220},line width=1.5pt,mark=square*,mark size=1.5pt] coordinates {
+(1,42.0)
+(2,41.8)
+(3,41.5)
+(4,41.0)
+(5,40.3)
+(6,39.5)
+(7,38.5)
+(8,37.5)
+(9,36.5)
+(10,35.5)
+};
+\addlegendentry{Uniform}
+\end{axis}
+\end{tikzpicture}
+\label{fig:attention replacement}
+\end{minipage}
+}
+\hfill
+\subfigure[
+Impact of manipulating middle-focused attention patterns on relevant passages
+]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.46]
+\begin{axis}[
+xlabel=Number of Manipulated Attentions,
+ylabel=EM Scores,
+% y label style={at={(-0.00,0.5)}},
+ymajorgrids=true,
+font=\Large,
+legend style={at={(0.5,1.05)},anchor=south,legend columns=2},
+legend cell align={left},
+% xmin=1, xmax=10,
+]
+\addplot+[color={rgb,255:red,255;green,110;blue,110},line width=1.5pt,mark=diamond*,mark size=2pt] coordinates {
+(1,42.0)
+(2,39.5)
+(3,36.0)
+(4,31.5)
+(5,26.0)
+(6,19.5)
+(7,12.0)
+(8,5.5)
+(9,2.0)
+(10,0.5)
+};
+\addlegendentry{Disrupting}
+\addplot+ [color={rgb,255:red,160;green,120;blue,190},line width=1.5pt,mark=triangle*,mark size=2pt] coordinates {
+(1,42.0)
+(2,45.5)
+(3,49.0)
+(4,53.5)
+(5,59.0)
+(6,65.5)
+(7,73.0)
+(8,81.5)
+(9,91.0)
+(10,98.0)
+};
+\addlegendentry{Concentrating}
+\end{axis}
+\end{tikzpicture}
+\label{fig:attention manipulation}
+\end{minipage}
+}
+\caption{Impact of manipulating attention patterns on RALM performance.}
+\label{fig:attention evidence}
+\vspace{-0.5cm}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm framework.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm framework.tex"
new file mode 100644
index 0000000000..b3df4199ef
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm framework.tex"
@@ -0,0 +1,10 @@
+\begin{figure*}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.03\textwidth]{figure/src/llm framework.pdf}
+\end{minipage}
+\centering
+\caption{The architectures of LKG-RALM.}
+\label{fig:llm framework}
+\end{figure*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm location.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm location.tex"
new file mode 100644
index 0000000000..5e75f5d294
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm location.tex"
@@ -0,0 +1,10 @@
+\begin{figure}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{figure/src/llm location.pdf}
+\end{minipage}
+\centering
+\caption{Comparison of ODQA architectures}
+\label{fig:llm location}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm pipeline.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm pipeline.tex"
new file mode 100644
index 0000000000..bc9a6e1f5f
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-llm pipeline.tex"
@@ -0,0 +1,10 @@
+\begin{figure}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{figure/src/OQDApipeline.pdf}
+\end{minipage}
+\centering
+\caption{Comparison between naive passage concatenation (right) and relevance-aware passage integration (left) in RALM. Due to low retrieval accuracy, the LLM is distracted by a vast number of \textcolor{red!80!black}{noisy tokens}. In contrast, explicitly guiding the LLM's attention to more \textcolor{green!60!black}{relevant passages} can effectively enhance its comprehension.}
+\label{fig:llm pipeline}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-motivation.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-motivation.tex"
new file mode 100644
index 0000000000..3693f1e1a6
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-motivation.tex"
@@ -0,0 +1,110 @@
+\begin{figure}[t]
+\centering
+\ref{combined_legend}
+\subfigure[
+Accuracy plateaus due to \textit{attention's distractibility}.
+]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.44]
+ \begin{axis}[
+ xlabel=Number of Retrieved Passages,
+ ylabel=Recall \& EM Scores,
+ % y label style={at={(-0.00,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.8,1.05)},anchor=south,legend columns=3,font={\fontsize{5pt}{5pt}\selectfont},/tikz/every even column/.append style={column sep=0.2cm}},
+ legend to name=combined_legend
+ ]
+ \addplot+[color={rgb,255:red,100;green,150;blue,220},line width=1.5pt,mark=square*,mark size=2pt] coordinates {
+ (0,0)
+ (1,19.8)
+ (3,35.1)
+ (5,42.3)
+ (10,52.1)
+ (20,61.6)
+ (30,65.8)
+ (40,68.2)
+ (50,71.5)
+ };
+ \addlegendentry{Retrieval Recall}
+ \addplot+[color={rgb,255:red,255;green,110;blue,110},line width=1.5pt,mark=diamond*,mark size=2pt] coordinates {
+ (0,21.7)
+ (1,33.6)
+ (3,39.3)
+ (5,42.3)
+ (10,42.9)
+ (20,43.5)
+ (30,42.2)
+ (40,41.7)
+ (50,40.8)
+ };
+ \addlegendentry{LLAMA with Retrieval}
+ \addplot+[color={rgb,255:red,160;green,120;blue,190},line width=1.5pt,mark=triangle*,mark size=2pt] coordinates {
+ (0,24.8)
+ (1,34.2)
+ (3,42.3)
+ (5,44.2)
+ (10,47.5)
+ (20,56.2)
+ (30,59.1)
+ (40,60.7)
+ (50,61.5)
+ };
+ \addlegendentry{LKG-RALM with Retrieval}
+ \end{axis}
+\end{tikzpicture}
+\label{fig:retrieved passages number}
+\end{minipage}
+}
+\hfill
+\subfigure[U-shaped performance due to \textit{positional bias}.]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.44]
+ \begin{axis}[
+ xlabel=Position of Passage with the Answer,
+ ylabel=EM Scores,
+ % y label style={at={(-0.00,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.5,1.05)},anchor=south,legend columns=3},
+ legend cell align={left},
+ ]
+ \addplot+ [color={rgb,255:red,255;green,110;blue,110},line width=1.5pt,mark=diamond*,mark size=2pt] coordinates {
+ (1,51.1)
+ (2,46.7)
+ (3,43.5)
+ (4,41.4)
+ (5,38.0)
+ (6,36.2)
+ (7,35.7)
+ (8,35.6)
+ (9,34.8)
+ (10,36.7)
+ (11,35.5)
+ (12,37.1)
+ (13,38.6)
+ (14,39.4)
+ (15,38.3)
+ (16,42.8)
+ (17,44.9)
+ (18,45.5)
+ (19,49.3)
+ (20,50.8)
+ };
+ \addplot+ [color={rgb,255:red,160;green,120;blue,190},line width=1.5pt,mark=triangle*,mark size=2pt] coordinates {
+ (1,51.2)
+ (5,51.0)
+ (10,51.1)
+ (15,51.2)
+ (20,51.2)
+ };
+ \end{axis}
+\end{tikzpicture}
+\label{fig:answer position}
+\end{minipage}
+}
+\caption{Comparison of LLAMA's and LKG-RALM's performance with retrieved passages.}
+\label{fig:LLM challenges}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-passage relevance estimator.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-passage relevance estimator.tex"
new file mode 100644
index 0000000000..89811eae79
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-passage relevance estimator.tex"
@@ -0,0 +1,10 @@
+\begin{figure}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=0.95\textwidth]{figure/src/Passage Relevance Estimator.pdf}
+\end{minipage}
+\centering
+\caption{The layer-wise passage estimator.}
+\label{fig:passage relevance estimator}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-training-data-size.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-training-data-size.tex"
new file mode 100644
index 0000000000..b773e150b2
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/fig-training-data-size.tex"
@@ -0,0 +1,76 @@
+\begin{figure}[t]
+\centering
+\subfigure[NQ]{
+\begin{minipage}[t]{0.46\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel=Training Data Size (k),
+ ylabel=EM Scores,
+ % y label style={at={(-0.0,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.5,1.05)}, anchor=south, legend columns=2, draw=black, fill=white,align=left,font=\normalsize},
+ ]
+ \addplot+[color={rgb,255:red,255;green,110;blue,110}, line width=1.5pt, mark=diamond*, mark size=2.5pt] coordinates {
+ (30,50.5)
+ (60,53.4)
+ (90,56.2)
+ (160,58.5)
+ (220,60.3)
+ (440,61.0)
+ };
+ \addlegendentry{LKG-RALM};
+ \addplot+[color={rgb,255:red,100;green,150;blue,220}, line width=1.5pt, mark=square*, mark size=2pt] coordinates {
+ (30,44.6)
+ (60,45.3)
+ (90,45.4)
+ (160,45.5)
+ (220,45.5)
+ (440,45.7)
+ };
+ \addlegendentry{LLAMA-FineTune};
+ \end{axis}
+\end{tikzpicture}
+\label{fig:nq-data-size}
+\end{minipage}
+}
+\hfill
+\subfigure[HotpotQA]{
+\begin{minipage}[t]{0.46\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel=Training Data Size (k),
+ ylabel=EM Scores,
+ % y label style={at={(-0.0,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.5,1.05)}, anchor=south, legend columns=2, draw=black, fill=white,align=left,font=\normalsize},
+ ]
+ \addplot+[color={rgb,255:red,255;green,110;blue,110}, line width=1.5pt, mark=diamond*, mark size=2.5pt] coordinates {
+ (30,37.1)
+ (60,40.7)
+ (90,43.2)
+ (160,45.0)
+ (220,45.6)
+ (440,45.8)
+ };
+ \addlegendentry{LKG-RALM};
+ \addplot+[color={rgb,255:red,100;green,150;blue,220}, line width=1.5pt, mark=square*, mark size=2pt] coordinates {
+ (30,36.7)
+ (60,37.1)
+ (90,37.7)
+ (160,37.9)
+ (220,38.2)
+ (440,38.1)
+ };
+ \addlegendentry{LLAMA-FineTune};
+ \end{axis}
+\end{tikzpicture}
+\label{fig:hotpotqa-data-size}
+\end{minipage}
+}
+\caption{Effects of training data size.}
+\label{fig:data-size-impact}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-1.tex"
new file mode 100644
index 0000000000..2245744b93
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-1.tex"
@@ -0,0 +1,18 @@
+\begin{table}[t]
+\begin{center}
+\begin{tabular}{lccc}
+\hline
+\textbf{Retrieval Varient} & \textbf{recall@5} & \textbf{recall@20} & \textbf{recall@50} \\
+\hline
+HC-LLM & 74.2 & 86.3 & 89.8 \\
+\hline
+w/o hybrid retriever & 71.5 & 82.7 & 87.6 \\
+w/o hierarchical granularity & 70.3 & 81.5 & 85.5 \\
+w/o question rewriting & 72.3 & 83.5 & 87.6 \\
+w/o document filtering & 71.5 & 84.1 & 87.6 \\
+\hline
+\end{tabular}
+\end{center}
+\caption{Ablation result of retrieval}
+\label{table:q2}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-2.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-2.tex"
new file mode 100644
index 0000000000..ab8608a6ef
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q2-2.tex"
@@ -0,0 +1,20 @@
+\setlength{\tabcolsep}{3.2pt} % 默认值是6pt
+\begin{table}[t]
+\begin{center}
+\scriptsize
+\caption{Ablation result of LKG-RALM, where "LPR", "ELS", and "RPF" stand for Layer-wised Passage Relevance, Entropy-based Layer-Knowledge Selection, and Relevance-aware Passage Fusion, respectively.}
+\begin{tabular}{lccccc}
+\hline
+\textbf{Model Varient} & \textbf{NQ} & \textbf{TriviaQA} & \textbf{HotpotQA} & \textbf{PopQA} & \textbf{2WikiMQA} \\
+\hline
+LKG-RALM-8B & 55.3 & 88.6 & 43.1 & 57.2 & 39.0 \\
+\hline
+w/o LPR & 50.6 & 86.0 & 39.5 & 52.5 & 34.8 \\
+w/o ELS & 53.1 & 87.7 & 41.8 & 55.8 & 37.2 \\
+w/o RPF & 30.9 & 70.7 & 26.0 & 34.9 & 9.6 \\
+w/o Auxiliary Loss & 54.2 & 87.3 & 42.3 & 56.4 & 37.8 \\
+\hline
+\end{tabular}
+\label{table:q2-2}
+\end{center}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-1.tex"
new file mode 100644
index 0000000000..e0bea5c2fd
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-1.tex"
@@ -0,0 +1,52 @@
+\begin{figure}[t]
+\centering
+\begin{minipage}[t]{0.23\textwidth}
+\begin{tikzpicture}[scale=0.47]
+ \begin{axis}[
+ xlabel=number $m$,
+ ylabel=EM,
+ y label style={at={(-0.05,0.5)}},
+ ymajorgrids=true,
+ font=\LARGE,
+ ]
+ \addplot+[color=red,line width=1.5pt,mark size=3pt] coordinates {
+ (4,59.2)
+ (8,60.7)
+ (12,61.0)
+ (16,61.9)
+ (20,61.8)
+ (24,60.3)
+ (28,60.5)
+ (32,60.1)
+ };
+ \end{axis}
+\end{tikzpicture}
+\caption{Impact of passage-level embedding number}
+\label{fig:global token number}
+\end{minipage}
+\begin{minipage}[t]{0.23\textwidth}
+\begin{tikzpicture}[scale=0.47]
+ \begin{axis}[
+ xlabel=layer number $M$,
+ ylabel=EM,
+ y label style={at={(-0.05,0.5)}},
+ ymajorgrids=true,
+ font=\LARGE,
+ % ymin=0.12,
+ % ymax=0.15,
+ ]
+ \addplot+ [color=red,line width=1.5pt,mark size=3pt] coordinates {
+ (5,54.3)
+ (10,61.9)
+ (15,60.8)
+ (20,61.2)
+ (25,59.4)
+ (30,51.5)
+ (35,36.2)
+ };
+ \end{axis}
+\end{tikzpicture}
+\caption{Impact of cyclic passage-question co-encoding layers number}
+\label{fig:local encoding layers number}
+\end{minipage}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-3.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-3.tex"
new file mode 100644
index 0000000000..46c332b6da
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q3-3.tex"
@@ -0,0 +1,60 @@
+\begin{figure}[t]
+\subfigure[Impact of varying passage number to retrieval]{
+\begin{minipage}[t][2.5cm][t]{0.46\linewidth}
+\begin{tikzpicture}[scale=0.46]
+ \begin{axis}[
+ xlabel=Number of retrieved passages,
+ ylabel=Recall,
+ y label style={at={(-0.05,0.5)}},
+ ymajorgrids=true,
+ % xlabel style={text width=8cm},
+ font=\LARGE,
+ ]
+ \addplot+[color=blue,line width=1.5pt,mark size=3pt] coordinates {
+ (5,74.2)
+ (10,81.0)
+ (15,83.7)
+ (20,86.3)
+ (25,87.1)
+ (30,87.9)
+ (35,88.5)
+ (40,88.9)
+ (45,89.4)
+ (50,89.8)
+ };
+ \end{axis}
+\end{tikzpicture}
+\label{fig:q3-3-1}
+\end{minipage}
+}
+\subfigure[Impact of varying passage number to HC-LLM]{
+\begin{minipage}[t][2.5cm][t]{0.46\linewidth}
+\begin{tikzpicture}[scale=0.46]
+ \begin{axis}[
+ xlabel=Number of retrieved passages,
+ ylabel=EM,
+ y label style={at={(-0.05,0.5)}},
+ ymajorgrids=true,
+ font=\LARGE,
+ % ymin=0.12,
+ % ymax=0.18,
+ ]
+ \addplot+[color=blue,line width=1.5pt,mark size=3pt] coordinates {
+ (5,40.4)
+ (10,53.8)
+ (15,57.1)
+ (20,61.9)
+ (25,61.7)
+ (30,61.8)
+ (35,61.2)
+ (40,59.5)
+ (45,58.4)
+ (50,58.2)
+ };
+ \end{axis}
+\end{tikzpicture}
+\label{fig:q3-3-2}
+\end{minipage}
+}
+\caption{Impact of retrieved passage number}
+\end{figure}
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q4.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q4.tex"
new file mode 100644
index 0000000000..56918cb7a9
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q4.tex"
@@ -0,0 +1,152 @@
+\begin{figure}[t]
+\centering
+\subfigure[Impact of Increasing Passage Number]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel=Retrieved Passage Number,
+ ylabel=EM Scores,
+ % y label style={at={(-0.0,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.5,1.05)}, anchor=south, legend columns=2, draw=black, fill=white,align=left},font=\normalsize
+ ]
+ \addplot+[color={rgb,255:red,255;green,110;blue,110}, line width=1.5pt, mark=diamond*, mark size=2.5pt] coordinates {
+ (0,24.8)
+ (5,44.2)
+ (10,47.5)
+ (15,52.8)
+ (20,56.2)
+ (25,57.9)
+ (30,59.1)
+ (35,60.3)
+ (40,60.7)
+ (45,61.0)
+ (50,61.5)
+ };
+ \addlegendentry{LKG-RALM};
+ \addplot+[color={rgb,255:red,100;green,150;blue,220}, line width=1.5pt, mark=square*, mark size=2pt] coordinates {
+ (0,30.0)
+ (5,38.6)
+ (10,45.3)
+ (15,48.7)
+ (20,51.2)
+ (25,53.7)
+ (30,53.4)
+ (35,52.9)
+ (40,53.3)
+ (45,53.8)
+ (50,54.2)
+ };
+ \addlegendentry{RankRAG};
+ \addplot+[color={rgb,255:red,100;green,200;blue,100}, line width=1.5pt, mark=triangle*, mark size=2.5pt] coordinates {
+ (0,35.0)
+ (5,39.4)
+ (10,42.8)
+ (15,43.6)
+ (20,43.5)
+ (25,40.2)
+ (30,37.5)
+ (35,33.7)
+ (40,31.2)
+ (45,29.6)
+ (50,28.4)
+ };
+ \addlegendentry{Self-RAG};
+ \addplot+[color=orange!60, line width=1.5pt, mark=*, mark size=2pt] coordinates {
+ (0,40.3)
+ (5,40.4)
+ (10,40.5)
+ (15,40.3)
+ (20,40.1)
+ (25,40.4)
+ (30,40.6)
+ (35,40.3)
+ (40,40.2)
+ (45,40.3)
+ (50,40.4)
+ };
+ \addlegendentry{GPT-4 with Retrieval};
+ \end{axis}
+\end{tikzpicture}
+\label{fig:q4-1}
+\end{minipage}
+}
+\hfill
+\subfigure[Impact of Irrelevant Passages]{
+\begin{minipage}[t]{0.45\linewidth}
+\centering
+\begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel=Proportion of Irrelevant Passages (\%),
+ ylabel=EM Scores,
+ % y label style={at={(-0.0,0.5)}},
+ ymajorgrids=true,
+ font=\Large,
+ legend style={at={(0.5,1.05)}, anchor=south, legend columns=2, draw=black, fill=white,align=left,font=\normalsize},
+ ]
+ \addplot+[color={rgb,255:red,255;green,110;blue,110}, line width=1.5pt, mark=diamond*, mark size=2.5pt] coordinates {
+ (0,61.0)
+ (10,60.8)
+ (20,60.4)
+ (30,59.9)
+ (40,59.2)
+ (50,58.4)
+ (60,57.3)
+ (70,56.2)
+ (80,54.8)
+ (90,50.1)
+ (100,42.7)
+ };
+ \addlegendentry{LKG-RALM};
+ \addplot+[color={rgb,255:red,100;green,150;blue,220}, line width=1.5pt, mark=square*, mark size=2pt] coordinates {
+ (0,54.2)
+ (10,53.6)
+ (20,52.7)
+ (30,51.2)
+ (40,48.8)
+ (50,46.1)
+ (60,42.6)
+ (70,38.7)
+ (80,34.1)
+ (90,26.3)
+ (100,15.6)
+ };
+ \addlegendentry{RankRAG};
+ \addplot+[color={rgb,255:red,100;green,200;blue,100}, line width=1.5pt, mark=triangle*, mark size=2.5pt] coordinates {
+ (0,28.4)
+ (10,28.1)
+ (20,27.6)
+ (30,26.8)
+ (40,25.6)
+ (50,24.1)
+ (60,22.3)
+ (70,20.2)
+ (80,17.8)
+ (90,14.5)
+ (100,10.2)
+ };
+ \addlegendentry{Self-RAG};
+ \addplot+[color=orange!60, line width=1.5pt, mark=*, mark size=2pt] coordinates {
+ (0,40.4)
+ (10,40.2)
+ (20,39.8)
+ (30,39.3)
+ (40,38.7)
+ (50,38.0)
+ (60,37.2)
+ (70,36.2)
+ (80,34.9)
+ (90,33.4)
+ (100,31.5)
+ };
+ \addlegendentry{GPT-4 with Retrieval};
+ \end{axis}
+\end{tikzpicture}
+\label{fig:q4-2}
+\end{minipage}
+}
+\caption{Robustness to the number of retrieved passages and the proportion of irrelevant passages.}
+\label{fig:impact_passages}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q5.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q5.tex"
new file mode 100644
index 0000000000..0fa3b7b2bc
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q5.tex"
@@ -0,0 +1,10 @@
+\begin{figure}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{figure/src/attention_visual.pdf}
+\end{minipage}
+\centering
+\caption{The visualization of QPI's attention weights in HC-LLM, where the heatmap shows the attention weights between questions and passages across model layers. The x-axis corresponds to the passages and question serving as attention keys/values, while the y-axis indexes the HC-LLM layers. A darker color indicates a larger attention weight. In this case, $p_5$ and $p_{13}$ contain answer evidence and receive higher weighting, while the irrelevant passages receive lower attention weights. This demonstrates how the question-guided focusing helps identify and prioritize the most relevant contexts for answering.}
+\label{fig:visual}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-1.tex"
new file mode 100644
index 0000000000..7fa8e2a3be
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-1.tex"
@@ -0,0 +1,20 @@
+\begin{table}[t]
+\footnotesize
+\begin{center}
+\begin{tabular}{llccccc}
+\hline
+\textbf{Model} & \textbf{Hum.} & \textbf{Social.} & \textbf{STEM} & \textbf{Other} & \textbf{All} \\
+\hline
+LLAMA-3-8B & 73.8 & 75.2 & 69.5 & 73.5 & 73.0 \\
+ChatGPT & 71.2 & 73.6 & 65.8 & 69.4 & 70.0 \\
+GPT4 & 85.7 & 87.9 & 84.2 & 87.8 & 86.4 \\
+Self-RAG & 64.5 & 65.8 & 63.1 & 65.4 & 64.7 \\
+RankRAG & 74.1 & 75.6 & 70.8 & 73.5 & 73.5 \\
+\hline
+LKG-RALM-8B & 75.3 & 77.2 & 71.9 & 74.8 & 74.8 \\
+\hline
+\end{tabular}
+\end{center}
+\caption{Performance on MMLU task.}
+\label{table:MMLU}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-2.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-2.tex"
new file mode 100644
index 0000000000..7572c4612a
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/q6-2.tex"
@@ -0,0 +1,25 @@
+\setlength{\tabcolsep}{1.5pt} % 默认值是6pt
+\begin{table}[t]
+\footnotesize
+\begin{center}
+\begin{tabular}{llccc}
+\hline
+\textbf{Model} & \textbf{\# Params} & \textbf{Original} & \textbf{+LKG-RALM} & \textbf{Gain \%}\\
+\hline
+\multirow{4}*{GPT-2} & 117M & 1.33 & 1.22 & 8.27 \\
+ & 345M & 1.20 & 1.13 & 10.83 \\
+ & 774M & 1.19 & 1.14 & 4.20 \\
+ & 1.5B & 1.16 & 1.01 & 12.93 \\
+\hline
+\multirow{3}*{Qwen-2.5} & 7B & 0.95 & 0.90 & 5.26 \\
+& 14B & 0.88 & 0.84 & 4.54 \\
+& 72B & 0.70 & 0.66 & 5.71 \\
+\hline
+\multirow{2}*{LLAMA-3.1} & 8B & 0.97 & 0.93 & 4.12 \\
+ & 70B & 0.72 & 0.70 & 2.77 \\
+\hline
+\end{tabular}
+\end{center}
+\caption{Performance on language modeling task}
+\label{table:LM}
+\end{table}
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pdf"
new file mode 100644
index 0000000000..80c9547ebf
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pptx" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pptx"
new file mode 100644
index 0000000000..0dfe1769a1
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/OQDApipeline.pptx" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pdf"
new file mode 100644
index 0000000000..5cac3c4ccb
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pptx" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pptx"
new file mode 100644
index 0000000000..76f3a18ef3
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/Passage Relevance Estimator.pptx" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention distribution.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention distribution.pdf"
new file mode 100644
index 0000000000..6d2b26f301
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention distribution.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_distribution_pie.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_distribution_pie.pdf"
new file mode 100644
index 0000000000..b6b02f8a62
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_distribution_pie.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_visual.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_visual.pdf"
new file mode 100644
index 0000000000..e2f1cf23f9
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/attention_visual.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/edge_focused_attention.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/edge_focused_attention.pdf"
new file mode 100644
index 0000000000..6cee77cf4a
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/edge_focused_attention.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f1.py" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f1.py"
new file mode 100644
index 0000000000..f9ecdcc681
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f1.py"
@@ -0,0 +1,44 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+edge_focused = [26, 22, 17, 1, 1.5, 0.7, 1.5, 0.4, 0.8, 8, 12, 16]
+
+colors = {'start': '#FF8C98', 'middle': '#7FD4FF', 'end': '#86E088'}
+
+def create_bar_plot(data, title):
+ fig, ax = plt.subplots(figsize=(12, 10)) # Further increased figure size
+ total = sum(data)
+ percentages = data
+
+ # Adjusted bar width
+ bar_width = 0.8
+
+ ax.bar(range(3), percentages[:3], color=colors['start'], label='Start Tokens', width=bar_width)
+ ax.bar(range(3, 9), percentages[3:9], color=colors['middle'], label='Middle Tokens', width=bar_width)
+ ax.bar(range(9, 12), percentages[9:], color=colors['end'], label='End Tokens', width=bar_width)
+
+ ax.set_xlabel('Input Context', fontsize=50) # Increased font size
+ ax.set_ylabel('Attention Weight (%)', fontsize=50) # Increased font size
+ ax.set_ylim(0, max(percentages) * 1.2)
+
+ ax.set_xticks([1, 6, 10])
+ ax.set_xticklabels(['Start', 'Middle', 'End'], fontsize=44) # Increased font size
+
+ # Increase y-axis tick label size
+ ax.tick_params(axis='y', labelsize=44) # Increased font size
+
+ start = 65.2
+ middle = 4.1
+ end = 30.7
+
+ legend_labels = [
+ f'Start Tokens ({start:.1f}%)',
+ f'Middle Tokens ({middle:.1f}%)',
+ f'End Tokens ({end:.1f}%)'
+ ]
+ ax.legend(labels=legend_labels, loc='upper right', fontsize=43) # Increased font size
+
+ plt.tight_layout()
+ plt.savefig('edge_focused_attention.pdf', bbox_inches='tight')
+
+create_bar_plot(edge_focused, 'Edge-focused Attention')
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f2.py" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f2.py"
new file mode 100644
index 0000000000..5effae06c6
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f2.py"
@@ -0,0 +1,47 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+uniform = [0.1, 0.07, 0.03, 11.2, 7.4, 6.6, 7.8, 9.4, 8.3, 0.1, 0.2, 0.25]
+
+colors = {'start': '#FF8C98', 'middle': '#7FD4FF', 'end': '#86E088'}
+
+def create_bar_plot(data, title):
+ fig, ax = plt.subplots(figsize=(12, 10)) # Increased figure size
+ total = sum(data)
+ percentages = data
+
+ # Adjusted bar width
+ bar_width = 0.8
+
+ ax.bar(range(3), percentages[:3], color=colors['start'], label='Start Tokens', width=bar_width)
+ ax.bar(range(3, 9), percentages[3:9], color=colors['middle'], label='Middle Tokens', width=bar_width)
+ ax.bar(range(9, 12), percentages[9:], color=colors['end'], label='End Tokens', width=bar_width)
+
+ ax.set_xlabel('Input Context', fontsize=50) # Increased font size
+ ax.set_ylabel('Attention Weight (%)', fontsize=50) # Increased font size
+ ax.set_ylim(0, max(percentages) * 1.2)
+
+ ax.set_xticks([1, 6, 10])
+ ax.set_xticklabels(['Start', 'Middle', 'End'], fontsize=44) # Increased font size
+
+ # Increase y-axis tick label size
+ ax.tick_params(axis='y', labelsize=44) # Increased font size
+
+ start = 0.1
+ middle = 99.8
+ end = 0.1
+
+ legend_labels = [
+ f'Start Tokens ({start:.1f}%)',
+ f'Middle Tokens ({middle:.1f}%)',
+ f'End Tokens ({end:.1f}%)'
+ ]
+
+ # 添加透明度设置
+ legend = ax.legend(labels=legend_labels, loc='upper right', fontsize=43)
+ legend.get_frame().set_alpha(0.4) # 设置图例背景的透明度
+
+ plt.tight_layout()
+ plt.savefig('uniform_attention.pdf', bbox_inches='tight')
+
+create_bar_plot(uniform, 'Uniform Attention')
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f3.py" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f3.py"
new file mode 100644
index 0000000000..6d2ec09c2b
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f3.py"
@@ -0,0 +1,46 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+middle_focused = [0.1, 1.07, 0.53, 1.2, 2.8, 31.5, 3.4, 1.3, 19.2, 1.1, 0.2, 1.25]
+
+colors = {'start': '#FF8C98', 'middle': '#7FD4FF', 'end': '#86E088'}
+
+def create_bar_plot(data, title):
+ fig, ax = plt.subplots(figsize=(12, 10)) # Increased figure size
+ total = sum(data)
+ percentages = data
+
+ # Adjusted bar width
+ bar_width = 0.8
+
+ ax.bar(range(3), percentages[:3], color=colors['start'], label='Start Tokens', width=bar_width)
+ ax.bar(range(3, 9), percentages[3:9], color=colors['middle'], label='Middle Tokens', width=bar_width)
+ ax.bar(range(9, 12), percentages[9:], color=colors['end'], label='End Tokens', width=bar_width)
+
+ ax.set_xlabel('Input Context', fontsize=50) # Increased font size
+ ax.set_ylabel('Attention Weight (%)', fontsize=50) # Increased font size
+ ax.set_ylim(0, max(percentages) * 1.2)
+
+ ax.set_xticks([1, 6, 10])
+ ax.set_xticklabels(['Start', 'Middle', 'End'], fontsize=44) # Increased font size
+
+ # Increase y-axis tick label size
+ ax.tick_params(axis='y', labelsize=44) # Increased font size
+
+ start = 1.7
+ middle = 95.8
+ end = 2.5
+
+ legend_labels = [
+ f'Start Tokens ({start:.1f}%)',
+ f'Middle Tokens ({middle:.1f}%)',
+ f'End Tokens ({end:.1f}%)'
+ ]
+ # 添加透明度设置
+ legend = ax.legend(labels=legend_labels, loc='upper right', fontsize=43)
+ legend.get_frame().set_alpha(0.4) # 设置图例背景的透明度
+
+ plt.tight_layout()
+ plt.savefig('middle_focused_attention.pdf', bbox_inches='tight')
+
+create_bar_plot(middle_focused, 'Middle-focused Attention')
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f4.py" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f4.py"
new file mode 100644
index 0000000000..1f07d5b92e
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f4.py"
@@ -0,0 +1,27 @@
+import matplotlib.pyplot as plt
+
+pie_data = [87.37, 5.97, 6.67]
+pie_labels = ['Edge-focused', 'Uniform', 'Middle-focused']
+
+# 使用更浅的颜色
+colors = {
+ 'start': '#FF8C98', # 浅蓝色
+ 'middle': '#7FD4FF', # 浅橙色
+ 'end': '#86E088' # 浅绿色
+}
+
+fig, ax = plt.subplots(figsize=(16, 12)) # 保持方形图形大小
+pie_colors = [colors['start'], colors['middle'], colors['end']]
+wedges, texts, autotexts = ax.pie(pie_data, labels=pie_labels, autopct='%1.1f%%', startangle=90, colors=pie_colors, textprops={'fontsize': 44})
+
+# 将图例移到顶部
+ax.legend(wedges, pie_labels,
+ loc="upper center",
+ bbox_to_anchor=(0.1, 1.1),
+ ncol=1, # 将图例项目排列在一行
+ fontsize=43)
+
+plt.setp(autotexts, size=44)
+
+plt.tight_layout()
+plt.savefig('attention_distribution_pie.pdf', bbox_inches='tight')
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f5.py" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f5.py"
new file mode 100644
index 0000000000..336158916f
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/f5.py"
@@ -0,0 +1,45 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+ralm_data = {
+ 'Edge-focused': [40.6, 40.3, 40.5, 40.2, 40.35, 40.4, 40.25, 40.45],
+ 'Uniform': [42, 41.7, 42.3, 42.2, 41.5, 41.9, 42.5, 41.8, 42.1, 41.6, 42.4],
+ 'Middle-focused': [6.2, 18.9, 29.6, 37.4, 43.6, 49.8, 58.4, 66.3, 72.5, 76.2, 79.8]
+}
+
+# 更新后的标准差数据
+std_devs = {
+ 'Edge-focused': [1.7, 1.63, 1.68, 1.65, 1.67, 1.66, 1.64, 1.69],
+ 'Uniform': [0.55, 0.5, 0.59, 0.55, 0.53, 0.57, 0.59, 0.52, 0.56, 0.51, 0.58],
+ 'Middle-focused': [4.45, 4.1, 3.89, 5.1, 5.5, 6.22, 4.34, 3.41, 3.58, 3.95, 3.75]
+}
+
+x_values = {
+ 'Edge-focused': [0, 1.43, 2.86, 4.29, 5.71, 7.14, 8.57, 10],
+ 'Uniform': [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
+ 'Middle-focused': [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30]
+}
+
+# 图表绘制代码保持不变
+fig, ax = plt.subplots(figsize=(15, 9))
+
+colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
+
+for (label, data), color in zip(ralm_data.items(), colors):
+ x = x_values[label]
+ y = data
+ std = std_devs[label]
+
+ ax.plot(x, y, marker='o', label=label, linewidth=3, markersize=10, color=color)
+ ax.fill_between(x, np.array(y) - np.array(std), np.array(y) + np.array(std),
+ alpha=0.3, color=color)
+
+ax.set_xlabel('Cumulative attention weight on relevant passages', fontsize=35)
+ax.set_ylabel('RALM accuracy (%)', fontsize=38)
+ax.legend(fontsize=34, loc='lower right')
+ax.grid(True, linestyle='--', alpha=0.7)
+
+ax.tick_params(axis='both', which='major', labelsize=38)
+
+plt.tight_layout()
+plt.savefig('ralm_accuracy_plot.pdf', bbox_inches='tight')
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pdf"
new file mode 100644
index 0000000000..bbac04419e
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pptx" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pptx"
new file mode 100644
index 0000000000..ef25184345
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm framework.pptx" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm location.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm location.pdf"
new file mode 100644
index 0000000000..edcef9865f
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/llm location.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/middle_focused_attention.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/middle_focused_attention.pdf"
new file mode 100644
index 0000000000..b6f9f8d1f9
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/middle_focused_attention.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pdf"
new file mode 100644
index 0000000000..d0bf56f9dd
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pptx" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pptx"
new file mode 100644
index 0000000000..7f363c9e15
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/mixture-architecture.pptx" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/ralm_accuracy_plot.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/ralm_accuracy_plot.pdf"
new file mode 100644
index 0000000000..271546ca79
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/ralm_accuracy_plot.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/uniform_attention.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/uniform_attention.pdf"
new file mode 100644
index 0000000000..c7fa63b096
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/src/uniform_attention.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-ODQAperformance.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-ODQAperformance.tex"
new file mode 100644
index 0000000000..d1e8fe744a
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-ODQAperformance.tex"
@@ -0,0 +1,50 @@
+\setlength{\tabcolsep}{3.2pt} % 默认值是6pt
+\begin{table}[t]
+\scriptsize
+\begin{center}
+\caption{Overall Performance. \textbf{Bold} numbers indicate the best score across all models, while underlined numbers represent the best score within each category.}
+\begin{tabular}{lccccc}
+\hline
+\textbf{Model} & \textbf{NQ} & \textbf{TriviaQA} & \textbf{HotpotQA} & \textbf{PopQA} & \textbf{2WikiMQA} \\
+\hline
+\multicolumn{6}{c}{\textit{General LLMs}} \\
+LLAMA-3.1-8B & 18.7 & 78.5 & 16.5 & 22.1 & 13.9 \\
+\quad + Retrieval & 30.9 & 70.7 & 26.0 & 34.9 & 9.6 \\
+\quad + Fine-tuning & 35.7 & 77.4 & 28.9 & 37.1 & 25.3 \\
+LLAMA-3.1-70B & 21.8 & 89.7 & 24.1 & 27.5 & 21.6 \\
+\quad + Retrieval & 42.7 & 82.4 & 35.5 & 45.3 & 13.5 \\
+\quad + Fine-tuning & 44.9 & 89.1 & 38.5 & 50.3 & 28.4 \\
+Qwen-2.5-7B & 37.5 & 80.2 & 20.3 & 24.8 & 16.2 \\
+\quad + Retrieval & 44.3 & 83.5 & 29.7 & 37.8 & 17.8 \\
+\quad + Fine-tuning & 46.1 & 86.3 & 30.8 & 40.6 & 28.8 \\
+Qwen-2.5-72B & 39.9 & 90.5 & 26.3 & 29.8 & 24.1 \\
+\quad + Retrieval & 45.1 & 90.6 & 37.2 & 47.9 & 27.7 \\
+\quad + Fine-tuning & 47.6 & 90.5 & \underline{39.4} & 52.2 & \underline{36.0} \\
+ChatGPT & 38.6 & 82.9 & 29.9 & 28.4 & 23.9 \\
+\quad + Retrieval & 46.7 & 79.7 & 31.2 & 49.9 & 27.2 \\
+GPT-4 & 40.3 & 87.0 & 34.5 & 31.3 & 29.8 \\
+\quad + Retrieval & 40.4 & 75.0 & 27.6 & 44.3 & 14.4 \\
+Claude-3-Sonnet & 49.2 & 87.5 & 32.8 & 33.4 & 31.4 \\
+\quad + Retrieval & \underline{55.1} & \textbf{90.8} & 33.3 & \underline{52.4} & 32.6 \\
+\hline
+\multicolumn{6}{c}{\textit{Robust RALM}} \\
+REPLUGE & 23.8 & 58.6 & 21.8 & 40.1 & 25.7 \\
+Self-RAG & 28.4 & 61.6 & 25.4 & 44.8 & 30.2 \\
+RA-ISF & 31.3 & 63.2 & 28.9 & 46.8 & 31.7 \\
+Noise-Resistant RALM & 45.7 & 80.3 & 34.4 & 48.1 & 34.7 \\
+ChatQA-1.5 & 47.0 & 85.6 & 35.5 & 45.3 & 13.5 \\
+RankRAG & \underline{54.2} & \underline{86.5} & \underline{42.7} & \underline{59.9} & \underline{38.2} \\
+\hline
+\multicolumn{6}{c}{\textit{LKG-RALM}} \\
+LLAMA-3-8B & 53.6 & 87.9 & 42.4 & 56.5 & 38.7 \\
+LLAMA-3-70B & 59.9 & 89.5 & 44.7 & 62.0 & 41.1 \\
+LLAMA-3.1-8B & 55.3 & 88.6 & 43.1 & 57.2 & 39.0 \\
+LLAMA-3.1-70B & 61.0 & 89.9 & 45.8 & 62.6 & 41.3 \\
+Qwen-2.5-7B & 55.4 & 88.1 & 43.1 & 57.4 & 39.7 \\
+Qwen-2.5-72B & \textbf{61.5} & \underline{90.0} & \textbf{46.1} & \textbf{62.7} & \textbf{41.4} \\
+\hline
+\end{tabular}
+\label{table:overall performance}
+\end{center}
+\end{table}
+
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-efficient.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-efficient.tex"
new file mode 100644
index 0000000000..f53fa8dd8a
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/ARR_RAG_camera/latex/figure/table-efficient.tex"
@@ -0,0 +1,29 @@
+\begin{table}[t]
+\footnotesize
+\begin{center}
+\caption{Efficiency and Accuracy Trade-off for LKG-RALM and Baseline Models under 1024 Context Tokens.}
+\begin{tabular}{lccc}
+\hline
+\textbf{Model} & \textbf{EM} & \textbf{Speed (s/query)} & \textbf{TFLOPs} \\
+\hline
+Self-RAG & 28.4 & 3.07 & 20.5 \\
+RA-ISF & 31.3 & 3.44 & 63.8 \\
+Robust-RALM & 45.7 & 0.74 & 14.6 \\
+RankRAG & 54.2 & 1.25 & 145.4 \\
+\hline
+\multicolumn{4}{c}{\textbf{LKG-RALM with Different Passage Estimator}} \\
+\hline
+LLAMA-3.1-8B & 30.9 & 0.82 & 16.3 \\
+\hspace{0.1cm} + Qwen-2.5-500M & 53.6 & 0.83 & 18.1 \\
+\hspace{0.1cm} + Qwen-2.5-1.5B & 54.8 & 0.85 & 20.0 \\
+\hspace{0.1cm} + Qwen-2.5-7B & 55.3 & 0.91 & 30.9 \\
+\hline
+LLAMA-3.1-70B & 42.7 & 1.24 & 142.6 \\
+\hspace{0.1cm} + Qwen-2.5-500M & 60.0 & 1.24 & 144.4 \\
+\hspace{0.1cm} + Qwen-2.5-1.5B & 60.7 & 1.25 & 146.3 \\
+\hspace{0.1cm} + Qwen-2.5-7B & 61.0 & 1.27 & 157.2 \\
+\hline
+\end{tabular}
+\label{table:efficiency-accuracy}
+\end{center}
+\end{table}
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README"
new file mode 100644
index 0000000000..7e59600739
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README"
@@ -0,0 +1 @@
+# README
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README.md" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README.md"
new file mode 100644
index 0000000000..025577f497
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/README.md"
@@ -0,0 +1,53 @@
+# *ACL Paper Styles
+
+This directory contains the latest LaTeX and Word templates for *ACL
+conferences.
+
+## Instructions for authors
+
+Paper submissions to *ACL conferences must use the official ACL style
+templates.
+
+The LaTeX style files are available
+
+- as an [Overleaf template](https://www.overleaf.com/latex/templates/association-for-computational-linguistics-acl-conference/jvxskxpnznfj)
+- in this repository, in the [`latex`](https://github.com/acl-org/acl-style-files/blob/master/latex) subdirectory
+- as a [.zip file](https://github.com/acl-org/acl-style-files/archive/refs/heads/master.zip)
+
+Please see [`latex/acl_latex.tex`](https://github.com/acl-org/acl-style-files/blob/master/acl_latex.tex) for an example.
+
+The Microsoft Word template is available in this repository at [`word/acl.docx`](https://github.com/acl-org/acl-style-files/blob/master/word/acl.docx).
+
+Please follow the paper formatting guidelines general to *ACL
+conferences:
+
+- [Paper formatting guidelines](https://acl-org.github.io/ACLPUB/formatting.html)
+
+Authors may not modify these style files or use templates designed for
+other conferences.
+
+## Instructions for publications chairs
+
+To adapt the style files for your conference, please fork this repository and
+make necessary changes. Minimally, you'll need to update the name of
+the conference and rename the files.
+
+If you make improvements to the templates that should be propagated to
+future conferences, please submit a pull request. Thank you in
+advance!
+
+In older versions of the templates, authors were asked to fill in the
+START submission ID so that it would be stamped at the top of each
+page of the anonymized version. This is no longer needed, because it
+is now possible to do this stamping automatically within
+START. Currently, the way to do this is for the program chair to email
+support@softconf.com and request it.
+
+## Instructions for making changes to style files
+
+- merge pull request in github, or push to github
+- git pull from github to a local repository
+- then, git push from your local repository to overleaf project
+ - Overleaf project is https://www.overleaf.com/project/5f64f1fb97c4c50001b60549
+ - Overleaf git url is https://git.overleaf.com/5f64f1fb97c4c50001b60549
+- then, click "Submit" and then "Sumbit as Template" in overleaf in order to ask overleaf to update the overleaf template from the overleaf project
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/anthology.bib.txt" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/anthology.bib.txt"
new file mode 100644
index 0000000000..14f228c6eb
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/anthology.bib.txt"
@@ -0,0 +1,8 @@
+% Please download the latest anthology.bib from the following URL:
+%
+% http://aclweb.org/anthology/anthology.bib
+%
+% From the command line, this can be done with curl or wget.
+%
+% If you are using Overleaf, go to "New File -> From External URL".
+% You will then be able to use it directly, and to periodically update it by clicking Overleaf's convenient "refresh" button.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/formatting.md" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/formatting.md"
new file mode 100644
index 0000000000..eeb1ce1548
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/formatting.md"
@@ -0,0 +1,326 @@
+# Instructions for *ACL Proceedings
+
+The following instructions are for authors of papers submitted for review to ACL conferences (hereafter, "review version") or paper accepted for publication in its proceedings (hereafter, "final version").
+All authors are required to adhere to these specifications.
+
+## Style Files
+
+*ACL provides style files for LaTeX and Microsoft Word that meet these requirements. They can be found at:
+
+> https://acl-org.github.io/ACLPUB/
+
+We strongly recommend the use of these style files, which have been appropriately tailored for the *ACL proceedings.
+
+## Paper Length
+
+The conference accepts submissions of long papers and short papers.
+Review versions of long papers may have up to eight (8) pages of content plus unlimited pages for references.
+Upon acceptance, final versions of long papers will be given one additional page -- up to nine (9) pages of content plus unlimited pages for acknowledgements and references -- so that reviewers' comments can be taken into account.
+Review versions of short papers may have up to four (4) pages of content, plus unlimited pages for references.
+Final versions of short papers may have up to five (5) pages, plus unlimited pages for acknowledgements and references.
+For both long and short papers, all figures and tables that are part of the main text must fit within these page limits.
+
+The conference encourages submission of appendices and supplementary material, which are not required to fit within these page limits. However, review versions of papers must be self-contained: it is optional for reviewers to look at appendices or supplementary material. Please see [Appendices](#Appendices) and [Supplementary](#Supplementary Material) for more information.
+
+Review versions should not refer, for further detail, to documents, code or data resources that are not available to the reviewers.
+
+Papers that do not conform to these requirements may be rejected without review.
+
+Workshop chairs may have different rules for allowed length and whether appendices or supplementary materials are welcome.
+As always, the respective call for papers is the authoritative source.
+
+## Anonymity
+
+As reviewing will be double-blind, review versions must not include any identifying information about the authors (such as names, affiliations, or URLs).
+Self-references that reveal the author's identity, e.g.,
+
+> We previously showed (Gusfield, 1997)...
+
+must be avoided, and anonymous citations, e.g.,
+
+> We previously showed (Anonymous, 1997)...
+
+should also be avoided. Instead, use citations such as
+
+> Gusfield (1997) previously showed...
+
+Review versions must not include acknowledgements.
+
+**Papers that do not conform to these requirements may be rejected without review.**
+
+Any preliminary non-archival versions of submitted papers should be listed in the submission form but not in the review version of the paper.
+Reviewers are generally aware that authors may present preliminary versions of their work in other venues, but will not be provided the list of previous presentations from the submission form.
+
+Once a paper has been accepted to the conference, the final version should include the author's names and affiliations, and is allowed to use self-references.
+
+## Multiple Submission
+
+Papers that have been or will be submitted to other meetings or publications must indicate this at submission time in the START submission form, and must be withdrawn from the other venues if accepted by *ACL.
+Authors of papers accepted for presentation at *ACL must notify the program chairs by the deadline for final versions ("camera-ready deadline") whether the paper will be presented.
+We will not accept for publication or presentation any papers that overlap significantly in content or results with papers that will be (or have been) published elsewhere.
+
+Authors submitting more than one paper to *ACL must ensure that submissions do not overlap significantly (>25%) with each other in content or results.
+
+## Formatting Instructions
+
+### File Format
+
+Papers must be in Adobe Portable Document Format (PDF).
+Please make sure that your PDF file embeds all necessary fonts (especially for tree diagrams, symbols, and Asian languages).
+When you print or create the PDF file, there is usually an option in your printer setup to include none, all or just non-standard fonts.
+Please make sure that you select the option of including *all* the fonts.
+**Before sending it, test your PDF by printing it from a computer different from the one where it was created.**
+
+Some word processors may generate very large PDF files, where each page is rendered as an image.
+Such images may reproduce poorly.
+In this case, try alternative ways to obtain the PDF.
+
+All papers must use **A4 paper format** (21 cm x 29.7 cm).
+Papers must not be submitted with any other paper size.
+
+If you cannot meet the above requirements, please contact the publication chairs as soon as possible.
+
+### Layout
+
+All text except for page numbers must fit within the margins.
+
+Review versions should have page numbers, centered in the bottom margin, but **pages should not be numbered in the final version.**
+
+Manuscripts must be set in two columns.
+Exceptions to the two-column format include the title, authors' names and complete addresses, which must be centered at the top of the first page, and any full-width figures or tables.
+
+The exact dimensions for a page on A4 paper are:
+
+* Left margin: 2.5 cm
+* Right margin: 2.5 cm
+* Top margin: 2.5 cm
+* Bottom margin: 2.5 cm
+* Column width: 7.7 cm
+* Column height: 24.7 cm
+* Gap between columns: 0.6 cm
+
+In the review version, a ruler (line numbers in the left and right margins of the article) should be printed, so that reviewers may comment on particular lines in the paper.
+The ruler should not change the appearance of any other content on the page.
+The final version should not contain a ruler.
+
+### Fonts
+
+All text (except non-Latin scripts and mathematical formulas) should be set in **Times Roman**.
+If Times Roman is unavailable, you may use **Times New Roman** or **Computer Modern Roman.**
+
+The following table specifies what font sizes and styles must be used for each type of text in the manuscript.
+
+| Type of Text | Font Size | Style |
+| --------------------- | --------- | ----- |
+| paper title | 15 pt | bold |
+| author names | 12 pt | bold |
+| author affiliation | 12 pt | |
+| the word ``Abstract'' | 12 pt | bold |
+| section titles | 12 pt | bold |
+| subsection titles | 11 pt | bold |
+| document text | 11 pt | |
+| captions | 10 pt | |
+| abstract text | 10 pt | |
+| bibliography | 10 pt | |
+| footnotes | 9 pt | |
+
+### Title and Authors
+
+Center the title, author's name(s) and affiliation(s) across both columns.
+
+Place the title centered at the top of the first page, in 15-point bold.
+Long titles should be typed on two lines without a blank line intervening.
+Put the title 2.5 cm from the top of the page.
+Write the title in [title case](https://apastyle.apa.org/style-grammar-guidelines/capitalization/title-case); do not write the title in all capital letters, except for acronyms (e.g., "BLEU") or proper nouns ("English") that are normally uppercased or capitalized.
+
+Place the author name(s) and affiliation(s) under the title.
+Write authors' full names; do not abbreviate given names to initials, unless they are normally written as initials ("Margaret Mitchell", not "M. Mitchell").
+Do not format surnames in all capitals ("Mitchell", not "MITCHELL").
+
+Do not use footnotes for affiliations.
+The affiliation should contain the author's complete address, and if possible, an electronic mail address.
+
+The title, author names and addresses should be completely identical to those entered to the paper submission website in order to maintain the consistency of author information among all publications of the conference.
+If they are different, the publication chairs may resolve the difference without consulting with you; so it is in your own interest to double-check that the information is consistent.
+
+Start the body of the first page 7.5 cm from the top of the page.
+**Even in the review version of the paper, you should maintain space for names and addresses so that they will fit in the final version.**
+
+### Abstract
+
+Type the abstract at the beginning of the first column.
+Center the word **Abstract** in 12 point bold above the body of the abstract.
+The width of the abstract should be smaller than the
+normal column width by 0.6 cm on each side.
+The abstract text should be 10 point roman, single-spaced.
+
+The abstract should be a concise summary of the general thesis and conclusions of the paper.
+It should be no longer than 200 words.
+
+### Text
+
+Begin typing the main body of the text immediately after the abstract, continuing in two columns.
+The text should be 11 point roman, single-spaced.
+
+Indent 0.4 cm when starting a new paragraph, except for the first paragraph in a section.
+
+### Sections
+
+Use numbered sections (Arabic numerals) to facilitate cross references.
+Number subsections with the section number and the subsection number separated by a dot, in Arabic numerals, e.g.,
+
+> 1 Introduction
+
+or
+
+> 6.1 File Format
+
+### Footnotes
+Put footnotes at the bottom of the page and use 9 point font.
+They may be numbered or referred to by asterisks or other symbols.
+Footnotes should be separated from the text by a line.
+
+### Figures and tables
+
+Place figures and tables in the paper near where they are first discussed, rather than at the end, if possible.
+Wide figures/tables may run across both columns.
+
+To accommodate people who are color-blind (as well as those printing with black-and-white printers), grayscale readability is strongly encouraged.
+Color is not forbidden, but authors should ensure that tables and figures do not rely solely on color to convey critical distinctions.
+
+**Captions:**
+Provide a caption for every figure/table; number each one sequentially in the form:
+
+> Figure 1: Caption of the Figure.
+
+and
+
+> Table 1: Caption of the Table.
+
+Captions should be placed below figures/tables, in 10 point roman type.
+Captions that are one line are centered.
+Captions longer than one line are left-aligned.
+
+### Hyperlinks
+
+Within-document and external hyperlinks should be dark blue (hex #000099), not underlined or boxed.
+
+### Non-English Text
+
+Text in languages other than English should be accompanied by translations into English, and text in scripts other than Latin should \emph{also} be accompanied by transliterations into Latin script, since not all readers can recognize non-Latin characters easily.
+
+For example, παράδειγμα *paradeigma* ‘example’ is a Greek word, and this is a Greek sentence:
+
+> Αυτό είναι ένα παράδειγμα.
+> auto einai ena paradeigma.
+> ‘This is an example.’
+
+### Citations
+
+Citations within the text appear in parentheses (Gusfield, 1997), or, if the author's name appears in the text itself: Gusfield (1997).
+Append lowercase letters to the year in cases of ambiguities.
+Cite papers with two authors using both authors' names (Aho and Ullman, 1972), but cite papers with more than two authors by the first author's name and ``et al.'' (Chandra et al., 1981).
+Collapse multiple citations into a single pair of parentheses (Gusfield, 1997; Aho and Ullman, 1972).
+
+Refrain from using full citations as sentence constituents.
+Instead of
+
+> (Gusfield, 1997) showed that ...
+> In (Gusfield, 1997), ...''
+
+write
+
+> Gusfield (1997) showed that ...
+> In Gusfield (1997), ...
+
+Submissions should accurately reference prior and related work, including code and data.
+If a piece of prior work appeared in multiple venues, the version that appeared in a refereed, archival venue should be referenced.
+If multiple versions of a piece of prior work exist, the one used by the authors should be referenced.
+
+### Acknowledgments
+
+The acknowledgments should go immediately before the references.
+Do not number the acknowledgments section.
+Do not include this section in the review version.
+
+### References
+
+Gather the full set of references together under the unnumbered section heading **References**.
+Place the References section before any Appendices.
+Arrange the references alphabetically by first author, rather than by order of occurrence in the text.
+
+Provide as complete a citation as possible, using a consistent format, such as the [one for Computational Linguistics](http://cljournal.org/style_guide_refs.html) or the one in the [Publication Manual of the American Psychological Association](https://apastyle.apa.org/products/publication-manual-7th-edition).
+Use full names for authors, not just initials.
+Authors should not rely on automated citation indices to provide accurate references for prior and related work.
+
+As part of our work to make ACL materials more widely used and cited outside of our discipline, ACL has registered as a CrossRef member, as a registrant of Digital Object Identifiers (DOIs), the standard for registering permanent URNs for referencing scholarly materials.
+
+All references are required to contain DOIs of all cited works when possible, or, as a second resort, links to ACL Anthology pages.
+Appropriate records should be found for most materials in the current [ACL Anthology](https://aclweb.org/anthology/).
+
+Example article in a journal:
+
+> Rie Kubota Ando and Tong Zhang. 2005. [A framework for learning predictive structures from multiple tasks and unlabeled data](https://www.jmlr.org/papers/v6/ando05a.html). *Journal of Machine Learning Research*, 6:1817–1853.
+
+Example paper in non-ACL proceedings, with DOI:
+
+> Galen Andrew and Jianfeng Gao. 2007. [Scalable training of L1-regularized log-linear models](https://doi.org/10.1145/1273496.1273501). In *Proceedings of the 24th International Conference on Machine Learning*, pages 33–40.
+
+Example ACL Anthology paper with DOI:
+
+> James Goodman, Andreas Vlachos, and Jason Naradowsky. 2016. [Noise reduction and targeted exploration in imitation learning for Abstract Meaning Representation parsing](http://dx.doi.org/10.18653/v1/P16-1001). In *Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)*, pages 1–45711, Berlin, Germany. Association for Computational Linguistics.
+
+Example ACL Anthology paper without DOI:
+
+> Benjamin Börschinger and Mark Johnson. 2011. [A particle filter algorithm for Bayesian word segmentation](https://www.aclweb.org/anthology/U11-1004/). In *Proceedings of the Australasian Language Technology Association Workshop 2011*, pages 10–44718, Canberra, Australia.
+
+Example arXiv paper:
+
+> Mohammad Sadegh Rasooli and Joel R. Tetreault. 2015. [Yara parser: A fast and accurate dependency parser](http://arxiv.org/abs/1503.06733). *Computing Research Repository*, arXiv:1503.06733. Version 2.
+
+## Appendices
+
+Appendices are material that can be read, and include lemmas, formulas, proofs, and tables that are not critical to the reading and understanding of the paper.
+Letter them in sequence and provide an informative title:
+
+> Appendix A. Title of Appendix
+
+The appendices come after the references.
+
+Review versions of appendices must follow the same anonymity guidelines as the main paper.
+
+## Supplementary Material
+
+Submissions may include non-readable supplementary material used in the work and described in the paper.
+Any accompanying software and/or data should include licenses and documentation of research review as appropriate.
+Supplementary material may report preprocessing decisions, model parameters, and other details necessary for the replication of the experiments reported in the paper.
+Seemingly small preprocessing decisions can sometimes make a large difference in performance, so it is crucial to record such decisions to precisely characterize state-of-the-art methods.
+
+Nonetheless, supplementary material should be supplementary (rather than central) to the paper.
+**Submissions that misuse the supplementary material may be rejected without review.**
+Supplementary material may include explanations or details of proofs or derivations that do not fit into the paper, lists of features or feature templates, sample inputs and outputs for a system, pseudo-code or source code, and data.
+(Source code and data should be separate uploads, rather than part of the paper).
+
+The paper should not rely on the supplementary material: while the paper may refer to and cite the supplementary material and the supplementary material will be available to the reviewers, they will not be asked to review the supplementary material.
+
+Review versions of supplementary material must follow the same anonymity guidelines as the main paper.
+
+## Credits
+
+This document has been adapted from the instructions for earlier ACL and NAACL proceedings, including those for
+ACL 2020 by Steven Bethard, Ryan Cotterell and Rui Yan,
+ACL 2019 by Douwe Kiela and Ivan Ivan Vulić,
+NAACL 2019 by Stephanie Lukin and Alla Roskovskaya,
+ACL 2018 by Shay Cohen, Kevin Gimpel, and Wei Lu,
+NAACL 2018 by Margaret Mitchell and Stephanie Lukin,
+BibTeX suggestions for (NA)ACL 2017/2018 from Jason Eisner,
+ACL 2017 by Dan Gildea and Min-Yen Kan,
+NAACL 2017 by Margaret Mitchell,
+ACL 2012 by Maggie Li and Michael White,
+ACL 2010 by Jing-Shin Chang and Philipp Koehn,
+ACL 2008 by Johanna D. Moore, Simone Teufel, James Allan, and Sadaoki Furui,
+ACL 2005 by Hwee Tou Ng and Kemal Oflazer,
+ACL 2002 by Eugene Charniak and Dekang Lin,
+and earlier ACL and EACL formats written by several people, including
+John Chen, Henry S. Thompson and Donald Walker.
+Additional elements were taken from the formatting instructions of the *International Joint Conference on Artificial Intelligence* and the *Conference on Computer Vision and Pattern Recognition*.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl.sty" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl.sty"
new file mode 100644
index 0000000000..c494e0a838
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl.sty"
@@ -0,0 +1,321 @@
+% This is the LaTex style file for *ACL.
+% The official sources can be found at
+%
+% https://github.com/acl-org/acl-style-files/
+%
+% This package is activated by adding
+%
+% \usepackage{acl}
+%
+% to your LaTeX file. When submitting your paper for review, add the "review" option:
+%
+% \usepackage[review]{acl}
+
+\newif\ifacl@finalcopy
+\newif\ifacl@anonymize
+\newif\ifacl@linenumbers
+\newif\ifacl@pagenumbers
+\DeclareOption{final}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumbersfalse}
+\DeclareOption{review}{\acl@finalcopyfalse\acl@anonymizetrue\acl@linenumberstrue\acl@pagenumberstrue}
+\DeclareOption{preprint}{\acl@finalcopytrue\acl@anonymizefalse\acl@linenumbersfalse\acl@pagenumberstrue}
+\ExecuteOptions{final} % final copy is the default
+
+% include hyperref, unless user specifies nohyperref option like this:
+% \usepackage[nohyperref]{acl}
+\newif\ifacl@hyperref
+\DeclareOption{hyperref}{\acl@hyperreftrue}
+\DeclareOption{nohyperref}{\acl@hyperreffalse}
+\ExecuteOptions{hyperref} % default is to use hyperref
+\ProcessOptions\relax
+
+\typeout{Conference Style for ACL}
+
+\usepackage{xcolor}
+
+\ifacl@linenumbers
+ % Add draft line numbering via the lineno package
+ % https://texblog.org/2012/02/08/adding-line-numbers-to-documents/
+ \usepackage[switch,mathlines]{lineno}
+
+ % Line numbers in gray Helvetica 8pt
+ \font\aclhv = phvb at 8pt
+ \renewcommand\linenumberfont{\aclhv\color{lightgray}}
+
+ % Zero-fill line numbers
+ % NUMBER with left flushed zeros \fillzeros[]
+ \newcount\cv@tmpc@ \newcount\cv@tmpc
+ \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi
+ \cv@tmpc=1 %
+ \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi
+ \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat
+ \ifnum#2<0\advance\cv@tmpc1\relax-\fi
+ \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat
+ \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}%
+ \renewcommand\thelinenumber{\fillzeros[3]{\arabic{linenumber}}}
+ \linenumbers
+
+ \setlength{\linenumbersep}{1.6cm}
+
+ % Bug: An equation with $$ ... $$ isn't numbered, nor is the previous line.
+
+ % Patch amsmath commands so that the previous line and the equation itself
+ % are numbered. Bug: multline has an extra line number.
+ % https://tex.stackexchange.com/questions/461186/how-to-use-lineno-with-amsmath-align
+ \usepackage{etoolbox} %% <- for \pretocmd, \apptocmd and \patchcmd
+
+ \newcommand*\linenomathpatch[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomath}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomath}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+ \newcommand*\linenomathpatchAMS[1]{%
+ \expandafter\pretocmd\csname #1\endcsname {\linenomathAMS}{}{}%
+ \expandafter\pretocmd\csname #1*\endcsname {\linenomathAMS}{}{}%
+ \expandafter\apptocmd\csname end#1\endcsname {\endlinenomath}{}{}%
+ \expandafter\apptocmd\csname end#1*\endcsname {\endlinenomath}{}{}%
+ }
+
+ %% Definition of \linenomathAMS depends on whether the mathlines option is provided
+ \expandafter\ifx\linenomath\linenomathWithnumbers
+ \let\linenomathAMS\linenomathWithnumbers
+ %% The following line gets rid of an extra line numbers at the bottom:
+ \patchcmd\linenomathAMS{\advance\postdisplaypenalty\linenopenalty}{}{}{}
+ \else
+ \let\linenomathAMS\linenomathNonumbers
+ \fi
+
+ \AtBeginDocument{%
+ \linenomathpatch{equation}%
+ \linenomathpatchAMS{gather}%
+ \linenomathpatchAMS{multline}%
+ \linenomathpatchAMS{align}%
+ \linenomathpatchAMS{alignat}%
+ \linenomathpatchAMS{flalign}%
+ }
+\else
+ % Hack to ignore these commands, which review mode puts into the .aux file.
+ \newcommand{\@LN@col}[1]{}
+ \newcommand{\@LN}[2]{}
+\fi
+
+\iffalse
+\PassOptionsToPackage{
+ a4paper,
+ top=2.21573cm,left=2.54cm,
+ textheight=704.60031pt, % 51 * \baselineskip + \topskip
+ textwidth=16.0cm,
+ headheight=0.17573cm,headsep=0cm
+}{geometry}
+\fi
+\PassOptionsToPackage{a4paper,margin=2.5cm,heightrounded=true}{geometry}
+\RequirePackage{geometry}
+
+\setlength\columnsep{0.6cm}
+\newlength\titlebox
+\setlength\titlebox{11\baselineskip}
+% \titlebox should be a multiple of \baselineskip so that
+% column height remaining fits an exact number of lines of text
+
+\flushbottom \twocolumn \sloppy
+
+% We're never going to need a table of contents, so just flush it to
+% save space --- suggested by drstrip@sandia-2
+\def\addcontentsline#1#2#3{}
+
+\ifacl@pagenumbers
+ \pagenumbering{arabic}
+\else
+ \thispagestyle{empty}
+ \pagestyle{empty}
+\fi
+
+%% Title and Authors %%
+
+\let\Thanks\thanks % \Thanks and \thanks used to be different, but keep this for backwards compatibility.
+
+\newcommand\outauthor{%
+ \begin{tabular}[t]{c}
+ \ifacl@anonymize
+ \bf Anonymous ACL submission
+ \else
+ \bf\@author
+ \fi
+ \end{tabular}}
+
+% Mostly taken from deproc.
+\AtBeginDocument{
+\def\maketitle{\par
+ \begingroup
+ \def\thefootnote{\fnsymbol{footnote}}
+ \twocolumn[\@maketitle] \@thanks
+ \endgroup
+ \setcounter{footnote}{0}
+ \let\maketitle\relax \let\@maketitle\relax
+ \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax}
+\def\@maketitle{\vbox to \titlebox{\hsize\textwidth
+ \linewidth\hsize \vskip 0.125in minus 0.125in \centering
+ {\Large\bf \@title \par} \vskip 0.2in plus 1fil minus 0.1in
+ {\def\and{\unskip\enspace{\rm and}\enspace}%
+ \def\And{\end{tabular}\hss \egroup \hskip 1in plus 2fil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bf}%
+ \def\AND{\end{tabular}\hss\egroup \hfil\hfil\egroup
+ \vskip 0.25in plus 1fil minus 0.125in
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss \begin{tabular}[t]{c}\bf}
+ \hbox to \linewidth\bgroup\large \hfil\hfil
+ \hbox to 0pt\bgroup\hss
+ \outauthor
+ \hss\egroup
+ \hfil\hfil\egroup}
+ \vskip 0.3in plus 2fil minus 0.1in
+}}
+}
+
+% margins and font size for abstract
+\renewenvironment{abstract}%
+ {\centerline{\large\bf Abstract}%
+ \begin{list}{}%
+ {\setlength{\rightmargin}{0.6cm}%
+ \setlength{\leftmargin}{0.6cm}}%
+ \item[]\ignorespaces%
+ \@setsize\normalsize{12pt}\xpt\@xpt
+ }%
+ {\unskip\end{list}}
+
+%\renewenvironment{abstract}{\centerline{\large\bf
+% Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex}
+
+% Resizing figure and table captions - SL
+% Support for interacting with the caption, subfigure, and subcaption packages - SL
+\RequirePackage{caption}
+\DeclareCaptionFont{10pt}{\fontsize{10pt}{12pt}\selectfont}
+\captionsetup{font=10pt}
+
+\RequirePackage{natbib}
+% for citation commands in the .tex, authors can use:
+% \citep, \citet, and \citeyearpar for compatibility with natbib, or
+% \cite, \newcite, and \shortcite for compatibility with older ACL .sty files
+\renewcommand\cite{\citep} % to get "(Author Year)" with natbib
+\newcommand\shortcite{\citeyearpar}% to get "(Year)" with natbib
+\newcommand\newcite{\citet} % to get "Author (Year)" with natbib
+\newcommand{\citeposs}[1]{\citeauthor{#1}'s (\citeyear{#1})} % to get "Author's (Year)"
+
+\bibliographystyle{acl_natbib}
+
+% Bibliography
+
+% Don't put a label in the bibliography at all. Just use the unlabeled format
+% instead.
+\def\thebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{References\@mkboth
+ {References}{References}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthebibliography=\endlist
+
+
+% Allow for a bibliography of sources of attested examples
+\def\thesourcebibliography#1{\vskip\parskip%
+\vskip\baselineskip%
+\def\baselinestretch{1}%
+\ifx\@currsize\normalsize\@normalsize\else\@currsize\fi%
+\vskip-\parskip%
+\vskip-\baselineskip%
+\section*{Sources of Attested Examples\@mkboth
+ {Sources of Attested Examples}{Sources of Attested Examples}}\list
+ {}{\setlength{\labelwidth}{0pt}\setlength{\leftmargin}{\parindent}
+ \setlength{\itemindent}{-\parindent}}
+ \def\newblock{\hskip .11em plus .33em minus -.07em}
+ \sloppy\clubpenalty4000\widowpenalty4000
+ \sfcode`\.=1000\relax}
+\let\endthesourcebibliography=\endlist
+
+% sections with less space
+\def\section{\@startsection {section}{1}{\z@}{-2.0ex plus
+ -0.5ex minus -.2ex}{1.5ex plus 0.3ex minus .2ex}{\large\bf\raggedright}}
+\def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus
+ -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\bf\raggedright}}
+%% changed by KO to - values to get the initial parindent right
+\def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex plus
+ -0.5ex minus -.2ex}{0.5ex plus .2ex}{\normalsize\bf\raggedright}}
+\def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bf}}
+\def\subparagraph{\@startsection{subparagraph}{5}{\parindent}{1.5ex plus
+ 0.5ex minus .2ex}{-1em}{\normalsize\bf}}
+
+% Footnotes
+\footnotesep 6.65pt %
+\skip\footins 9pt plus 4pt minus 2pt
+\def\footnoterule{\kern-3pt \hrule width 5pc \kern 2.6pt }
+\setcounter{footnote}{0}
+
+% Lists and paragraphs
+\parindent 1em
+\topsep 4pt plus 1pt minus 2pt
+\partopsep 1pt plus 0.5pt minus 0.5pt
+\itemsep 2pt plus 1pt minus 0.5pt
+\parsep 2pt plus 1pt minus 0.5pt
+
+\leftmargin 2em \leftmargini\leftmargin \leftmarginii 2em
+\leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em \leftmarginvi .5em
+\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
+
+\def\@listi{\leftmargin\leftmargini}
+\def\@listii{\leftmargin\leftmarginii
+ \labelwidth\leftmarginii\advance\labelwidth-\labelsep
+ \topsep 2pt plus 1pt minus 0.5pt
+ \parsep 1pt plus 0.5pt minus 0.5pt
+ \itemsep \parsep}
+\def\@listiii{\leftmargin\leftmarginiii
+ \labelwidth\leftmarginiii\advance\labelwidth-\labelsep
+ \topsep 1pt plus 0.5pt minus 0.5pt
+ \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt
+ \itemsep \topsep}
+\def\@listiv{\leftmargin\leftmarginiv
+ \labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
+\def\@listv{\leftmargin\leftmarginv
+ \labelwidth\leftmarginv\advance\labelwidth-\labelsep}
+\def\@listvi{\leftmargin\leftmarginvi
+ \labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
+
+\abovedisplayskip 7pt plus2pt minus5pt%
+\belowdisplayskip \abovedisplayskip
+\abovedisplayshortskip 0pt plus3pt%
+\belowdisplayshortskip 4pt plus3pt minus3pt%
+
+% Less leading in most fonts (due to the narrow columns)
+% The choices were between 1-pt and 1.5-pt leading
+\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt}
+\def\small{\@setsize\small{10pt}\ixpt\@ixpt}
+\def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt}
+\def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt}
+\def\tiny{\@setsize\tiny{7pt}\vipt\@vipt}
+\def\large{\@setsize\large{14pt}\xiipt\@xiipt}
+\def\Large{\@setsize\Large{16pt}\xivpt\@xivpt}
+\def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt}
+\def\huge{\@setsize\huge{23pt}\xxpt\@xxpt}
+\def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt}
+
+% The hyperref manual (section 9) says hyperref should be loaded after natbib
+\ifacl@hyperref
+ \PassOptionsToPackage{breaklinks}{hyperref}
+ \RequirePackage{hyperref}
+ % make links dark blue
+ \definecolor{darkblue}{rgb}{0, 0, 0.5}
+ \hypersetup{colorlinks=true, citecolor=darkblue, linkcolor=darkblue, urlcolor=darkblue}
+\else
+ % This definition is used if the hyperref package is not loaded.
+ % It provides a backup, no-op definiton of \href.
+ % This is necessary because \href command is used in the acl_natbib.bst file.
+ \def\href#1#2{{#2}}
+ \usepackage{url}
+\fi
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_latex.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_latex.tex"
new file mode 100644
index 0000000000..c3e30e78d0
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_latex.tex"
@@ -0,0 +1,662 @@
+% This must be in the first 5 lines to tell arXiv to use pdfLaTeX, which is strongly recommended.
+\pdfoutput=1
+% In particular, the hyperref package requires pdfLaTeX in order to break URLs across lines.
+
+\documentclass[11pt]{article}
+
+% Change "review" to "final" to generate the final (sometimes called camera-ready) version.
+% Change to "preprint" to generate a non-anonymous version with page numbers.
+\usepackage[final]{acl}
+
+% Standard package includes
+\usepackage{times}
+\usepackage{latexsym}
+% \usepackage{algorithmic}
+\usepackage{graphicx}
+\usepackage{textcomp}
+\usepackage{xcolor}
+\usepackage{physics}
+\usepackage{mathdots}
+% \usepackage{algorithm}
+\usepackage{subfigure}
+\usepackage{pgfplots}
+\usepackage{tikz}
+\usetikzlibrary{patterns}
+\usepackage{multirow}
+\usepackage{multicol}
+% \usepackage{authblk}
+\usepackage{array}
+\pgfplotsset{compat=1.18}
+
+\usepackage[ruled,linesnumbered]{algorithm2e}
+\usepackage{booktabs}
+
+% 在这里添加这一行
+\let\Bbbk\relax
+\usepackage{amsmath,amssymb,amsfonts}
+\DeclareMathOperator*{\argmax}{arg\,max}
+
+% For proper rendering and hyphenation of words containing Latin characters (including in bib files)
+\usepackage[T1]{fontenc}
+% For Vietnamese characters
+% \usepackage[T5]{fontenc}
+% See https://www.latex-project.org/help/documentation/encguide.pdf for other character sets
+
+% This assumes your files are encoded as UTF8
+\usepackage[utf8]{inputenc}
+
+% This is not strictly necessary, and may be commented out,
+% but it will improve the layout of the manuscript,
+% and will typically save some space.
+\usepackage{microtype}
+
+% This is also not strictly necessary, and may be commented out.
+% However, it will improve the aesthetics of text in
+% the typewriter font.
+\usepackage{inconsolata}
+
+%Including images in your LaTeX document requires adding
+%additional package(s)
+\usepackage{graphicx}
+
+% If the title and author information does not fit in the area allocated, uncomment the following
+%
+%\setlength\titlebox{}
+%
+% and set to something 5cm or larger.
+
+\title{DIDS: Domain Impact-aware Data Sampling for \\ Large Language Model Training}
+
+% Author information can be set in various styles:
+% For several authors from the same institution:
+% \author{Author 1 \and ... \and Author n \\
+% Address line \\ ... \\ Address line}
+% if the names do not fit well on one line use
+% Author 1 \\ {\bf Author 2} \\ ... \\ {\bf Author n} \\
+% For authors from different institutions:
+% \author{Author 1 \\ Address line \\ ... \\ Address line
+% \And ... \And
+% Author n \\ Address line \\ ... \\ Address line}
+% To start a separate ``row'' of authors use \AND, as in
+% \author{Author 1 \\ Address line \\ ... \\ Address line
+% \AND
+% Author 2 \\ Address line \\ ... \\ Address line \And
+% Author 3 \\ Address line \\ ... \\ Address line}
+
+\author{
+ \textbf{Weijie Shi\textsuperscript{1}\thanks{Equal contribution}\thanks{\small{
+ \textbf{Email:} \href{mailto:wshiah@connect.ust.hk}{wshiah@connect.ust.hk}
+ }}},
+ \textbf{Jipeng Zhang\textsuperscript{1}\footnotemark[1]},
+ \textbf{Yaguang Wu\textsuperscript{2}},
+ \textbf{Jingzhi Fang\textsuperscript{1}},
+\\
+ \textbf{Shibo Zhang\textsuperscript{2}},
+ \textbf{Yao Zhao\textsuperscript{3}},
+ \textbf{Hao Chen\textsuperscript{1}},
+ \textbf{Ruiyuan Zhang\textsuperscript{1}},
+\\
+ \textbf{Yue Cui\textsuperscript{1}},
+ \textbf{Jia Zhu\textsuperscript{4}},
+ \textbf{Sirui Han\textsuperscript{1}\thanks{Corresponding authors}},
+ \textbf{Jiajie Xu\textsuperscript{5}}, \textbf{Xiaofang Zhou\textsuperscript{1}\footnotemark[3]}
+\\
+\\
+ \textsuperscript{1}The Hong Kong University of Science and Technology, \textsuperscript{2}MetaX, \\
+ \textsuperscript{3}Alibaba Group,
+ \textsuperscript{4}Zhejiang Normal University, \textsuperscript{5}Soochow University
+}
+
+%\author{
+% \textbf{First Author\textsuperscript{1}},
+% \textbf{Second Author\textsuperscript{1,2}},
+% \textbf{Third T. Author\textsuperscript{1}},
+% \textbf{Fourth Author\textsuperscript{1}},
+%\\
+% \textbf{Fifth Author\textsuperscript{1,2}},
+% \textbf{Sixth Author\textsuperscript{1}},
+% \textbf{Seventh Author\textsuperscript{1}},
+% \textbf{Eighth Author \textsuperscript{1,2,3,4}},
+%\\
+% \textbf{Ninth Author\textsuperscript{1}},
+% \textbf{Tenth Author\textsuperscript{1}},
+% \textbf{Eleventh E. Author\textsuperscript{1,2,3,4,5}},
+% \textbf{Twelfth Author\textsuperscript{1}},
+%\\
+% \textbf{Thirteenth Author\textsuperscript{3}},
+% \textbf{Fourteenth F. Author\textsuperscript{2,4}},
+% \textbf{Fifteenth Author\textsuperscript{1}},
+% \textbf{Sixteenth Author\textsuperscript{1}},
+%\\
+% \textbf{Seventeenth S. Author\textsuperscript{4,5}},
+% \textbf{Eighteenth Author\textsuperscript{3,4}},
+% \textbf{Nineteenth N. Author\textsuperscript{2,5}},
+% \textbf{Twentieth Author\textsuperscript{1}}
+%\\
+%\\
+% \textsuperscript{1}Affiliation 1,
+% \textsuperscript{2}Affiliation 2,
+% \textsuperscript{3}Affiliation 3,
+% \textsuperscript{4}Affiliation 4,
+% \textsuperscript{5}Affiliation 5
+%\\
+% \small{
+% \textbf{Correspondence:} \href{mailto:email@domain}{email@domain}
+% }
+%}
+
+\begin{document}
+\maketitle
+\begin{abstract}
+Large language models (LLMs) are commonly trained on multi-domain datasets, where domain sampling strategies significantly impact model performance due to varying domain importance across downstream tasks. Existing approaches for optimizing domain-level sampling strategies struggle with maintaining intra-domain consistency and accurately measuring domain impact. In this paper, we present Domain Impact-aware Data Sampling (DIDS). To ensure intra-domain consistency, a gradient clustering algorithm is proposed to group training data based on their learning effects, where a proxy language model and dimensionality reduction are employed to reduce computational overhead. To accurately measure domain impact, we develop a Fisher Information Matrix (FIM) guided metric that quantifies how domain-specific parameter updates affect the model's output distributions on downstream tasks, with theoretical guarantees. Furthermore, to determine optimal sampling ratios, DIDS combines both the FIM-guided domain impact assessment and loss learning trajectories that indicate domain-specific potential, while accounting for diminishing marginal returns. Extensive experiments demonstrate that DIDS achieves 3.4\% higher average performance while maintaining comparable training efficiency. The code is available at \url{https://github.com/shiweijiezero/DIDS}.
+\end{abstract}
+
+\section{Introduction}
+Large language models (LLMs) have demonstrated remarkable capabilities across diverse tasks through training on massive multi-domain datasets, enabling robust generalization and adaptation abilities \cite{weber2024redpajama,biderman2023pythia,chen2024large,azaria2024chat,zhang2024effective}. While the composition of training data (e.g., code, scientific papers, web text) significantly shapes model performance, their relative importance varies substantially with respect to target applications. On the one hand, some data domains contribute positively to model performance, whereas others may even impair effectiveness and waste computational resources \cite{xia2024less,zhou2024lima}. On the other hand, each domain's contribution to model learning evolves dynamically throughout the training process \cite{luo2024velocitune,kang2024autoscale}. This necessitates an approach for optimizing domain-level data sampling strategies during LLM training to maximize performance across downstream tasks while maintaining training efficiency. Unfortunately, designing such an algorithm presents several crucial challenges.
+
+\textbf{Intra-domain Consistency.} A fundamental prerequisite for effective domain-level sampling strategies is maintaining data consistency within each domain. Existing approaches either rely on data source categorization \cite{xie2024doremi,fandoge} or employ BERT semantic clustering \cite{fan2024dynamic}. However, these methods often fail to ensure that data within each domain has similar training effects, which is crucial for making domain-level sampling strategies meaningful. To address this limitation, gradient information serves as a more direct measure of training impact. Gradients inherently capture how each data point influences model parameters during training, enabling us to group samples based on their learning effects rather than superficial characteristics.
+
+\textbf{Domain Impact and Mixing Strategy.} The next key challenge lies in accurately measuring each domain's impact on downstream tasks throughout the dynamic training process. Unfortunately, existing grid search methods \cite{ye2024data,liu2024regmix,mckinzie2025mm1} are computationally intensive and cannot adapt to the dynamic of domain importance during training, while gradient similarity approaches \cite{fandoge,fan2024dynamic} only measure the instantaneous parameter update direction alignment without considering how these updates actually affect the model's predictive behavior on downstream tasks. To quantify such influence in a principled way, a natural objective is minimizing the output distributional discrepancy between how different domains' updates shift the model's predictions. Beyond measuring impact, determining optimal sampling ratios requires balancing computation resources across all downstream tasks while considering the marginal utility of domain data, as domain-specific capabilities may saturate with diminishing returns over time.
+
+In this paper, we propose \textbf{D}omain \textbf{I}mpact-aware \textbf{D}ata \textbf{S}ampling (DIDS), which dynamically optimizes domain-level sampling probability by measuring domains' impact on model's predictive behavior. To ensure intra-domain consistency, a gradient clustering algorithm is proposed to group training data, where a small proxy language model is employed instead of the full-size model to reduce computation cost, followed by gradient norm-based subsampling and Johnson-Lindenstrauss random projection for dimensionality reduction. To accurately measure domain impact, a Fisher Information Matrix (FIM) guided metric is developed to quantify the output distributional shift based on the second-order Taylor approximation of KL divergence, enabling efficient assessment of how each domain affects the model's predictive behavior on downstream tasks. We also provide theoretical foundations for the FIM-guided metric. To determine domain sampling proportions, weights are computed by combining both the FIM-guided domain impact on downstream tasks and their loss improvement trajectories indicating learning potential. Extensive experiments on Llama-3.1 across 9 downstream tasks demonstrate that DIDS achieves 3.4\% higher average performance. Our contributions are summarized as follows:
+\begin{itemize}
+ \item We present a gradient-based data clustering that leverages proxy models and dimensionality reduction to group training samples, ensuring intra-domain training consistency.
+ \item We propose a FIM-guided impact metric that measures how domain-specific parameter updates shift model's output distributions on downstream tasks, enabling accurate assessment of domain importance during training with theoretical foundations.
+ \item We design DIDS, a domain sampling framework that dynamically adjusts mixing ratios by combining domain impact with learning trajectories, accounting for diminishing marginal returns of domain-specific performance.
+\end{itemize}
+
+\section{Related Work}
+\subsection{Instance-level Data Sampling}
+Instance-level data sampling approaches for language model training primarily focus on selecting high-quality training samples that maximize model performance. LIMA \cite{zhou2024lima} demonstrates that a small set of 1,000 carefully curated prompts can achieve strong performance, \citet{ge2024clustering} ensures both quality and diversity through BERT-based scoring and clustering, and DEITA \cite{liu2023makes} further considers instruction complexity by ChatGPT. Moreover, to align requirements of the specific downstream task, DSIR \cite{xie2023data} utilizes N-gram feature-based importance resampling, while LESS \cite{xia2024less} and TracIn \cite{pruthi2020estimating} leverage gradient-based methods to identify influential training samples through gradient alignment and descent tracing. However, these approaches either lack downstream task awareness or are computationally expensive, motivating domain-level sampling strategies.
+
+\input{latex/figure/architecture}
+
+\subsection{Domain-level Data Sampling}
+Domain-level data sampling strategies can be categorized into static and online methods. Static methods determine fixed sampling ratios using proxy models before full-scale training begins. MM1 \cite{mckinzie2025mm1} employs grid search to evaluate different sampling ratios empirically, while Mixing Laws \cite{ye2024data} extends this by proposing scaling law formulas to model the relationship between mixing ratios and model performance. REGMIX \cite{liu2024regmix} introduces regression models to predict this scaling curve. Moreover, Doremi \cite{xie2024doremi} incorporates reference models to consider excess loss, and Doge \cite{fandoge} utilizes gradient alignment between training and validation sets. However, AUTOSCALE \cite{kang2024autoscale} reveals that optimal mixing ratios derived from proxy models may not transfer effectively to larger models. Thus, online methods directly adjust sampling ratios throughout the training process. DGA \cite{fan2024dynamic} extends Doge's gradient-based approach to online scenarios, while Velocitune \cite{luo2024velocitune} monitors learning velocity to adaptively adjust domain proportions.
+
+Moreover, DRPruning \cite{deng2024drpruning} employs distributionally robust optimization to iteratively shift data distribution toward underperforming domains during training, ensuring balanced recovery across all areas rather than allowing some domains to lag behind after model pruning. It shares our motivation for adaptive domain reweighting but focuses specifically on post-pruning recovery scenarios. DDK \cite{liu2024ddk} computes perplexity ratios between teacher and student models across domains and uses factor-smooth updating mechanisms to periodically adjust sampling probabilities. DDK allocates more training data to domains where the student model underperforms relative to the teacher, thereby reducing performance gaps during knowledge distillation.
+
+Yet existing methods either rely on gradient similarity alone without capturing downstream impact, or use computationally expensive techniques like Scaling Law, limiting their practicality. This motivates our efficient, theoretically grounded approach to dynamic domain-level sampling.
+
+
+\section{Problem Formulation}
+In this part, we formalize the problem of optimizing domain-level sampling strategy for LLM training.
+
+Consider $\mathcal{D}=\{D_1,...,D_k\}$ denote a training dataset comprising $k$ disjoint domains and $\mathcal{S} = \{S_1, ..., S_m\}$ represent a collection of downstream tasks. Given a large language model $f_\theta$ parameterized by $\theta$ and a computational budget of $n$ training instances, our goal is to optimize the model's performance across all tasks by adjusting the sampling probabilities across different domains during parameter training.
+
+We characterize the domain sampling strategy through a probability vector $\mathbf{p}_t = [p_{t,1}, ..., p_{t,k}]$ at each training step $t$, where $p_{t,i}$ represents the sampling probability from domain $D_i$ subject to the $(k-1)$-dimensional probability simplex $\Pi^{k-1}$:
+\begin{equation}
+ \mathbf{p}_t \in \Pi^{k-1} = \{p_{t,i} \geq 0, \sum_{i=1}^k p_{t,i} = 1\}
+\end{equation}
+
+The objective of the training process follows a bi-level optimization framework to optimize both model parameters $\theta$ and sampling probabilities $\mathbf{p}$:
+\begin{equation}
+ \max_{\theta, \mathbf{p} \in \Pi^{k-1}} \sum_{j=1}^m \text{Acc}_j(f_\theta; S_j)
+\end{equation}
+where $\text{Acc}_j(f_\theta; S_j)$ measures the model’s accuracy on downstream task $S_j$.
+
+To update the model parameters, we perform standard gradient descent:
+\begin{equation}
+ \theta_{t+1} = \theta_t - \eta \nabla \ell(\theta_t, \mathcal{B}_t), \quad \mathcal{B}_t \sim \mathbf{p_t}
+\end{equation}
+where $\mathcal{B}_t$ denotes a mini-batch sampled according to the domain sampling probabilities $\mathbf{p}_t$, $\eta$ denotes the learning rate, and $\nabla \ell$ computes the loss gradients with respect to the model parameters.
+
+To update the domain sampling probabilities, we periodically adjust the sampling distribution every $\tau$ steps to optimize the expected model performance across all downstream tasks:
+\begin{equation}
+ \mathbf{p}_{t} = \argmax_{\mathbf{p} \in \Pi^{k-1}} \sum_{j=1}^m \text{Acc}_j(f_{\theta_{t+\tau}}; S_j)
+\end{equation}
+where $\theta_{t+\tau}$ represents the model parameters after $\tau$ steps of training using sampling distribution $\mathbf{p}_{t}$.
+
+% \subsection{Comparason with 现有的xxx}
+
+\section{Methodology}
+
+\subsection{Gradient-based Domain Repartition}
+Effective domain-level sampling strategies require consistent training behavior within each domain. Traditional approaches to domain partitioning typically rely on superficial characteristics, such as data sources or semantic similarity measured through BERT embeddings. However, these methods often fail to capture how different training samples actually influence model learning. For instance, mathematical proofs and programming implementations, despite being traditionally categorized into different domains, often induce similar gradient patterns during training due to their shared logical reasoning nature. Conversely, two web documents from the same domain might trigger drastically different parameter updates. To better organize the training data, a gradient-based domain repartitioning is suitable to directly reflect parameter update behaviors.
+
+Unfortunately, computing and clustering gradients using a full-size LLM for all samples would be computationally prohibitive. A small proxy model maintaining the same architecture but with reduced width and depth serves as an efficient alternative. For each training sample $x_i \in \mathcal{D}$, gradient computation yields vector $g_i$ through
+$\nabla \ell(\theta', x_i)$. Here we only keep the last 10\% gradients to accelerate computation. To make clustering computationally feasible, gradient norm-based subsampling retains only the top-k elements with the largest magnitudes in each gradient vector. Next, dimensionality reduction is performed via Johnson-Lindenstrauss random projection \cite{park2023trak} to compress the gradient vectors from parameter-scale dimensionality (millions-level) to a clustering-manageable dimension (thousands-level):
+\begin{equation}
+\tilde{g}_i = R^T g_i, \quad R \in \mathbb{R}^{h \times s}
+\end{equation}
+where $h$ represents the original dimension and $s$ denotes the target dimension satisfying $s \ll h$. The random projection matrix $R$ is initialized by randomly orthogonal matrices. The detailed Johnson-Lindenstrauss theorem and initialization methods are provided in Appendix A.
+
+Building upon initial semantic categorization, k-means clustering on these reduced gradient vectors refines each domain, where the number of clusters serves as a hyperparameter. The resulting domains are denoted as $\mathcal{D} = \{D_1, ..., D_k\}$, where $k$ represents the total number of domains.
+
+\subsection{FIM-guided Domain Impact}
+After establishing consistent domain partitions, a key challenge is accurately measuring how each domain's training data impacts model performance on downstream tasks. Existing approaches either rely on computationally expensive grid search methods that cannot adapt to dynamic training processes, or use gradient similarity metrics. For example, DGA \cite{fan2024dynamic} measures the domain impact on specific downstream tasks as:
+\begin{equation}
+ I(D_i,\!S_j)\! =\! \mathbb{E}_{x_i\sim D_i, x_j\sim S_j} [\!\langle \nabla \!\ell(\theta^t \!, x_i) ,\! \nabla \! \ell(\theta^t \!, x_j)\rangle]
+\end{equation}
+where $I(D_i,S_j)$ measures the impact of training domain $D_i$ on downstream task $S_j$, $\langle\nabla\ell(\theta^t, x_i), \nabla\ell(\theta^t, x_j)\rangle$ represents the inner product of gradients. However, they only capture instantaneous parameter update directions without considering their actual effects on model behavior. We need a more principled approach that can efficiently quantify how domain-specific parameter updates influence the model's predictive distributions on target tasks.
+
+To this end, we propose a Fisher Information Matrix (FIM) guided metric that quantifies the output distributional changes induced by domain-specific data. The core insight is that the Kullback-Leibler (KL) divergence between the original and updated model predictions provides a natural measure of how parameter updates affect model behavior. Due to the intractable nature of direct KL divergence computation in infinite input spaces, here we employ a second-order Taylor approximation.
+
+For notational simplicity, let $p(y|\theta)$ be denoted as $p(\theta)$, $\theta_{D_i} = \theta + \nabla \ell_{D_i}$ and $\theta_{S_j} = \theta + \nabla \ell_{S_j}$ represent the parameters after updates from domain $D_i$ and task $S_j$ respectively, and $\Delta = \nabla \ell_{S_j}-\nabla \ell_{D_i}$ represent the gradient difference between downstream task updates and training domain updates. Formally, we define the domain impact metric as:
+\begin{equation}
+\begin{split}
+ I(D_i,S_j) &= \text{KL}[p(\theta_{D_i}) \parallel p(\theta_{S_j})] \\
+ &= \int p(\theta_{D_i})\log \frac{p(\theta_{D_i})}{p(\theta_{S_j})}dy \\
+ &= \mathbb{E}_{p(\theta_{D_i})}[\log p(\theta_{D_i})]\!\! -\! \! \mathbb{E}_{p(\theta_{D_i})}\![\log p(\theta_{S_j})]
+\end{split}
+\end{equation}
+When the gradient updates are small (i.e., $\nabla \ell_{D_i} \approx \nabla \ell_{S_j} \approx 0$), we can approximate using second-order Taylor expansion around $\theta_{D_i}$ as:
+\begin{equation}
+\begin{split}
+ &\text{KL}[p(\theta_{D_i}) \!\! \parallel \! p(\theta_{S_j})] \! \approx \mathbb{E}_{p(\theta_{D_i})}[\log p(\theta_{D_i})] \!-\! \mathbb{E}_{p(\theta_{D_i})}[ \\
+ &\quad \log p(\theta_{D_i}) \!+\! \nabla\!\log p(\theta_{D_i})\Delta \!+\! \frac{1}{2}\!\Delta\!^T\nabla^2\!\log p(\theta_{D_i})\Delta] \\
+ &= -\mathbb{E}_{p(\theta_{D_i})}[\nabla\log p(\theta_{D_i})\Delta] \\
+ &\quad - \mathbb{E}_{p(\theta_{D_i})}\left[\frac{1}{2}\Delta^T\nabla^2\log p(\theta_{D_i})\Delta\right]
+\end{split}
+\end{equation}
+The first term can be simplified through integration-differentiation interchange:
+\begin{equation}
+\begin{split}
+ \mathbb{E}_{p(\theta_{D_i})}[\nabla\!\log p(\theta_{D_i})\Delta]\! &=\! \! \int_{\theta_{D_i}} \!\!\!\! \frac{\nabla p(\theta_{D_i})}{p(\theta_{D_i})} \! p(\theta_{D_i})\Delta d\theta\!_{D_i} \\
+ &=\! \nabla\int_{\theta_{D_i}} p(\theta_{D_i})d\theta_{D_i} \cdot \Delta \\
+ &=\! \nabla(1) \cdot \Delta = 0
+\end{split}
+\end{equation}
+For the second term, the expected Hessian of the negative log-likelihood is equivalent to Fisher Information Matrix:
+\begin{equation}
+\begin{split}
+ \mathbb{E}_{p(\theta_{D_i})} [\nabla^2\log p(\theta_{D_i})] &= \mathbb{E}_{p(x|\theta_{D_i})}[\text{H}_{\log p(x|\theta_{D_i})}] \\
+ &= -\text{F}
+\end{split}
+\end{equation}
+Considering that the FIM for LLMs is extremely large and cannot be computed at $\theta_{D_i}$ since the model has not been updated, we instead use diagonal approximation at $\theta$ in practice:
+\begin{equation}
+ \text{F} \approx \mathbb{E}[\nabla\log p(\theta) \odot \nabla\log p(\theta)]
+\end{equation}
+Note that FIM only measures the local geometry of the parameter space, and the difference between using FIM at $\theta_{D_i}$ and $\theta$ is negligible when the gradient updates are small. Afterward, the domain impact metric could be rewritten as:
+\begin{equation}
+\begin{split}
+ I(D_i,S_j) &= \! \text{KL}[p(\theta_{D_i}) \parallel p(\theta_{S_j})] \\
+ &= \! -\mathbb{E}_{p(\theta_{D_i})}\!\left[\frac{1}{2}\Delta^T\nabla^2\log p(\theta_{D_i})\Delta\!\right] \\
+ &= \! \frac{1}{2}\Delta^T F \Delta
+\end{split}
+\end{equation}
+This quadratic form captures how the difference in gradient updates affects the model's output distribution, weighted by the FIM which characterizes the local geometry of the parameter space. The complexity analysis is provided in Section \ref{cost}.
+
+\subsection{Dynamic Domain Sampling Strategy}
+Building upon the FIM-guided domain impact measurement, a dynamic sampling strategy is proposed to optimize domain mixing ratios by considering both current learning progress and future potential. The sampling probability for each domain is updated periodically using a combination of three key components:
+
+\textbf{Current Performance Impact.} To identify valuable domains that can achieve larger performance improvements with lower sampling probabilities, we compute a utility score for each domain $D_i$ and downstream task $S_j$ that measures the domain's effectiveness in improving task performance:
+\begin{equation}
+U(D_i, S_j) = \frac{I(D_i, S_j) \cdot l_c}{p_{t-1,i}}
+\end{equation}
+where $I(D_i, S_j)$ is the normalized FIM-guided impact score, $l_c$ represents the loss improvement on task $S_j$ between consecutive updates $\Delta L(S_j)$, and $p_{t-1,i}$ is the previous sampling probability for domain $D_i$.
+
+\textbf{Future Potential Estimation.} To account for the diminishing returns in domain-specific learning and prioritize unsaturated domains, we introduce a potential factor $l_p$ that estimates future improvement opportunities. Given the loss history ${l_1, ..., l_t}$ for each downstream task, we fit an exponential decay model\footnote{\url{https://scikit-learn.org/stable/auto_examples/linear_model/plot_bayesian_ridge_curvefit.html}}, which is a typical pattern for learning curves:
+\begin{equation}
+l_t = ae^{-bt} + c
+\end{equation}
+where parameters $a$, $b$, and $c$ are estimated using curve fitting. The potential factor $l_p$ is then computed as the difference between current loss and predicted future loss:
+\begin{equation}
+l_p = l_t - l_{t + \tau}
+\end{equation}
+where $\tau$ represents the prediction window size.
+
+\textbf{Sampling Probability Update.} The final sampling probabilities are updated using an exponential moving average (EMA) to maintain stability:
+\begin{equation}
+p_{t,i} = \beta p_{t-1,i} + (1-\beta)\left(\frac{\sum_j I(D_i, S_j) \cdot (l_c + l_p)}{p_{t-1,i}}\right)
+\end{equation}
+where $\beta$ is the EMA momentum coefficient, $l_c$ represents the current loss improvement, and $l_p$ is the estimated potential factor. A softmax normalization ensures valid probability distribution while the division by previous probabilities implements importance sampling correction. The complete algorithm is summarized in Appendix \ref{appendix D}.
+
+\input{figure/tab-q1}
+\section{Experiments}
+\subsection{Experimental Setup}
+\subsubsection{Datasets and Tasks}
+We utilize the Tulu-3 \cite{lambert2024t} post-training dataset containing 939,344 samples from 18 sources across web text, academic papers, code, mathematics, and books. The downstream evaluation suite comprises: BIG-Bench Hard (BBH) \cite{suzgun2022challenging} for reasoning and problem-solving, BoolQ \cite{clark2019boolq} for reading comprehension and binary question answering, GSM8K \cite{cobbe2021training} and Minerva-MathQA \cite{lewkowycz2022solving} for mathematical reasoning, IFEval \cite{zhou2023instruction} for instruction following, MMLU \cite{hendrycks2020measuring} for multitask language understanding, PIQA \cite{bisk2020piqa} for physical commonsense reasoning, PubMedQA \cite{jin2019pubmedqa} for biomedical question answering, and TruthfulQA \cite{lin2021truthfulqa} for measuring truthfulness in model responses.
+
+\subsubsection{Baselines}
+We evaluate DIDS against several domain sampling strategies: Uniform and Random sampling, Doremi \cite{xie2024doremi}, Velocitune \cite{luo2024velocitune}, Doge \cite{fandoge}, and DGA \cite{fan2024dynamic}. For all baseline implementations, we partition a small subset from the downstream task's validation set to serve as observable samples for domain reweighting. Detailed implementations are provided in Appendix \ref{appendix_C}.
+
+\subsection{Main Results}
+% 包含主表同时支持多个下游任务
+% 包含每次只支持单个下游任务
+Table \ref{tab:main_results} presents comprehensive evaluation results across nine downstream tasks under both multi-task and single-task optimization scenarios. For reference, we include results from the base Llama-3.1-8B model and its variant trained on the full 929k samples.
+
+For multi-task optimization, DIDS with only 100k samples achieves an average score of 62.3, significantly outperforming all baseline methods while surpassing the performance of full data training at 61.2. Specifically, DIDS improves over the strongest baseline Doge by 2.1 on average, with particularly notable gains on mathematical reasoning tasks such as Minerva-MathQA improving by 2.7 points from 17.8 to 20.5. This demonstrates DIDS's effectiveness in identifying and prioritizing the most impactful training samples across diverse downstream tasks. Notably, we observe that for some tasks like MMLU and PIQA where the base model is already approaching saturation, additional training with irrelevant data can be detrimental, as evidenced by the Full Data approach's performance decline from 64.7 to 64.3 on MMLU. Furthermore, given the limited training data budget, unbalanced resource allocation across multiple tasks can lead to improved performance on some tasks at the expense of others, as demonstrated by DGA's poor performance of 42.1 on IFEval.
+
+When optimizing for individual tasks, DIDS demonstrates even stronger performance with an average score of 63.7, surpassing the second-best method DGA by 2.1. DIDS shows significant gains on Knowledge-intensive tasks, with IFEval increasing from 53.2 to 57.5 and TruthfulQA improving from 38.5 to 44.8. This indicates that DIDS's FIM-guided domain impact measurement and dynamic sampling strategy are especially effective when focusing on specific downstream objectives. Notably, even with just 100k samples, roughly 10 percent of the full dataset, DIDS achieves higher average performance than training on the full 929k samples with scores of 63.7 versus 61.2.
+
+\input{latex/figure/tab-q2}
+\subsection{Ablations}
+To analyze the contribution of each component in DIDS, we conduct ablation experiments by progressively removing key components through gradient-based clustering DIDS-GC, FIM-guided impact measurement DIDS-FIM, and loss trajectory consideration DIDS-LT. Results are shown in Table \ref{tab:ablation}. DIDS-GC replaces gradient-based clustering with BERT semantic clustering, leading to a 1.8-point drop in average performance from 46.9 to 45.1. DIDS-FIM removes the FIM-guided impact measurement, causing a 2.9-point decline to 44.0, most notably affecting TruthfulQA with a 4.5-point drop and IFEval with a 3.7-point decrease. DIDS-LT eliminates the loss trajectory and saturation consideration, resulting in 2.8-point decrease to 44.1, demonstrating that dynamic adaptation to learning progress is crucial for optimal performance. These results show that each component contributes significantly to DIDS effectiveness.
+
+\input{latex/figure/tab-q3}
+\subsection{Efficiency Analysis} \label{cost}
+% 额外的性能代价,
+To comprehensively evaluate DIDS's computational overhead, we analyze the efficiency of each component: gradient-based clustering, FIM-guided impact measurement, and loss trajectory estimation. Our implementation optimizes computational costs by retaining gradients only from the final 10\% of layers, requiring complete forward passes but partial backward passes. Table \ref{tab:efficiency} presents the computational requirements in terms of TFLOPs and GPU Hours on H800.
+
+Base training of an 8B parameter model on 1B tokens requires $5.47 \times 10^4$ TFLOPs for forward and backward passes, consuming approximately 101.6 GPU hours. For the clustering component processing 1B tokens, we evaluate two approaches using 500M models. BERT semantic clustering requires only forward passes at $7.77 \times 10^2$ TFLOPs, while gradient-based clustering with dimensionality reduction necessitates both forward and partial backward computation at $1.87 \times 10^3$ TFLOPs, requiring 1.5 and 3.3 GPU hours respectively.
+
+For domain impact measurement using an 8B parameter base model with 25 mixing ratio updates, we compare FIM-guided metrics against gradient alignment. Across 72 training domains, maintaining running averages of domain-specific gradients incurs negligible overhead. Evaluating 9 downstream tasks with 200 samples per task, gradient alignment requires $9.86 \times 10^1$ TFLOPs. DIDS additionally computes FIM diagonal elements, adding negligible overhead at approximately $1.78 \times 10^2$ TFLOPs, totaling 0.2 GPU hours. The loss trajectory estimation component introduces minimal computational burden below $10^{-1}$ TFLOPs as it only involves scalar loss value curve fitting. While DIDS introduces roughly 1.9\% additional computational cost compared to DGA, this overhead is justified by substantial performance improvements and reduced training data requirements.
+
+\subsection{Parameter Analysis}
+\input{latex/figure/fig-q4}
+% We conduct extensive analysis to investigate the impact of key parameters on DIDS performance.
+
+\subsubsection{Impact of Update Frequency}
+Figure \ref{fig:parameter-analysis}a shows how the number of domain sampling probability updates during training affects model performance. When using only 5 updates throughout the entire training process, DIDS achieves an average score of 58.2, which is comparable to the random sampling baseline at 58.9. As we increase the number of updates to 25 and 45, DIDS shows substantial improvements, reaching scores of 60.1 and 61.8 respectively. The performance continues to improve with 65 updates, achieving 62.3, and peaks at 62.4 with 85 updates. However, further increasing to 95 updates leads to a slight performance decline back to 62.3. DGA exhibits a similar trend but with lower overall performance, reaching its peak of 60.1 at 65 updates. Random sampling maintains a constant performance of 58.9 regardless of update frequency, serving as a stable baseline. These results suggest that performing a limited update number during training provides optimal performance for domain sampling strategies.
+
+\subsubsection{Impact of Irrelevant Data Ratio}
+To evaluate DIDS's robustness to noise in training data, we introduce varying proportions of irrelevant financial domain data and measure the model performance. As shown in Figure \ref{fig:parameter-analysis}b, DIDS demonstrates strong resilience to irrelevant data. Starting at a baseline performance of 62.3 with no irrelevant data, DIDS maintains and even improves its performance as noise increases, reaching a peak of 63.5 at 20\% irrelevant data before showing slight decline to 63.1 at 25\%. In contrast, both comparison methods exhibit clear degradation with increased noise. DGA's performance drops from 58.5 to 57.1, showing moderate sensitivity to irrelevant data. Random sampling demonstrates the most severe degradation, falling from 58.9 to 54.2. These results highlight DIDS's robust ability to identify and leverage valuable training samples through its FIM-guided impact measurement, even in challenging scenarios with substantial noise in the training dataset.
+
+\input{latex/figure/fig-q5}
+\subsubsection{Impact of Proxy Model Size}
+We evaluate DIDS using different sizes of proxy models for gradient-based clustering: 500M, 1B, and the full 8B target model. As shown in Figure \ref{fig:model-analysis}a, the choice of proxy model size has minimal impact on final performance, with average scores of 62.3, 62.4, and 62.5 respectively. This validates our design choice of using a 500M proxy model for clustering, as it provides comparable quality while significantly reducing computational costs.
+
+\subsubsection{Impact of Domain Partition Count}
+We further examine how the granularity of domain partitioning affects model performance. Figure \ref{fig:model-analysis}b shows that increasing domains from the initial 18 based on data sources leads to substantial early improvements in performance. The average score rises sharply from 61.4 to 62.0 when increasing to 36 domains, followed by moderate gains up to 62.3 with 72 domains. However, further partitioning yields small returns, with performance plateauing around 62.7 even when scaling to 1152 domains. Based on this analysis, we select 72 domains as our default configuration to balance effectiveness and computational efficiency.
+
+\section{Conclusion}
+In this paper, we proposed DIDS, a domain impact-aware data sampling framework for large language model training. To ensure consistent domain partitioning, DIDS groups training samples based on gradient patterns, which leads to more effective sampling decisions. FIM-guided metrics measure domain impact accurately, while dynamic sampling optimization combines impact assessment with learning trajectories. Experiments demonstrated that DIDS achieves superior performance across multiple tasks using only 10\% training data.
+
+\section*{Limitations}
+Our work has several limitations that should be acknowledged:
+
+First, while DIDS demonstrates strong performance with limited training data, the gradient-based domain repartitioning introduces additional computational overhead when processing large-scale datasets. Although we mitigate this through proxy models and dimensionality reduction, the clustering process still requires considerable computational resources when scaling to billions of training samples. Future work could explore more efficient methods for gradient-based domain partitioning to further reduce this overhead while maintaining clustering quality.
+
+Second, the effectiveness of our FIM-guided impact measurement depends on the accuracy of the diagonal approximation of the Fisher Information Matrix. While this approximation is computationally necessary, it may not capture all parameter interactions, potentially leading to suboptimal sampling decisions in cases where off-diagonal elements are significant. Additionally, our current approach to loss trajectory modeling assumes exponential decay patterns which may not hold for all learning scenarios.
+
+\section*{Ethics Statement}
+While DIDS improves training efficiency through selective sampling, it may inadvertently amplify existing biases in the training data by preferentially selecting certain domains based on their measured impact. This could lead to underrepresentation of minority groups or less common topics in the final model. In future applications, DIDS should be integrated with ethical auditing tools to ensure fairness in the sampling process and maintain model ethics.
+
+\section*{Acknowledgments}
+We would like to specially thank the support from the A3 project of the HKUST \& MetaX Joint Laboratory. The research work described in this paper was supported by Hong Kong Research Grants Council (grant\# 16202722, 16210625, T43-513/23-N, T22-607/24N). It was partially conducted in JC STEM Lab of Data Science Foundations funded by The Hong Kong Jockey Club Charities Trust. We acknowledge the support of Natural Science Foundation of Zhejiang Province under Grant (LY23F020010). This work is supported by the National Natural Science Foundation of China (Grant No.62272334, 6257073827).
+
+\bibliography{custom}
+
+\appendix
+
+\section{Johnson-Lindenstrauss Theorem and Random Projection Initialization}
+
+\subsection{Johnson-Lindenstrauss Lemma}
+The Johnson-Lindenstrauss lemma states that for any set $X$ of $m$ points in $\mathbb{R}^N$ and $0 < \varepsilon < 1$, there exists a linear map $f: \mathbb{R}^N \to \mathbb{R}^n$ where $n > 8\ln(m)/\varepsilon^2$ such that:
+
+$$
+(1-\varepsilon)||u-v||^2 \leq ||f(u)-f(v)||^2 \leq (1+\varepsilon)||u-v||^2
+$$
+where $u,v \in X$. This theorem guarantees that we can project high-dimensional vectors into a lower-dimensional space while approximately preserving their pairwise distances.
+
+\subsection{Gaussian Random Projection}
+For practical implementation, we utilize Gaussian random projection matrices which satisfy the following properties:
+
+1. Spherical symmetry: For any orthogonal matrices $A,B \in O(d)$, $RAR^T$ and $RBR^T$ have identical distributions.
+
+2. Orthogonality: The rows of $R$ are approximately orthogonal.
+
+3. Unit length: Each row of $R$ is normalized to unit length.
+
+The projection matrix $R \in \mathbb{R}^{h \times s}$ is constructed as follows:
+
+1. Generate entries $R_{ij}$ independently according to:
+ $$R_{ij} = \begin{cases}
+ +1/\sqrt{t} & \text{with probability } 1/2 \\
+ -1/\sqrt{t} & \text{with probability } 1/2
+ \end{cases}$$
+ where $t = \Omega(k/\varepsilon^2)$ for dimension reduction parameter $k$ and error tolerance $\varepsilon$.
+
+2. Normalize each column to ensure unit length: $\tilde{R}_j = R_j/||R_j||_2$
+
+\subsection{Application to Gradient Dimensionality Reduction}
+In the context of gradient-based domain repartitioning, given gradient vectors $g_i \in \mathbb{R}^h$, we project them to $\tilde{g}_i \in \mathbb{R}^s$ where $s \ll h$ using:
+
+$$\tilde{g}_i = R^T g_i$$
+
+The choice of target dimension $s$ balances computational efficiency with distance preservation, typically set as:
+
+$$s = O(\log(m)/\varepsilon^2)$$
+
+where $m$ is the size of gradient vectors and $\varepsilon$ is the desired distance preservation tolerance (typically 0.1-0.3).
+
+This projection enables efficient clustering of gradient vectors while maintaining their essential geometric relationships, facilitating meaningful domain repartitioning based on training behavior patterns.
+
+\section{Implementation Details} \label{appendix_C}
+\subsection{Training Data Distribution}
+\input{latex/figure/t1}
+The training dataset consists of 939,344 samples from 18 diverse sources, covering domains including mathematics, coding, instruction following, and general dialogue. The dataset is available at \url{https://huggingface.co/datasets/allenai/tulu-3-sft-mixture}. The largest components are Tulu 3 Persona MATH with 149,960 samples focusing on mathematical reasoning, followed by FLAN v2 with 89,982 samples of general task instructions, and Evol CodeAlpaca with 107,276 coding-related samples. We provide a detailed breakdown of the dataset composition in Table \ref{tab:data_dist}.
+
+\subsection{Model Architecture}
+We implement DIDS based on multiple foundation models: Llama-3.1 ($8$B and $70$B variants), Llama-2-$7$B, and Pythia-$6.9$B. For the proxy model, we utilize Qwen-2.5 ($500$M) and Llama-3.2 ($1$B).
+
+\subsection{Baseline Description}
+We compare DIDS against the following baseline methods:
+\begin{itemize}
+ \item \textbf{Uniform sampling}: A basic baseline that assigns equal probabilities to all domains throughout training.
+
+ \item \textbf{Random sampling}: Randomly selects domain data at each step without optimization.
+
+ \item \textbf{Doremi} \cite{xie2024doremi}: Trains a proxy model using group distributionally robust optimization to produce offline domain weights for resampling.
+
+ \item \textbf{Velocitune} \cite{luo2024velocitune}: Dynamically adjusts domain proportions based on learning velocity guided by scaling laws.
+
+ \item \textbf{Doge} \cite{fandoge}: Uses bi-level optimization with a proxy model to learn offline domain weights through gradient alignment.
+
+ \item \textbf{DGA} \cite{fan2024dynamic}: Employs online gradient alignment to dynamically estimate optimal pre-training data mixtures.
+\end{itemize}
+For all baselines, we use identical validation set splits from downstream tasks and tune hyperparameters on a separate development set to ensure fair comparison.
+
+To ensure fair comparison across all methods, we adapted the baseline approaches to work with observable downstream tasks. Specifically, \textbf{DoReMi} was originally designed for in-domain scenarios where test sets follow the same distribution as training data. We seamlessly transferred this algorithm to our downstream task settings by computing excess loss over downstream domains. \textbf{VelociTune} was similarly adapted to observe loss over downstream domains for adjusting training data proportions. \textbf{DOGE} and \textbf{DGA} naturally support downstream domain settings as they compute data proportions based on gradient similarity between training and validation sets (observable sets). Importantly, all baseline methods use the same gradient-based domain partitioning strategy as DIDS, ensuring that computational overhead and domain granularity are consistent across comparisons.
+
+
+\subsection{Training Details}
+The training process employs the AdamW optimizer with a learning rate of $5 \times 10^{-4}$ and linear decay scheduling based on Llama-Factory \cite{zheng2024llamafactory} \footnote{\url{https://github.com/hiyouga/LLaMA-Factory}}. We apply gradient clipping at $1.0$ and weight decay at $0.1$. The model processes sequences with a maximum length of $8,192$ tokens and uses a batch size of $128$, distributed across $8$ H800 GPUs. For DIDS-specific configurations, we set the domain update interval $\tau = 4,000$ steps and use an EMA coefficient $\beta = 0.1$. The framework utilizes $72$ domains for gradient-based clustering. Our dimensionality reduction approach first retains the top $10\%$ of elements by magnitude before projecting the gradients to $1,024$ dimensions.
+
+\subsection{Evaluation Details}
+We conduct evaluations using the lm-eval-harness platform \cite{eval-harness}\footnote{\url{https://github.com/EleutherAI/lm-evaluation-harness/}}. All tasks are evaluated under a 3-shot setting using the Vllm backend with chat templates applied. Other parameters follow the platform's default configurations.
+
+% \section{% 不同方法混比的变化情况的可视化}
+
+\section{Algorithm Description} \label{appendix D}
+The Domain Impact-aware Data Sampling (DIDS) algorithm is shown in Algorithm \ref{alg:dids}, which begins with initialization and domain repartitioning. Starting with uniform sampling probabilities across k domains, the algorithm employs a proxy model $f'$ to compute gradients for each training sample $x_i$. These gradients undergo TopK selection and Johnson-Lindenstrauss random projection for dimensionality reduction before k-means clustering establishes the k domains $\{D_1,...,D_k\}$.
+
+At intervals of $\tau$ training steps, DIDS performs domain impact assessment and probability updates. For each domain-task pair $(D_i, S_j)$, the algorithm calculates gradient differences $\Delta$ and impact scores $I(D_i,S_j)$ using the FIM-guided metric, while simultaneously fitting exponential decay curves to task loss histories to estimate future potential $L_p(S_j)$ and current improvements $\Delta L(S_j)$. The algorithm then updates sampling probabilities by computing utility scores $U(D_i)$ that combine these impact scores and loss improvements, applying softmax normalization and exponential moving average with coefficient $\beta$.
+
+Between updates, mini-batches are sampled according to current probabilities $\mathbf{p}_t$, with model parameters updated through standard optimization. This design balances theoretical foundations with practical efficiency through its use of proxy models, strategic gradient processing, and periodic updates, enabling effective domain sampling while maintaining computational feasibility.
+
+\section{Extended Experimental Results}
+\subsection{Experiments on Additional Models and Datasets}
+To validate the effectiveness of DIDS across different model architectures and datasets, we conducted additional experiments using Mixtral-7B \footnote{https://huggingface.co/mistralai/Mistral-7B-v0.1} alongside Llama-3.1-8B, and tested on both Tulu-v3 and the OpenHermes-2.5 \footnote{https://huggingface.co/datasets/teknium/openhermes} datasets. These comprehensive evaluations strengthen our claims regarding DIDS's broad applicability.
+
+\subsubsection{Results on Mixtral-7B with Tulu-v3}
+Table \ref{tab:mixtral_tulu} presents the performance of Mixtral-7B when trained on the Tulu-v3 dataset using various sampling strategies. Similar to our findings with Llama-3.1-8B, DIDS demonstrates superior performance across most tasks, achieving an average score of 61.2 in multi-task optimization, which outperforms the full data training (60.4) despite using only 10\% of the training examples. Notably, DIDS shows substantial improvements on mathematical reasoning tasks (MathQA: 17.8 vs. 15.8 for DGA) and truthfulness (TruthfulQA: 50.5 vs. 47.2 for Doge).
+
+\input{latex/figure/mixtral_tulu}
+
+\subsubsection{Results on Llama-3.1-8B with OpenHermes-2.5}
+We further evaluated DIDS using the OpenHermes-2.5 dataset, which offers a different distribution of training data compared to Tulu-v3. Table \ref{tab:llama_openhermes} shows that DIDS consistently outperforms baseline methods across various downstream tasks, achieving an average score of 62.7 in multi-task optimization, which is comparable to training on the full dataset (62.4). In single-task optimization, DIDS achieves even better performance with a score of 64.1, demonstrating the effectiveness of our domain-aware sampling approach on different data distributions.
+
+\input{latex/figure/llama_openhermes}
+
+\subsubsection{Results on Mixtral-7B with OpenHermes-2.5}
+To further demonstrate the robustness of our approach across different model-dataset combinations, we evaluated Mixtral-7B on the OpenHermes-2.5 dataset. As shown in Table \ref{tab:mixtral_openhermes}, DIDS continues to outperform baseline methods, achieving an average score of 60.1 in multi-task optimization and 61.3 in single-task optimization. These consistent improvements across different models and datasets strongly support the generalizability of our approach.
+
+\input{latex/figure/mixtral_openhermes}
+
+\subsection{Complete Ablation Study on All Downstream Tasks}
+Table \ref{tab:full_ablation} presents a comprehensive ablation study of DIDS across all nine downstream tasks. This expanded analysis provides a more detailed understanding of how each component contributes to the overall performance gains.
+
+\input{latex/figure/full_ablation}
+
+The ablation results clearly demonstrate the contribution of each component of DIDS. Gradient-based Clustering significantly improves performance, as replacing it with BERT semantic clustering leads to a 1.2-point drop in average performance from 62.3 to 61.1. This highlights the importance of grouping data based on actual training effects rather than semantic similarity alone. The FIM-guided Impact Measurement proves crucial, with its removal resulting in a 2.1-point decline to 60.2. This component shows particularly notable benefits for TruthfulQA, PubMedQA, and MathQA tasks, confirming that measuring domain impact through output distributional changes provides a more accurate assessment than gradient similarity alone. Finally, Loss Trajectory Consideration plays a vital role, as its elimination causes a 2.0-point decrease to 60.3, with substantial performance drops in instruction following and truthfulness tasks. This demonstrates the importance of accounting for both current learning progress and future potential when determining sampling probabilities.
+
+\subsection{Domain Mixing Analysis}
+Understanding how domain weights evolve during training provides valuable insights into DIDS's operation. Table \ref{tab:domain_weights} shows the weight changes for 10 randomly selected domains (out of 256) throughout the training process for both DIDS and DGA.
+
+\input{latex/figure/domain_weights}
+
+This comparison reveals several key differences between DIDS and DGA. DIDS makes more decisive weight adjustments, with stronger amplification of valuable domains like D023 reaching 2.8 compared to DGA's 1.3, while aggressively reducing less useful domains such as D045 to 0.0 versus DGA's 0.1. This decisive resource allocation contributes significantly to DIDS's superior performance. Furthermore, domains like D078 show non-monotonic weight changes in DIDS, demonstrating its ability to adapt to the dynamic importance of domains during training, in contrast to DGA's more gradual and sometimes inconsistent adjustments. DIDS also tends to converge more quickly to stable domain weights, particularly for highly valuable or less useful domains, enabling more efficient training as the optimal sampling distribution is established earlier. Analysis of domain overlap between DIDS and DGA shows approximately 40\% consistency in domain selection, with substantial differences in the remaining 60\%, highlighting the distinct impact assessment approaches of the two methods and explaining their performance differences.
+
+\subsection{Domain Clustering Insights}
+Our gradient-based domain clustering revealed several interesting patterns in how training data is organized:
+\begin{itemize}
+ \item \textbf{Fine-grained Topic Distinction}: With sufficiently large cluster counts (over 1,000), DIDS can distinguish between closely related topics. For example, in scientific data, middle school and high school biology knowledge are clustered separately, reflecting their different training effects on the model.
+
+ \item \textbf{Format Sensitivity}: Within the same subject area (e.g., middle school biology), different question formats like multiple-choice and fill-in-the-blank are clustered into separate domains. This suggests that format significantly influences how data affects model learning, beyond just semantic content.
+
+ \item \textbf{Cross-domain Similarity}: Some seemingly distinct topics like mathematical proofs and programming implementations are clustered together due to their shared logical reasoning patterns, despite their different semantic categories in traditional domain partitioning.
+
+ \item \textbf{Instruction Pattern Recognition}: Data with similar instruction patterns tends to be clustered together regardless of content topic, highlighting the importance of task structure in determining training effects.
+\end{itemize}
+
+\section{Theoretical Analysis of FIM-guided Impact Measurement}
+The Fisher Information Matrix (FIM) plays a crucial role in DIDS by enabling accurate measurement of how domain-specific parameter updates affect model behavior on downstream tasks. Here, we provide additional theoretical analysis to justify our approach.
+
+\subsection{Relationship to Model Uncertainty}
+The FIM is inherently connected to model uncertainty through the Cramér-Rao bound, which establishes that the inverse of FIM provides a lower bound on the covariance of any unbiased estimator of the parameters. In the context of domain impact measurement, this means that parameters with high Fisher Information have a stronger influence on the model's predictive distribution and consequently on task performance.
+
+For a parameter set $\theta$, the Fisher Information Matrix is defined as:
+
+\begin{equation}
+F(\theta) = \mathbb{E}_{p(x|\theta)}\left[\nabla_\theta \log p(x|\theta) \nabla_\theta \log p(x|\theta)^T\right]
+\end{equation}
+
+When we compute the impact metric between domain $D_i$ and task $S_j$ as $I(D_i, S_j) = \frac{1}{2}\Delta^T F \Delta$, we are effectively measuring the expected change in the model's log-likelihood on task $S_j$ when updated with domain $D_i$ data, weighted by the parameter sensitivity through the FIM.
+
+\subsection{Consistency with KL Divergence}
+The KL divergence between two distributions $p(y|\theta_{D_i})$ and $p(y|\theta_{S_j})$ measures how much information is lost when using one distribution to approximate the other. Our use of the second-order Taylor approximation of KL divergence:
+
+\begin{equation}
+\text{KL}[p(\theta_{D_i}) \parallel p(\theta_{S_j})] \approx \frac{1}{2}\Delta^T F \Delta
+\end{equation}
+
+Captures this information loss efficiently and accurately when the parameter updates are relatively small. Furthermore, this approximation has the advantage of being positive definite and symmetric (when properly scaled), which makes it a suitable measure for domain impact.
+
+\subsection{Extensions to Alternative Divergences}
+While our implementation focuses on KL divergence, the framework can be extended to other divergence measures such as Wasserstein distance or Jensen-Shannon divergence. The general form would remain similar:
+
+\begin{equation}
+D[p(\theta_{D_i})\!\! \parallel \!\! p(\theta_{S_j})] \!\! \approx \!\! (\nabla \! \ell_{S_j} \!\!- \! \nabla \! \ell_{D_i})^T M (\nabla \! \ell_{S_j} \!\! - \!\nabla \! \ell_{D_i}\!)
+\end{equation}
+
+where $M$ is a metric tensor appropriate for the chosen divergence. This flexibility allows DIDS to be adapted to different notions of distribution similarity based on specific requirements.
+
+\section{Practical Guidelines for DIDS Implementation}
+Based on our experiments and analyses, we provide the following practical guidelines for implementing DIDS effectively:
+
+\begin{itemize}
+ \item \textbf{Domain Count Selection}: Start with a medium number of domains (approximately 50-100) for gradient-based clustering. Increasing the domain count beyond 100 provides diminishing returns in most cases, while increasing computational cost.
+
+ \item \textbf{Update Frequency}: Set the domain sampling probability update interval to approximately 5-10\% of the total training steps. More frequent updates can cause instability, while less frequent updates may miss important adaptation opportunities.
+
+ \item \textbf{EMA Coefficient Tuning}: Use an EMA coefficient ($\beta$) of 0.1-0.3 for stability. Lower values allow for more rapid adaptation, which is beneficial in early training stages, while higher values provide stability in later stages.
+
+ \item \textbf{Proxy Model Selection}: A proxy model with 5-10\% the size of the target model typically provides a good balance between computational efficiency and gradient similarity. The proxy model should maintain the same architecture family as the target model for best results.
+
+ \item \textbf{Downstream Task Selection}: Include a diverse set of downstream tasks in the observation set, covering different capability areas like reasoning, knowledge, instruction following, etc. This diversity ensures balanced optimization across different aspects of model performance.
+
+ \item \textbf{FIM Computation Efficiency}: Compute the diagonal FIM approximation using a small batch size (16-32) for efficiency without significant loss in accuracy. The FIM computation only needs to be performed during domain sampling probability updates.
+\end{itemize}
+
+\section{Mixed Ratio Analysis}
+To further validate the effectiveness of different domain mixing strategies, we conducted a grid search analysis similar to that reported in previous work \cite{zhang2022opt,mckinzie2025mm1}. Table \ref{tab:grid_search} presents results for different mixing ratios of code/math domains versus general domains.
+
+\input{latex/figure/grid_search}
+
+This analysis demonstrates several key points. First, DIDS's dynamic approach outperforms all static mixing ratios across all three tasks, highlighting the limitations of fixed domain proportions throughout training. Second, different tasks show different optimal static mixing ratios - GSM8K and HumanEval benefit from higher proportions of code and math content, while MT-Bench performs better with more balanced or general-leaning distributions. Third, increasing the proportion of specialized domains like code and math significantly improves performance on related tasks such as GSM8K and HumanEval but can negatively impact general capabilities measured by MT-Bench. DIDS effectively navigates these trade-offs through dynamic adaptation.
+
+These results align with findings from industry practices in models like MM1 \cite{mckinzie2025mm1} and Llama3 \cite{touvron2023llama}, where mixed ratios are carefully tuned through extensive grid search. DIDS automates this process and improves upon static optimal ratios through dynamic adaptation.
+
+\section{Domain Partitioning Robustness Analysis}
+To evaluate DIDS's sensitivity to domain partitioning quality, we conducted destructive experiments by artificially corrupting the domain structure. Starting with our standard 72-domain gradient-based partitioning, we randomly swapped half of each domain's data with samples from other domains, creating mixed domains that violate intra-domain consistency assumptions.
+
+\begin{table*}[h]
+\centering
+\small
+\begin{tabular}{lccccccccc|c}
+\toprule
+\textbf{Method} & \textbf{BBH} & \textbf{BoolQ} & \textbf{GSM8K} & \textbf{MathQA} & \textbf{IFEval} & \textbf{MMLU} & \textbf{PIQA} & \textbf{PubMedQA} & \textbf{TruthfulQA} & \textbf{Avg} \\
+\midrule
+Random & 67.4 & 85.6 & 58.9 & 11.4 & 48.2 & 64.0 & 82.0 & 77.4 & 31.5 & 58.9 \\
+DIDS (Original) & \textbf{69.2} & \textbf{87.5} & \textbf{63.0} & \textbf{21.5} & \textbf{57.5} & \textbf{65.8} & \textbf{83.0} & \textbf{81.2} & \textbf{44.8} & \textbf{63.7} \\
+DIDS (Unreasonable) & 67.4 & 86.1 & 60.6 & 14.5 & 50.6 & 65.1 & 82.6 & 79.6 & 35.4 & 60.2 \\
+\bottomrule
+\end{tabular}
+\caption{Robustness analysis under corrupted domain partitioning on Llama-3.1-8B.}
+\label{tab:domain_partition_robustness}
+\end{table*}
+
+Table \ref{tab:domain_partition_robustness} demonstrates that while unreasonable partitioning degrades DIDS performance by 3.5 points (from 63.7 to 60.2), it still outperforms random sampling by 1.3 points. This indicates that DIDS exhibits graceful degradation and maintains effectiveness even when domain assumptions are violated, highlighting the robustness of our FIM-guided impact measurement and dynamic sampling components beyond perfect domain organization.
+
+\section{Detailed Computational Cost Analysis}
+We provide a comprehensive breakdown of computational costs across all baseline methods to demonstrate DIDS's efficiency advantages.
+
+\textbf{DoReMi and DoGE} employ offline reweighting strategies using 280M parameter proxy models. DoReMi requires: (1) reference model training with uniform sampling ($2.05 \times 10^3$ TFLOPs, 3.8 GPU hours), (2) proxy model training with Group DRO ($2.05 \times 10^3$ TFLOPs, 3.8 GPU hours), and (3) excess loss computation (0.37 TFLOPs, 0.004 GPU hours). DoGE follows similar complexity but uses gradient similarity calculations instead of excess loss computation.
+
+\textbf{Velocitune} employs two phases: (1) target estimation using full 8B models on 51\% subsampled data ($2.79 \times 10^4$ TFLOPs, 51.8 GPU hours), and (2) velocity-guided training with periodic updates every 150 steps ($1.2 \times 10^2$ TFLOPs, 0.2 GPU hours). The method requires significantly more resources due to full-size model training for target estimation.
+
+\textbf{DGA} uses online gradient alignment with minimal overhead for maintaining running averages of domain-specific gradients.
+
+\begin{table*}[h]
+\centering
+\small
+\begin{tabular}{lccc}
+\toprule
+\textbf{Method} & \textbf{Total TFLOPs} & \textbf{Overhead} & \textbf{Total GPU Hours} \\
+\midrule
+Base Training & $5.47 \times 10^4$ & - & 101.6 \\
+DGA & $5.56 \times 10^4$ & 1.6\% & 103.2 \\
+DIDS & $5.67 \times 10^4$ & 3.7\% & 105.2 \\
+DoReMi & $5.88 \times 10^4$ & 7.5\% & 109.2 \\
+DoGE & $5.88 \times 10^4$ & 7.5\% & 109.2 \\
+Velocitune & $8.27 \times 10^4$ & 51.2\% & 153.6 \\
+\bottomrule
+\end{tabular}
+\caption{Comprehensive computational cost comparison across all methods.}
+\label{tab:computational_cost_detailed}
+\end{table*}
+
+Table \ref{tab:computational_cost_detailed} demonstrates that DIDS achieves superior performance improvements while maintaining competitive computational efficiency. Our method requires only 3.7\% additional overhead, significantly lower than Velocitune's 51.2\% and comparable to efficient methods like DGA, while providing substantial performance gains.
+
+\section{Out-of-Distribution Generalization Analysis}
+
+For potential overfitting to specific downstream tasks, we evaluate DIDS's generalization capability on unseen out-of-distribution (OOD) tasks across two experimental setups.
+
+For the diverse downstream task setup, we used our standard 9-task evaluation suite (BBH, BoolQ, GSM8K, MathQA, IFEval, MMLU, PIQA, PubMedQA, TruthfulQA) during training, then evaluated on four unseen OOD tasks: WMT16 English-German translation (BLEU), TLDR summarization (Win Rate vs. Llama-3.1-8B), ARC-Challenge science questions (ACC), and MBPP code generation (Pass@1).
+
+For the single task setup, we also tested the extreme case where only one downstream task guides training on MathQA and TruthfulQA.
+
+
+\begin{table*}[h]
+\centering
+\small
+\begin{tabular}{lccccc}
+\toprule
+\textbf{Method} & \textbf{WMT16 EN-DE} & \textbf{TLDR} & \textbf{ARC-Challenge} & \textbf{MBPP} & \textbf{Average} \\
+\midrule
+Llama-3.1-8B (no training) & 17.1 & 50.1 & 76.4 & 55.6 & 49.8 \\
+Random (100k) & 16.8 & 47.5 & 82.7 & 57.3 & 51.0 \\
+DGA (100k) & 17.3 & 49.8 & 83.6 & 59.7 & 52.6 \\
+DIDS (100k) & \textbf{17.2} & \textbf{50.0} & \textbf{85.7} & \textbf{61.3} & \textbf{53.5} \\
+\bottomrule
+\end{tabular}
+\caption{OOD generalization performance with diverse downstream task training.}
+\label{tab:ood_diverse}
+\end{table*}
+
+\begin{table*}[h]
+\centering
+\small
+\begin{tabular}{lcccc|cccc}
+\toprule
+ & \multicolumn{4}{c|}{\textbf{Training Target: MathQA}} & \multicolumn{4}{c}{\textbf{Training Target: TruthfulQA}} \\
+\textbf{Method} & \textbf{WMT16} & \textbf{TLDR} & \textbf{ARC} & \textbf{MBPP} & \textbf{WMT16} & \textbf{TLDR} & \textbf{ARC} & \textbf{MBPP} \\
+\midrule
+Llama-3.1-8B (no training) & 17.1 & 50.1 & 76.4 & 55.6 & 17.1 & 50.1 & 76.4 & 55.6 \\
+Random & 16.8 & 47.5 & 82.7 & 57.3 & 16.8 & 47.5 & 82.7 & 57.3 \\
+DGA & 15.9 & 45.3 & 83.7 & 59.2 & 17.8 & 48.9 & 81.4 & 56.8 \\
+DIDS & \textbf{15.9} & \textbf{45.8} & \textbf{86.4} & \textbf{60.6} & \textbf{17.6} & \textbf{54.1} & \textbf{82.3} & \textbf{56.6} \\
+\bottomrule
+\end{tabular}
+\caption{OOD generalization under single-task optimization scenarios.}
+\label{tab:ood_single}
+\end{table*}
+
+The results in Table \ref{tab:ood_diverse} demonstrate that DIDS maintains competitive OOD performance, especially when optimized for diverse downstream objectives, achieving the highest average performance (53.5) across all OOD evaluation tasks. As shown in Table \ref{tab:ood_single}, when using single downstream tasks, DIDS shows meaningful cross-task transfer—mathematical reasoning benefits scientific reasoning (ARC: 86.4) and code generation (MBPP: 60.6), while truthfulness training improves summarization (TLDR: 54.1). While domain-specific optimization can lead to some specialization effects in unrelated areas, DIDS's FIM-guided approach captures meaningful cross-task dependencies and maintains reasonable generalization capabilities.
+
+\input{latex/figure/alg}
+
+\end{document}
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_natbib.bst" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_natbib.bst"
new file mode 100644
index 0000000000..cad5a5e9d7
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/acl_natbib.bst"
@@ -0,0 +1,1928 @@
+%%% Modification of BibTeX style file acl_natbib_nourl.bst
+%%% ... by urlbst, version 0.9.1 (marked with "% urlbst")
+%%% See and repository
+%%% Modifications Copyright 2002–23, Norman Gray,
+%%% and distributed under the terms of the LPPL; see README for discussion.
+%%%
+%%% Added webpage entry type, and url and lastchecked fields.
+%%% Added eprint support.
+%%% Added DOI support.
+%%% Added PUBMED support.
+%%% Added hyperref support.
+%%% Original headers follow...
+
+%%
+%% This is file `acl_natbib_basic.bst',
+%% generated with the docstrip utility.
+%%
+%% The original source files were:
+%%
+%% merlin.mbs (with options: `ay,nat,pres,ed-au,keyxyr,blkyear,dt-beg,yr-per,note-yr,num-xser,pre-edn,xedn,nfss')
+%% ----------------------------------------
+%% *** Intended for ACL conferences ***
+%%
+%% Copyright 1994-2011 Patrick W Daly
+ % ===============================================================
+ % IMPORTANT NOTICE:
+ % This bibliographic style (bst) file has been generated from one or
+ % more master bibliographic style (mbs) files, listed above.
+ %
+ % This generated file can be redistributed and/or modified under the terms
+ % of the LaTeX Project Public License Distributed from CTAN
+ % archives in directory macros/latex/base/lppl.txt; either
+ % version 1 of the License, or any later version.
+ % ===============================================================
+ % Name and version information of the main mbs file:
+ % \ProvidesFile{merlin.mbs}[2011/11/18 4.33 (PWD, AO, DPC)]
+ % For use with BibTeX version 0.99a or later
+ %-------------------------------------------------------------------
+ % This bibliography style file is intended for texts in ENGLISH
+ % This is an author-year citation style bibliography. As such, it is
+ % non-standard LaTeX, and requires a special package file to function properly.
+ % Such a package is natbib.sty by Patrick W. Daly
+ % The form of the \bibitem entries is
+ % \bibitem[Jones et al.(1990)]{key}...
+ % \bibitem[Jones et al.(1990)Jones, Baker, and Smith]{key}...
+ % The essential feature is that the label (the part in brackets) consists
+ % of the author names, as they should appear in the citation, with the year
+ % in parentheses following. There must be no space before the opening
+ % parenthesis!
+ % With natbib v5.3, a full list of authors may also follow the year.
+ % In natbib.sty, it is possible to define the type of enclosures that is
+ % really wanted (brackets or parentheses), but in either case, there must
+ % be parentheses in the label.
+ % The \cite command functions as follows:
+ % \citet{key} ==>> Jones et al. (1990)
+ % \citet*{key} ==>> Jones, Baker, and Smith (1990)
+ % \citep{key} ==>> (Jones et al., 1990)
+ % \citep*{key} ==>> (Jones, Baker, and Smith, 1990)
+ % \citep[chap. 2]{key} ==>> (Jones et al., 1990, chap. 2)
+ % \citep[e.g.][]{key} ==>> (e.g. Jones et al., 1990)
+ % \citep[e.g.][p. 32]{key} ==>> (e.g. Jones et al., 1990, p. 32)
+ % \citeauthor{key} ==>> Jones et al.
+ % \citeauthor*{key} ==>> Jones, Baker, and Smith
+ % \citeyear{key} ==>> 1990
+ %---------------------------------------------------------------------
+
+ENTRY
+ { address
+ archivePrefix
+ author
+ booktitle
+ chapter
+ edition
+ editor
+ eid
+ eprint
+ eprinttype % = archivePrefix
+ howpublished
+ institution
+ journal
+ key
+ month
+ note
+ number
+ organization
+ pages
+ publisher
+ school
+ series
+ title
+ type
+ volume
+ year
+ doi % urlbst
+ pubmed % urlbst
+ url % urlbst
+ lastchecked % urlbst
+ }
+ {}
+ { label extra.label sort.label short.list }
+INTEGERS { output.state before.all mid.sentence after.sentence after.block }
+% urlbst...
+% urlbst constants and state variables
+STRINGS { urlintro
+ eprinturl eprintprefix doiprefix doiurl pubmedprefix pubmedurl
+ citedstring onlinestring linktextstring
+ openinlinelink closeinlinelink }
+INTEGERS { hrefform doiform inlinelinks makeinlinelink
+ addeprints adddoi addpubmed }
+FUNCTION {init.urlbst.variables}
+{
+ % The following constants may be adjusted by hand, if desired
+
+ % The first set allow you to enable or disable certain functionality.
+ #1 'addeprints := % 0=no eprints; 1=include eprints
+ #2 'hrefform := % 0=no crossrefs; 1=hypertex hrefs; 2=hyperref hrefs
+ #1 'inlinelinks := % 0=URLs explicit; 1=URLs attached to titles
+ #1 'adddoi := % 0=no DOI resolver; 1=include it
+ #1 'addpubmed := % 0=no PUBMED resolver; 1=include it
+ #0 'doiform := % 0=with href; 1=with \doi{}
+
+ % String constants, which you _might_ want to tweak.
+ "online" 'onlinestring := % label that a resource is online
+ "[link]" 'linktextstring := % anonymous link text
+ "http://www.ncbi.nlm.nih.gov/pubmed/" 'pubmedurl := % prefix to make URL from PUBMED
+ "https://doi.org/" 'doiurl := % prefix to make URL from DOI
+ "doi:" 'doiprefix := % printed text to introduce DOI
+ "https://arxiv.org/abs/" 'eprinturl := % prefix to make URL from eprint ref
+ "cited " 'citedstring := % label in "lastchecked" remark
+ "arXiv:" 'eprintprefix := % text prefix printed before eprint ref
+ "PMID:" 'pubmedprefix := % text prefix printed before PUBMED ref
+ "URL: " 'urlintro := % text prefix before URL
+
+ % The following are internal state variables, not configuration constants,
+ % so they shouldn't be fiddled with.
+ #0 'makeinlinelink := % state variable managed by possibly.setup.inlinelink
+ "" 'openinlinelink := % ditto
+ "" 'closeinlinelink := % ditto
+}
+INTEGERS {
+ bracket.state
+ outside.brackets
+ open.brackets
+ within.brackets
+ close.brackets
+}
+% ...urlbst to here
+FUNCTION {init.state.consts}
+{ #0 'outside.brackets := % urlbst...
+ #1 'open.brackets :=
+ #2 'within.brackets :=
+ #3 'close.brackets := % ...urlbst to here
+
+ #0 'before.all :=
+ #1 'mid.sentence :=
+ #2 'after.sentence :=
+ #3 'after.block :=
+}
+STRINGS { s t}
+% urlbst
+FUNCTION {output.nonnull.original}
+{ 's :=
+ output.state mid.sentence =
+ { ", " * write$ }
+ { output.state after.block =
+ { add.period$ write$
+ newline$
+ "\newblock " write$
+ }
+ { output.state before.all =
+ 'write$
+ { add.period$ " " * write$ }
+ if$
+ }
+ if$
+ mid.sentence 'output.state :=
+ }
+ if$
+ s
+}
+
+% urlbst...
+% Minimal DOI parsing.
+% Given a DOI on the stack, check whether it starts with 'doiurl' or not.
+% In either case, leave on the stack first a DOI with, and then a DOI without, the URL prefix.
+FUNCTION {parse.doi}
+{
+ #1 doiurl text.length$ substring$
+ doiurl =
+ { doi
+ doi doiurl text.length$ #1 + #999 substring$ }
+ { doiurl doi *
+ doi }
+ if$
+}
+% The following three functions are for handling inlinelink. They wrap
+% a block of text which is potentially output with write$ by multiple
+% other functions, so we don't know the content a priori.
+% They communicate between each other using the variables makeinlinelink
+% (which is true if a link should be made), and closeinlinelink (which holds
+% the string which should close any current link. They can be called
+% at any time, but start.inlinelink will be a no-op unless something has
+% previously set makeinlinelink true, and the two ...end.inlinelink functions
+% will only do their stuff if start.inlinelink has previously set
+% closeinlinelink to be non-empty.
+% (thanks to 'ijvm' for suggested code here)
+FUNCTION {uand}
+{ 'skip$ { pop$ #0 } if$ } % 'and' (which isn't defined at this point in the file)
+FUNCTION {possibly.setup.inlinelink}
+{ makeinlinelink hrefform #0 > uand
+ { doi empty$ adddoi uand
+ { pubmed empty$ addpubmed uand
+ { eprint empty$ addeprints uand
+ { url empty$
+ { "" }
+ { url }
+ if$ }
+ { eprinturl eprint * }
+ if$ }
+ { pubmedurl pubmed * }
+ if$ }
+% { doiurl doi * }
+ { doi empty$
+ { "XXX" }
+ { doi parse.doi pop$ }
+ if$
+ }
+ if$
+ % an appropriately-formatted URL is now on the stack
+ hrefform #1 = % hypertex
+ { "\special {html: }{" * 'openinlinelink :=
+ "\special {html:}" 'closeinlinelink := }
+ { "\href {" swap$ * "} {" * 'openinlinelink := % hrefform=#2 -- hyperref
+ % the space between "} {" matters: a URL of just the right length can cause "\% newline em"
+ "}" 'closeinlinelink := }
+ if$
+ #0 'makeinlinelink :=
+ }
+ 'skip$
+ if$ % makeinlinelink
+}
+FUNCTION {add.inlinelink}
+{ openinlinelink empty$
+ 'skip$
+ { openinlinelink swap$ * closeinlinelink *
+ "" 'openinlinelink :=
+ }
+ if$
+}
+FUNCTION {output.nonnull}
+{ % Save the thing we've been asked to output
+ 's :=
+ % If the bracket-state is close.brackets, then add a close-bracket to
+ % what is currently at the top of the stack, and set bracket.state
+ % to outside.brackets
+ bracket.state close.brackets =
+ { "]" *
+ outside.brackets 'bracket.state :=
+ }
+ 'skip$
+ if$
+ bracket.state outside.brackets =
+ { % We're outside all brackets -- this is the normal situation.
+ % Write out what's currently at the top of the stack, using the
+ % original output.nonnull function.
+ s
+ add.inlinelink
+ output.nonnull.original % invoke the original output.nonnull
+ }
+ { % Still in brackets. Add open-bracket or (continuation) comma, add the
+ % new text (in s) to the top of the stack, and move to the close-brackets
+ % state, ready for next time (unless inbrackets resets it). If we come
+ % into this branch, then output.state is carefully undisturbed.
+ bracket.state open.brackets =
+ { " [" * }
+ { ", " * } % bracket.state will be within.brackets
+ if$
+ s *
+ close.brackets 'bracket.state :=
+ }
+ if$
+}
+
+% Call this function just before adding something which should be presented in
+% brackets. bracket.state is handled specially within output.nonnull.
+FUNCTION {inbrackets}
+{ bracket.state close.brackets =
+ { within.brackets 'bracket.state := } % reset the state: not open nor closed
+ { open.brackets 'bracket.state := }
+ if$
+}
+
+FUNCTION {format.lastchecked}
+{ lastchecked empty$
+ { "" }
+ { inbrackets citedstring lastchecked * }
+ if$
+}
+% ...urlbst to here
+FUNCTION {output}
+{ duplicate$ empty$
+ 'pop$
+ 'output.nonnull
+ if$
+}
+FUNCTION {output.check}
+{ 't :=
+ duplicate$ empty$
+ { pop$ "empty " t * " in " * cite$ * warning$ }
+ 'output.nonnull
+ if$
+}
+FUNCTION {fin.entry.original} % urlbst (renamed from fin.entry, so it can be wrapped below)
+{ add.period$
+ write$
+ newline$
+}
+
+FUNCTION {new.block}
+{ output.state before.all =
+ 'skip$
+ { after.block 'output.state := }
+ if$
+}
+FUNCTION {new.sentence}
+{ output.state after.block =
+ 'skip$
+ { output.state before.all =
+ 'skip$
+ { after.sentence 'output.state := }
+ if$
+ }
+ if$
+}
+FUNCTION {add.blank}
+{ " " * before.all 'output.state :=
+}
+
+FUNCTION {date.block}
+{
+ new.block
+}
+
+FUNCTION {not}
+{ { #0 }
+ { #1 }
+ if$
+}
+FUNCTION {and}
+{ 'skip$
+ { pop$ #0 }
+ if$
+}
+FUNCTION {or}
+{ { pop$ #1 }
+ 'skip$
+ if$
+}
+FUNCTION {new.block.checkb}
+{ empty$
+ swap$ empty$
+ and
+ 'skip$
+ 'new.block
+ if$
+}
+FUNCTION {field.or.null}
+{ duplicate$ empty$
+ { pop$ "" }
+ 'skip$
+ if$
+}
+FUNCTION {emphasize}
+{ duplicate$ empty$
+ { pop$ "" }
+ { "\emph{" swap$ * "}" * }
+ if$
+}
+FUNCTION {tie.or.space.prefix} % puts ~ before the preceding part if it is of length <3
+{ duplicate$ text.length$ #3 <
+ { "~" }
+ { " " }
+ if$
+ swap$
+}
+
+FUNCTION {capitalize}
+{ "u" change.case$ "t" change.case$ }
+
+FUNCTION {space.word}
+{ " " swap$ * " " * }
+ % Here are the language-specific definitions for explicit words.
+ % Each function has a name bbl.xxx where xxx is the English word.
+ % The language selected here is ENGLISH
+FUNCTION {bbl.and}
+{ "and"}
+
+FUNCTION {bbl.etal}
+{ "et~al." }
+
+FUNCTION {bbl.editors}
+{ "editors" }
+
+FUNCTION {bbl.editor}
+{ "editor" }
+
+FUNCTION {bbl.edby}
+{ "edited by" }
+
+FUNCTION {bbl.edition}
+{ "edition" }
+
+FUNCTION {bbl.volume}
+{ "volume" }
+
+FUNCTION {bbl.of}
+{ "of" }
+
+FUNCTION {bbl.number}
+{ "number" }
+
+FUNCTION {bbl.nr}
+{ "no." }
+
+FUNCTION {bbl.in}
+{ "in" }
+
+FUNCTION {bbl.pages}
+{ "pages" }
+
+FUNCTION {bbl.page}
+{ "page" }
+
+FUNCTION {bbl.chapter}
+{ "chapter" }
+
+FUNCTION {bbl.techrep}
+{ "Technical Report" }
+
+FUNCTION {bbl.mthesis}
+{ "Master's thesis" }
+
+FUNCTION {bbl.phdthesis}
+{ "Ph.D. thesis" }
+
+MACRO {jan} {"January"}
+
+MACRO {feb} {"February"}
+
+MACRO {mar} {"March"}
+
+MACRO {apr} {"April"}
+
+MACRO {may} {"May"}
+
+MACRO {jun} {"June"}
+
+MACRO {jul} {"July"}
+
+MACRO {aug} {"August"}
+
+MACRO {sep} {"September"}
+
+MACRO {oct} {"October"}
+
+MACRO {nov} {"November"}
+
+MACRO {dec} {"December"}
+
+MACRO {acmcs} {"ACM Computing Surveys"}
+
+MACRO {acta} {"Acta Informatica"}
+
+MACRO {cacm} {"Communications of the ACM"}
+
+MACRO {ibmjrd} {"IBM Journal of Research and Development"}
+
+MACRO {ibmsj} {"IBM Systems Journal"}
+
+MACRO {ieeese} {"IEEE Transactions on Software Engineering"}
+
+MACRO {ieeetc} {"IEEE Transactions on Computers"}
+
+MACRO {ieeetcad}
+ {"IEEE Transactions on Computer-Aided Design of Integrated Circuits"}
+
+MACRO {ipl} {"Information Processing Letters"}
+
+MACRO {jacm} {"Journal of the ACM"}
+
+MACRO {jcss} {"Journal of Computer and System Sciences"}
+
+MACRO {scp} {"Science of Computer Programming"}
+
+MACRO {sicomp} {"SIAM Journal on Computing"}
+
+MACRO {tocs} {"ACM Transactions on Computer Systems"}
+
+MACRO {tods} {"ACM Transactions on Database Systems"}
+
+MACRO {tog} {"ACM Transactions on Graphics"}
+
+MACRO {toms} {"ACM Transactions on Mathematical Software"}
+
+MACRO {toois} {"ACM Transactions on Office Information Systems"}
+
+MACRO {toplas} {"ACM Transactions on Programming Languages and Systems"}
+
+MACRO {tcs} {"Theoretical Computer Science"}
+
+% bibinfo.check avoids acting on missing fields while bibinfo.warn will
+% issue a warning message if a missing field is detected. Prior to calling
+% the bibinfo functions, the user should push the field value and then its
+% name string, in that order.
+FUNCTION {bibinfo.check}
+{ swap$
+ duplicate$ missing$
+ {
+ pop$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ pop$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {bibinfo.warn}
+{ swap$
+ duplicate$ missing$
+ {
+ swap$ "missing " swap$ * " in " * cite$ * warning$ pop$
+ ""
+ }
+ { duplicate$ empty$
+ {
+ swap$ "empty " swap$ * " in " * cite$ * warning$
+ }
+ { swap$
+ pop$
+ }
+ if$
+ }
+ if$
+}
+INTEGERS { nameptr namesleft numnames }
+
+
+STRINGS { bibinfo}
+
+FUNCTION {format.names}
+{ 'bibinfo :=
+ duplicate$ empty$ 'skip$ {
+ 's :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{ff~}{vv~}{ll}{, jj}" % first name first for all authors
+ format.name$
+ bibinfo bibinfo.check
+ 't :=
+ nameptr #1 >
+ {
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ t "others" =
+ {
+ " " * bbl.etal *
+ }
+ {
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+ } if$
+}
+FUNCTION {format.names.ed}
+{
+ format.names
+}
+FUNCTION {format.key}
+{ empty$
+ { key field.or.null }
+ { "" }
+ if$
+}
+
+FUNCTION {format.authors}
+{ author "author" format.names
+}
+FUNCTION {get.bbl.editor}
+{ editor num.names$ #1 > 'bbl.editors 'bbl.editor if$ }
+
+FUNCTION {format.editors}
+{ editor "editor" format.names duplicate$ empty$ 'skip$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ *
+ }
+ if$
+}
+FUNCTION {format.note}
+{
+ note empty$
+ { "" }
+ { note #1 #1 substring$
+ duplicate$ "{" =
+ 'skip$
+ { output.state mid.sentence =
+ { "l" }
+ { "u" }
+ if$
+ change.case$
+ }
+ if$
+ note #2 global.max$ substring$ * "note" bibinfo.check
+ }
+ if$
+}
+
+FUNCTION {format.title}
+{ title
+ duplicate$ empty$ 'skip$
+ { "t" change.case$ }
+ if$
+ "title" bibinfo.check
+}
+FUNCTION {format.full.names}
+{'s :=
+ "" 't :=
+ #1 'nameptr :=
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv~}{ll}" format.name$
+ 't :=
+ nameptr #1 >
+ {
+ namesleft #1 >
+ { ", " * t * }
+ {
+ s nameptr "{ll}" format.name$ duplicate$ "others" =
+ { 't := }
+ { pop$ }
+ if$
+ t "others" =
+ {
+ " " * bbl.etal *
+ }
+ {
+ numnames #2 >
+ { "," * }
+ 'skip$
+ if$
+ bbl.and
+ space.word * t *
+ }
+ if$
+ }
+ if$
+ }
+ 't
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {author.editor.key.full}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {author.key.full}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.full.names }
+ if$
+}
+
+FUNCTION {editor.key.full}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.full.names }
+ if$
+}
+
+FUNCTION {make.full.names}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.full
+ { type$ "proceedings" =
+ 'editor.key.full
+ 'author.key.full
+ if$
+ }
+ if$
+}
+
+FUNCTION {output.bibitem.original} % urlbst (renamed from output.bibitem, so it can be wrapped below)
+{ newline$
+ "\bibitem[{" write$
+ label write$
+ ")" make.full.names duplicate$ short.list =
+ { pop$ }
+ { * }
+ if$
+ "}]{" * write$
+ cite$ write$
+ "}" write$
+ newline$
+ ""
+ before.all 'output.state :=
+}
+
+FUNCTION {n.dashify}
+{
+ 't :=
+ ""
+ { t empty$ not }
+ { t #1 #1 substring$ "-" =
+ { t #1 #2 substring$ "--" = not
+ { "--" *
+ t #2 global.max$ substring$ 't :=
+ }
+ { { t #1 #1 substring$ "-" = }
+ { "-" *
+ t #2 global.max$ substring$ 't :=
+ }
+ while$
+ }
+ if$
+ }
+ { t #1 #1 substring$ *
+ t #2 global.max$ substring$ 't :=
+ }
+ if$
+ }
+ while$
+}
+
+FUNCTION {word.in}
+{ bbl.in capitalize
+ " " * }
+
+FUNCTION {format.date}
+{ year "year" bibinfo.check duplicate$ empty$
+ {
+ }
+ 'skip$
+ if$
+ extra.label *
+ before.all 'output.state :=
+ after.sentence 'output.state :=
+}
+FUNCTION {format.btitle}
+{ title "title" bibinfo.check
+ duplicate$ empty$ 'skip$
+ {
+ emphasize
+ }
+ if$
+}
+FUNCTION {either.or.check}
+{ empty$
+ 'pop$
+ { "can't use both " swap$ * " fields in " * cite$ * warning$ }
+ if$
+}
+FUNCTION {format.bvolume}
+{ volume empty$
+ { "" }
+ { bbl.volume volume tie.or.space.prefix
+ "volume" bibinfo.check * *
+ series "series" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ bbl.of space.word * swap$
+ emphasize * }
+ if$
+ "volume and number" number either.or.check
+ }
+ if$
+}
+FUNCTION {format.number.series}
+{ volume empty$
+ { number empty$
+ { series field.or.null }
+ { series empty$
+ { number "number" bibinfo.check }
+ { output.state mid.sentence =
+ { bbl.number }
+ { bbl.number capitalize }
+ if$
+ number tie.or.space.prefix "number" bibinfo.check * *
+ bbl.in space.word *
+ series "series" bibinfo.check *
+ }
+ if$
+ }
+ if$
+ }
+ { "" }
+ if$
+}
+
+FUNCTION {format.edition}
+{ edition duplicate$ empty$ 'skip$
+ {
+ output.state mid.sentence =
+ { "l" }
+ { "t" }
+ if$ change.case$
+ "edition" bibinfo.check
+ " " * bbl.edition *
+ }
+ if$
+}
+INTEGERS { multiresult }
+FUNCTION {multi.page.check}
+{ 't :=
+ #0 'multiresult :=
+ { multiresult not
+ t empty$ not
+ and
+ }
+ { t #1 #1 substring$
+ duplicate$ "-" =
+ swap$ duplicate$ "," =
+ swap$ "+" =
+ or or
+ { #1 'multiresult := }
+ { t #2 global.max$ substring$ 't := }
+ if$
+ }
+ while$
+ multiresult
+}
+FUNCTION {format.pages}
+{ pages duplicate$ empty$ 'skip$
+ { duplicate$ multi.page.check
+ {
+ bbl.pages swap$
+ n.dashify
+ }
+ {
+ bbl.page swap$
+ }
+ if$
+ tie.or.space.prefix
+ "pages" bibinfo.check
+ * *
+ }
+ if$
+}
+FUNCTION {format.journal.pages}
+{ pages duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$
+ { pop$ pop$ format.pages }
+ {
+ ":" *
+ swap$
+ n.dashify
+ "pages" bibinfo.check
+ *
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.journal.eid}
+{ eid "eid" bibinfo.check
+ duplicate$ empty$ 'pop$
+ { swap$ duplicate$ empty$ 'skip$
+ {
+ ":" *
+ }
+ if$
+ swap$ *
+ }
+ if$
+}
+FUNCTION {format.vol.num.pages}
+{ volume field.or.null
+ duplicate$ empty$ 'skip$
+ {
+ "volume" bibinfo.check
+ }
+ if$
+ number "number" bibinfo.check duplicate$ empty$ 'skip$
+ {
+ swap$ duplicate$ empty$
+ { "there's a number but no volume in " cite$ * warning$ }
+ 'skip$
+ if$
+ swap$
+ "(" swap$ * ")" *
+ }
+ if$ *
+ eid empty$
+ { format.journal.pages }
+ { format.journal.eid }
+ if$
+}
+
+FUNCTION {format.chapter}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ }
+ if$
+}
+
+FUNCTION {format.chapter.pages}
+{ chapter empty$
+ 'format.pages
+ { type empty$
+ { bbl.chapter }
+ { type "l" change.case$
+ "type" bibinfo.check
+ }
+ if$
+ chapter tie.or.space.prefix
+ "chapter" bibinfo.check
+ * *
+ pages empty$
+ 'skip$
+ { ", " * format.pages * }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.booktitle}
+{
+ booktitle "booktitle" bibinfo.check
+ emphasize
+}
+FUNCTION {format.in.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.in.ed.booktitle}
+{ format.booktitle duplicate$ empty$ 'skip$
+ {
+ editor "editor" format.names.ed duplicate$ empty$ 'pop$
+ {
+ "," *
+ " " *
+ get.bbl.editor
+ ", " *
+ * swap$
+ * }
+ if$
+ word.in swap$ *
+ }
+ if$
+}
+FUNCTION {format.thesis.type}
+{ type duplicate$ empty$
+ 'pop$
+ { swap$ pop$
+ "t" change.case$ "type" bibinfo.check
+ }
+ if$
+}
+FUNCTION {format.tr.number}
+{ number "number" bibinfo.check
+ type duplicate$ empty$
+ { pop$ bbl.techrep }
+ 'skip$
+ if$
+ "type" bibinfo.check
+ swap$ duplicate$ empty$
+ { pop$ "t" change.case$ }
+ { tie.or.space.prefix * * }
+ if$
+}
+FUNCTION {format.article.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.book.crossref}
+{ volume duplicate$ empty$
+ { "empty volume in " cite$ * "'s crossref of " * crossref * warning$
+ pop$ word.in
+ }
+ { bbl.volume
+ capitalize
+ swap$ tie.or.space.prefix "volume" bibinfo.check * * bbl.of space.word *
+ }
+ if$
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.incoll.inproc.crossref}
+{
+ word.in
+ " \cite{" * crossref * "}" *
+}
+FUNCTION {format.org.or.pub}
+{ 't :=
+ ""
+ address empty$ t empty$ and
+ 'skip$
+ {
+ t empty$
+ { address "address" bibinfo.check *
+ }
+ { t *
+ address empty$
+ 'skip$
+ { ", " * address "address" bibinfo.check * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+FUNCTION {format.publisher.address}
+{ publisher "publisher" bibinfo.warn format.org.or.pub
+}
+
+FUNCTION {format.organization.address}
+{ organization "organization" bibinfo.check format.org.or.pub
+}
+
+FUNCTION {archiveprefix.or.eprinttype} % holder for eprinttype with archiveprefix precedence
+{
+ archiveprefix empty$
+ {
+ eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ { eprinttype }
+ if$
+ }
+ { archiveprefix }
+ if$
+}
+
+FUNCTION {output.eprint} % this is only used with the @misc record type (common for arXiv and other preprint server bibtex records)
+{
+ eprint empty$
+ {% if eprint field is empty
+ publisher field.or.null "arXiv" = % field.or.null here helps when no publisher field in the record
+ { publisher " preprint" * } % add " preprint" to publisher with the idea that publisher is the name of the preprint server
+ { "" } % if publisher != "arXiv" then empty output
+ if$
+ emphasize % no output function after emphasize because nothing goes after this
+ }
+ {% if eprint field is not empty
+ archiveprefix.or.eprinttype empty$
+ { "" } % not using 'skip$ to reduce errors like "nothing to pop from stack"
+ {% if archiveprefix or eprinttype fields are not empty
+ journal empty$
+ { "Preprint" } % if journal field is empty: output just "Preprint" emphasized like a journal name
+ { journal } % if journal field is not empty, output it (takes precedence)
+ if$
+ emphasize output % emphasize what we formed before, setting output as a border to the subblock that follows with the comma delimiter
+ archiveprefix.or.eprinttype ":" * eprint * % subblock with eprinttype and eprint number
+ }
+ if$
+ }
+ if$
+}
+
+% urlbst...
+% Functions for making hypertext links.
+% In all cases, the stack has (link-text href-url)
+%
+% make 'null' specials
+FUNCTION {make.href.null}
+{
+ pop$
+}
+% make hypertex specials
+FUNCTION {make.href.hypertex}
+{
+ "\special {html: }" * swap$ *
+ "\special {html:}" *
+}
+% make hyperref specials
+FUNCTION {make.href.hyperref}
+{
+ "\href {" swap$ * "} {\path{" * swap$ * "}}" *
+}
+FUNCTION {make.href}
+{ hrefform #2 =
+ 'make.href.hyperref % hrefform = 2
+ { hrefform #1 =
+ 'make.href.hypertex % hrefform = 1
+ 'make.href.null % hrefform = 0 (or anything else)
+ if$
+ }
+ if$
+}
+
+% If inlinelinks is true, then format.url should be a no-op, since it's
+% (a) redundant, and (b) could end up as a link-within-a-link.
+FUNCTION {format.url}
+{ inlinelinks #1 = url empty$ or
+ { "" }
+ { hrefform #1 =
+ { % special case -- add HyperTeX specials
+ urlintro "\url{" url * "}" * url make.href.hypertex * }
+ { urlintro "\url{" * url * "}" * }
+ if$
+ }
+ if$
+}
+FUNCTION {format.eprint}
+{ eprint empty$
+ { "" }
+ { eprintprefix eprint * eprinturl eprint * make.href }
+ if$
+}
+
+FUNCTION {format.doi}
+{ doi empty$
+ { "" }
+ { doi parse.doi % leaves "https://doi.org/DOI" DOI on the stack
+ 's := 't :=
+ doiform #1 =
+ { "\doi{" s * "}" * }
+ { doiprefix s * t make.href }
+ if$
+ }
+ if$
+}
+
+FUNCTION {format.pubmed}
+{ pubmed empty$
+ { "" }
+ { pubmedprefix pubmed * pubmedurl pubmed * make.href }
+ if$
+}
+
+% Output a URL. We can't use the more normal idiom (something like
+% `format.url output'), because the `inbrackets' within
+% format.lastchecked applies to everything between calls to `output',
+% so that `format.url format.lastchecked * output' ends up with both
+% the URL and the lastchecked in brackets.
+FUNCTION {output.url}
+{ url empty$
+ 'skip$
+ { new.block
+ format.url output
+ format.lastchecked output
+ }
+ if$
+}
+
+FUNCTION {output.web.refs}
+{
+ new.block
+ inlinelinks
+ 'skip$ % links were inline -- don't repeat them
+ { % If the generated DOI will be the same as the URL,
+ % then don't print the URL (thanks to Joseph Wright
+ % for (the original version of) this code,
+ % at http://tex.stackexchange.com/questions/5660)
+ adddoi
+ doi empty$ { "X" } { doi parse.doi pop$ } if$ % DOI URL to be generated
+ url empty$ { "Y" } { url } if$ % the URL, or "Y" if empty
+ = % are the strings equal?
+ and
+ 'skip$
+ { output.url }
+ if$
+ addeprints eprint empty$ not and
+ { format.eprint output.nonnull }
+ 'skip$
+ if$
+ adddoi doi empty$ not and
+ { format.doi output.nonnull }
+ 'skip$
+ if$
+ addpubmed pubmed empty$ not and
+ { format.pubmed output.nonnull }
+ 'skip$
+ if$
+ }
+ if$
+}
+
+% Wrapper for output.bibitem.original.
+% If the URL field is not empty, set makeinlinelink to be true,
+% so that an inline link will be started at the next opportunity
+FUNCTION {output.bibitem}
+{ outside.brackets 'bracket.state :=
+ output.bibitem.original
+ inlinelinks url empty$ not doi empty$ not or pubmed empty$ not or eprint empty$ not or and
+ { #1 'makeinlinelink := }
+ { #0 'makeinlinelink := }
+ if$
+}
+
+% Wrapper for fin.entry.original
+FUNCTION {fin.entry}
+{ output.web.refs % urlbst
+ makeinlinelink % ooops, it appears we didn't have a title for inlinelink
+ { possibly.setup.inlinelink % add some artificial link text here, as a fallback
+ linktextstring output.nonnull }
+ 'skip$
+ if$
+ bracket.state close.brackets = % urlbst
+ { "]" * }
+ 'skip$
+ if$
+ fin.entry.original
+}
+
+% Webpage entry type.
+% Title and url fields required;
+% author, note, year, month, and lastchecked fields optional
+% See references
+% ISO 690-2 http://www.nlc-bnc.ca/iso/tc46sc9/standard/690-2e.htm
+% http://www.classroom.net/classroom/CitingNetResources.html
+% http://neal.ctstateu.edu/history/cite.html
+% http://www.cas.usf.edu/english/walker/mla.html
+% for citation formats for web pages.
+FUNCTION {webpage}
+{ output.bibitem
+ author empty$
+ { editor empty$
+ 'skip$ % author and editor both optional
+ { format.editors output.nonnull }
+ if$
+ }
+ { editor empty$
+ { format.authors output.nonnull }
+ { "can't use both author and editor fields in " cite$ * warning$ }
+ if$
+ }
+ if$
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$
+ format.title "title" output.check
+ inbrackets onlinestring output
+ new.block
+ year empty$
+ 'skip$
+ { format.date "year" output.check }
+ if$
+ % We don't need to output the URL details ('lastchecked' and 'url'),
+ % because fin.entry does that for us, using output.web.refs. The only
+ % reason we would want to put them here is if we were to decide that
+ % they should go in front of the rather miscellaneous information in 'note'.
+ new.block
+ note output
+ fin.entry
+}
+% ...urlbst to here
+
+
+FUNCTION {article}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ {
+ journal
+ "journal" bibinfo.check
+ emphasize
+ "journal" output.check
+ possibly.setup.inlinelink format.vol.num.pages output% urlbst
+ }
+ { format.article.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {book}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ crossref missing$
+ { format.bvolume output
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {booklet}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {inbook}
+{ output.bibitem
+ author empty$
+ { format.editors "author and editor" output.check
+ editor format.key output
+ }
+ { format.authors output.nonnull
+ crossref missing$
+ { "author and editor" editor either.or.check }
+ 'skip$
+ if$
+ }
+ if$
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ crossref missing$
+ {
+ format.edition output
+ format.bvolume output
+ format.chapter "chapter" output.check
+ new.block
+ format.number.series output
+ new.sentence
+ format.publisher.address output
+ }
+ {
+ format.chapter "chapter" output.check
+ new.block
+ format.book.crossref output.nonnull
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {incollection}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.ed.booktitle "booktitle" output.check
+ format.edition output
+ format.bvolume output
+ format.number.series output
+ format.chapter.pages output
+ new.sentence
+ format.publisher.address output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.chapter.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {inproceedings}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ crossref missing$
+ { format.in.booktitle "booktitle" output.check
+ format.bvolume output
+ format.number.series output
+ format.pages output
+ address "address" bibinfo.check output
+ new.sentence
+ organization "organization" bibinfo.check output
+ publisher "publisher" bibinfo.check output
+ }
+ { format.incoll.inproc.crossref output.nonnull
+ format.pages output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {conference} { inproceedings }
+FUNCTION {manual}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.edition output
+ organization address new.block.checkb
+ organization "organization" bibinfo.check output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {mastersthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ bbl.mthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ month "month" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {misc}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ howpublished "howpublished" bibinfo.check output
+ new.block
+ output.eprint output
+ new.block
+ format.note output
+ fin.entry
+}
+FUNCTION {phdthesis}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle
+ "title" output.check
+ new.block
+ bbl.phdthesis format.thesis.type output.nonnull
+ school "school" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {presentation}
+{ output.bibitem
+ format.authors output
+ author format.key output
+ new.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title output
+ new.block
+ format.organization.address "organization and address" output.check
+ month "month" output.check
+ year "year" output.check
+ new.block
+ format.note output
+ new.sentence
+ type missing$ 'skip$
+ {"(" type capitalize * ")" * output}
+ if$
+ fin.entry
+}
+
+FUNCTION {proceedings}
+{ output.bibitem
+ format.editors output
+ editor format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.btitle "title" output.check
+ format.bvolume output
+ format.number.series output
+ new.sentence
+ publisher empty$
+ { format.organization.address output }
+ { organization "organization" bibinfo.check output
+ new.sentence
+ format.publisher.address output
+ }
+ if$
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {techreport}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title
+ "title" output.check
+ new.block
+ format.tr.number output.nonnull
+ institution "institution" bibinfo.warn output
+ address "address" bibinfo.check output
+ new.block
+ format.note output
+ fin.entry
+}
+
+FUNCTION {unpublished}
+{ output.bibitem
+ format.authors "author" output.check
+ author format.key output
+ format.date "year" output.check
+ date.block
+ title empty$ 'skip$ 'possibly.setup.inlinelink if$ % urlbst
+ format.title "title" output.check
+ new.block
+ format.note "note" output.check
+ fin.entry
+}
+
+FUNCTION {default.type} { misc }
+READ
+FUNCTION {sortify}
+{ purify$
+ "l" change.case$
+}
+INTEGERS { len }
+FUNCTION {chop.word}
+{ 's :=
+ 'len :=
+ s #1 len substring$ =
+ { s len #1 + global.max$ substring$ }
+ 's
+ if$
+}
+FUNCTION {format.lab.names}
+{ 's :=
+ "" 't :=
+ s #1 "{vv~}{ll}" format.name$
+ s num.names$ duplicate$
+ #2 >
+ { pop$
+ " " * bbl.etal *
+ }
+ { #2 <
+ 'skip$
+ { s #2 "{ff }{vv }{ll}{ jj}" format.name$ "others" =
+ {
+ " " * bbl.etal *
+ }
+ { bbl.and space.word * s #2 "{vv~}{ll}" format.name$
+ * }
+ if$
+ }
+ if$
+ }
+ if$
+}
+
+FUNCTION {author.key.label}
+{ author empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {author.editor.key.label}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+ }
+ { author format.lab.names }
+ if$
+}
+
+FUNCTION {editor.key.label}
+{ editor empty$
+ { key empty$
+ { cite$ #1 #3 substring$ }
+ 'key
+ if$
+ }
+ { editor format.lab.names }
+ if$
+}
+
+FUNCTION {calc.short.authors}
+{ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.key.label
+ { type$ "proceedings" =
+ 'editor.key.label
+ 'author.key.label
+ if$
+ }
+ if$
+ 'short.list :=
+}
+
+FUNCTION {calc.label}
+{ calc.short.authors
+ short.list
+ "("
+ *
+ year duplicate$ empty$
+ short.list key field.or.null = or
+ { pop$ "" }
+ 'skip$
+ if$
+ *
+ 'label :=
+}
+
+FUNCTION {sort.format.names}
+{ 's :=
+ #1 'nameptr :=
+ ""
+ s num.names$ 'numnames :=
+ numnames 'namesleft :=
+ { namesleft #0 > }
+ { s nameptr
+ "{vv{ } }{ll{ }}{ ff{ }}{ jj{ }}"
+ format.name$ 't :=
+ nameptr #1 >
+ {
+ " " *
+ namesleft #1 = t "others" = and
+ { "zzzzz" 't := }
+ 'skip$
+ if$
+ t sortify *
+ }
+ { t sortify * }
+ if$
+ nameptr #1 + 'nameptr :=
+ namesleft #1 - 'namesleft :=
+ }
+ while$
+}
+
+FUNCTION {sort.format.title}
+{ 't :=
+ "A " #2
+ "An " #3
+ "The " #4 t chop.word
+ chop.word
+ chop.word
+ sortify
+ #1 global.max$ substring$
+}
+FUNCTION {author.sort}
+{ author empty$
+ { key empty$
+ { "to sort, need author or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {author.editor.sort}
+{ author empty$
+ { editor empty$
+ { key empty$
+ { "to sort, need author, editor, or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+ }
+ { author sort.format.names }
+ if$
+}
+FUNCTION {editor.sort}
+{ editor empty$
+ { key empty$
+ { "to sort, need editor or key in " cite$ * warning$
+ ""
+ }
+ { key sortify }
+ if$
+ }
+ { editor sort.format.names }
+ if$
+}
+FUNCTION {presort}
+{ calc.label
+ label sortify
+ " "
+ *
+ type$ "book" =
+ type$ "inbook" =
+ or
+ 'author.editor.sort
+ { type$ "proceedings" =
+ 'editor.sort
+ 'author.sort
+ if$
+ }
+ if$
+ #1 entry.max$ substring$
+ 'sort.label :=
+ sort.label
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+
+ITERATE {presort}
+SORT
+STRINGS { last.label next.extra }
+INTEGERS { last.extra.num last.extra.num.extended last.extra.num.blank number.label }
+FUNCTION {initialize.extra.label.stuff}
+{ #0 int.to.chr$ 'last.label :=
+ "" 'next.extra :=
+ #0 'last.extra.num :=
+ "a" chr.to.int$ #1 - 'last.extra.num.blank :=
+ last.extra.num.blank 'last.extra.num.extended :=
+ #0 'number.label :=
+}
+FUNCTION {forward.pass}
+{ last.label label =
+ { last.extra.num #1 + 'last.extra.num :=
+ last.extra.num "z" chr.to.int$ >
+ { "a" chr.to.int$ 'last.extra.num :=
+ last.extra.num.extended #1 + 'last.extra.num.extended :=
+ }
+ 'skip$
+ if$
+ last.extra.num.extended last.extra.num.blank >
+ { last.extra.num.extended int.to.chr$
+ last.extra.num int.to.chr$
+ * 'extra.label := }
+ { last.extra.num int.to.chr$ 'extra.label := }
+ if$
+ }
+ { "a" chr.to.int$ 'last.extra.num :=
+ "" 'extra.label :=
+ label 'last.label :=
+ }
+ if$
+ number.label #1 + 'number.label :=
+}
+FUNCTION {reverse.pass}
+{ next.extra "b" =
+ { "a" 'extra.label := }
+ 'skip$
+ if$
+ extra.label 'next.extra :=
+ extra.label
+ duplicate$ empty$
+ 'skip$
+ { year field.or.null #-1 #1 substring$ chr.to.int$ #65 <
+ { "{\natexlab{" swap$ * "}}" * }
+ { "{(\natexlab{" swap$ * "})}" * }
+ if$ }
+ if$
+ 'extra.label :=
+ label extra.label * 'label :=
+}
+EXECUTE {initialize.extra.label.stuff}
+ITERATE {forward.pass}
+REVERSE {reverse.pass}
+FUNCTION {bib.sort.order}
+{ sort.label
+ " "
+ *
+ year field.or.null sortify
+ *
+ " "
+ *
+ title field.or.null
+ sort.format.title
+ *
+ #1 entry.max$ substring$
+ 'sort.key$ :=
+}
+ITERATE {bib.sort.order}
+SORT
+FUNCTION {begin.bib}
+{ preamble$ empty$
+ 'skip$
+ { preamble$ write$ newline$ }
+ if$
+ "\begin{thebibliography}{" number.label int.to.str$ * "}" *
+ write$ newline$
+ "\providecommand{\natexlab}[1]{#1}"
+ write$ newline$
+}
+EXECUTE {begin.bib}
+EXECUTE {init.urlbst.variables} % urlbst
+EXECUTE {init.state.consts}
+ITERATE {call.type$}
+FUNCTION {end.bib}
+{ newline$
+ "\end{thebibliography}" write$ newline$
+}
+EXECUTE {end.bib}
+%% End of customized bst file
+%%
+%% End of file `acl_natbib_basic.bst'.
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/custom.bib" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/custom.bib"
new file mode 100644
index 0000000000..08f386159e
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/custom.bib"
@@ -0,0 +1,552 @@
+@Article{Abril07,
+ author = "Patricia S. Abril and Robert Plant",
+ title = "The patent holder's dilemma: Buy, sell, or troll?",
+ journal = "Communications of the ACM",
+ volume = "50",
+ number = "1",
+ month = jan,
+ year = "2007",
+ pages = "36--44",
+ doi = "10.1145/1188913.1188915",
+ url = "http://doi.acm.org/10.1145/1219092.1219093",
+ note = "",
+}
+
+@Article{Cohen07,
+ author = "Sarah Cohen and Werner Nutt and Yehoshua Sagic",
+ title = "Deciding equivalances among conjunctive aggregate queries",
+ journal = JACM,
+ articleno = "5",
+ numpages = "50",
+ volume = "54",
+ number = "2",
+ month = apr,
+ year = "2007",
+ doi = "10.1145/1219092.1219093",
+ url = "http://doi.acm.org/10.1145/1219092.1219093",
+ acmid = "1219093",
+ note = "",
+}
+
+
+@periodical{JCohen96,
+ key = "Cohen",
+ editor = "Jacques Cohen",
+ title = "Special issue: Digital Libraries",
+ journal = CACM,
+ volume = "39",
+ number = "11",
+ month = nov,
+ year = "1996",
+}
+
+
+@Book{Kosiur01,
+ author = "David Kosiur",
+ title = "Understanding Policy-Based Networking",
+ publisher = "Wiley",
+ year = "2001",
+ address = "USA",
+ edition = "2.",
+ editor = "",
+ volume = "",
+ number = "",
+ series = "",
+ month = "",
+ note = "",
+}
+
+
+@Book{Harel79,
+ author = "David Harel",
+ year = "1979",
+ title = "First-Order Dynamic Logic",
+ series = "Lecture Notes in Computer Science",
+ volume = "68",
+ address = "New York, NY",
+ publisher = "Springer-Verlag",
+ doi = "10.1007/3-540-09237-4",
+ url = "http://dx.doi.org/10.1007/3-540-09237-4",
+ editor = "",
+ number = "",
+ month = "",
+ note = "",
+}
+
+@inbook{Editor00,
+ author = {Peter Eston},
+ title = {The title of the work},
+ chapter = 8,
+ pages = {201-213},
+ publisher = {The name of the publisher},
+ doi = "10.1007/3-540-09237-4",
+ year = 1993,
+ volume = 4,
+ series = 5,
+ address = {The address of the publisher},
+ edition = 3,
+ month = 7,
+ note = {An optional note}
+}
+
+%
+@InBook{Editor00a,
+ author = "",
+ editor = "Ian Editor",
+ title = "The title of book two",
+ subtitle = "The book subtitle",
+ series = "The name of the series two",
+ year = "2008",
+ address = "Chicago",
+ edition = "2nd.",
+ publisher = "University of Chicago Press",
+ doi = "10.1007/3-540-09237-4",
+ url = "http://dx.doi.org/10.1007/3-540-09456-9",
+ volume = "",
+ chapter = "100",
+ pages = {201-213},
+ number = "",
+ type = "",
+ month = "",
+ note = "",
+}
+
+
+% incollection (has an editor, title, and possibly a booktitle)
+@Incollection{Spector90,
+ author = "Asad Z. Spector",
+ title = "Achieving application requirements",
+ booktitle = "Distributed Systems",
+ publisher = "ACM Press",
+ address = "New York, NY",
+ year = "1990",
+ edition = "2nd.",
+ chapter = "",
+ editor = "Sape Mullender",
+ pages = "19--33",
+ doi = "10.1145/90417.90738",
+ url = "http://doi.acm.org/10.1145/90417.90738",
+ volume = "",
+ number = "",
+ series = "",
+ type = "",
+ month = "",
+ note = "",
+}
+
+
+% incollection (has an editor, title, and possibly a booktitle)
+@Incollection{Douglass98,
+ author = "Bruce P. Douglass and David Harel and Mark B. Trakhtenbrot",
+ title = "Statecarts in use: structured analysis and object-orientation",
+ series = "Lecture Notes in Computer Science",
+ booktitle = "Lectures on Embedded Systems",
+ publisher = "Springer-Verlag",
+ address = "London",
+ volume = "1494",
+ year = "1998",
+ chapter = "",
+ editor = "Grzegorz Rozenberg and Frits W. Vaandrager",
+ pages = "368--394",
+ doi = "10.1007/3-540-65193-4_29",
+ url = "http://dx.doi.org/10.1007/3-540-65193-4_29",
+ edition = "",
+ number = "",
+ type = "",
+ month = "",
+ note = "",
+}
+
+
+@Book{Knuth97,
+ author = "Donald E. Knuth",
+ title = "The Art of Computer Programming, Vol. 1: Fundamental Algorithms (3rd. ed.)",
+ publisher = "Addison Wesley Longman Publishing Co., Inc.",
+ year = "1997",
+ address = "USA",
+ edition = "",
+ editor = "",
+ volume = "",
+ number = "",
+ series = "",
+ month = "",
+ note = "",
+}
+
+
+@Book{Knuth98,
+ author = "Donald E. Knuth",
+ year = "1998",
+ title = "The Art of Computer Programming",
+ series = "Fundamental Algorithms",
+ volume = "1",
+ edition = "3rd",
+ address = "",
+ publisher = "Addison Wesley Longman Publishing Co., Inc.",
+ doi = "",
+ url = "",
+ editor = "",
+ number = "",
+ month = "",
+ note = "(book)",
+}
+
+%Inbook{Knuth97,
+% author = "Donald E. Knuth",
+% title = "The Art of Computer Programming",
+% booktitle = "the booktitle",
+% edition = "3",
+% volume = "1",
+% year = "1997",
+% publisher = "Addison Wesley Longman Publishing Co., Inc.",
+% editor = "",
+% number = "",
+% series = "Fundamental Algorithms",
+% type = "",
+% chapter = "",
+% pages = "",
+% address = "",
+% month = "",
+% note = "(inbook)",
+%}
+
+%INBOOK{DK:73-inbook-full,
+% author = "Donald E. Knuth",
+% title = "Fundamental Algorithms (inbook w series)",
+% volume = 1,
+% series = "The Art of Computer Programming",
+% publisher = "Addison-Wesley",
+% address = "Reading, Massachusetts",
+% edition = "Second",
+% month = "10~" # jan,
+% year = "1973",
+% type = "Section",
+% chapter = "1.2",
+% pages = "10--119",
+% note = "Full INBOOK entry (w series)",
+%}
+
+%INcollection{DK:74-incoll,
+% author = "Donald E. Knuth",
+% title = "Fundamental Algorithms (incoll)",
+% volume = 1,
+% booktitle = "The Art of Computer Programming",
+% publisher = "Addison-Wesley",
+% address = "Reading, Massachusetts",
+% month = "10~" # jan,
+% year = "1974",
+% pages = "10--119",
+% editor = "Bernard Rous",
+% note = "This is a full incoll entry with an editor",
+%}
+
+%INcollection{DK:75-incollws,
+% author = "Donald E. Knuth",
+% title = "Fundamental Algorithms (incoll w series)",
+% volume = 1,
+% booktitle = "The Art of Computer Programming",
+% series = "The Art of Computer Programming",
+% publisher = "Addison-Wesley",
+% address = "Reading, Massachusetts",
+% month = "10~" # jan,
+% year = "1975",
+% pages = "10--119",
+% editor = "Bernard Rous",
+% note = "This is a full incoll entry with an editor and series",
+%}
+
+@article{fan2024dynamic,
+ title={Dynamic Gradient Alignment for Online Data Mixing},
+ author={Fan, Simin and Grangier, David and Ablin, Pierre},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{xie2024doremi,
+ title={Doremi: Optimizing data mixtures speeds up language model pretraining},
+ author={Xie, Sang Michael and Pham, Hieu and Dong, Xuanyi and Du, Nan and Liu, Hanxiao and Lu, Yifeng and Liang, Percy S and Le, Quoc V and Ma, Tengyu and Yu, Adams Wei},
+ journal={NIPS},
+ volume={36},
+ year={2024}
+}
+
+@inproceedings{fandoge,
+ title={DOGE: Domain Reweighting with Generalization Estimation},
+ author={Fan, Simin and Pagliardini, Matteo and Jaggi, Martin},
+ booktitle={ICML},
+ year={2024}
+}
+
+@article{kang2024autoscale,
+ title={Autoscale: Automatic prediction of compute-optimal data composition for training llms},
+ author={Kang, Feiyang and Sun, Yifan and Wen, Bingbing and Chen, Si and Song, Dawn and Mahmood, Rafid and Jia, Ruoxi},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{liu2024regmix,
+ title={Regmix: Data mixture as regression for language model pre-training},
+ author={Liu, Qian and Zheng, Xiaosen and Muennighoff, Niklas and Zeng, Guangtao and Dou, Longxu and Pang, Tianyu and Jiang, Jing and Lin, Min},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{ye2024data,
+ title={Data mixing laws: Optimizing data mixtures by predicting language modeling performance},
+ author={Ye, Jiasheng and Liu, Peiju and Sun, Tianxiang and Zhou, Yunhua and Zhan, Jun and Qiu, Xipeng},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{touvron2023llama,
+ title={Llama: Open and efficient foundation language models},
+ author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{weber2024redpajama,
+ title={Redpajama: an open dataset for training large language models},
+ author={Weber, Maurice and Fu, Daniel and Anthony, Quentin and Oren, Yonatan and Adams, Shane and Alexandrov, Anton and Lyu, Xiaozhong and Nguyen, Huu and Yao, Xiaozhe and Adams, Virginia and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{zhang2022opt,
+ title={Opt: Open pre-trained transformer language models},
+ author={Zhang, Susan and Roller, Stephen and Goyal, Naman and Artetxe, Mikel and Chen, Moya and Chen, Shuohui and Dewan, Christopher and Diab, Mona and Li, Xian and Lin, Xi Victoria and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@inproceedings{biderman2023pythia,
+ title={Pythia: A suite for analyzing large language models across training and scaling},
+ author={Biderman, Stella and Schoelkopf, Hailey and Anthony, Quentin Gregory and Bradley, Herbie and O’Brien, Kyle and Hallahan, Eric and Khan, Mohammad Aflah and Purohit, Shivanshu and Prashanth, USVSN Sai and Raff, Edward and others},
+ booktitle={ICML},
+ pages={2397--2430},
+ year={2023},
+ organization={PMLR}
+}
+
+@article{xia2024less,
+ title={Less: Selecting influential data for targeted instruction tuning},
+ author={Xia, Mengzhou and Malladi, Sadhika and Gururangan, Suchin and Arora, Sanjeev and Chen, Danqi},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{zhou2024lima,
+ title={Lima: Less is more for alignment},
+ author={Zhou, Chunting and Liu, Pengfei and Xu, Puxin and Iyer, Srinivasan and Sun, Jiao and Mao, Yuning and Ma, Xuezhe and Efrat, Avia and Yu, Ping and Yu, Lili and others},
+ journal={Advances in Neural Information Processing Systems},
+ volume={36},
+ year={2024}
+}
+
+@article{luo2024velocitune,
+ title={Velocitune: A Velocity-based Dynamic Domain Reweighting Method for Continual Pre-training},
+ author={Luo, Zheheng and Zhang, Xin and Liu, Xiao and Li, Haoling and Gong, Yeyun and Qi, Chen and Cheng, Peng},
+ journal={arXiv},
+ year={2024}
+}
+
+@inproceedings{mckinzie2025mm1,
+ title={MM1: methods, analysis and insights from multimodal LLM pre-training},
+ author={McKinzie, Brandon and Gan, Zhe and Fauconnier, Jean-Philippe and Dodge, Sam and Zhang, Bowen and Dufter, Philipp and Shah, Dhruti and Du, Xianzhi and Peng, Futang and Belyi, Anton and others},
+ booktitle={ECCV},
+ pages={304--323},
+ year={2025},
+ organization={Springer}
+}
+
+@article{pruthi2020estimating,
+ title={Estimating training data influence by tracing gradient descent},
+ author={Pruthi, Garima and Liu, Frederick and Kale, Satyen and Sundararajan, Mukund},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ pages={19920--19930},
+ year={2020}
+}
+
+@article{ge2024clustering,
+ title={Clustering and ranking: Diversity-preserved instruction selection through expert-aligned quality estimation},
+ author={Ge, Yuan and Liu, Yilun and Hu, Chi and Meng, Weibin and Tao, Shimin and Zhao, Xiaofeng and Ma, Hongxia and Zhang, Li and Chen, Boxing and Yang, Hao and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{liu2023makes,
+ title={What makes good data for alignment? a comprehensive study of automatic data selection in instruction tuning},
+ author={Liu, Wei and Zeng, Weihao and He, Keqing and Jiang, Yong and He, Junxian},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{xie2023data,
+ title={Data selection for language models via importance resampling},
+ author={Xie, Sang Michael and Santurkar, Shibani and Ma, Tengyu and Liang, Percy S},
+ journal={Advances in Neural Information Processing Systems},
+ volume={36},
+ pages={34201--34227},
+ year={2023}
+}
+
+@article{park2023trak,
+ title={Trak: Attributing model behavior at scale},
+ author={Park, Sung Min and Georgiev, Kristian and Ilyas, Andrew and Leclerc, Guillaume and Madry, Aleksander},
+ journal={arXiv},
+ year={2023}
+}
+
+@inproceedings{zhanglearning,
+ title={Learning Representation from Neural Fisher Kernel with Low-rank Approximation},
+ author={ZHANG, Ruixiang and Zhai, Shuangfei and Littwin, Etai and Susskind, Joshua M},
+ booktitle={ICLR}
+}
+
+@article{lambert2024t,
+ title={T$\backslash$" ULU 3: Pushing Frontiers in Open Language Model Post-Training},
+ author={Lambert, Nathan and Morrison, Jacob and Pyatkin, Valentina and Huang, Shengyi and Ivison, Hamish and Brahman, Faeze and Miranda, Lester James V and Liu, Alisa and Dziri, Nouha and Lyu, Shane and others},
+ journal={arXiv},
+ year={2024}
+}
+
+@article{suzgun2022challenging,
+ title={Challenging big-bench tasks and whether chain-of-thought can solve them},
+ author={Suzgun, Mirac and Scales, Nathan and Sch{\"a}rli, Nathanael and Gehrmann, Sebastian and Tay, Yi and Chung, Hyung Won and Chowdhery, Aakanksha and Le, Quoc V and Chi, Ed H and Zhou, Denny and others},
+ journal={arXiv},
+ year={2022}
+}
+
+@article{clark2019boolq,
+ title={BoolQ: Exploring the surprising difficulty of natural yes/no questions},
+ author={Clark, Christopher and Lee, Kenton and Chang, Ming-Wei and Kwiatkowski, Tom and Collins, Michael and Toutanova, Kristina},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{cobbe2021training,
+ title={Training verifiers to solve math word problems},
+ author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and others},
+ journal={arXiv},
+ year={2021}
+}
+
+@article{zhou2023instruction,
+ title={Instruction-following evaluation for large language models},
+ author={Zhou, Jeffrey and Lu, Tianjian and Mishra, Swaroop and Brahma, Siddhartha and Basu, Sujoy and Luan, Yi and Zhou, Denny and Hou, Le},
+ journal={arXiv},
+ year={2023}
+}
+
+@article{lewkowycz2022solving,
+ title={Solving quantitative reasoning problems with language models},
+ author={Lewkowycz, Aitor and Andreassen, Anders and Dohan, David and Dyer, Ethan and Michalewski, Henryk and Ramasesh, Vinay and Slone, Ambrose and Anil, Cem and Schlag, Imanol and Gutman-Solo, Theo and others},
+ journal={Advances in Neural Information Processing Systems},
+ volume={35},
+ pages={3843--3857},
+ year={2022}
+}
+
+@article{hendrycks2020measuring,
+ title={Measuring massive multitask language understanding},
+ author={Hendrycks, Dan and Burns, Collin and Basart, Steven and Zou, Andy and Mazeika, Mantas and Song, Dawn and Steinhardt, Jacob},
+ journal={arXiv},
+ year={2020}
+}
+
+@inproceedings{bisk2020piqa,
+ title={Piqa: Reasoning about physical commonsense in natural language},
+ author={Bisk, Yonatan and Zellers, Rowan and Gao, Jianfeng and Choi, Yejin and others},
+ booktitle={AAAI},
+ volume={34},
+ pages={7432--7439},
+ year={2020}
+}
+
+@article{jin2019pubmedqa,
+ title={Pubmedqa: A dataset for biomedical research question answering},
+ author={Jin, Qiao and Dhingra, Bhuwan and Liu, Zhengping and Cohen, William W and Lu, Xinghua},
+ journal={arXiv},
+ year={2019}
+}
+
+@article{lin2021truthfulqa,
+ title={Truthfulqa: Measuring how models mimic human falsehoods},
+ author={Lin, Stephanie and Hilton, Jacob and Evans, Owain},
+ journal={arXiv},
+ year={2021}
+}
+
+@misc{eval-harness,
+ author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
+ title = {A framework for few-shot language model evaluation},
+ month = 07,
+ year = 2024,
+ publisher = {Zenodo},
+ version = {v0.4.3},
+ doi = {10.5281/zenodo.12608602},
+ url = {https://zenodo.org/records/12608602}
+}
+
+@inproceedings{zheng2024llamafactory,
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={ACL},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
+ year={2024},
+ url={http://arxiv.org/abs/2403.13372}
+}
+
+@article{deng2024drpruning,
+ title={DRPruning: Efficient Large Language Model Pruning through Distributionally Robust Optimization},
+ author={Deng, Hexuan and Jiao, Wenxiang and Liu, Xuebo and Li, Jing and Zhang, Min and Tu, Zhaopeng},
+ journal={arXiv preprint arXiv:2411.14055},
+ year={2024}
+}
+
+@article{liu2024ddk,
+ title={Ddk: Distilling domain knowledge for efficient large language models},
+ author={Liu, Jiaheng and Zhang, Chenchen and Guo, Jinyang and Zhang, Yuanxing and Que, Haoran and Deng, Ken and Liu, Jie and Zhang, Ge and Wu, Yanan and Liu, Congnan and others},
+ journal={Advances in Neural Information Processing Systems},
+ volume={37},
+ pages={98297--98319},
+ year={2024}
+}
+
+@article{chen2024large,
+ title={Large Knowledge Model: Perspectives and Challenges},
+ author={Chen, Huajun},
+ journal={Data Intelligence},
+ volume={6},
+ number={3},
+ pages={587--620},
+ year={2024},
+ publisher={China Science Publishing & Media Ltd.},
+ doi={10.3724/2096-7004.di.2024.0001}
+}
+
+@article{azaria2024chat,
+ title={ChatGPT is a Remarkable Tool—For Experts},
+ author={Azaria, Amos and Azoulay, Rina and Reches, Shulamit},
+ journal={Data Intelligence},
+ volume={6},
+ number={1},
+ pages={240--296},
+ year={2024},
+ publisher={China Science Publishing & Media Ltd.},
+ doi={10.1162/dint_a_00235}
+}
+
+@article{zhang2024effective,
+ title={Effective Tool Augmented Multi-Agent Framework for Data Analysis},
+ author={Zhang, Xilin and Mao, Zhixin and Chen, Ziwen and Gao, Shen},
+ journal={Data Intelligence},
+ volume={6},
+ number={4},
+ pages={923--945},
+ year={2024},
+ publisher={China Science Publishing & Media Ltd.},
+ doi={10.3724/2096-7004.di.2024.0013}
+}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/alg.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/alg.tex"
new file mode 100644
index 0000000000..1cab530525
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/alg.tex"
@@ -0,0 +1,50 @@
+\begin{algorithm*}[t]
+\caption{Domain Impact-aware Data Sampling (DIDS)}
+\label{alg:dids}
+\KwIn{Training dataset $\mathcal{D}$; Downstream tasks $\mathcal{S}$; Proxy model $f'$; Number of domains $k$; Update interval $\tau$; EMA coefficient $\beta$; Training steps $T$}
+\KwOut{Domain sampling probabilities $\mathbf{p}_t$}
+
+// Initialize sampling probabilities uniformly\;
+$\mathbf{p}_0 \leftarrow [1/k,...,1/k]$\;
+
+// Domain repartition based on gradients\;
+$G \leftarrow \emptyset$\;
+\ForEach{$x_i \in \mathcal{D}$}{
+ $g_i \leftarrow \nabla \ell(f', x_i)$ \;
+ $g_i \leftarrow \text{TopK}(g_i)$ \;
+ $\tilde{g}_i \leftarrow R^T g_i$\;
+ $G \leftarrow G \cup \{\tilde{g}_i\}$\;
+}
+$\{D_1,...,D_k\} \leftarrow \text{KMeans}(G, k)$\;
+
+\For{$t \leftarrow 1$ \KwTo $T$}{
+ \If{$t \bmod \tau = 0$}{
+ // Compute domain impact matrix\;
+ \ForEach{$D_i \in \{D_1,...,D_k\}$}{
+ \ForEach{$S_j \in \{S_1,...,S_m\}$}{
+ $\Delta \leftarrow \nabla \ell_{S_j} - \nabla \ell_{D_i}$\;
+ $F \leftarrow \mathbb{E}[\nabla\log p(\theta) \odot \nabla\log p(\theta)]$\;
+ $I(D_i,S_j) \leftarrow \frac{1}{2}\Delta^T F \Delta$\;
+ }
+ }
+
+ // Compute future potential\;
+ \ForEach{$S_j \in \{S_1,...,S_m\}$}{
+ Fit $L(t) = ae^{-bt} + c$ using loss history $\{L_1(S_j),...,L_t(S_j)\}$\;
+ $L_p(S_j) \leftarrow L_t(S_j) - L(t + \tau)$\;
+ $\Delta L(S_j) \leftarrow L_{t-1}(S_j) - L_t(S_j)$\;
+ }
+
+ // Update sampling probabilities\;
+ \ForEach{$D_i \in \{D_1,...,D_k\}$}{
+ $U(D_i) \leftarrow \sum_j \frac{I(D_i,S_j) \cdot (\Delta L(S_j) + L_p(S_j))}{p_{t-1,i}}$\;
+ $\hat{p}_{t,i} \leftarrow \text{softmax}(U(D_i))$\;
+ $p_{t,i} \leftarrow \beta p_{t-1,i} + (1-\beta)\hat{p}_{t,i}$\;
+ }
+ $\mathbf{p}_t \leftarrow \mathbf{p}_t / \sum_i p_{t,i}$\;
+ }
+ Sample batch $\mathcal{B}_t$ according to $\mathbf{p}_t$\;
+ Update model parameters $\theta$ using $\mathcal{B}_t$\;
+}
+\Return{$\mathbf{p}_t$}
+\end{algorithm*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/architecture.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/architecture.tex"
new file mode 100644
index 0000000000..033776cd4b
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/architecture.tex"
@@ -0,0 +1,11 @@
+\begin{figure*}[t]
+\centering
+\begin{minipage}[t]{1\linewidth}
+\centering
+\includegraphics[width=1.0\textwidth]{latex/figure/src/mixture-architecture.pdf}
+\end{minipage}
+\centering
+% \caption{Overview of DIDS, which consists of three steps: gradient-based domain repartition, FIM-guided impact measurement, and dynamic sampling probability update based on both current improvement and future potential.}
+\caption{Overview of DIDS's three-step process: (1) Domain repartition using gradient-based clustering with a proxy model and dimensionality reduction through random projection, (2) Domain impact measurement using FIM-guided metrics that quantify how domain-specific parameter updates affect model's output distributions on downstream tasks, and (3) Dynamic sampling probability updates that combine both FIM-guided domain impact assessment and loss learning trajectories to account for diminishing marginal returns.}
+\label{fig:architecture}
+\end{figure*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/domain_weights.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/domain_weights.tex"
new file mode 100644
index 0000000000..8cf24fcdcc
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/domain_weights.tex"
@@ -0,0 +1,41 @@
+\begin{table*}[t]
+\centering
+\small
+\begin{tabular}{l|c|ccccccccccc}
+\toprule
+Domain ID & Method & 0\% & 10\% & 20\% & 30\% & 40\% & 50\% & 60\% & 70\% & 80\% & 90\% & 100\% \\
+\midrule
+D023 & DIDS & 0.4 & 0.5 & 0.7 & 0.9 & 1.2 & 1.6 & 1.9 & 2.3 & 2.5 & 2.7 & 2.8 \\
+ & DGA & 0.4 & 0.4 & 0.5 & 0.6 & 0.7 & 0.9 & 1.0 & 1.1 & 1.2 & 1.3 & 1.3 \\
+\midrule
+D045 & DIDS & 0.4 & 0.3 & 0.2 & 0.2 & 0.1 & 0.1 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 \\
+ & DGA & 0.4 & 0.3 & 0.3 & 0.2 & 0.2 & 0.2 & 0.2 & 0.1 & 0.1 & 0.1 & 0.1 \\
+\midrule
+D078 & DIDS & 0.4 & 0.6 & 0.8 & 1.0 & 1.5 & 1.2 & 1.2 & 1.3 & 1.3 & 1.4 & 1.3 \\
+ & DGA & 0.4 & 0.5 & 0.6 & 0.7 & 0.8 & 0.8 & 0.9 & 0.4 & 0.6 & 0.3 & 0.4 \\
+\midrule
+D102 & DIDS & 0.4 & 0.7 & 0.9 & 1.2 & 1.4 & 1.5 & 1.6 & 1.7 & 1.7 & 1.7 & 1.7 \\
+ & DGA & 0.4 & 0.5 & 0.6 & 0.7 & 0.8 & 0.8 & 0.9 & 1.0 & 1.0 & 1.0 & 1.0 \\
+\midrule
+D129 & DIDS & 0.4 & 0.2 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\
+ & DGA & 0.4 & 0.3 & 0.2 & 0.2 & 0.1 & 0.1 & 0.1 & 0.1 & 0.1 & 0.0 & 0.0 \\
+\midrule
+D147 & DIDS & 0.4 & 0.3 & 0.2 & 0.2 & 0.1 & 0.1 & 0.1 & 0.1 & 0.1 & 0.0 & 0.0 \\
+ & DGA & 0.4 & 0.4 & 0.3 & 0.3 & 0.3 & 0.2 & 0.2 & 0.2 & 0.2 & 0.2 & 0.1 \\
+\midrule
+D175 & DIDS & 0.4 & 0.5 & 0.5 & 0.6 & 0.6 & 0.7 & 0.7 & 0.7 & 0.7 & 0.8 & 0.8 \\
+ & DGA & 0.4 & 0.5 & 0.5 & 0.5 & 0.6 & 0.6 & 0.6 & 0.6 & 0.6 & 0.6 & 0.6 \\
+\midrule
+D198 & DIDS & 0.4 & 0.7 & 1.0 & 1.3 & 1.5 & 1.7 & 1.8 & 1.9 & 1.9 & 1.9 & 1.9 \\
+ & DGA & 0.4 & 0.6 & 0.7 & 0.8 & 0.9 & 0.9 & 1.0 & 1.0 & 1.0 & 1.0 & 1.0 \\
+\midrule
+D221 & DIDS & 0.4 & 0.2 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\
+ & DGA & 0.4 & 0.3 & 0.2 & 0.1 & 0.1 & 0.1 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\
+\midrule
+D244 & DIDS & 0.4 & 0.4 & 0.5 & 0.4 & 0.4 & 0.4 & 0.4 & 0.5 & 0.4 & 0.4 & 0.4 \\
+ & DGA & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 & 0.4 \\
+\bottomrule
+\end{tabular}
+\caption{Comparison of domain weight evolution between DIDS and DGA across training progress (from 0\% to 100\% completion).}
+\label{tab:domain_weights}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q4.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q4.tex"
new file mode 100644
index 0000000000..29b5ac83c7
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q4.tex"
@@ -0,0 +1,92 @@
+\begin{figure}[t]
+ \subfigure[Impact of Update Frequency]{
+ \begin{minipage}[t]{0.5\linewidth}
+ \centering
+ \begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel={Update Count},
+ ylabel={Average Score},
+ legend style={at={(0.98,0.02)}, anchor=south east},
+ ymajorgrids=true,
+ grid style=dashed,
+ ymin=55,
+ ymax=64,
+ font=\Large,
+ ]
+ \addplot[thick,mark=*,blue] coordinates {
+ (5,58.2)
+ (25,60.1)
+ (45,61.8)
+ (65,62.3)
+ (85,62.4)
+ (95,62.3)
+ };
+ \addplot[thick,mark=square*,red] coordinates {
+ (5,57.8)
+ (25,59.2)
+ (45,59.8)
+ (65,60.1)
+ (85,60.0)
+ (95,59.9)
+ };
+ \addplot[thick,mark=triangle*,brown] coordinates {
+ (5,58.9)
+ (25,58.9)
+ (45,58.9)
+ (65,58.9)
+ (85,58.9)
+ (95,58.9)
+ };
+ \legend{DIDS,DGA,Random}
+ \end{axis}
+ \end{tikzpicture}
+ \end{minipage}
+ }%
+ \subfigure[Impact of Irrelevant Data]{
+ \begin{minipage}[t]{0.5\linewidth}
+ \centering
+ \begin{tikzpicture}[scale=0.45]
+ \begin{axis}[
+ xlabel={Irrelevant Data Ratio (\%)},
+ ylabel={Average Score},
+ legend style={at={(0.98,0.02)}, anchor=south east},
+ ymajorgrids=true,
+ grid style=dashed,
+ ymin=45,
+ ymax=66,
+ font=\Large,
+ ]
+ \addplot[thick,mark=*,blue] coordinates {
+ (0,62.3)
+ (5,62.5)
+ (10,62.4)
+ (15,63.2)
+ (20,63.5)
+ (25,63.1)
+ };
+ \addplot[thick,mark=square*,red] coordinates {
+ (0,58.5)
+ (5,58.3)
+ (10,57.5)
+ (15,58.2)
+ (20,57.8)
+ (25,57.1)
+ };
+ \addplot[thick,mark=triangle*,brown] coordinates {
+ (0,58.9)
+ (5,57.8)
+ (10,57.4)
+ (15,57.0)
+ (20,56.4)
+ (25,54.2)
+ };
+ \legend{DIDS,DGA,Random}
+ \end{axis}
+ \end{tikzpicture}
+ \end{minipage}
+ }
+ \vspace{-0.3cm}
+ \caption{Effects of update frequency and irrelevant data.}
+ \label{fig:parameter-analysis}
+ \vspace{-0.5cm}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q5.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q5.tex"
new file mode 100644
index 0000000000..b8266bfeaf
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/fig-q5.tex"
@@ -0,0 +1,64 @@
+\begin{figure}[t]
+ \subfigure[Impact of Proxy Model Size]{
+ \begin{minipage}[t]{0.45\linewidth}
+ \centering
+ \begin{tikzpicture}[scale=0.41]
+ \begin{axis}[
+ xlabel={Model Size},
+ ylabel={Average Score},
+ symbolic x coords={500M,1B,8B},
+ xtick=data,
+ ymajorgrids=true,
+ grid style=dashed,
+ ymin=60,
+ ymax=64,
+ ybar,
+ bar width=20pt,
+ font=\Large,
+ nodes near coords,
+ nodes near coords align={vertical},
+ ]
+ \addplot[fill=blue!40] coordinates {
+ (500M,62.3)
+ (1B,62.4)
+ (8B,62.5)
+ };
+ \end{axis}
+ \end{tikzpicture}
+ \end{minipage}
+ }%
+ \subfigure[Impact of Domain Count]{
+ \begin{minipage}[t]{0.45\linewidth}
+ \centering
+ \begin{tikzpicture}[scale=0.41]
+ \begin{axis}[
+ xlabel={Number of Domains},
+ ylabel={Average Score},
+ legend style={at={(0.98,0.02)}, anchor=south east},
+ ymajorgrids=true,
+ grid style=dashed,
+ ymin=61,
+ ymax=63,
+ % xtick={18,36,72,144,288,576,1152},
+ font=\Large,
+ % mark options={scale=1.5}
+ ]
+ \addplot[thick,mark=*,blue] coordinates {
+ (18,61.3)
+ (36,62.0)
+ (72,62.3)
+ (144,62.4)
+ (288,62.46)
+ (576,62.52)
+ (1152,62.73)
+ (2304,62.74)
+ };
+ \end{axis}
+ \end{tikzpicture}
+ \end{minipage}
+ }
+ \vspace{-0.3cm}
+ \caption{Effects of model size and domain count.}
+ \label{fig:model-analysis}
+ \vspace{-0.5cm}
+\end{figure}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/full_ablation.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/full_ablation.tex"
new file mode 100644
index 0000000000..9e09cd0e54
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/full_ablation.tex"
@@ -0,0 +1,17 @@
+\begin{table*}[thbp]
+\centering
+\small
+\begin{tabular}{l|ccc|cc|c|cc|c|c}
+\toprule
+\multirow{2}{*}{Variant} & \multicolumn{3}{c|}{Reasoning} & \multicolumn{2}{c|}{Mathematics} & Instruction & \multicolumn{2}{c|}{Commonsense} & Truthfulness & \multirow{2}{*}{Average} \\
+& BBH & BoolQ & GSM8K & MathQA & IFEval & MMLU & PIQA & PubMedQA & TruthfulQA & \\
+\midrule
+DIDS (100k) & \textbf{68.3} & \textbf{86.9} & \textbf{59.0} & \textbf{20.5} & \textbf{55.6} & \textbf{64.9} & \textbf{82.2} & \textbf{80.4} & \textbf{43.0} & \textbf{62.3} \\
+DIDS-GC & 67.7 & 85.7 & 58.2 & 19.7 & 53.0 & 64.4 & 81.8 & 78.9 & 40.1 & 61.1 \\
+DIDS-FIM & 67.2 & 85.0 & 57.4 & 18.6 & 51.9 & 64.1 & 81.5 & 77.2 & 38.5 & 60.2 \\
+DIDS-LT & 67.5 & 85.3 & 57.8 & 19.5 & 51.4 & 64.2 & 81.6 & 77.5 & 38.1 & 60.3 \\
+\bottomrule
+\end{tabular}
+\caption{Comprehensive ablation study of DIDS across all downstream tasks. DIDS-GC replaces gradient-based clustering with BERT semantic clustering, DIDS-FIM removes the FIM-guided impact measurement, and DIDS-LT eliminates the loss trajectory and saturation consideration.}
+\label{tab:full_ablation}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/grid_search.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/grid_search.tex"
new file mode 100644
index 0000000000..98347ac19b
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/grid_search.tex"
@@ -0,0 +1,25 @@
+\begin{table}[t]
+\centering
+\scriptsize
+\begin{tabular}{lccc}
+\toprule
+\textbf{Mixing Strategy} & \textbf{GSM8K} & \textbf{HumanEval} & \textbf{MT-Bench} \\
+\midrule
+Mix[(code,math), 1 general] & 47.53 & 14.63 & 5.76 \\
+Mix[(code,math), 1/4 general] & 48.44 & 15.85 & 5.73 \\
+Mix[(code,math), 1/16 general] & 47.99 & 15.24 & 5.27 \\
+Mix[(code,math), 1/64 general] & 47.23 & 14.63 & 5.16 \\
+Mix[(code,math), 1/256 general] & 48.52 & 16.46 & 4.69 \\
+\midrule
+Mix[1(code,math), general] & 47.53 & 14.63 & 5.76 \\
+Mix[1/4(code,math), general] & 41.31 & 10.97 & 5.81 \\
+Mix[1/16(code,math), general] & 33.20 & 11.58 & 5.76 \\
+Mix[1/64(code,math), general] & 25.17 & 12.19 & 5.84 \\
+Mix[1/256(code,math), general] & 16.52 & 9.14 & 5.82 \\
+\midrule
+DIDS (dynamic) & 52.21 & 18.05 & 5.88 \\
+\bottomrule
+\end{tabular}
+\caption{Performance with different static mixing ratios between specialized and general domains.}
+\label{tab:grid_search}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/llama_openhermes.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/llama_openhermes.tex"
new file mode 100644
index 0000000000..591a098bc9
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/llama_openhermes.tex"
@@ -0,0 +1,33 @@
+\setlength{\tabcolsep}{3pt}
+\begin{table*}[thbp]
+\centering
+\small
+\begin{tabular}{l|ccc|cc|c|cc|c|c}
+\toprule
+\multirow{2}{*}{Method} & \multicolumn{3}{c|}{Reasoning} & \multicolumn{2}{c|}{Mathematics} & Instruction & \multicolumn{2}{c|}{Commonsense} & Truthfulness & \multirow{2}{*}{Average} \\
+& BBH & BoolQ & GSM8K & MathQA & IFEval & MMLU & PIQA & PubMedQA & TruthfulQA & \\
+\midrule
+\multicolumn{11}{c}{Multi-task Optimization} \\
+\midrule
+Llama-3.1-8B & 62.5 & 81.8 & 48.9 & 15.7 & 18.5 & 64.7 & 81.1 & 75.8 & 28.5 & 53.1 \\
++ Full OH-2.5 (1000k) & 67.5 & 86.8 & 67.0 & 17.5 & 60.0 & 64.5 & 81.5 & 77.5 & 39.5 & 62.4 \\
++ Random (100k) & 66.8 & 85.0 & 60.2 & 12.5 & 49.0 & 64.2 & 82.0 & 76.8 & 34.0 & 59.0 \\
++ Uniform (100k) & 65.8 & 83.5 & 59.0 & 12.8 & 48.8 & 64.0 & 81.6 & 76.2 & 33.5 & 58.4 \\
++ Doremi (100k) & 67.2 & 85.5 & 61.5 & 17.0 & 50.5 & 64.5 & 82.0 & 77.5 & 38.0 & 60.4 \\
++ Velocitune (100k) & 67.0 & 85.2 & 60.0 & 16.8 & 50.0 & 64.3 & 81.8 & 77.2 & 37.5 & 60.0 \\
++ Doge (100k) & 67.5 & 85.8 & 61.0 & 17.5 & 52.0 & 64.6 & 82.1 & 78.0 & 39.0 & 60.8 \\
++ DGA (100k) & 66.8 & 85.2 & 61.8 & 17.8 & 46.5 & 64.7 & 81.9 & 76.5 & 36.5 & 59.7 \\
++ DIDS (100k) & \textbf{68.0} & \textbf{86.5} & \textbf{62.5} & \textbf{19.5} & \textbf{56.0} & \textbf{64.8} & \textbf{82.3} & \textbf{79.5} & \textbf{45.0} & \textbf{62.7} \\
+\midrule
+\multicolumn{11}{c}{Single-task Optimization} \\
+\midrule
++ Doremi (100k) & 68.5 & 86.0 & 63.0 & 18.0 & 52.5 & 64.8 & 82.5 & 78.2 & 39.0 & 61.4 \\
++ Velocitune (100k) & 67.8 & 85.8 & 62.5 & 17.8 & 52.0 & 64.6 & 82.2 & 78.0 & 38.6 & 61.0 \\
++ Doge (100k) & 68.0 & 86.5 & 63.0 & 18.2 & 53.0 & 64.8 & 82.3 & 78.8 & 39.5 & 61.6 \\
++ DGA (100k) & 68.4 & 86.2 & 64.0 & 19.0 & 54.5 & 65.0 & 82.5 & 78.5 & 40.5 & 62.1 \\
++ DIDS (100k) & \textbf{69.0} & \textbf{87.2} & \textbf{65.5} & \textbf{21.0} & \textbf{58.5} & \textbf{65.5} & \textbf{82.8} & \textbf{80.5} & \textbf{46.5} & \textbf{64.1} \\
+\bottomrule
+\end{tabular}
+\caption{Performance comparison of Llama-3.1-8B model trained on OpenHermes-2.5 dataset under different sampling strategies.}
+\label{tab:llama_openhermes}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_openhermes.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_openhermes.tex"
new file mode 100644
index 0000000000..1e8dada2dd
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_openhermes.tex"
@@ -0,0 +1,33 @@
+\setlength{\tabcolsep}{3pt}
+\begin{table*}[thbp]
+\centering
+\small
+\begin{tabular}{l|ccc|cc|c|cc|c|c}
+\toprule
+\multirow{2}{*}{Method} & \multicolumn{3}{c|}{Reasoning} & \multicolumn{2}{c|}{Mathematics} & Instruction & \multicolumn{2}{c|}{Commonsense} & Truthfulness & \multirow{2}{*}{Average} \\
+& BBH & BoolQ & GSM8K & MathQA & IFEval & MMLU & PIQA & PubMedQA & TruthfulQA & \\
+\midrule
+\multicolumn{11}{c}{Multi-task Optimization} \\
+\midrule
+Mixtral-7B & 56.0 & 84.7 & 36.9 & 13.2 & 36.6 & 61.9 & 81.6 & 77.8 & 41.3 & 54.4 \\
++ Full OH-2.5 (1000k) & 59.8 & 87.9 & 64.2 & 14.4 & 45.8 & 62.0 & 82.6 & 76.6 & 50.5 & 60.4 \\
++ Random (100k) & 58.5 & 86.5 & 54.0 & 10.5 & 42.0 & 61.8 & 82.0 & 76.0 & 46.0 & 57.5 \\
++ Uniform (100k) & 57.8 & 86.0 & 52.5 & 10.8 & 41.8 & 61.7 & 81.8 & 75.5 & 45.5 & 57.0 \\
++ Doremi (100k) & 59.0 & 87.0 & 55.5 & 13.8 & 43.5 & 62.0 & 82.2 & 76.2 & 48.0 & 58.6 \\
++ Velocitune (100k) & 58.7 & 86.8 & 54.5 & 13.5 & 43.0 & 61.9 & 82.0 & 76.0 & 47.5 & 58.2 \\
++ Doge (100k) & 59.2 & 87.2 & 55.8 & 14.0 & 44.0 & 62.1 & 82.3 & 76.8 & 48.5 & 58.9 \\
++ DGA (100k) & 58.5 & 86.8 & 56.5 & 14.2 & 41.0 & 62.2 & 82.1 & 75.8 & 47.0 & 58.2 \\
++ DIDS (100k) & \textbf{60.0} & \textbf{87.5} & \textbf{58.0} & \textbf{15.8} & \textbf{45.0} & \textbf{62.3} & \textbf{82.5} & \textbf{77.5} & \textbf{52.0} & \textbf{60.1} \\
+\midrule
+\multicolumn{11}{c}{Single-task Optimization} \\
+\midrule
++ Doremi (100k) & 60.5 & 87.5 & 58.0 & 14.5 & 44.5 & 62.2 & 82.8 & 77.0 & 49.0 & 59.6 \\
++ Velocitune (100k) & 60.0 & 87.2 & 57.5 & 14.2 & 44.0 & 62.0 & 82.5 & 76.8 & 48.5 & 59.2 \\
++ Doge (100k) & 60.2 & 87.8 & 58.2 & 14.8 & 44.8 & 62.3 & 82.7 & 77.2 & 49.5 & 59.7 \\
++ DGA (100k) & 60.8 & 87.5 & 59.5 & 15.0 & 46.0 & 62.5 & 82.9 & 77.0 & 50.0 & 60.1 \\
++ DIDS (100k) & \textbf{61.5} & \textbf{88.0} & \textbf{61.0} & \textbf{16.5} & \textbf{47.5} & \textbf{62.8} & \textbf{83.0} & \textbf{78.0} & \textbf{53.0} & \textbf{61.3} \\
+\bottomrule
+\end{tabular}
+\caption{Performance comparison of Mixtral-7B model trained on OpenHermes-2.5 dataset under different sampling strategies.}
+\label{tab:mixtral_openhermes}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_tulu.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_tulu.tex"
new file mode 100644
index 0000000000..4ebde44c62
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/mixtral_tulu.tex"
@@ -0,0 +1,33 @@
+\setlength{\tabcolsep}{3pt}
+\begin{table*}[thbp]
+\centering
+\small
+\begin{tabular}{l|ccc|cc|c|cc|c|c}
+\toprule
+\multirow{2}{*}{Method} & \multicolumn{3}{c|}{Reasoning} & \multicolumn{2}{c|}{Mathematics} & Instruction & \multicolumn{2}{c|}{Commonsense} & Truthfulness & \multirow{2}{*}{Average} \\
+& BBH & BoolQ & GSM8K & MathQA & IFEval & MMLU & PIQA & PubMedQA & TruthfulQA & \\
+\midrule
+\multicolumn{11}{c}{Multi-task Optimization} \\
+\midrule
+Mixtral-7B & 56.0 & 84.7 & 36.9 & 13.2 & 36.6 & 61.9 & 81.6 & 77.8 & 41.3 & 54.4 \\
++ Full Data (929k) & 61.0 & 87.8 & 52.5 & 14.0 & 63.2 & 62.0 & 81.8 & 79.5 & 42.0 & 60.4 \\
++ Random (100k) & 60.2 & 86.7 & 47.8 & 9.5 & 55.3 & 61.8 & 82.2 & 78.5 & 43.0 & 58.3 \\
++ Uniform (100k) & 59.1 & 85.3 & 46.5 & 9.8 & 55.0 & 61.5 & 81.8 & 77.5 & 42.8 & 57.7 \\
++ Doremi (100k) & 60.5 & 86.5 & 48.0 & 15.0 & 56.5 & 62.0 & 82.0 & 79.0 & 46.5 & 59.5 \\
++ Velocitune (100k) & 60.2 & 86.2 & 46.0 & 14.8 & 56.0 & 61.8 & 81.9 & 78.8 & 45.8 & 59.1 \\
++ Doge (100k) & 60.8 & 86.8 & 47.0 & 15.3 & 57.8 & 62.1 & 82.3 & 79.6 & 47.2 & 59.9 \\
++ DGA (100k) & 60.0 & 86.3 & 48.0 & 15.8 & 53.5 & 62.2 & 82.0 & 77.0 & 44.5 & 58.8 \\
++ DIDS (100k) & \textbf{61.5} & \textbf{87.0} & \textbf{48.5} & \textbf{17.8} & \textbf{60.0} & \textbf{62.4} & \textbf{82.5} & \textbf{81.0} & \textbf{50.5} & \textbf{61.2} \\
+\midrule
+\multicolumn{11}{c}{Single-task Optimization} \\
+\midrule
++ Doremi (100k) & 61.8 & 86.8 & 50.0 & 15.8 & 57.5 & 62.3 & 82.8 & 79.8 & 47.0 & 60.4 \\
++ Velocitune (100k) & 61.0 & 86.5 & 49.5 & 15.5 & 57.0 & 62.0 & 82.3 & 79.5 & 46.5 & 60.0 \\
++ Doge (100k) & 61.3 & 87.0 & 50.2 & 16.0 & 57.8 & 62.4 & 82.5 & 80.0 & 47.5 & 60.5 \\
++ DGA (100k) & 61.7 & 86.8 & 51.0 & 16.7 & 58.8 & 62.6 & 82.6 & 79.8 & 48.0 & 60.9 \\
++ DIDS (100k) & \textbf{62.5} & \textbf{87.5} & \textbf{52.0} & \textbf{18.5} & \textbf{62.0} & \textbf{63.0} & \textbf{83.0} & \textbf{82.0} & \textbf{52.0} & \textbf{62.5} \\
+\bottomrule
+\end{tabular}
+\caption{Performance comparison of Mixtral-7B model trained on Tulu-v3 dataset under different sampling strategies.}
+\label{tab:mixtral_tulu}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/src/mixture-architecture.pdf" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/src/mixture-architecture.pdf"
new file mode 100644
index 0000000000..e914db05f3
Binary files /dev/null and "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/src/mixture-architecture.pdf" differ
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/t1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/t1.tex"
new file mode 100644
index 0000000000..49148ed734
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/t1.tex"
@@ -0,0 +1,32 @@
+\begin{table}[t]
+\footnotesize
+\centering
+\begin{tabular}{lrr}
+\toprule
+\textbf{Dataset} & \textbf{Samples} & \textbf{Percentage (\%)} \\
+\midrule
+Tulu 3 Persona MATH & 149,960 & 16.0 \\
+Evol CodeAlpaca & 107,276 & 11.4 \\
+FLAN v2 & 89,982 & 9.6 \\
+NuminaMath-TIR & 64,312 & 6.8 \\
+Tulu 3 Persona GSM & 49,980 & 5.3 \\
+Tulu 3 WildGuardMix & 50,000 & 5.3 \\
+Tulu 3 WildJailbreak & 50,000 & 5.3 \\
+Tulu 3 Persona Python & 34,999 & 3.7 \\
+Tulu 3 Persona IF & 29,980 & 3.2 \\
+Tulu 3 Persona Algebra & 20,000 & 2.1 \\
+CoCoNot & 10,983 & 1.2 \\
+No Robots & 9,500 & 1.0 \\
+OpenAssistant Guanaco & 7,132 & 0.8 \\
+TableGPT & 5,000 & 0.5 \\
+Tulu 3 Hardcoded & 240 & 0.03 \\
+Aya & 100,000 & 10.6 \\
+WildChat GPT-4 & 100,000 & 10.6 \\
+SciRIFF & 10,000 & 1.1 \\
+\midrule
+\textbf{Total} & 939,344 & 100.0 \\
+\bottomrule
+\end{tabular}
+\caption{Distribution of training data across different sources.}
+\label{tab:data_dist}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q1.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q1.tex"
new file mode 100644
index 0000000000..077c0d0dea
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q1.tex"
@@ -0,0 +1,37 @@
+\setlength{\tabcolsep}{3pt}
+\begin{table*}[t]
+\footnotesize
+\begin{center}
+\begin{tabular}{l|ccccccccc|c}
+\toprule
+\multirow{2}{*}{Method} & \multicolumn{2}{c}{Reasoning} & \multicolumn{2}{c}{Mathematics} & \multicolumn{1}{c}{Instruction} & \multicolumn{2}{c}{Commonsense} & \multicolumn{1}{c}{Domain} & \multicolumn{1}{c|}{Truthfulness} & \multirow{2}{*}{Average} \\
+\cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8} \cmidrule(lr){9-9} \cmidrule(lr){10-10}
+& BBH & BoolQ & GSM8K & MathQA & IFEval & MMLU & PIQA & PubMedQA & TruthfulQA & \\
+\midrule
+\multicolumn{11}{c}{\textit{Multi-task Optimization}} \\
+\midrule
+Llama-3.1-8B & \colorbox[rgb]{1.0,0.7,0.7}{62.5} & \colorbox[rgb]{1.0,0.7,0.7}{81.8} & \colorbox[rgb]{1.0,0.7,0.7}{48.9} & \colorbox[rgb]{1.0,0.9,0.9}{15.7} & \colorbox[rgb]{1.0,0.7,0.7}{18.5} & 64.7 & 81.1 & \colorbox[rgb]{1.0,0.8,0.8}{75.8} & \colorbox[rgb]{1.0,0.7,0.7}{28.5} & \colorbox[rgb]{1.0,0.7,0.7}{53.1} \\
++ Full Data (929k) & \colorbox[rgb]{0.8,0.8,1.0}{68.0} & \colorbox[rgb]{0.8,0.8,1.0}{87.3} & \colorbox[rgb]{0.8,0.8,1.0}{65.2} & 16.2 & \colorbox[rgb]{0.8,0.8,1.0}{61.2} & \colorbox[rgb]{1.0,0.85,0.85}{64.3} & \colorbox[rgb]{1.0,0.85,0.85}{81.0} & \colorbox[rgb]{0.85,0.85,1.0}{78.0} & \colorbox[rgb]{1.0,0.8,0.8}{29.5} & \colorbox[rgb]{0.85,0.85,1.0}{61.2} \\
++ Random (100k) & 67.4 & 85.6 & \colorbox[rgb]{1.0,0.8,0.8}{58.9} & \colorbox[rgb]{1.0,0.7,0.7}{11.4} & \colorbox[rgb]{1.0,0.9,0.9}{48.2} & 64.0 & \colorbox[rgb]{0.9,0.9,1.0}{82.0} & 77.4 & 31.5 & 58.9 \\
++ Uniform (100k) & \colorbox[rgb]{1.0,0.8,0.8}{66.2} & \colorbox[rgb]{1.0,0.8,0.8}{83.2} & \colorbox[rgb]{1.0,0.8,0.8}{57.5} & \colorbox[rgb]{1.0,0.8,0.8}{11.8} & 48.2 & 64.1 & 81.5 & \colorbox[rgb]{1.0,0.8,0.8}{76.0} & 31.2 & \colorbox[rgb]{1.0,0.8,0.8}{57.7} \\
+\midrule
++ Doremi (100k) & \colorbox[rgb]{0.9,0.9,1.0}{67.5} & 85.8 & 58.8 & 17.5 & 49.8 & 64.5 & 81.9 & 77.8 & 35.8 & 59.9 \\
++ Velocitune (100k) & 67.2 & 85.5 & 56.2 & 17.2 & 49.0 & 64.4 & 81.7 & 77.5 & 35.0 & 59.3 \\
++ Doge (100k) & \colorbox[rgb]{0.85,0.85,1.0}{67.8} & \colorbox[rgb]{0.85,0.85,1.0}{86.0} & \colorbox[rgb]{0.9,0.9,1.0}{57.5} & \colorbox[rgb]{0.85,0.85,1.0}{17.8} & \colorbox[rgb]{0.85,0.85,1.0}{51.2} & 64.6 & \colorbox[rgb]{0.85,0.85,1.0}{82.0} & \colorbox[rgb]{0.9,0.9,1.0}{78.5} & \colorbox[rgb]{0.8,0.8,1.0}{37.2} & \colorbox[rgb]{0.9,0.9,1.0}{60.2} \\
++ DGA (100k) & 67.0 & 85.4 & 58.8 & \colorbox[rgb]{0.8,0.8,1.0}{18.2} & \colorbox[rgb]{1.0,0.7,0.7}{42.1} & \colorbox[rgb]{0.9,0.9,1.0}{64.8} & 81.8 & \colorbox[rgb]{1.0,0.7,0.7}{75.2} & \colorbox[rgb]{1.0,0.8,0.8}{33.4} & \colorbox[rgb]{1.0,0.9,0.9}{58.5} \\
++ DIDS (100k) & \colorbox[rgb]{0.8,0.8,1.0}{68.3} & \colorbox[rgb]{0.85,0.85,1.0}{86.9} & \colorbox[rgb]{0.85,0.85,1.0}{59.0} & \colorbox[rgb]{0.8,0.8,1.0}{20.5} & \colorbox[rgb]{0.9,0.9,1.0}{55.6} & \colorbox[rgb]{0.8,0.8,1.0}{64.9} & \colorbox[rgb]{0.8,0.8,1.0}{82.2} & \colorbox[rgb]{0.8,0.8,1.0}{80.4} & \colorbox[rgb]{0.8,0.8,1.0}{43.0} & \colorbox[rgb]{0.8,0.8,1.0}{62.3} \\
+\midrule
+\multicolumn{11}{c}{\textit{Single-task Optimization}} \\
+\midrule
++ Doremi (100k) & \colorbox[rgb]{0.9,0.9,1.0}{68.8} & 86.2 & 60.8 & 18.2 & 51.2 & 64.8 & \colorbox[rgb]{0.8,0.8,1.0}{82.6} & 78.5 & 37.2 & 60.9 \\
++ Velocitune (100k) & 68.0 & 86.0 & 60.5 & 18.0 & 50.8 & 64.5 & 82.0 & 78.2 & 36.8 & 60.5 \\
++ Doge (100k) & 68.2 & \colorbox[rgb]{0.9,0.9,1.0}{86.8} & 60.9 & 18.4 & 51.5 & 64.9 & 82.2 & \colorbox[rgb]{0.8,0.8,1.0}{79.0} & 37.5 & 61.0 \\
++ DGA (100k) & \colorbox[rgb]{0.85,0.85,1.0}{68.6} & \colorbox[rgb]{0.85,0.85,1.0}{86.5} & \colorbox[rgb]{0.85,0.85,1.0}{61.8} & \colorbox[rgb]{0.85,0.85,1.0}{19.2} & \colorbox[rgb]{0.85,0.85,1.0}{53.2} & \colorbox[rgb]{0.85,0.85,1.0}{65.2} & \colorbox[rgb]{0.85,0.85,1.0}{82.4} & \colorbox[rgb]{0.9,0.9,1.0}{78.8} & \colorbox[rgb]{0.85,0.85,1.0}{38.5} & \colorbox[rgb]{0.85,0.85,1.0}{61.6} \\
++ DIDS (100k) & \colorbox[rgb]{0.8,0.8,1.0}{69.2} & \colorbox[rgb]{0.8,0.8,1.0}{87.5} & \colorbox[rgb]{0.8,0.8,1.0}{63.0} & \colorbox[rgb]{0.8,0.8,1.0}{21.5} & \colorbox[rgb]{0.8,0.8,1.0}{57.5} & \colorbox[rgb]{0.8,0.8,1.0}{65.8} & \colorbox[rgb]{0.8,0.8,1.0}{83.0} & \colorbox[rgb]{0.8,0.8,1.0}{81.2} & \colorbox[rgb]{0.8,0.8,1.0}{44.8} & \colorbox[rgb]{0.8,0.8,1.0}{63.7} \\
+\bottomrule
+\end{tabular}
+\caption{The overall performance comparison. Cells with \colorbox[rgb]{0.9,0.9,1.0}{blue background indicate high scores}, while \colorbox[rgb]{1.0,0.9,0.9}{red background indicates low scores}. The top section shows results when optimizing for multiple downstream tasks simultaneously, while the bottom section shows results when optimizing for individual tasks.}
+\label{tab:main_results}
+\end{center}
+\vspace{-0.6cm}
+\end{table*}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q2.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q2.tex"
new file mode 100644
index 0000000000..1876c21d21
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q2.tex"
@@ -0,0 +1,18 @@
+\begin{table}[t]
+\footnotesize
+\begin{center}
+\begin{tabular}{l|cccc|c}
+\toprule
+Variant & BBH & MathQA & IFEval & TruthfulQA & Avg \\
+\midrule
+DIDS (100k) & \textbf{68.3} & \textbf{20.5} & \textbf{55.6} & \textbf{43.0} & \textbf{46.9} \\
+DIDS-GC & 67.7 & 19.7 & 53.0 & 40.1 & 45.1 \\
+DIDS-FIM & 67.2 & 18.6 & 51.9 & 38.5 & 44.0 \\
+DIDS-LT & 67.5 & 19.5 & 51.4 & 38.1 & 44.1 \\
+\bottomrule
+\end{tabular}
+\caption{Ablation results. We progressively remove key components: gradient-based clustering (DIDS-GC), FIM-guided impact measurement (DIDS-FIM), and loss trajectory consideration (DIDS-LT).}
+\label{tab:ablation}
+\end{center}
+\vspace{-0.6cm}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q3.tex" "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q3.tex"
new file mode 100644
index 0000000000..4783306430
--- /dev/null
+++ "b/writing/\345\217\202\350\200\203\350\256\272\346\226\207/emnlp_Data_Mixture_camera_ready/latex/figure/tab-q3.tex"
@@ -0,0 +1,21 @@
+\setlength{\tabcolsep}{3pt}
+\begin{table}[t]
+\scriptsize
+\centering
+\begin{tabular}{l|cc|cc}
+\toprule
+\multirow{2}{*}{Component} & \multicolumn{2}{c|}{TFLOPs} & \multicolumn{2}{c}{GPU Hours} \\
+& DGA & DIDS & DGA & DIDS \\
+\midrule
+Base Training & $5.47 \times 10^{4}$ & $5.47 \times 10^{4}$ & 101.6 & 101.6 \\
+Cluster (BERT vs. Gradient) & $7.77 \times 10^{2}$ & $1.87 \times 10^{3}$ & 1.5 & 3.3 \\
+Impact (Gradient vs. FIM) & $9.86 \times 10^{1}$ & $1.78 \times 10^{2}$ & 0.1 & 0.2 \\
+Loss Trajectory Consideration & - & $< 10^{-1}$ & - & < 0.1 \\
+\midrule
+Total & $5.56 \times 10^{4}$ & $5.67 \times 10^{4}$ & 103.2 & 105.2 \\
+\bottomrule
+\end{tabular}
+\caption{Computational cost analysis of different components in DIDS. Base training refers to standard training of an 8B parameter model on 1B tokens.}
+\label{tab:efficiency}
+\vspace{-0.7cm}
+\end{table}
\ No newline at end of file
diff --git "a/writing/\345\256\236\351\252\214/\345\256\236\351\252\214\350\241\250\346\240\274.md" "b/writing/\345\256\236\351\252\214/\345\256\236\351\252\214\350\241\250\346\240\274.md"
new file mode 100644
index 0000000000..e69de29bb2
diff --git "a/\351\203\250\347\275\262\346\225\231\347\250\213.md" "b/\351\203\250\347\275\262\346\225\231\347\250\213.md"
new file mode 100644
index 0000000000..68e604298d
--- /dev/null
+++ "b/\351\203\250\347\275\262\346\225\231\347\250\213.md"
@@ -0,0 +1,321 @@
+# Trinity 项目部署教程
+
+## 一、环境安装
+
+### 1.1 安装项目依赖
+
+```bash
+# 安装项目(开发模式)
+pip install -e .
+
+# 如果不需要更新依赖,使用以下命令
+pip install -e . --no-deps
+```
+
+### 1.2 安装特殊依赖包
+
+以下是必需的特定版本依赖:
+
+```bash
+# vLLM(推荐版本)
+pip install vllm==10.0.2
+
+# Click
+pip install click==0.8.1
+
+# Flash Attention
+pip install flash-attn==2.8.1 --no-build-isolation
+```
+
+### 1.3 配置环境变量
+
+需要设置以下环境变量:
+
+```bash
+# Hugging Face Token(必需)
+export HF_TOKEN=your_huggingface_token
+
+# Weights & Biases API Key(如果使用 wandb 监控)
+export WANDB_API_KEY=your_wandb_api_key
+```
+
+---
+
+## 二、项目结构
+
+### 2.1 目录说明
+
+- **配置文件路径**: `examples/R3L/`
+- **工作流路径**: `trinity/common/workflows/envs/R3L/`
+
+### 2.2 支持的环境(5个)
+
+1. **ALFWORLD** - 交互式文本游戏环境
+2. **Webshop** - 电商购物环境(需要约 1T 内存)
+3. **ScienceWorld** - 科学实验环境
+4. **DAPO** - 数学问题环境
+5. **CountDown** - 倒计时游戏环境
+
+### 2.3 支持的算法(5种)
+
+**基线方法**:
+- GRPO
+- OPMD
+- RAFT
+
+**我们的方法**:
+- OPMD + Reweight Adv
+- R3L
+
+### 2.4 实验规模
+
+- 环境数量: 5
+- 算法数量: 5
+- 模型规模: 2(Qwen-2.5-1.5B 和 Qwen-2.5-7B)
+- **总实验数**: 5 × 5 × 2 = **50 个实验**
+
+> **注意**: Webshop 环境需要约 1T 内存,如果内存不足,建议先跳过该环境。
+
+---
+
+## 三、参数配置
+
+### 3.1 常用配置参数
+
+以 `examples/R3L/alfworld/grpo_1.5B.yaml` 为例:
+
+| 参数 | 说明 | 示例值 |
+|------|------|--------|
+| `gpu_per_node` | 节点上的 GPU 总数 | 8 |
+| `engine_num` | 用于 rollout 的 GPU 数量 | 6(剩余 2 卡用于训练)|
+| `gpu_memory_utilization` | 每张卡的显存使用率 | 0.7(70%,防止显存溢出)|
+| `save_interval` | 每 N 步保存一次模型 | 20 |
+| `monitor_type` | 监控工具类型 | wandb / tensorboard / none |
+
+### 3.2 GPU 资源分配建议
+
+- **1.5B 模型**: 2 卡训练,其余卡用于 rollout
+- **7B 模型**: 4 卡训练,其余卡用于 rollout
+
+### 3.3 checkpoint 说明
+
+- 使用相同启动命令会自动从上次保存的 checkpoint 继续训练
+- 如需从头训练,需删除对应的 checkpoint 文件夹和 .db 文件
+
+---
+
+## 四、运行实验
+
+### 4.1 启动 Ray
+
+在项目根目录下启动 Ray(只需启动一次):
+
+```bash
+# 指定使用的 GPU 并启动 Ray
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ray start --head
+```
+
+> **重要**: 必须在项目根目录启动 Ray,确保相对路径正确。
+
+### 4.2 运行训练
+
+基本命令格式:
+
+```bash
+trinity run --config
+```
+
+#### 示例:ALFWORLD 环境(1.5B 模型)
+
+```bash
+# GRPO 算法
+trinity run --config examples/R3L/alfworld/grpo_1.5B.yaml
+
+# OPMD 算法
+trinity run --config examples/R3L/alfworld/opmd_1.5B.yaml
+
+# OPMD + R3L 算法
+trinity run --config examples/R3L/alfworld/opmd_R3L_1.5B.yaml
+
+# OPMD + Reweight Adv 算法
+trinity run --config examples/R3L/alfworld/opmd_reweight_adv_1.5B.yaml
+
+# RAFT 算法
+trinity run --config examples/R3L/alfworld/RAFT_1.5B.yaml
+```
+
+#### 示例:ALFWORLD 环境(7B 模型)
+
+```bash
+trinity run --config examples/R3L/alfworld/grpo_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_R3L_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_reweight_adv_7B.yaml
+trinity run --config examples/R3L/alfworld/RAFT_7B.yaml
+```
+
+其他环境的配置文件位于 `examples/R3L/<环境名>/` 目录下,使用方式相同。
+
+### 4.3 从头开始训练
+
+如需重新开始实验,需删除相应的 checkpoint 和数据库文件:
+
+```bash
+# 删除数据库文件
+rm alfworld_grpo_baseline_1.5B.db
+
+# 删除 checkpoint 目录
+rm -r checkpoints/ALFWORLD/ALFWORLD_RFT_Qwen_1.5B_GRPO_Baseline/
+```
+
+### 4.4 关闭 Ray
+
+```bash
+# 正常关闭
+ray stop
+
+# 强制关闭
+ray stop --force
+```
+
+---
+
+## 五、特殊环境数据准备
+
+### 5.1 ALFWORLD 环境
+
+#### 步骤 1: 安装 alfworld
+
+```bash
+pip install alfworld
+```
+
+#### 步骤 2: 下载数据
+
+```bash
+# 方式 1: 自动下载到 ~/.cache/alfworld/
+alfworld-download
+
+# 方式 2: 指定下载路径
+alfworld-download --data-dir ./alf-data
+```
+
+#### 步骤 3: 修改数据路径配置
+
+编辑 `examples/R3L/alfworld/get_alfworld_data.py` 文件:
+
+```python
+# 第 11 行,修改为你的实际数据路径
+alfworld_data_root = "/你的本地路径/alfworld/json_2.1.1"
+```
+
+> **注意**: 保留路径末尾的 `json_2.1.1`
+
+#### 步骤 4: 处理数据
+
+```bash
+cd examples/R3L/alfworld
+python get_alfworld_data.py
+```
+
+处理后的数据将保存在 `examples/R3L/alfworld/` 目录下。
+
+#### 步骤 5: 开始训练
+
+```bash
+trinity run --config examples/R3L/alfworld/grpo_7B.yaml
+```
+
+### 5.2 Webshop 环境
+
+(数据准备步骤待补充)
+
+### 5.3 ScienceWorld 环境
+
+(数据准备步骤待补充)
+
+---
+
+## 六、实验监控与结果记录
+
+### 6.1 实验运行建议
+
+- 实验不需要等进度条跑完
+- 观察到收敛或记录峰值后即可停止
+- 如果崩溃,记录崩溃前的高点
+
+### 6.2 需要记录的指标
+
+#### 所有实验
+- **Eval 结果**(必需)
+
+#### 特殊环境额外指标
+
+| 环境 | 额外指标 |
+|------|----------|
+| DAPO | 多个测试集的结果 |
+| Webshop | success、平均使用 step |
+| R3L 算法 | rollout 中的 improve 比例 |
+
+---
+
+## 七、待开发功能
+
+### 7.1 待新增算法
+- DAPO
+
+### 7.2 待新增环境
+- Sokoban
+- Gym-Card (包括 BlackJack、EZPoints、NumberLine)
+
+---
+
+## 八、快速参考
+
+### 8.1 完整工作流程
+
+```bash
+# 1. 设置环境变量
+export HF_TOKEN=your_token
+export WANDB_API_KEY=your_key
+
+# 2. 启动 Ray
+cd /path/to/trinity
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ray start --head
+
+# 3. 运行实验
+trinity run --config examples/R3L/alfworld/grpo_1.5B.yaml
+
+## 4. 实验结束后关闭 Ray
+#ray stop
+```
+
+### 8.2 常见问题
+
+**Q: 如何更改使用的 GPU?**
+A: 修改 `CUDA_VISIBLE_DEVICES` 环境变量和配置文件中的 `gpu_per_node`、`engine_num` 参数。
+
+**Q: 显存不足怎么办?**
+A: 降低 `gpu_memory_utilization` 参数(如从 0.7 改为 0.5)。
+
+**Q: 如何切换监控工具?**
+A: 修改配置文件中的 `monitor_type` 参数(wandb / tensorboard / none)。
+
+**Q: checkpoint 在哪里?**
+A: 在 `checkpoints/<环境名>/<具体实验名>/` 目录下。
+
+---
+
+## 九、目前需要的实验任务列表
+
+### ALFWORLD 环境(7B 模型)
+
+```bash
+trinity run --config examples/R3L/alfworld/grpo_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_R3L_7B.yaml
+trinity run --config examples/R3L/alfworld/opmd_reweight_adv_7B.yaml
+trinity run --config examples/R3L/alfworld/RAFT_7B.yaml
+```
+
+(其他环境和模型的实验列表可参考 `examples/R3L/` 目录结构)