Skip to content

Commit 64b49ad

Browse files
Add GPT OSS vllm mapping generator.
1 parent 9204d6b commit 64b49ad

File tree

9 files changed

+416
-16
lines changed

9 files changed

+416
-16
lines changed

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,31 @@ RUN pip install keyring keyrings.google-artifactregistry-auth
2727

2828
RUN pip install numba==0.61.2
2929

30-
COPY tunix /tunix
31-
RUN pip uninstall -y google-tunix
32-
RUN pip install -e /tunix --no-cache-dir
30+
RUN pip install vllm-tpu
3331

32+
# 1. TUNIX
33+
# Clone directly into /tunix instead of COPYing local files
34+
RUN git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
3435

35-
COPY vllm /vllm
36-
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir
36+
# 2. TPU-INFERENCE
37+
# Clone directly into /tpu-inference
38+
RUN git clone https://github.com/vllm-project/tpu-inference.git /tpu-inference
39+
# Note: The repo name is 'tpu-inference' (dash), but python package might be 'tpu_inference'.
40+
# pip install handles this mapping automatically.
3741

42+
# 3. vLLM
43+
# Clone directly into /vllm
44+
RUN git clone https://github.com/vllm-project/vllm.git /vllm
45+
# Set the TPU target and install
3846

39-
COPY tpu-inference /tpu-inference
40-
RUN pip install -e /tpu-inference --no-cache-dir
47+
# --- REPLACEMENT END ---
4148

4249
RUN pip install --no-deps qwix==0.1.4
4350

51+
RUN pip install google-metrax numpy==2.2
52+
4453
RUN if [ "$MODE" = "post-training-experimental" ]; then \
4554
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
4655
pip uninstall -y jax jaxlib libtpu && \
47-
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
48-
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
56+
pip install --pre jax==0.8.0.dev20251013 jaxlib==0.8.0.dev20251013 libtpu==0.0.25.dev20251012+nightly -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
4957
fi
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
#!/bin/bash
3+
4+
# 1. Define the target directory
5+
SITE_PACKAGES="/usr/local/lib/python*/site-packages"
6+
TEMP_DIR="temp_patch_work"
7+
8+
# Ensure the script stops if any command fails
9+
set -e
10+
11+
echo "Navigate to site-packages: $SITE_PACKAGES"
12+
cd "$SITE_PACKAGES"
13+
14+
# 2. Create a temporary directory for cloning
15+
echo "Creating temporary directory..."
16+
# Remove it first if it exists from a previous failed run to ensure a clean slate
17+
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
18+
mkdir "$TEMP_DIR"
19+
cd "$TEMP_DIR"
20+
21+
# 3. Clone the repositories
22+
echo "Cloning repositories..."
23+
git clone https://github.com/vllm-project/vllm.git
24+
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
25+
git clone https://github.com/vllm-project/tpu-inference.git
26+
27+
# Go back up to site-packages
28+
cd ..
29+
30+
# 4. Copy files
31+
# We use 'cp -rf' to force overwrite existing files recursively.
32+
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
33+
# If they don't exist, we create them.
34+
35+
echo "Patching Tunix..."
36+
mkdir -p ./tunix
37+
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/
38+
39+
echo "Patching TPU-Inference..."
40+
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
41+
# Based on your prompt, we are using 'tpu-inference'.
42+
mkdir -p ./tpu_inference
43+
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/
44+
45+
echo "Patching vLLM..."
46+
mkdir -p ./vllm
47+
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/
48+
49+
# 5. Cleanup
50+
echo "Cleaning up temporary files..."
51+
rm -rf "$TEMP_DIR"
52+
53+
echo "Done! Packages have been patched."

dependencies/scripts/patch_work.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
#!/bin/bash
3+
4+
# 1. Define the target directory
5+
SITE_PACKAGES="/usr/local/lib/python*/site-packages"
6+
TEMP_DIR="temp_patch_work"
7+
8+
# Ensure the script stops if any command fails
9+
set -e
10+
11+
echo "Navigate to site-packages: $SITE_PACKAGES"
12+
cd "$SITE_PACKAGES"
13+
14+
# 2. Create a temporary directory for cloning
15+
echo "Creating temporary directory..."
16+
# Remove it first if it exists from a previous failed run to ensure a clean slate
17+
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
18+
mkdir "$TEMP_DIR"
19+
cd "$TEMP_DIR"
20+
21+
# 3. Clone the repositories
22+
echo "Cloning repositories..."
23+
git clone https://github.com/vllm-project/vllm.git
24+
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
25+
git clone https://github.com/vllm-project/tpu-inference.git
26+
27+
# Go back up to site-packages
28+
cd ..
29+
30+
# 4. Copy files
31+
# We use 'cp -rf' to force overwrite existing files recursively.
32+
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
33+
# If they don't exist, we create them.
34+
35+
echo "Patching Tunix..."
36+
mkdir -p ./tunix
37+
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/
38+
39+
echo "Patching TPU-Inference..."
40+
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
41+
# Based on your prompt, we are using 'tpu-inference'.
42+
mkdir -p ./tpu_inference
43+
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/
44+
45+
echo "Patching vLLM..."
46+
mkdir -p ./vllm
47+
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/
48+
49+
# 5. Cleanup
50+
echo "Cleaning up temporary files..."
51+
rm -rf "$TEMP_DIR"
52+
53+
echo "Done! Packages have been patched."

patch_work.sh

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
#!/bin/bash
3+
4+
# Ensure the script stops if any command fails
5+
set -e
6+
7+
cd ..
8+
9+
# 1. Define the target directory
10+
SITE_PACKAGES=$(find . -type d -name "*site-packages*" -print -quit)
11+
12+
TEMP_DIR="temp_patch_work"
13+
14+
echo "Navigate to site-packages: $SITE_PACKAGES"
15+
cd "$SITE_PACKAGES"
16+
17+
# 2. Create a temporary directory for cloning
18+
echo "Creating temporary directory..."
19+
# Remove it first if it exists from a previous failed run to ensure a clean slate
20+
if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi
21+
mkdir "$TEMP_DIR"
22+
cd "$TEMP_DIR"
23+
24+
# 3. Clone the repositories
25+
echo "Cloning repositories..."
26+
git clone https://github.com/vllm-project/vllm.git && cd vllm && git checkout 8c363ed6663f69b97c9f34b0be0091d8135f958c && cd ..
27+
git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git
28+
git clone https://github.com/abhinavclemson/tpu-inference.git
29+
30+
31+
# Go back up to site-packages
32+
cd ..
33+
34+
# 4. Copy files
35+
# We use 'cp -rf' to force overwrite existing files recursively.
36+
# We assume the destination folders (./tunix, ./vllm) already exist as installed packages.
37+
# If they don't exist, we create them.
38+
39+
echo "Patching Tunix..."
40+
mkdir -p ./tunix
41+
cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/
42+
43+
echo "Patching TPU-Inference..."
44+
# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash).
45+
# Based on your prompt, we are using 'tpu-inference'.
46+
mkdir -p ./tpu_inference
47+
cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/
48+
49+
echo "Patching vLLM..."
50+
mkdir -p ./vllm
51+
cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/
52+
53+
# 5. Cleanup
54+
echo "Cleaning up temporary files..."
55+
rm -rf "$TEMP_DIR"
56+
57+
echo "Done! Packages have been patched."

src/MaxText/configs/rl.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ debug:
8383
enable_tunix_perf_metrics: False
8484

8585
# ====== Training ======
86-
batch_size: 1
86+
batch_size: 8
8787
# Increase `batch_size` and `MAX_STEPS` for better results.
8888
# num_batches: 3738
8989
num_batches: 4 # 200
9090
# A batch can be split into multiple micro batches for memory management
9191
# and/or async sampling and training.
92-
micro_batch_size: -1
92+
micro_batch_size: 8
9393
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
9494
# increased to a max. of 330 (if batch size is 4).
9595
num_test_batches: 5 # 200
@@ -130,7 +130,7 @@ eval_make_lst: False # If True, return a list of (question, answer, responses) d
130130
max_prefill_predict_length: 256
131131
max_target_length: 1024
132132
kv_cache_buffer: 256
133-
hbm_utilization_vllm: 0.72
133+
hbm_utilization_vllm: 0.6
134134
swap_space_vllm_gb: 2
135135
# Generation Configuration During Training
136136
# Important to keep a high-ish temperature for varied, diverse responses during

src/MaxText/integration/tunix/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Utils for Tunix integration."""
1616

17+
import inspect
1718
import re
1819

20+
1921
import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
2022
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
2123
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
@@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False):
127129
def to_hf_mapping(self):
128130
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
129131
if self.use_standalone_mappings:
130-
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
132+
mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping
133+
total_num_layers = self.config["num_hidden_layers"]
134+
print(f"total_num_layers: {total_num_layers} for model: {self.model_name}")
135+
sig = inspect.signature(mapping_fn)
136+
if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters:
137+
mapping = mapping_fn(
138+
total_num_layers=total_num_layers,
139+
)
140+
return mapping
141+
142+
return mapping_fn()
131143

132144
config = self.config
133145
mapping = self.convert_hf_map_to_sharding_map(

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22+
from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping
2223
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
2324
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2425

@@ -31,6 +32,8 @@ def __getattr__(self, name):
3132
return LLAMA3_VLLM_MAPPING
3233
elif name.startswith("qwen3"):
3334
return QWEN3_VLLM_MAPPING
35+
elif name.startswith("gpt"):
36+
return GptOssMaxTextMapping
3437
else:
3538
raise ValueError(f"{name} vLLM weight mapping not found.")
3639

0 commit comments

Comments
 (0)