@@ -322,6 +322,271 @@ def test_generate_with_chat_template(self): # pylint: disable=g-doc-args
322322 called_prompts = rl_cluster .rollout .generate .call_args [0 ][0 ]
323323 self .assertEqual (called_prompts , ['formatted prompt' ])
324324
325+ def _create_test_rl_cluster (
326+ self ,
327+ rollout_engine : str ,
328+ rollout_config : base_rollout .RolloutConfig ,
329+ ) -> rl_cluster_lib .RLCluster :
330+ split_index = self .device_count // 2
331+ actor_mesh = Mesh (
332+ np .array (jax .devices ()[:split_index ]).reshape (split_index , 1 ),
333+ ('fsdp' , 'tp' ),
334+ )
335+ rollout_mesh = Mesh (
336+ np .array (jax .devices ()[split_index :]).reshape (1 , split_index ),
337+ ('fsdp' , 'tp' ),
338+ )
339+ cluster_config = rl_cluster_lib .ClusterConfig (
340+ role_to_mesh = {
341+ rl_cluster_lib .Role .ACTOR : actor_mesh ,
342+ rl_cluster_lib .Role .REFERENCE : actor_mesh ,
343+ rl_cluster_lib .Role .ROLLOUT : rollout_mesh ,
344+ },
345+ rollout_engine = rollout_engine ,
346+ offload_to_cpu = False ,
347+ training_config = rl_cluster_lib .RLTrainingConfig (
348+ actor_optimizer = optax .sgd (1e-3 ),
349+ eval_every_n_steps = 1 ,
350+ max_steps = 10 ,
351+ gradient_accumulation_steps = None ,
352+ ),
353+ rollout_config = rollout_config ,
354+ )
355+ vocab = tc .MockVocab ()
356+ model = tc .ToyTransformer (
357+ config = tc .ModelConfig (vocab_size = vocab .GetPieceSize ()), rngs = nnx .Rngs (0 )
358+ )
359+ return rl_cluster_lib .RLCluster (
360+ actor = model , tokenizer = vocab , cluster_config = cluster_config
361+ )
362+
363+ def test_init_cluster_invalid_engine_string (self ):
364+ with self .assertRaisesRegex (
365+ ValueError , '`cluster_config.rollout_engine` should be one of'
366+ ):
367+ self ._create_test_rl_cluster (
368+ 'invalid_engine' , base_rollout .RolloutConfig ()
369+ )
370+
371+ @parameterized .parameters ('vanilla' , 'vllm' , 'sglang_jax' )
372+ def test_init_rollout_engine_missing_config_raises_error (self , engine ):
373+ with self .assertRaisesRegex (
374+ ValueError , '`cluster_config.rollout_config` cannot be None.'
375+ ):
376+ self ._create_test_rl_cluster (engine , None )
377+
378+ @parameterized .parameters ('vanilla' , 'vllm' , 'sglang_jax' )
379+ def test_init_rollout_engine_empty_dict_config_raises_error (self , engine ):
380+ with self .assertRaisesRegex (
381+ ValueError ,
382+ 'Rollout config is a dict but missing a train config.' ,
383+ ):
384+ self ._create_test_rl_cluster (engine , {})
385+
386+ @parameterized .named_parameters (
387+ dict (
388+ testcase_name = 'single_config' ,
389+ rollout_config = base_rollout .RolloutConfig (
390+ max_tokens_to_generate = 10 ,
391+ kv_cache_size = 1024 ,
392+ data_type = jnp .bfloat16 ,
393+ ),
394+ expected_cache_size = 1024 ,
395+ ),
396+ dict (
397+ testcase_name = 'dict_config' ,
398+ rollout_config = {
399+ rl_cluster_lib .Mode .TRAIN : base_rollout .RolloutConfig (
400+ max_tokens_to_generate = 10 ,
401+ kv_cache_size = 1024 ,
402+ data_type = jnp .bfloat16 ,
403+ ),
404+ rl_cluster_lib .Mode .EVAL : base_rollout .RolloutConfig (
405+ max_tokens_to_generate = 10 ,
406+ kv_cache_size = 2048 ,
407+ data_type = jnp .bfloat16 ,
408+ ),
409+ },
410+ expected_cache_size = 2048 ,
411+ ),
412+ )
413+ @mock .patch .object (
414+ rl_cluster_lib .vanilla_rollout , 'VanillaRollout' , autospec = True
415+ )
416+ def test_init_vanilla_rollout_engine (
417+ self , mock_vanilla_cls , rollout_config , expected_cache_size
418+ ):
419+ rl_cluster = self ._create_test_rl_cluster ('vanilla' , rollout_config )
420+
421+ mock_vanilla_cls .assert_called_once ()
422+ self .assertEqual (rl_cluster .rollout , mock_vanilla_cls .return_value )
423+ called_kwargs = mock_vanilla_cls .call_args .kwargs
424+ self .assertIsInstance (
425+ called_kwargs ['cache_config_or_size' ], base_rollout .CacheConfig
426+ )
427+ self .assertEqual (
428+ called_kwargs ['cache_config_or_size' ].cache_size , expected_cache_size
429+ )
430+
431+ def test_init_vanilla_rollout_engine_missing_model_config (self ):
432+ split_index = self .device_count // 2
433+ actor_mesh = Mesh (
434+ np .array (jax .devices ()[:split_index ]).reshape (split_index , 1 ),
435+ ('fsdp' , 'tp' ),
436+ )
437+ cluster_config = rl_cluster_lib .ClusterConfig (
438+ role_to_mesh = {
439+ rl_cluster_lib .Role .ACTOR : actor_mesh ,
440+ rl_cluster_lib .Role .REFERENCE : actor_mesh ,
441+ rl_cluster_lib .Role .ROLLOUT : actor_mesh ,
442+ },
443+ rollout_engine = 'vanilla' ,
444+ offload_to_cpu = False ,
445+ training_config = rl_cluster_lib .RLTrainingConfig (
446+ actor_optimizer = optax .sgd (1e-3 ),
447+ eval_every_n_steps = 1 ,
448+ ),
449+ rollout_config = base_rollout .RolloutConfig (),
450+ )
451+
452+ # A dummy model without config
453+ class DummyModel (nnx .Module ):
454+
455+ def __init__ (self ):
456+ self .w = nnx .Param (jnp .zeros ((1 ,)))
457+
458+ with self .assertRaisesRegex (
459+ ValueError , '`self.rollout_actor` must have a config attribute.'
460+ ):
461+ rl_cluster_lib .RLCluster (
462+ actor = DummyModel (),
463+ tokenizer = tc .MockVocab (),
464+ cluster_config = cluster_config ,
465+ )
466+
467+ @parameterized .named_parameters (
468+ dict (
469+ testcase_name = 'single_config' ,
470+ rollout_config = base_rollout .RolloutConfig (
471+ max_tokens_to_generate = 10 ,
472+ kv_cache_size = 1024 ,
473+ rollout_vllm_model_version = 'dummy_version' ,
474+ ),
475+ expected_train_config = base_rollout .RolloutConfig (
476+ max_tokens_to_generate = 10 ,
477+ kv_cache_size = 1024 ,
478+ rollout_vllm_model_version = 'dummy_version' ,
479+ ),
480+ expected_cache_size = 1024 ,
481+ ),
482+ dict (
483+ testcase_name = 'dict_config' ,
484+ rollout_config = {
485+ rl_cluster_lib .Mode .TRAIN : base_rollout .RolloutConfig (
486+ max_tokens_to_generate = 10 ,
487+ kv_cache_size = 1024 ,
488+ rollout_vllm_model_version = 'dummy_version' ,
489+ ),
490+ rl_cluster_lib .Mode .EVAL : base_rollout .RolloutConfig (
491+ max_tokens_to_generate = 20 ,
492+ kv_cache_size = 2048 ,
493+ rollout_vllm_model_version = 'dummy_version' ,
494+ ),
495+ },
496+ expected_train_config = base_rollout .RolloutConfig (
497+ max_tokens_to_generate = 10 ,
498+ kv_cache_size = 1024 ,
499+ rollout_vllm_model_version = 'dummy_version' ,
500+ ),
501+ expected_cache_size = 2048 ,
502+ ),
503+ )
504+ @mock .patch .object (rl_cluster_lib .vllm_rollout , 'VllmRollout' , autospec = True )
505+ def test_init_vllm_rollout_engine (
506+ self ,
507+ mock_vllm_cls ,
508+ rollout_config ,
509+ expected_train_config ,
510+ expected_cache_size ,
511+ ):
512+ rl_cluster = self ._create_test_rl_cluster ('vllm' , rollout_config )
513+
514+ mock_vllm_cls .assert_called_once ()
515+ self .assertEqual (rl_cluster .rollout , mock_vllm_cls .return_value )
516+ called_kwargs = mock_vllm_cls .call_args .kwargs
517+ self .assertEqual (called_kwargs ['rollout_config' ], expected_train_config )
518+ self .assertEqual (called_kwargs ['cache_config_or_size' ], expected_cache_size )
519+ self .assertIn ('mesh' , called_kwargs )
520+
521+ def test_init_vllm_rollout_engine_missing_version_raises (self ):
522+ rollout_config = base_rollout .RolloutConfig (
523+ rollout_vllm_model_version = None ,
524+ )
525+ with self .assertRaisesRegex (
526+ ValueError , 'Rollout vllm model version or path is missing!'
527+ ):
528+ self ._create_test_rl_cluster ('vllm' , rollout_config )
529+
530+ @parameterized .named_parameters (
531+ dict (
532+ testcase_name = 'single_config' ,
533+ rollout_config = base_rollout .RolloutConfig (
534+ max_tokens_to_generate = 10 , kv_cache_size = 1024
535+ ),
536+ expected_train_config = base_rollout .RolloutConfig (
537+ max_tokens_to_generate = 10 , kv_cache_size = 1024
538+ ),
539+ ),
540+ dict (
541+ testcase_name = 'dict_config' ,
542+ rollout_config = {
543+ rl_cluster_lib .Mode .TRAIN : base_rollout .RolloutConfig (
544+ max_tokens_to_generate = 10 , kv_cache_size = 1024
545+ ),
546+ rl_cluster_lib .Mode .EVAL : base_rollout .RolloutConfig (
547+ max_tokens_to_generate = 20 , kv_cache_size = 2048
548+ ),
549+ },
550+ expected_train_config = base_rollout .RolloutConfig (
551+ max_tokens_to_generate = 10 , kv_cache_size = 1024
552+ ),
553+ ),
554+ )
555+ @mock .patch .object (
556+ rl_cluster_lib .sglang_jax_rollout , 'SglangJaxRollout' , autospec = True
557+ )
558+ def test_init_sglang_jax_rollout_engine (
559+ self , mock_sglang_cls , rollout_config , expected_train_config
560+ ):
561+ rl_cluster = self ._create_test_rl_cluster ('sglang_jax' , rollout_config )
562+
563+ mock_sglang_cls .assert_called_once ()
564+ self .assertEqual (rl_cluster .rollout , mock_sglang_cls .return_value )
565+ called_kwargs = mock_sglang_cls .call_args .kwargs
566+ self .assertEqual (called_kwargs ['rollout_config' ], expected_train_config )
567+ self .assertIn ('mesh' , called_kwargs )
568+
569+ @mock .patch .object (rl_cluster_lib .sft_utils , 'is_lora_enabled' , autospec = True )
570+ def test_init_sglang_jax_rollout_engine_lora_error (self , mock_is_lora ):
571+ mock_is_lora .return_value = True
572+ rollout_config = base_rollout .RolloutConfig (
573+ rollout_sglang_jax_enable_static_lora = False
574+ )
575+
576+ with self .assertRaisesRegex (
577+ ValueError , 'Rollout sglang jax lora config is missing'
578+ ):
579+ self ._create_test_rl_cluster ('sglang_jax' , rollout_config )
580+
581+ def test_init_cluster_unsupported_engine_type (self ):
582+ class InvalidEngine :
583+ pass
584+
585+ with self .assertRaisesRegex (
586+ NotImplementedError , 'Rollout engine .* not supported'
587+ ):
588+ self ._create_test_rl_cluster (InvalidEngine , base_rollout .RolloutConfig ())
589+
325590 def test_user_defined_rollout_engine_class (self ):
326591 class CustomRolloutEngine (base_rollout .BaseRollout ):
327592
@@ -363,6 +628,13 @@ def model(self) -> nnx.Module:
363628 def update_params (self , params , filter_types ):
364629 pass
365630
631+ @property
632+ def mesh (self ):
633+ return Mesh (
634+ np .array (jax .devices ()[:1 ]).reshape (1 , 1 ),
635+ ('fsdp' , 'tp' ),
636+ )
637+
366638 split_index = self .device_count // 2
367639
368640 actor_mesh = Mesh (
@@ -443,6 +715,9 @@ def create_cluster_config(rollout_engine):
443715 self .assertIsInstance (rl_cluster .rollout , CustomRolloutEngine )
444716 self .assertEqual (rl_cluster .rollout .my_arg , 0 )
445717 self .assertEqual (rl_cluster .rollout .config , cluster_config .rollout_config )
718+ self .assertEqual (
719+ rl_cluster .r2m [rl_cluster_lib .Role .ROLLOUT ], rl_cluster .rollout .mesh
720+ )
446721
447722 @parameterized .named_parameters (
448723 dict (
0 commit comments