Skip to content

Commit 4db1704

Browse files
feat: Add Penguin env (#1327)
Signed-off-by: Brian Yu <bxyu@nvidia.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent eb5bb0f commit 4db1704

File tree

7 files changed

+805
-2
lines changed

7 files changed

+805
-2
lines changed

3rdparty/Penguin-workspace/setup.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
1415
import tomllib
1516
from pathlib import Path
1617

@@ -21,19 +22,81 @@
2122

2223
# If the submodule is present, expose `penguin` package from the checkout
2324
src_dir = Path("Penguin")
24-
package_name = "penguin"
25+
26+
27+
CACHED_DEPENDENCIES = [
28+
"openai<=1.97.1",
29+
"tqdm",
30+
"pydantic",
31+
"pydantic_core",
32+
"devtools",
33+
"fastapi",
34+
"uvicorn",
35+
"uvloop",
36+
"hydra-core",
37+
"omegaconf",
38+
"gradio",
39+
"mlflow",
40+
"tdigest>=0.5.2.2",
41+
"aiohttp",
42+
"yappi",
43+
]
2544

2645
if src_dir.exists():
2746
pyproject_toml_path = src_dir / "pyproject.toml"
2847
with pyproject_toml_path.open("rb") as f:
2948
pyproject_toml = tomllib.load(f)
49+
if not pyproject_toml_path.exists():
50+
raise FileNotFoundError(
51+
f"[Penguin][setup] {pyproject_toml_path} not found; skipping dependency consistency check."
52+
)
3053

3154
packages = pyproject_toml["tool"]["setuptools"]["packages"]["find"]["include"]
3255

3356
for package in packages:
3457
final_packages.append(package)
3558
final_package_dir[package] = src_dir / package
3659

60+
actual_dependencies = pyproject_toml["project"]["dependencies"]
61+
62+
########################################
63+
# Compare cached dependencies with the submodule's pyproject
64+
########################################
65+
66+
missing_in_cached = set(actual_dependencies) - set(CACHED_DEPENDENCIES)
67+
extra_in_cached = set(CACHED_DEPENDENCIES) - set(actual_dependencies)
68+
69+
if missing_in_cached or extra_in_cached:
70+
print(
71+
"[Penguin][setup] Dependency mismatch between Penguin-workspace/Penguin/pyproject.toml vs Penguin-workspace/setup.py::CACHED_DEPENDENCIES.",
72+
file=sys.stderr,
73+
)
74+
if missing_in_cached:
75+
print(
76+
" - Present in Penguin-workspace/Penguin/pyproject.toml but missing from CACHED_DEPENDENCIES:",
77+
file=sys.stderr,
78+
)
79+
for dep in sorted(missing_in_cached):
80+
print(f" * {dep}", file=sys.stderr)
81+
if extra_in_cached:
82+
print(
83+
" - Present in CACHED_DEPENDENCIES but not in Penguin-workspace/Penguin/pyproject.toml:",
84+
file=sys.stderr,
85+
)
86+
for dep in sorted(extra_in_cached):
87+
print(f" * {dep}", file=sys.stderr)
88+
print(
89+
" Please update CACHED_DEPENDENCIES or the submodule pyproject to keep them in sync.",
90+
file=sys.stderr,
91+
)
92+
sys.exit(1)
93+
else:
94+
print(
95+
"[Penguin][setup] Dependency sets are consistent with the submodule pyproject.",
96+
file=sys.stderr,
97+
)
98+
99+
37100
setuptools.setup(
38101
name="penguin",
39102
version="0.0.0",
@@ -43,5 +106,5 @@
43106
packages=final_packages,
44107
package_dir=final_package_dir,
45108
py_modules=["is_penguin_installed"],
46-
install_requires=[],
109+
install_requires=CACHED_DEPENDENCIES,
47110
)

nemo_rl/distributed/ray_actor_environment_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker
4343
"nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.VLLM,
4444
"nemo_rl.environments.tools.retriever.RAGEnvironment": PY_EXECUTABLES.SYSTEM,
45+
"nemo_rl.environments.penguin.Penguin": PY_EXECUTABLES.PENGUIN,
4546
}
4647

4748

nemo_rl/distributed/virtual_cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class PY_EXECUTABLES:
5151
# Use NeMo-RL direct dependencies and nemo-automodel.
5252
AUTOMODEL = "uv run --locked --extra automodel"
5353

54+
# Use Penguin dependencies
55+
PENGUIN = "uv run --locked --extra penguin"
56+
5457
# Megatron-core (and nemo dependencies)
5558
# We always run with --reinstall to avoid issues where someone runs "uv run ... --extra mcore ..."
5659
# but the submodules are not downloaded yet. This results in errors where it appears Megatron/Nemo

nemo_rl/environments/penguin.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pathlib import Path
15+
from typing import Any, Dict, List, TypedDict
16+
17+
import ray
18+
import torch
19+
20+
from nemo_rl.data.interfaces import DatumSpec
21+
from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local
22+
from nemo_rl.environments.interfaces import EnvironmentInterface
23+
24+
25+
class PenguinConfig(TypedDict):
26+
model_name: str
27+
base_urls: List[str]
28+
initial_global_config_dict: Dict[str, Any]
29+
30+
31+
@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover
32+
class Penguin(EnvironmentInterface):
33+
"""This environment class isn't really used for training. It's really meant as an integration wrapper around Penguin that hooks into the existing NeMo RL resource management via ray. So there is still one source of truth for resource management in NeMo RL."""
34+
35+
def __init__(self, cfg: PenguinConfig):
36+
self.cfg = cfg
37+
38+
self.node_ip = _get_node_ip_local()
39+
self.head_server_port = _get_free_port_local()
40+
41+
from omegaconf import DictConfig
42+
from penguin.cli import GlobalConfigDictParserConfig, RunHelper
43+
from penguin.rollout_collection import RolloutCollectionHelper
44+
from penguin.server_utils import HEAD_SERVER_KEY_NAME, BaseServerConfig
45+
46+
RELATIVE_PATH = "nemo_rl/environments/penguin.py"
47+
assert __file__.endswith(RELATIVE_PATH)
48+
49+
initial_global_config_dict = self.cfg["initial_global_config_dict"]
50+
# Policy information
51+
initial_global_config_dict["policy_model_name"] = self.cfg["model_name"]
52+
initial_global_config_dict["policy_api_key"] = (
53+
"dummy_key" # No key necessary for training.
54+
)
55+
initial_global_config_dict["policy_base_url"] = self.cfg["base_urls"]
56+
57+
initial_global_config_dict["global_aiohttp_connector_limit_per_host"] = (
58+
initial_global_config_dict.get("global_aiohttp_connector_limit_per_host")
59+
or 1024
60+
)
61+
initial_global_config_dict["global_aiohttp_connector_limit"] = (
62+
initial_global_config_dict["global_aiohttp_connector_limit_per_host"]
63+
* len(self.cfg["base_urls"])
64+
)
65+
66+
print(
67+
f"""Set `global_aiohttp_connector_limit_per_host` to a flat {initial_global_config_dict["global_aiohttp_connector_limit_per_host"]}.
68+
Since there are {len(self.cfg["base_urls"])} data-parallel vLLM worker instances, the `global_aiohttp_connector_limit` has been set to {len(self.cfg["base_urls"])} * {initial_global_config_dict["global_aiohttp_connector_limit_per_host"]} = {initial_global_config_dict["global_aiohttp_connector_limit"]}."""
69+
)
70+
71+
# Head server
72+
initial_global_config_dict[HEAD_SERVER_KEY_NAME] = {
73+
"host": "0.0.0.0",
74+
"port": self.head_server_port,
75+
}
76+
77+
self.rh = RunHelper()
78+
self.rh.start(
79+
global_config_dict_parser_config=GlobalConfigDictParserConfig(
80+
dotenv_path=Path(__file__.removesuffix(RELATIVE_PATH)).absolute()
81+
/ "penguin_env.yaml",
82+
initial_global_config_dict=DictConfig(initial_global_config_dict),
83+
skip_load_from_cli=True,
84+
)
85+
)
86+
87+
# Setup for rollout collection
88+
self.head_server_config = BaseServerConfig(
89+
host=self.node_ip,
90+
port=self.head_server_port,
91+
)
92+
self.rch = RolloutCollectionHelper()
93+
94+
def health_check(self) -> bool:
95+
return True
96+
97+
async def run_rollouts(self, penguin_examples: list[dict]) -> list[dict]:
98+
penguin_results = await self.rch.run_examples(
99+
examples=penguin_examples, head_server_config=self.head_server_config
100+
)
101+
102+
nemo_rl_results = list(
103+
map(self._postprocess_penguin_to_nemo_rl_result, penguin_results)
104+
)
105+
return nemo_rl_results
106+
107+
def _postprocess_penguin_to_nemo_rl_result(self, penguin_result: dict) -> dict:
108+
nemo_rl_message_log = []
109+
seen_token_ids: List[int] = []
110+
for output_item_dict in penguin_result["response"]["output"]:
111+
# Nemo RL really only has two types of messages: assistant and not assistant since that is all that it is concerned with (i.e. to train or not to train)
112+
# Here we map all the trainable messages to assistant and all the non-trainable messages to user.
113+
# Eventually we can maybe be smarter about this, but this is functional for now.
114+
115+
# Note that Penguin will only return token ids on "assistant" messages and not other message types.
116+
if "generation_token_ids" not in output_item_dict:
117+
continue
118+
119+
assert (
120+
seen_token_ids
121+
== output_item_dict["prompt_token_ids"][: len(seen_token_ids)]
122+
), f"""Non-contiguous messages found! This may be a tokenization issue where certain tokens are combined when messages are concatenated, or it may be due to part of the chat history being truncated (like if super long history is truncated or if reasoning is stripped out).
123+
Seen token IDs: {seen_token_ids}
124+
Output prompt token IDs: {output_item_dict["prompt_token_ids"]}
125+
"""
126+
127+
nemo_rl_message_log.append(
128+
{
129+
"role": "user",
130+
"content": "",
131+
"token_ids": output_item_dict["prompt_token_ids"][
132+
len(seen_token_ids) :
133+
],
134+
}
135+
)
136+
nemo_rl_message_log.append(
137+
{
138+
"role": "assistant",
139+
"content": "",
140+
"token_ids": output_item_dict["generation_token_ids"],
141+
"generation_logprobs": output_item_dict["generation_log_probs"],
142+
}
143+
)
144+
145+
seen_token_ids.extend(nemo_rl_message_log[-2]["token_ids"])
146+
seen_token_ids.extend(nemo_rl_message_log[-1]["token_ids"])
147+
148+
return {
149+
"message_log": nemo_rl_message_log,
150+
"input_message_log": nemo_rl_message_log[:1],
151+
"full_result": penguin_result,
152+
}
153+
154+
def shutdown(self) -> None:
155+
self.rh.shutdown()
156+
157+
def step(self, message_log_batch, metadata):
158+
# This is not used since Penguin will handle the rollouts entirely.
159+
raise NotImplementedError
160+
161+
def global_post_process_and_metrics(self, batch):
162+
# Similar to the step function, this is not used.
163+
raise NotImplementedError
164+
165+
166+
########################################
167+
# Global config utils
168+
########################################
169+
170+
171+
def setup_penguin_config(config, tokenizer) -> None:
172+
generation_config = config["policy"]["generation"]
173+
174+
# Enable the http server. Requires both async engine and the expose_http_server flag
175+
generation_config["vllm_cfg"]["async_engine"] = True
176+
generation_config["vllm_cfg"]["expose_http_server"] = True
177+
178+
# Stop strings or token ids are not supported
179+
generation_config["stop_strings"] = None
180+
generation_config["stop_token_ids"] = None
181+
182+
183+
########################################
184+
# Data utils
185+
########################################
186+
187+
188+
# We do some light preprocessing here to make our data format compatible with nemo rl format
189+
def penguin_example_to_nemo_rl_datum_spec(penguin_example: dict, idx: int) -> DatumSpec:
190+
return DatumSpec(
191+
message_log=[
192+
{"role": "user", "content": "", "token_ids": torch.tensor([])}
193+
], # Fake message
194+
length=0,
195+
extra_env_info=penguin_example,
196+
loss_multiplier=1.0, # Fix to 1.0 to backprop on all examples
197+
idx=idx,
198+
task_name="penguin",
199+
stop_strings=None,
200+
# Extra vars
201+
token_ids=[], # Just need this empty key to be compatible with the current NeMo RL GRPO impl
202+
)

tests/unit/environments/penguin_test_data/test_penguin_sanity.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)