|
48 | 48 | import collections |
49 | 49 | import grain |
50 | 50 | import jax |
| 51 | +import json |
51 | 52 | import os |
52 | 53 | import pathwaysutils |
53 | 54 | import tensorflow_datasets as tfds |
@@ -92,9 +93,18 @@ def get_maxtext_model(config, devices=None): |
92 | 93 | # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., |
93 | 94 | # load_parameters_path=/path/to/your/output/directory/0/items |
94 | 95 | """ |
95 | | - model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) |
| 96 | + model, mesh = model_creation_utils.create_nnx_model(config, devices=devices, model_mode="train") |
96 | 97 | with mesh: |
97 | | - tunix_model = TunixMaxTextAdapter(base_model=model) |
| 98 | + if "maxtext_config" in config.vllm_additional_config: |
| 99 | + use_standalone_mappings = False |
| 100 | + use_no_op_mappings = True |
| 101 | + else: |
| 102 | + use_standalone_mappings = True |
| 103 | + use_no_op_mappings = False |
| 104 | + |
| 105 | + tunix_model = TunixMaxTextAdapter( |
| 106 | + base_model=model, use_standalone_mappings=use_standalone_mappings, use_no_op_mappings=use_no_op_mappings |
| 107 | + ) |
98 | 108 | tunix_model.config = None |
99 | 109 | return tunix_model, mesh |
100 | 110 |
|
@@ -323,6 +333,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): |
323 | 333 | set_profile_options=False, |
324 | 334 | ) |
325 | 335 |
|
| 336 | + # Parse vllm_additional_config |
| 337 | + rollout_additional_config = None |
| 338 | + if trainer_config.vllm_additional_config: |
| 339 | + if isinstance(trainer_config.vllm_additional_config, dict): |
| 340 | + # It's already parsed into a dict |
| 341 | + rollout_additional_config = trainer_config.vllm_additional_config |
| 342 | + elif isinstance(trainer_config.vllm_additional_config, str): |
| 343 | + # It's a string, so we need to parse it |
| 344 | + try: |
| 345 | + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) |
| 346 | + except json.JSONDecodeError as e: |
| 347 | + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e |
| 348 | + |
| 349 | + max_logging.log(f"Parsed additional config: {rollout_additional_config}") |
| 350 | + |
326 | 351 | # RL Cluster config |
327 | 352 | # Note that we use vLLM as the rollout engine. |
328 | 353 | # and we are using Tensor Parallelism for rollout |
@@ -361,6 +386,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): |
361 | 386 | rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, |
362 | 387 | rollout_vllm_tpu_backend_type="jax", |
363 | 388 | rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, |
| 389 | + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, |
| 390 | + rollout_vllm_additional_config=rollout_additional_config, |
| 391 | + rollout_vllm_init_with_random_weights=False, |
364 | 392 | ), |
365 | 393 | ) |
366 | 394 | grpo_config = GrpoConfig( |
@@ -389,14 +417,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): |
389 | 417 | max_logging.log( |
390 | 418 | "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics." |
391 | 419 | ) |
392 | | - with nn_partitioning.axis_rules(trainer_config.logical_axis_rules): |
393 | | - rl_cluster = rl_cluster_lib.RLCluster( |
394 | | - actor=actor_model, |
395 | | - reference=reference_model, |
396 | | - tokenizer=model_tokenizer, |
397 | | - cluster_config=cluster_config, |
398 | | - **rl_cluster_kwargs, |
399 | | - ) |
| 420 | + |
| 421 | + rl_cluster = rl_cluster_lib.RLCluster( |
| 422 | + actor=actor_model, |
| 423 | + reference=reference_model, |
| 424 | + tokenizer=model_tokenizer, |
| 425 | + cluster_config=cluster_config, |
| 426 | + **rl_cluster_kwargs, |
| 427 | + ) |
400 | 428 |
|
401 | 429 | # Create RL trainer |
402 | 430 | max_logging.log("Setting up RL trainer...") |
|
0 commit comments