Skip to content

Commit bc49da1

Browse files
gagikaNicoGrande
authored andcommitted
Support Custom MaxText model (with vLLM engine) in RL rollouts.
Fix formatting. Refactor model creation and error handling in RL training fix linting. adding no-op mappings to tunix adapter.
1 parent dc87cba commit bc49da1

File tree

8 files changed

+83
-17
lines changed

8 files changed

+83
-17
lines changed

src/MaxText/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,3 +965,9 @@ partial_rotary_factor: 1.0
965965
# Use tokamax library for gmm kernel implementation
966966
use_tokamax_gmm: false
967967
use_tokamax_splash: false
968+
969+
# vLLM Adapter Configurations
970+
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
971+
vllm_hf_config_path: ""
972+
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
973+
vllm_additional_config: {}

src/MaxText/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,8 @@ class VLLM(BaseModel):
12801280
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
12811281
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
12821282
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
1283+
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
1284+
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
12831285

12841286

12851287
class GRPO(BaseModel):

src/MaxText/configs/vllm.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ logical_axis_rules: [
4141
['activation_kv_batch_no_exp', ['data']],
4242
['activation_kv_head_dim', ['model']],
4343
['activation_vocab', ['model']],
44+
['activation_embed', ['model']],
4445
['activation_exp', ['expert']],
4546
['decode_batch', ['data', 'expert']],
4647
['mlp', ['model']],
@@ -56,6 +57,13 @@ logical_axis_rules: [
5657
['cache_heads', ['model']],
5758
['exp', ['expert']],
5859
['paged_kv_heads', ['model']],
60+
['autoregressive', ['model']],
61+
['tensor', ['model']],
62+
['tensor_transpose', ['model']],
63+
['fsdp', ['data']],
64+
['fsdp_transpose', ['data']],
65+
['sequence', ['model']],
66+
['context', ['model']],
5967
]
6068
data_sharding: [['data', 'model', 'expert']]
61-
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
69+
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
self,
3838
base_model: Transformer,
3939
use_standalone_mappings: bool = True,
40+
use_no_op_mappings: bool = False,
4041
):
4142
super().__init__()
4243
self.base = base_model
@@ -45,6 +46,7 @@ def __init__(
4546
HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(),
4647
use_standalone_mappings,
4748
)
49+
self.use_no_op_mappings = use_no_op_mappings
4850

4951
# ------------------------------------------------------------------ #
5052
# Tunix call signature
@@ -69,13 +71,25 @@ def __call__(
6971
return logits, None
7072

7173
def to_hf_mappings(self):
74+
if self.use_no_op_mappings:
75+
return {}
76+
7277
return self._vllm_weight_mapping.to_hf_mapping()
7378

7479
def to_hf_transpose_keys(self):
80+
if self.use_no_op_mappings:
81+
return {}
82+
7583
return self._vllm_weight_mapping.to_hf_transpose_keys()
7684

7785
def to_hf_hook_fns(self):
86+
if self.use_no_op_mappings:
87+
return {}
88+
7889
return self._vllm_weight_mapping.to_hf_hook_fns()
7990

8091
def lora_to_hf_mappings(self):
92+
if self.use_no_op_mappings:
93+
return {}
94+
8195
return self._vllm_weight_mapping.lora_to_hf_mappings()

src/MaxText/integration/tunix/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def to_hf_hook_fns(self):
147147
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()
148148

149149
model_family = self.model_name.split("-")[0]
150-
return VLLM_HOOK_FNS[model_family]()
150+
if model_family in VLLM_HOOK_FNS:
151+
return VLLM_HOOK_FNS[model_family]()
152+
else:
153+
return {}
151154

152155
def lora_to_hf_mappings(self):
153156
if self.use_standalone_mappings:

src/MaxText/rl/evaluate_rl.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,18 @@ def score_responses(tmvp_config, question, responses, answer):
121121

122122
# Check exact correctness
123123
try:
124-
if float(extracted_response.strip()) == float(answer.strip()):
124+
# Remove ',' and '$' then convert to float
125+
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
126+
val_answer = float(answer.replace(",", "").replace("$", "").strip())
127+
128+
if val_extracted == val_answer:
125129
is_correct = True
126130

127131
# Check partial correctness (within 10%)
128-
ratio = float(extracted_response.strip()) / float(answer.strip())
129-
if 0.9 <= ratio <= 1.1:
130-
is_partially_correct = True
132+
if val_answer != 0.0:
133+
ratio = val_extracted / val_answer
134+
if 0.9 <= ratio <= 1.1:
135+
is_partially_correct = True
131136
except Exception as e:
132137
if tmvp_config.debug["rl"]:
133138
max_logging.log(f"Evaluation Exception: {e}")

src/MaxText/rl/train_rl.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import collections
4949
import grain
5050
import jax
51+
import json
5152
import os
5253
import pathwaysutils
5354
import tensorflow_datasets as tfds
@@ -92,9 +93,18 @@ def get_maxtext_model(config, devices=None):
9293
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
9394
# load_parameters_path=/path/to/your/output/directory/0/items
9495
"""
95-
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
96+
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices, model_mode="train")
9697
with mesh:
97-
tunix_model = TunixMaxTextAdapter(base_model=model)
98+
if "maxtext_config" in config.vllm_additional_config:
99+
use_standalone_mappings = False
100+
use_no_op_mappings = True
101+
else:
102+
use_standalone_mappings = True
103+
use_no_op_mappings = False
104+
105+
tunix_model = TunixMaxTextAdapter(
106+
base_model=model, use_standalone_mappings=use_standalone_mappings, use_no_op_mappings=use_no_op_mappings
107+
)
98108
tunix_model.config = None
99109
return tunix_model, mesh
100110

@@ -323,6 +333,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323333
set_profile_options=False,
324334
)
325335

336+
# Parse vllm_additional_config
337+
rollout_additional_config = None
338+
if trainer_config.vllm_additional_config:
339+
if isinstance(trainer_config.vllm_additional_config, dict):
340+
# It's already parsed into a dict
341+
rollout_additional_config = trainer_config.vllm_additional_config
342+
elif isinstance(trainer_config.vllm_additional_config, str):
343+
# It's a string, so we need to parse it
344+
try:
345+
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
346+
except json.JSONDecodeError as e:
347+
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e
348+
349+
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
350+
326351
# RL Cluster config
327352
# Note that we use vLLM as the rollout engine.
328353
# and we are using Tensor Parallelism for rollout
@@ -361,6 +386,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
361386
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
362387
rollout_vllm_tpu_backend_type="jax",
363388
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
389+
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
390+
rollout_vllm_additional_config=rollout_additional_config,
391+
rollout_vllm_init_with_random_weights=False,
364392
),
365393
)
366394
grpo_config = GrpoConfig(
@@ -389,14 +417,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
389417
max_logging.log(
390418
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
391419
)
392-
with nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
393-
rl_cluster = rl_cluster_lib.RLCluster(
394-
actor=actor_model,
395-
reference=reference_model,
396-
tokenizer=model_tokenizer,
397-
cluster_config=cluster_config,
398-
**rl_cluster_kwargs,
399-
)
420+
421+
rl_cluster = rl_cluster_lib.RLCluster(
422+
actor=actor_model,
423+
reference=reference_model,
424+
tokenizer=model_tokenizer,
425+
cluster_config=cluster_config,
426+
**rl_cluster_kwargs,
427+
)
400428

401429
# Create RL trainer
402430
max_logging.log("Setting up RL trainer...")

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,5 +1479,5 @@ def transform_query_kernel(arr):
14791479
VLLM_HOOK_FNS = {
14801480
"qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
14811481
"llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
1482-
"deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
1482+
"deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
14831483
}

0 commit comments

Comments
 (0)