Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,9 @@ partial_rotary_factor: 1.0
# Use tokamax library for gmm kernel implementation
use_tokamax_gmm: false
use_tokamax_splash: false

# vLLM Adapter Configurations
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
vllm_hf_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
2 changes: 2 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,8 @@ class VLLM(BaseModel):
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")


class GRPO(BaseModel):
Expand Down
13 changes: 12 additions & 1 deletion src/MaxText/configs/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ logical_axis_rules: [
['activation_kv_batch_no_exp', []],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model']],
['activation_embed', ['model']],
['activation_exp', ['expert']],
['decode_batch', ['expert']],
['mlp', ['model']],
Expand All @@ -49,13 +50,23 @@ logical_axis_rules: [
['heads', ['model']],
['q_heads', ['model']],
['kv_heads', ['model']],
['kv_head_dim', []],
['kv', []],
['embed', ['expert']],
['embed_no_exp', []],
['q_lora', ['expert']],
['kv_lora', ['expert']],
['norm', ['model']],
['cache_heads', ['model']],
['exp', ['expert']],
['paged_kv_heads', ['model']],
['autoregressive', ['model']],
['tensor', ['model']],
['tensor_transpose', ['model']],
['fsdp', ['data']],
['fsdp_transpose', ['data']],
['sequence', ['model']],
['context', ['model']],
]
data_sharding: [['data', 'model', 'expert']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
14 changes: 14 additions & 0 deletions src/MaxText/integration/tunix/tunix_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self,
base_model: Transformer,
use_standalone_mappings: bool = True,
use_no_op_mappings: bool = False,
):
super().__init__()
self.base = base_model
Expand All @@ -45,6 +46,7 @@ def __init__(
HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(),
use_standalone_mappings,
)
self.use_no_op_mappings = use_no_op_mappings

# ------------------------------------------------------------------ #
# Tunix call signature
Expand All @@ -69,13 +71,25 @@ def __call__(
return logits, None

def to_hf_mappings(self):
if self.use_no_op_mappings:
return {}

return self._vllm_weight_mapping.to_hf_mapping()

def to_hf_transpose_keys(self):
if self.use_no_op_mappings:
return {}

return self._vllm_weight_mapping.to_hf_transpose_keys()

def to_hf_hook_fns(self):
if self.use_no_op_mappings:
return {}

return self._vllm_weight_mapping.to_hf_hook_fns()

def lora_to_hf_mappings(self):
if self.use_no_op_mappings:
return {}

return self._vllm_weight_mapping.lora_to_hf_mappings()
5 changes: 4 additions & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def to_hf_hook_fns(self):
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()

model_family = self.model_name.split("-")[0]
return VLLM_HOOK_FNS[model_family]()
if model_family in VLLM_HOOK_FNS:
return VLLM_HOOK_FNS[model_family]()
else:
return {}

def lora_to_hf_mappings(self):
if self.use_standalone_mappings:
Expand Down
39 changes: 14 additions & 25 deletions src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,17 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
Raises:
ValueError: If `hf_config_path` is not provided in the vLLM model config.
"""

def _path_exists(path: str) -> bool:
if not path:
return False
return epath.Path(path).exists()

if "maxtext_config" in vllm_config.additional_config:
overrides = vllm_config.additional_config["maxtext_config"]
else:
overrides = {}
load_path = None
if _path_exists(vllm_config.load.download_dir):
load_path = vllm_config.load.download_dir
elif _path_exists(vllm_config.model.model):
load_path = vllm_config.model.model

if load_path:
overrides["load_parameters_path"] = load_path
elif vllm_config.model.model:
overrides["model_name"] = vllm_config.model.model

if vllm_config.load_config.load_format == "dummy":
if overrides.get("load_parameters_path") is not None:
max_logging.log(
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
)
overrides["load_parameters_path"] = None

if vllm_config.model_config.hf_config_path is None:
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
Expand Down Expand Up @@ -110,12 +101,6 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N

# Handle dummy weight loading during initialization
if vllm_config.load_config.load_format == "dummy":
if self.maxtext_config.load_parameters_path is not None:
max_logging.log(
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
)
self.maxtext_config.load_parameters_path = None

with self.mesh:
self.load_weights(rng_key)

Expand Down Expand Up @@ -199,9 +184,13 @@ def load_weights(self, rng_key: jax.Array) -> None:
Args:
rng_key: A JAX random key for model initialization.
"""
self.model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
if self.model is not None:
return

with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
self.model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)


class MaxTextForCausalLM(nnx.Module):
Expand Down
13 changes: 10 additions & 3 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(
# Module attribute names must match names previously passed to Linen for checkpointing
self.KVCache_0 = (
self.init_kv_caches(inputs_kv_shape=inputs_kv_shape)
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache and config.attention != "vllm_rpa"
else None
)

Expand Down Expand Up @@ -909,7 +909,7 @@ def forward_serve_vllm(
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
from tpu_inference.layers.common.attention_interface import sharded_ragged_paged_attention as rpa_ops
except ImportError as e:
raise ImportError(
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
Expand All @@ -930,7 +930,8 @@ def forward_serve_vllm(

md = rpa_metadata

output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
output, kv_cache = rpa_ops(
self.mesh,
query,
key,
value,
Expand All @@ -939,6 +940,12 @@ def forward_serve_vllm(
md.block_tables,
md.query_start_loc,
md.request_distribution,
None,
1.0,
attention_chunk_size,
q_scale,
k_scale,
v_scale,
)
return kv_cache, output

Expand Down
13 changes: 9 additions & 4 deletions src/MaxText/rl/evaluate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def score_responses(tmvp_config, question, responses, answer):

# Check exact correctness
try:
if float(extracted_response.strip()) == float(answer.strip()):
# Remove ',' and '$' then convert to float
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
val_answer = float(answer.replace(",", "").replace("$", "").strip())

if val_extracted == val_answer:
is_correct = True

# Check partial correctness (within 10%)
ratio = float(extracted_response.strip()) / float(answer.strip())
if 0.9 <= ratio <= 1.1:
is_partially_correct = True
if val_answer != 0.0:
ratio = val_extracted / val_answer
if 0.9 <= ratio <= 1.1:
is_partially_correct = True
except Exception as e:
if tmvp_config.debug["rl"]:
max_logging.log(f"Evaluation Exception: {e}")
Expand Down
38 changes: 36 additions & 2 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import collections
import grain
import jax
import json
import os
import pathwaysutils
import tensorflow_datasets as tfds
Expand All @@ -70,6 +71,7 @@

from MaxText import max_logging, max_utils, maxtext_utils, pyconfig
from MaxText import model_creation_utils
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from MaxText.rl.evaluate_rl import evaluate
from MaxText.rl import utils_rl
Expand All @@ -93,7 +95,16 @@ def get_maxtext_model(config, devices=None):
"""
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
with jax.set_mesh(mesh):
tunix_model = TunixMaxTextAdapter(base_model=model)
if "maxtext_config" in config.vllm_additional_config:
use_standalone_mappings = False
use_no_op_mappings = True
else:
use_standalone_mappings = True
use_no_op_mappings = False

tunix_model = TunixMaxTextAdapter(
base_model=model, use_standalone_mappings=use_standalone_mappings, use_no_op_mappings=use_no_op_mappings
)
tunix_model.config = None
return tunix_model, mesh

Expand Down Expand Up @@ -352,6 +363,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
set_profile_options=False,
)

# Parse vllm_additional_config
rollout_additional_config = None
if trainer_config.vllm_additional_config:
if isinstance(trainer_config.vllm_additional_config, dict):
# It's already parsed into a dict
rollout_additional_config = trainer_config.vllm_additional_config
elif isinstance(trainer_config.vllm_additional_config, str):
# It's a string, so we need to parse it
try:
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e

max_logging.log(f"Parsed additional config: {rollout_additional_config}")

# RL Cluster config
# Note that we use vLLM as the rollout engine.
# and we are using Tensor Parallelism for rollout
Expand Down Expand Up @@ -394,6 +420,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
rollout_vllm_tpu_backend_type="jax",
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
rollout_vllm_additional_config=rollout_additional_config,
rollout_vllm_init_with_random_weights=True,
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
),
)
Expand Down Expand Up @@ -423,7 +452,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
max_logging.log(
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
)
with nn_partitioning.axis_rules(trainer_config.logical_axis_rules):

vllm_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
argv_list = ["", str(vllm_config_path)]
vllm_config = pyconfig.initialize(argv_list)

with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
rl_cluster = rl_cluster_lib.RLCluster(
actor=actor_model,
reference=reference_model,
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/utils/ckpt_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,5 +1479,5 @@ def transform_query_kernel(arr):
VLLM_HOOK_FNS = {
"qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
"llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
}
Loading