Skip to content

Commit 85a0c8f

Browse files
s-noghabiThe tunix Authors
authored andcommitted
improve rl_cluster init corner case handling and test coverage
PiperOrigin-RevId: 889503171
1 parent bf94583 commit 85a0c8f

File tree

2 files changed

+296
-17
lines changed

2 files changed

+296
-17
lines changed

tests/rl/rl_cluster_test.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,271 @@ def test_generate_with_chat_template(self): # pylint: disable=g-doc-args
322322
called_prompts = rl_cluster.rollout.generate.call_args[0][0]
323323
self.assertEqual(called_prompts, ['formatted prompt'])
324324

325+
def _create_test_rl_cluster(
326+
self,
327+
rollout_engine: str,
328+
rollout_config: base_rollout.RolloutConfig,
329+
) -> rl_cluster_lib.RLCluster:
330+
split_index = self.device_count // 2
331+
actor_mesh = Mesh(
332+
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
333+
('fsdp', 'tp'),
334+
)
335+
rollout_mesh = Mesh(
336+
np.array(jax.devices()[split_index:]).reshape(1, split_index),
337+
('fsdp', 'tp'),
338+
)
339+
cluster_config = rl_cluster_lib.ClusterConfig(
340+
role_to_mesh={
341+
rl_cluster_lib.Role.ACTOR: actor_mesh,
342+
rl_cluster_lib.Role.REFERENCE: actor_mesh,
343+
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
344+
},
345+
rollout_engine=rollout_engine,
346+
offload_to_cpu=False,
347+
training_config=rl_cluster_lib.RLTrainingConfig(
348+
actor_optimizer=optax.sgd(1e-3),
349+
eval_every_n_steps=1,
350+
max_steps=10,
351+
gradient_accumulation_steps=None,
352+
),
353+
rollout_config=rollout_config,
354+
)
355+
vocab = tc.MockVocab()
356+
model = tc.ToyTransformer(
357+
config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), rngs=nnx.Rngs(0)
358+
)
359+
return rl_cluster_lib.RLCluster(
360+
actor=model, tokenizer=vocab, cluster_config=cluster_config
361+
)
362+
363+
def test_init_cluster_invalid_engine_string(self):
364+
with self.assertRaisesRegex(
365+
ValueError, '`cluster_config.rollout_engine` should be one of'
366+
):
367+
self._create_test_rl_cluster(
368+
'invalid_engine', base_rollout.RolloutConfig()
369+
)
370+
371+
@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
372+
def test_init_rollout_engine_missing_config_raises_error(self, engine):
373+
with self.assertRaisesRegex(
374+
ValueError, '`cluster_config.rollout_config` cannot be None.'
375+
):
376+
self._create_test_rl_cluster(engine, None)
377+
378+
@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
379+
def test_init_rollout_engine_empty_dict_config_raises_error(self, engine):
380+
with self.assertRaisesRegex(
381+
ValueError,
382+
'Rollout config is a dict but missing a train config.',
383+
):
384+
self._create_test_rl_cluster(engine, {})
385+
386+
@parameterized.named_parameters(
387+
dict(
388+
testcase_name='single_config',
389+
rollout_config=base_rollout.RolloutConfig(
390+
max_tokens_to_generate=10,
391+
kv_cache_size=1024,
392+
data_type=jnp.bfloat16,
393+
),
394+
expected_cache_size=1024,
395+
),
396+
dict(
397+
testcase_name='dict_config',
398+
rollout_config={
399+
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
400+
max_tokens_to_generate=10,
401+
kv_cache_size=1024,
402+
data_type=jnp.bfloat16,
403+
),
404+
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
405+
max_tokens_to_generate=10,
406+
kv_cache_size=2048,
407+
data_type=jnp.bfloat16,
408+
),
409+
},
410+
expected_cache_size=2048,
411+
),
412+
)
413+
@mock.patch.object(
414+
rl_cluster_lib.vanilla_rollout, 'VanillaRollout', autospec=True
415+
)
416+
def test_init_vanilla_rollout_engine(
417+
self, mock_vanilla_cls, rollout_config, expected_cache_size
418+
):
419+
rl_cluster = self._create_test_rl_cluster('vanilla', rollout_config)
420+
421+
mock_vanilla_cls.assert_called_once()
422+
self.assertEqual(rl_cluster.rollout, mock_vanilla_cls.return_value)
423+
called_kwargs = mock_vanilla_cls.call_args.kwargs
424+
self.assertIsInstance(
425+
called_kwargs['cache_config_or_size'], base_rollout.CacheConfig
426+
)
427+
self.assertEqual(
428+
called_kwargs['cache_config_or_size'].cache_size, expected_cache_size
429+
)
430+
431+
def test_init_vanilla_rollout_engine_missing_model_config(self):
432+
split_index = self.device_count // 2
433+
actor_mesh = Mesh(
434+
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
435+
('fsdp', 'tp'),
436+
)
437+
cluster_config = rl_cluster_lib.ClusterConfig(
438+
role_to_mesh={
439+
rl_cluster_lib.Role.ACTOR: actor_mesh,
440+
rl_cluster_lib.Role.REFERENCE: actor_mesh,
441+
rl_cluster_lib.Role.ROLLOUT: actor_mesh,
442+
},
443+
rollout_engine='vanilla',
444+
offload_to_cpu=False,
445+
training_config=rl_cluster_lib.RLTrainingConfig(
446+
actor_optimizer=optax.sgd(1e-3),
447+
eval_every_n_steps=1,
448+
),
449+
rollout_config=base_rollout.RolloutConfig(),
450+
)
451+
452+
# A dummy model without config
453+
class DummyModel(nnx.Module):
454+
455+
def __init__(self):
456+
self.w = nnx.Param(jnp.zeros((1,)))
457+
458+
with self.assertRaisesRegex(
459+
ValueError, '`self.rollout_actor` must have a config attribute.'
460+
):
461+
rl_cluster_lib.RLCluster(
462+
actor=DummyModel(),
463+
tokenizer=tc.MockVocab(),
464+
cluster_config=cluster_config,
465+
)
466+
467+
@parameterized.named_parameters(
468+
dict(
469+
testcase_name='single_config',
470+
rollout_config=base_rollout.RolloutConfig(
471+
max_tokens_to_generate=10,
472+
kv_cache_size=1024,
473+
rollout_vllm_model_version='dummy_version',
474+
),
475+
expected_train_config=base_rollout.RolloutConfig(
476+
max_tokens_to_generate=10,
477+
kv_cache_size=1024,
478+
rollout_vllm_model_version='dummy_version',
479+
),
480+
expected_cache_size=1024,
481+
),
482+
dict(
483+
testcase_name='dict_config',
484+
rollout_config={
485+
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
486+
max_tokens_to_generate=10,
487+
kv_cache_size=1024,
488+
rollout_vllm_model_version='dummy_version',
489+
),
490+
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
491+
max_tokens_to_generate=20,
492+
kv_cache_size=2048,
493+
rollout_vllm_model_version='dummy_version',
494+
),
495+
},
496+
expected_train_config=base_rollout.RolloutConfig(
497+
max_tokens_to_generate=10,
498+
kv_cache_size=1024,
499+
rollout_vllm_model_version='dummy_version',
500+
),
501+
expected_cache_size=2048,
502+
),
503+
)
504+
@mock.patch.object(rl_cluster_lib.vllm_rollout, 'VllmRollout', autospec=True)
505+
def test_init_vllm_rollout_engine(
506+
self,
507+
mock_vllm_cls,
508+
rollout_config,
509+
expected_train_config,
510+
expected_cache_size,
511+
):
512+
rl_cluster = self._create_test_rl_cluster('vllm', rollout_config)
513+
514+
mock_vllm_cls.assert_called_once()
515+
self.assertEqual(rl_cluster.rollout, mock_vllm_cls.return_value)
516+
called_kwargs = mock_vllm_cls.call_args.kwargs
517+
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
518+
self.assertEqual(called_kwargs['cache_config_or_size'], expected_cache_size)
519+
self.assertIn('mesh', called_kwargs)
520+
521+
def test_init_vllm_rollout_engine_missing_version_raises(self):
522+
rollout_config = base_rollout.RolloutConfig(
523+
rollout_vllm_model_version=None,
524+
)
525+
with self.assertRaisesRegex(
526+
ValueError, 'Rollout vllm model version or path is missing!'
527+
):
528+
self._create_test_rl_cluster('vllm', rollout_config)
529+
530+
@parameterized.named_parameters(
531+
dict(
532+
testcase_name='single_config',
533+
rollout_config=base_rollout.RolloutConfig(
534+
max_tokens_to_generate=10, kv_cache_size=1024
535+
),
536+
expected_train_config=base_rollout.RolloutConfig(
537+
max_tokens_to_generate=10, kv_cache_size=1024
538+
),
539+
),
540+
dict(
541+
testcase_name='dict_config',
542+
rollout_config={
543+
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
544+
max_tokens_to_generate=10, kv_cache_size=1024
545+
),
546+
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
547+
max_tokens_to_generate=20, kv_cache_size=2048
548+
),
549+
},
550+
expected_train_config=base_rollout.RolloutConfig(
551+
max_tokens_to_generate=10, kv_cache_size=1024
552+
),
553+
),
554+
)
555+
@mock.patch.object(
556+
rl_cluster_lib.sglang_jax_rollout, 'SglangJaxRollout', autospec=True
557+
)
558+
def test_init_sglang_jax_rollout_engine(
559+
self, mock_sglang_cls, rollout_config, expected_train_config
560+
):
561+
rl_cluster = self._create_test_rl_cluster('sglang_jax', rollout_config)
562+
563+
mock_sglang_cls.assert_called_once()
564+
self.assertEqual(rl_cluster.rollout, mock_sglang_cls.return_value)
565+
called_kwargs = mock_sglang_cls.call_args.kwargs
566+
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
567+
self.assertIn('mesh', called_kwargs)
568+
569+
@mock.patch.object(rl_cluster_lib.sft_utils, 'is_lora_enabled', autospec=True)
570+
def test_init_sglang_jax_rollout_engine_lora_error(self, mock_is_lora):
571+
mock_is_lora.return_value = True
572+
rollout_config = base_rollout.RolloutConfig(
573+
rollout_sglang_jax_enable_static_lora=False
574+
)
575+
576+
with self.assertRaisesRegex(
577+
ValueError, 'Rollout sglang jax lora config is missing'
578+
):
579+
self._create_test_rl_cluster('sglang_jax', rollout_config)
580+
581+
def test_init_cluster_unsupported_engine_type(self):
582+
class InvalidEngine:
583+
pass
584+
585+
with self.assertRaisesRegex(
586+
NotImplementedError, 'Rollout engine .* not supported'
587+
):
588+
self._create_test_rl_cluster(InvalidEngine, base_rollout.RolloutConfig())
589+
325590
def test_user_defined_rollout_engine_class(self):
326591
class CustomRolloutEngine(base_rollout.BaseRollout):
327592

@@ -363,6 +628,13 @@ def model(self) -> nnx.Module:
363628
def update_params(self, params, filter_types):
364629
pass
365630

631+
@property
632+
def mesh(self):
633+
return Mesh(
634+
np.array(jax.devices()[:1]).reshape(1, 1),
635+
('fsdp', 'tp'),
636+
)
637+
366638
split_index = self.device_count // 2
367639

368640
actor_mesh = Mesh(
@@ -443,6 +715,9 @@ def create_cluster_config(rollout_engine):
443715
self.assertIsInstance(rl_cluster.rollout, CustomRolloutEngine)
444716
self.assertEqual(rl_cluster.rollout.my_arg, 0)
445717
self.assertEqual(rl_cluster.rollout.config, cluster_config.rollout_config)
718+
self.assertEqual(
719+
rl_cluster.r2m[rl_cluster_lib.Role.ROLLOUT], rl_cluster.rollout.mesh
720+
)
446721

447722
@parameterized.named_parameters(
448723
dict(

tunix/rl/rl_cluster.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ def __init__(
206206
self._default_memory_kind = jax.devices()[0].default_memory().kind
207207
self.train_actor = self._load_model(actor, self.r2m[Role.ACTOR])
208208

209+
if self.cluster_config.rollout_config is None:
210+
raise ValueError("`cluster_config.rollout_config` cannot be None.")
211+
if isinstance(
212+
self.cluster_config.rollout_config, dict
213+
) and not self.cluster_config.rollout_config.get(Mode.TRAIN):
214+
raise ValueError(
215+
"Rollout config is a dict but missing a train config. Provided"
216+
f" config: {self.cluster_config.rollout_config}"
217+
)
218+
209219
if Role.ROLLOUT in self._backbone_sharing_map[Role.ACTOR]:
210220
self.rollout_actor = self.train_actor
211221
elif self.cluster_config.rollout_engine == "vanilla":
@@ -368,10 +378,14 @@ def _init_cluster(self):
368378
" `'vllm'` or `'sglang_jax'`. Received:"
369379
f" '{self.cluster_config.rollout_engine}'."
370380
)
381+
371382
if isinstance(self.cluster_config.rollout_config, dict):
383+
# train_cfg should always be provided.
384+
train_cfg = self.cluster_config.rollout_config[Mode.TRAIN]
385+
eval_cfg = self.cluster_config.rollout_config.get(Mode.EVAL)
372386
max_kv_cache_size = max(
373-
self.cluster_config.rollout_config[Mode.TRAIN].kv_cache_size,
374-
self.cluster_config.rollout_config[Mode.EVAL].kv_cache_size,
387+
train_cfg.kv_cache_size,
388+
eval_cfg.kv_cache_size if eval_cfg is not None else 0,
375389
)
376390
else:
377391
max_kv_cache_size = self.cluster_config.rollout_config.kv_cache_size
@@ -396,16 +410,10 @@ def _init_cluster(self):
396410
elif self.cluster_config.rollout_engine == "vllm":
397411
from tunix.rl.rollout import vllm_rollout
398412

399-
loaded_vllm_config = None
400-
if isinstance(
401-
self.cluster_config.rollout_config, base_rollout.RolloutConfig
402-
):
403-
loaded_vllm_config = self.cluster_config.rollout_config
404-
elif isinstance(self.cluster_config.rollout_config, dict):
413+
if isinstance(self.cluster_config.rollout_config, dict):
405414
loaded_vllm_config = self.cluster_config.rollout_config[Mode.TRAIN]
406-
407-
if loaded_vllm_config is None:
408-
raise ValueError("Rollout vllm model config is missing!")
415+
else:
416+
loaded_vllm_config = self.cluster_config.rollout_config
409417

410418
if loaded_vllm_config.rollout_vllm_model_version is None:
411419
raise ValueError("Rollout vllm model version or path is missing!")
@@ -427,16 +435,12 @@ def _init_cluster(self):
427435
elif self.cluster_config.rollout_engine == "sglang_jax":
428436
from tunix.rl.rollout import sglang_jax_rollout
429437

430-
if isinstance(
431-
self.cluster_config.rollout_config, base_rollout.RolloutConfig
432-
):
433-
loaded_sglang_jax_config = self.cluster_config.rollout_config
434-
elif isinstance(self.cluster_config.rollout_config, dict):
438+
if isinstance(self.cluster_config.rollout_config, dict):
435439
loaded_sglang_jax_config = self.cluster_config.rollout_config[
436440
Mode.TRAIN
437441
]
438442
else:
439-
raise ValueError("Rollout sglang jax model config is missing!")
443+
loaded_sglang_jax_config = self.cluster_config.rollout_config
440444

441445
if (
442446
sft_utils.is_lora_enabled(self.rollout_actor)

0 commit comments

Comments
 (0)