Skip to content

Commit 2aea5ad

Browse files
fix: fix temperature-related issues (#935)
Signed-off-by: Zhanda <zhandazhu@gmail.com> Signed-off-by: Zhanda Zhu <49645678+zhandaz@users.noreply.github.com> Co-authored-by: Shang Wang <samshang.wang@mail.utoronto.ca>
1 parent 989f177 commit 2aea5ad

File tree

7 files changed

+135
-16
lines changed

7 files changed

+135
-16
lines changed

nemo_rl/models/generation/vllm/vllm_generation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
)
3838
from nemo_rl.models.generation.vllm.config import VllmConfig
3939

40+
# Global thresholds for top_k and top_p validation.
41+
# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible.
42+
# See https://github.com/NVIDIA-NeMo/RL/issues/69 and https://github.com/NVIDIA-NeMo/RL/issues/237 for more details.
43+
TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering)
44+
TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0)
45+
4046

4147
class VllmGeneration(GenerationInterface):
4248
def __init__(
@@ -55,6 +61,33 @@ def __init__(
5561
"You can enable it by adding `policy.generation.vllm_cfg.async_engine=true` to your command."
5662
)
5763

64+
# Validate sampling parameters early to avoid resource allocation with unsupported configs.
65+
# The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly.
66+
# However, we allow values above certain thresholds for token filtering purposes.
67+
top_k: int | None = self.cfg.get("top_k")
68+
if top_k is not None and top_k != -1 and top_k < TOP_K_THRESHOLD:
69+
raise ValueError(
70+
(
71+
f"top_k sampling with values < {TOP_K_THRESHOLD} is not supported because the vLLM V1 engine "
72+
"does not return logprobs after top_k filtering. Values >= {TOP_K_THRESHOLD} are allowed "
73+
"for token filtering purposes. If you understand the implications and still want to use "
74+
f"a lower top_k value, please manually comment out this check. Got top_k={top_k}. "
75+
"See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details."
76+
)
77+
)
78+
79+
top_p: float = self.cfg.get("top_p", 1.0)
80+
if top_p < TOP_P_THRESHOLD:
81+
raise ValueError(
82+
(
83+
f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine "
84+
"does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed "
85+
"for token filtering purposes. If you understand the implications and still want to use "
86+
f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. "
87+
"See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details."
88+
)
89+
)
90+
5891
# Ensure all required VllmConfig fields are present
5992
missing_keys = [
6093
key for key in VllmConfig.__required_keys__ if key not in self.cfg

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,43 @@ def _patch_vllm_init_workers_ray():
250250
_patch_vllm_init_workers_ray()
251251
logger.info("Successfully patched vllm _init_workers_ray.")
252252

253+
# Patch the vLLM sampler.py file to modify logprobs computation wrt temperature.
254+
# This replaces raw_logprobs = self.compute_logprobs(logits) with custom temperature-applied logprobs.
255+
# TODO(zhanda): This is only a temporary fix to address the issue of incorrect logprobs returned by vllm
256+
# and should be removed or improved after vllm's new logprobs option is released. And currently, other
257+
# sampling parameters like top_p, top_k, etc. are not supported.
258+
# See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details.
259+
def _patch_vllm_sampler():
260+
try:
261+
import vllm.v1.sample.sampler as sampler_module
262+
263+
file_to_patch = sampler_module.__file__
264+
265+
with open(file_to_patch, "r") as f:
266+
content = f.read()
267+
268+
old_line = "raw_logprobs = self.compute_logprobs(logits)"
269+
new_lines = "raw_logprobs = self.compute_logprobs(self.apply_temperature(logits.to(torch.float32), sampling_metadata.temperature) if sampling_metadata.temperature is not None else logits)"
270+
271+
if new_lines in content:
272+
return
273+
274+
if old_line not in content:
275+
return
276+
277+
# Replace all instances of the old line with the new lines
278+
patched_content = content.replace(old_line, new_lines)
279+
280+
# Write back the patched content
281+
with open(file_to_patch, "w") as f:
282+
f.write(patched_content)
283+
284+
except (ImportError, FileNotFoundError, PermissionError):
285+
# Allow failures gracefully
286+
pass
287+
288+
_patch_vllm_sampler()
289+
253290
except (ImportError, AttributeError):
254291
# vllm not installed or has a different structure, skipping patch.
255292
pass

nemo_rl/models/megatron/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def forward_step_arbitrary_loss(
260260
pad_individual_seqs_to_multiple_of: int = 1,
261261
pad_full_seq_to: Optional[int] = None,
262262
cp_normalize: bool = True,
263+
policy_cfg: Optional[dict] = None,
263264
):
264265
"""Forward training step with support for packed sequences and context parallelism.
265266
@@ -273,6 +274,7 @@ def forward_step_arbitrary_loss(
273274
pack_sequences (bool): Whether to pack sequences for efficiency
274275
seq_length_key (Optional[str]): Key in data_dict containing actual sequence lengths
275276
cp_normalize (bool): Whether to normalize the loss by the cp_size
277+
policy_cfg (Optional[dict]): Policy configuration containing generation parameters
276278
277279
Notes on packed sequences with context parallelism (CP):
278280
- When CP > 1, each sequence is padded to a multiple of (cp_size * 2)
@@ -342,6 +344,15 @@ def forward_step_arbitrary_loss(
342344
packed_seq_params=packed_seq_params,
343345
)
344346

347+
# Apply temperature scaling to logits for training
348+
# This matches the dtensor worker's _apply_temperature_scaling in the train method
349+
if (
350+
policy_cfg is not None
351+
and "generation" in policy_cfg
352+
and policy_cfg["generation"] is not None
353+
):
354+
output_tensor.div_(policy_cfg["generation"]["temperature"])
355+
345356
# Unpack the output tensor if we did packed sequences
346357
if pack_sequences and packed_seq_params is not None:
347358
# remove padding

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
get_handle_from_tensor,
7070
get_runtime_env_for_policy_worker,
7171
import_class_from_path,
72-
is_vllm_v1_engine_enabled,
7372
resolve_model_class,
7473
sliding_window_overwrite,
7574
)
@@ -471,13 +470,8 @@ def create_context_parallel_ctx(
471470
# based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178
472471

473472
def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
474-
# Apply temperature scaling to logits if configured and not using V1 engine.
475473
if "generation" in self.cfg and self.cfg["generation"] is not None:
476-
# The V1 engine returns raw logits before temperature scaling.
477-
# The V0 engine returns scaled logits.
478-
# Therefore, we only divide if we are not using the V1 engine.
479-
if not is_vllm_v1_engine_enabled():
480-
logits.div_(self.cfg["generation"]["temperature"])
474+
logits.div_(self.cfg["generation"]["temperature"])
481475
return logits
482476

483477
@staticmethod

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
get_handle_from_tensor,
8080
get_runtime_env_for_policy_worker,
8181
import_class_from_path,
82-
is_vllm_v1_engine_enabled,
8382
)
8483
from nemo_rl.utils.native_checkpoint import (
8584
load_checkpoint,
@@ -420,13 +419,8 @@ def __init__(
420419
self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None
421420

422421
def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
423-
# Apply temperature scaling to logits if configured and not using V1 engine.
424422
if "generation" in self.cfg and self.cfg["generation"] is not None:
425-
# The V1 engine returns raw logits before temperature scaling.
426-
# The V0 engine returns scaled logits.
427-
# Therefore, we only divide if we are not using the V1 engine.
428-
if not is_vllm_v1_engine_enabled():
429-
logits.div_(self.cfg["generation"]["temperature"])
423+
logits.div_(self.cfg["generation"]["temperature"])
430424
return logits
431425

432426
def init_collective(self, ip: str, port: int, world_size: int) -> None:

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,9 @@ def train(
820820
f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}"
821821
)
822822

823-
forward_step = partial(forward_step_arbitrary_loss, loss_fn=loss_fn)
823+
forward_step = partial(
824+
forward_step_arbitrary_loss, loss_fn=loss_fn, policy_cfg=self.cfg
825+
)
824826
all_mb_metrics = []
825827
losses = []
826828
for gb_idx in range(num_global_batches):
@@ -1111,6 +1113,11 @@ def forward_step_fn(
11111113
packed_seq_params=packed_seq_params,
11121114
)
11131115

1116+
# Apply temperature scaling to logits for training
1117+
# This matches the dtensor worker's _apply_temperature_scaling in the train method
1118+
if "generation" in self.cfg and self.cfg["generation"] is not None:
1119+
output_tensor.div_(self.cfg["generation"]["temperature"])
1120+
11141121
def collection_fn(output_tensor):
11151122
stc = time.time()
11161123
tp_grp = get_tensor_model_parallel_group()

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939
},
4040
"dtype": "bfloat16",
4141
"max_new_tokens": 5, # Small number of tokens for testing
42-
"temperature": 0.8,
42+
# Set temperature=1.0 to ensure consistent probability scaling when comparing vLLM and HF policy outputs.
43+
# Note: greedy=True is only used in tests for deterministic behavior and not used in the real training.
44+
# In vLLM, enabling greedy=True disables temperature scaling (temperature is overridden to None).
45+
# The HF policy worker does not currently support greedy=True for get_logprobs.
46+
# Using temperature=1.0 allows us to meaningfully test the average probability multiplicative error between the two implementations,
47+
# while still maintaining the deterministic behavior.
48+
"temperature": 1.0,
4349
"top_p": 1.0,
4450
"top_k": None,
4551
"stop_token_ids": None,
@@ -326,6 +332,43 @@ def test_vllm_missing_required_config_key(cluster):
326332
print(f"Successfully caught missing config key with error: {error_message}")
327333

328334

335+
def test_vllm_top_p_top_k_validation(cluster):
336+
"""Test that top_p and top_k validation works correctly with threshold-based logic."""
337+
# Test that values above thresholds are allowed
338+
config_above_thresholds = deepcopy(basic_vllm_test_config)
339+
config_above_thresholds["top_p"] = 0.99 # Above TOP_P_THRESHOLD
340+
config_above_thresholds["top_k"] = 8000 # Above TOP_K_THRESHOLD
341+
342+
# Should not raise an error
343+
try:
344+
VllmGeneration(cluster, config_above_thresholds)
345+
print("Successfully initialized with top_p=0.99 and top_k=8000")
346+
except Exception as e:
347+
pytest.fail(f"Should not raise error with values above thresholds: {e}")
348+
349+
# Test that values below thresholds are rejected
350+
config_below_thresholds = deepcopy(basic_vllm_test_config)
351+
config_below_thresholds["top_p"] = 0.9 # Below TOP_P_THRESHOLD
352+
353+
with pytest.raises(ValueError) as excinfo:
354+
VllmGeneration(cluster, config_below_thresholds)
355+
356+
error_message = str(excinfo.value)
357+
assert "top_p sampling with values < 0.99 is not supported" in error_message
358+
print(f"Successfully caught low top_p value with error: {error_message}")
359+
360+
# Test that low top_k values are rejected
361+
config_low_top_k = deepcopy(basic_vllm_test_config)
362+
config_low_top_k["top_k"] = 7999 # Below TOP_K_THRESHOLD
363+
364+
with pytest.raises(ValueError) as excinfo:
365+
VllmGeneration(cluster, config_low_top_k)
366+
367+
error_message = str(excinfo.value)
368+
assert "top_k sampling with values < 8000 is not supported" in error_message
369+
print(f"Successfully caught low top_k value with error: {error_message}")
370+
371+
329372
def test_vllm_policy_generation(policy, test_input_data, tokenizer):
330373
"""Test vLLM policy generation capabilities."""
331374
# Test generation

0 commit comments

Comments
 (0)