Skip to content

Commit 9704198

Browse files
committed
Fixes
1 parent 4e7cc85 commit 9704198

File tree

5 files changed

+25
-40
lines changed

5 files changed

+25
-40
lines changed

docs/source/openenv.md

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -316,41 +316,16 @@ That's it! Let's unpack how the main pieces fit together:
316316
317317
You can run the example in either colocate mode (1 GPU) or server mode (2 GPUs):
318318
319-
<hfoptions id="vllm_mode">
320-
321-
<hfoption id="colocate">
322-
323-
**Colocate mode (1 GPU, recommended)**
324-
325319
```bash
326-
python examples/scripts/openenv/echo.py --vllm-mode colocate
320+
python examples/scripts/openenv/echo.py
327321
```
328322
329-
This runs vLLM in the same process as training, requiring only a single GPU.
330-
331-
</hfoption>
332-
333-
<hfoption id="server">
334-
335-
**Server mode (2+ GPUs, scalable)**
323+
You can customize the model and environment URL:
336324
337325
```bash
338-
# Terminal 1: Start vLLM inference server
339-
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8000
340-
341-
# Terminal 2: Run GRPO training with OpenEnv
342-
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py --vllm-mode server --vllm-server-url http://localhost:8000
326+
python examples/scripts/openenv/echo.py --model Qwen/Qwen3-0.6B --env-host https://openenv-echo-env.hf.space
343327
```
344328
345-
This runs vLLM as a separate server process, useful when you want to:
346-
- Share the inference server across multiple training jobs
347-
- Use multiple GPUs for the vLLM server (via `--tensor-parallel-size`)
348-
- Scale up training to many GPUs while sharing a single inference endpoint
349-
350-
</hfoption>
351-
352-
</hfoptions>
353-
354329
Below is the reward curve from training:
355330
356331
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
@@ -512,6 +487,9 @@ The resulting model improves its performance on the game, both by reducing the n
512487
513488
<iframe src="https://burtenshaw-wordle-grpo.hf.space?project=group-Qwen-Qwen3-17B&metrics=reward&runs=run-2025-10-26_09-39-49,run-2025-10-26_08-04-49&sidebar=hidden&navbar=hidden" style="width:1600px; height:500px; border:0;"></iframe>
514489
490+
> [!NOTE]
491+
> With `enable_thinking=False` (the default in these examples), small models like Qwen3-1.7B can learn to improve their guesses but should not be expected to consistently solve the game. For significantly better results, use larger models or enable thinking mode (`enable_thinking=True`), which allows the model to reason before making a guess at the cost of longer completions.
492+
515493
We experimented with larger models like `gpt-oss-20b` and found that the model was able to consistently win the game. However, this requires a lot of compute to train the model. Why not try this out yourself?
516494
517495
## Multi-Environment Training
@@ -587,6 +565,7 @@ Key patterns:
587565
- **Lazy client initialization**: Create clients in `reset()`, not `__init__()`, to avoid unnecessary WebSocket connections.
588566
- **Close before reopen**: Close the previous client before creating a new one to avoid server capacity errors.
589567
- **`kwargs` routing**: The `"env"` column from the dataset is passed to `reset()` as a keyword argument.
568+
- **All tools are exposed simultaneously**: The model sees `guess`, `move`, and `stay` as available tools regardless of the active environment. If it calls the wrong tool (e.g., `move` during Wordle), the method raises a `ValueError` that the trainer catches gracefully. In practice, models learn to use the correct tools based on the system prompt.
590569
591570
### Per-environment reward functions
592571

examples/scripts/openenv/multi_env.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,14 @@ def main() -> None:
237237
CATCH_URL = args.catch_url
238238

239239
n = 500 # samples per environment
240-
dataset = Dataset.from_dict({
241-
"prompt": (
242-
[[{"role": "user", "content": wordle_prompt}]] * n
243-
+ [[{"role": "user", "content": catch_prompt}]] * n
244-
),
245-
"env": ["wordle"] * n + ["catch"] * n,
246-
})
240+
dataset = Dataset.from_dict(
241+
{
242+
"prompt": (
243+
[[{"role": "user", "content": wordle_prompt}]] * n + [[{"role": "user", "content": catch_prompt}]] * n
244+
),
245+
"env": ["wordle"] * n + ["catch"] * n,
246+
}
247+
)
247248

248249
trainer = GRPOTrainer(
249250
model="Qwen/Qwen3-1.7B",

examples/scripts/openenv/sudoku.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def parse_args() -> argparse.Namespace:
170170
)
171171

172172
# LoRA / PEFT
173-
parser.add_argument("--use-lora", action="store_true", default=False, help="Use LoRA for memory-efficient training")
173+
parser.add_argument(
174+
"--use-lora", action="store_true", default=False, help="Use LoRA for memory-efficient training"
175+
)
174176
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
175177
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
176178

@@ -499,7 +501,7 @@ def place(self, row: int, col: int, number: int) -> str:
499501

500502
# Only check the NEW content for feedback (messages are cumulative)
501503
full_content = observation.messages[0].content if observation.messages else ""
502-
new_content = full_content[len(self._last_full_content):]
504+
new_content = full_content[len(self._last_full_content) :]
503505
self._last_full_content = full_content
504506

505507
new_content_lower = new_content.lower()

trl/generation/vllm_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def generate(
523523
images: list[list | None] | None,
524524
num_generations: int,
525525
profiler: ProfilingContext | None = None,
526+
tools: list | None = None,
526527
) -> tuple:
527528
"""Generate completions using vLLM.
528529

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import asyncio
1616
import atexit
1717
import copy
18-
import math
1918
import importlib.resources as pkg_resources
2019
import inspect
20+
import math
2121
import os
2222
import sys
2323
import textwrap
@@ -1249,6 +1249,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
12491249
images=images,
12501250
num_generations=num_generations,
12511251
profiler=profiling_context(self, "vLLM.generate"),
1252+
tools=self.tools,
12521253
)
12531254
# vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob
12541255
logprobs = [[lp[0] for lp in seq] for seq in logprobs]
@@ -1504,8 +1505,9 @@ async def _run_async_tools(async_coros):
15041505
for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True):
15051506
if not tool_call:
15061507
continue
1507-
# If the environment has a _done attribute and it's True, stop calling tools for it
1508-
if hasattr(self.environments[idx], "_done") and self.environments[idx]._done:
1508+
# If the environment signals it's done, stop calling tools for it
1509+
env = self.environments[idx]
1510+
if getattr(env, "_done", False) or getattr(env, "done", False):
15091511
continue
15101512
filtered_idxs.append(idx)
15111513
filtered_tool_calls.append(tool_call)

0 commit comments

Comments
 (0)