Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions tests/rl/rl_cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import optax
from transformers import tokenization_utils_base
from tunix.generate import mappings
# Internal placeholder for sglang_jax rollout worker stub, don't change this line.
# Internal placeholder for vllm rollout worker stub, don't change this line.
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl import utils
from tunix.rl.rollout import base_rollout
Expand Down Expand Up @@ -322,6 +324,269 @@ def test_generate_with_chat_template(self): # pylint: disable=g-doc-args
called_prompts = rl_cluster.rollout.generate.call_args[0][0]
self.assertEqual(called_prompts, ['formatted prompt'])

def _create_test_rl_cluster(
self,
rollout_engine: str,
rollout_config: base_rollout.RolloutConfig,
) -> rl_cluster_lib.RLCluster:
split_index = self.device_count // 2
actor_mesh = Mesh(
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
('fsdp', 'tp'),
)
rollout_mesh = Mesh(
np.array(jax.devices()[split_index:]).reshape(1, split_index),
('fsdp', 'tp'),
)
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: actor_mesh,
rl_cluster_lib.Role.REFERENCE: actor_mesh,
rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
},
rollout_engine=rollout_engine,
offload_to_cpu=False,
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=1,
max_steps=10,
gradient_accumulation_steps=None,
),
rollout_config=rollout_config,
)
vocab = tc.MockVocab()
model = tc.ToyTransformer(
config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()), rngs=nnx.Rngs(0)
)
return rl_cluster_lib.RLCluster(
actor=model, tokenizer=vocab, cluster_config=cluster_config
)

def test_init_cluster_invalid_engine_string(self):
with self.assertRaisesRegex(
ValueError, '`cluster_config.rollout_engine` should be one of'
):
self._create_test_rl_cluster(
'invalid_engine', base_rollout.RolloutConfig()
)

@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
def test_init_rollout_engine_missing_config_raises_error(self, engine):
with self.assertRaisesRegex(
ValueError, '`cluster_config.rollout_config` cannot be None.'
):
self._create_test_rl_cluster(engine, None)

@parameterized.parameters('vanilla', 'vllm', 'sglang_jax')
def test_init_rollout_engine_empty_dict_config_raises_error(self, engine):
with self.assertRaisesRegex(
ValueError,
'Rollout config is a dict but missing a train config.',
):
self._create_test_rl_cluster(engine, {})

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
data_type=jnp.bfloat16,
),
expected_cache_size=1024,
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
data_type=jnp.bfloat16,
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=2048,
data_type=jnp.bfloat16,
),
},
expected_cache_size=2048,
),
)
@mock.patch.object(
rl_cluster_lib.vanilla_rollout, 'VanillaRollout', autospec=True
)
def test_init_vanilla_rollout_engine(
self, mock_vanilla_cls, rollout_config, expected_cache_size
):
rl_cluster = self._create_test_rl_cluster('vanilla', rollout_config)

mock_vanilla_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_vanilla_cls.return_value)
called_kwargs = mock_vanilla_cls.call_args.kwargs
self.assertIsInstance(
called_kwargs['cache_config_or_size'], base_rollout.CacheConfig
)
self.assertEqual(
called_kwargs['cache_config_or_size'].cache_size, expected_cache_size
)

def test_init_vanilla_rollout_engine_missing_model_config(self):
split_index = self.device_count // 2
actor_mesh = Mesh(
np.array(jax.devices()[:split_index]).reshape(split_index, 1),
('fsdp', 'tp'),
)
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: actor_mesh,
rl_cluster_lib.Role.REFERENCE: actor_mesh,
rl_cluster_lib.Role.ROLLOUT: actor_mesh,
},
rollout_engine='vanilla',
offload_to_cpu=False,
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=1,
),
rollout_config=base_rollout.RolloutConfig(),
)

# A dummy model without config
class DummyModel(nnx.Module):

def __init__(self):
self.w = nnx.Param(jnp.zeros((1,)))

with self.assertRaisesRegex(
ValueError, '`self.rollout_actor` must have a config attribute.'
):
rl_cluster_lib.RLCluster(
actor=DummyModel(),
tokenizer=tc.MockVocab(),
cluster_config=cluster_config,
)

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_cache_size=1024,
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=20,
kv_cache_size=2048,
rollout_vllm_model_version='dummy_version',
),
},
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
kv_cache_size=1024,
rollout_vllm_model_version='dummy_version',
),
expected_cache_size=2048,
),
)
@mock.patch.object(vllm_rollout, 'VllmRollout', autospec=True)
def test_init_vllm_rollout_engine(
self,
mock_vllm_cls,
rollout_config,
expected_train_config,
expected_cache_size,
):
rl_cluster = self._create_test_rl_cluster('vllm', rollout_config)

mock_vllm_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_vllm_cls.return_value)
called_kwargs = mock_vllm_cls.call_args.kwargs
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
self.assertEqual(called_kwargs['cache_config_or_size'], expected_cache_size)
self.assertIn('mesh', called_kwargs)

def test_init_vllm_rollout_engine_missing_version_raises(self):
rollout_config = base_rollout.RolloutConfig(
rollout_vllm_model_version=None,
)
with self.assertRaisesRegex(
ValueError, 'Rollout vllm model version or path is missing!'
):
self._create_test_rl_cluster('vllm', rollout_config)

@parameterized.named_parameters(
dict(
testcase_name='single_config',
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
dict(
testcase_name='dict_config',
rollout_config={
rl_cluster_lib.Mode.TRAIN: base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
rl_cluster_lib.Mode.EVAL: base_rollout.RolloutConfig(
max_tokens_to_generate=20, kv_cache_size=2048
),
},
expected_train_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10, kv_cache_size=1024
),
),
)
@mock.patch.object(sglang_jax_rollout, 'SglangJaxRollout', autospec=True)
def test_init_sglang_jax_rollout_engine(
self, mock_sglang_cls, rollout_config, expected_train_config
):
rl_cluster = self._create_test_rl_cluster('sglang_jax', rollout_config)

mock_sglang_cls.assert_called_once()
self.assertEqual(rl_cluster.rollout, mock_sglang_cls.return_value)
called_kwargs = mock_sglang_cls.call_args.kwargs
self.assertEqual(called_kwargs['rollout_config'], expected_train_config)
self.assertIn('mesh', called_kwargs)

@mock.patch.object(rl_cluster_lib.sft_utils, 'is_lora_enabled', autospec=True)
def test_init_sglang_jax_rollout_engine_lora_error(self, mock_is_lora):
mock_is_lora.return_value = True
rollout_config = base_rollout.RolloutConfig(
rollout_sglang_jax_enable_static_lora=False
)

with self.assertRaisesRegex(
ValueError, 'Rollout sglang jax lora config is missing'
):
self._create_test_rl_cluster('sglang_jax', rollout_config)

def test_init_cluster_unsupported_engine_type(self):
class InvalidEngine:
pass

with self.assertRaisesRegex(
NotImplementedError, 'Rollout engine .* not supported'
):
self._create_test_rl_cluster(InvalidEngine, base_rollout.RolloutConfig())

def test_user_defined_rollout_engine_class(self):
class CustomRolloutEngine(base_rollout.BaseRollout):

Expand Down Expand Up @@ -363,6 +628,13 @@ def model(self) -> nnx.Module:
def update_params(self, params, filter_types):
pass

@property
def mesh(self):
return Mesh(
np.array(jax.devices()[:1]).reshape(1, 1),
('fsdp', 'tp'),
)

split_index = self.device_count // 2

actor_mesh = Mesh(
Expand Down Expand Up @@ -443,6 +715,9 @@ def create_cluster_config(rollout_engine):
self.assertIsInstance(rl_cluster.rollout, CustomRolloutEngine)
self.assertEqual(rl_cluster.rollout.my_arg, 0)
self.assertEqual(rl_cluster.rollout.config, cluster_config.rollout_config)
self.assertEqual(
rl_cluster.r2m[rl_cluster_lib.Role.ROLLOUT], rl_cluster.rollout.mesh
)

@parameterized.named_parameters(
dict(
Expand Down
38 changes: 21 additions & 17 deletions tunix/rl/rl_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def __init__(
self._default_memory_kind = jax.devices()[0].default_memory().kind
self.train_actor = self._load_model(actor, self.r2m[Role.ACTOR])

if self.cluster_config.rollout_config is None:
raise ValueError("`cluster_config.rollout_config` cannot be None.")
if isinstance(
self.cluster_config.rollout_config, dict
) and not self.cluster_config.rollout_config.get(Mode.TRAIN):
raise ValueError(
"Rollout config is a dict but missing a train config. Provided"
f" config: {self.cluster_config.rollout_config}"
)

if Role.ROLLOUT in self._backbone_sharing_map[Role.ACTOR]:
self.rollout_actor = self.train_actor
elif self.cluster_config.rollout_engine == "vanilla":
Expand Down Expand Up @@ -368,10 +378,14 @@ def _init_cluster(self):
" `'vllm'` or `'sglang_jax'`. Received:"
f" '{self.cluster_config.rollout_engine}'."
)

if isinstance(self.cluster_config.rollout_config, dict):
# train_cfg should always be provided.
train_cfg = self.cluster_config.rollout_config[Mode.TRAIN]
eval_cfg = self.cluster_config.rollout_config.get(Mode.EVAL)
max_kv_cache_size = max(
self.cluster_config.rollout_config[Mode.TRAIN].kv_cache_size,
self.cluster_config.rollout_config[Mode.EVAL].kv_cache_size,
train_cfg.kv_cache_size,
eval_cfg.kv_cache_size if eval_cfg is not None else 0,
)
else:
max_kv_cache_size = self.cluster_config.rollout_config.kv_cache_size
Expand All @@ -396,16 +410,10 @@ def _init_cluster(self):
elif self.cluster_config.rollout_engine == "vllm":
from tunix.rl.rollout import vllm_rollout

loaded_vllm_config = None
if isinstance(
self.cluster_config.rollout_config, base_rollout.RolloutConfig
):
loaded_vllm_config = self.cluster_config.rollout_config
elif isinstance(self.cluster_config.rollout_config, dict):
if isinstance(self.cluster_config.rollout_config, dict):
loaded_vllm_config = self.cluster_config.rollout_config[Mode.TRAIN]

if loaded_vllm_config is None:
raise ValueError("Rollout vllm model config is missing!")
else:
loaded_vllm_config = self.cluster_config.rollout_config

if loaded_vllm_config.rollout_vllm_model_version is None:
raise ValueError("Rollout vllm model version or path is missing!")
Expand All @@ -427,16 +435,12 @@ def _init_cluster(self):
elif self.cluster_config.rollout_engine == "sglang_jax":
from tunix.rl.rollout import sglang_jax_rollout

if isinstance(
self.cluster_config.rollout_config, base_rollout.RolloutConfig
):
loaded_sglang_jax_config = self.cluster_config.rollout_config
elif isinstance(self.cluster_config.rollout_config, dict):
if isinstance(self.cluster_config.rollout_config, dict):
loaded_sglang_jax_config = self.cluster_config.rollout_config[
Mode.TRAIN
]
else:
raise ValueError("Rollout sglang jax model config is missing!")
loaded_sglang_jax_config = self.cluster_config.rollout_config

if (
sft_utils.is_lora_enabled(self.rollout_actor)
Expand Down
Loading