Skip to content

Commit cca2aec

Browse files
authored
fix: Remove reference of tokenizer from generation backend (#75) (#82)
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
1 parent bd7e4b0 commit cca2aec

File tree

6 files changed

+91
-15
lines changed

6 files changed

+91
-15
lines changed

docs/design_docs/generation.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,30 @@ The {py:class}`UpdatableVllmInternalWorker <nemo_reinforcer.models.generation.vl
9595
To use a generation backend:
9696

9797
```python
98+
from transformers import AutoTokenizer
99+
98100
from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig
99101
from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster
100102
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
101103

102104
# Set up the configuration
105+
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
106+
if tokenizer.pad_token is None:
107+
tokenizer.pad_token = tokenizer.eos_token
108+
103109
config = VllmConfig(
104-
backend="vllm",
105110
model_name="Qwen/Qwen2.5-1.5B",
106111
max_new_tokens=100,
107112
temperature=0.7,
108113
top_p=1,
114+
top_k=None,
115+
stop_token_ids=[tokenizer.eos_token_id]
116+
pad_token=tokenizer.pad_token_id,
117+
skip_tokenizer_init=True,
109118
vllm_cfg={
110119
"tensor_parallel_size": 1,
111-
"gpu_memory_utilization": 0.8
120+
"gpu_memory_utilization": 0.8,
121+
"max_model_len": 2048,
112122
}
113123
)
114124

examples/run_grpo_math.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
188188
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")
189189

190190
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
191+
if tokenizer.pad_token is None:
192+
tokenizer.pad_token = tokenizer.eos_token
191193

192194
task_data_processors = defaultdict(
193195
lambda: (math_task_spec, openinstructmath2_data_processor)
@@ -270,7 +272,7 @@ def main():
270272
checkpointer,
271273
grpo_state,
272274
master_config,
273-
) = setup(config, dataset, val_dataset)
275+
) = setup(config, tokenizer, dataset, val_dataset)
274276
grpo_train(
275277
policy,
276278
policy_generation,

nemo_reinforcer/algorithms/grpo.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class MasterConfig(TypedDict):
109109

110110
def setup(
111111
master_config: MasterConfig,
112+
tokenizer: AutoTokenizer,
112113
dataset: AllTaskProcessedDataset,
113114
val_dataset: Optional[AllTaskProcessedDataset],
114115
) -> Tuple[
@@ -219,6 +220,12 @@ def setup(
219220
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
220221
backend = generation_config["backend"]
221222
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
223+
generation_config["vllm_cfg"]["skip_tokenizer_init"] = True
224+
# When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below.
225+
generation_config["stop_token_ids"] = [tokenizer.eos_token_id]
226+
generation_config["pad_token"] = tokenizer.pad_token_id
227+
generation_config["vllm_cfg"]["load_format"] = "dummy"
228+
222229
if backend == "hf":
223230
policy_generation = None
224231
print(f" ✓ Using HF backend for generation with {policy_config['model_name']}")

nemo_reinforcer/models/generation/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import Any, TypedDict, Union, Tuple
15+
from typing import Any, TypedDict, Union, Tuple, List
1616

1717
import torch
1818
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
@@ -101,6 +101,8 @@ class GenerationConfig(TypedDict):
101101
top_p: float
102102
top_k: int
103103
model_name: str
104+
stop_token_ids: List[int]
105+
pad_token: int
104106

105107

106108
class GenerationDatumSpec(TypedDict):

nemo_reinforcer/models/generation/vllm.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class VllmSpecificArgs(TypedDict):
3939
tensor_parallel_size: int
4040
gpu_memory_utilization: float
4141
max_model_len: int
42+
# Additional arguments for vLLM inserted by reinforcer based on the context of when vllm is used
43+
skip_tokenizer_init: bool
44+
load_format: str
4245

4346

4447
class VllmConfig(GenerationConfig):
@@ -110,6 +113,7 @@ def __init__(
110113
Only needed for the first worker in each tied worker group.
111114
"""
112115
self.cfg = config
116+
113117
self.model_name = self.cfg["model_name"]
114118
self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"]
115119
self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"]
@@ -166,23 +170,22 @@ def __init__(
166170

167171
self.llm = LLM(
168172
model=self.model_name,
169-
load_format="dummy",
170-
tensor_parallel_size=self.tensor_parallel_size,
171-
gpu_memory_utilization=self.gpu_memory_utilization,
173+
# Training pipeline will set this to "dummy" and eval will load real weights using 'auto'
174+
load_format=self.cfg["vllm_cfg"]["load_format"],
175+
skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"],
176+
tensor_parallel_size=self.cfg["vllm_cfg"]["tensor_parallel_size"],
177+
gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"],
172178
enable_prefix_caching=True,
173179
dtype="auto",
174180
enforce_eager=True,
175181
max_model_len=self.cfg["vllm_cfg"]["max_model_len"],
176182
trust_remote_code=True,
177183
worker_cls=UpdatableVllmInternalWorker,
178184
enable_sleep_mode=True,
185+
disable_log_stats=True,
179186
**vllm_kwargs,
180187
)
181188

182-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
183-
if self.tokenizer.pad_token is None:
184-
self.tokenizer.pad_token = self.tokenizer.eos_token
185-
186189
def llm(self):
187190
return self.llm
188191

@@ -213,7 +216,7 @@ def generate(
213216
f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}"
214217
)
215218
is_right_padded, error_msg = verify_right_padding(
216-
data, pad_value=self.tokenizer.pad_token_id
219+
data, pad_value=self.cfg["pad_token"]
217220
)
218221
if not is_right_padded:
219222
warnings.warn(
@@ -251,6 +254,7 @@ def generate(
251254
max_tokens=self.cfg["max_new_tokens"],
252255
logprobs=0, # Return logprobs for the generated tokens
253256
stop=None,
257+
stop_token_ids=self.cfg["stop_token_ids"],
254258
)
255259

256260
# Generate outputs
@@ -276,7 +280,7 @@ def generate(
276280

277281
# Create a new tensor with the right size and fill with padding token
278282
full_output = torch.full(
279-
(total_length,), self.tokenizer.pad_token_id, dtype=input_ids.dtype
283+
(total_length,), self.cfg["pad_token"], dtype=input_ids.dtype
280284
)
281285

282286
# Copy original input (with padding) into the beginning
@@ -402,6 +406,17 @@ def __init__(
402406
"""Initialize a vLLM policy with distributed workers."""
403407
# Store config
404408
self.cfg = config
409+
# Ensure all required VllmConfig fields are present
410+
missing_keys = [
411+
key for key in VllmConfig.__annotations__ if key not in self.cfg
412+
]
413+
assert not missing_keys, (
414+
f"VLLM Configuration Error: Missing required keys in VllmConfig.\n"
415+
f"Missing keys: {', '.join(missing_keys)}\n"
416+
f"Provided keys: {', '.join(self.cfg.keys())}\n"
417+
f"Please update your configuration to include all required VLLM parameters."
418+
)
419+
405420
self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"]
406421

407422
# Create worker builder for VllmGenerationWorker

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
# Define basic vLLM test config
2727
basic_vllm_test_config: VllmConfig = {
28+
"backend": "vllm",
2829
"model_name": "meta-llama/Llama-3.2-1B", # Small model for testing
2930
"dtype": "bfloat16",
3031
"max_new_tokens": 10,
@@ -39,6 +40,15 @@
3940
}
4041

4142

43+
def configure_vllm_with_tokenizer(vllm_config, tokenizer):
44+
"""Apply tokenizer-specific configurations to vLLM config."""
45+
vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True
46+
vllm_config["vllm_cfg"]["load_format"] = "dummy"
47+
vllm_config["pad_token"] = tokenizer.pad_token_id
48+
vllm_config["stop_token_ids"] = [tokenizer.eos_token_id]
49+
return vllm_config
50+
51+
4252
@pytest.fixture(scope="module")
4353
def check_vllm_available():
4454
"""Skip tests if vLLM is not installed."""
@@ -74,9 +84,12 @@ def tokenizer():
7484

7585

7686
@pytest.fixture(scope="function")
77-
def policy(cluster, check_vllm_available):
87+
def policy(cluster, tokenizer, check_vllm_available):
7888
"""Initialize the vLLM policy."""
79-
policy = VllmGeneration(cluster, basic_vllm_test_config)
89+
# Create separate configs for each policy
90+
vllm_config = basic_vllm_test_config.copy()
91+
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)
92+
policy = VllmGeneration(cluster, vllm_config)
8093
yield policy
8194

8295
# Ensure policy is properly shutdown
@@ -121,6 +134,30 @@ def test_input_data(tokenizer):
121134
)
122135

123136

137+
def test_vllm_missing_required_config_key(cluster, check_vllm_available):
138+
"""Test that an assertion error is raised when a required config key is missing."""
139+
# Create a config missing a required key by removing 'model_name'
140+
incomplete_config = basic_vllm_test_config.copy()
141+
del incomplete_config["model_name"] # Remove a required key
142+
143+
# Also need to ensure skip_tokenizer_init and load_format are there
144+
# since these are checked in VllmConfig.__annotations__
145+
incomplete_config["skip_tokenizer_init"] = True
146+
incomplete_config["load_format"] = "auto"
147+
148+
# Attempt to initialize VllmGeneration with incomplete config - should raise AssertionError
149+
with pytest.raises(AssertionError) as excinfo:
150+
VllmGeneration(cluster, incomplete_config)
151+
152+
# Verify the error message contains information about the missing key
153+
error_message = str(excinfo.value)
154+
assert "Missing required keys in VllmConfig" in error_message
155+
assert "model_name" in error_message, (
156+
"Error should mention the missing 'model_name' key"
157+
)
158+
print(f"Successfully caught missing config key with error: {error_message}")
159+
160+
124161
def test_vllm_policy_generation(policy, test_input_data, tokenizer):
125162
"""Test vLLM policy generation capabilities."""
126163
# Test generation
@@ -171,6 +208,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
171208

172209
# Create separate configs for each policy
173210
vllm_config = basic_vllm_test_config.copy()
211+
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)
174212

175213
# Create HF-specific config with required parameters
176214
hf_config = {
@@ -359,6 +397,7 @@ def test_vllm_policy_tensor_parallel(cluster, tokenizer):
359397
"""Test vLLM policy with tensor parallelism > 1."""
360398
# Configure with tensor_parallel_size=2
361399
tp_config = basic_vllm_test_config.copy()
400+
tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer)
362401
tp_config["tensor_parallel_size"] = 2
363402

364403
# Ensure we specify the distributed executor backend
@@ -420,6 +459,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size):
420459

421460
# Create separate configs for each policy
422461
vllm_config = basic_vllm_test_config.copy()
462+
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)
423463
vllm_config["tensor_parallel_size"] = tensor_parallel_size
424464

425465
# Add vllm_kwargs only if using tensor parallelism

0 commit comments

Comments
 (0)