@@ -92,9 +92,18 @@ def get_maxtext_model(config, devices=None):
9292 # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
9393 # load_parameters_path=/path/to/your/output/directory/0/items
9494 """
95- model , mesh = model_creation_utils .create_nnx_model (config , devices = devices )
95+ model , mesh = model_creation_utils .create_nnx_model (config , devices = devices , model_mode = "train" )
9696 with mesh :
97- tunix_model = TunixMaxTextAdapter (base_model = model )
97+ if "maxtext_config" in config .vllm_additional_config :
98+ use_standalone_mappings = False
99+ use_no_op_mappings = True
100+ else :
101+ use_standalone_mappings = True
102+ use_no_op_mappings = False
103+
104+ tunix_model = TunixMaxTextAdapter (
105+ base_model = model , use_standalone_mappings = use_standalone_mappings , use_no_op_mappings = use_no_op_mappings
106+ )
98107 tunix_model .config = None
99108 return tunix_model , mesh
100109
@@ -323,6 +332,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323332 set_profile_options = False ,
324333 )
325334
335+ # Parse vllm_additional_config
336+ rollout_additional_config = None
337+ if trainer_config .vllm_additional_config :
338+ if isinstance (trainer_config .vllm_additional_config , dict ):
339+ # It's already parsed into a dict
340+ rollout_additional_config = trainer_config .vllm_additional_config
341+ elif isinstance (trainer_config .vllm_additional_config , str ):
342+ # It's a string, so we need to parse it
343+ try :
344+ rollout_additional_config = json .loads (trainer_config .vllm_additional_config )
345+ except json .JSONDecodeError as e :
346+ raise ValueError (f"Failed to parse additional_config JSON: { e } " ) from e
347+
348+ max_logging .log (f"Parsed additional config: { rollout_additional_config } " )
349+
326350 # RL Cluster config
327351 # Note that we use vLLM as the rollout engine.
328352 # and we are using Tensor Parallelism for rollout
@@ -361,6 +385,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
361385 rollout_vllm_hbm_utilization = trainer_config .hbm_utilization_vllm ,
362386 rollout_vllm_tpu_backend_type = "jax" ,
363387 rollout_vllm_swap_space_size_gb = trainer_config .swap_space_vllm_gb ,
388+ rollout_vllm_hf_config_path = trainer_config .vllm_hf_config_path ,
389+ rollout_vllm_additional_config = rollout_additional_config ,
390+ rollout_vllm_init_with_random_weights = False ,
364391 ),
365392 )
366393 grpo_config = GrpoConfig (
@@ -389,14 +416,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
389416 max_logging .log (
390417 "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
391418 )
392- with nn_partitioning . axis_rules ( trainer_config . logical_axis_rules ):
393- rl_cluster = rl_cluster_lib .RLCluster (
394- actor = actor_model ,
395- reference = reference_model ,
396- tokenizer = model_tokenizer ,
397- cluster_config = cluster_config ,
398- ** rl_cluster_kwargs ,
399- )
419+
420+ rl_cluster = rl_cluster_lib .RLCluster (
421+ actor = actor_model ,
422+ reference = reference_model ,
423+ tokenizer = model_tokenizer ,
424+ cluster_config = cluster_config ,
425+ ** rl_cluster_kwargs ,
426+ )
400427
401428 # Create RL trainer
402429 max_logging .log ("Setting up RL trainer..." )
0 commit comments