diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 5be8df263..ca0402789 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -958,3 +958,9 @@ partial_rotary_factor: 1.0 # Use tokamax library for gmm kernel implementation use_tokamax_gmm: false use_tokamax_splash: false + +# vLLM Adapter Configurations +# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) +vllm_hf_config_path: "" +# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') +vllm_additional_config: {} diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 4f2406534..4f9f0952c 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1272,6 +1272,9 @@ class VLLM(BaseModel): kv_cache_buffer: int = Field(256, description="Buffer for KV cache.") hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.") swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.") + vllm_additional_config: dict[str, Any] = Field(default_factory=dict, + description="Additional vLLM config options.") + vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") class GRPO(BaseModel): diff --git a/src/MaxText/configs/vllm.yml b/src/MaxText/configs/vllm.yml index 8132681bf..569062e6b 100644 --- a/src/MaxText/configs/vllm.yml +++ b/src/MaxText/configs/vllm.yml @@ -41,6 +41,7 @@ logical_axis_rules: [ ['activation_kv_batch_no_exp', ['data']], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model']], + ['activation_embed', ['model']], ['activation_exp', ['expert']], ['decode_batch', ['data', 'expert']], ['mlp', ['model']], @@ -56,6 +57,13 @@ logical_axis_rules: [ ['cache_heads', ['model']], ['exp', ['expert']], ['paged_kv_heads', ['model']], + ['autoregressive', ['model']], + ['tensor', ['model']], + ['tensor_transpose', ['model']], + ['fsdp', ['data']], + ['fsdp_transpose', ['data']], + ['sequence', ['model']], + ['context', ['model']], ] data_sharding: [['data', 'model', 'expert']] -input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] +input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] \ No newline at end of file diff --git a/src/MaxText/integration/tunix/utils.py b/src/MaxText/integration/tunix/utils.py index 2cf12c048..463cff8ee 100644 --- a/src/MaxText/integration/tunix/utils.py +++ b/src/MaxText/integration/tunix/utils.py @@ -147,7 +147,10 @@ def to_hf_hook_fns(self): return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns() model_family = self.model_name.split("-")[0] - return VLLM_HOOK_FNS[model_family]() + if model_family in VLLM_HOOK_FNS: + return VLLM_HOOK_FNS[model_family]() + else: + return {} def lora_to_hf_mappings(self): if self.use_standalone_mappings: diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index 29610ef58..4711f0dfb 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -121,13 +121,18 @@ def score_responses(tmvp_config, question, responses, answer): # Check exact correctness try: - if float(extracted_response.strip()) == float(answer.strip()): + # Remove ',' and '$' then convert to float + val_extracted = float(extracted_response.replace(',', '').replace('$', '').strip()) + val_answer = float(answer.replace(',', '').replace('$', '').strip()) + + if val_extracted == val_answer: is_correct = True # Check partial correctness (within 10%) - ratio = float(extracted_response.strip()) / float(answer.strip()) - if 0.9 <= ratio <= 1.1: - is_partially_correct = True + if val_answer != 0.0: + ratio = val_extracted / val_answer + if 0.9 <= ratio <= 1.1: + is_partially_correct = True except Exception as e: if tmvp_config.debug["rl"]: max_logging.log(f"Evaluation Exception: {e}") diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index ceab65980..0fe8944a3 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -54,6 +54,7 @@ from flax.linen import partitioning as nn_partitioning import grain from etils import epath +import json from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import import jax @@ -100,9 +101,16 @@ def get_maxtext_model(config, devices=None): # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., # load_parameters_path=/path/to/your/output/directory/0/items """ - model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) + model, mesh = model_creation_utils.create_nnx_model(config, + devices=devices, + model_mode="train") with mesh: - tunix_model = TunixMaxTextAdapter(base_model=model) + if "maxtext_config" in config.vllm_additional_config: + use_standalone_mappings = False + else: + use_standalone_mappings = True + tunix_model = TunixMaxTextAdapter(base_model=model, + use_standalone_mappings=use_standalone_mappings) tunix_model.config = None return tunix_model, mesh @@ -331,6 +339,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): set_profile_options=False, ) + # Parse vllm_additional_config + rollout_additional_config = None + if trainer_config.vllm_additional_config: + if isinstance(trainer_config.vllm_additional_config, dict): + # It's already parsed into a dict + rollout_additional_config = trainer_config.vllm_additional_config + elif isinstance(trainer_config.vllm_additional_config, str): + # It's a string, so we need to parse it + try: + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e + + max_logging.log(f"Parsed additional config: {rollout_additional_config}") + # RL Cluster config # Note that we use vLLM as the rollout engine. # and we are using Tensor Parallelism for rollout @@ -369,6 +392,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, + rollout_vllm_additional_config=rollout_additional_config, + rollout_vllm_init_with_random_weights=False ), ) grpo_config = GrpoConfig( @@ -397,14 +423,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.log( "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics." ) - with nn_partitioning.axis_rules(trainer_config.logical_axis_rules): - rl_cluster = rl_cluster_lib.RLCluster( - actor=actor_model, - reference=reference_model, - tokenizer=model_tokenizer, - cluster_config=cluster_config, - **rl_cluster_kwargs, - ) + + rl_cluster = rl_cluster_lib.RLCluster( + actor=actor_model, + reference=reference_model, + tokenizer=model_tokenizer, + cluster_config=cluster_config, + **rl_cluster_kwargs, + ) # Create GRPO trainer max_logging.log("Setting up GRPO trainer...") diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index bde6a62e4..3d4a194bc 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -1479,5 +1479,5 @@ def transform_query_kernel(arr): VLLM_HOOK_FNS = { "qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN, "llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN, - "deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, + "deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, }