Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,9 @@ partial_rotary_factor: 1.0
# Use tokamax library for gmm kernel implementation
use_tokamax_gmm: false
use_tokamax_splash: false

# vLLM Adapter Configurations
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
vllm_hf_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
3 changes: 3 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,9 @@ class VLLM(BaseModel):
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
vllm_additional_config: dict[str, Any] = Field(default_factory=dict,
description="Additional vLLM config options.")
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")


class GRPO(BaseModel):
Expand Down
10 changes: 9 additions & 1 deletion src/MaxText/configs/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ logical_axis_rules: [
['activation_kv_batch_no_exp', ['data']],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model']],
['activation_embed', ['model']],
['activation_exp', ['expert']],
['decode_batch', ['data', 'expert']],
['mlp', ['model']],
Expand All @@ -56,6 +57,13 @@ logical_axis_rules: [
['cache_heads', ['model']],
['exp', ['expert']],
['paged_kv_heads', ['model']],
['autoregressive', ['model']],
['tensor', ['model']],
['tensor_transpose', ['model']],
['fsdp', ['data']],
['fsdp_transpose', ['data']],
['sequence', ['model']],
['context', ['model']],
]
data_sharding: [['data', 'model', 'expert']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
5 changes: 4 additions & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def to_hf_hook_fns(self):
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()

model_family = self.model_name.split("-")[0]
return VLLM_HOOK_FNS[model_family]()
if model_family in VLLM_HOOK_FNS:
return VLLM_HOOK_FNS[model_family]()
else:
return {}

def lora_to_hf_mappings(self):
if self.use_standalone_mappings:
Expand Down
13 changes: 9 additions & 4 deletions src/MaxText/rl/evaluate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def score_responses(tmvp_config, question, responses, answer):

# Check exact correctness
try:
if float(extracted_response.strip()) == float(answer.strip()):
# Remove ',' and '$' then convert to float
val_extracted = float(extracted_response.replace(',', '').replace('$', '').strip())
val_answer = float(answer.replace(',', '').replace('$', '').strip())

if val_extracted == val_answer:
is_correct = True

# Check partial correctness (within 10%)
ratio = float(extracted_response.strip()) / float(answer.strip())
if 0.9 <= ratio <= 1.1:
is_partially_correct = True
if val_answer != 0.0:
ratio = val_extracted / val_answer
if 0.9 <= ratio <= 1.1:
is_partially_correct = True
except Exception as e:
if tmvp_config.debug["rl"]:
max_logging.log(f"Evaluation Exception: {e}")
Expand Down
46 changes: 36 additions & 10 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from flax.linen import partitioning as nn_partitioning
import grain
from etils import epath
import json

from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import
import jax
Expand Down Expand Up @@ -100,9 +101,16 @@ def get_maxtext_model(config, devices=None):
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
# load_parameters_path=/path/to/your/output/directory/0/items
"""
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
model, mesh = model_creation_utils.create_nnx_model(config,
devices=devices,
model_mode="train")
with mesh:
tunix_model = TunixMaxTextAdapter(base_model=model)
if "maxtext_config" in config.vllm_additional_config:
use_standalone_mappings = False
else:
use_standalone_mappings = True
tunix_model = TunixMaxTextAdapter(base_model=model,
use_standalone_mappings=use_standalone_mappings)
tunix_model.config = None
return tunix_model, mesh

Expand Down Expand Up @@ -331,6 +339,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
set_profile_options=False,
)

# Parse vllm_additional_config
rollout_additional_config = None
if trainer_config.vllm_additional_config:
if isinstance(trainer_config.vllm_additional_config, dict):
# It's already parsed into a dict
rollout_additional_config = trainer_config.vllm_additional_config
elif isinstance(trainer_config.vllm_additional_config, str):
# It's a string, so we need to parse it
try:
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e

max_logging.log(f"Parsed additional config: {rollout_additional_config}")

# RL Cluster config
# Note that we use vLLM as the rollout engine.
# and we are using Tensor Parallelism for rollout
Expand Down Expand Up @@ -369,6 +392,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
rollout_vllm_tpu_backend_type="jax",
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
rollout_vllm_additional_config=rollout_additional_config,
rollout_vllm_init_with_random_weights=False
),
)
grpo_config = GrpoConfig(
Expand Down Expand Up @@ -397,14 +423,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
max_logging.log(
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
)
with nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
rl_cluster = rl_cluster_lib.RLCluster(
actor=actor_model,
reference=reference_model,
tokenizer=model_tokenizer,
cluster_config=cluster_config,
**rl_cluster_kwargs,
)

rl_cluster = rl_cluster_lib.RLCluster(
actor=actor_model,
reference=reference_model,
tokenizer=model_tokenizer,
cluster_config=cluster_config,
**rl_cluster_kwargs,
)

# Create GRPO trainer
max_logging.log("Setting up GRPO trainer...")
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/utils/ckpt_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,5 +1479,5 @@ def transform_query_kernel(arr):
VLLM_HOOK_FNS = {
"qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
"llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
}
Loading