Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,19 @@ def run_async_nemo_gym_rollout(
timer_prefix = "timing/rollout"
timer.start(f"{timer_prefix}/total")

# Parse global logits processor defaults so per-sample budgets can inherit
# grace_period and end_token_ids without repeating them in every datum.
_lp_defaults: dict = {}
_lp_args_str = (
(policy_generation.cfg.get("vllm_cfg") or {})
.get("logits_processor_env_vars") or {}
).get("THINKING_BUDGET_LOGITS_PROCESSOR_ARGS")
if _lp_args_str:
try:
_lp_defaults = json.loads(_lp_args_str)
except (json.JSONDecodeError, AttributeError):
pass

for rowidx, row in enumerate(nemo_gym_rows):
# We may need better handling here. The max tokens set here would be the max new generated tokens, not the total max tokens.
# Currently, we just rely on the underlying vLLM engine to do the truncation for us using the max model seq len set in the config.
Expand All @@ -1045,6 +1058,23 @@ def run_async_nemo_gym_rollout(
responses_create_params["temperature"] = generation_config["temperature"]
responses_create_params["top_p"] = generation_config["top_p"]

# Per-sample thinking budget: inject vllm_xargs so the logits processor
# picks up the per-request budget override alongside global defaults.
if "thinking_budget" in row:
vllm_xargs = {
"thinking_budget": row["thinking_budget"],
"thinking_budget_grace_period": row.get(
"thinking_budget_grace_period",
_lp_defaults.get("thinking_budget_grace_period", 30),
),
"end_token_ids": json.dumps(
row.get("end_token_ids", _lp_defaults.get("end_token_ids", []))
),
}
metadata = responses_create_params.get("metadata") or {}
metadata["extra_body"] = json.dumps({"vllm_xargs": vllm_xargs})
responses_create_params["metadata"] = metadata

# Max new tokens, just like max_seq_len above is ignored and we rely on the underlying vLLM engine for truncation.
# generation_config["max_new_tokens"]

Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/models/generation/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class VllmSpecificArgs(TypedDict):
# Miscellaneous top level vLLM HTTP server arguments.
# A filepath that can be imported to register a vLLM tool parser
tool_parser_plugin: NotRequired[str]
# List of logits processor class paths to load (e.g., ["module.path:ClassName"])
logits_processors: NotRequired[list[str]]
# Environment variables for logits processor configuration (global defaults)
logits_processor_env_vars: NotRequired[dict[str, str]]


class VllmConfig(GenerationConfig):
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def __init__(
if self.ep_size > self.tp_size:
env_vars["VLLM_DP_SIZE"] = str(self.vllm_dp_size)

if self.cfg["vllm_cfg"].get("logits_processor_env_vars"):
env_vars.update(self.cfg["vllm_cfg"]["logits_processor_env_vars"])

# Check if we need parallelism-aware worker group creation
if self.model_parallel_size > 1:
# For parallelism, create node-aware worker groups
Expand Down
43 changes: 43 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,34 @@ def __repr__(self) -> str:
"""
return f"{self.__class__.__name__}"

@staticmethod
def _load_logits_processor_classes(processor_specs: list[str]) -> list[type]:
"""Load logits processor classes from module:class strings.

Args:
processor_specs: List of strings in format "module.path:ClassName"

Returns:
List of loaded processor classes
"""
import importlib

processor_classes = []
for spec in processor_specs:
try:
if ":" not in spec:
raise ValueError(f"Invalid spec '{spec}'. Must be 'module:ClassName'")

module_path, class_name = spec.split(":", 1)
module = importlib.import_module(module_path)
processor_class = getattr(module, class_name)
processor_classes.append(processor_class)
except Exception as e:
import warnings
warnings.warn(f"Failed to load logits processor '{spec}': {e}")

return processor_classes

@staticmethod
def configure_worker(
num_gpus: int | float, bundle_indices: Optional[tuple[int, list[int]]] = None
Expand Down Expand Up @@ -113,6 +141,11 @@ def configure_worker(
# Skip vllm P2P check and rely on driver to report peer to peer capability.
env_vars["VLLM_SKIP_P2P_CHECK"] = "1"

if "cfg" in init_kwargs:
cfg = init_kwargs["cfg"]
if "vllm_cfg" in cfg and "logits_processor_env_vars" in cfg["vllm_cfg"]:
env_vars.update(cfg["vllm_cfg"]["logits_processor_env_vars"])

return resources, env_vars, init_kwargs

def __init__(
Expand Down Expand Up @@ -435,6 +468,12 @@ def _patch_vllm_speculative_decoding_post_step():
)
self.cfg["vllm_cfg"]["skip_tokenizer_init"] = False

logits_processor_classes = []
if self.cfg["vllm_cfg"].get("logits_processors"):
logits_processor_classes = self._load_logits_processor_classes(
self.cfg["vllm_cfg"]["logits_processors"]
)

llm_kwargs = dict(
model=self.model_name,
served_model_name=self.model_name,
Expand All @@ -459,6 +498,10 @@ def _patch_vllm_speculative_decoding_post_step():
**vllm_kwargs,
)

# Add logits processors if loaded
if logits_processor_classes:
llm_kwargs["logits_processors"] = logits_processor_classes

self._create_engine(llm_kwargs)

# will be initialized in post_init
Expand Down