Skip to content

Commit f3f0846

Browse files
authored
General Multi-turn FrozenLake (#429)
1 parent ffc6cac commit f3f0846

File tree

26 files changed

+835
-125
lines changed

26 files changed

+835
-125
lines changed
128 KB
Loading
134 KB
Loading

docs/sphinx_doc/source/tutorial/align_with_verl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# How to align configuration with veRL
1+
# Align configuration with veRL
22

33
This guide provides guidance for users familiar with [veRL](https://github.com/volcengine/verl) to align the parameters and metrics in Trinity-RFT with the ones in veRL.
44

docs/sphinx_doc/source_zh/tutorial/align_with_verl.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 如何和 veRL 对齐配置
1+
# veRL 对齐训练配置
22

33
本指南为熟悉 [veRL](https://github.com/volcengine/verl) 的用户提供了将 Trinity-RFT 与 veRL 的参数和指标对齐的方法。
44

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Frozen Lake Agent
2+
3+
This example shows the implementation of a Frozen Lake agent using the Agentscope framework. The agent is designed to navigate a frozen lake environment by interpreting observations and selecting appropriate actions.
4+
5+
The data preparation and environment setup are the same as those in the [GRPO Frozen Lake example](../grpo_frozen_lake/README.md). Please follow the instructions there to set up the environment and prepare the dataset.
6+
7+
8+
## Results
9+
10+
The configuration file for this example is located at [`frozenlake_agent.yaml`](./frozenlake_agent.yaml). We use Qwen2.5-3B-Instruct as the base LLM for the agent.
11+
12+
The training and evaluation dataset is generated using the same process as described in the [GRPO Frozen Lake example](../grpo_frozen_lake/README.md) with the following command:
13+
14+
```bash
15+
cd examples/grpo_frozen_lake
16+
python get_frozen_lake_data.py --test_size 50 --map_max_size 10
17+
```
18+
19+
The training result is shown below, demonstrating the reward during training and evaluation phases:
20+
21+
![](../../docs/sphinx_doc/assets/agentscope_frozenlake_reward_train.png)
22+
![](../../docs/sphinx_doc/assets/agentscope_frozenlake_reward_bench.png)

examples/agentscope_frozenlake/__init__.py

Whitespace-only changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import re
2+
3+
from agentscope.agent import ReActAgent
4+
from agentscope.formatter import OpenAIChatFormatter
5+
from agentscope.message import Msg
6+
from agentscope.model import OpenAIChatModel
7+
8+
from examples.agentscope_frozenlake.utils import SYSTEM_PROMPT, FrozenLakeAction
9+
10+
INVALID_ACTION = "still"
11+
VALID_ACTIONS = {
12+
"left": 1,
13+
"down": 2,
14+
"right": 3,
15+
"up": 4,
16+
}
17+
18+
19+
class FrozenLakeAgent:
20+
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
21+
self.model = model
22+
self.agent = ReActAgent(
23+
name="frozenlake_agent",
24+
sys_prompt=SYSTEM_PROMPT,
25+
model=model,
26+
formatter=OpenAIChatFormatter(),
27+
max_iters=2,
28+
)
29+
self.response_structure = FrozenLakeAction
30+
self.current_step = 0
31+
self.last_action = None
32+
self.last_observation = None
33+
self.max_steps = max_steps
34+
35+
def get_prompt(self, observation: str) -> str:
36+
prompt = (
37+
f"Current Observation ({self.current_step}): \n"
38+
+ observation
39+
+ "\n"
40+
+ "You have not achieved the goal, P has not reached G yet. Please give the next action."
41+
)
42+
if self.current_step > 0 and self.last_action is not None:
43+
if self.last_observation == observation:
44+
prompt += "\nYour last response is invalid. Your position didn't change at all. You may need to recheck your thinking process, action outputted, and the format of response. Remember, you should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```."
45+
46+
if self.max_steps is not None and self.max_steps - self.current_step > 0:
47+
prompt += (
48+
f"\nThe maximum number of steps remaining is {self.max_steps - self.current_step}."
49+
)
50+
51+
return prompt
52+
53+
def get_action(self, msg: Msg) -> str:
54+
response: str = msg.content if isinstance(msg.content, str) else msg.content[0].get("text")
55+
action = INVALID_ACTION
56+
57+
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
58+
59+
if matches:
60+
last_match_content = matches[-1].strip()
61+
action = last_match_content.lower()
62+
if action not in VALID_ACTIONS:
63+
action = INVALID_ACTION
64+
65+
return action
66+
67+
async def step(self, current_observation: str) -> str:
68+
prompt = self.get_prompt(current_observation)
69+
response = await self.agent.reply(Msg("user", prompt, role="user"))
70+
action = self.get_action(response)
71+
self.last_observation = current_observation
72+
self.last_action = action
73+
self.current_step += 1
74+
return action
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import copy
2+
from typing import Dict, Optional, Tuple
3+
4+
import numpy as np
5+
6+
from examples.agentscope_frozenlake.utils import generate_random_map, get_goal_position
7+
from trinity.utils.log import get_logger
8+
9+
try:
10+
from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv as GymFrozenLakeEnv
11+
except ImportError:
12+
GymFrozenLakeEnv = object
13+
14+
15+
class FrozenLakeEnv(GymFrozenLakeEnv):
16+
# Map gym state in integer
17+
MAP_LOOKUP = {
18+
b"P": 0,
19+
b"F": 1,
20+
b"H": 2,
21+
b"G": 3,
22+
}
23+
24+
# Define rules to transform to rendered text observation of the environment
25+
GRID_LOOKUP = {
26+
0: " P \t", # player
27+
1: " _ \t", # frozen
28+
2: " O \t", # hole
29+
3: " G \t", # goal
30+
4: " X \t", # player fall into hole
31+
5: " √ \t", # player on goal
32+
}
33+
34+
ACTION_LOOKUP = {
35+
"still": 0,
36+
"left": 1,
37+
"down": 2,
38+
"right": 3,
39+
"up": 4,
40+
}
41+
42+
INVALID_ACTION = 0
43+
PENALTY_FOR_INVALID = -1
44+
45+
def __init__(
46+
self,
47+
max_steps: int = 8,
48+
desc: Optional[str] = None,
49+
is_slippery: bool = False,
50+
size: int = 8,
51+
p: float = 0.8,
52+
seed: int = 42,
53+
):
54+
self.logger = get_logger()
55+
self.max_steps = max_steps or 8
56+
self.desc = desc
57+
self.is_slippery = is_slippery
58+
self.size = size
59+
self.p = p
60+
self.seed = seed
61+
try:
62+
import gymnasium as gym
63+
from gymnasium.envs.toy_text.frozen_lake import (
64+
FrozenLakeEnv as GymFrozenLakeEnv,
65+
)
66+
except ImportError as e:
67+
error_message = (
68+
f"Gymnasium is not installed. Please install gymnasium first before "
69+
f"running the frozen_lake workflow. Error: {str(e)}"
70+
)
71+
self.logger.error(error_message)
72+
raise ImportError(error_message)
73+
74+
if self.desc is None:
75+
random_map, goal_position = generate_random_map(
76+
size=self.size, p=self.p, seed=self.seed, max_steps=self.max_steps
77+
)
78+
else:
79+
random_map = np.asarray(copy.deepcopy(self.desc), dtype="c")
80+
goal_position = get_goal_position(random_map)
81+
82+
self.goal_position = goal_position
83+
84+
GymFrozenLakeEnv.__init__(self, desc=random_map[:], is_slippery=self.is_slippery)
85+
self.action_space = gym.spaces.Discrete(4, start=1)
86+
87+
self.map_kwargs = {
88+
"size": size,
89+
"p": p,
90+
}
91+
self.env_kwargs = {
92+
"is_slippery": is_slippery,
93+
"desc": copy.deepcopy(desc),
94+
"seed": seed,
95+
}
96+
97+
self.action_map = {
98+
1: 0, # left
99+
2: 1, # down
100+
3: 2, # right
101+
4: 3, # up
102+
}
103+
104+
def _get_player_position(self) -> Tuple[int, int]:
105+
return (self.s // self.ncol, self.s % self.ncol) # (row, col)
106+
107+
def step(self, action: str) -> Tuple[str, float, bool, Dict]:
108+
"""Execute a step in the environment.
109+
110+
Maps custom action to gymnasium FrozenLakeEnv action and takes the step.
111+
Checks if the action is effective (whether player moves in the env).
112+
113+
Args:
114+
action: The action to take.
115+
116+
Returns:
117+
Tuple of (observation, reward, done, info).
118+
"""
119+
if self.success():
120+
return self.render(), 1, True, {"action_is_effective": False}
121+
122+
action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0)
123+
124+
if not action_id:
125+
action_id = self.INVALID_ACTION
126+
127+
if action_id == self.INVALID_ACTION or action_id not in self.action_map:
128+
return self.render(), 0, False, {"action_is_effective": False}
129+
130+
prev_player_position = int(self.s)
131+
132+
player_pos, reward, done, _, _ = GymFrozenLakeEnv.step(self, self.action_map[action_id])
133+
134+
obs = self.render()
135+
return obs, reward, done, {"action_is_effective": prev_player_position != int(player_pos)}
136+
137+
def render(self, mode="tiny_rgb_array"):
138+
"""Render the environment.
139+
140+
Args:
141+
mode: Rendering mode. Options: "tiny_rgb_array", "list", "state", "rgb_array", "ansi".
142+
143+
Returns:
144+
Rendered observation based on the mode.
145+
"""
146+
assert mode in ["tiny_rgb_array", "list", "state", "rgb_array", "ansi"]
147+
if mode in ["rgb_array", "ansi"]:
148+
prev_render_mode = self.render_mode
149+
self.render_mode = mode
150+
obs = GymFrozenLakeEnv.render(self)
151+
self.render_mode = prev_render_mode
152+
return obs
153+
room_state = copy.deepcopy(self.desc)
154+
155+
# replace the position of start 'S' with 'F'
156+
position_S = np.where(room_state == b"S")
157+
room_state[position_S] = b"F"
158+
159+
# replace the position of the player with 'P'
160+
position_P = self._get_player_position()
161+
room_state[position_P] = b"P"
162+
163+
if mode == "state":
164+
# transform 'S', 'F', 'H', 'G' to numpy integer array
165+
room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state)
166+
# add player in hole or player on goal
167+
if self.desc[position_P] == b"H":
168+
room_state[position_P] = 4
169+
elif self.desc[position_P] == b"G":
170+
room_state[position_P] = 5
171+
return room_state
172+
173+
room_state = self.render(mode="state").tolist()
174+
175+
if mode == "list":
176+
177+
def lookup(cell):
178+
return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip()
179+
180+
return [" ".join(lookup(cell) for cell in row) for row in room_state]
181+
182+
if mode == "tiny_rgb_array":
183+
184+
def lookup(cell):
185+
return self.GRID_LOOKUP.get(cell, "?")
186+
187+
result = "\n".join("".join(lookup(cell) for cell in row) for row in room_state)
188+
return result
189+
190+
def reset(self, task: Optional[Dict] = None):
191+
task = task or {}
192+
self.__init__( # type: ignore [misc]
193+
size=task.get("size", self.map_kwargs["size"]),
194+
p=task.get("p", self.map_kwargs["p"]),
195+
seed=task.get("seed", self.env_kwargs["seed"]),
196+
is_slippery=task.get("is_slippery", self.env_kwargs["is_slippery"]),
197+
)
198+
GymFrozenLakeEnv.reset(self, seed=self.seed)
199+
return self.render(mode="tiny_rgb_array"), {}
200+
201+
def finished(self) -> bool:
202+
player_pos = self._get_player_position()
203+
return self.desc[player_pos] in b"GH" # type: ignore [index,operator]
204+
205+
def success(self):
206+
"""
207+
Check if the agent has reached the goal (G).
208+
"""
209+
player_pos = self._get_player_position()
210+
return self.desc[player_pos] in b"G"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
project: "FrozenLake"
2+
name: "Qwen2.5-3B-Instruct-agent"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: multi_step_grpo
6+
repeat_times: 16
7+
kl_loss_fn: "low_var_kl"
8+
kl_loss_fn_args:
9+
kl_coef: 0
10+
advantage_fn_args:
11+
epsilon: 1e-6
12+
std_threshold: 0.0001
13+
enable_step_norm: true
14+
optimizer:
15+
lr: 1e-6
16+
model:
17+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
18+
max_response_tokens: 2048
19+
max_model_len: 25600
20+
temperature: 1.0
21+
cluster:
22+
node_num: 1
23+
gpu_per_node: 8
24+
buffer:
25+
total_epochs: 1
26+
batch_size: 32
27+
train_batch_size: 1024
28+
explorer_input:
29+
taskset:
30+
name: frozenlake
31+
storage_type: file
32+
path: ${oc.env:TRINITY_TASKSET_PATH}
33+
split: train
34+
workflow_args:
35+
env_max_steps: 8
36+
agent_max_steps: 10
37+
is_slippery: false
38+
default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'
39+
trainer_input:
40+
experience_buffer:
41+
name: frozenlake_experience_buffer
42+
storage_type: queue
43+
max_read_timeout: 7200
44+
replay_buffer:
45+
enable: true
46+
priority_fn: linear_decay
47+
priority_fn_args:
48+
decay: 0.1
49+
explorer:
50+
eval_on_startup: true
51+
eval_interval: 20
52+
runner_per_model: 8
53+
rollout_model:
54+
engine_num: 6
55+
tensor_parallel_size: 1
56+
enable_chunked_prefill: true
57+
enforce_eager: false
58+
enable_openai_api: true
59+
enable_log_requests: true
60+
enable_history: true
61+
enable_auto_tool_choice: true
62+
tool_call_parser: hermes
63+
# reasoning_parser: deepseek_r1 # if you use Qwen3 series, uncomment this line
64+
enable_thinking: true
65+
dtype: bfloat16
66+
seed: 42
67+
gpu_memory_utilization: 0.85
68+
trainer:
69+
save_interval: 100
70+
use_dynamic_bsz: true
71+
grad_clip: 1.0
72+
ulysses_sequence_parallel_size: 2
73+
74+
synchronizer:
75+
sync_method: nccl
76+
sync_style: dynamic_by_explorer
77+
sync_interval: 1
78+
sync_timeout: 1200

0 commit comments

Comments
 (0)