Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,37 @@ RUN pip install keyring keyrings.google-artifactregistry-auth

RUN pip install numba==0.61.2

COPY tunix /tunix
RUN pip uninstall -y google-tunix
RUN pip install -e /tunix --no-cache-dir
RUN pip install vllm-tpu

# Clone directly into /vllm
RUN pip install vllm==0.12.0
# 1. TUNIX
# Clone directly into /tunix instead of COPYing local files
# RUN git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
RUN pip uninstall -y tunix && git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git && cd tunix && pip install -e . && cd ..

COPY vllm /vllm
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir
# 2. TPU-INFERENCEs
# Clone directly into /tpu-inference
# RUN git clone https://github.com/vllm-project/tpu-inference.git /tpu-inference
RUN pip uninstall -y tpu-inference && git clone https://github.com/abhinavclemson/tpu-inference.git && cd tpu-inference && pip install -e . && cd ..

# Note: The repo name is 'tpu-inference' (dash), but python package might be 'tpu_inference'.
# pip install handles this mapping automatically.

COPY tpu-inference /tpu-inference
RUN pip install -e /tpu-inference --no-cache-dir
# 3. vLLM


# RUN git clone https://github.com/vllm-project/vllm.git /vllm
# Set the TPU target and install

# --- REPLACEMENT END ---

RUN pip install --no-deps qwix==0.1.4

RUN pip install google-metrax numpy==2.2

RUN if [ "$MODE" = "post-training-experimental" ]; then \
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
pip uninstall -y jax jaxlib libtpu && \
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
pip install --pre jax==0.8.0.dev20251013 jaxlib==0.8.0.dev20251013 libtpu==0.0.25.dev20251012+nightly -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
fi
53 changes: 53 additions & 0 deletions dependencies/dockerfiles/patch_work.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

#!/bin/bash

# 1. Define the target directory
SITE_PACKAGES="/usr/local/lib/python*/site-packages"
TEMP_DIR="temp_patch_work"

# Ensure the script stops if any command fails
set -e

echo "Navigate to site-packages: $SITE_PACKAGES"
cd "$SITE_PACKAGES"

# 2. Create a temporary directory for cloning
echo "Creating temporary directory..."
# Remove it first if it exists from a previous failed run to ensure a clean slate
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
mkdir "$TEMP_DIR"
cd "$TEMP_DIR"

# 3. Clone the repositories
echo "Cloning repositories..."
git clone https://github.com/vllm-project/vllm.git
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
git clone https://github.com/vllm-project/tpu-inference.git

# Go back up to site-packages
cd ..

# 4. Copy files
# We use 'cp -rf' to force overwrite existing files recursively.
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
# If they don't exist, we create them.

echo "Patching Tunix..."
mkdir -p ./tunix
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/

echo "Patching TPU-Inference..."
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
# Based on your prompt, we are using 'tpu-inference'.
mkdir -p ./tpu_inference
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/

echo "Patching vLLM..."
mkdir -p ./vllm
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/

# 5. Cleanup
echo "Cleaning up temporary files..."
rm -rf "$TEMP_DIR"

echo "Done! Packages have been patched."
53 changes: 53 additions & 0 deletions dependencies/scripts/patch_work.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

#!/bin/bash

# 1. Define the target directory
SITE_PACKAGES="/usr/local/lib/python*/site-packages"
TEMP_DIR="temp_patch_work"

# Ensure the script stops if any command fails
set -e

echo "Navigate to site-packages: $SITE_PACKAGES"
cd "$SITE_PACKAGES"

# 2. Create a temporary directory for cloning
echo "Creating temporary directory..."
# Remove it first if it exists from a previous failed run to ensure a clean slate
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
mkdir "$TEMP_DIR"
cd "$TEMP_DIR"

# 3. Clone the repositories
echo "Cloning repositories..."
git clone https://github.com/vllm-project/vllm.git
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
git clone https://github.com/vllm-project/tpu-inference.git

# Go back up to site-packages
cd ..

# 4. Copy files
# We use 'cp -rf' to force overwrite existing files recursively.
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
# If they don't exist, we create them.

echo "Patching Tunix..."
mkdir -p ./tunix
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/

echo "Patching TPU-Inference..."
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
# Based on your prompt, we are using 'tpu-inference'.
mkdir -p ./tpu_inference
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/

echo "Patching vLLM..."
mkdir -p ./vllm
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/

# 5. Cleanup
echo "Cleaning up temporary files..."
rm -rf "$TEMP_DIR"

echo "Done! Packages have been patched."
57 changes: 57 additions & 0 deletions patch_work.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

#!/bin/bash

# Ensure the script stops if any command fails
set -e

cd ..

# 1. Define the target directory
SITE_PACKAGES=$(find . -type d -name "*site-packages*" -print -quit)

TEMP_DIR="temp_patch_work"

echo "Navigate to site-packages: $SITE_PACKAGES"
cd "$SITE_PACKAGES"

# 2. Create a temporary directory for cloning
echo "Creating temporary directory..."
# Remove it first if it exists from a previous failed run to ensure a clean slate
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
mkdir "$TEMP_DIR"
cd "$TEMP_DIR"

# 3. Clone the repositories
echo "Cloning repositories..."
git clone https://github.com/vllm-project/vllm.git && cd vllm && git checkout 8c363ed6663f69b97c9f34b0be0091d8135f958c && cd ..
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
git clone https://github.com/abhinavclemson/tpu-inference.git


# Go back up to site-packages
cd ..

# 4. Copy files
# We use 'cp -rf' to force overwrite existing files recursively.
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
# If they don't exist, we create them.

echo "Patching Tunix..."
mkdir -p ./tunix
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/

echo "Patching TPU-Inference..."
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
# Based on your prompt, we are using 'tpu-inference'.
mkdir -p ./tpu_inference
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/

echo "Patching vLLM..."
mkdir -p ./vllm
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/

# 5. Cleanup
echo "Cleaning up temporary files..."
rm -rf "$TEMP_DIR"

echo "Done! Packages have been patched."
10 changes: 7 additions & 3 deletions src/MaxText/configs/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ sampler_devices_fraction: 0.5
chips_per_vm: 4 # depends on hardware, for v5p this is 4
num_trainer_slices: -1
num_samplers_slices: -1
# Only specify rollout_data_parallelism when you would like to use more than one model
# replicas in rollout. If not specified, rollout_tensor_parallelism will be auto-determined.
rollout_data_parallelism: -1
rollout_tensor_parallelism: -1

# ====== Reproducibility ======
data_shuffle_seed: 42
Expand Down Expand Up @@ -83,13 +87,13 @@ debug:
enable_tunix_perf_metrics: False

# ====== Training ======
batch_size: 1
batch_size: 8
# Increase `batch_size` and `MAX_STEPS` for better results.
# num_batches: 3738
num_batches: 4 # 200
# A batch can be split into multiple micro batches for memory management
# and/or async sampling and training.
micro_batch_size: -1
micro_batch_size: 8
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
num_test_batches: 5 # 200
Expand Down Expand Up @@ -130,7 +134,7 @@ eval_make_lst: False # If True, return a list of (question, answer, responses) d
max_prefill_predict_length: 256
max_target_length: 1024
kv_cache_buffer: 256
hbm_utilization_vllm: 0.72
hbm_utilization_vllm: 0.6
swap_space_vllm_gb: 2
# Generation Configuration During Training
# Important to keep a high-ish temperature for varied, diverse responses during
Expand Down
9 changes: 8 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,14 @@ class RLHardware(BaseModel):
use_pathways: bool = Field(True, description="Whether to use Pathways for multihost orchestration.")
num_trainer_slices: int = Field(-1, description="Number of slices for the trainer.")
num_samplers_slices: int = Field(-1, description="Number of slices for the samplers.")

rollout_data_parallelism: int = Field(
-1,
description="Total model replicas for rollout. It should only be specified when you would like to use more "
"than one model replica in rollout.",
)
rollout_tensor_parallelism: int = Field(
-1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
)

class VLLM(BaseModel):
"""vLLM-specific configuration for rollouts."""
Expand Down
14 changes: 13 additions & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

"""Utils for Tunix integration."""

import inspect
import re


import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
Expand Down Expand Up @@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
def to_hf_mapping(self):
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
total_num_layers = self.config["num_hidden_layers"]
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
sig = inspect.signature(mapping_fn)
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
mapping = mapping_fn(
total_num_layers=total_num_layers,
)
return mapping

return mapping_fn()

config = self.config
mapping = self.convert_hf_map_to_sharding_map(
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
model name. This allows for easy extension to support new models.
"""

from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING

Expand All @@ -31,6 +32,8 @@ def __getattr__(self, name):
return LLAMA3_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("gpt"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if someone tried (weirdly) to use our old gpt3 model would this hit?

return GptOssMaxTextMapping
else:
raise ValueError(f"{name} vLLM weight mapping not found.")

Expand Down
Loading
Loading